diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 5a9302d88..d527eb573 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -124,6 +124,7 @@ def __init__( trainer: str | None = None, cpu: bool = True, seed: int | None = None, + only_output: list[str] | None = None, ) -> None: """ OCP-ASE Calculator @@ -209,6 +210,20 @@ def __init__( self.config["checkpoint"] = str(checkpoint_path) del config["dataset"]["src"] + # some models that are published have configs that include tasks + # which are not output by the model + if only_output is not None: + assert isinstance( + only_output, list + ), "only output must be a list of targets to output" + for key in only_output: + assert ( + key in config["outputs"] + ), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}" + remove_outputs = set(config["outputs"].keys()) - set(only_output) + for key in remove_outputs: + config["outputs"].pop(key) + self.trainer = registry.get_trainer_class(config["trainer"])( task=config.get("task", {}), model=config["model"], diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index fa679a262..27dc664a3 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -185,14 +185,14 @@ def convert(self, atoms: ase.Atoms, sid=None): cell = np.array(atoms.get_cell(complete=True), copy=True) positions = wrap_positions(positions, cell, pbc=pbc, eps=0) - atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()) + atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8) positions = torch.from_numpy(positions).float() cell = torch.from_numpy(cell).view(1, 3, 3).float() natoms = positions.shape[0] # initialized to torch.zeros(natoms) if tags missing. # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags - tags = torch.Tensor(atoms.get_tags()) + tags = torch.tensor(atoms.get_tags(), dtype=torch.int) # put the minimum data in torch geometric data object data = Data( @@ -228,10 +228,15 @@ def convert(self, atoms: ase.Atoms, sid=None): energy = atoms.get_potential_energy(apply_constraint=False) data.energy = energy if self.r_forces: - forces = torch.Tensor(atoms.get_forces(apply_constraint=False)) + forces = torch.tensor( + atoms.get_forces(apply_constraint=False), dtype=torch.float32 + ) data.forces = forces if self.r_stress: - stress = torch.Tensor(atoms.get_stress(apply_constraint=False, voigt=False)) + stress = torch.tensor( + atoms.get_stress(apply_constraint=False, voigt=False), + dtype=torch.float32, + ) data.stress = stress if self.r_distances and self.r_edges: data.distances = edge_distances @@ -245,13 +250,13 @@ def convert(self, atoms: ase.Atoms, sid=None): fixed_idx[constraint.index] = 1 data.fixed = fixed_idx if self.r_pbc: - data.pbc = torch.tensor(atoms.pbc) + data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool) if self.r_data_keys is not None: for data_key in self.r_data_keys: data[data_key] = ( atoms.info[data_key] if isinstance(atoms.info[data_key], (int, float, str)) - else torch.Tensor(atoms.info[data_key]) + else torch.tensor(atoms.info[data_key]) ) return data diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 8e5d17820..caa2f5651 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -301,7 +301,8 @@ def _forward(self, batch): ) else: raise AttributeError( - f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}" + f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n" + + "If this is being called from OCPCalculator consider using only_output=[..]" ) ### not all models are consistent with the output shape diff --git a/tests/core/common/__snapshots__/test_ase_calculator.ambr b/tests/core/common/__snapshots__/test_ase_calculator.ambr index 71435c0eb..49d8fb3e0 100644 --- a/tests/core/common/__snapshots__/test_ase_calculator.ambr +++ b/tests/core/common/__snapshots__/test_ase_calculator.ambr @@ -1,4 +1,7 @@ # serializer version: 1 +# name: test_energy_with_is2re_model + 1.09 +# --- # name: test_relaxation_final_energy 0.92 # --- diff --git a/tests/core/common/test_ase_calculator.py b/tests/core/common/test_ase_calculator.py index 92baa37cb..e126e68d5 100644 --- a/tests/core/common/test_ase_calculator.py +++ b/tests/core/common/test_ase_calculator.py @@ -38,6 +38,7 @@ def atoms() -> Atoms: "PaiNN-S2EF-OC20-All", "GemNet-OC-Large-S2EF-OC20-All+MD", "SCN-S2EF-OC20-All+MD", + "PaiNN-IS2RE-OC20-All", # Equiformer v2 # already tested in test_relaxation_final_energy # "EquiformerV2-153M-S2EF-OC20-All+MD" # eSCNm # already tested in test_random_seed_final_energy @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path): _ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True) +def test_energy_with_is2re_model(atoms, tmp_path, snapshot): + random.seed(1) + torch.manual_seed(1) + + with pytest.raises(AttributeError): # noqa + calc = OCPCalculator( + checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path), + cpu=True, + ) + atoms.set_calculator(calc) + atoms.get_potential_energy() + + calc = OCPCalculator( + checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path), + cpu=True, + only_output=["energy"], + ) + atoms.set_calculator(calc) + assert snapshot == round(atoms.get_potential_energy(), 2) + + # test relaxation with EqV2 def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None: random.seed(1)