Skip to content

Commit

Permalink
Merge branch 'main' into fix_pytorch_2_4_0
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Dec 4, 2024
2 parents cd8f2f3 + 3f695ef commit 53dad8d
Show file tree
Hide file tree
Showing 17 changed files with 1,068 additions and 241 deletions.
2 changes: 1 addition & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [

[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"]
adsorbml = ["dscribe","x3dase","scikit-image"]

[project.scripts]
Expand Down
13 changes: 13 additions & 0 deletions src/fairchem/core/common/relaxation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from .ml_relaxation import ml_relax
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch

__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]
108 changes: 77 additions & 31 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import copy
import logging
from typing import ClassVar
from types import MappingProxyType
from typing import TYPE_CHECKING

import torch
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.singlepoint import SinglePointCalculator as sp
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
from ase.geometry import wrap_positions

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
Expand All @@ -33,51 +35,93 @@
from fairchem.core.models.model_registry import model_name_to_local_file
from fairchem.core.preprocessing import AtomsToGraphs

if TYPE_CHECKING:
from pathlib import Path

def batch_to_atoms(batch):
from torch_geometric.data import Batch


# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE = MappingProxyType(
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
)


def batch_to_atoms(
batch: Batch,
results: dict[str, torch.Tensor] | None = None,
wrap_pos: bool = True,
eps: float = 1e-7,
) -> list[Atoms]:
"""Convert a data batch to ase Atoms
Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.
Returns:
list of Atoms
"""
n_systems = batch.natoms.shape[0]
natoms = batch.natoms.tolist()
numbers = torch.split(batch.atomic_numbers, natoms)
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
forces = torch.split(batch.force, natoms)
if results is not None:
results = {
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
if len(val) == len(batch)
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
for key, val in results.items()
}

positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
energies = batch.energy.view(-1).tolist()

atoms_objects = []
for idx in range(n_systems):
pos = positions[idx].cpu().detach().numpy()
cell = cells[idx].cpu().detach().numpy()

# TODO take pbc from data
if wrap_pos:
pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)

atoms = Atoms(
numbers=numbers[idx].tolist(),
positions=positions[idx].cpu().detach().numpy(),
cell=cell,
positions=pos,
tags=tags[idx].tolist(),
cell=cells[idx].cpu().detach().numpy(),
constraint=FixAtoms(mask=fixed[idx].tolist()),
pbc=[True, True, True],
)
calc = sp(
atoms=atoms,
energy=energies[idx],
forces=forces[idx].cpu().detach().numpy(),
)
atoms.set_calculator(calc)

if results is not None:
calc = SinglePointCalculator(
atoms=atoms, **{key: val[idx] for key, val in results.items()}
)
atoms.set_calculator(calc)

atoms_objects.append(atoms)

return atoms_objects


class OCPCalculator(Calculator):
implemented_properties: ClassVar[list[str]] = ["energy", "forces"]
"""ASE based calculator using an OCP model"""

_reshaped_props = ASE_PROP_RESHAPE

def __init__(
self,
config_yml: str | None = None,
checkpoint_path: str | None = None,
checkpoint_path: str | Path | None = None,
model_name: str | None = None,
local_cache: str | None = None,
trainer: str | None = None,
cutoff: int = 6,
max_neighbors: int = 50,
cpu: bool = True,
seed: int | None = None,
) -> None:
Expand All @@ -96,16 +140,12 @@ def __init__(
Directory to save pretrained model checkpoints.
trainer (str):
OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE.
cutoff (int):
Cutoff radius to be used for data preprocessing.
max_neighbors (int):
Maximum amount of neighbors to store for a given atom.
cpu (bool):
Whether to load and run the model on CPU. Set `False` for GPU.
"""
setup_imports()
setup_logging()
Calculator.__init__(self)
super().__init__()

if model_name is not None:
if checkpoint_path is not None:
Expand Down Expand Up @@ -165,9 +205,8 @@ def __init__(
### backwards compatability with OCP v<2.0
config = update_config(config)

# Save config so obj can be transported over network (pkl)
self.config = copy.deepcopy(config)
self.config["checkpoint"] = checkpoint_path
self.config["checkpoint"] = str(checkpoint_path)
del config["dataset"]["src"]

self.trainer = registry.get_trainer_class(config["trainer"])(
Expand Down Expand Up @@ -199,14 +238,13 @@ def __init__(
self.trainer.set_seed(seed)

self.a2g = AtomsToGraphs(
max_neigh=max_neighbors,
radius=cutoff,
r_energy=False,
r_forces=False,
r_distances=False,
r_edges=False,
r_pbc=True,
r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model...
)
self.implemented_properties = list(self.config["outputs"].keys())

def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
Expand All @@ -217,6 +255,8 @@ def load_checkpoint(
Args:
checkpoint_path: string
Path to trained model
checkpoint: dict
A pretrained checkpoint dict
"""
try:
self.trainer.load_checkpoint(
Expand All @@ -225,14 +265,20 @@ def load_checkpoint(
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

def calculate(self, atoms: Atoms, properties, system_changes) -> None:
Calculator.calculate(self, atoms, properties, system_changes)
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None:
"""Calculate implemented properties for a single Atoms object or a Batch of them."""
super().calculate(atoms, properties, system_changes)
if isinstance(atoms, Atoms):
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
else:
batch = atoms

predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True)

for key in predictions:
_pred = predictions[key]
_pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy()
if key in OCPCalculator._reshaped_props:
_pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze()
self.results[key] = _pred
117 changes: 79 additions & 38 deletions src/fairchem/core/common/relaxation/ml_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,85 +10,126 @@
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING

import torch
from torch_geometric.data import Batch

from fairchem.core.common.typing import assert_is_instance
from fairchem.core.datasets.lmdb_dataset import data_list_collater

from .optimizers.lbfgs_torch import LBFGS, TorchCalc
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
from .optimizers.lbfgs_torch import LBFGS

if TYPE_CHECKING:
from fairchem.core.trainers import BaseTrainer


def ml_relax(
batch,
model,
batch: Batch,
model: BaseTrainer,
steps: int,
fmax: float,
relax_opt,
save_full_traj,
device: str = "cuda:0",
transform=None,
early_stop_batch: bool = False,
relax_opt: dict[str] | None = None,
relax_cell: bool = False,
relax_volume: bool = False,
save_full_traj: bool = True,
transform: torch.nn.Module | None = None,
mask_converged: bool = True,
):
"""
Runs ML-based relaxations.
"""Runs ML-based relaxations.
Args:
batch: object
model: object
steps: int
Max number of steps in the structure relaxation.
fmax: float
Structure relaxation terminates when the max force
of the system is no bigger than fmax.
relax_opt: str
Optimizer and corresponding parameters to be used for structure relaxations.
save_full_traj: bool
Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
batch: a data batch object.
model: a trainer object with model.
steps: Max number of steps in the structure relaxation.
fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax.
relax_opt: Optimizer parameters to be used for structure relaxations.
relax_cell: if true will use stress predictions to relax crystallographic cell.
The model given must predict stress
relax_volume: if true will relax the cell isotropically. the given model must predict stress.
save_full_traj: Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
mask_converged: whether to mask batches where all atoms are below convergence threshold
cumulative_mask: if true, once system is masked then it remains masked even if new predictions give forces
above threshold, ie. once masked always masked. Note if this is used make sure to check convergence with
the same fmax always
"""
relax_opt = relax_opt or {}
# if not pbc is set, ignore it when comparing batches
if not hasattr(batch, "pbc"):
OptimizableBatch.ignored_changes = {"pbc"}

batches = deque([batch])
relaxed_batches = []
while batches:
batch = batches.popleft()
oom = False
ids = batch.sid
calc = TorchCalc(model, transform)

# clone the batch otherwise you can not run batch.to_data_list
# see https://github.com/pyg-team/pytorch_geometric/issues/8439#issuecomment-1826747915
if relax_cell or relax_volume:
optimizable = OptimizableUnitCellBatch(
batch.clone(),
trainer=model,
transform=transform,
mask_converged=mask_converged,
hydrostatic_strain=relax_volume,
)
else:
optimizable = OptimizableBatch(
batch.clone(),
trainer=model,
transform=transform,
mask_converged=mask_converged,
)

# Run ML-based relaxation
traj_dir = relax_opt.get("traj_dir", None)
traj_dir = relax_opt.get("traj_dir")
relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None})

optimizer = LBFGS(
batch,
calc,
maxstep=relax_opt.get("maxstep", 0.2),
memory=relax_opt["memory"],
damping=relax_opt.get("damping", 1.2),
alpha=relax_opt.get("alpha", 80.0),
device=device,
optimizable_batch=optimizable,
save_full_traj=save_full_traj,
traj_dir=Path(traj_dir) if traj_dir is not None else None,
traj_names=ids,
early_stop_batch=early_stop_batch,
**relax_opt,
)

e: RuntimeError | None = None
try:
relaxed_batch = optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(relaxed_batch)
optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(optimizable.batch)
except RuntimeError as err:
e = err
oom = True
torch.cuda.empty_cache()

if oom:
# move OOM recovery code outside of except clause to allow tensors to be freed.
# move OOM recovery code outside off except clause to allow tensors to be freed.
data_list = batch.to_data_list()
if len(data_list) == 1:
raise assert_is_instance(e, RuntimeError)
logging.info(
f"Failed to relax batch with size: {len(data_list)}, splitting into two..."
)
mid = len(data_list) // 2
batches.appendleft(data_list_collater(data_list[:mid]))
batches.appendleft(data_list_collater(data_list[mid:]))
batches.appendleft(
data_list_collater(data_list[:mid], otf_graph=optimizable.otf_graph)
)
batches.appendleft(
data_list_collater(data_list[mid:], otf_graph=optimizable.otf_graph)
)

# reset for good measure
OptimizableBatch.ignored_changes = {}

relaxed_batch = Batch.from_data_list(relaxed_batches)

# Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str
# it will be incorrectly collated as a list of lists for each batch.
# but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above).
# So instead just manually fix it for now. Remove this once pyg dependency is removed
if isinstance(relaxed_batch.sid, list):
relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list]

return Batch.from_data_list(relaxed_batches)
return relaxed_batch
Loading

0 comments on commit 53dad8d

Please sign in to comment.