Skip to content

Commit

Permalink
pass cuda test
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 29, 2024
1 parent d6ac5b7 commit 4017073
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def expected_energy_forces():
forces = torch.tensor([-1.2720219900e-07, 8.2126695133e-07, -3.8776403244e-07])
return energy, forces

def init(backend="nccl"):
def expected_energy_forces_cuda():
energy = torch.tensor([0.0001747000])
forces = torch.tensor([-4.5273015559e-08, 9.0246174977e-07, -3.8560736471e-07])
return energy, forces

def init(backend: str):
if not torch.distributed.is_initialized():
init_local_distributed_process_group(backend=backend)

Expand All @@ -108,7 +113,7 @@ def test_escn_baseline_cuda(self):
model = load_model().cuda()
ddp_model = DistributedDataParallel(model)
output = ddp_model(data)
expected_energy, expected_forces = expected_energy_forces()
expected_energy, expected_forces = expected_energy_forces_cuda()
torch.set_printoptions(precision=8)
assert torch.allclose(output["energy"].cpu(), expected_energy)
assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces)
Expand Down

0 comments on commit 4017073

Please sign in to comment.