Skip to content

Commit

Permalink
add gpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 29, 2024
1 parent 672e1ff commit d6ac5b7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ def forward(
###############################################################
# Compute messages
###############################################################

# Compute edge scalar features (invariant to rotations)
# Uses atomic numbers and edge distance as inputs
x_edge = self.edge_block(
Expand Down
18 changes: 17 additions & 1 deletion tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from torch.export import export
from torch.export import Dim

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="skipping when no gpu")


def load_data():
atoms = read(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"),
Expand Down Expand Up @@ -86,7 +89,7 @@ def init(backend="nccl"):

class TestESCNCompiles:
def test_escn_baseline_cpu(self):
init("gloo")
init('gloo')
data = load_data()
data = data_list_collater([data])
model = load_model()
Expand All @@ -97,6 +100,19 @@ def test_escn_baseline_cpu(self):
assert torch.allclose(output["energy"], expected_energy)
assert torch.allclose(output["forces"].mean(0), expected_forces)

@skip_if_no_cuda
def test_escn_baseline_cuda(self):
init('nccl')
data = load_data()
data = data_list_collater([data]).to("cuda")
model = load_model().cuda()
ddp_model = DistributedDataParallel(model)
output = ddp_model(data)
expected_energy, expected_forces = expected_energy_forces()
torch.set_printoptions(precision=8)
assert torch.allclose(output["energy"].cpu(), expected_energy)
assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces)

def test_escn_compiles(self):
init("gloo")
data = load_data()
Expand Down

0 comments on commit d6ac5b7

Please sign in to comment.