-
Notifications
You must be signed in to change notification settings - Fork 265
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add script to port old equiv2 checkpoint+yaml to hydra version (#846)
* add script to port old equiv2 checkpoint+yaml to hydra version * fix up comments * lint * move script and add forces to test
- Loading branch information
Showing
3 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
88 changes: 88 additions & 0 deletions
88
src/fairchem/core/models/equiformer_v2/eqv2_to_eqv2_hydra.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
from collections import OrderedDict | ||
from copy import deepcopy | ||
|
||
import torch | ||
import yaml | ||
|
||
|
||
def convert_checkpoint_and_config_to_hydra( | ||
yaml_fn, checkpoint_fn, new_yaml_fn, new_checkpoint_fn | ||
): | ||
assert not os.path.exists(new_yaml_fn), "Output yaml cannot already exist!" | ||
assert not os.path.exists( | ||
new_checkpoint_fn | ||
), "Output checkpoint cannot already exist!" | ||
|
||
def eqv2_state_dict_to_hydra_state_dict(eqv2_state_dict): | ||
hydra_state_dict = OrderedDict() | ||
for og_key in list(eqv2_state_dict.keys()): | ||
if "force_block" in og_key or "energy_block" in og_key: | ||
key = og_key.replace( | ||
"force_block", "output_heads.forces.force_block" | ||
).replace("energy_block", "output_heads.energy.energy_block") | ||
else: | ||
offset = 0 | ||
if og_key[: len("module.")] == "module.": | ||
offset += len("module.") | ||
key = og_key[:offset] + "backbone." + og_key[offset:] | ||
hydra_state_dict[key] = eqv2_state_dict[og_key] | ||
return hydra_state_dict | ||
|
||
def convert_configs_to_hydra(yaml_config, checkpoint_config): | ||
new_model_config = { | ||
"name": "hydra", | ||
"backbone": checkpoint_config["model"].copy(), | ||
"heads": { | ||
"energy": {"module": "equiformer_v2_energy_head"}, | ||
"forces": {"module": "equiformer_v2_force_head"}, | ||
}, | ||
} | ||
assert new_model_config["backbone"]["name"] in ["equiformer_v2"] | ||
new_model_config["backbone"].pop("name") | ||
new_model_config["backbone"]["model"] = "equiformer_v2_backbone" | ||
|
||
# create a new checkpoint config | ||
new_checkpoint_config = deepcopy(checkpoint_config) | ||
new_checkpoint_config["model"] = new_model_config | ||
|
||
# create a new YAML config | ||
new_yaml_config = deepcopy(yaml_config) | ||
new_yaml_config["model"] = new_model_config | ||
|
||
for output_key, output_d in new_yaml_config["outputs"].items(): | ||
if output_d["level"] == "system": | ||
output_d["property"] = "energy" | ||
elif output_d["level"] == "atom": | ||
output_d["property"] = "forces" | ||
else: | ||
logging.warning( | ||
f"Converting output:{output_key} to new equiv2 hydra config \ | ||
failed to find level and could not set property in output correctly" | ||
) | ||
|
||
return new_yaml_config, new_checkpoint_config | ||
|
||
# load existing from disk | ||
with open(yaml_fn) as yaml_f: | ||
yaml_config = yaml.safe_load(yaml_f) | ||
checkpoint = torch.load(checkpoint_fn, map_location="cpu") | ||
|
||
new_checkpoint = checkpoint.copy() | ||
new_yaml_config, new_checkpoint_config = convert_configs_to_hydra( | ||
yaml_config, checkpoint["config"] | ||
) | ||
new_checkpoint["config"] = new_checkpoint_config | ||
new_checkpoint["state_dict"] = eqv2_state_dict_to_hydra_state_dict( | ||
checkpoint["state_dict"] | ||
) | ||
for key in ["ema", "optimizer", "scheduler"]: | ||
new_checkpoint.pop(key) | ||
|
||
# write output | ||
torch.save(new_checkpoint, new_checkpoint_fn) | ||
with open(str(new_yaml_fn), "w") as yaml_file: | ||
yaml.dump(new_yaml_config, yaml_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
import argparse | ||
|
||
from fairchem.core.models.equiformer_v2.eqv2_to_eqv2_hydra import ( | ||
convert_checkpoint_and_config_to_hydra, | ||
) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--eqv2-checkpoint", help="path to eqv2 checkpoint", type=str, required=True | ||
) | ||
parser.add_argument( | ||
"--eqv2-yaml", help="path to eqv2 yaml config", type=str, required=True | ||
) | ||
parser.add_argument( | ||
"--hydra-eqv2-checkpoint", | ||
help="path where to output hydra checkpoint", | ||
type=str, | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--hydra-eqv2-yaml", | ||
help="path where to output hydra yaml", | ||
type=str, | ||
required=True, | ||
) | ||
args = parser.parse_args() | ||
|
||
convert_checkpoint_and_config_to_hydra( | ||
yaml_fn=args.eqv2_yaml, | ||
checkpoint_fn=args.eqv2_checkpoint, | ||
new_yaml_fn=args.hydra_eqv2_yaml, | ||
new_checkpoint_fn=args.hydra_eqv2_checkpoint, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters