Skip to content


update the two-track model used for yeast PPI screening
Browse files Browse the repository at this point in the history
  • Loading branch information
Minkyung Baek committed Nov 2, 2021
1 parent db3126e commit fcaac86
Show file tree
Hide file tree
Showing 14 changed files with 3,707 additions and 2 deletions.
10 changes: 8 additions & 2 deletions
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ While the code is licensed under the MIT License, the trained weights and data f
tar xfz weights.tar.gz
[Update Nov/02/2021] It's now including the weights ( for RoseTTAFold-2track model used for yeast PPI screening. If you want to use it, please re-download weights. The original RoseTTAFold weights are not changed.

4. Download and install third-party software.
Expand Down Expand Up @@ -65,6 +67,10 @@ cd example
# For complex modeling
# please see README file under example/complex_modeling/README for details.
python network/ -i paired.a3m -o complex -Ls 218 310
# For PPI screening using faster 2-track version (example input and output are at example/complex_2track)
python network_2track/ -msa [paired MSA file in a3m format] -npz [output npz file name] -L1 [Length of first chain]
e.g. python network_2track/ -msa input.a3m -npz complex.npz -L1 218

## Expected outputs
Expand All @@ -90,5 +96,5 @@ The codes in network/equivariant_attention is from the original SE(3)-Transforme

## References

M Baek, et al., Accurate prediction of protein structures and interactions using a 3-track network, bioRxiv (2021). [link](

M. Baek, et al., Accurate prediction of protein structures and interactions using a three-tracki neural network, Science (2021). [link](
I.R. Humphreys, J. Pei, M. Baek, A. Krishnakumar, et al, Structures of core eukaryotic protein complexes, bioRxiv (2021). [link](
Binary file added example/complex_2track/complex.npz
Binary file not shown.
2,354 changes: 2,354 additions & 0 deletions example/complex_2track/input.a3m

Large diffs are not rendered by default.

256 changes: 256 additions & 0 deletions network_2track/
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from Transformer import _get_clones, EncoderLayer, Encoder, InterEncoderLayer, InterEncoder, SpecialEncoderLayer, SpecialEncoder
import Transformer
from resnet import ResidualNetwork

# Attention module based on AlphaFold2's idea written by Minkyung Baek
# - Iterative MSA feature extraction
# - 1) MSA2Pair: extract pairwise feature from MSA --> added to previous residue-pair features
# architecture design inspired by CopulaNet paper
# - 2) MSA2MSA: process MSA features using Transformer (or Performer) encoder. (Attention over L first followed by attention over N)
# - 3) Pair2MSA: Update MSA features using pair feature
# - 4) Pair2Pair: process pair features using Transformer (or Performer) encoder.

class MSA2Pair(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32,
n_resblock=1, p_drop=0.1):
super(MSA2Pair, self).__init__()
# project down embedding dimension (n_feat --> n_feat_proj)
self.norm_1 = nn.LayerNorm(n_feat)
self.proj_1 = nn.Linear(n_feat, n_feat_proj)
self.norm_2d = nn.LayerNorm(n_feat_proj*n_feat_proj)

# project down to output dimension (pair feature dimension)
self.proj_2 = nn.Linear(n_feat_proj**2, n_feat_out)

# ResNet to update pair features
self.norm_down = nn.LayerNorm(n_feat_proj)
self.norm_orig = nn.LayerNorm(n_feat_out)
self.norm_new = nn.LayerNorm(n_feat_out)
self.update = ResidualNetwork(n_resblock, n_feat_out*2+n_feat_proj*4, n_feat_out, n_feat_out, p_drop=p_drop)

def forward(self, msa, pair_orig):
# Input: MSA embeddings (B, N, L, K), original pair embeddings (B, L, L, C)
# Output: updated pair info (B, L, L, C)
B, N, L, _ = msa.shape
# project down to reduce memory
msa = self.norm_1(msa)
x_down = self.proj_1(msa) # (B, N, L, n_feat_proj)
#pair = torch.einsum('abij,ablm->ailjm', x_down, x_down)
pair = torch.einsum('abij,ablm->ailjm', x_down, x_down/float(N)) # outer-product & average pool
pair = pair.reshape(B, L, L, -1)
pair = self.norm_2d(pair)
pair = self.proj_2(pair) # (B, L, L, n_feat_out) # project down to pair dimension

# average pooling over N of given MSA info
x_down = self.norm_down(x_down)
feat_1d = x_down.mean(1) # (B,L,K)
# query sequence info
query = x_down[:,0] # (B,L,K)
feat_1d =, query), dim=-1) # additional 1D features
# tile 1D features
left = feat_1d.unsqueeze(2).repeat(1, 1, L, 1)
right = feat_1d.unsqueeze(1).repeat(1, L, 1, 1)
# update original pair features through convolutions after concat
pair_orig = self.norm_orig(pair_orig)
pair = self.norm_new(pair)
pair =, pair, left, right), -1)
pair = pair.permute(0,3,1,2).contiguous() # prep for convolution layer
pair = self.update(pair)
pair = pair.permute(0,2,3,1).contiguous() # (B, L, L, C)

return pair

class MSA2MSA(nn.Module):
def __init__(self, n_layer=1, n_att_head=8, n_feat=256, r_ff=4, p_drop=0.1,
performer_N_opts=None, performer_L_opts=None):
super(MSA2MSA, self).__init__()
# attention along N
enc_layer_1 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
self.encoder_1 = Encoder(enc_layer_1, n_layer, n_feat)
# attention along L
enc_layer_2 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
self.encoder_2 = Encoder(enc_layer_2, n_layer, n_feat)

def forward(self, x):
# Input: MSA embeddings (B, N, L, K)
# Output: updated MSA embeddings (B, N, L, K)
B, N, L, _ = x.shape
# attention along N
x = x.permute(0,2,1,3).contiguous()
x = self.encoder_1(x)
x = x.permute(0,2,1,3).contiguous()
# attention along L
x = self.encoder_2(x)
return x

class Pair2MSA(nn.Module):
def __init__(self, n_layer=1, n_att_head=4, n_feat_in=128, n_feat_out=256, r_ff=4, p_drop=0.1):
super(Pair2MSA, self).__init__()
enc_layer = SpecialEncoderLayer(heads=n_att_head, \
d_in=n_feat_in, d_out=n_feat_out,\
self.encoder = SpecialEncoder(enc_layer, n_layer, n_feat_out)

def forward(self, pair, msa):
out = self.encoder(pair, msa) # (B, N, L, K)
return out

class Pair2Pair(nn.Module):
def __init__(self, n_layer=1, n_att_head=8, n_feat=128, r_ff=4, p_drop=0.1,
super(Pair2Pair, self).__init__()
enc_layer_1 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
self.encoder_1 = Encoder(enc_layer_1, n_layer, n_feat)
enc_layer_2 = EncoderLayer(d_model=n_feat, d_ff=n_feat*r_ff,
heads=n_att_head, p_drop=p_drop,
self.encoder_2 = Encoder(enc_layer_2, n_layer, n_feat)
def forward(self, x):
# Input: residue pair embeddings (B, L, L, C)
# Ouput: residue pair embeddings (B, L, L, C)
# attention over column
B, L = x.shape[:2]
x = self.encoder_1(x) # attention over column
x = x.permute(0,2,1,3).contiguous()
x = self.encoder_2(x) # attention over row
x = x.permute(0, 2, 1, 3).contiguous()
return x

class IterBlock(nn.Module):
def __init__(self, n_layer=1, d_msa=64, d_pair=128, n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1, performer_L_opts=None, performer_N_opts=None):
super(IterBlock, self).__init__()

self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=n_head_pair,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)

def forward(self, msa, pair):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)

# 1. process MSA features
msa = self.msa2msa(msa)

# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair)

# 3. process pair features
pair = self.pair2pair(pair)

# 4. update MSA features using updated pair features
msa = self.pair2msa(pair, msa)

return msa, pair

class IterBlockShare(nn.Module):
def __init__(self, n_module=4, n_layer=1, d_msa=64, d_pair=128,
n_head_msa=4, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1,
performer_L_opts=None, performer_N_opts=None):
super(IterBlockShare, self).__init__()
self.n_module = n_module
self.msa2msa = MSA2MSA(n_layer=n_layer, n_att_head=n_head_msa, n_feat=d_msa,
r_ff=r_ff, p_drop=p_drop,
self.msa2pair = MSA2Pair(n_feat=d_msa, n_feat_out=d_pair, n_feat_proj=32,
n_resblock=n_resblock, p_drop=p_drop)
self.pair2pair = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,
self.pair2msa = Pair2MSA(n_layer=n_layer, n_att_head=n_head_pair,
n_feat_in=d_pair, n_feat_out=d_msa, r_ff=r_ff, p_drop=p_drop)

def forward(self, msa, pair):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)

for i_m in range(self.n_module):
# 1. process MSA features
msa = self.msa2msa(msa)

# 2. update pair features using given MSA
pair = self.msa2pair(msa, pair)

# 3. process pair features
pair = self.pair2pair(pair)

# 4. update MSA features using updated pair features
msa = self.pair2msa(pair, msa)

return msa, pair

class IterativeFeatureExtractor(nn.Module):
def __init__(self, n_module=4, n_diff_module=2, n_layer=4, d_msa=256, d_pair=128,
n_head_msa=8, n_head_pair=8, r_ff=4,
n_resblock=1, p_drop=0.1,
performer_L_opts=None, performer_N_opts=None):
super(IterativeFeatureExtractor, self).__init__()
self.n_module = n_module
self.n_diff_module = n_diff_module
self.n_share_module = n_module - n_diff_module
self.initial = Pair2Pair(n_layer=n_layer, n_att_head=n_head_pair,
n_feat=d_pair, r_ff=r_ff, p_drop=p_drop,

if self.n_diff_module > 0:
self.iter_block_1 = _get_clones(IterBlock(n_layer=n_layer,
d_msa=d_msa, d_pair=d_pair,
), n_diff_module)
if self.n_share_module > 0:
self.iter_block_2 = IterBlockShare(n_module=n_module-n_diff_module, n_layer=n_layer,
d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa, n_head_pair=n_head_pair,

def forward(self, msa, pair):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)

pair_s = list()
pair = self.initial(pair)
if self.n_diff_module > 0:
for i_m in range(self.n_diff_module):
# extract features from MSA & update original pair features
msa, pair = self.iter_block_1[i_m](msa, pair)

if self.n_share_module > 0:
msa, pair = self.iter_block_2(msa, pair)

return msa, pair
28 changes: 28 additions & 0 deletions network_2track/
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import torch.nn as nn
from resnet import ResidualNetwork

# predict distance map from pair features
# based on simple 2D ResNet

class DistanceNetwork(nn.Module):
def __init__(self, n_block, n_feat, block_type='orig', p_drop=0.1):
super(DistanceNetwork, self).__init__()
self.resnet_dist = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
self.resnet_omega = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
self.resnet_theta = ResidualNetwork(n_block, n_feat, n_feat, 37, block_type=block_type, p_drop=p_drop)
self.resnet_phi = ResidualNetwork(n_block, n_feat, n_feat, 19, block_type=block_type, p_drop=p_drop)

def forward(self, x):
# input: pair info (1, C, L, L)

# predict theta, phi (non-symmetric)
logits_theta = self.resnet_theta(x)
logits_phi = self.resnet_phi(x)

# predict dist, omega
x = 0.5 * (x + x.permute(0,1,3,2))
logits_dist = self.resnet_dist(x)
logits_omega = self.resnet_omega(x)

return logits_dist, logits_omega, logits_theta, logits_phi

0 comments on commit fcaac86

Please sign in to comment.