forked from RosettaCommons/RFdiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Joseph Watson
committed
Mar 30, 2023
0 parents
commit 4c62149
Showing
108 changed files
with
38,150 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class DistanceNetwork(nn.Module): | ||
def __init__(self, n_feat, p_drop=0.1): | ||
super(DistanceNetwork, self).__init__() | ||
# | ||
self.proj_symm = nn.Linear(n_feat, 37*2) | ||
self.proj_asymm = nn.Linear(n_feat, 37+19) | ||
|
||
self.reset_parameter() | ||
|
||
def reset_parameter(self): | ||
# initialize linear layer for final logit prediction | ||
nn.init.zeros_(self.proj_symm.weight) | ||
nn.init.zeros_(self.proj_asymm.weight) | ||
nn.init.zeros_(self.proj_symm.bias) | ||
nn.init.zeros_(self.proj_asymm.bias) | ||
|
||
def forward(self, x): | ||
# input: pair info (B, L, L, C) | ||
|
||
# predict theta, phi (non-symmetric) | ||
logits_asymm = self.proj_asymm(x) | ||
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2) | ||
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2) | ||
|
||
# predict dist, omega | ||
logits_symm = self.proj_symm(x) | ||
logits_symm = logits_symm + logits_symm.permute(0,2,1,3) | ||
logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2) | ||
logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2) | ||
|
||
return logits_dist, logits_omega, logits_theta, logits_phi | ||
|
||
class MaskedTokenNetwork(nn.Module): | ||
def __init__(self, n_feat, p_drop=0.1): | ||
super(MaskedTokenNetwork, self).__init__() | ||
self.proj = nn.Linear(n_feat, 21) | ||
|
||
self.reset_parameter() | ||
|
||
def reset_parameter(self): | ||
nn.init.zeros_(self.proj.weight) | ||
nn.init.zeros_(self.proj.bias) | ||
|
||
def forward(self, x): | ||
B, N, L = x.shape[:3] | ||
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L) | ||
|
||
return logits | ||
|
||
class LDDTNetwork(nn.Module): | ||
def __init__(self, n_feat, n_bin_lddt=50): | ||
super(LDDTNetwork, self).__init__() | ||
self.proj = nn.Linear(n_feat, n_bin_lddt) | ||
|
||
self.reset_parameter() | ||
|
||
def reset_parameter(self): | ||
nn.init.zeros_(self.proj.weight) | ||
nn.init.zeros_(self.proj.bias) | ||
|
||
def forward(self, x): | ||
logits = self.proj(x) # (B, L, 50) | ||
|
||
return logits.permute(0,2,1) | ||
|
||
class ExpResolvedNetwork(nn.Module): | ||
def __init__(self, d_msa, d_state, p_drop=0.1): | ||
super(ExpResolvedNetwork, self).__init__() | ||
self.norm_msa = nn.LayerNorm(d_msa) | ||
self.norm_state = nn.LayerNorm(d_state) | ||
self.proj = nn.Linear(d_msa+d_state, 1) | ||
|
||
self.reset_parameter() | ||
|
||
def reset_parameter(self): | ||
nn.init.zeros_(self.proj.weight) | ||
nn.init.zeros_(self.proj.bias) | ||
|
||
def forward(self, seq, state): | ||
B, L = seq.shape[:2] | ||
|
||
seq = self.norm_msa(seq) | ||
state = self.norm_state(state) | ||
feat = torch.cat((seq, state), dim=-1) | ||
logits = self.proj(feat) | ||
return logits.reshape(B, L) | ||
|
||
|
||
|
Oops, something went wrong.