From 936a7be3ceff38ab8da84696ce522c83c8b73cd1 Mon Sep 17 00:00:00 2001 From: anuroopsriram Date: Thu, 14 Sep 2023 15:36:15 -0700 Subject: [PATCH] Support for passing in stats to Equiformer V2 model (#576) Co-authored-by: Abhishek Das --- ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py b/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py index 59fabeda5..b65aa60f5 100644 --- a/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py +++ b/ocpmodels/models/equiformer_v2/equiformer_v2_oc20.py @@ -139,6 +139,8 @@ def __init__( proj_drop: float = 0.0, weight_init: str = "normal", enforce_max_neighbors_strictly: bool = True, + avg_num_nodes: Optional[float] = None, + avg_degree: Optional[float] = None, ): super().__init__() @@ -197,6 +199,9 @@ def __init__( self.drop_path_rate = drop_path_rate self.proj_drop = proj_drop + self.avg_num_nodes = avg_num_nodes or _AVG_NUM_NODES + self.avg_degree = avg_degree or _AVG_DEGREE + self.weight_init = weight_init assert self.weight_init in ["normal", "uniform"] @@ -286,7 +291,7 @@ def __init__( self.max_num_elements, self.edge_channels_list, self.block_use_atom_edge_embedding, - rescale_factor=_AVG_DEGREE, + rescale_factor=self.avg_degree, ) # Initialize the blocks for each layer of EquiformerV2 @@ -480,7 +485,7 @@ def forward(self, data): dtype=node_energy.dtype, ) energy.index_add_(0, data.batch, node_energy.view(-1)) - energy = energy / _AVG_NUM_NODES + energy = energy / self.avg_num_nodes ############################################################### # Force estimation