From 254ea9c6b642676f4a0385fb3add306eb3a0ead0 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 8 Aug 2024 12:20:18 -0700 Subject: [PATCH] small fixes --- src/fairchem/core/models/base.py | 2 +- src/fairchem/core/models/equiformer_v2/equiformer_v2.py | 4 ++-- .../core/models/equiformer_v2/prediction_heads/rank2.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 8ce8f3fcb..4936c725f 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -228,7 +228,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: @registry.register_model("hydra") -class HydraModel(nn.Module, GraphModelMixin): +class HydraModel(nn.Module): def __init__( self, backbone: dict, diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 37b5c72c0..b1c7214fa 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -836,7 +836,7 @@ def __init__(self, backbone): backbone.use_grid_mlp, backbone.use_sep_s2_act, ) - self.apply(backbone._init_weights) + self.apply(backbone.init_weights) self.apply(backbone._uniform_init_rad_func_linear_weights) def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): @@ -881,7 +881,7 @@ def __init__(self, backbone): backbone.use_sep_s2_act, alpha_drop=0.0, ) - self.apply(backbone._init_weights) + self.apply(backbone.init_weights) self.apply(backbone._uniform_init_rad_func_linear_weights) def forward(self, data: Batch, emb: dict[str, torch.Tensor]): diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index aaa39ef22..2c18e4afe 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -25,7 +25,7 @@ class Rank2Block(nn.Module): Args: emb_size (int): Size of edge embedding used to compute outer product num_layers (int): Number of layers of the MLP - edge_level (bool): Apply MLP to edges' outer product + edge_level (bool): If true apply MLP at edge level before pooling, otherwise use MLP at nodes after pooling extensive (bool): Whether to sum or average the outer products """ @@ -324,7 +324,7 @@ def forward( graph.edge_distance_vec, x_edge, graph.edge_index[1], data ) - if self.extensive: # legacy, may be interesting to try + if self.block.extensive: # legacy, may be interesting to try tensor_0 = tensor_0 / self.avg_num_nodes tensor_2 = tensor_2 / self.avg_num_nodes