Skip to content

Commit

Permalink
Merge branch 'main' into rgao_fix_local_n_gpurun
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 authored Dec 4, 2024
2 parents b8c3905 + 816cf00 commit 0e52700
Show file tree
Hide file tree
Showing 19 changed files with 1,071 additions and 244 deletions.
2 changes: 1 addition & 1 deletion packages/env.cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- defaults
dependencies:
- cpuonly
- pytorch>=2.4
- pytorch==2.4.0
- ase
- e3nn>=0.5
- numpy >=1.26.0,<2.0.0
Expand Down
2 changes: 1 addition & 1 deletion packages/env.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- pytorch-cuda=12.1
- pytorch>=2.4
- pytorch==2.4.0
- ase
- e3nn>=0.5
- numpy >=1.26.0,<2.0.0
Expand Down
4 changes: 2 additions & 2 deletions packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = {text = "MIT License"}
dynamic = ["version", "readme"]
requires-python = ">=3.9, <3.13"
dependencies = [
"torch>=2.4",
"torch==2.4",
"numpy >=1.26.0, <2.0.0",
"lmdb",
"ase",
Expand All @@ -28,7 +28,7 @@ dependencies = [

[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"]
adsorbml = ["dscribe","x3dase","scikit-image"]

[project.scripts]
Expand Down
13 changes: 13 additions & 0 deletions src/fairchem/core/common/relaxation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from .ml_relaxation import ml_relax
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch

__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]
108 changes: 77 additions & 31 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import copy
import logging
from typing import ClassVar
from types import MappingProxyType
from typing import TYPE_CHECKING

import torch
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.singlepoint import SinglePointCalculator as sp
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
from ase.geometry import wrap_positions

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
Expand All @@ -33,51 +35,93 @@
from fairchem.core.models.model_registry import model_name_to_local_file
from fairchem.core.preprocessing import AtomsToGraphs

if TYPE_CHECKING:
from pathlib import Path

def batch_to_atoms(batch):
from torch_geometric.data import Batch


# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE = MappingProxyType(
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
)


def batch_to_atoms(
batch: Batch,
results: dict[str, torch.Tensor] | None = None,
wrap_pos: bool = True,
eps: float = 1e-7,
) -> list[Atoms]:
"""Convert a data batch to ase Atoms
Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.
Returns:
list of Atoms
"""
n_systems = batch.natoms.shape[0]
natoms = batch.natoms.tolist()
numbers = torch.split(batch.atomic_numbers, natoms)
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
forces = torch.split(batch.force, natoms)
if results is not None:
results = {
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
if len(val) == len(batch)
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
for key, val in results.items()
}

positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
energies = batch.energy.view(-1).tolist()

atoms_objects = []
for idx in range(n_systems):
pos = positions[idx].cpu().detach().numpy()
cell = cells[idx].cpu().detach().numpy()

# TODO take pbc from data
if wrap_pos:
pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)

atoms = Atoms(
numbers=numbers[idx].tolist(),
positions=positions[idx].cpu().detach().numpy(),
cell=cell,
positions=pos,
tags=tags[idx].tolist(),
cell=cells[idx].cpu().detach().numpy(),
constraint=FixAtoms(mask=fixed[idx].tolist()),
pbc=[True, True, True],
)
calc = sp(
atoms=atoms,
energy=energies[idx],
forces=forces[idx].cpu().detach().numpy(),
)
atoms.set_calculator(calc)

if results is not None:
calc = SinglePointCalculator(
atoms=atoms, **{key: val[idx] for key, val in results.items()}
)
atoms.set_calculator(calc)

atoms_objects.append(atoms)

return atoms_objects


class OCPCalculator(Calculator):
implemented_properties: ClassVar[list[str]] = ["energy", "forces"]
"""ASE based calculator using an OCP model"""

_reshaped_props = ASE_PROP_RESHAPE

def __init__(
self,
config_yml: str | None = None,
checkpoint_path: str | None = None,
checkpoint_path: str | Path | None = None,
model_name: str | None = None,
local_cache: str | None = None,
trainer: str | None = None,
cutoff: int = 6,
max_neighbors: int = 50,
cpu: bool = True,
seed: int | None = None,
) -> None:
Expand All @@ -96,16 +140,12 @@ def __init__(
Directory to save pretrained model checkpoints.
trainer (str):
OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE.
cutoff (int):
Cutoff radius to be used for data preprocessing.
max_neighbors (int):
Maximum amount of neighbors to store for a given atom.
cpu (bool):
Whether to load and run the model on CPU. Set `False` for GPU.
"""
setup_imports()
setup_logging()
Calculator.__init__(self)
super().__init__()

if model_name is not None:
if checkpoint_path is not None:
Expand Down Expand Up @@ -165,9 +205,8 @@ def __init__(
### backwards compatability with OCP v<2.0
config = update_config(config)

# Save config so obj can be transported over network (pkl)
self.config = copy.deepcopy(config)
self.config["checkpoint"] = checkpoint_path
self.config["checkpoint"] = str(checkpoint_path)
del config["dataset"]["src"]

self.trainer = registry.get_trainer_class(config["trainer"])(
Expand Down Expand Up @@ -199,14 +238,13 @@ def __init__(
self.trainer.set_seed(seed)

self.a2g = AtomsToGraphs(
max_neigh=max_neighbors,
radius=cutoff,
r_energy=False,
r_forces=False,
r_distances=False,
r_edges=False,
r_pbc=True,
r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model...
)
self.implemented_properties = list(self.config["outputs"].keys())

def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
Expand All @@ -217,6 +255,8 @@ def load_checkpoint(
Args:
checkpoint_path: string
Path to trained model
checkpoint: dict
A pretrained checkpoint dict
"""
try:
self.trainer.load_checkpoint(
Expand All @@ -225,14 +265,20 @@ def load_checkpoint(
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

def calculate(self, atoms: Atoms, properties, system_changes) -> None:
Calculator.calculate(self, atoms, properties, system_changes)
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None:
"""Calculate implemented properties for a single Atoms object or a Batch of them."""
super().calculate(atoms, properties, system_changes)
if isinstance(atoms, Atoms):
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
else:
batch = atoms

predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True)

for key in predictions:
_pred = predictions[key]
_pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy()
if key in OCPCalculator._reshaped_props:
_pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze()
self.results[key] = _pred
Loading

0 comments on commit 0e52700

Please sign in to comment.