Skip to content

Commit

Permalink
Merge branch 'main' into rgao_add_mean_pool_eqv2
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 authored Aug 22, 2024
2 parents 37e061e + 1bee0d7 commit 6274d5f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/fairchem/core/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import argparse
import os
from pathlib import Path


Expand Down Expand Up @@ -48,7 +49,7 @@ def add_core_args(self) -> None:
)
self.parser.add_argument(
"--run-dir",
default="./",
default=os.path.abspath("./"),
type=str,
help="Directory to store checkpoint/log/result directory",
)
Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.dtype = data.pos.dtype
self.device = data.pos.device
atomic_numbers = data.atomic_numbers.long()
assert (
atomic_numbers.max().item() < self.max_num_elements
), "Atomic number exceeds that given in model config"
graph = self.generate_graph(
data,
enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly,
Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def forward(self, data):

start_time = time.time()
atomic_numbers = data.atomic_numbers.long()
assert (
atomic_numbers.max().item() < self.max_num_elements
), "Atomic number exceeds that given in model config"
num_atoms = len(atomic_numbers)
graph = self.generate_graph(data)

Expand Down
20 changes: 20 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
input_yaml=configs["equiformer_v2"],
)

def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
extra_args = {"seed": 0}
with pytest.raises(AssertionError):
_ = _run_main(
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"backbone": {"max_num_elements": 2}},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with=extra_args,
input_yaml=configs["equiformer_v2_hydra"],
)

@pytest.mark.parametrize(
("world_size", "ddp"),
[
Expand Down

0 comments on commit 6274d5f

Please sign in to comment.