Skip to content

Commit

Permalink
fix rank2 head and add to e2e test
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Aug 2, 2024
1 parent 0a1faa6 commit f1cbbbb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def forward(self, edge_distance_vec, x_edge, edge_index, data):

# node_outer: nAtoms, 9 => average across all atoms at the structure level
if self.extensive:
stress = scatter(node_outer, data.batch, dim=0, reduce="sum")
r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="sum")
else:
stress = scatter(node_outer, data.batch, dim=0, reduce="mean")
return stress
r2_tensor = scatter(node_outer, data.batch, dim=0, reduce="mean")
return r2_tensor


class Rank2DecompositionEdgeBlock(nn.Module):
Expand Down Expand Up @@ -237,44 +237,47 @@ def __init__(
backbone: BackboneInterface,
output_name: str,
decompose: bool = False,
use_source_target_embedding_stress: bool = False,
edge_level_mlp: bool = False,
use_source_target_embedding: bool = False,
extensive: bool = False,
):
"""
Args:
backbone: Backbone model that the head is attached to
decompose: Wether to decompose the rank2 tensor into isotropic and anisotropic components
use_source_target_embedding_stress: Whether to use both source and target atom embeddings
use_source_target_embedding: Whether to use both source and target atom embeddings
extensive: Whether to do sum-pooling (extensive) vs mean pooling (intensive).
"""
super().__init__()
self.output_name = output_name
self.decompose = decompose
self.use_source_target_embedding = use_source_target_embedding

self.sphharm_norm = get_normalization_layer(
backbone.norm_type,
lmax=max(backbone.lmax_list),
num_channels=1,
)

if use_source_target_embedding_stress:
stress_sphere_channels = self.sphere_channels * 2
if use_source_target_embedding:
r2_tensor_sphere_channels = backbone.sphere_channels * 2
else:
stress_sphere_channels = self.sphere_channels
r2_tensor_sphere_channels = backbone.sphere_channels

self.xedge_layer_norm = nn.LayerNorm(stress_sphere_channels)
self.xedge_layer_norm = nn.LayerNorm(r2_tensor_sphere_channels)

if decompose:
self.block = Rank2DecompositionEdgeBlock(
emb_size=stress_sphere_channels,
emb_size=r2_tensor_sphere_channels,
num_layers=2,
edge_level=self.edge_level_mlp_stress,
edge_level=edge_level_mlp,
extensive=extensive,
)
else:
self.block = Rank2Block(
emb_size=stress_sphere_channels,
emb_size=r2_tensor_sphere_channels,
num_layers=2,
edge_level=self.edge_level_mlp_stress,
edge_level=edge_level_mlp,
extensive=extensive,
)

Expand All @@ -300,11 +303,11 @@ def forward(
).detach()

# layer norm because sphharm_weights_edge values become large and causes infs with amp
sphharm_weights_edge = self.stress_sph_norm(
sphharm_weights_edge = self.sphharm_norm(
sphharm_weights_edge[:, :, None]
).squeeze()

if self.use_source_target_embedding_stress:
if self.use_source_target_embedding:
x_source = node_emb.expand_edge(graph.edge_index[0]).embedding
x_target = node_emb.expand_edge(graph.edge_index[1]).embedding
x_edge = torch.cat((x_source, x_target), dim=2)
Expand All @@ -314,10 +317,10 @@ def forward(
x_edge = torch.einsum("abc, ab->ac", x_edge, sphharm_weights_edge)

# layer norm because x_edge values become large and causes infs with amp
x_edge = self.stress_xedge_layer_norm(x_edge)
x_edge = self.xedge_layer_norm(x_edge)

if self.decompose_stress:
tensor_0, tensor_2 = self.stress_block(
if self.decompose:
tensor_0, tensor_2 = self.block(
graph.edge_distance_vec, x_edge, graph.edge_index[1], data
)

Expand All @@ -330,9 +333,9 @@ def forward(
f"{self.output_name}_anisotropic": tensor_2,
}
else:
stress = self.stress_block(
out_tensor = self.block(
graph.edge_distance_vec, x_edge, graph.edge_index[1], data
)
output = {self.output_name: stress.reshape((-1, 3))}
output = {self.output_name: out_tensor.reshape((-1, 3))}

return output
3 changes: 3 additions & 0 deletions tests/core/models/test_configs/test_equiformerv2_hydra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ model:
module: equiformer_v2_energy_head
forces:
module: equiformer_v2_force_head
stress:
module: rank2_symmetric_head
output_name: "stress"

dataset:
train:
Expand Down

0 comments on commit f1cbbbb

Please sign in to comment.