diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index c113cc1c8..26e4b86bc 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -259,9 +259,7 @@ def forward(self, data): ) # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = nn.ModuleList() - for i in range(self.num_resolutions): - self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) ############################################################### # Initialize node embeddings @@ -290,11 +288,11 @@ def forward(self, data): ############################################################### # Update spherical node embeddings ############################################################### - + x_message = x.embedding for i in range(self.num_layers): if i > 0: - x_message = self.layer_blocks[i]( - x, + x_message_new = self.layer_blocks[i]( + x_message, atomic_numbers, graph.edge_distance, graph.edge_index, @@ -302,17 +300,18 @@ def forward(self, data): ) # Residual layer for all layers past the first - x.embedding = x.embedding + x_message.embedding + x_xessage = x_message + x_message_new else: # No residual for the first layer - x = self.layer_blocks[i]( - x, + x_message = self.layer_blocks[i]( + x_message, atomic_numbers, graph.edge_distance, graph.edge_index, self.SO3_edge_rot, ) + x.embedding = x_message # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. @@ -448,9 +447,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ) # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = nn.ModuleList() - for i in range(self.num_resolutions): - self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) ############################################################### # Initialize node embeddings @@ -640,12 +637,12 @@ def __init__( def forward( self, - x, - atomic_numbers, - edge_distance, - edge_index, - SO3_edge_rot, - ): + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + SO3_edge_rot: SO3_Rotation, + ) -> torch.Tensor: # Compute messages by performing message block x_message = self.message_block( x, @@ -654,25 +651,33 @@ def forward( edge_index, SO3_edge_rot, ) + print(f"x_message: {x_message.mean()}") # Compute point-wise spherical non-linearity on aggregated messages - max_lmax = max(self.lmax_list) # Project to grid - x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) - x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) + # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) + + # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x) x_grid = torch.cat([x_grid, x_grid_message], dim=3) # Perform point-wise convolution x_grid = self.act(self.fc1_sphere(x_grid)) x_grid = self.act(self.fc2_sphere(x_grid)) x_grid = self.fc3_sphere(x_grid) + print(f"x_grid: {x_grid.mean()}") # Project back to spherical harmonic coefficients - x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) + # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + print(f"x_message_final: {x_message_final.mean()}") # Return aggregated messages - return x_message + return x_message_final class MessageBlock(torch.nn.Module): @@ -748,12 +753,12 @@ def __init__( def forward( self, - x, - atomic_numbers, - edge_distance, - edge_index, - SO3_edge_rot, - ): + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + SO3_edge_rot: SO3_Rotation, + ) -> torch.Tensor: ############################################################### # Compute messages ############################################################### @@ -768,30 +773,36 @@ def forward( # Copy embeddings for each edge's source and target nodes x_source = x.clone() x_target = x.clone() - x_source._expand_edge(edge_index[0, :]) - x_target._expand_edge(edge_index[1, :]) + x_source = x_source[edge_index[0, :]] + x_target = x_target[edge_index[1, :]] # Rotate the irreps to align with the edge - x_source._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) - x_target._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) + x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) + x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute messages - x_source.embedding = self.so2_block_source(x_source.embedding, x_edge) - x_target.embedding = self.so2_block_target(x_target.embedding, x_edge) + x_source = self.so2_block_source(x_source, x_edge) + x_target = self.so2_block_target(x_target, x_edge) # Add together the source and target results - x_target.embedding = x_source.embedding + x_target.embedding + x_target = x_source + x_target # Point-wise spherical non-linearity - x_target._grid_act(self.SO3_grid["lmax_mmax"], self.act, self.mappingReduced.res_size) + to_grid_mat = self.SO3_grid["lmax_mmax"].get_to_grid_mat() + from_grid_mat = self.SO3_grid["lmax_mmax"].get_from_grid_mat() + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_target) + x_grid = self.act(x_grid) + x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size) + x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute the sum of the incoming neighboring messages for each target node - x_target._reduce_edge(edge_index[1], len(x.embedding)) + new_embedding = torch.fill(x.clone(), 0) + new_embedding.index_add_(0, edge_index[1], x_target) + # x_target._reduce_edge(edge_index[1], len(x.embedding)) - return x_target + return new_embedding class SO2Block(torch.nn.Module): diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 2caa7acc8..817a5e674 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -93,7 +93,7 @@ def init(backend: str): init_local_distributed_process_group(backend=backend) class TestESCNCompiles: - def test_escn_baseline_cpu(self): + def test_escn_baseline_cpu(self, tol=1e-5): init('gloo') data = load_data() data = data_list_collater([data]) @@ -102,11 +102,11 @@ def test_escn_baseline_cpu(self): output = ddp_model(data) expected_energy, expected_forces = expected_energy_forces() torch.set_printoptions(precision=8) - assert torch.allclose(output["energy"], expected_energy) - assert torch.allclose(output["forces"].mean(0), expected_forces) + assert torch.allclose(output["energy"], expected_energy, atol=tol) + assert torch.allclose(output["forces"].mean(0), expected_forces, atol=tol) @skip_if_no_cuda - def test_escn_baseline_cuda(self): + def test_escn_baseline_cuda(self, tol=1e-5): init('nccl') data = load_data() data = data_list_collater([data]).to("cuda") @@ -115,8 +115,8 @@ def test_escn_baseline_cuda(self): output = ddp_model(data) expected_energy, expected_forces = expected_energy_forces_cuda() torch.set_printoptions(precision=8) - assert torch.allclose(output["energy"].cpu(), expected_energy) - assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces) + assert torch.allclose(output["energy"].cpu(), expected_energy, atol=tol) + assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces, atol=tol) def test_escn_compiles(self): init("gloo")