diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 7b05278f2..c113cc1c8 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -757,7 +757,6 @@ def forward( ############################################################### # Compute messages ############################################################### - # Compute edge scalar features (invariant to rotations) # Uses atomic numbers and edge distance as inputs x_edge = self.edge_block( diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 6a2763266..cd339b887 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -33,6 +33,9 @@ from torch.export import export from torch.export import Dim +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="skipping when no gpu") + + def load_data(): atoms = read( os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), @@ -86,7 +89,7 @@ def init(backend="nccl"): class TestESCNCompiles: def test_escn_baseline_cpu(self): - init("gloo") + init('gloo') data = load_data() data = data_list_collater([data]) model = load_model() @@ -97,6 +100,19 @@ def test_escn_baseline_cpu(self): assert torch.allclose(output["energy"], expected_energy) assert torch.allclose(output["forces"].mean(0), expected_forces) + @skip_if_no_cuda + def test_escn_baseline_cuda(self): + init('nccl') + data = load_data() + data = data_list_collater([data]).to("cuda") + model = load_model().cuda() + ddp_model = DistributedDataParallel(model) + output = ddp_model(data) + expected_energy, expected_forces = expected_energy_forces() + torch.set_printoptions(precision=8) + assert torch.allclose(output["energy"].cpu(), expected_energy) + assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces) + def test_escn_compiles(self): init("gloo") data = load_data()