Skip to content

Commit

Permalink
(OTF) Normalization and element references (#715)
Browse files Browse the repository at this point in the history
* denorm targets in _forward only

* linear reference class

* atomref in normalizer

* raise input error

* clean up normalizer interface

* add element refs

* add element refs correctly

* ruff

* fix save_checkpoint

* reference and dereference

* 2xnorm linref trainer add

* clean-up

* otf linear reference fit

* fix tensor device

* otf element references and normalizers

* use only present elements when fitting

* lint

* _forward norm and derefd values

* fix list of paths in src

* total mean and std

* fitted flag to avoid refitting normalizers/references on rerun

* allow passing lstsq driver

* element ref unit tests

* remove superfluous type

* lint fix

* allow setting batch_size explicitly

* test applying element refs

* normalizer tests

* increase distributed timeout

* save normalizers and linear refs in otf_fit

* remove debug code

* fix removing refs

* swap otf_fit for fit, and save all normalizers in one file

* log loading and saving normalizers

* fit references and normalizer scripts

* lint fixes

* allow absent optim key in config

* lin-ref description

* read files based on extension

* pass seed

* rename dataset fixture

* check if file is none

* pass generator correctly

* separate method for norms and refs

* add normalizer code back

* fix Generator construction

* import order

* log warnings if multiple inputs are passed

* raise Error if duplicate references or norms are set

* use len batch

* assert element reference targets are scalar

* fix name and rename method

* load and save norms and refs using same logic

* fix creating normalizer

* remove print statements

* adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764)

* adding new notebook for using fairchem models with NEBs

* adding md tutorials

* blocking code cells that arent needed or take too long

* warn instead of error when duplicate norm/ref target names

* allow timeout to be read from config

* test seed noseed ref fits

* lotsa refactoring

* lotsa fixing

* more fixing...

* num_workers zero to prevent mp issues

* add otf norms smoke test and fixes

* allow overriding normalization fit values

* update tests

* fix normalizer loading

* use rmsd instead of only stdev

* fix tests

* correct rmsd calc and fix loading

* clean up norm loading and log values

* logg linear reference metrics

* load element references state dict

* fix loading and tests

* fix imports in scripts

* fix test?

* fix test

* use numpy as default to fit references

* minor fixes

* rm torch_tempdir fixture

---------

Co-authored-by: Brook Wander <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
  • Loading branch information
3 people authored Aug 5, 2024
1 parent 08b8c1e commit 029d4d3
Show file tree
Hide file tree
Showing 19 changed files with 1,572 additions and 265 deletions.
16 changes: 14 additions & 2 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import logging
import os
import subprocess
from typing import TypeVar
from datetime import timedelta
from typing import Any, TypeVar

import torch
import torch.distributed as dist
Expand All @@ -27,6 +28,7 @@ def os_environ_get_or_throw(x: str) -> str:


def setup(config) -> None:
timeout = timedelta(minutes=config.get("timeout", 30))
if config["submit"]:
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
Expand Down Expand Up @@ -72,6 +74,7 @@ def setup(config) -> None:
init_method=config["init_method"],
world_size=config["world_size"],
rank=config["rank"],
timeout=timeout,
)
except subprocess.CalledProcessError as e: # scontrol failed
raise e
Expand All @@ -95,10 +98,11 @@ def setup(config) -> None:
rank=world_rank,
world_size=world_size,
init_method="env://",
timeout=timeout,
)
else:
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend=config.get("backend", "nccl"))
dist.init_process_group(backend=config.get("backend", "nccl"), timeout=timeout)


def cleanup() -> None:
Expand Down Expand Up @@ -135,6 +139,14 @@ def broadcast(
dist.broadcast(tensor, src, group, async_op)


def broadcast_object_list(
object_list: list[Any], src: int, group=dist.group.WORLD, device: str | None = None
) -> None:
if get_world_size() == 1:
return
dist.broadcast_object_list(object_list, src, group, device)


def all_reduce(
data, group=dist.group.WORLD, average: bool = False, device=None
) -> torch.Tensor:
Expand Down
17 changes: 9 additions & 8 deletions src/fairchem/core/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from functools import cache, reduce
from functools import cache
from glob import glob
from pathlib import Path
from typing import Any, Callable
Expand Down Expand Up @@ -467,13 +467,14 @@ class AseDBDataset(AseAtomsDataset):

def _load_dataset_get_ids(self, config: dict) -> list[int]:
if isinstance(config["src"], list):
if os.path.isdir(config["src"][0]):
filepaths = reduce(
lambda x, y: x + y,
(glob(f"{path}/*") for path in config["src"]),
)
else:
filepaths = config["src"]
filepaths = []
for path in config["src"]:
if os.path.isdir(path):
filepaths.extend(glob(f"{path}/*"))
elif os.path.isfile(path):
filepaths.append(path)
else:
raise RuntimeError(f"Error reading dataset in {path}!")
elif os.path.isfile(config["src"]):
filepaths = [config["src"]]
elif os.path.isdir(config["src"]):
Expand Down
Empty file.
113 changes: 113 additions & 0 deletions src/fairchem/core/modules/normalization/_load_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Callable

import torch

from fairchem.core.common.utils import save_checkpoint

if TYPE_CHECKING:
from pathlib import Path

from torch.nn import Module
from torch.utils.data import Dataset


def _load_check_duplicates(config: dict, name: str) -> dict[str, torch.nn.Module]:
"""Attempt to load a single file with normalizers/element references and check config for duplicate targets.
Args:
config: configuration dictionary
name: Name of module to use for logging
Returns:
dictionary of normalizer or element reference modules
"""
modules = {}
if "file" in config:
modules = torch.load(config["file"])
logging.info(f"Loaded {name} for the following targets: {list(modules.keys())}")
# make sure that element-refs are not specified both as fit and file
fit_targets = config["fit"]["targets"] if "fit" in config else []
duplicates = list(
filter(
lambda x: x in fit_targets,
list(config) + list(modules.keys()),
)
)
if len(duplicates) > 0:
logging.warning(
f"{name} values for the following targets {duplicates} have been specified to be fit and also read"
f" from a file. The files read from file will be used instead of fitting."
)
duplicates = list(filter(lambda x: x in modules, config))
if len(duplicates) > 0:
logging.warning(
f"Duplicate {name} values for the following targets {duplicates} where specified in the file "
f"{config['file']} and an explicitly set file. The normalization values read from "
f"{config['file']} will be used."
)
return modules


def _load_from_config(
config: dict,
name: str,
fit_fun: Callable[[list[str], Dataset, Any, ...], dict[str, Module]],
create_fun: Callable[[str | Path], Module],
dataset: Dataset,
checkpoint_dir: str | Path | None = None,
**fit_kwargs,
) -> dict[str, torch.nn.Module]:
"""Load or fit normalizers or element references from config
If a fit is done, a fitted key with value true is added to the config to avoid re-fitting
once a checkpoint has been saved.
Args:
config: configuration dictionary
name: Name of module to use for logging
fit_fun: Function to fit modules
create_fun: Function to create a module from file
checkpoint_dir: directory to save modules. If not given, modules won't be saved.
Returns:
dictionary of normalizer or element reference modules
"""
modules = _load_check_duplicates(config, name)
for target in config:
if target == "fit" and not config["fit"].get("fitted", False):
# remove values for output targets that have already been read from files
targets = [
target for target in config["fit"]["targets"] if target not in modules
]
fit_kwargs.update(
{k: v for k, v in config["fit"].items() if k != "targets"}
)
modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs))
config["fit"]["fitted"] = True
# if a single file for all outputs is not provided,
# then check if a single file is provided for a specific output
elif target != "file":
modules[target] = create_fun(**config[target])
# save the linear references for possible subsequent use
if checkpoint_dir is not None:
path = save_checkpoint(
modules,
checkpoint_dir,
f"{name}.pt",
)
logging.info(
f"{name} checkpoint for targets {list(modules.keys())} have been saved to: {path}"
)

return modules
Loading

0 comments on commit 029d4d3

Please sign in to comment.