Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Aug 8, 2024
1 parent c3b6d29 commit 254ea9c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:


@registry.register_model("hydra")
class HydraModel(nn.Module, GraphModelMixin):
class HydraModel(nn.Module):
def __init__(
self,
backbone: dict,
Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def __init__(self, backbone):
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(backbone._init_weights)
self.apply(backbone.init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
Expand Down Expand Up @@ -881,7 +881,7 @@ def __init__(self, backbone):
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(backbone._init_weights)
self.apply(backbone.init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Rank2Block(nn.Module):
Args:
emb_size (int): Size of edge embedding used to compute outer product
num_layers (int): Number of layers of the MLP
edge_level (bool): Apply MLP to edges' outer product
edge_level (bool): If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling
extensive (bool): Whether to sum or average the outer products
"""

Expand Down Expand Up @@ -324,7 +324,7 @@ def forward(
graph.edge_distance_vec, x_edge, graph.edge_index[1], data
)

if self.extensive: # legacy, may be interesting to try
if self.block.extensive: # legacy, may be interesting to try
tensor_0 = tensor_0 / self.avg_num_nodes
tensor_2 = tensor_2 / self.avg_num_nodes

Expand Down

0 comments on commit 254ea9c

Please sign in to comment.