Skip to content

Commit

Permalink
Add check to max num atoms (#817)
Browse files Browse the repository at this point in the history
* add assert for max_num_atoms

* add test to make sure we are properly checking for max_num_elements

* fix post merge
  • Loading branch information
misko authored Aug 22, 2024
1 parent df89330 commit 1bee0d7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.dtype = data.pos.dtype
self.device = data.pos.device
atomic_numbers = data.atomic_numbers.long()
assert (
atomic_numbers.max().item() < self.max_num_elements
), "Atomic number exceeds that given in model config"
graph = self.generate_graph(
data,
enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly,
Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def forward(self, data):

start_time = time.time()
atomic_numbers = data.atomic_numbers.long()
assert (
atomic_numbers.max().item() < self.max_num_elements
), "Atomic number exceeds that given in model config"
num_atoms = len(atomic_numbers)
graph = self.generate_graph(data)

Expand Down
20 changes: 20 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
input_yaml=configs["equiformer_v2"],
)

def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
extra_args = {"seed": 0}
with pytest.raises(AssertionError):
_ = _run_main(
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"backbone": {"max_num_elements": 2}},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with=extra_args,
input_yaml=configs["equiformer_v2_hydra"],
)

@pytest.mark.parametrize(
("world_size", "ddp"),
[
Expand Down

0 comments on commit 1bee0d7

Please sign in to comment.