diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 5a9302d88..d527eb573 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -124,6 +124,7 @@ def __init__( trainer: str | None = None, cpu: bool = True, seed: int | None = None, + only_output: list[str] | None = None, ) -> None: """ OCP-ASE Calculator @@ -209,6 +210,20 @@ def __init__( self.config["checkpoint"] = str(checkpoint_path) del config["dataset"]["src"] + # some models that are published have configs that include tasks + # which are not output by the model + if only_output is not None: + assert isinstance( + only_output, list + ), "only output must be a list of targets to output" + for key in only_output: + assert ( + key in config["outputs"] + ), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}" + remove_outputs = set(config["outputs"].keys()) - set(only_output) + for key in remove_outputs: + config["outputs"].pop(key) + self.trainer = registry.get_trainer_class(config["trainer"])( task=config.get("task", {}), model=config["model"], diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 8e5d17820..caa2f5651 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -301,7 +301,8 @@ def _forward(self, batch): ) else: raise AttributeError( - f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}" + f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n" + + "If this is being called from OCPCalculator consider using only_output=[..]" ) ### not all models are consistent with the output shape diff --git a/tests/core/common/__snapshots__/test_ase_calculator.ambr b/tests/core/common/__snapshots__/test_ase_calculator.ambr index 71435c0eb..49d8fb3e0 100644 --- a/tests/core/common/__snapshots__/test_ase_calculator.ambr +++ b/tests/core/common/__snapshots__/test_ase_calculator.ambr @@ -1,4 +1,7 @@ # serializer version: 1 +# name: test_energy_with_is2re_model + 1.09 +# --- # name: test_relaxation_final_energy 0.92 # --- diff --git a/tests/core/common/test_ase_calculator.py b/tests/core/common/test_ase_calculator.py index 92baa37cb..e126e68d5 100644 --- a/tests/core/common/test_ase_calculator.py +++ b/tests/core/common/test_ase_calculator.py @@ -38,6 +38,7 @@ def atoms() -> Atoms: "PaiNN-S2EF-OC20-All", "GemNet-OC-Large-S2EF-OC20-All+MD", "SCN-S2EF-OC20-All+MD", + "PaiNN-IS2RE-OC20-All", # Equiformer v2 # already tested in test_relaxation_final_energy # "EquiformerV2-153M-S2EF-OC20-All+MD" # eSCNm # already tested in test_random_seed_final_energy @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path): _ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True) +def test_energy_with_is2re_model(atoms, tmp_path, snapshot): + random.seed(1) + torch.manual_seed(1) + + with pytest.raises(AttributeError): # noqa + calc = OCPCalculator( + checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path), + cpu=True, + ) + atoms.set_calculator(calc) + atoms.get_potential_energy() + + calc = OCPCalculator( + checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path), + cpu=True, + only_output=["energy"], + ) + atoms.set_calculator(calc) + assert snapshot == round(atoms.get_potential_energy(), 2) + + # test relaxation with EqV2 def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None: random.seed(1)