zl程序教程

您现在的位置是:首页 >  其它

当前栏目

alphaFold2 | 补充Evoformer之outer productor mean

补充 outer AlphaFold2 Mean
2023-06-13 09:13:13 时间

<<AlphaFold2专题>>

alphaFold2 | 解决问题及背景(一)

alphaFold2 | 模型框架搭建(二)

alphaFold2 | 模型细节之特征提取(三)

alphaFold2 | 模型细节之Evoformer(四)

  • 文章转自微信公众号:机器学习炼丹术
  • 作者:陈亦新(欢迎交流共同进步) 补充一下下图这个结构的计算过程:
class OuterMean(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim = None,
        eps = 1e-5
    ):
        super().__init__()
        self.eps = eps
        self.norm = nn.LayerNorm(dim)
        hidden_dim = default(hidden_dim, dim)

        self.left_proj = nn.Linear(dim, hidden_dim)
        self.right_proj = nn.Linear(dim, hidden_dim)
        self.proj_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, mask = None):
        x = self.norm(x)
        left = self.left_proj(x)
        right = self.right_proj(x)
        outer = rearrange(left, 'b m i d -> b m i () d') * rearrange(right, 'b m j d -> b m () j d')

        if exists(mask):
            # masked mean, if there are padding in the rows of the MSA
            mask = rearrange(mask, 'b m i -> b m i () ()') * rearrange(mask, 'b m j -> b m () j ()')
            outer = outer.masked_fill(~mask, 0.)
            outer = outer.mean(dim = 1) / (mask.sum(dim = 1) + self.eps)
        else:
            outer = outer.mean(dim = 1)

        return self.proj_out(outer)

上面代码是一个矩阵的操作,我们可以将其具体到单个元素来看

这张图中,左边msa特征中,画出来的分别是第i个氨基酸和第j个氨基酸的特征。这两个特征分别是(s,c)的形状,s表示msa特征的氨基酸序列数,c是特征数量。而右边pair特征当中,第i个氨基酸和第j个氨基酸构成的对的特征,其实就是

c_z

长度的一维特征。

通过这样的方式,实现了从msa特征当中更新pair特征的方式。