diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index b13c8e5cd..97702873a 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -500,7 +500,6 @@ def forward( edge_index, wigner, ) - print(f"x_message: {x_message.mean()}") # Compute point-wise spherical non-linearity on aggregated messages @@ -517,14 +516,12 @@ def forward( x_grid = self.act(self.fc1_sphere(x_grid)) x_grid = self.act(self.fc2_sphere(x_grid)) x_grid = self.fc3_sphere(x_grid) - print(f"x_grid: {x_grid.mean()}") # Project back to spherical harmonic coefficients # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) - print(f"x_message_final: {x_message_final.mean()}") # Return aggregated messages return x_message_final diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index c0458fbb9..8e53eb702 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -226,10 +226,6 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True torch._dynamo.config.verbose = True - # torch._logging.set_logs(dynamo = logging.INFO) - # torch._dynamo.reset() - # explain_output = torch._dynamo.explain(message_block)(*args[0]) - # print(explain_output) compiled_model = torch.compile(message_block, dynamic=True) compiled_output = compiled_model(*args[0]) @@ -240,6 +236,72 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: assert torch.allclose(compiled_output, regular_out, atol=tol) assert torch.allclose(exported_output, regular_out, atol=tol) + def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: + random.seed(1) + + sphere_channels = 128 + hidden_channels = 128 + edge_channels = 128 + lmax, mmax = 4, 2 + distance_expansion = GaussianSmearing(0.0, 8.0, int(8.0 / 0.02), 1.0) + SO3_grid = torch.nn.ModuleDict() + SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) + SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) + mappingReduced = CoefficientMapping([lmax], [mmax]) + layer_block = escn_exportable.LayerBlock( + layer_idx = 0, + sphere_channels = sphere_channels, + hidden_channels = hidden_channels, + edge_channels = edge_channels, + lmax_list = [lmax], + mmax_list = [mmax], + distance_expansion = distance_expansion, + max_num_elements = 90, + SO3_grid = SO3_grid, + act = torch.nn.SiLU(), + mappingReduced = mappingReduced + ) + + # generate inputs + batch_sizes = [34, 35, 35] + num_edges = [680, 700, 680] + num_coefs = 25 + run_args = [] + for b,edges in zip(batch_sizes, num_edges): + x = torch.rand([b, num_coefs, sphere_channels]) + atom_n = torch.randint(1, 90, (b,)) + edge_d = torch.rand([edges]) + edge_indx = torch.randint(0, b, (2, edges)) + wigner = torch.rand([edges, num_coefs, num_coefs]) + run_args.append((x, atom_n, edge_d, edge_indx, wigner)) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.verbose = True + # torch._logging.set_logs(dynamo = logging.INFO) + # torch._dynamo.reset() + # explain_output = torch._dynamo.explain(message_block)(*args[0]) + # print(explain_output) + + batch_dim = Dim("batch_dim") + edges_dim = Dim("edges_dim") + dynamic_shapes1 = { + "x": {0: batch_dim, 1: None, 2: None}, + "atomic_numbers": {0: batch_dim}, + "edge_distance": {0: edges_dim}, + "edge_index": {0: None, 1: edges_dim}, + "wigner": {0: edges_dim, 1: None, 2: None} + } + exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1) + for run_arg in run_args: + exported_output = exported_prog(*run_arg) + compiled_model = torch.compile(layer_block, dynamic=True) + compiled_output = compiled_model(*run_arg) + regular_out = layer_block(*run_arg) + assert torch.allclose(compiled_output, regular_out, atol=tol) + assert torch.allclose(exported_output, regular_out, atol=tol) + def test_escn_compiles(self): init("gloo") data = load_data()