Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OCPCalculator output_only option #922

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
trainer: str | None = None,
cpu: bool = True,
seed: int | None = None,
only_output: list[str] | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, better to have only_output as a tuple, since its not meant to be changed at all in the code

) -> None:
"""
OCP-ASE Calculator
Expand Down Expand Up @@ -209,6 +210,20 @@ def __init__(
self.config["checkpoint"] = str(checkpoint_path)
del config["dataset"]["src"]

# some models that are published have configs that include tasks
# which are not output by the model
if only_output is not None:
assert isinstance(
lbluque marked this conversation as resolved.
Show resolved Hide resolved
only_output, list
), "only output must be a list of targets to output"
for key in only_output:
assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

key in config["outputs"]
), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
remove_outputs = set(config["outputs"].keys()) - set(only_output)
for key in remove_outputs:
config["outputs"].pop(key)

self.trainer = registry.get_trainer_class(config["trainer"])(
task=config.get("task", {}),
model=config["model"],
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def _forward(self, batch):
)
else:
raise AttributeError(
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}"
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n"
+ "If this is being called from OCPCalculator consider using only_output=[..]"
)

### not all models are consistent with the output shape
Expand Down
3 changes: 3 additions & 0 deletions tests/core/common/__snapshots__/test_ase_calculator.ambr
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# serializer version: 1
# name: test_energy_with_is2re_model
1.09
# ---
# name: test_relaxation_final_energy
0.92
# ---
22 changes: 22 additions & 0 deletions tests/core/common/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def atoms() -> Atoms:
"PaiNN-S2EF-OC20-All",
"GemNet-OC-Large-S2EF-OC20-All+MD",
"SCN-S2EF-OC20-All+MD",
"PaiNN-IS2RE-OC20-All",
# Equiformer v2 # already tested in test_relaxation_final_energy
# "EquiformerV2-153M-S2EF-OC20-All+MD"
# eSCNm # already tested in test_random_seed_final_energy
Expand All @@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path):
_ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True)


def test_energy_with_is2re_model(atoms, tmp_path, snapshot):
random.seed(1)
torch.manual_seed(1)

with pytest.raises(AttributeError): # noqa
calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
)
atoms.set_calculator(calc)
atoms.get_potential_energy()

calc = OCPCalculator(
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
cpu=True,
only_output=["energy"],
)
atoms.set_calculator(calc)
assert snapshot == round(atoms.get_potential_energy(), 2)


# test relaxation with EqV2
def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None:
random.seed(1)
Expand Down
Loading