diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index cd339b887..2caa7acc8 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -83,7 +83,12 @@ def expected_energy_forces(): forces = torch.tensor([-1.2720219900e-07, 8.2126695133e-07, -3.8776403244e-07]) return energy, forces -def init(backend="nccl"): +def expected_energy_forces_cuda(): + energy = torch.tensor([0.0001747000]) + forces = torch.tensor([-4.5273015559e-08, 9.0246174977e-07, -3.8560736471e-07]) + return energy, forces + +def init(backend: str): if not torch.distributed.is_initialized(): init_local_distributed_process_group(backend=backend) @@ -108,7 +113,7 @@ def test_escn_baseline_cuda(self): model = load_model().cuda() ddp_model = DistributedDataParallel(model) output = ddp_model(data) - expected_energy, expected_forces = expected_energy_forces() + expected_energy, expected_forces = expected_energy_forces_cuda() torch.set_printoptions(precision=8) assert torch.allclose(output["energy"].cpu(), expected_energy) assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces)