diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0a6c7064d..bedbf12c5 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -173,14 +173,18 @@ def __init__( ) # Initialize the transformations between spherical and grid representations - self.SO3_grid = nn.ModuleList() - for lval in range(max(self.lmax_list) + 1): - SO3_m_grid = nn.ModuleList() - for m in range(max(self.lmax_list) + 1): - SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) - - self.SO3_grid.append(SO3_m_grid) - + assert self.num_resolutions == 1, "Only one resolution is supported" + self.SO3_grid = nn.ModuleDict() + self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) + self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) + # self.SO3_grid = nn.ModuleList() + # for lval in range(max(self.lmax_list) + 1): + # SO3_m_grid = nn.ModuleList() + # for m in range(max(self.lmax_list) + 1): + # SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) + + # self.SO3_grid.append(SO3_m_grid) + # import pdb;pdb.set_trace() self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) # Initialize the blocks for each layer of the GNN @@ -655,8 +659,8 @@ def forward( max_lmax = max(self.lmax_list) # Project to grid - x_grid_message = x_message.to_grid(self.SO3_grid, lmax=max_lmax) - x_grid = x.to_grid(self.SO3_grid, lmax=max_lmax) + x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) + x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) x_grid = torch.cat([x_grid, x_grid_message], dim=3) # Perform point-wise convolution @@ -665,7 +669,7 @@ def forward( x_grid = self.fc3_sphere(x_grid) # Project back to spherical harmonic coefficients - x_message._from_grid(x_grid, self.SO3_grid, lmax=max_lmax) + x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) # Return aggregated messages return x_message @@ -780,7 +784,7 @@ def forward( x_target.embedding = x_source.embedding + x_target.embedding # Point-wise spherical non-linearity - x_target._grid_act(self.SO3_grid, self.act, self.mappingReduced.res_size) + x_target._grid_act(self.SO3_grid["lmax_mmax"], self.act, self.mappingReduced.res_size) # Rotate back the irreps x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size) diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index cb0f9d223..0c674e288 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -274,12 +274,8 @@ def _grid_act(self, SO3_grid, act, res_size) -> None: num_coefficients = res_size[i] x_res = self.embedding[:, offset : offset + num_coefficients].contiguous() - to_grid_mat = SO3_grid[self.lmax_list[i]][ - self.mmax_list[i] - ].get_to_grid_mat(self.device) - from_grid_mat = SO3_grid[self.lmax_list[i]][ - self.mmax_list[i] - ].get_from_grid_mat(self.device) + to_grid_mat = SO3_grid.get_to_grid_mat(self.device) + from_grid_mat = SO3_grid.get_from_grid_mat(self.device) x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res) x_grid = act(x_grid) @@ -293,8 +289,8 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor: if lmax == -1: lmax = max(self.lmax_list) - to_grid_mat_lmax = SO3_grid[lmax][lmax].get_to_grid_mat(self.device) - grid_mapping = SO3_grid[lmax][lmax].mapping + to_grid_mat_lmax = SO3_grid.get_to_grid_mat(self.device) + grid_mapping = SO3_grid.mapping offset = 0 x_grid = torch.tensor([], device=self.device) @@ -316,12 +312,9 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor: return x_grid # Compute irreps from grid representation - def _from_grid(self, x_grid, SO3_grid, lmax: int = -1) -> None: - if lmax == -1: - lmax = max(self.lmax_list) - - from_grid_mat_lmax = SO3_grid[lmax][lmax].get_from_grid_mat(self.device) - grid_mapping = SO3_grid[lmax][lmax].mapping + def _from_grid(self, x_grid, SO3_grid) -> None: + from_grid_mat_lmax = SO3_grid.get_from_grid_mat(self.device) + grid_mapping = SO3_grid.mapping offset = 0 offset_channel = 0 diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 89ba8502a..e6d0e0d8f 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -118,7 +118,10 @@ def test_escn_compiles(self): torch._dynamo.config.optimize_ddp = False # torch._dynamo.explain(model)(data) # assert False - compiled_model(data) + output = compiled_model(data) + expected_energy, expected_forces = expected_energy_forces() + assert torch.allclose(output["energy"], expected_energy) + assert torch.allclose(output["forces"].mean(0), expected_forces) def test_rotation_invariance(self) -> None: random.seed(1)