From 3f695efd4877db6cdb4adde9d5c8ca8008ce2a75 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Tue, 3 Dec 2024 10:52:44 -0800 Subject: [PATCH] OptimizableBatch and stress relaxations (#718) * remove r_edges, radius, max_neigh and add deprecation warning * edit typing and dont use dicts as default * use super() and remove overkill deprecation warning * set implemented_properties from config * make determine step a method * allow calculator to operate on batches * only update if old config is used * reshape properties * no test classes in ase calculator * yaml load fix * use mappingproxy * expressive import * remove duplicated code * optimizable batch class for ase compatible batch relaxations * fix optimizable batch * optimizable goodies * apply force constraints * use optimizable batch instead and remove torchcalc * update ml relaxations to use optimizable batch correctly * force_consistent check for ASE compat * force_consistent check for ASE compat * check force_consistent * init docs in lbfgs * unitcellfilter for batch relaxations * ruff * UnitCellOptimizable as child class instead of filter * allow running unit cell relaxations * ruff * no grad in run_relaxations * make batched_dot and determine_step methods * imports * rename to optimizableunitcellbatch * allow passing energy and forces explicitly to batch to atoms * check convergence in optimizable and allow passing general results to atoms_from_batch * relaxation test * unit tests * move update mask to optimizable * use energy instead of y * all setting/getting positions and convergence in optimizable * more (unfinished) tests * backwards compatible test * minor fixes * code cleanup * add/fix tests * fix lbfgs * assert using norm * add eps to masked batches if using ASE optimizers * match iterations from previous implementation * use float64 for forces * float32 * use energy_relaxed instead of y_relaxed * energy_relaxed and more explicit error msg * default to batch_size 1 if not set in config * keep float64 training * rename y_relaxed -> energy_relaxed * rm expcell batch * convenience commit from no_experimental_resolve * use numatoms tensor for cell factor * remove positions tests (wrapping atoms gives different results) * allow wrapping positions in batch to atoms * fix test * wrap_positions in batch_to_atoms * take a2g properties from model * test lbfgs traj writes * remove comments * use model generate graph * fix cell_factor * fix using model in ddp * fix r_edges in OCPcalculator * write initial and final structure if save_full is false * check unique atoms saved in trajectory * tighter tol * update ASE release comment * remove cumulative mask option * remove left over cumulative_mask * fix batching when sids as str * do not try to fetch energy and forces if no explicit results * accept Path objects * clean up setting defaults * expose ml_relax in relaxation * force set r_pbc True * make relax_opt optional * no ema on inference only * define ema none to avoid issues * lower force threshold to make sure test does not converge * clean up exception msg * allow strings in batch * remove device argument from lbfgs * minor cleanup * fix optimizable import * do not pass device in ml_relax * simplify enforce max neighbors * fix tests (still not testing stress) * pin sphinx autoapi * typo in version --------- Co-authored-by: zulissimeta <122578103+zulissimeta@users.noreply.github.com> Co-authored-by: Zack Ulissi --- packages/fairchem-core/pyproject.toml | 2 +- .../core/common/relaxation/__init__.py | 13 + .../core/common/relaxation/ase_utils.py | 108 +++- .../core/common/relaxation/ml_relaxation.py | 117 ++-- .../core/common/relaxation/optimizable.py | 547 ++++++++++++++++++ .../common/relaxation/optimizers/__init__.py | 12 + .../relaxation/optimizers/lbfgs_torch.py | 238 ++++---- src/fairchem/core/datasets/ase_datasets.py | 16 +- src/fairchem/core/models/base.py | 12 +- .../core/preprocessing/atoms_to_graphs.py | 2 +- src/fairchem/core/trainers/base_trainer.py | 7 +- src/fairchem/core/trainers/ocp_trainer.py | 13 +- tests/core/common/conftest.py | 33 ++ tests/core/common/test_ase_calculator.py | 3 + tests/core/common/test_lbfgs_torch.py | 66 +++ tests/core/common/test_optimizable.py | 110 ++++ tests/core/datasets/test_ase_datasets.py | 10 +- 17 files changed, 1068 insertions(+), 241 deletions(-) create mode 100644 src/fairchem/core/common/relaxation/optimizable.py create mode 100644 tests/core/common/conftest.py create mode 100644 tests/core/common/test_lbfgs_torch.py create mode 100644 tests/core/common/test_optimizable.py diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index ee92db45df..dfd5c671ae 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ [project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev] dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"] -docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"] +docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"] adsorbml = ["dscribe","x3dase","scikit-image"] [project.scripts] diff --git a/src/fairchem/core/common/relaxation/__init__.py b/src/fairchem/core/common/relaxation/__init__.py index e69de29bb2..1700e00405 100644 --- a/src/fairchem/core/common/relaxation/__init__.py +++ b/src/fairchem/core/common/relaxation/__init__.py @@ -0,0 +1,13 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from .ml_relaxation import ml_relax +from .optimizable import OptimizableBatch, OptimizableUnitCellBatch + +__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"] diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 2dacce2cb7..5a9302d88f 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -14,13 +14,15 @@ import copy import logging -from typing import ClassVar +from types import MappingProxyType +from typing import TYPE_CHECKING import torch from ase import Atoms from ase.calculators.calculator import Calculator -from ase.calculators.singlepoint import SinglePointCalculator as sp +from ase.calculators.singlepoint import SinglePointCalculator from ase.constraints import FixAtoms +from ase.geometry import wrap_positions from fairchem.core.common.registry import registry from fairchem.core.common.utils import ( @@ -33,51 +35,93 @@ from fairchem.core.models.model_registry import model_name_to_local_file from fairchem.core.preprocessing import AtomsToGraphs +if TYPE_CHECKING: + from pathlib import Path -def batch_to_atoms(batch): + from torch_geometric.data import Batch + + +# system level model predictions have different shapes than expected by ASE +ASE_PROP_RESHAPE = MappingProxyType( + {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} +) + + +def batch_to_atoms( + batch: Batch, + results: dict[str, torch.Tensor] | None = None, + wrap_pos: bool = True, + eps: float = 1e-7, +) -> list[Atoms]: + """Convert a data batch to ase Atoms + + Args: + batch: data batch + results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results + are given no calculator will be added to the atoms objects. + wrap_pos: wrap positions back into the cell. + eps: Small number to prevent slightly negative coordinates from being wrapped. + + Returns: + list of Atoms + """ n_systems = batch.natoms.shape[0] natoms = batch.natoms.tolist() numbers = torch.split(batch.atomic_numbers, natoms) fixed = torch.split(batch.fixed.to(torch.bool), natoms) - forces = torch.split(batch.force, natoms) + if results is not None: + results = { + key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist() + if len(val) == len(batch) + else [v.cpu().detach().numpy() for v in torch.split(val, natoms)] + for key, val in results.items() + } + positions = torch.split(batch.pos, natoms) tags = torch.split(batch.tags, natoms) cells = batch.cell - energies = batch.energy.view(-1).tolist() atoms_objects = [] for idx in range(n_systems): + pos = positions[idx].cpu().detach().numpy() + cell = cells[idx].cpu().detach().numpy() + + # TODO take pbc from data + if wrap_pos: + pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps) + atoms = Atoms( numbers=numbers[idx].tolist(), - positions=positions[idx].cpu().detach().numpy(), + cell=cell, + positions=pos, tags=tags[idx].tolist(), - cell=cells[idx].cpu().detach().numpy(), constraint=FixAtoms(mask=fixed[idx].tolist()), pbc=[True, True, True], ) - calc = sp( - atoms=atoms, - energy=energies[idx], - forces=forces[idx].cpu().detach().numpy(), - ) - atoms.set_calculator(calc) + + if results is not None: + calc = SinglePointCalculator( + atoms=atoms, **{key: val[idx] for key, val in results.items()} + ) + atoms.set_calculator(calc) + atoms_objects.append(atoms) return atoms_objects class OCPCalculator(Calculator): - implemented_properties: ClassVar[list[str]] = ["energy", "forces"] + """ASE based calculator using an OCP model""" + + _reshaped_props = ASE_PROP_RESHAPE def __init__( self, config_yml: str | None = None, - checkpoint_path: str | None = None, + checkpoint_path: str | Path | None = None, model_name: str | None = None, local_cache: str | None = None, trainer: str | None = None, - cutoff: int = 6, - max_neighbors: int = 50, cpu: bool = True, seed: int | None = None, ) -> None: @@ -96,16 +140,12 @@ def __init__( Directory to save pretrained model checkpoints. trainer (str): OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE. - cutoff (int): - Cutoff radius to be used for data preprocessing. - max_neighbors (int): - Maximum amount of neighbors to store for a given atom. cpu (bool): Whether to load and run the model on CPU. Set `False` for GPU. """ setup_imports() setup_logging() - Calculator.__init__(self) + super().__init__() if model_name is not None: if checkpoint_path is not None: @@ -165,9 +205,8 @@ def __init__( ### backwards compatability with OCP v<2.0 config = update_config(config) - # Save config so obj can be transported over network (pkl) self.config = copy.deepcopy(config) - self.config["checkpoint"] = checkpoint_path + self.config["checkpoint"] = str(checkpoint_path) del config["dataset"]["src"] self.trainer = registry.get_trainer_class(config["trainer"])( @@ -199,14 +238,13 @@ def __init__( self.trainer.set_seed(seed) self.a2g = AtomsToGraphs( - max_neigh=max_neighbors, - radius=cutoff, r_energy=False, r_forces=False, r_distances=False, - r_edges=False, r_pbc=True, + r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model... ) + self.implemented_properties = list(self.config["outputs"].keys()) def load_checkpoint( self, checkpoint_path: str, checkpoint: dict | None = None @@ -217,6 +255,8 @@ def load_checkpoint( Args: checkpoint_path: string Path to trained model + checkpoint: dict + A pretrained checkpoint dict """ try: self.trainer.load_checkpoint( @@ -225,14 +265,20 @@ def load_checkpoint( except NotImplementedError: logging.warning("Unable to load checkpoint!") - def calculate(self, atoms: Atoms, properties, system_changes) -> None: - Calculator.calculate(self, atoms, properties, system_changes) - data_object = self.a2g.convert(atoms) - batch = data_list_collater([data_object], otf_graph=True) + def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None: + """Calculate implemented properties for a single Atoms object or a Batch of them.""" + super().calculate(atoms, properties, system_changes) + if isinstance(atoms, Atoms): + data_object = self.a2g.convert(atoms) + batch = data_list_collater([data_object], otf_graph=True) + else: + batch = atoms predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True) for key in predictions: _pred = predictions[key] _pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy() + if key in OCPCalculator._reshaped_props: + _pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze() self.results[key] = _pred diff --git a/src/fairchem/core/common/relaxation/ml_relaxation.py b/src/fairchem/core/common/relaxation/ml_relaxation.py index 406b6b1cc3..bf5eb3cac7 100644 --- a/src/fairchem/core/common/relaxation/ml_relaxation.py +++ b/src/fairchem/core/common/relaxation/ml_relaxation.py @@ -10,6 +10,7 @@ import logging from collections import deque from pathlib import Path +from typing import TYPE_CHECKING import torch from torch_geometric.data import Batch @@ -17,70 +18,94 @@ from fairchem.core.common.typing import assert_is_instance from fairchem.core.datasets.lmdb_dataset import data_list_collater -from .optimizers.lbfgs_torch import LBFGS, TorchCalc +from .optimizable import OptimizableBatch, OptimizableUnitCellBatch +from .optimizers.lbfgs_torch import LBFGS + +if TYPE_CHECKING: + from fairchem.core.trainers import BaseTrainer def ml_relax( - batch, - model, + batch: Batch, + model: BaseTrainer, steps: int, fmax: float, - relax_opt, - save_full_traj, - device: str = "cuda:0", - transform=None, - early_stop_batch: bool = False, + relax_opt: dict[str] | None = None, + relax_cell: bool = False, + relax_volume: bool = False, + save_full_traj: bool = True, + transform: torch.nn.Module | None = None, + mask_converged: bool = True, ): - """ - Runs ML-based relaxations. + """Runs ML-based relaxations. + Args: - batch: object - model: object - steps: int - Max number of steps in the structure relaxation. - fmax: float - Structure relaxation terminates when the max force - of the system is no bigger than fmax. - relax_opt: str - Optimizer and corresponding parameters to be used for structure relaxations. - save_full_traj: bool - Whether to save out the full ASE trajectory. If False, only save out initial and final frames. + batch: a data batch object. + model: a trainer object with model. + steps: Max number of steps in the structure relaxation. + fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax. + relax_opt: Optimizer parameters to be used for structure relaxations. + relax_cell: if true will use stress predictions to relax crystallographic cell. + The model given must predict stress + relax_volume: if true will relax the cell isotropically. the given model must predict stress. + save_full_traj: Whether to save out the full ASE trajectory. If False, only save out initial and final frames. + mask_converged: whether to mask batches where all atoms are below convergence threshold + cumulative_mask: if true, once system is masked then it remains masked even if new predictions give forces + above threshold, ie. once masked always masked. Note if this is used make sure to check convergence with + the same fmax always """ + relax_opt = relax_opt or {} + # if not pbc is set, ignore it when comparing batches + if not hasattr(batch, "pbc"): + OptimizableBatch.ignored_changes = {"pbc"} + batches = deque([batch]) relaxed_batches = [] while batches: batch = batches.popleft() oom = False ids = batch.sid - calc = TorchCalc(model, transform) + + # clone the batch otherwise you can not run batch.to_data_list + # see https://github.com/pyg-team/pytorch_geometric/issues/8439#issuecomment-1826747915 + if relax_cell or relax_volume: + optimizable = OptimizableUnitCellBatch( + batch.clone(), + trainer=model, + transform=transform, + mask_converged=mask_converged, + hydrostatic_strain=relax_volume, + ) + else: + optimizable = OptimizableBatch( + batch.clone(), + trainer=model, + transform=transform, + mask_converged=mask_converged, + ) # Run ML-based relaxation - traj_dir = relax_opt.get("traj_dir", None) + traj_dir = relax_opt.get("traj_dir") + relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None}) + optimizer = LBFGS( - batch, - calc, - maxstep=relax_opt.get("maxstep", 0.2), - memory=relax_opt["memory"], - damping=relax_opt.get("damping", 1.2), - alpha=relax_opt.get("alpha", 80.0), - device=device, + optimizable_batch=optimizable, save_full_traj=save_full_traj, - traj_dir=Path(traj_dir) if traj_dir is not None else None, traj_names=ids, - early_stop_batch=early_stop_batch, + **relax_opt, ) e: RuntimeError | None = None try: - relaxed_batch = optimizer.run(fmax=fmax, steps=steps) - relaxed_batches.append(relaxed_batch) + optimizer.run(fmax=fmax, steps=steps) + relaxed_batches.append(optimizable.batch) except RuntimeError as err: e = err oom = True torch.cuda.empty_cache() if oom: - # move OOM recovery code outside of except clause to allow tensors to be freed. + # move OOM recovery code outside off except clause to allow tensors to be freed. data_list = batch.to_data_list() if len(data_list) == 1: raise assert_is_instance(e, RuntimeError) @@ -88,7 +113,23 @@ def ml_relax( f"Failed to relax batch with size: {len(data_list)}, splitting into two..." ) mid = len(data_list) // 2 - batches.appendleft(data_list_collater(data_list[:mid])) - batches.appendleft(data_list_collater(data_list[mid:])) + batches.appendleft( + data_list_collater(data_list[:mid], otf_graph=optimizable.otf_graph) + ) + batches.appendleft( + data_list_collater(data_list[mid:], otf_graph=optimizable.otf_graph) + ) + + # reset for good measure + OptimizableBatch.ignored_changes = {} + + relaxed_batch = Batch.from_data_list(relaxed_batches) + + # Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str + # it will be incorrectly collated as a list of lists for each batch. + # but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above). + # So instead just manually fix it for now. Remove this once pyg dependency is removed + if isinstance(relaxed_batch.sid, list): + relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list] - return Batch.from_data_list(relaxed_batches) + return relaxed_batch diff --git a/src/fairchem/core/common/relaxation/optimizable.py b/src/fairchem/core/common/relaxation/optimizable.py new file mode 100644 index 0000000000..c40f461267 --- /dev/null +++ b/src/fairchem/core/common/relaxation/optimizable.py @@ -0,0 +1,547 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Code based on ase.optimize +""" + +from __future__ import annotations + +from functools import cached_property +from types import SimpleNamespace +from typing import TYPE_CHECKING, ClassVar + +import numpy as np +import torch +from ase.calculators.calculator import PropertyNotImplementedError +from ase.stress import voigt_6_to_full_3x3_stress +from torch_scatter import scatter + +from fairchem.core.common.relaxation.ase_utils import batch_to_atoms + +# this can be removed after pinning ASE dependency >= 3.23 +try: + from ase.optimize.optimize import Optimizable +except ImportError: + + class Optimizable: + pass + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ase import Atoms + from numpy.typing import NDArray + from torch_geometric.data import Batch + + from fairchem.core.trainers import BaseTrainer + + +ALL_CHANGES: set[str] = { + "pos", + "atomic_numbers", + "cell", + "pbc", +} + + +def compare_batches( + batch1: Batch | None, + batch2: Batch, + tol: float = 1e-6, + excluded_properties: set[str] | None = None, +) -> list[str]: + """Compare properties between two batches + + Args: + batch1: atoms batch + batch2: atoms batch + tol: tolerance used to compare equility of floating point properties + excluded_properties: list of properties to exclude from comparison + + Returns: + list of system changes, property names that are differente between batch1 and batch2 + """ + system_changes = [] + + if batch1 is None: + system_changes = ALL_CHANGES + else: + properties_to_check = set(ALL_CHANGES) + if excluded_properties: + properties_to_check -= set(excluded_properties) + + # Check properties that aren't + for prop in ALL_CHANGES: + if prop in properties_to_check: + properties_to_check.remove(prop) + if not torch.allclose( + getattr(batch1, prop), getattr(batch2, prop), atol=tol + ): + system_changes.append(prop) + + return system_changes + + +class OptimizableBatch(Optimizable): + """A Batch version of ase Optimizable Atoms + + This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation + or in ase relaxations classes, i.e. ase.optimize.lbfgs + """ + + ignored_changes: ClassVar[set[str]] = set() + + def __init__( + self, + batch: Batch, + trainer: BaseTrainer, + transform: torch.nn.Module | None = None, + mask_converged: bool = True, + numpy: bool = False, + masked_eps: float = 1e-8, + ): + """Initialize Optimizable Batch + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + mask_converged: if true will mask systems in batch that are already converged + numpy: whether to cast results to numpy arrays + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + """ + self.batch = batch.to(trainer.device) + self.trainer = trainer + self.transform = transform + self.numpy = numpy + self.mask_converged = mask_converged + self._cached_batch = None + self._update_mask = None + self.torch_results = {} + self.results = {} + self._eps = masked_eps + + self.otf_graph = True # trainer._unwrapped_model.otf_graph + if not self.otf_graph and "edge_index" not in self.batch: + self.update_graph() + + @property + def device(self): + return self.trainer.device + + @property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch.""" + return self.batch.batch + + @property + def converged_mask(self): + if self._update_mask is not None: + return torch.logical_not(self._update_mask) + return None + + @property + def update_mask(self): + if self._update_mask is None: + return torch.ones(len(self.batch), dtype=bool) + return self._update_mask + + def check_state(self, batch: Batch, tol: float = 1e-12) -> bool: + """Check for any system changes since last calculation.""" + return compare_batches( + self._cached_batch, + batch, + tol=tol, + excluded_properties=set(self.ignored_changes), + ) + + def _predict(self) -> None: + """Run prediction if batch has any changes.""" + system_changes = self.check_state(self.batch) + if len(system_changes) > 0: + self.torch_results = self.trainer.predict( + self.batch, per_image=False, disable_tqdm=True + ) + # save only subset of props in simple namespace instead of cloning the whole batch to save memory + changes = ALL_CHANGES - set(self.ignored_changes) + self._cached_batch = SimpleNamespace( + **{prop: self.batch[prop].clone() for prop in changes} + ) + + def get_property(self, name, no_numpy: bool = False) -> torch.Tensor | NDArray: + """Get a predicted property by name.""" + self._predict() + if self.numpy: + self.results = { + key: pred.item() if pred.numel() == 1 else pred.cpu().numpy() + for key, pred in self.torch_results.items() + } + else: + self.results = self.torch_results + + if name not in self.results: + raise PropertyNotImplementedError(f"{name} not present in this calculation") + + return self.results[name] if no_numpy is False else self.torch_results[name] + + def get_positions(self) -> torch.Tensor | NDArray: + """Get the batch positions""" + pos = self.batch.pos.clone() + if self.numpy: + if self.mask_converged: + pos[~self.update_mask[self.batch.batch]] = self._eps + pos = pos.cpu().numpy() + + return pos + + def set_positions(self, positions: torch.Tensor | NDArray) -> None: + """Set the atom positions in the batch.""" + if isinstance(positions, np.ndarray): + positions = torch.tensor(positions) + + positions = positions.to(dtype=torch.float32, device=self.device) + if self.mask_converged and self._update_mask is not None: + mask = self.update_mask[self.batch.batch] + self.batch.pos[mask] = positions[mask] + else: + self.batch.pos = positions + + if not self.otf_graph: + self.update_graph() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get predicted batch forces.""" + forces = self.get_property("forces", no_numpy=no_numpy) + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + if isinstance(forces, np.ndarray): + fixed_idx = fixed_idx.tolist() + forces[fixed_idx] = 0.0 + return forces + + def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray: + """Get predicted energy as the sum of all batch energies.""" + # ASE 3.22.1 expects a check for force_consistent calculations + if kwargs.get("force_consistent", False) is True: + raise PropertyNotImplementedError( + "force_consistent calculations are not implemented" + ) + if ( + len(self.batch) == 1 + ): # unfortunately batch size 1 returns a float, not a tensor + return self.get_property("energy") + return self.get_property("energy").sum() + + def get_potential_energies(self) -> torch.Tensor | NDArray: + """Get the predicted energy for each system in batch.""" + return self.get_property("energy") + + def get_cells(self) -> torch.Tensor: + """Get batch crystallographic cells.""" + return self.batch.cell + + def set_cells(self, cells: torch.Tensor | NDArray) -> None: + """Set batch cells.""" + assert self.batch.cell.shape == cells.shape, "Cell shape mismatch" + if isinstance(cells, np.ndarray): + cells = torch.tensor(cells, dtype=torch.float32, device=self.device) + cells = cells.to(dtype=torch.float32, device=self.device) + self.batch.cell[self.update_mask] = cells[self.update_mask] + + def get_volumes(self) -> torch.Tensor: + """Get a tensor of volumes for each cell in batch""" + cells = self.get_cells() + return torch.linalg.det(cells) + + def iterimages(self) -> Batch: + # XXX document purpose of iterimages - this is just needed to work with ASE optimizers + yield self.batch + + def get_max_forces( + self, forces: torch.Tensor | None = None, apply_constraint: bool = False + ) -> torch.Tensor: + """Get the maximum forces per structure in batch""" + if forces is None: + forces = self.get_forces(apply_constraint=apply_constraint, no_numpy=True) + return scatter((forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max") + + def converged( + self, + forces: torch.Tensor | NDArray | None, + fmax: float, + max_forces: torch.Tensor | None = None, + ) -> bool: + """Check if norm of all predicted forces are below fmax""" + if forces is not None: + if isinstance(forces, np.ndarray): + forces = torch.tensor(forces, device=self.device, dtype=torch.float32) + max_forces = self.get_max_forces(forces) + elif max_forces is None: + max_forces = self.get_max_forces() + + update_mask = max_forces.ge(fmax) + # update cached mask + if self.mask_converged: + if self._update_mask is None: + self._update_mask = update_mask + else: + # some models can have random noise in their predictions, so the mask is updated by + # keeping all previously converged structures masked even if new force predictions + # push it slightly above threshold + self._update_mask = torch.logical_and(self._update_mask, update_mask) + update_mask = self._update_mask + + return not torch.any(update_mask).item() + + def get_atoms_list(self) -> list[Atoms]: + """Get ase Atoms objects corresponding to the batch""" + self._predict() # in case no predictions have been run + return batch_to_atoms(self.batch, results=self.torch_results) + + def update_graph(self): + """Update the graph if model does not use otf_graph.""" + graph = self.trainer._unwrapped_model.generate_graph(self.batch) + self.batch.edge_index = graph.edge_index + self.batch.cell_offsets = graph.cell_offsets + self.batch.neighbors = graph.neighbors + if self.transform is not None: + self.batch = self.transform(self.batch) + + def __len__(self) -> int: + # TODO: this might be changed in ASE to be 3 * len(self.atoms) + return len(self.batch.pos) + + +class OptimizableUnitCellBatch(OptimizableBatch): + """Modify the supercell and the atom positions in relaxations. + + Based on ase UnitCellFilter to work on data batches + """ + + def __init__( + self, + batch: Batch, + trainer: BaseTrainer, + transform: torch.nn.Module | None = None, + numpy: bool = False, + mask_converged: bool = True, + mask: Sequence[bool] | None = None, + cell_factor: float | torch.Tensor | None = None, + hydrostatic_strain: bool = False, + constant_volume: bool = False, + scalar_pressure: float = 0.0, + masked_eps: float = 1e-8, + ): + """Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization. + + For full details see: + E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras, + Phys. Rev. B 59, 235 (1999) + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + numpy: whether to cast results to numpy arrays + mask_converged: if true will mask systems in batch that are already converged + mask: a boolean mask specifying which strain components are allowed to relax + cell_factor: + Factor by which deformation gradient is multiplied to put + it on the same scale as the positions when assembling + the combined position/cell vector. The stress contribution to + the forces is scaled down by the same factor. This can be thought + of as a very simple preconditioner. Default is number of atoms + which gives approximately the correct scaling. + hydrostatic_strain: + Constrain the cell by only allowing hydrostatic deformation. + The virial tensor is replaced by np.diag([np.trace(virial)]*3). + constant_volume: + Project out the diagonal elements of the virial tensor to allow + relaxations at constant volume, e.g. for mapping out an + energy-volume curve. Note: this only approximately conserves + the volume and breaks energy/force consistency so can only be + used with optimizers that do require a line minimisation + (e.g. FIRE). + scalar_pressure: + Applied pressure to use for enthalpy pV term. As above, this + breaks energy/force consistency. + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + """ + super().__init__( + batch=batch, + trainer=trainer, + transform=transform, + numpy=numpy, + mask_converged=mask_converged, + masked_eps=masked_eps, + ) + + self.orig_cells = self.get_cells().clone() + self.stress = None + + if mask is None: + mask = torch.eye(3, device=self.device) + + # TODO make sure mask is on GPU + if mask.shape == (6,): + self.mask = torch.tensor( + voigt_6_to_full_3x3_stress(mask.detach().cpu()), + device=self.device, + ) + elif mask.shape == (3, 3): + self.mask = mask + else: + raise ValueError("shape of mask should be (3,3) or (6,)") + + if isinstance(cell_factor, float): + cell_factor = cell_factor * torch.ones( + (3 * len(batch), 1), requires_grad=False + ) + if cell_factor is None: + cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze(dim=1) + + self.hydrostatic_strain = hydrostatic_strain + self.constant_volume = constant_volume + self.pressure = scalar_pressure * torch.eye(3, device=self.device) + self.cell_factor = cell_factor + self.stress = None + self._batch_trace = torch.vmap(torch.trace) + self._batch_diag = torch.vmap(lambda x: x * torch.eye(3, device=x.device)) + + @cached_property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch. + + We augment this to specify the batch indices for augmented positions and forces. + """ + augmented_batch = torch.repeat_interleave( + torch.arange( + len(self.batch), dtype=self.batch.batch.dtype, device=self.device + ), + 3, + ) + return torch.cat([self.batch.batch, augmented_batch]) + + def deform_grad(self): + """Get the cell deformation matrix""" + return torch.transpose( + torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2 + ) + + def get_positions(self): + """Get positions and cell deformation gradient.""" + cur_deform_grad = self.deform_grad() + natoms = self.batch.num_nodes + pos = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + dtype=self.batch.pos.dtype, + device=self.device, + ) + + # Augmented positions are the self.atoms.positions but without the applied deformation gradient + pos[:natoms] = torch.linalg.solve( + cur_deform_grad[self.batch.batch, :, :], + self.batch.pos.view(-1, 3, 1), + ).view(-1, 3) + # cell DOFs are the deformation gradient times a scaling factor + pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3) + return pos.cpu().numpy() if self.numpy else pos + + def set_positions(self, positions: torch.Tensor | NDArray): + """Set positions and cell. + + positions has shape (natoms + ncells * 3, 3). + the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor + for each cell. + """ + if isinstance(positions, np.ndarray): + positions = torch.tensor(positions) + + positions = positions.to(dtype=torch.float32, device=self.device) + natoms = self.batch.num_nodes + new_atom_positions = positions[:natoms] + new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3) + + # TODO check that in fact symmetry is preserved setting cells and positions + # Set the new cell from the original cell and the new deformation gradient. Both current and final structures + # should preserve symmetry. + new_cells = torch.bmm(self.orig_cells, torch.transpose(new_deform_grad, 1, 2)) + self.set_cells(new_cells) + + # Set the positions from the ones passed in (which are without the deformation gradient applied) and the new + # deformation gradient. This should also preserve symmetry + new_atom_positions = torch.bmm( + new_atom_positions.view(-1, 1, 3), + torch.transpose( + new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2 + ), + ) + super().set_positions(new_atom_positions.view(-1, 3)) + + def get_potential_energy(self, **kwargs): + """ + returns potential energy including enthalpy PV term. + """ + atoms_energy = super().get_potential_energy(**kwargs) + return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get forces and unit cell stress.""" + stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3) + atom_forces = self.get_property("forces", no_numpy=True) + + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + atom_forces[fixed_idx] = 0.0 + + volumes = self.get_volumes().view(-1, 1, 1) + virial = -volumes * stress + self.pressure.view(-1, 3, 3) + cur_deform_grad = self.deform_grad() + atom_forces = torch.bmm( + atom_forces.view(-1, 1, 3), + cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), + ) + virial = torch.linalg.solve( + cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2) + ) + virial = torch.transpose(virial, dim0=1, dim1=2) + + # TODO this does not work yet! maybe _batch_trace gives an issue + if self.hydrostatic_strain: + virial = self._batch_diag(self._batch_trace(virial) / 3.0) + + # Zero out components corresponding to fixed lattice elements + if (self.mask != 1.0).any(): + virial *= self.mask.view(-1, 3, 3) + + if self.constant_volume: + virial[:, range(3), range(3)] -= self._batch_trace(virial).view(3, -1) / 3.0 + + natoms = self.batch.num_nodes + augmented_forces = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + device=self.device, + dtype=atom_forces.dtype, + ) + augmented_forces[:natoms] = atom_forces.view(-1, 3) + augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor + + self.stress = -virial.view(-1, 9) / volumes.view(-1, 1) + + if self.numpy and not no_numpy: + augmented_forces = augmented_forces.cpu().numpy() + + return augmented_forces + + def __len__(self): + return len(self.batch.pos) + 3 * len(self.batch) diff --git a/src/fairchem/core/common/relaxation/optimizers/__init__.py b/src/fairchem/core/common/relaxation/optimizers/__init__.py index e69de29bb2..1c7c27f9f1 100644 --- a/src/fairchem/core/common/relaxation/optimizers/__init__.py +++ b/src/fairchem/core/common/relaxation/optimizers/__init__.py @@ -0,0 +1,12 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from .lbfgs_torch import LBFGS + +__all__ = ["LBFGS"] diff --git a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py index a90f0dce5b..467c4bec41 100644 --- a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py +++ b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py @@ -16,87 +16,66 @@ import torch from torch_scatter import scatter -from fairchem.core.common.relaxation.ase_utils import batch_to_atoms -from fairchem.core.common.utils import radius_graph_pbc - if TYPE_CHECKING: - from torch_geometric.data import Batch + from .optimizable import OptimizableBatch class LBFGS: + """Limited memory BFGS optimizer for batch ML relaxations.""" + def __init__( self, - batch: Batch, - model: TorchCalc, - maxstep: float = 0.01, + optimizable_batch: OptimizableBatch, + maxstep: float = 0.02, memory: int = 100, - damping: float = 0.25, + damping: float = 1.2, alpha: float = 100.0, - force_consistent=None, - device: str = "cuda:0", save_full_traj: bool = True, traj_dir: Path | None = None, - traj_names=None, - early_stop_batch: bool = False, + traj_names: list[str] | None = None, ) -> None: - self.batch = batch - self.model = model + """ + Args: + optimizable_batch: an optimizable batch which includes a model and a batch of data + maxstep: largest step that any atom is allowed to move + memory: Number of steps to be stored in memory + damping: The calculated step is multiplied with this number before added to the positions. + alpha: Initial guess for the Hessian (curvature of energy surface) + save_full_traj: wether to save full trajectory + traj_dir: path to save trajectories in + traj_names: list of trajectory files names + """ + self.optimizable = optimizable_batch self.maxstep = maxstep self.memory = memory self.damping = damping self.alpha = alpha self.H0 = 1.0 / self.alpha - self.force_consistent = force_consistent - self.device = device self.save_full = save_full_traj self.traj_dir = traj_dir self.traj_names = traj_names - self.early_stop_batch = early_stop_batch - self.otf_graph = True - assert not self.traj_dir or ( - traj_dir and len(traj_names) - ), "Trajectory names should be specified to save trajectories" - logging.info("Step Fmax(eV/A)") - - if not self.otf_graph and "edge_index" not in batch: - self.model.update_graph(self.batch) - - def get_energy_and_forces(self, apply_constraint: bool = True): - energy, forces = self.model.get_energy_and_forces(self.batch, apply_constraint) - return energy, forces - - def set_positions(self, update, update_mask) -> None: - if not self.early_stop_batch: - update = torch.where(update_mask.unsqueeze(1), update, 0.0) - self.batch.pos += update.to(dtype=torch.float32) - - if not self.otf_graph: - self.model.update_graph(self.batch) - - def check_convergence(self, iteration, forces=None, energy=None): - if forces is None or energy is None: - energy, forces = self.get_energy_and_forces() - forces = forces.to(dtype=torch.float64) + self.trajectories = None - max_forces_ = scatter( - (forces**2).sum(axis=1).sqrt(), self.batch.batch, reduce="max" - ) - logging.info( - f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces_.tolist()) - ) + self.fmax = None + self.steps = None - # (batch_size) -> (nAtoms) - max_forces = max_forces_[self.batch.batch] + self.s = deque(maxlen=self.memory) + self.y = deque(maxlen=self.memory) + self.rho = deque(maxlen=self.memory) + self.r0 = None + self.f0 = None - return max_forces.lt(self.fmax), energy, forces + assert not self.traj_dir or ( + traj_dir and len(traj_names) + ), "Trajectory names should be specified to save trajectories" def run(self, fmax, steps): self.fmax = fmax self.steps = steps - self.s = deque(maxlen=self.memory) - self.y = deque(maxlen=self.memory) - self.rho = deque(maxlen=self.memory) + self.s.clear() + self.y.clear() + self.rho.clear() self.r0 = self.f0 = None self.trajectories = None @@ -108,29 +87,33 @@ def run(self, fmax, steps): ] iteration = 0 - converged = False - converged_mask = torch.zeros_like( - self.batch.atomic_numbers, device=self.device - ).bool() - while iteration < steps and not converged: - _converged_mask, energy, forces = self.check_convergence(iteration) - # Models like GemNet-OC can have random noise in their predictions. - # Here we ensure atom positions are not being updated after already - # hitting the desired convergence criteria. - converged_mask = torch.logical_or(converged_mask, _converged_mask) - converged = torch.all(converged_mask) - update_mask = torch.logical_not(converged_mask) + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + logging.info("Step Fmax(eV/A)") + + while iteration < steps and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ): + logging.info( + f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist()) + ) if self.trajectories is not None and ( - self.save_full or converged or iteration == steps - 1 or iteration == 0 + self.save_full is True or iteration == 0 ): - self.write(energy, forces, update_mask) - - if not converged and iteration < steps - 1: - self.step(iteration, forces, update_mask) + self.write() + self.step(iteration) + max_forces = self.optimizable.get_max_forces(apply_constraint=True) iteration += 1 + logging.info( + f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist()) + ) + + # save after converged or all iterations ran + if iteration > 0 and self.trajectories is not None: + self.write() + # GPU memory usage as per nvidia-smi seems to gradually build up as # batches are processed. This releases unoccupied cached memory. torch.cuda.empty_cache() @@ -142,102 +125,79 @@ def run(self, fmax, steps): traj_fl = Path(self.traj_dir / f"{name}.traj_tmp", mode="w") traj_fl.rename(traj_fl.with_suffix(".traj")) - self.batch.energy, self.batch.force = self.get_energy_and_forces( - apply_constraint=False - ) - return self.batch + # set predicted values to batch + for name, value in self.optimizable.results.items(): + setattr(self.optimizable.batch, name, value) - def step( - self, - iteration: int, - forces: torch.Tensor | None, - update_mask: torch.Tensor, - ) -> None: - def _batched_dot(x: torch.Tensor, y: torch.Tensor): - return scatter((x * y).sum(dim=-1), self.batch.batch, reduce="sum") - - def determine_step(dr): - steplengths = torch.norm(dr, dim=1) - longest_steps = scatter(steplengths, self.batch.batch, reduce="max") - longest_steps = longest_steps[self.batch.batch] - maxstep = longest_steps.new_tensor(self.maxstep) - scale = (longest_steps + 1e-7).reciprocal() * torch.min( - longest_steps, maxstep - ) - dr *= scale.unsqueeze(1) - return dr * self.damping + return self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ) - if forces is None: - _, forces = self.get_energy_and_forces() + def determine_step(self, dr): + steplengths = torch.norm(dr, dim=1) + longest_steps = scatter( + steplengths, self.optimizable.batch_indices, reduce="max" + ) + longest_steps = longest_steps[self.optimizable.batch_indices] + maxstep = longest_steps.new_tensor(self.maxstep) + scale = (longest_steps + 1e-7).reciprocal() * torch.min(longest_steps, maxstep) + dr *= scale.unsqueeze(1) + return dr * self.damping + + def _batched_dot(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" + ) - r = self.batch.pos.clone().to(dtype=torch.float64) + def step(self, iteration: int) -> None: + # cast forces and positions to float64 otherwise the algorithm is prone to overflow + forces = self.optimizable.get_forces(apply_constraint=True).to( + dtype=torch.float64 + ) + pos = self.optimizable.get_positions().to(dtype=torch.float64) # Update s, y, rho if iteration > 0: - s0 = r - self.r0 + s0 = pos - self.r0 self.s.append(s0) y0 = -(forces - self.f0) self.y.append(y0) - self.rho.append(1.0 / _batched_dot(y0, s0)) + self.rho.append(1.0 / self._batched_dot(y0, s0)) loopmax = min(self.memory, iteration) - alpha = forces.new_empty(loopmax, self.batch.natoms.shape[0]) + alpha = forces.new_empty(loopmax, self.optimizable.batch.natoms.shape[0]) q = -forces for i in range(loopmax - 1, -1, -1): - alpha[i] = self.rho[i] * _batched_dot(self.s[i], q) # b - q -= alpha[i][self.batch.batch, ..., None] * self.y[i] + alpha[i] = self.rho[i] * self._batched_dot(self.s[i], q) # b + q -= alpha[i][self.optimizable.batch_indices, ..., None] * self.y[i] z = self.H0 * q for i in range(loopmax): - beta = self.rho[i] * _batched_dot(self.y[i], z) + beta = self.rho[i] * self._batched_dot(self.y[i], z) z += self.s[i] * ( - alpha[i][self.batch.batch, ..., None] - - beta[self.batch.batch, ..., None] + alpha[i][self.optimizable.batch_indices, ..., None] + - beta[self.optimizable.batch_indices, ..., None] ) # descent direction p = -z - dr = determine_step(p) + dr = self.determine_step(p) + if torch.abs(dr).max() < 1e-7: # Same configuration again (maybe a restart): return - self.set_positions(dr, update_mask) - - self.r0 = r + self.optimizable.set_positions(pos + dr) + self.r0 = pos self.f0 = forces - def write(self, energy, forces, update_mask) -> None: - self.batch.energy, self.batch.force = energy, forces - atoms_objects = batch_to_atoms(self.batch) - update_mask_ = torch.split(update_mask, self.batch.natoms.tolist()) - for atm, traj, mask in zip(atoms_objects, self.trajectories, update_mask_): - if mask[0] or not self.save_full: + def write(self) -> None: + atoms_objects = self.optimizable.get_atoms_list() + for atm, traj, mask in zip( + atoms_objects, self.trajectories, self.optimizable.update_mask + ): + if mask: traj.write(atm) - - -class TorchCalc: - def __init__(self, model, transform=None) -> None: - self.model = model - self.transform = transform - - def get_energy_and_forces(self, atoms, apply_constraint: bool = True): - predictions = self.model.predict(atoms, per_image=False, disable_tqdm=True) - energy = predictions["energy"] - forces = predictions["forces"] - if apply_constraint: - fixed_idx = torch.where(atoms.fixed == 1)[0] - forces[fixed_idx] = 0 - return energy, forces - - def update_graph(self, atoms): - edge_index, cell_offsets, num_neighbors = radius_graph_pbc(atoms, 6, 50) - atoms.edge_index = edge_index - atoms.cell_offsets = cell_offsets - atoms.neighbors = num_neighbors - if self.transform is not None: - atoms = self.transform(atoms) - return atoms diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index d688b8e798..ebbe1dfac3 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -105,7 +105,7 @@ def __init__( if len(self.ids) == 0: raise ValueError( - rf"No valid ase data found!" + rf"No valid ase data found! \n" f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" ) @@ -142,7 +142,7 @@ def __getitem__(self, idx): data_object = self.transforms(data_object) if self.config.get("include_relaxed_energy", False): - data_object.y_relaxed = self.get_relaxed_energy(self.ids[idx]) + data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx]) return data_object @@ -160,9 +160,12 @@ def _load_dataset_get_ids(self, config): "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." ) - @abstractmethod def get_relaxed_energy(self, identifier): - raise NotImplementedError("IS2RE-Direct is not implemented with this dataset.") + raise NotImplementedError( + "Reading relaxed energy from trajectory or file is not implemented with this dataset. " + "If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in " + "the r_data_keys argument under a2g_args." + ) def sample_property_metadata(self, num_samples: int = 100) -> dict: metadata = {} @@ -568,8 +571,3 @@ def sample_property_metadata(self, num_samples: int = 100) -> dict: return super().sample_property_metadata(num_samples) return copy.deepcopy(self.dbs[0].metadata) - - def get_relaxed_energy(self, identifier): - raise NotImplementedError( - "IS2RE-Direct training with an ASE DB is not currently supported." - ) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 32865e0ef8..e6c3e08206 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -63,14 +63,10 @@ def generate_graph( use_pbc_single = use_pbc_single or self.use_pbc_single otf_graph = otf_graph or self.otf_graph - if enforce_max_neighbors_strictly is not None: - pass - elif hasattr(self, "enforce_max_neighbors_strictly"): - # Not all models will have this attribute - enforce_max_neighbors_strictly = self.enforce_max_neighbors_strictly - else: - # Default to old behavior - enforce_max_neighbors_strictly = True + if enforce_max_neighbors_strictly is None: + enforce_max_neighbors_strictly = getattr( + self, "enforce_max_neighbors_strictly", True + ) if not otf_graph: try: diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index f4b5a757b4..fa679a262a 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -250,7 +250,7 @@ def convert(self, atoms: ase.Atoms, sid=None): for data_key in self.r_data_keys: data[data_key] = ( atoms.info[data_key] - if isinstance(atoms.info[data_key], (int, float)) + if isinstance(atoms.info[data_key], (int, float, str)) else torch.Tensor(atoms.info[data_key]) ) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 90cdce0e58..5c21a743aa 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -215,6 +215,7 @@ def __init__( self.test_dataset = None self.best_val_metric = None self.primary_metric = None + self.ema = None self.load(inference_only) @@ -361,7 +362,7 @@ def convert_settings_to_split_settings(config, split_name): ) self.train_sampler = self.get_sampler( self.train_dataset, - self.config["optim"]["batch_size"], + self.config["optim"].get("batch_size", 1), shuffle=True, ) self.train_loader = self.get_dataloader( @@ -392,7 +393,7 @@ def convert_settings_to_split_settings(config, split_name): self.val_sampler = self.get_sampler( self.val_dataset, self.config["optim"].get( - "eval_batch_size", self.config["optim"]["batch_size"] + "eval_batch_size", self.config["optim"].get("batch_size", 1) ), shuffle=False, ) @@ -414,7 +415,7 @@ def convert_settings_to_split_settings(config, split_name): self.test_sampler = self.get_sampler( self.test_dataset, self.config["optim"].get( - "eval_batch_size", self.config["optim"]["batch_size"] + "eval_batch_size", self.config["optim"].get("batch_size", 1) ), shuffle=False, ) diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index a8976773c6..8e5d178206 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -552,7 +552,7 @@ def predict( return predictions @torch.no_grad - def run_relaxations(self, split="val"): + def run_relaxations(self): ensure_fitted(self._unwrapped_model) # When set to true, uses deterministic CUDA scatter ops, if available. @@ -572,14 +572,14 @@ def run_relaxations(self, split="val"): evaluator_is2rs, metrics_is2rs = Evaluator(task="is2rs"), {} evaluator_is2re, metrics_is2re = Evaluator(task="is2re"), {} - # Need both `pos_relaxed` and `y_relaxed` to compute val IS2R* metrics. + # Need both `pos_relaxed` and `energy_relaxed` to compute val IS2R* metrics. # Else just generate predictions. if ( hasattr(self.relax_dataset[0], "pos_relaxed") and self.relax_dataset[0].pos_relaxed is not None ) and ( - hasattr(self.relax_dataset[0], "y_relaxed") - and self.relax_dataset[0].y_relaxed is not None + hasattr(self.relax_dataset[0], "energy_relaxed") + and self.relax_dataset[0].energy_relaxed is not None ): split = "val" else: @@ -608,9 +608,10 @@ def run_relaxations(self, split="val"): model=self, steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.02), + relax_cell=self.config["task"].get("relax_cell", False), + relax_volume=self.config["task"].get("relax_volume", False), relax_opt=self.config["task"]["relax_opt"], save_full_traj=self.config["task"].get("save_full_traj", True), - device=self.device, transform=None, ) @@ -638,7 +639,7 @@ def run_relaxations(self, split="val"): s_idx += natoms target = { - "energy": relaxed_batch.energy, + "energy": relaxed_batch.energy_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), diff --git a/tests/core/common/conftest.py b/tests/core/common/conftest.py new file mode 100644 index 0000000000..6187cbf3a9 --- /dev/null +++ b/tests/core/common/conftest.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest +from ase import build + +from fairchem.core.common.relaxation.ase_utils import OCPCalculator +from fairchem.core.datasets import data_list_collater +from fairchem.core.preprocessing.atoms_to_graphs import AtomsToGraphs + + +@pytest.fixture(scope="session") +def calculator(tmp_path_factory): + dir = tmp_path_factory.mktemp("checkpoints") + return OCPCalculator( + model_name="EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=dir, seed=0 + ) + + +@pytest.fixture() +def atoms_list(): + atoms_list = [ + build.bulk("Cu", "fcc", a=3.8, cubic=True), + build.bulk("NaCl", crystalstructure="rocksalt", a=5.8), + ] + for atoms in atoms_list: + atoms.rattle(stdev=0.05, seed=0) + return atoms_list + + +@pytest.fixture() +def batch(atoms_list): + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + return data_list_collater([a2g.convert(atoms) for atoms in atoms_list]) diff --git a/tests/core/common/test_ase_calculator.py b/tests/core/common/test_ase_calculator.py index 3d62c35e1a..92baa37cbd 100644 --- a/tests/core/common/test_ase_calculator.py +++ b/tests/core/common/test_ase_calculator.py @@ -65,6 +65,9 @@ def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None: cpu=True, ) + assert "energy" in calc.implemented_properties + assert "forces" in calc.implemented_properties + atoms.set_calculator(calc) opt = BFGS(atoms) opt.run(fmax=0.05, steps=100) diff --git a/tests/core/common/test_lbfgs_torch.py b/tests/core/common/test_lbfgs_torch.py new file mode 100644 index 0000000000..7bcf743ebb --- /dev/null +++ b/tests/core/common/test_lbfgs_torch.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from itertools import combinations, product + +import numpy as np +import numpy.testing as npt +import pytest +from ase.io import read +from ase.optimize import LBFGS as LBFGS_ASE + +from fairchem.core.common.relaxation import OptimizableBatch +from fairchem.core.common.relaxation.optimizers import LBFGS +from fairchem.core.modules.evaluator import min_diff + + +def test_lbfgs_relaxation(atoms_list, batch, calculator): + """Tests batch relaxation using fairchem LBFGS optimizer.""" + obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False) + + # optimize atoms one-by-one + for atoms in atoms_list: + atoms.calc = calculator + opt = LBFGS_ASE(atoms, damping=0.8, alpha=70.0) + opt.run(0.01, 20) + + # optimize atoms in batch using ASE + batch_optimizer = LBFGS(obatch, damping=0.8, alpha=70.0) + batch_optimizer.run(0.01, 20) + + # compare energy and atom positions, this needs pretty slack tols but that should be ok + for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): + assert a1.get_potential_energy() / len(a1) == pytest.approx( + a2.get_potential_energy() / len(a2), abs=0.05 + ) + diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) + npt.assert_allclose(diff, 0, atol=0.01) + + +@pytest.mark.parametrize( + ("save_full_traj", "steps"), list(product((True, False), (0, 1, 5))) +) +def test_lbfgs_write_trajectory(save_full_traj, steps, batch, calculator, tmp_path): + obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False) + batch_optimizer = LBFGS( + obatch, + save_full_traj=save_full_traj, + traj_dir=tmp_path, + traj_names=[f"system-{i}" for i in range(len(batch))], + ) + + batch_optimizer.run(0.001, steps=steps) + + # check that trajectory files where written + traj_files = list(tmp_path.glob("*.traj")) + assert len(traj_files) == len(batch) + + traj_length = ( + 0 if steps == 0 else steps + 1 if save_full_traj else 2 + ) # first and final frame + for file in traj_files: + traj = read(file, ":") + assert len(traj) == traj_length + + # make sure all written frames are unique + for a1, a2 in combinations(traj, r=2): + assert not np.allclose(a1.positions, a2.positions, atol=1e-5) diff --git a/tests/core/common/test_optimizable.py b/tests/core/common/test_optimizable.py new file mode 100644 index 0000000000..7024d91281 --- /dev/null +++ b/tests/core/common/test_optimizable.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import numpy as np +import numpy.testing as npt +import pytest +from ase.optimize import BFGS, FIRE, LBFGS + +try: + from ase.filters import UnitCellFilter +except ModuleNotFoundError: + # older ase version, import UnitCellFilterOld + from ase.constraints import UnitCellFilter + +from fairchem.core.common.relaxation import OptimizableBatch, OptimizableUnitCellBatch +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.evaluator import min_diff + + +@pytest.fixture(params=[FIRE, BFGS, LBFGS]) +def optimizer_cls(request): + return request.param + + +def test_ase_relaxation(atoms_list, batch, calculator, optimizer_cls): + """Tests batch relaxation using ASE optimizers.""" + obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=True) + + # optimize atoms one-by-one + for atoms in atoms_list: + atoms.calc = calculator + opt = optimizer_cls(atoms) + opt.run(0.01, 20) + + # optimize atoms in batch using ASE + batch_optimizer = optimizer_cls(obatch) + batch_optimizer.run(0.01, 20) + + # compare energy and atom positions, this needs pretty slack tols but that should be ok + for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): + assert a1.get_potential_energy() / len(a1) == pytest.approx( + a2.get_potential_energy() / len(a2), abs=0.05 + ) + diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) + npt.assert_allclose(diff, 0, atol=0.01) + + +@pytest.mark.parametrize("mask_converged", [False, True]) +def test_batch_relaxation_mask(atoms_list, calculator, mask_converged): + """Test that masking is working as intended!""" + # relax only the first atom in list + atoms = atoms_list[0] + atoms.calc = calculator + opt = LBFGS(atoms) + opt.run(0.01, 50) + assert ((atoms.get_forces() ** 2).sum(axis=1) ** 0.5 <= 0.01).all() + + # now create a batch + batch = data_list_collater([calculator.a2g.convert(atoms) for atoms in atoms_list]) + obatch = OptimizableBatch( + batch, trainer=calculator.trainer, numpy=True, mask_converged=mask_converged + ) + + npt.assert_allclose(batch.pos[batch.batch == 0].cpu().numpy(), atoms.positions) + batch_opt = LBFGS(obatch) + batch_opt.run(0.01, 20) + + if mask_converged: + # assert preconverged structure was not changed at all + npt.assert_allclose(batch.pos[batch.batch == 0].cpu().numpy(), atoms.positions) + assert not np.allclose( + batch.pos[batch.batch == 1].cpu().numpy(), atoms_list[1].positions + ) + else: + # assert that it was changed + assert not np.allclose( + batch.pos[batch.batch == 0].cpu().numpy(), atoms.positions + ) + + +@pytest.mark.skip("Skip until we have a test model that can predict stress") +def test_ase_cell_relaxation(atoms_list, batch, calculator, optimizer_cls): + """Tests batch relaxation using ASE optimizers.""" + cell_factor = batch.natoms.cpu().numpy().mean() + obatch = OptimizableUnitCellBatch( + batch, trainer=calculator.trainer, numpy=True, cell_factor=cell_factor + ) + + # optimize atoms in batch using ASE + batch_optimizer = optimizer_cls(obatch) + batch_optimizer.run(0.01, 20) + + # optimize atoms one-by-one + for atoms in atoms_list: + print(atoms.cell.array) + atoms.calc = calculator + opt = optimizer_cls(UnitCellFilter(atoms, cell_factor=cell_factor)) + opt.run(0.01, 20) + + # compare energy, atom positions and cell + for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): + assert a1.get_potential_energy() / len(a1) == pytest.approx( + a2.get_potential_energy() / len(a2), abs=0.05 + ) + diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) + npt.assert_allclose(diff, 0, atol=0.05, rtol=0.05) + + cnorm1 = np.linalg.norm(a1.cell.array, axis=1) + cnorm2 = np.linalg.norm(a2.cell.array, axis=1) + npt.assert_allclose(cnorm1, cnorm2, atol=0.01, rtol=0.01) + npt.assert_allclose(a1.cell.array.T, a2.cell.array.T, rtol=0.01, atol=0.01) diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 7b114d877f..676805c653 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -228,9 +228,9 @@ def test_ase_multiread_dataset(tmp_path): assert len(dataset) == len(atoms_objects) - assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].energy - assert dataset[-1].y_relaxed == dataset[-1].energy + assert hasattr(dataset[0], "energy_relaxed") + assert dataset[0].energy_relaxed != dataset[0].energy + assert dataset[-1].energy_relaxed == dataset[-1].energy dataset = AseReadDataset( config={ @@ -247,8 +247,8 @@ def test_ase_multiread_dataset(tmp_path): } ) - assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].energy + assert hasattr(dataset[0], "energy_relaxed") + assert dataset[0].energy_relaxed != dataset[0].energy def test_empty_dataset(tmp_path):