Skip to content

Commit

Permalink
merge upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Aug 22, 2024
2 parents 33bab2f + c2b0c30 commit 13db14c
Show file tree
Hide file tree
Showing 32 changed files with 1,394 additions and 824 deletions.
2 changes: 2 additions & 0 deletions configs/ocp_hydra_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ logger: wandb

outputs:
energy:
property: energy
shape: 1
level: system
forces:
property: forces
irrep_dim: 1
level: atom
train_on_free_atoms: True
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import argparse
import os
from pathlib import Path


Expand Down Expand Up @@ -48,7 +49,7 @@ def add_core_args(self) -> None:
)
self.parser.add_argument(
"--run-dir",
default="./",
default=os.path.abspath("./"),
type=str,
help="Directory to store checkpoint/log/result directory",
)
Expand Down
168 changes: 125 additions & 43 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ast
import collections
import copy
import errno
import importlib
import itertools
import json
Expand Down Expand Up @@ -38,6 +39,7 @@
from torch_scatter import scatter, segment_coo, segment_csr

import fairchem.core
from fairchem.core.common.registry import registry
from fairchem.core.modules.loss import AtomwiseL2Loss, L2MAELoss

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,6 +92,15 @@ def save_checkpoint(
return filename


multitask_required_keys = {
"tasks",
"datasets",
"combined_dataset",
"model",
"optim",
}


class Complete:
def __call__(self, data):
device = data.edge_index.device
Expand Down Expand Up @@ -391,48 +402,83 @@ def create_dict_from_args(args: list, sep: str = "."):
return return_dict


def load_config(path: str, previous_includes: list | None = None):
if previous_includes is None:
previous_includes = []
# given a filename and set of paths , return the full file path
def find_relative_file_in_paths(filename, include_paths):
if os.path.exists(filename):
return filename
for path in include_paths:
include_filename = os.path.join(path, filename)
if os.path.exists(include_filename):
return include_filename
raise ValueError(f"Cannot find include YML {filename}")


def load_config(
path: str,
files_previously_included: list | None = None,
include_paths: list | None = None,
):
"""
Load a given config with any defined imports
When imports are present this is a recursive function called on imports.
To prevent any cyclic imports we keep track of already imported yml files
using files_previously_included
"""
if include_paths is None:
include_paths = []
if files_previously_included is None:
files_previously_included = []
path = Path(path)
if path in previous_includes:
if path in files_previously_included:
raise ValueError(
f"Cyclic config include detected. {path} included in sequence {previous_includes}."
f"Cyclic config include detected. {path} included in sequence {files_previously_included}."
)
previous_includes = [*previous_includes, path]
files_previously_included = [*files_previously_included, path]

with open(path) as fp:
direct_config = yaml.load(fp, Loader=UniqueKeyLoader)
current_config = yaml.load(fp, Loader=UniqueKeyLoader)

# Load config from included files.
includes = direct_config.pop("includes") if "includes" in direct_config else []
if not isinstance(includes, list):
raise AttributeError(f"Includes must be a list, '{type(includes)}' provided")
includes_listed_in_config = (
current_config.pop("includes") if "includes" in current_config else []
)
if not isinstance(includes_listed_in_config, list):
raise AttributeError(
f"Includes must be a list, '{type(includes_listed_in_config)}' provided"
)

config = {}
config_from_includes = {}
duplicates_warning = []
duplicates_error = []

for include in includes:
for include in includes_listed_in_config:
include_filename = find_relative_file_in_paths(
include, [os.path.dirname(path), *include_paths]
)
include_config, inc_dup_warning, inc_dup_error = load_config(
include, previous_includes
include_filename, files_previously_included
)
duplicates_warning += inc_dup_warning
duplicates_error += inc_dup_error

# Duplicates between includes causes an error
config, merge_dup_error = merge_dicts(config, include_config)
config_from_includes, merge_dup_error = merge_dicts(
config_from_includes, include_config
)
duplicates_error += merge_dup_error

# Duplicates between included and main file causes warnings
config, merge_dup_warning = merge_dicts(config, direct_config)
config_from_includes, merge_dup_warning = merge_dicts(
config_from_includes, current_config
)
duplicates_warning += merge_dup_warning

return config, duplicates_warning, duplicates_error
return config_from_includes, duplicates_warning, duplicates_error


def build_config(args, args_override):
config, duplicates_warning, duplicates_error = load_config(args.config_yml)
def build_config(args, args_override, include_paths=None):
config, duplicates_warning, duplicates_error = load_config(
args.config_yml, include_paths=include_paths
)
if len(duplicates_warning) > 0:
logging.warning(
f"Overwritten config parameters from included configs "
Expand Down Expand Up @@ -997,34 +1043,53 @@ class _TrainingContext:
task_name = "s2ef"
elif trainer_name in ["energy", "equiformerv2_energy"]:
task_name = "is2re"
elif "multitask" in trainer_name:
task_name = "multitask"
else:
task_name = "ocp"

trainer_cls = registry.get_trainer_class(trainer_name)
assert trainer_cls is not None, "Trainer not found"
trainer = trainer_cls(
task=config.get("task", {}),
model=config["model"],
outputs=config.get("outputs", {}),
dataset=config["dataset"],
optimizer=config["optim"],
loss_functions=config.get("loss_functions", {}),
evaluation_metrics=config.get("evaluation_metrics", {}),
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "wandb"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
slurm=config.get("slurm", {}),
noddp=config.get("noddp", False),
name=task_name,
gp_gpus=config.get("gp_gpus"),
)

trainer_config = {
"model": config["model"],
"optimizer": config["optim"],
"identifier": config["identifier"],
"timestamp_id": config.get("timestamp_id", None),
"run_dir": config.get("run_dir", "./"),
"is_debug": config.get("is_debug", False),
"print_every": config.get("print_every", 10),
"seed": config.get("seed", 0),
"logger": config.get("logger", "wandb"),
"local_rank": config["local_rank"],
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"noddp": config.get("noddp", False),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}

if task_name == "multitask":
trainer_config.update(
{
"tasks": config.get("tasks", {}),
"dataset_configs": config["datasets"],
"combined_dataset_config": config.get("combined_dataset", {}),
"evaluations": config.get("evaluations", {}),
}
)
else:
trainer_config.update(
{
"task": config.get("task", {}),
"outputs": config.get("outputs", {}),
"dataset": config["dataset"],
"loss_functions": config.get("loss_functions", {}),
"evaluation_metrics": config.get("evaluation_metrics", {}),
}
)
trainer = trainer_cls(**trainer_config)

task_cls = registry.get_task_class(config["mode"])
assert task_cls is not None, "Task not found"
Expand Down Expand Up @@ -1370,3 +1435,20 @@ def get_loss_module(loss_name):
raise NotImplementedError(f"Unknown loss function name: {loss_name}")

return loss_fn


def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# this assumes the checkpont also contains the config with the full model in it
# TODO: need to schematize how we save and load the config from checkpoint
config = checkpoint["config"]["model"]
name = config.pop("name")
model = registry.get_model_class(name)(**config)
matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"])
load_state_dict(model, matched_dict, strict=True)
return model
21 changes: 17 additions & 4 deletions src/fairchem/core/datasets/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
from torch_geometric.data import Data


def rename_data_object_keys(data_object: Data, key_mapping: dict[str, str]) -> Data:
def rename_data_object_keys(
data_object: Data, key_mapping: dict[str, str | list[str]]
) -> Data:
"""Rename data object keys
Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
new_key can be a list of new keys, for example,
prev_key: energy
new_key: [common_energy, oc20_energy]
This is currently required when we use a single target/label for multiple tasks
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = key_mapping[_property]
if new_property not in data_object:
list_of_new_keys = key_mapping[_property]
if isinstance(list_of_new_keys, str):
list_of_new_keys = [list_of_new_keys]
for new_property in list_of_new_keys:
if new_property == _property:
continue
assert new_property not in data_object
data_object[new_property] = data_object[_property]
if _property not in list_of_new_keys:
del data_object[_property]

return data_object
Loading

0 comments on commit 13db14c

Please sign in to comment.