From f1cbbbb4a717bd977d07f2d8dfd1d09b6bd7e124 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 2 Aug 2024 14:01:11 -0600 Subject: [PATCH] fix rank2 head and add to e2e test --- .../equiformer_v2/prediction_heads/rank2.py | 43 ++++++++++--------- .../test_configs/test_equiformerv2_hydra.yml | 3 ++ 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py index d85896503..aaa39ef22 100644 --- a/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py +++ b/src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py @@ -85,10 +85,10 @@ def forward(self, edge_distance_vec, x_edge, edge_index, data): # node_outer: nAtoms, 9 => average across all atoms at the structure level if self.extensive: - stress = scatter(node_outer, data.batch, dim=0, reduce="sum") + r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="sum") else: - stress = scatter(node_outer, data.batch, dim=0, reduce="mean") - return stress + r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="mean") + return r2_tensor class Rank2DecompositionEdgeBlock(nn.Module): @@ -237,44 +237,47 @@ def __init__( backbone: BackboneInterface, output_name: str, decompose: bool = False, - use_source_target_embedding_stress: bool = False, + edge_level_mlp: bool = False, + use_source_target_embedding: bool = False, extensive: bool = False, ): """ Args: backbone: Backbone model that the head is attached to decompose: Wether to decompose the rank2 tensor into isotropic and anisotropic components - use_source_target_embedding_stress: Whether to use both source and target atom embeddings + use_source_target_embedding: Whether to use both source and target atom embeddings extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive). """ super().__init__() self.output_name = output_name self.decompose = decompose + self.use_source_target_embedding = use_source_target_embedding + self.sphharm_norm = get_normalization_layer( backbone.norm_type, lmax=max(backbone.lmax_list), num_channels=1, ) - if use_source_target_embedding_stress: - stress_sphere_channels = self.sphere_channels * 2 + if use_source_target_embedding: + r2_tensor_sphere_channels = backbone.sphere_channels * 2 else: - stress_sphere_channels = self.sphere_channels + r2_tensor_sphere_channels = backbone.sphere_channels - self.xedge_layer_norm = nn.LayerNorm(stress_sphere_channels) + self.xedge_layer_norm = nn.LayerNorm(r2_tensor_sphere_channels) if decompose: self.block = Rank2DecompositionEdgeBlock( - emb_size=stress_sphere_channels, + emb_size=r2_tensor_sphere_channels, num_layers=2, - edge_level=self.edge_level_mlp_stress, + edge_level=edge_level_mlp, extensive=extensive, ) else: self.block = Rank2Block( - emb_size=stress_sphere_channels, + emb_size=r2_tensor_sphere_channels, num_layers=2, - edge_level=self.edge_level_mlp_stress, + edge_level=edge_level_mlp, extensive=extensive, ) @@ -300,11 +303,11 @@ def forward( ).detach() # layer norm because sphharm_weights_edge values become large and causes infs with amp - sphharm_weights_edge = self.stress_sph_norm( + sphharm_weights_edge = self.sphharm_norm( sphharm_weights_edge[:, :, None] ).squeeze() - if self.use_source_target_embedding_stress: + if self.use_source_target_embedding: x_source = node_emb.expand_edge(graph.edge_index[0]).embedding x_target = node_emb.expand_edge(graph.edge_index[1]).embedding x_edge = torch.cat((x_source, x_target), dim=2) @@ -314,10 +317,10 @@ def forward( x_edge = torch.einsum("abc, ab->ac", x_edge, sphharm_weights_edge) # layer norm because x_edge values become large and causes infs with amp - x_edge = self.stress_xedge_layer_norm(x_edge) + x_edge = self.xedge_layer_norm(x_edge) - if self.decompose_stress: - tensor_0, tensor_2 = self.stress_block( + if self.decompose: + tensor_0, tensor_2 = self.block( graph.edge_distance_vec, x_edge, graph.edge_index[1], data ) @@ -330,9 +333,9 @@ def forward( f"{self.output_name}_anisotropic": tensor_2, } else: - stress = self.stress_block( + out_tensor = self.block( graph.edge_distance_vec, x_edge, graph.edge_index[1], data ) - output = {self.output_name: stress.reshape((-1, 3))} + output = {self.output_name: out_tensor.reshape((-1, 3))} return output diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml index 4c00fe6a2..dbe3e4899 100644 --- a/tests/core/models/test_configs/test_equiformerv2_hydra.yml +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -52,6 +52,9 @@ model: module: equiformer_v2_energy_head forces: module: equiformer_v2_force_head + stress: + module: rank2_symmetric_head + output_name: "stress" dataset: train: