Skip to content

Commit

Permalink
add optional field to calculator to output only requested
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Nov 19, 2024
1 parent aa298ac commit 984d87c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 1 deletion.
15 changes: 15 additions & 0 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
max_neighbors: int = 50,
cpu: bool = True,
seed: int | None = None,
only_output: List[str] | None = None,
) -> None:
"""
OCP-ASE Calculator
Expand Down Expand Up @@ -170,6 +171,20 @@ def __init__(
self.config["checkpoint"] = checkpoint_path
del config["dataset"]["src"]

# some models that are published have configs that include tasks
# which are not output by the model
if only_output is not None:
assert isinstance(
only_output, list
), "only output must be a list of targets to output"
for key in only_output:
assert (
key in config["outputs"]
), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
remove_outputs = set(config["outputs"].keys()) - set(only_output)
for key in remove_outputs:
config["outputs"].pop(key)

self.trainer = registry.get_trainer_class(config["trainer"])(
task=config.get("task", {}),
model=config["model"],
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def _forward(self, batch):
)
else:
raise AttributeError(
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}"
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n"
+ "If this is being called from OCPCalculator consider using only_output=[..]"
)

### not all models are consistent with the output shape
Expand Down
3 changes: 3 additions & 0 deletions tests/core/common/__snapshots__/test_ase_calculator.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# serializer version: 1
# name: test_energy_with_is2re_model
1.09
# ---
# name: test_relaxation_final_energy
0.92
# ---
22 changes: 22 additions & 0 deletions tests/core/common/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def atoms() -> Atoms:
"PaiNN-S2EF-OC20-All",
"GemNet-OC-Large-S2EF-OC20-All+MD",
"SCN-S2EF-OC20-All+MD",
"PaiNN-IS2RE-OC20-All",
# Equiformer v2 # already tested in test_relaxation_final_energy
# "EquiformerV2-153M-S2EF-OC20-All+MD"
# eSCNm # already tested in test_random_seed_final_energy
Expand All @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path):
_ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True)


def test_energy_with_is2re_model(atoms, tmp_path, snapshot):
random.seed(1)
torch.manual_seed(1)

with pytest.raises(AttributeError): # noqa
calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
)
atoms.set_calculator(calc)
atoms.get_potential_energy()

calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
only_output=["energy"],
)
atoms.set_calculator(calc)
assert snapshot == round(atoms.get_potential_energy(), 2)


# test relaxation with EqV2
def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None:
random.seed(1)
Expand Down

0 comments on commit 984d87c

Please sign in to comment.