From df8933067743f46cc1c5aae176e3474007080527 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:04:19 -0700 Subject: [PATCH 1/2] update to use abs run_dir paths by default (#820) --- src/fairchem/core/common/flags.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/common/flags.py b/src/fairchem/core/common/flags.py index ac4bd2f84a..266e1e640b 100644 --- a/src/fairchem/core/common/flags.py +++ b/src/fairchem/core/common/flags.py @@ -8,6 +8,7 @@ from __future__ import annotations import argparse +import os from pathlib import Path @@ -48,7 +49,7 @@ def add_core_args(self) -> None: ) self.parser.add_argument( "--run-dir", - default="./", + default=os.path.abspath("./"), type=str, help="Directory to store checkpoint/log/result directory", ) From 1bee0d71cc96e15293c49382c6ad80840ca5dc57 Mon Sep 17 00:00:00 2001 From: Misko Date: Wed, 21 Aug 2024 17:07:47 -0700 Subject: [PATCH 2/2] Add check to max num atoms (#817) * add assert for max_num_atoms * add test to make sure we are properly checking for max_num_elements * fix post merge --- .../models/equiformer_v2/equiformer_v2.py | 3 +++ src/fairchem/core/models/escn/escn.py | 3 +++ tests/core/e2e/test_s2ef.py | 20 +++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 978d4c226d..61b62be16f 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -397,6 +397,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 54b1992f4a..6eb95947ae 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -235,6 +235,9 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" num_atoms = len(atomic_numbers) graph = self.generate_graph(data) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 6b83749c0d..2f7dfa3730 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -170,6 +170,26 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): input_yaml=configs["equiformer_v2"], ) + def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + with pytest.raises(AssertionError): + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"backbone": {"max_num_elements": 2}}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2_hydra"], + ) + @pytest.mark.parametrize( ("world_size", "ddp"), [