diff --git a/configs/ocp_example.yml b/configs/ocp_example.yml index b979a7a324..a988b4ef1a 100644 --- a/configs/ocp_example.yml +++ b/configs/ocp_example.yml @@ -12,7 +12,7 @@ dataset: # Can use 'single_point_lmdb' or 'trajectory_lmdb' for backward compatibility. # 'single_point_lmdb' was for training IS2RE models, and 'trajectory_lmdb' was # for training S2EF models. - format: lmdb # 'lmdb' or 'oc22_lmdb' + format: lmdb # 'lmdb', 'oc22_lmdb', or 'ase_d' # Directory containing training set LMDBs src: data/s2ef/all/train/ # If we want to rename a target value stored in the data object, specify the mapping here. @@ -34,9 +34,11 @@ dataset: irrep_dim: 0 anisotropic_stress: irrep_dim: 2 - # If we want to normalize targets, i.e. subtract the mean and - # divide by standard deviation, then specify the 'mean' and 'stdev' here. + # If we want to normalize targets, there are a couple of ways to specify normalization values. + # normalization values are applied as: (target - mean) / rmsd + # Note root mean squared difference (rmsd) is equal to stdev if mean != 0, and equal to rms if mean == 0. # Statistics will by default be applied to the validation and test set. + # 1) specify the 'mean' and 'stdev' explicitly here. normalizer: energy: mean: -0.7554450631141663 @@ -49,6 +51,52 @@ dataset: stdev: 674.1657344451734 anisotropic_stress: stdev: 143.72764771869745 + # 2) Estimate the values on-the-fly (OTF) from training data + normalizer: + fit: + targets: + forces: { mean: 0.0 } # values can be explicitly set, ie if you need RMS forces instead of stdev force + stress_isotropic: { } # to estimate both mean and rmsd set to {} or None + stress_anisotropic: { } + batch_size: 64 + num_batches: 5000 # if num_batches is not given, the whole dataset will be used + # 3) Specify a single .pt file with dict of target names and Normalizer modules + # (this is the format that OTF vales are saved in) + # see Normalizer module in fairchem.core.modules.normalization.normalizer + normalizer: + file: normalizers.pt + # 4) specify an individual file either .pt or .npz with keys 'mean' and 'rmsd' or 'stdev' + normalizer: + energy: + file: energy_norm.pt + forces: + file: forces_norm.npz + isotropic_stress: + file: isostress_norm.npz + anisotropic_stress: + file: anisostress_norm.npz + # If we want to train on total energies and use a per-element linear reference + # normalization scheme, we can estimate those from the data or specify the path to the per-element + # 1) Fit element references from data + element_references: + fit: + targets: + - energy + batch_size: 64 + num_batches: 5000 # if num_batches is not given, the whole dataset will be used + # 2) Specify a file with with key energy and LinearReference object. This is the format OTF references are saved in. + # see fairchem.core.modules.normalization.element_references for references. + element_references: + file: element_references.pt + # 3) Legacy files in npz format can be specified as well. They must have the elemenet references + # under the key coeff + element_references: + energy: + file: element_ref.npz + # 4) backwards compatibility only, linear references can be set as follows. Setting the references + # file as follows is a legacy setting and only works with oc22_lmdb and ase_lmdb datasets + lin_ref: element_ref.npz + # If we want to train OC20 on total energy, a path to OC20 reference # energies `oc20_ref` must be specified to unreference existing OC20 data. # download at https://dl.fbaipublicfiles.com/opencatalystproject/data/oc22/oc20_ref.pkl @@ -56,10 +104,7 @@ dataset: # OC22 defaults to total energy, so these flags are not necessary. train_on_oc20_total_energies: False # True or False oc20_ref: None # path to oc20_ref - # If we want to train on total energies and use a linear reference - # normalization scheme, we must specify the path to the per-element - # coefficients in a `.npz` format. - lin_ref: False # True or False + val: # Directory containing val set LMDBs src: data/s2ef/all/val_id/ diff --git a/docs/core/fine-tuning/fine-tuning-oxides.md b/docs/core/fine-tuning/fine-tuning-oxides.md index 77a9350d3b..39c39cad40 100644 --- a/docs/core/fine-tuning/fine-tuning-oxides.md +++ b/docs/core/fine-tuning/fine-tuning-oxides.md @@ -205,6 +205,7 @@ from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', 'optim.loss_force', # the checkpoint setting causes an error + 'optim.load_balancing', 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, 'optim.eval_every': 10, diff --git a/docs/core/install.md b/docs/core/install.md index 8ad523f326..5eb4569f82 100644 --- a/docs/core/install.md +++ b/docs/core/install.md @@ -44,28 +44,28 @@ You can also install `pytorch` and `torch_geometric` dependencies from PyPI to s similarly by selecting the appropriate versions in the official [PyG docs](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) -## Install fairchem-core +## Standard installation of fairchem-core Install `fairchem-core` from PyPi ```bash pip install fairchem-core ``` -## Additional packages - +### Additional packages `fairchem` is a namespace package, meaning all packages are installed seperately. If you need to install other packages you can do so by: ```bash pip install fairchem-{package-to-install} ``` +Available `fairchem` packages are `fairchem-core`,`fairchem-data-oc`,`fairchem-demo-ocpapi`,`fairchem-applications-cattsunami` -## Development install - +## Development installation If you plan to make contributions you will need to fork and clone (for windows user please see next section) the repo, set up the environment, and install fairchem-core from source in editable mode with dev dependencies, ```bash git clone https://github.com/FAIR-Chem/fairchem.git cd fairchem pip install -e packages/fairchem-core[dev] +pytest tests/core ``` And similarly for any other namespace package: diff --git a/docs/legacy_tutorials/OCP_Tutorial.md b/docs/legacy_tutorials/OCP_Tutorial.md index 8b5d4d522a..19fd93f6bc 100644 --- a/docs/legacy_tutorials/OCP_Tutorial.md +++ b/docs/legacy_tutorials/OCP_Tutorial.md @@ -1807,7 +1807,7 @@ Similarly, to predict forces, we pass edge features through a fully-connected la @registry.register_model("simple") class SimpleAtomEdgeModel(torch.nn.Module): - def __init__(self, num_atoms, bond_feat_dim, num_targets, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): + def __init__(self, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): super().__init__() self.radial_basis = RadialBasis( diff --git a/docs/tutorials/advanced/fine-tuning-in-python.md b/docs/tutorials/advanced/fine-tuning-in-python.md index 1d14219c88..0eeb8e5485 100644 --- a/docs/tutorials/advanced/fine-tuning-in-python.md +++ b/docs/tutorials/advanced/fine-tuning-in-python.md @@ -75,7 +75,7 @@ We start by making the config.yml. We build this from the calculator checkpoint. from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', - delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', + delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes','optim.load_balancing', 'optim.loss_force', # the checkpoint setting causes an error 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index 4d5836b786..89c3b67445 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -9,20 +9,23 @@ import heapq import logging -from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal import numba import numpy as np -import numpy.typing as npt import torch -from torch.utils.data import BatchSampler, DistributedSampler, Sampler +import torch.distributed +from torch.utils.data import BatchSampler, Dataset, DistributedSampler +from typing_extensions import override from fairchem.core.common import distutils, gp_utils from fairchem.core.datasets import data_list_collater +from fairchem.core.datasets.base_dataset import ( + UnsupportedDatasetError, +) if TYPE_CHECKING: - from pathlib import Path - + from numpy.typing import NDArray from torch_geometric.data import Batch, Data @@ -35,30 +38,24 @@ def __call__(self, data_list: list[Data]) -> Batch: @numba.njit -def balanced_partition(sizes: npt.NDArray[np.int_], num_parts: int): +def _balanced_partition(sizes: NDArray[np.int_], num_parts: int): """ Greedily partition the given set by always inserting the largest element into the smallest partition. """ sort_idx = np.argsort(-sizes) # Sort in descending order - heap: list[tuple[list[int], list[int]]] = [ - (sizes[idx], [idx]) for idx in sort_idx[:num_parts] - ] + heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]] heapq.heapify(heap) for idx in sort_idx[num_parts:]: smallest_part = heapq.heappop(heap) new_size = smallest_part[0] + sizes[idx] - new_idx = smallest_part[1] + [idx] + new_idx = smallest_part[1] + [ + idx + ] # TODO should this be append to save time/space heapq.heappush(heap, (new_size, new_idx)) return [part[1] for part in heap] -@runtime_checkable -class _HasMetadata(Protocol): - @property - def metadata_path(self) -> Path: ... - - class StatefulDistributedSampler(DistributedSampler): """ More fine-grained state DataSampler that uses training iteration and epoch @@ -105,56 +102,83 @@ def set_epoch_and_start_iteration(self, epoch, start_iter): self.start_iter = start_iter -class BalancedBatchSampler(Sampler): - def _load_dataset(self, dataset, mode: Literal["atoms", "neighbors"]): - errors: list[str] = [] - if not isinstance(dataset, _HasMetadata): - errors.append(f"Dataset {dataset} does not have a metadata_path attribute.") - return None, errors - if not dataset.metadata_path.exists(): - errors.append(f"Metadata file {dataset.metadata_path} does not exist.") - return None, errors +def _ensure_supported(dataset: Any): + if not isinstance(dataset, Dataset): + raise UnsupportedDatasetError("BalancedBatchSampler requires a dataset.") + + if not dataset.metadata_hasattr("natoms"): + raise UnsupportedDatasetError( + "BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms." + ) - key = {"atoms": "natoms", "neighbors": "neighbors"}[mode] - sizes = np.load(dataset.metadata_path)[key] + logging.debug(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}") + return dataset - return sizes, errors +class BalancedBatchSampler(BatchSampler): def __init__( self, - dataset, + dataset: Dataset, + *, batch_size: int, num_replicas: int, rank: int, device: torch.device, seed: int, - mode: str | bool = "atoms", + mode: bool | Literal["atoms"] = "atoms", shuffle: bool = True, + on_error: Literal["warn_and_balance", "warn_and_no_balance", "raise"] = "raise", drop_last: bool = False, - force_balancing: bool = False, - throw_on_error: bool = False, - ) -> None: - if mode is True: - mode = "atoms" - - if isinstance(mode, str): - mode = mode.lower() - if mode not in ("atoms", "neighbors"): - raise ValueError( - f"Invalid mode {mode}. Must be one of 'atoms', 'neighbors', or a boolean." - ) + ): + """ + Initializes a BalancedBatchSampler object. - self.dataset = dataset - self.batch_size = batch_size - self.num_replicas = num_replicas - self.rank = rank - self.device = device - self.mode = mode - self.shuffle = shuffle - self.drop_last = drop_last + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): The size of each batch. + num_replicas (int): The number of processes participating in distributed training. + rank (int): The rank of the current process in distributed training. + device (torch.device): The device to use for the batches. + mode (str or bool, optional): The mode to use for balancing the batches. Defaults to "atoms". + shuffle (bool, optional): Whether to shuffle the samples. Defaults to True. + on_error (Literal["warn_and_balance", "warn_and_no_balance", "raise"], optional): The action to take when an error occurs (i.e., when we have an invalid dataset). Defaults to "raise". + - "warn_and_balance": Raise a warning and balance the batch by manually loading the data samples and counting the number of nodes (this is slow). + - "warn_and_no_balance": Raise a warning and do not do any balancing. + - "raise": Raise an error. + drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False. + """ + self.disabled = False + self.on_error = on_error + + if mode is False: + logging.warning(f"Disabled BalancedBatchSampler because {mode=}.") + self.disabled = True + elif mode.lower() != "atoms": + raise ValueError( + f"Only mode='atoms' or mode=True is supported, got {mode=}." + ) + elif num_replicas == 1: + logging.warning(f"Disabled BalancedBatchSampler because {num_replicas=}.") + self.disabled = True + + try: + dataset = _ensure_supported(dataset) + except UnsupportedDatasetError as error: + if self.on_error == "raise": + raise error + if self.on_error == "warn_and_balance": + logging.warning( + f"Failed to get data sizes from metadata, loading data to get sizes (THIS IS SLOW). {error}" + ) + elif self.on_error == "warn_and_no_balance": + logging.warning( + f"Failed to get data sizes, falling back to uniform partitioning. {error}" + ) + else: + raise ValueError(f"Unknown on_error={self.on_error}") from error - self.single_sampler = StatefulDistributedSampler( - self.dataset, + sampler = StatefulDistributedSampler( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, @@ -162,82 +186,59 @@ def __init__( batch_size=batch_size, seed=seed, ) - self.batch_sampler = BatchSampler( - self.single_sampler, - batch_size, - drop_last=drop_last, - ) - - self.sizes = None - self.balance_batches = False - if self.num_replicas <= 1: - logging.info("Batch balancing is disabled for single GPU training.") - return - - if self.mode is False: - logging.info( - "Batch balancing is disabled because `optim.load_balancing` is `False`" - ) - return - - self.sizes, errors = self._load_dataset(dataset, self.mode) - if self.sizes is None: - self.balance_batches = force_balancing - if force_balancing: - errors.append( - "BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead! " - "You can disable balancing by setting `optim.load_balancing` to `False`." - ) - else: - errors.append( - "Batches will not be balanced, which can incur significant overhead!" - ) - else: - self.balance_batches = True - - if errors: - msg = "BalancedBatchSampler: " + " ".join(errors) - if throw_on_error: - raise RuntimeError(msg) + super().__init__(sampler, batch_size=batch_size, drop_last=drop_last) + self.device = device - logging.warning(msg) + logging.info( + f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}" + ) - def __len__(self) -> int: - return len(self.batch_sampler) + def _get_natoms(self, batch_idx: list[int]): + if self.sampler.dataset.metadata_hasattr("natoms"): + return np.array( + self.sampler.dataset.get_metadata("natoms", batch_idx) + ).reshape(-1) + if self.on_error == "warn_and_balance": + return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx]) + return None def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None: - if not hasattr(self.single_sampler, "set_epoch_and_start_iteration"): + if not isinstance(self.sampler, StatefulDistributedSampler): if start_iteration != 0: raise NotImplementedError( f"{type(self.single_sampler)} does not support resuming from a nonzero step." ) - self.single_sampler.set_epoch(epoch) + self.sampler.set_epoch(epoch) else: - self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration) + self.sampler.set_epoch_and_start_iteration(epoch, start_iteration) + + def set_epoch(self, epoch: int) -> None: + if isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + @staticmethod + def _dist_enabled(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @override def __iter__(self): - if not self.balance_batches: - yield from self.batch_sampler + if self.disabled or not self._dist_enabled(): + yield from super().__iter__() return - for batch_idx in self.batch_sampler: - if self.sizes is None: - # Unfortunately, we need to load the data to know the image sizes - data_list = [self.dataset[idx] for idx in batch_idx] - - if self.mode == "atoms": - sizes = [data.num_nodes for data in data_list] - elif self.mode == "neighbors": - sizes = [data.edge_index.shape[1] for data in data_list] - else: - raise NotImplementedError( - f"Unknown load balancing mode: {self.mode}" - ) - else: - sizes = [self.sizes[idx] for idx in batch_idx] - - idx_sizes = torch.stack([torch.tensor(batch_idx), torch.tensor(sizes)]) + for batch_idx in super().__iter__(): + sizes = self._get_natoms(batch_idx) + if sizes is None: # on_error == "warn_and_no_balance" is set + yield batch_idx + continue + + idx_sizes = torch.stack( + [ + torch.tensor(batch_idx, device=self.device), + torch.tensor(sizes, device=self.device), + ] + ) idx_sizes_all = distutils.all_gather(idx_sizes, device=self.device) idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu() if gp_utils.initialized(): @@ -245,9 +246,10 @@ def __iter__(self): idx_all = idx_sizes_all[0] sizes_all = idx_sizes_all[1] - local_idx_balanced = balanced_partition( - sizes_all.numpy(), num_parts=self.num_replicas + local_idx_balanced = _balanced_partition( + sizes_all.numpy(), + num_parts=self.sampler.num_replicas, ) # Since DistributedSampler pads the last batch # this should always have an entry for each replica. - yield idx_all[local_idx_balanced[self.rank]] + yield idx_all[local_idx_balanced[self.sampler.rank]] diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 919f7ba66d..f6bf88ccaf 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -10,7 +10,8 @@ import logging import os import subprocess -from typing import TypeVar +from datetime import timedelta +from typing import Any, TypeVar import torch import torch.distributed as dist @@ -27,6 +28,7 @@ def os_environ_get_or_throw(x: str) -> str: def setup(config) -> None: + timeout = timedelta(minutes=config.get("timeout", 30)) if config["submit"]: node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: @@ -72,6 +74,7 @@ def setup(config) -> None: init_method=config["init_method"], world_size=config["world_size"], rank=config["rank"], + timeout=timeout, ) except subprocess.CalledProcessError as e: # scontrol failed raise e @@ -95,10 +98,11 @@ def setup(config) -> None: rank=world_rank, world_size=world_size, init_method="env://", + timeout=timeout, ) else: config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend="nccl") + dist.init_process_group(backend=config.get("backend", "nccl"), timeout=timeout) def cleanup() -> None: @@ -135,6 +139,14 @@ def broadcast( dist.broadcast(tensor, src, group, async_op) +def broadcast_object_list( + object_list: list[Any], src: int, group=dist.group.WORLD, device: str | None = None +) -> None: + if get_world_size() == 1: + return + dist.broadcast_object_list(object_list, src, group, device) + + def all_reduce( data, group=dist.group.WORLD, average: bool = False, device=None ) -> torch.Tensor: @@ -144,7 +156,7 @@ def all_reduce( if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) dist.all_reduce(tensor, group=group) if average: tensor /= get_world_size() @@ -162,7 +174,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]: if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())] dist.all_gather(tensor_list, tensor, group=group) if not isinstance(data, torch.Tensor): diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index 8aaf822105..130daba2d5 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -44,9 +44,55 @@ class PGConfig: use_gp: bool = True +def init_env_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["RANK"] = str(rank) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + +def init_pg_and_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + # setup default process group + dist.init_process_group( + rank=rank, + world_size=pg_setup_params.world_size, + backend=pg_setup_params.backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) + # setup gp + if pg_setup_params.use_gp: + config = { + "gp_gpus": pg_setup_params.gp_group_size, + "distributed_backend": pg_setup_params.backend, + } + setup_gp(config) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + def spawn_multi_process( config: PGConfig, test_method: callable, + init_and_launch: callable, *test_method_args: Any, **test_method_kwargs: Any, ) -> list[Any]: @@ -72,7 +118,7 @@ def spawn_multi_process( torch.multiprocessing.spawn( # torch.multiprocessing.spawn sends rank as the first param # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn - _init_pg_and_rank_and_launch_test, + init_and_launch, args=( config, mp_output_dict, @@ -84,32 +130,3 @@ def spawn_multi_process( ) return [mp_output_dict[i] for i in range(config.world_size)] - - -def _init_pg_and_rank_and_launch_test( - rank: int, - pg_setup_params: PGConfig, - mp_output_dict: dict[int, object], - test_method: callable, - args: list[object], - kwargs: dict[str, object], -) -> None: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = pg_setup_params.port - os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) - os.environ["LOCAL_RANK"] = str(rank) - # setup default process group - dist.init_process_group( - rank=rank, - world_size=pg_setup_params.world_size, - backend=pg_setup_params.backend, - timeout=timedelta(seconds=10), # setting up timeout for distributed collectives - ) - # setup gp - if pg_setup_params.use_gp: - config = { - "gp_gpus": pg_setup_params.gp_group_size, - "distributed_backend": pg_setup_params.backend, - } - setup_gp(config) - mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme diff --git a/src/fairchem/core/datasets/__init__.py b/src/fairchem/core/datasets/__init__.py index 1fd4b51cd5..dc3f7d0e4d 100644 --- a/src/fairchem/core/datasets/__init__.py +++ b/src/fairchem/core/datasets/__init__.py @@ -5,23 +5,19 @@ from __future__ import annotations from .ase_datasets import AseDBDataset, AseReadDataset, AseReadMultiStructureDataset +from .base_dataset import create_dataset from .lmdb_database import LMDBDatabase from .lmdb_dataset import ( LmdbDataset, - SinglePointLmdbDataset, - TrajectoryLmdbDataset, data_list_collater, ) -from .oc22_lmdb_dataset import OC22LmdbDataset __all__ = [ "AseDBDataset", "AseReadDataset", "AseReadMultiStructureDataset", "LmdbDataset", - "SinglePointLmdbDataset", - "TrajectoryLmdbDataset", - "data_list_collater", - "OC22LmdbDataset", "LMDBDatabase", + "create_dataset", + "data_list_collater", ] diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index 0b1a40e72d..fc117bf5c4 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -13,20 +13,19 @@ import os import warnings from abc import ABC, abstractmethod -from functools import cache, reduce +from functools import cache from glob import glob from pathlib import Path from typing import Any, Callable import ase import numpy as np -import torch.nn from torch import tensor -from torch.utils.data import Dataset from tqdm import tqdm from fairchem.core.common.registry import registry from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.lmdb_database import LMDBDatabase from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms @@ -60,7 +59,7 @@ def apply_one_tags( return atoms -class AseAtomsDataset(Dataset, ABC): +class AseAtomsDataset(BaseDataset, ABC): """ This is an abstract Dataset that includes helpful utilities for turning ASE atoms objects into OCP-usable data objects. This should not be instantiated directly @@ -81,7 +80,7 @@ def __init__( config: dict, atoms_transform: Callable[[ase.Atoms, Any, ...], ase.Atoms] = apply_one_tags, ) -> None: - self.config = config + super().__init__(config) a2g_args = config.get("a2g_args", {}) or {} @@ -96,19 +95,13 @@ def __init__( self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - self.lin_ref = None - if self.config.get("lin_ref", False): - lin_ref = torch.tensor( - np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - ) - self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) - self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): self.__getitem__ = cache(self.__getitem__) self.ids = self._load_dataset_get_ids(config) + self.num_samples = len(self.ids) if len(self.ids) == 0: raise ValueError( @@ -116,9 +109,6 @@ def __init__( f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" ) - def __len__(self) -> int: - return len(self.ids) - def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): @@ -177,11 +167,7 @@ def get_relaxed_energy(self, identifier): "the r_data_keys argument under a2g_args." ) - def close_db(self) -> None: - # This method is sometimes called by a trainer - pass - - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: metadata = {} if num_samples < len(self): @@ -200,6 +186,18 @@ def get_metadata(self, num_samples: int = 100) -> dict: return metadata + def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here + if attr != "natoms": + return None + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) + return len(self.get_atoms(idx)) + @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): @@ -402,7 +400,7 @@ def get_atoms(self, idx: str) -> ase.Atoms: return atoms - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: return {} def get_relaxed_energy(self, identifier) -> float: @@ -472,13 +470,14 @@ class AseDBDataset(AseAtomsDataset): def _load_dataset_get_ids(self, config: dict) -> list[int]: if isinstance(config["src"], list): - if os.path.isdir(config["src"][0]): - filepaths = reduce( - lambda x, y: x + y, - (glob(f"{path}/*") for path in config["src"]), - ) - else: - filepaths = config["src"] + filepaths = [] + for path in config["src"]: + if os.path.isdir(path): + filepaths.extend(glob(f"{path}/*")) + elif os.path.isfile(path): + filepaths.append(path) + else: + raise RuntimeError(f"Error reading dataset in {path}!") elif os.path.isfile(config["src"]): filepaths = [config["src"]] elif os.path.isdir(config["src"]): @@ -559,16 +558,16 @@ def connect_db( return ase.db.connect(address, **connect_args) - def close_db(self) -> None: + def __del__(self): for db in self.dbs: if hasattr(db, "close"): db.close() - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: logging.warning( "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" ) if self.dbs[0].metadata == {}: - return super().get_metadata(num_samples) + return super().sample_property_metadata(num_samples) return copy.deepcopy(self.dbs[0].metadata) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py new file mode 100644 index 0000000000..2ca26596c3 --- /dev/null +++ b/src/fairchem/core/datasets/base_dataset.py @@ -0,0 +1,227 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from abc import ABCMeta +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + NamedTuple, + TypeVar, +) + +import numpy as np +import torch +from torch import randperm +from torch.utils.data import Dataset +from torch.utils.data import Subset as Subset_ + +from fairchem.core.common.registry import registry + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import ArrayLike + + +T_co = TypeVar("T_co", covariant=True) + + +class DatasetMetadata(NamedTuple): + natoms: ArrayLike | None = None + + +class UnsupportedDatasetError(ValueError): + pass + + +class BaseDataset(Dataset[T_co], metaclass=ABCMeta): + """Base Dataset class for all OCP datasets.""" + + def __init__(self, config: dict): + """Initialize + + Args: + config (dict): dataset configuration + """ + self.config = config + self.paths = [] + + if "src" in self.config: + if isinstance(config["src"], str): + self.paths = [Path(self.config["src"])] + else: + self.paths = tuple(Path(path) for path in config["src"]) + + self.lin_ref = None + if self.config.get("lin_ref", False): + lin_ref = torch.tensor( + np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] + ) + self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) + + def __len__(self) -> int: + return self.num_samples + + def metadata_hasattr(self, attr) -> bool: + if self._metadata is None: + return False + return hasattr(self._metadata, attr) + + @cached_property + def indices(self): + return np.arange(self.num_samples, dtype=int) + + @cached_property + def _metadata(self) -> DatasetMetadata: + # logic to read metadata file here + metadata_npzs = [] + if self.config.get("metadata_path", None) is not None: + metadata_npzs.append( + np.load(self.config["metadata_path"], allow_pickle=True) + ) + + else: + for path in self.paths: + if path.is_file(): + metadata_file = path.parent / "metadata.npz" + else: + metadata_file = path / "metadata.npz" + if metadata_file.is_file(): + metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) + + if len(metadata_npzs) == 0: + logging.warning( + f"Could not find dataset metadata.npz files in '{self.paths}'" + ) + return None + + metadata = DatasetMetadata( + **{ + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in DatasetMetadata._fields + } + ) + + assert metadata.natoms.shape[0] == len( + self + ), "Loaded metadata and dataset size mismatch." + + return metadata + + def get_metadata(self, attr, idx): + if self._metadata is not None: + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + return None + + +class Subset(Subset_, BaseDataset): + """A pytorch subset that also takes metadata if given.""" + + def __init__( + self, + dataset: BaseDataset, + indices: Sequence[int], + metadata: DatasetMetadata | None = None, + ) -> None: + super().__init__(dataset, indices) + self.metadata = metadata + self.indices = indices + self.num_samples = len(indices) + self.config = dataset.config + + @cached_property + def _metadata(self) -> DatasetMetadata: + return self.dataset._metadata + + def get_metadata(self, attr, idx): + if isinstance(idx, list): + return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) + return self.dataset.get_metadata(attr, self.indices[idx]) + + +def create_dataset(config: dict[str, Any], split: str) -> Subset: + """Create a dataset from a config dictionary + + Args: + config (dict): dataset config dictionary + split (str): name of split + + Returns: + Subset: dataset subset class + """ + # Initialize the dataset + dataset_cls = registry.get_dataset_class(config.get("format", "lmdb")) + assert issubclass(dataset_cls, Dataset), f"{dataset_cls} is not a Dataset" + + # remove information about other splits, only keep specified split + # this may only work with the mt config not main config + current_split_config = config.copy() + if "splits" in current_split_config: + current_split_config.pop("splits") + current_split_config.update(config["splits"][split]) + + seed = current_split_config.get("seed", 0) + if split != "train": + seed += ( + 1 # if we use same dataset for train / val , make sure its diff sampling + ) + + g = torch.Generator() + g.manual_seed(seed) + + dataset = dataset_cls(current_split_config) + # Get indices of the dataset + indices = dataset.indices + max_atoms = current_split_config.get("max_atoms", None) + if max_atoms is not None: + if not dataset.metadata_hasattr("natoms"): + raise ValueError("Cannot use max_atoms without dataset metadata") + indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms] + + # Apply dataset level transforms + # TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle? + first_n = current_split_config.get("first_n") + sample_n = current_split_config.get("sample_n") + no_shuffle = current_split_config.get("no_shuffle") + # this is true if at most one of the mutually exclusive arguments are set + if sum(arg is not None for arg in (first_n, sample_n, no_shuffle)) > 1: + raise ValueError( + "sample_n, first_n, no_shuffle are mutually exclusive arguments. Only one can be provided." + ) + if first_n is not None: + max_index = first_n + elif sample_n is not None: + # shuffle by default, user can disable to optimize if they have confidence in dataset + # shuffle all datasets by default to avoid biasing the sampling in concat dataset + # TODO only shuffle if split is train + max_index = sample_n + indices = indices[randperm(len(indices), generator=g)] + else: + max_index = len(indices) + indices = ( + indices if no_shuffle else indices[randperm(len(indices), generator=g)] + ) + + if max_index > len(indices): + msg = ( + f"Cannot take {max_index} data points from a dataset of only length {len(indices)}.\n" + f"Make sure to set first_n or sample_n to a number =< the total samples in dataset." + ) + if max_atoms is not None: + msg = msg[:-1] + f"that are smaller than the given max_atoms {max_atoms}." + raise ValueError(msg) + + indices = indices[:max_index] + + return Subset(dataset, indices, metadata=dataset._metadata) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index 91ced220ea..ca1fcc2b77 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -9,32 +9,33 @@ import bisect import logging import pickle -import warnings -from pathlib import Path -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import lmdb import numpy as np import torch -from torch.utils.data import Dataset from torch_geometric.data import Batch -from torch_geometric.data.data import BaseData from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms +if TYPE_CHECKING: + from pathlib import Path + + from torch_geometric.data.data import BaseData + T_co = TypeVar("T_co", covariant=True) @registry.register_dataset("lmdb") @registry.register_dataset("single_point_lmdb") @registry.register_dataset("trajectory_lmdb") -class LmdbDataset(Dataset[T_co]): - metadata_path: Path +class LmdbDataset(BaseDataset): sharded: bool r"""Dataset class to load from LMDB files containing relaxation @@ -50,20 +51,21 @@ class LmdbDataset(Dataset[T_co]): """ def __init__(self, config) -> None: - super().__init__() - self.config = config + super().__init__(config) assert not self.config.get( "train_on_oc20_total_energies", False ), "For training on total energies set dataset=oc22_lmdb" - self.path = Path(self.config["src"]) + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] + if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" - self.metadata_path = self.path / "metadata.npz" - self._keys = [] self.envs = [] for db_path in db_paths: @@ -86,7 +88,6 @@ def __init__(self, config) -> None: self._keylen_cumulative = np.cumsum(keylens).tolist() self.num_samples = sum(keylens) else: - self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) # If "length" encoded as ascii is present, use that @@ -113,19 +114,15 @@ def __init__(self, config) -> None: self.indices, self.config.get("total_shards", 1) ) # limit each process to see a subset of data based off defined shard - self.available_indices = self.shards[self.config.get("shard", 0)] - self.num_samples = len(self.available_indices) + self.indices = self.shards[self.config.get("shard", 0)] + self.num_samples = len(self.indices) self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - def __len__(self) -> int: - return self.num_samples - def __getitem__(self, idx: int) -> T_co: # if sharding, remap idx to appropriate idx of the sharded set - if self.sharded: - idx = self.available_indices[idx] + idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) @@ -165,14 +162,14 @@ def connect_db(self, lmdb_path: Path | None = None) -> lmdb.Environment: max_readers=1, ) - def close_db(self) -> None: + def __del__(self): if not self.path.is_file(): for env in self.envs: env.close() else: self.env.close() - def get_metadata(self, num_samples: int = 100): + def sample_property_metadata(self, num_samples: int = 100): # This will interogate the classic OCP LMDB format to determine # which properties are present and attempt to guess their shapes # and whether they are intensive or extensive. @@ -214,26 +211,6 @@ def get_metadata(self, num_samples: int = 100): } -class SinglePointLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "SinglePointLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - -class TrajectoryLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "TrajectoryLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData: batch = Batch.from_data_list(data_list) diff --git a/src/fairchem/core/datasets/oc22_lmdb_dataset.py b/src/fairchem/core/datasets/oc22_lmdb_dataset.py index 0c6d4e8bfb..867a72726f 100644 --- a/src/fairchem/core/datasets/oc22_lmdb_dataset.py +++ b/src/fairchem/core/datasets/oc22_lmdb_dataset.py @@ -9,22 +9,21 @@ import bisect import pickle -from pathlib import Path import lmdb import numpy as np import torch -from torch.utils.data import Dataset from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance as aii from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.modules.transforms import DataTransforms @registry.register_dataset("oc22_lmdb") -class OC22LmdbDataset(Dataset): +class OC22LmdbDataset(BaseDataset): r"""Dataset class to load from LMDB files containing relaxation trajectories or single point computations. @@ -43,10 +42,13 @@ class OC22LmdbDataset(Dataset): """ def __init__(self, config, transform=None) -> None: - super().__init__() - self.config = config + super().__init__(config) + + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] - self.path = Path(self.config["src"]) self.data2train = self.config.get("data2train", "all") if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) @@ -114,19 +116,11 @@ def __init__(self, config, transform=None) -> None: if self.train_on_oc20_total_energies: with open(config["oc20_ref"], "rb") as fp: self.oc20_ref = pickle.load(fp) - if self.config.get("lin_ref", False): - coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - self.lin_ref = torch.nn.Parameter(torch.tensor(coeff), requires_grad=False) - self.subsample = aii(self.config.get("subsample", False), bool) def __len__(self) -> int: - if self.subsample: - return min(self.subsample, self.num_samples) return self.num_samples def __getitem__(self, idx): - if self.data2train != "all": - idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 42790643a9..8ce8f3fcb1 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -8,27 +8,42 @@ from __future__ import annotations import logging +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch -import torch.nn as nn +from torch import nn from torch_geometric.nn import radius_graph +from fairchem.core.common.registry import registry from fairchem.core.common.utils import ( compute_neighbors, get_pbc_distances, radius_graph_pbc, ) +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +@dataclass +class GraphData: + """Class to keep graph attributes nicely packaged.""" + + edge_index: torch.Tensor + edge_distance: torch.Tensor + edge_distance_vec: torch.Tensor + cell_offsets: torch.Tensor + offset_distances: torch.Tensor + neighbors: torch.Tensor + batch_full: torch.Tensor # used for GP functionality + atomic_numbers_full: torch.Tensor # used for GP functionality + node_offset: int = 0 # used for GP functionality -class BaseModel(nn.Module): - def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None) -> None: - super().__init__() - self.num_atoms = num_atoms - self.bond_feat_dim = bond_feat_dim - self.num_targets = num_targets - def forward(self, data): - raise NotImplementedError +class GraphModelMixin: + """Mixin Model class implementing some general convenience properties and methods.""" def generate_graph( self, @@ -38,10 +53,12 @@ def generate_graph( use_pbc=None, otf_graph=None, enforce_max_neighbors_strictly=None, + use_pbc_single=False, ): cutoff = cutoff or self.cutoff max_neighbors = max_neighbors or self.max_neighbors use_pbc = use_pbc or self.use_pbc + use_pbc_single = use_pbc_single or self.use_pbc_single otf_graph = otf_graph or self.otf_graph if enforce_max_neighbors_strictly is not None: @@ -69,12 +86,47 @@ def generate_graph( if use_pbc: if otf_graph: - edge_index, cell_offsets, neighbors = radius_graph_pbc( - data, - cutoff, - max_neighbors, - enforce_max_neighbors_strictly, - ) + if use_pbc_single: + ( + edge_index_per_system, + cell_offsets_per_system, + neighbors_per_system, + ) = list( + zip( + *[ + radius_graph_pbc( + data[idx], + cutoff, + max_neighbors, + enforce_max_neighbors_strictly, + ) + for idx in range(len(data)) + ] + ) + ) + + # atom indexs in the edge_index need to be offset + atom_index_offset = data.natoms.cumsum(dim=0).roll(1) + atom_index_offset[0] = 0 + edge_index = torch.hstack( + [ + edge_index_per_system[idx] + atom_index_offset[idx] + for idx in range(len(data)) + ] + ) + cell_offsets = torch.vstack(cell_offsets_per_system) + neighbors = torch.hstack(neighbors_per_system) + else: + ## TODO this is the original call, but blows up with memory + ## using two different samples + ## sid='mp-675045-mp-675045-0-7' (MPTRAJ) + ## sid='75396' (OC22) + edge_index, cell_offsets, neighbors = radius_graph_pbc( + data, + cutoff, + max_neighbors, + enforce_max_neighbors_strictly, + ) out = get_pbc_distances( data.pos, @@ -109,13 +161,16 @@ def generate_graph( ) neighbors = compute_neighbors(data, edge_index) - return ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - cell_offset_distances, - neighbors, + return GraphData( + edge_index=edge_index, + edge_distance=edge_dist, + edge_distance_vec=distance_vec, + cell_offsets=cell_offsets, + offset_distances=cell_offset_distances, + neighbors=neighbors, + node_offset=0, + batch_full=data.batch, + atomic_numbers_full=data.atomic_numbers.long(), ) @property @@ -130,3 +185,90 @@ def no_weight_decay(self) -> list: if "embedding" in name or "frequencies" in name or "bias" in name: no_wd_list.append(name) return no_wd_list + + +class HeadInterface(metaclass=ABCMeta): + @abstractmethod + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Head forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + emb: dict[str->torch.Tensor] + Embeddings of the input as generated by the backbone + + Returns + ------- + outputs: dict[str->torch.Tensor] + Return one or more targets generated by this head + """ + return + + +class BackboneInterface(metaclass=ABCMeta): + @abstractmethod + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + """Backbone forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + + Returns + ------- + embedding: dict[str->torch.Tensor] + Return backbone embeddings for the given input + """ + return + + +@registry.register_model("hydra") +class HydraModel(nn.Module, GraphModelMixin): + def __init__( + self, + backbone: dict, + heads: dict, + otf_graph: bool = True, + ): + super().__init__() + self.otf_graph = otf_graph + + backbone_model_name = backbone.pop("model") + self.backbone: BackboneInterface = registry.get_model_class( + backbone_model_name + )( + **backbone, + ) + + # Iterate through outputs_cfg and create heads + self.output_heads: dict[str, HeadInterface] = {} + + head_names_sorted = sorted(heads.keys()) + for head_name in head_names_sorted: + head_config = heads[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) + + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) + + self.output_heads = torch.nn.ModuleDict(self.output_heads) + + def forward(self, data: Batch): + emb = self.backbone(data) + # Predict all output properties for all structures in the batch for now. + out = {} + for k in self.output_heads: + out.update(self.output_heads[k](data, emb)) + + return out diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index 296a77bbba..f555448261 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -34,6 +34,8 @@ from __future__ import annotations +import typing + import torch from torch import nn from torch_geometric.nn.inits import glorot_orthogonal @@ -49,7 +51,10 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch try: import sympy as sym @@ -57,7 +62,7 @@ sym = None -class InteractionPPBlock(torch.nn.Module): +class InteractionPPBlock(nn.Module): def __init__( self, hidden_channels: int, @@ -90,11 +95,11 @@ def __init__( self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) # Residual layers before and after skip connection. - self.layers_before_skip = torch.nn.ModuleList( + self.layers_before_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)] ) self.lin = nn.Linear(hidden_channels, hidden_channels) - self.layers_after_skip = torch.nn.ModuleList( + self.layers_after_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)] ) @@ -153,7 +158,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji): return h -class OutputPPBlock(torch.nn.Module): +class OutputPPBlock(nn.Module): def __init__( self, num_radial: int, @@ -169,7 +174,7 @@ def __init__( self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) - self.lins = torch.nn.ModuleList() + self.lins = nn.ModuleList() for _ in range(num_layers): self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) @@ -193,7 +198,7 @@ def forward(self, x, rbf, i, num_nodes: int | None = None): return self.lin(x) -class DimeNetPlusPlus(torch.nn.Module): +class DimeNetPlusPlus(nn.Module): r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. Args: @@ -241,7 +246,6 @@ def __init__( act = activation_resolver(act) super().__init__() - self.cutoff = cutoff if sym is None: @@ -256,7 +260,7 @@ def __init__( self.emb = EmbeddingBlock(num_radial, hidden_channels, act) - self.output_blocks = torch.nn.ModuleList( + self.output_blocks = nn.ModuleList( [ OutputPPBlock( num_radial, @@ -270,7 +274,7 @@ def __init__( ] ) - self.interaction_blocks = torch.nn.ModuleList( + self.interaction_blocks = nn.ModuleList( [ InteractionPPBlock( hidden_channels, @@ -330,14 +334,41 @@ def forward(self, z, pos, batch=None): raise NotImplementedError +@registry.register_model("dimenetplusplus_energy_and_force_head") +class DimeNetPlusPlusWrapEnergyAndForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.regress_forces = backbone.regress_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + outputs = { + "energy": ( + emb["P"].sum(dim=0) + if data.batch is None + else scatter(emb["P"], data.batch, dim=0) + ) + } + if self.regress_forces: + outputs["forces"] = -1 * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] + ) + return outputs + + @registry.register_model("dimenetplusplus") -class DimeNetPlusPlusWrap(DimeNetPlusPlus, BaseModel): +class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin): def __init__( self, - num_atoms: int, - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, hidden_channels: int = 128, num_blocks: int = 4, @@ -353,16 +384,16 @@ def __init__( num_after_skip: int = 2, num_output_layers: int = 3, ) -> None: - self.num_targets = num_targets self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.max_neighbors = 50 super().__init__( hidden_channels=hidden_channels, - out_channels=num_targets, + out_channels=1, num_blocks=num_blocks, int_emb_size=int_emb_size, basis_emb_size=basis_emb_size, @@ -380,22 +411,15 @@ def __init__( def _forward(self, data): pos = data.pos batch = data.batch - ( - edge_index, - dist, - _, - cell_offsets, - offsets, - neighbors, - ) = self.generate_graph(data) - - data.edge_index = edge_index - data.cell_offsets = cell_offsets - data.neighbors = neighbors - j, i = edge_index + graph = self.generate_graph(data) + + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( - edge_index, + graph.edge_index, data.cell_offsets, num_nodes=data.atomic_numbers.size(0), ) @@ -405,8 +429,8 @@ def _forward(self, data): pos_j = pos[idx_j].detach() if self.use_pbc: pos_ji, pos_kj = ( - pos[idx_j].detach() - pos_i + offsets[idx_ji], - pos[idx_k].detach() - pos_j + offsets[idx_kj], + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], ) else: pos_ji, pos_kj = ( @@ -418,8 +442,8 @@ def _forward(self, data): b = torch.cross(pos_ji, pos_kj).norm(dim=-1) angle = torch.atan2(b, a) - rbf = self.rbf(dist) - sbf = self.sbf(dist, angle, idx_kj) + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) # Embedding block. x = self.emb(data.atomic_numbers.long(), rbf, i, j) @@ -441,16 +465,13 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = ( - -1 - * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + forces = -1 * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] ) outputs["forces"] = forces @@ -459,3 +480,57 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("dimenetplusplus_backbone") +class DimeNetPlusPlusWrapBackbone(DimeNetPlusPlusWrap, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + if self.regress_forces: + data.pos.requires_grad_(True) + pos = data.pos + graph = self.generate_graph(data) + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index + + _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( + graph.edge_index, + data.cell_offsets, + num_nodes=data.atomic_numbers.size(0), + ) + + # Calculate angles. + pos_i = pos[idx_i].detach() + pos_j = pos[idx_j].detach() + if self.use_pbc: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], + ) + else: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i, + pos[idx_k].detach() - pos_j, + ) + + a = (pos_ji * pos_kj).sum(dim=-1) + b = torch.cross(pos_ji, pos_kj).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) + + # Embedding block. + x = self.emb(data.atomic_numbers.long(), rbf, i, j) + P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) + + # Interaction blocks. + for interaction_block, output_block in zip( + self.interaction_blocks, self.output_blocks[1:] + ): + x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) + P += output_block(x, rbf, i, num_nodes=pos.size(0)) + + return {"P": P, "edge_embedding": x, "edge_idx": i} diff --git a/src/fairchem/core/models/equiformer_v2/__init__.py b/src/fairchem/core/models/equiformer_v2/__init__.py index 424b64f9ed..720f890f65 100644 --- a/src/fairchem/core/models/equiformer_v2/__init__.py +++ b/src/fairchem/core/models/equiformer_v2/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .equiformer_v2_oc20 import EquiformerV2_OC20 as EquiformerV2 +from .equiformer_v2 import EquiformerV2 __all__ = ["EquiformerV2"] diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py similarity index 72% rename from src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py rename to src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 8edf81319c..06a0280e98 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -10,13 +10,15 @@ from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): pass +import typing + from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding @@ -42,13 +44,18 @@ TransBlockV2, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + + from fairchem.core.models.base import GraphData + # Statistics of IS2RE 100K _AVG_NUM_NODES = 77.81317 _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 @registry.register_model("equiformer_v2") -class EquiformerV2_OC20(BaseModel): +class EquiformerV2(nn.Module, GraphModelMixin): """ Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation @@ -108,10 +115,8 @@ class EquiformerV2_OC20(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = True, max_neighbors: int = 500, @@ -165,6 +170,7 @@ def __init__( raise ImportError self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.regress_forces = regress_forces self.otf_graph = otf_graph self.max_neighbors = max_neighbors @@ -436,23 +442,12 @@ def forward(self, data): self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, ) - data_batch_full = data.batch data_batch = data.batch - atomic_numbers_full = atomic_numbers - node_offset = 0 if gp_utils.initialized(): ( atomic_numbers, @@ -462,12 +457,17 @@ def forward(self, data): edge_distance, edge_distance_vec, ) = self._init_gp_partitions( - atomic_numbers_full, - data_batch_full, - edge_index, - edge_distance, - edge_distance_vec, + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + ############################################################### # Entering Graph Parallel Region # after this point, if using gp, then node, edge tensors are split @@ -485,7 +485,9 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -496,7 +498,6 @@ def forward(self, data): ############################################################### # Init per node representations using an atomic number based embedding - offset = 0 x = SO3_Embedding( len(atomic_numbers), self.lmax_list, @@ -519,27 +520,27 @@ def forward(self, data): offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) # Edge encoding (distance and atom edge) - edge_distance = self.distance_expansion(edge_distance) + graph.edge_distance = self.distance_expansion(graph.edge_distance) if self.share_atom_edge_embedding and self.use_atom_edge_embedding: - source_element = atomic_numbers_full[ - edge_index[0] + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] ] # Source atom atomic number - target_element = atomic_numbers_full[ - edge_index[1] + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] ] # Target atom atomic number source_embedding = self.source_embedding(source_element) target_embedding = self.target_embedding(target_element) - edge_distance = torch.cat( - (edge_distance, source_embedding, target_embedding), dim=1 + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 ) # Edge-degree embedding edge_degree = self.edge_degree_embedding( - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, len(atomic_numbers), - node_offset, + graph.node_offset, ) x.embedding = x.embedding + edge_degree.embedding @@ -550,11 +551,11 @@ def forward(self, data): for i in range(self.num_layers): x = self.blocks[i]( x, # SO3_Embedding - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, batch=data_batch, # for GraphDropPath - node_offset=node_offset, + node_offset=graph.node_offset, ) # Final layer norm @@ -572,7 +573,7 @@ def forward(self, data): device=node_energy.device, dtype=node_energy.dtype, ) - energy.index_add_(0, data_batch_full, node_energy.view(-1)) + energy.index_add_(0, graph.batch_full, node_energy.view(-1)) energy = energy / self.avg_num_nodes # Add the per-atom linear references to the energy. @@ -594,8 +595,8 @@ def forward(self, data): with torch.cuda.amp.autocast(False): energy = energy.to(self.energy_lin_ref.dtype).index_add( 0, - data_batch_full, - self.energy_lin_ref[atomic_numbers_full], + graph.batch_full, + self.energy_lin_ref[graph.atomic_numbers_full], ) outputs = {"energy": energy} @@ -605,10 +606,10 @@ def forward(self, data): if self.regress_forces: forces = self.force_block( x, - atomic_numbers_full, - edge_distance, - edge_index, - node_offset=node_offset, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + node_offset=graph.node_offset, ) forces = forces.embedding.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() @@ -678,3 +679,209 @@ def no_weight_decay(self) -> set: no_wd_list.append(global_parameter_name) return set(no_wd_list) + + +@registry.register_model("equiformer_v2_backbone") +class EquiformerV2Backbone(EquiformerV2, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + return {"node_embedding": x, "graph": graph} + + +@registry.register_model("equiformer_v2_energy_head") +class EquiformerV2EnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.avg_num_nodes = backbone.avg_num_nodes + self.energy_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): + node_energy = self.energy_block(emb["node_embedding"]) + node_energy = node_energy.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) + energy = torch.zeros( + len(data.natoms), + device=node_energy.device, + dtype=node_energy.dtype, + ) + energy.index_add_(0, data.batch, node_energy.view(-1)) + return {"energy": energy / self.avg_num_nodes} + + +@registry.register_model("equiformer_v2_force_head") +class EquiformerV2ForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor]): + forces = self.force_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + forces = forces.embedding.narrow(1, 1, 3) + forces = forces.view(-1, 3).contiguous() + if gp_utils.initialized(): + forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + return {"forces": forces} diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0ec66b9dba..d6367fa9ad 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -10,13 +10,17 @@ import contextlib import logging import time +import typing import torch import torch.nn as nn +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.escn.so3 import ( CoefficientMapping, SO3_Embedding, @@ -36,13 +40,14 @@ @registry.register_model("escn") -class eSCN(BaseModel): +class eSCN(nn.Module, GraphModelMixin): """Equivariant Spherical Channel Network Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs Args: use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time regress_forces (bool): Compute forces otf_graph (bool): Compute graph On The Fly (OTF) max_neighbors (int): Maximum number of neighbors per atom @@ -64,10 +69,8 @@ class eSCN(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, max_neighbors: int = 40, @@ -79,7 +82,6 @@ def __init__( sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, - use_grid: bool = True, num_sphere_samples: int = 128, distance_function: str = "gaussian", basis_width_scalar: float = 1.0, @@ -100,6 +102,7 @@ def __init__( self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.show_timing_info = show_timing_info @@ -232,22 +235,16 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() @@ -290,8 +287,8 @@ def forward(self, data): x_message = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -304,8 +301,8 @@ def forward(self, data): x = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -421,6 +418,149 @@ def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) +@registry.register_model("escn_backbone") +class eSCNBackbone(eSCN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + device = data.pos.device + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + + atomic_numbers = data.atomic_numbers.long() + num_atoms = len(atomic_numbers) + + graph = self.generate_graph(data) + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + self.SO3_edge_rot = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + offset = 0 + x = SO3_Embedding( + num_atoms, + self.lmax_list, + self.sphere_channels, + device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l=0,m=0 coefficients for each resolution + for i in range(self.num_resolutions): + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer + mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + if i > 0: + x_message = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Residual layer for all layers past the first + x.embedding = x.embedding + x_message.embedding + + else: + # No residual for the first layer + x = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. + # These values are fed into the output blocks. + x_pt = torch.tensor([], device=device) + offset = 0 + # Compute the embedding values at every sampled point on the sphere + for i in range(self.num_resolutions): + num_coefficients = int((x.lmax_list[i] + 1) ** 2) + x_pt = torch.cat( + [ + x_pt, + torch.einsum( + "abc, pb->apc", + x.embedding[:, offset : offset + num_coefficients], + self.sphharm_weights[i], + ).contiguous(), + ], + dim=2, + ) + offset = offset + num_coefficients + + x_pt = x_pt.view(-1, self.sphere_channels_all) + + return {"sphere_values": x_pt, "sphere_points": self.sphere_points} + + +@registry.register_model("escn_energy_head") +class eSCNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + # Output blocks for energy and forces + self.energy_block = EnergyBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + node_energy = self.energy_block(emb["sphere_values"]) + energy = torch.zeros(len(data.natoms), device=data.pos.device) + energy.index_add_(0, data.batch, node_energy.view(-1)) + # Scale energy to help balance numerical precision w.r.t. forces + return {"energy": energy * 0.001} + + +@registry.register_model("escn_force_head") +class eSCNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = ForceBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + return {"forces": self.force_block(emb["sphere_values"], emb["sphere_points"])} + + class LayerBlock(torch.nn.Module): """ Layer block: Perform one layer (message passing and aggregation) of the GNN diff --git a/src/fairchem/core/models/gemnet/gemnet.py b/src/fairchem/core/models/gemnet/gemnet.py index e719c219b8..f5537b9535 100644 --- a/src/fairchem/core/models/gemnet/gemnet.py +++ b/src/fairchem/core/models/gemnet/gemnet.py @@ -7,14 +7,20 @@ from __future__ import annotations +import typing + import numpy as np import torch +import torch.nn as nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -28,17 +34,12 @@ @registry.register_model("gemnet_t") -class GemNetT(BaseModel): +class GemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -94,9 +95,6 @@ class GemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -120,6 +118,7 @@ def __init__( extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, output_init: str = "HeOrthogonal", activation: str = "swish", num_elements: int = 83, @@ -132,7 +131,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -146,6 +144,7 @@ def __init__( self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # GemNet variants self.direct_forces = direct_forces @@ -235,7 +234,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -421,18 +420,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.edge_distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -447,10 +438,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -530,7 +521,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -549,7 +540,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -557,11 +548,11 @@ def forward(self, data): if self.extensive: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} @@ -569,30 +560,18 @@ def forward(self, data): if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t @@ -601,3 +580,129 @@ def forward(self, data): @property def num_params(self): return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_t_backbone") +class GemNetTBackbone(GemNetT, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + edge_index, + neighbors, + D_st, + V_st, + id_swap, + id3_ba, + id3_ca, + id3_ragged_idx, + ) = self.generate_interaction_graph(data) + idx_s, idx_t = edge_index + + # Calculate triplet angles + cosφ_cab = inner_product_normalized(V_st[id3_ca], V_st[id3_ba]) + rad_cbf3, cbf3 = self.cbf_basis3(D_st, cosφ_cab, id3_ca) + + rbf = self.radial_basis(D_st) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, rbf, idx_s, idx_t) # (nEdges, emb_size_edge) + + rbf3 = self.mlp_rbf3(rbf) + cbf3 = self.mlp_cbf3(rad_cbf3, cbf3, id3_ca, id3_ragged_idx) + + rbf_h = self.mlp_rbf_h(rbf) + rbf_out = self.mlp_rbf_out(rbf) + + E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + rbf3=rbf3, + cbf3=cbf3, + id3_ragged_idx=id3_ragged_idx, + id_swap=id_swap, + id3_ba=id3_ba, + id3_ca=id3_ca, + rbf_h=rbf_h, + idx_s=idx_s, + idx_t=idx_t, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + F_st += F + E_t += E + return { + "F_st": F_st, + "E_t": E_t, + "edge_vec": V_st, + "edge_idx": idx_t, + "node_embedding": h, + "edge_embedding": m, + } + + +@registry.register_model("gemnet_t_energy_and_grad_force_head") +class GemNetTEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.extensive = backbone.extensive + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t} + + if self.regress_forces and not self.direct_forces: + outputs["forces"] = -torch.autograd.grad( + E_t.sum(), data.pos, create_graph=True + )[0] + # (nAtoms, 3) + return outputs + + +@registry.register_model("gemnet_t_force_head") +class GemNetTForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # map forces in edge directions + F_st_vec = emb["F_st"][:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.size(0), + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (nAtoms, 3) diff --git a/src/fairchem/core/models/gemnet_gp/gemnet.py b/src/fairchem/core/models/gemnet_gp/gemnet.py index 81fbd40694..97af540de2 100644 --- a/src/fairchem/core/models/gemnet_gp/gemnet.py +++ b/src/fairchem/core/models/gemnet_gp/gemnet.py @@ -9,13 +9,14 @@ import numpy as np import torch +from torch import nn from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -29,17 +30,12 @@ @registry.register_model("gp_gemnet_t") -class GraphParallelGemNetT(BaseModel): +class GraphParallelGemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -95,9 +91,6 @@ class GraphParallelGemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -121,6 +114,7 @@ def __init__( extensive: bool = True, otf_graph: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, output_init: str = "HeOrthogonal", activation: str = "swish", scale_num_blocks: bool = False, @@ -134,7 +128,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -150,6 +143,7 @@ def __init__( self.regress_forces = regress_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # GemNet variants self.direct_forces = direct_forces @@ -239,7 +233,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -415,18 +409,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -441,10 +427,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -563,7 +549,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -585,7 +571,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -601,41 +587,29 @@ def forward(self, data): E_t = gp_utils.gather_from_model_parallel_region(E_t, dim=0) E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} if self.regress_forces: if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t_full, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index e1176d00c9..c9dd9e13ed 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -7,9 +7,11 @@ from __future__ import annotations import logging +import typing import numpy as np import torch +import torch.nn as nn from torch_scatter import segment_coo from fairchem.core.common.registry import registry @@ -18,7 +20,7 @@ get_max_neighbors_mask, scatter_det, ) -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .initializers import get_initializer @@ -40,17 +42,15 @@ repeat_blocks, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + @registry.register_model("gemnet_oc") -class GemNetOC(BaseModel): +class GemNetOC(nn.Module, GraphModelMixin): """ Arguments --------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -108,6 +108,8 @@ class GemNetOC(BaseModel): If False predict forces based on negative gradient of energy potential. use_pbc: bool Whether to use periodic boundary conditions. + use_pbc_single: + Process batch PBC graphs one at a time scale_backprop_forces: bool Whether to scale up the energy and then scales down the forces to prevent NaNs and infs in backpropagated forces. @@ -179,9 +181,6 @@ class GemNetOC(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -206,6 +205,7 @@ def __init__( regress_forces: bool = True, direct_forces: bool = False, use_pbc: bool = True, + use_pbc_single: bool = False, scale_backprop_forces: bool = False, cutoff: float = 6.0, cutoff_qint: float | None = None, @@ -249,11 +249,11 @@ def __init__( super().__init__() if len(kwargs) > 0: logging.warning(f"Unrecognized arguments: {list(kwargs.keys())}") - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive + self.activation = activation self.atom_edge_interaction = atom_edge_interaction self.edge_atom_interaction = edge_atom_interaction self.atom_interaction = atom_interaction @@ -272,6 +272,7 @@ def __init__( ) self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.direct_forces = direct_forces self.forces_coupled = forces_coupled @@ -357,7 +358,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) - self.out_energy = Dense(emb_size_atom, num_targets, bias=False, activation=None) + self.out_energy = Dense(emb_size_atom, 1, bias=False, activation=None) if direct_forces: out_mlp_F = [ Dense( @@ -373,9 +374,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) - self.out_forces = Dense( - emb_size_edge, num_targets, bias=False, activation=None - ) + self.out_forces = Dense(emb_size_edge, 1, bias=False, activation=None) out_initializer = get_initializer(output_init) self.out_energy.reset_parameters(out_initializer) @@ -870,15 +869,7 @@ def subselect_edges( def generate_graph_dict(self, data, cutoff, max_neighbors): """Generate a radius/nearest neighbor graph.""" otf_graph = cutoff > 6 or max_neighbors > 50 or self.otf_graph - - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - num_neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, cutoff=cutoff, max_neighbors=max_neighbors, @@ -886,15 +877,15 @@ def generate_graph_dict(self, data, cutoff, max_neighbors): ) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - edge_vector = -distance_vec / edge_dist[:, None] - cell_offsets = -cell_offsets # a - c + offset + edge_vector = -graph.edge_distance_vec / graph.edge_distance[:, None] + cell_offsets = -graph.cell_offsets # a - c + offset graph = { - "edge_index": edge_index, - "distance": edge_dist, + "edge_index": graph.edge_index, + "distance": graph.edge_distance, "vector": edge_vector, "cell_offset": cell_offsets, - "num_neighbors": num_neighbors, + "num_neighbors": graph.neighbors, } # Mask interaction edges if required @@ -1285,11 +1276,11 @@ def forward(self, data): if self.extensive: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) E_t = E_t.squeeze(1) # (num_molecules) outputs = {"energy": E_t} @@ -1308,19 +1299,19 @@ def forward(self, data): dim=0, dim_size=int(nEdges / 2), reduce="mean", - ) # (nEdges/2, num_targets) - F_st = F_st[id_undir] # (nEdges, num_targets) + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) # map forces in edge directions F_st_vec = F_st[:, :, None] * main_graph["vector"][:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter_det( F_st_vec, idx_t, dim=0, dim_size=num_atoms, reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) else: F_t = self.force_scaler.calc_forces_and_update(E_t, pos) @@ -1333,3 +1324,233 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_oc_backbone") +class GemNetOCBackbone(GemNetOC, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + num_atoms = atomic_numbers.shape[0] + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + main_graph, + a2a_graph, + a2ee2a_graph, + qint_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ) = self.get_graphs_and_indices(data) + _, idx_t = main_graph["edge_index"] + + ( + basis_rad_raw, + basis_atom_update, + basis_output, + bases_qint, + bases_e2e, + bases_a2e, + bases_e2a, + basis_a2a_rad, + ) = self.get_bases( + main_graph=main_graph, + a2a_graph=a2a_graph, + a2ee2a_graph=a2ee2a_graph, + qint_graph=qint_graph, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + num_atoms=num_atoms, + ) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"]) + # (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[0](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E, xs_F = [x_E], [x_F] + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + bases_qint=bases_qint, + bases_e2e=bases_e2e, + bases_a2e=bases_a2e, + bases_e2a=bases_e2a, + basis_a2a_rad=basis_a2a_rad, + basis_atom_update=basis_atom_update, + edge_index_main=main_graph["edge_index"], + a2ee2a_graph=a2ee2a_graph, + a2a_graph=a2a_graph, + id_swap=id_swap, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E.append(x_E) + xs_F.append(x_F) + + return { + "xs_E": xs_E, + "xs_F": xs_F, + "edge_vec": main_graph["vector"], + "edge_idx": idx_t, + "num_neighbors": main_graph["num_neighbors"], + } + + +@registry.register_model("gemnet_oc_energy_and_grad_force_head") +class GemNetOCEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__( + self, + backbone: BackboneInterface, + num_global_out_layers: int, + output_init: str = "HeOrthogonal", + ): + super().__init__() + self.extensive = backbone.extensive + + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + self.force_scaler = backbone.force_scaler + + out_mlp_E = [ + Dense( + backbone.atom_emb.emb_size * (len(backbone.int_blocks) + 1), + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) + + self.out_energy = Dense( + backbone.atom_emb.emb_size, + 1, + bias=False, + activation=None, + ) + + out_initializer = get_initializer(output_init) + self.out_energy.reset_parameters(out_initializer) + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # Global output block for final predictions + x_E = self.out_mlp_E(torch.cat(emb["xs_E"], dim=-1)) + with torch.cuda.amp.autocast(False): + E_t = self.out_energy(x_E.float()) + + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t.squeeze(1)} # (num_molecules) + + if self.regress_forces and not self.direct_forces: + F_t = self.force_scaler.calc_forces_and_update(outputs["energy"], data.pos) + outputs["forces"] = F_t.squeeze(1) + return outputs + + +@registry.register_model("gemnet_oc_force_head") +class GemNetOCForceHead(nn.Module, HeadInterface): + def __init__( + self, backbone, num_global_out_layers: int, output_init: str = "HeOrthogonal" + ): + super().__init__() + + self.direct_forces = backbone.direct_forces + self.forces_coupled = backbone.forces_coupled + + emb_size_edge = backbone.edge_emb.dense.linear.out_features + if self.direct_forces: + out_mlp_F = [ + Dense( + emb_size_edge * (len(backbone.int_blocks) + 1), + emb_size_edge, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + emb_size_edge, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) + self.out_forces = Dense( + emb_size_edge, + 1, + bias=False, + activation=None, + ) + out_initializer = get_initializer(output_init) + self.out_forces.reset_parameters(out_initializer) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) + with torch.cuda.amp.autocast(False): + F_st = self.out_forces(x_F.float()) + + if self.forces_coupled: # enforce F_st = F_ts + nEdges = emb["edge_idx"].shape[0] + id_undir = repeat_blocks( + emb["num_neighbors"] // 2, + repeats=2, + continuous_indexing=True, + ) + F_st = scatter_det( + F_st, + id_undir, + dim=0, + dim_size=int(nEdges / 2), + reduce="mean", + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) + + # map forces in edge directions + F_st_vec = F_st[:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter_det( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.long().shape[0], + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (num_atoms, 3) + return {} diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index 8843f02b2e..33425e8d8d 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -32,15 +32,19 @@ from __future__ import annotations import math +import typing import torch from torch import nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_geometric.nn import MessagePassing from torch_scatter import scatter, segment_coo from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.gemnet.layers.base_layers import ScaledSiLU from fairchem.core.models.gemnet.layers.embedding_block import AtomEmbedding from fairchem.core.models.gemnet.layers.radial_basis import RadialBasis @@ -51,7 +55,7 @@ @registry.register_model("painn") -class PaiNN(BaseModel): +class PaiNN(nn.Module, GraphModelMixin): r"""PaiNN model based on the description in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra, https://arxiv.org/abs/2102.03150. @@ -59,9 +63,6 @@ class PaiNN(BaseModel): def __init__( self, - num_atoms: int, - bond_feat_dim: int, - num_targets: int, hidden_channels: int = 512, num_layers: int = 6, num_rbf: int = 128, @@ -72,6 +73,7 @@ def __init__( regress_forces: bool = True, direct_forces: bool = True, use_pbc: bool = True, + use_pbc_single: bool = False, otf_graph: bool = True, num_elements: int = 83, scale_file: str | None = None, @@ -91,6 +93,7 @@ def __init__( self.direct_forces = direct_forces self.otf_graph = otf_graph self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single # Borrowed from GemNet. self.symmetric_edge_symmetrization = False @@ -310,23 +313,16 @@ def symmetrize_edges( ) def generate_graph_values(self, data): - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # Unit vectors pointing from edge_index[1] to edge_index[0], # i.e., edge_index[0] - edge_index[1] divided by the norm. # make sure that the distances are not close to zero before dividing - mask_zero = torch.isclose(edge_dist, torch.tensor(0.0), atol=1e-6) - edge_dist[mask_zero] = 1.0e-6 - edge_vector = distance_vec / edge_dist[:, None] + mask_zero = torch.isclose(graph.edge_distance, torch.tensor(0.0), atol=1e-6) + graph.edge_distance[mask_zero] = 1.0e-6 + edge_vector = graph.edge_distance_vec / graph.edge_distance[:, None] - empty_image = neighbors == 0 + empty_image = graph.neighbors == 0 if torch.any(empty_image): raise ValueError( f"An image has no neighbors: id={data.id[empty_image]}, " @@ -342,11 +338,11 @@ def generate_graph_values(self, data): [edge_vector], id_swap, ) = self.symmetrize_edges( - edge_index, - cell_offsets, - neighbors, + graph.edge_index, + graph.cell_offsets, + graph.neighbors, data.batch, - [edge_dist], + [graph.edge_distance], [edge_vector], ) @@ -436,6 +432,50 @@ def __repr__(self) -> str: ) +@registry.register_model("painn_backbone") +class PaiNNBackbone(PaiNN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data) -> dict[str, torch.Tensor]: + pos = data.pos + z = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos = pos.requires_grad_(True) + + ( + edge_index, + neighbors, + edge_dist, + edge_vector, + id_swap, + ) = self.generate_graph_values(data) + + assert z.dim() == 1 + assert z.dtype == torch.long + + edge_rbf = self.radial_basis(edge_dist) # rbf * envelope + + x = self.atom_emb(z) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + + #### Interaction blocks ############################################### + + for i in range(self.num_layers): + dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) + + x = x + dx + vec = vec + dvec + x = x * self.inv_sqrt_2 + + dx, dvec = self.update_layers[i](x, vec) + + x = x + dx + vec = vec + dvec + x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) + + return {"node_embedding": x, "node_vec": vec} + + class PaiNNMessage(MessagePassing): def __init__( self, @@ -625,3 +665,53 @@ def forward(self, x, v): x = self.act(x) return x, v + + +@registry.register_model("painn_energy_head") +class PaiNNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.out_energy = nn.Sequential( + nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2), + ScaledSiLU(), + nn.Linear(backbone.hidden_channels // 2, 1), + ) + + nn.init.xavier_uniform_(self.out_energy[0].weight) + self.out_energy[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_energy[2].weight) + self.out_energy[2].bias.data.fill_(0) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + per_atom_energy = self.out_energy(emb["node_embedding"]).squeeze(1) + return {"energy": scatter(per_atom_energy, data.batch, dim=0)} + + +@registry.register_model("painn_force_head") +class PaiNNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + if self.direct_forces: + self.out_forces = PaiNNOutput(backbone.hidden_channels) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + forces = self.out_forces(emb["node_embedding"], emb["node_vec"]) + else: + forces = ( + -1 + * torch.autograd.grad( + emb["node_embedding"], + data.pos, + grad_outputs=torch.ones_like(emb["node_embedding"]), + create_graph=True, + )[0] + ) + return {"forces": forces} diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 2f89c17e1f..878aee746a 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -13,11 +13,11 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin @registry.register_model("schnet") -class SchNetWrap(SchNet, BaseModel): +class SchNetWrap(SchNet, GraphModelMixin): r"""Wrapper around the continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" `_. Each layer uses interaction @@ -28,11 +28,9 @@ class SchNetWrap(SchNet, BaseModel): h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), Args: - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets (int): Number of targets to predict. use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. (default: :obj:`True`) + use_pbc_single (bool,optional): Process batch PBC graphs one at a time regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating energy with respect to positions. (default: :obj:`True`) @@ -54,10 +52,8 @@ class SchNetWrap(SchNet, BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, + use_pbc_single: bool = False, regress_forces: bool = True, otf_graph: bool = False, hidden_channels: int = 128, @@ -67,9 +63,10 @@ def __init__( cutoff: float = 10.0, readout: str = "add", ) -> None: - self.num_targets = num_targets + self.num_targets = 1 self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.max_neighbors = 50 @@ -88,25 +85,17 @@ def _forward(self, data): z = data.atomic_numbers.long() pos = data.pos batch = data.batch - - ( - edge_index, - edge_weight, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) if self.use_pbc: assert z.dim() == 1 assert z.dtype == torch.long - edge_attr = self.distance_expansion(edge_weight) + edge_attr = self.distance_expansion(graph.edge_distance) h = self.embedding(z) for interaction in self.interactions: - h = h + interaction(h, edge_index, edge_weight, edge_attr) + h = h + interaction(h, graph.edge_index, graph.edge_distance, edge_attr) h = self.lin1(h) h = self.act(h) @@ -125,16 +114,13 @@ def forward(self, data): outputs = {"energy": energy} if self.regress_forces: - forces = ( - -1 - * ( - torch.autograd.grad( - energy, - data.pos, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - ) + forces = -1 * ( + torch.autograd.grad( + energy, + data.pos, + grad_outputs=torch.ones_like(energy), + create_graph=True, + )[0] ) outputs["forces"] = forces diff --git a/src/fairchem/core/models/scn/scn.py b/src/fairchem/core/models/scn/scn.py index bf8454f212..299fa48584 100644 --- a/src/fairchem/core/models/scn/scn.py +++ b/src/fairchem/core/models/scn/scn.py @@ -18,7 +18,7 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -33,12 +33,13 @@ @registry.register_model("scn") -class SphericalChannelNetwork(BaseModel): +class SphericalChannelNetwork(nn.Module, GraphModelMixin): """Spherical Channel Network Paper: Spherical Channels for Modeling Atomic Interactions Args: use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time regress_forces (bool): Compute forces otf_graph (bool): Compute graph On The Fly (OTF) max_num_neighbors (int): Maximum number of neighbors per atom @@ -75,10 +76,8 @@ class SphericalChannelNetwork(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, + use_pbc_single: bool = True, regress_forces: bool = True, otf_graph: bool = False, max_num_neighbors: int = 20, @@ -110,6 +109,7 @@ def __init__( self.regress_forces = regress_forces self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph self.show_timing_info = show_timing_info @@ -262,15 +262,7 @@ def _forward_helper(self, data): atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) pos = data.pos - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures @@ -278,12 +270,12 @@ def _forward_helper(self, data): # Calculate which message block each edge should use. Based on edge distance rank. edge_rank = self._rank_edge_distances( - edge_distance, edge_index, self.max_num_neighbors + graph.edge_distance, graph.edge_index, self.max_num_neighbors ) # Reorder edges so that they are grouped by distance rank (lowest to highest) last_cutoff = -0.1 - message_block_idx = torch.zeros(len(edge_distance), device=pos.device) + message_block_idx = torch.zeros(len(graph.edge_distance), device=pos.device) edge_distance_reorder = torch.tensor([], device=self.device) edge_index_reorder = torch.tensor([], device=self.device) edge_distance_vec_reorder = torch.tensor([], device=self.device) @@ -297,21 +289,21 @@ def _forward_helper(self, data): edge_distance_reorder = torch.cat( [ edge_distance_reorder, - torch.masked_select(edge_distance, mask), + torch.masked_select(graph.edge_distance, mask), ], dim=0, ) edge_index_reorder = torch.cat( [ edge_index_reorder, - torch.masked_select(edge_index, mask.view(1, -1).repeat(2, 1)).view( - 2, -1 - ), + torch.masked_select( + graph.edge_index, mask.view(1, -1).repeat(2, 1) + ).view(2, -1), ], dim=1, ) edge_distance_vec_mask = torch.masked_select( - edge_distance_vec, mask.view(-1, 1).repeat(1, 3) + graph.edge_distance_vec, mask.view(-1, 1).repeat(1, 3) ).view(-1, 3) edge_distance_vec_reorder = torch.cat( [edge_distance_vec_reorder, edge_distance_vec_mask], dim=0 diff --git a/src/fairchem/core/modules/normalization/__init__.py b/src/fairchem/core/modules/normalization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/fairchem/core/modules/normalization/_load_utils.py b/src/fairchem/core/modules/normalization/_load_utils.py new file mode 100644 index 0000000000..0825886db9 --- /dev/null +++ b/src/fairchem/core/modules/normalization/_load_utils.py @@ -0,0 +1,113 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Callable + +import torch + +from fairchem.core.common.utils import save_checkpoint + +if TYPE_CHECKING: + from pathlib import Path + + from torch.nn import Module + from torch.utils.data import Dataset + + +def _load_check_duplicates(config: dict, name: str) -> dict[str, torch.nn.Module]: + """Attempt to load a single file with normalizers/element references and check config for duplicate targets. + + Args: + config: configuration dictionary + name: Name of module to use for logging + + Returns: + dictionary of normalizer or element reference modules + """ + modules = {} + if "file" in config: + modules = torch.load(config["file"]) + logging.info(f"Loaded {name} for the following targets: {list(modules.keys())}") + # make sure that element-refs are not specified both as fit and file + fit_targets = config["fit"]["targets"] if "fit" in config else [] + duplicates = list( + filter( + lambda x: x in fit_targets, + list(config) + list(modules.keys()), + ) + ) + if len(duplicates) > 0: + logging.warning( + f"{name} values for the following targets {duplicates} have been specified to be fit and also read" + f" from a file. The files read from file will be used instead of fitting." + ) + duplicates = list(filter(lambda x: x in modules, config)) + if len(duplicates) > 0: + logging.warning( + f"Duplicate {name} values for the following targets {duplicates} where specified in the file " + f"{config['file']} and an explicitly set file. The normalization values read from " + f"{config['file']} will be used." + ) + return modules + + +def _load_from_config( + config: dict, + name: str, + fit_fun: Callable[[list[str], Dataset, Any, ...], dict[str, Module]], + create_fun: Callable[[str | Path], Module], + dataset: Dataset, + checkpoint_dir: str | Path | None = None, + **fit_kwargs, +) -> dict[str, torch.nn.Module]: + """Load or fit normalizers or element references from config + + If a fit is done, a fitted key with value true is added to the config to avoid re-fitting + once a checkpoint has been saved. + + Args: + config: configuration dictionary + name: Name of module to use for logging + fit_fun: Function to fit modules + create_fun: Function to create a module from file + checkpoint_dir: directory to save modules. If not given, modules won't be saved. + + Returns: + dictionary of normalizer or element reference modules + + """ + modules = _load_check_duplicates(config, name) + for target in config: + if target == "fit" and not config["fit"].get("fitted", False): + # remove values for output targets that have already been read from files + targets = [ + target for target in config["fit"]["targets"] if target not in modules + ] + fit_kwargs.update( + {k: v for k, v in config["fit"].items() if k != "targets"} + ) + modules.update(fit_fun(targets=targets, dataset=dataset, **fit_kwargs)) + config["fit"]["fitted"] = True + # if a single file for all outputs is not provided, + # then check if a single file is provided for a specific output + elif target != "file": + modules[target] = create_fun(**config[target]) + # save the linear references for possible subsequent use + if checkpoint_dir is not None: + path = save_checkpoint( + modules, + checkpoint_dir, + f"{name}.pt", + ) + logging.info( + f"{name} checkpoint for targets {list(modules.keys())} have been saved to: {path}" + ) + + return modules diff --git a/src/fairchem/core/modules/normalization/element_references.py b/src/fairchem/core/modules/normalization/element_references.py new file mode 100644 index 0000000000..e41dbe588c --- /dev/null +++ b/src/fairchem/core/modules/normalization/element_references.py @@ -0,0 +1,290 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from fairchem.core.datasets import data_list_collater + +from ._load_utils import _load_from_config + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +class LinearReferences(nn.Module): + """Represents an elemental linear references model for a target property. + + In an elemental reference associates a value with each chemical element present in the dataset. + Elemental references define a chemical composition model, i.e. a rough approximation of a target + property (energy) using elemental references is done by summing the elemental references multiplied + by the number of times the corresponding element is present. + + Elemental references energies can be taken as: + - the energy of a chemical species in its elemental state + (i.e. lowest energy polymorph of single element crystal structures for solids) + - fitting a linear model to a dataset, where the features are the counts of each element in each data point. + see the function fit_linear references below for details + + Training GNNs to predict the difference between DFT and the predictions of a chemical composition + model represent a useful normalization scheme that can improve model accuracy. See for example the + "Alternative reference scheme" section of the OC22 manuscript: https://arxiv.org/pdf/2206.08917 + """ + + def __init__( + self, + element_references: torch.Tensor | None = None, + max_num_elements: int = 118, + ): + """ + Args: + element_references (Tensor): tensor with linear reference values + max_num_elements (int): max number of elements - 118 is a stretch + metrics (dict): dictionary with accuracy metrics in predicting values for structures used in fitting. + """ + super().__init__() + self.register_buffer( + name="element_references", + tensor=element_references + if element_references is not None + else torch.zeros(max_num_elements + 1), + ) + + def _apply_refs( + self, target: torch.Tensor, batch: Batch, sign: int, reshaped: bool = True + ) -> torch.Tensor: + """Apply references batch-wise""" + indices = batch.atomic_numbers.to( + dtype=torch.int, device=self.element_references.device + ) + elemrefs = self.element_references[indices].to(dtype=target.dtype) + # this option should not exist, all tensors should have compatible shapes in dataset and trainer outputs + if reshaped: + elemrefs = elemrefs.view(batch.natoms.sum(), -1) + + return target.index_add(0, batch.batch, elemrefs, alpha=sign) + + @torch.autocast(device_type="cuda", enabled=False) + def dereference( + self, target: torch.Tensor, batch: Batch, reshaped: bool = True + ) -> torch.Tensor: + """Remove linear references""" + return self._apply_refs(target, batch, -1, reshaped=reshaped) + + @torch.autocast(device_type="cuda", enabled=False) + def forward( + self, target: torch.Tensor, batch: Batch, reshaped: bool = True + ) -> torch.Tensor: + """Add linear references""" + return self._apply_refs(target, batch, 1, reshaped=reshaped) + + +def create_element_references( + file: str | Path | None = None, + state_dict: dict | None = None, +) -> LinearReferences: + """Create an element reference module. + + Args: + type (str): type of reference (only linear implemented) + file (str or Path): path to pt or npz file + state_dict (dict): a state dict of a element reference module + + Returns: + LinearReference + """ + if file is not None and state_dict is not None: + logging.warning( + "Both a file and a state_dict for element references was given." + "The references will be read from the file and the provided state_dict will be ignored." + ) + + # path takes priority if given + if file is not None: + extension = Path(file).suffix + if extension == ".pt": + # try to load a pt file + state_dict = torch.load(file) + elif extension == ".npz": + state_dict = {} + with np.load(file) as values: + # legacy linref files + if "coeff" in values: + state_dict["element_references"] = torch.tensor(values["coeff"]) + else: + state_dict["element_references"] = torch.tensor( + values["element_references"] + ) + else: + raise RuntimeError( + f"Element references file with extension '{extension}' is not supported." + ) + + if "element_references" not in state_dict: + raise RuntimeError("Unable to load linear element references!") + + return LinearReferences(element_references=state_dict["element_references"]) + + +@torch.no_grad() +def fit_linear_references( + targets: list[str], + dataset: Dataset, + batch_size: int, + num_batches: int | None = None, + num_workers: int = 0, + max_num_elements: int = 118, + log_metrics: bool = True, + use_numpy: bool = True, + driver: str | None = None, + shuffle: bool = True, + seed: int = 0, +) -> dict[str, LinearReferences]: + """Fit a set linear references for a list of targets using a given number of batches. + + Args: + targets: list of target names + dataset: data set to fit linear references with + batch_size: size of batch + num_batches: number of batches to use in fit. If not given will use all batches + num_workers: number of workers to use in data loader. + Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function + in distributed mode. The issue has to do with pickling the functions in load_references_from_config + see function below... + max_num_elements: max number of elements in dataset. If not given will use an ambitious value of 118 + log_metrics: if true will compute MAE, RMSE and R2 score of fit and log. + use_numpy: use numpy.linalg.lstsq instead of torch. This tends to give better solutions. + driver: backend used to solve linear system. See torch.linalg.lstsq docs. Ignored if use_numpy=True + shuffle: whether to shuffle when loading the dataset + seed: random seed used to shuffle the sampler if shuffle=True + + Returns: + dict of fitted LinearReferences objects + """ + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(data_list_collater, otf_graph=True), + num_workers=num_workers, + persistent_workers=num_workers > 0, + generator=torch.Generator().manual_seed(seed), + ) + + num_batches = num_batches if num_batches is not None else len(data_loader) + if num_batches > len(data_loader): + logging.warning( + f"The given num_batches {num_batches} is larger than total batches of size {batch_size} in dataset. " + f"num_batches will be ignored and the whole dataset will be used." + ) + num_batches = len(data_loader) + + max_num_elements += 1 # + 1 since H starts at index 1 + # solving linear system happens on CPU, which allows handling poorly conditioned and + # rank deficient matrices, unlike torch lstsq on GPU + composition_matrix = torch.zeros( + num_batches * batch_size, + max_num_elements, + ) + + target_vectors = { + target: torch.zeros(num_batches * batch_size) for target in targets + } + + logging.info( + f"Fitting linear references using {num_batches * batch_size} samples in {num_batches} " + f"batches of size {batch_size}." + ) + for i, batch in tqdm( + enumerate(data_loader), total=num_batches, desc="Fitting linear references" + ): + if i == 0: + assert all( + len(batch[target].squeeze().shape) == 1 for target in targets + ), "element references can only be used for scalar targets" + elif i == num_batches: + break + + next_batch_size = len(batch) if i == len(data_loader) - 1 else batch_size + for target in targets: + target_vectors[target][ + i * batch_size : i * batch_size + next_batch_size + ] = batch[target].to(torch.float64) + for j, data in enumerate(batch.to_data_list()): + composition_matrix[i * batch_size + j] = torch.bincount( + data.atomic_numbers.int(), + minlength=max_num_elements, + ).to(torch.float64) + + # reduce the composition matrix to only features that are non-zero to improve rank + mask = composition_matrix.sum(axis=0) != 0.0 + reduced_composition_matrix = composition_matrix[:, mask] + elementrefs = {} + + for target in targets: + coeffs = torch.zeros(max_num_elements) + + if use_numpy: + solution = torch.tensor( + np.linalg.lstsq( + reduced_composition_matrix.numpy(), + target_vectors[target].numpy(), + rcond=None, + )[0] + ) + else: + lstsq = torch.linalg.lstsq( + reduced_composition_matrix, target_vectors[target], driver=driver + ) + solution = lstsq.solution + + coeffs[mask] = solution + elementrefs[target] = LinearReferences(coeffs) + + if log_metrics is True: + y = target_vectors[target] + y_pred = torch.matmul(reduced_composition_matrix, solution) + y_mean = target_vectors[target].mean() + N = len(target_vectors[target]) + ss_res = ((y - y_pred) ** 2).sum() + ss_tot = ((y - y_mean) ** 2).sum() + mae = (abs(y - y_pred)).sum() / N + rmse = (((y - y_pred) ** 2).sum() / N).sqrt() + r2 = 1 - (ss_res / ss_tot) + logging.info( + f"Training accuracy metrics for fitted linear element references: mae={mae}, rmse={rmse}, r2 score={r2}" + ) + + return elementrefs + + +def load_references_from_config( + config: dict[str, Any], + dataset: Dataset, + seed: int = 0, + checkpoint_dir: str | Path | None = None, +) -> dict[str, LinearReferences]: + """Create a dictionary with element references from a config.""" + return _load_from_config( + config, + "element_references", + fit_linear_references, + create_element_references, + dataset, + checkpoint_dir, + seed=seed, + ) diff --git a/src/fairchem/core/modules/normalization/normalizer.py b/src/fairchem/core/modules/normalization/normalizer.py new file mode 100644 index 0000000000..f16db7d398 --- /dev/null +++ b/src/fairchem/core/modules/normalization/normalizer.py @@ -0,0 +1,290 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import warnings +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from fairchem.core.datasets import data_list_collater + +from ._load_utils import _load_from_config + +if TYPE_CHECKING: + from collections.abc import Mapping + + from fairchem.core.modules.normalization.element_references import LinearReferences + + +class Normalizer(nn.Module): + """Normalize/denormalize a tensor and optionally add a atom reference offset.""" + + def __init__( + self, + mean: float | torch.Tensor = 0.0, + rmsd: float | torch.Tensor = 1.0, + ): + """tensor is taken as a sample to calculate the mean and rmsd""" + super().__init__() + + if isinstance(mean, float): + mean = torch.tensor(mean) + if isinstance(rmsd, float): + rmsd = torch.tensor(rmsd) + + self.register_buffer(name="mean", tensor=mean) + self.register_buffer(name="rmsd", tensor=rmsd) + + @torch.autocast(device_type="cuda", enabled=False) + def norm(self, tensor: torch.Tensor) -> torch.Tensor: + return (tensor - self.mean) / self.rmsd + + @torch.autocast(device_type="cuda", enabled=False) + def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: + return normed_tensor * self.rmsd + self.mean + + def forward(self, normed_tensor: torch.Tensor) -> torch.Tensor: + return self.denorm(normed_tensor) + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + # check if state dict is legacy state dicts + if "std" in state_dict: + state_dict = { + "mean": torch.tensor(state_dict["mean"]), + "rmsd": state_dict["std"], + } + + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + +def create_normalizer( + file: str | Path | None = None, + state_dict: dict | None = None, + tensor: torch.Tensor | None = None, + mean: float | torch.Tensor | None = None, + rmsd: float | torch.Tensor | None = None, + stdev: float | torch.Tensor | None = None, +) -> Normalizer: + """Build a target data normalizers with optional atom ref + + Only one of file, state_dict, tensor, or (mean and rmsd) will be used to create a normalizer. + If more than one set of inputs are given priority will be given following the order in which they are listed above. + + Args: + file (str or Path): path to pt or npz file. + state_dict (dict): a state dict for Normalizer module + tensor (Tensor): a tensor with target values used to compute mean and std + mean (float | Tensor): mean of target data + rmsd (float | Tensor): rmsd of target data, rmsd from mean = stdev, rmsd from 0 = rms + stdev: standard deviation (deprecated, use rmsd instead) + + Returns: + Normalizer + """ + if stdev is not None: + warnings.warn( + "Use of 'stdev' is deprecated, use 'rmsd' instead", DeprecationWarning + ) + if rmsd is not None: + logging.warning( + "Both 'stdev' and 'rmsd' values where given to create a normalizer, rmsd values will be used." + ) + + # old configs called it stdev, using this in the function signature reduces overhead code elsewhere + if stdev is not None and rmsd is None: + rmsd = stdev + + # path takes priority if given + if file is not None: + if state_dict is not None or tensor is not None or mean is not None: + logging.warning( + "A file to a normalizer has been given. Normalization values will be read from it, and all other inputs" + " will be ignored." + ) + extension = Path(file).suffix + if extension == ".pt": + # try to load a pt file + state_dict = torch.load(file) + elif extension == ".npz": + # try to load an NPZ file + values = np.load(file) + mean = values.get("mean") + rmsd = values.get("rmsd") or values.get("std") # legacy files + tensor = None # set to None since values read from file are prioritized + else: + raise RuntimeError( + f"Normalizer file with extension '{extension}' is not supported." + ) + + # state dict is second priority + if state_dict is not None: + if tensor is not None or mean is not None: + logging.warning( + "The state_dict provided will be used to set normalization values. All other inputs will be ignored." + ) + normalizer = Normalizer() + normalizer.load_state_dict(state_dict) + return normalizer + + # if not then read target value tensor + if tensor is not None: + if mean is not None: + logging.warning( + "Normalization values will be computed from input tensor, all other inputs will be ignored." + ) + mean = torch.mean(tensor) + rmsd = torch.std(tensor) + elif mean is not None and rmsd is not None: + if not isinstance(mean, torch.Tensor): + mean = torch.tensor(mean) + if not isinstance(rmsd, torch.Tensor): + rmsd = torch.tensor(rmsd) + + # if mean and rmsd are still None than raise an error + if mean is None or rmsd is None: + raise ValueError( + "Incorrect inputs. One of the following sets of inputs must be given: ", + "a file path to a .pt or .npz file, or mean and rmsd values, or a tensor of target values", + ) + + return Normalizer(mean=mean, rmsd=rmsd) + + +@torch.no_grad() +def fit_normalizers( + targets: list[str], + dataset: Dataset, + batch_size: int, + override_values: dict[str, dict[str, float]] | None = None, + rmsd_correction: int | None = None, + element_references: dict | None = None, + num_batches: int | None = None, + num_workers: int = 0, + shuffle: bool = True, + seed: int = 0, +) -> dict[str, Normalizer]: + """Estimate mean and rmsd from data to create normalizers + + Args: + targets: list of target names + dataset: data set to fit linear references with + batch_size: size of batch + override_values: dictionary with target names and values to override. i.e. {"forces": {"mean": 0.0}} will set + the forces mean to zero. + rmsd_correction: correction to use when computing mean in std/rmsd. See docs for torch.std. + If not given, will always use 0 when mean == 0, and 1 otherwise. + element_references: + num_batches: number of batches to use in fit. If not given will use all batches + num_workers: number of workers to use in data loader + Note setting num_workers > 1 leads to finicky multiprocessing issues when using this function + in distributed mode. The issue has to do with pickling the functions in load_normalizers_from_config + see function below... + shuffle: whether to shuffle when loading the dataset + seed: random seed used to shuffle the sampler if shuffle=True + + Returns: + dict of normalizer objects + """ + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=partial(data_list_collater, otf_graph=True), + num_workers=num_workers, + persistent_workers=num_workers > 0, + generator=torch.Generator().manual_seed(seed), + ) + + num_batches = num_batches if num_batches is not None else len(data_loader) + if num_batches > len(data_loader): + logging.warning( + f"The given num_batches {num_batches} is larger than total batches of size {batch_size} in dataset. " + f"num_batches will be ignored and the whole dataset will be used." + ) + num_batches = len(data_loader) + + element_references = element_references or {} + target_vectors = defaultdict(list) + + logging.info( + f"Estimating mean and rmsd for normalization using {num_batches * batch_size} samples in {num_batches} batches " + f"of size {batch_size}." + ) + for i, batch in tqdm( + enumerate(data_loader), total=num_batches, desc="Estimating mean and rmsd" + ): + if i == num_batches: + break + + for target in targets: + target_vector = batch[target] + if target in element_references: + target_vector = element_references[target].dereference( + target_vector, batch, reshaped=False + ) + target_vectors[target].append(target_vector) + + normalizers = {} + for target in targets: + target_vector = torch.cat(target_vectors[target], dim=0) + values = {"mean": target_vector.mean()} + if target in override_values: + for name, val in override_values[target].items(): + values[name] = torch.tensor(val) + # calculate root mean square deviation + if "rmsd" not in values: + if rmsd_correction is None: + rmsd_correction = 0 if values["mean"] == 0.0 else 1 + values["rmsd"] = ( + ((target_vector - values["mean"]) ** 2).sum() + / max(len(target_vector) - rmsd_correction, 1) + ).sqrt() + normalizers[target] = create_normalizer(**values) + + return normalizers + + +def load_normalizers_from_config( + config: dict[str, Any], + dataset: Dataset, + seed: int = 0, + checkpoint_dir: str | Path | None = None, + element_references: dict[str, LinearReferences] | None = None, +) -> dict[str, Normalizer]: + """Create a dictionary with element references from a config.""" + # edit the config slightly to extract override args + if "fit" in config: + override_values = { + target: vals + for target, vals in config["fit"]["targets"].items() + if isinstance(vals, dict) + } + config["fit"]["override_values"] = override_values + config["fit"]["targets"] = list(config["fit"]["targets"].keys()) + + return _load_from_config( + config, + "normalizers", + fit_normalizers, + create_normalizer, + dataset, + checkpoint_dir, + seed=seed, + element_references=element_references, + ) diff --git a/src/fairchem/core/modules/normalizer.py b/src/fairchem/core/modules/normalizer.py deleted file mode 100644 index 75f34e83f4..0000000000 --- a/src/fairchem/core/modules/normalizer.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import torch - - -class Normalizer: - """Normalize a Tensor and restore it later.""" - - def __init__( - self, - tensor: torch.Tensor | None = None, - mean=None, - std=None, - device=None, - ) -> None: - """tensor is taken as a sample to calculate the mean and std""" - if tensor is None and mean is None: - return - - if device is None: - device = "cpu" - - self.mean: torch.Tensor - self.std: torch.Tensor - if tensor is not None: - self.mean = torch.mean(tensor, dim=0).to(device) - self.std = torch.std(tensor, dim=0).to(device) - return - - if mean is not None and std is not None: - self.mean = torch.tensor(mean).to(device) - self.std = torch.tensor(std).to(device) - - def to(self, device) -> None: - self.mean = self.mean.to(device) - self.std = self.std.to(device) - - def norm(self, tensor: torch.Tensor) -> torch.Tensor: - return (tensor - self.mean) / self.std - - def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: - return normed_tensor * self.std + self.mean - - def state_dict(self): - return {"mean": self.mean, "std": self.std} - - def load_state_dict(self, state_dict) -> None: - self.mean = state_dict["mean"].to(self.mean.device) - self.std = state_dict["std"].to(self.mean.device) diff --git a/src/fairchem/core/modules/transforms.py b/src/fairchem/core/modules/transforms.py index 3a86be468c..52675fd28f 100644 --- a/src/fairchem/core/modules/transforms.py +++ b/src/fairchem/core/modules/transforms.py @@ -19,10 +19,12 @@ def __call__(self, data_object): return data_object for transform_fn in self.config: - # TODO: Normalization information used in the trainers. Ignore here - # for now. - if transform_fn == "normalizer": + # TODO: Normalization information used in the trainers. Ignore here for now + # TODO: if we dont use them here, these should not be defined as "transforms" in the config + # TODO: add them as another entry under dataset, maybe "standardize"? + if transform_fn in ("normalizer", "element_references"): continue + data_object = eval(transform_fn)(data_object, self.config[transform_fn]) return data_object diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 8d1618addf..2283c40b8a 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -13,6 +13,7 @@ import ase.io.trajectory import numpy as np import torch +from ase.geometry import wrap_positions from torch_geometric.data import Data from fairchem.core.common.utils import collate @@ -163,10 +164,16 @@ def convert(self, atoms: ase.Atoms, sid=None): """ # set the atomic numbers, positions, and cell + positions = np.array(atoms.get_positions(), copy=True) + pbc = np.array(atoms.pbc, copy=True) + cell = np.array(atoms.get_cell(complete=True), copy=True) + positions = wrap_positions(positions, cell, pbc=pbc, eps=0) + atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()) - positions = torch.Tensor(atoms.get_positions()) - cell = torch.Tensor(np.array(atoms.get_cell())).view(1, 3, 3) + positions = torch.from_numpy(positions).float() + cell = torch.from_numpy(cell).view(1, 3, 3).float() natoms = positions.shape[0] + # initialized to torch.zeros(natoms) if tags missing. # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags tags = torch.Tensor(atoms.get_tags()) @@ -187,13 +194,16 @@ def convert(self, atoms: ase.Atoms, sid=None): # optionally include other properties if self.r_edges: # run internal functions to get padded indices and distances - split_idx_dist = self._get_neighbors_pymatgen(atoms) + atoms_copy = atoms.copy() + atoms_copy.set_positions(positions) + split_idx_dist = self._get_neighbors_pymatgen(atoms_copy) edge_index, edge_distances, cell_offsets = self._reshape_features( *split_idx_dist ) data.edge_index = edge_index data.cell_offsets = cell_offsets + del atoms_copy if self.r_energy: energy = atoms.get_potential_energy(apply_constraint=False) data.energy = energy diff --git a/src/fairchem/core/scripts/fit_normalizers.py b/src/fairchem/core/scripts/fit_normalizers.py new file mode 100644 index 0000000000..0cfa2f2db5 --- /dev/null +++ b/src/fairchem/core/scripts/fit_normalizers.py @@ -0,0 +1,119 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_config, save_checkpoint +from fairchem.core.modules.normalization.element_references import ( + create_element_references, +) +from fairchem.core.modules.normalization.normalizer import fit_normalizers + + +def fit_norms( + config: dict, + output_path: str | Path, + linref_file: str | Path | None = None, + linref_target: str = "energy", +) -> None: + """Fit dataset mean and std using the standard config + + Args: + config: config + output_path: output path + linref_file: path to fitted linear references. IF these are used in training they must be used to compute mean/std + linref_target: target using linear references, basically always energy. + """ + output_path = Path(output_path).resolve() + elementrefs = ( + {linref_target: create_element_references(linref_file)} + if linref_file is not None + else {} + ) + + try: + # load the training dataset + train_dataset = registry.get_dataset_class( + config["dataset"]["train"].get("format", "lmdb") + )(config["dataset"]["train"]) + except KeyError as err: + raise ValueError("Train dataset is not specified in config!") from err + + try: + norm_config = config["dataset"]["train"]["transforms"]["normalizer"]["fit"] + except KeyError as err: + raise ValueError( + "The provided config does not specify a 'fit' block for 'normalizer'!" + ) from err + + targets = list(norm_config["targets"].keys()) + override_values = { + target: vals + for target, vals in norm_config["targets"].items() + if isinstance(vals, dict) + } + + normalizers = fit_normalizers( + targets=targets, + override_values=override_values, + element_references=elementrefs, + dataset=train_dataset, + batch_size=norm_config.get("batch_size", 32), + num_batches=norm_config.get("num_batches"), + num_workers=config.get("optim", {}).get("num_workers", 16), + ) + path = save_checkpoint( + normalizers, + output_path, + "normalizers.pt", + ) + logging.info(f"normalizers have been saved to {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to configuration yaml file", + ) + parser.add_argument( + "--out-path", + default=".", + type=str, + help="Output path to save normalizers", + ) + parser.add_argument( + "--linref-path", + type=str, + help="Path to linear references used.", + ) + parser.add_argument( + "--linref-target", + default="energy", + type=str, + help="target for which linear references are used.", + ) + args = parser.parse_args() + config, dup_warning, dup_error = load_config(args.config) + + if len(dup_warning) > 0: + logging.warning( + f"The following keys in the given config have duplicates: {dup_warning}." + ) + if len(dup_error) > 0: + raise RuntimeError( + f"The following include entries in the config have duplicates: {dup_error}" + ) + + fit_norms(config, args.out_path, args.linref_path) diff --git a/src/fairchem/core/scripts/fit_references.py b/src/fairchem/core/scripts/fit_references.py new file mode 100644 index 0000000000..f7f0c84dd7 --- /dev/null +++ b/src/fairchem/core/scripts/fit_references.py @@ -0,0 +1,91 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_config, save_checkpoint +from fairchem.core.modules.normalization.element_references import fit_linear_references + + +def fit_linref(config: dict, output_path: str | Path) -> None: + """Fit linear references using the standard config + + Args: + config: config + output_path: output path + """ + # load the training dataset + output_path = Path(output_path).resolve() + + try: + # load the training dataset + train_dataset = registry.get_dataset_class( + config["dataset"]["train"].get("format", "lmdb") + )(config["dataset"]["train"]) + except KeyError as err: + raise ValueError("Train dataset is not specified in config!") from err + + try: + elementref_config = config["dataset"]["train"]["transforms"][ + "element_references" + ]["fit"] + except KeyError as err: + raise ValueError( + "The provided config does not specify a 'fit' block for 'element_refereces'!" + ) from err + + element_refs = fit_linear_references( + targets=elementref_config["targets"], + dataset=train_dataset, + batch_size=elementref_config.get("batch_size", 32), + num_batches=elementref_config.get("num_batches"), + num_workers=config.get("optim", {}).get("num_workers", 16), + max_num_elements=elementref_config.get("max_num_elements", 118), + driver=elementref_config.get("driver", None), + ) + + for target, references in element_refs.items(): + path = save_checkpoint( + references.state_dict(), + output_path, + f"{target}_linref.pt", + ) + logging.info(f"{target} linear references have been saved to: {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to configuration yaml file", + ) + parser.add_argument( + "--out-path", + default=".", + type=str, + help="Output path to save linear references", + ) + args = parser.parse_args() + config, dup_warning, dup_error = load_config(args.config) + + if len(dup_warning) > 0: + logging.warning( + f"The following keys in the given config have duplicates: {dup_warning}." + ) + if len(dup_error) > 0: + raise RuntimeError( + f"The following include entries in the config have duplicates: {dup_error}" + ) + + fit_linref(config, args.out_path) diff --git a/src/fairchem/core/scripts/make_lmdb_sizes.py b/src/fairchem/core/scripts/make_lmdb_sizes.py index 682fb58e65..ebf2122aeb 100644 --- a/src/fairchem/core/scripts/make_lmdb_sizes.py +++ b/src/fairchem/core/scripts/make_lmdb_sizes.py @@ -15,7 +15,7 @@ from tqdm import tqdm from fairchem.core.common.typing import assert_is_instance -from fairchem.core.datasets import SinglePointLmdbDataset, TrajectoryLmdbDataset +from fairchem.core.datasets.lmdb_dataset import LmdbDataset def get_data(index): @@ -28,14 +28,13 @@ def get_data(index): return index, natoms, neighbors -def main(args) -> None: +def make_lmdb_sizes(args) -> None: path = assert_is_instance(args.data_path, str) global dataset + dataset = LmdbDataset({"src": path}) if os.path.isdir(path): - dataset = TrajectoryLmdbDataset({"src": path}) outpath = os.path.join(path, "metadata.npz") elif os.path.isfile(path): - dataset = SinglePointLmdbDataset({"src": path}) outpath = os.path.join(os.path.dirname(path), "metadata.npz") output_indices = range(len(dataset)) @@ -63,7 +62,7 @@ def main(args) -> None: np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors) -if __name__ == "__main__": +def get_lmdb_sizes_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--data-path", @@ -77,5 +76,10 @@ def main(args) -> None: type=int, help="Num of workers to parallelize across", ) + return parser + + +if __name__ == "__main__": + parser = get_lmdb_sizes_parser() args: argparse.Namespace = parser.parse_args() - main(args) + make_lmdb_sizes(args) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index cc575272be..a5f8690955 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -7,6 +7,7 @@ from __future__ import annotations +import copy import datetime import errno import logging @@ -38,10 +39,15 @@ save_checkpoint, update_config, ) +from fairchem.core.datasets.base_dataset import create_dataset from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss -from fairchem.core.modules.normalizer import Normalizer +from fairchem.core.modules.normalization.element_references import ( + LinearReferences, + load_references_from_config, +) +from fairchem.core.modules.normalization.normalizer import load_normalizers_from_config from fairchem.core.modules.scaling.compat import load_scales_compat from fairchem.core.modules.scaling.util import ensure_fitted from fairchem.core.modules.scheduler import LRScheduler @@ -183,6 +189,11 @@ def __init__( if distutils.is_master(): logging.info(yaml.dump(self.config, default_flow_style=False)) + self.elementrefs = {} + self.normalizers = {} + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None self.load() @abstractmethod @@ -206,6 +217,7 @@ def load(self) -> None: self.load_seed_from_config() self.load_logger() self.load_datasets() + self.load_references_and_normalizers() self.load_task() self.load_model() self.load_loss() @@ -241,12 +253,16 @@ def load_logger(self) -> None: def get_sampler( self, dataset, batch_size: int, shuffle: bool ) -> BalancedBatchSampler: - if "load_balancing" in self.config["optim"]: - balancing_mode = self.config["optim"]["load_balancing"] - force_balancing = True + balancing_mode = self.config["optim"].get("load_balancing", None) + on_error = self.config["optim"].get("load_balancing_on_error", None) + if balancing_mode is not None: + if on_error is None: + on_error = "raise" else: balancing_mode = "atoms" - force_balancing = False + + if on_error is None: + on_error = "warn_and_no_balance" if gp_utils.initialized(): num_replicas = gp_utils.get_dp_world_size() @@ -262,7 +278,7 @@ def get_sampler( device=self.device, mode=balancing_mode, shuffle=shuffle, - force_balancing=force_balancing, + on_error=on_error, seed=self.config["cmd"]["seed"], ) @@ -283,15 +299,26 @@ def load_datasets(self) -> None: self.val_loader = None self.test_loader = None + # This is hacky and scheduled to be removed next BE week + # move ['X_split_settings'] to ['splits'][X] + def convert_settings_to_split_settings(config, split_name): + config = copy.deepcopy(config) # make sure we dont modify the original + if f"{split_name}_split_settings" in config: + config["splits"] = { + split_name: config.pop(f"{split_name}_split_settings") + } + return config + # load train, val, test datasets if "src" in self.config["dataset"]: logging.info( f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}" ) - self.train_dataset = registry.get_dataset_class( - self.config["dataset"].get("format", "lmdb") - )(self.config["dataset"]) + self.train_dataset = create_dataset( + convert_settings_to_split_settings(self.config["dataset"], "train"), + "train", + ) self.train_sampler = self.get_sampler( self.train_dataset, self.config["optim"].get("batch_size", 1), @@ -302,6 +329,16 @@ def load_datasets(self) -> None: self.train_sampler, ) + if ( + "first_n" in self.config["dataset"] + or "sample_n" in self.config["dataset"] + or "max_atom" in self.config["dataset"] + ): + logging.warn( + "Dataset attributes (first_n/sample_n/max_atom) passed to all datasets! Please don't do this, its dangerous!\n" + + "Add them under each dataset 'train_split_settings'/'val_split_settings'/'test_split_settings'" + ) + if "src" in self.config["val_dataset"]: if self.config["val_dataset"].get("use_train_settings", True): val_config = self.config["dataset"].copy() @@ -309,9 +346,9 @@ def load_datasets(self) -> None: else: val_config = self.config["val_dataset"] - self.val_dataset = registry.get_dataset_class( - val_config.get("format", "lmdb") - )(val_config) + self.val_dataset = create_dataset( + convert_settings_to_split_settings(val_config, "val"), "val" + ) self.val_sampler = self.get_sampler( self.val_dataset, self.config["optim"].get( @@ -331,9 +368,9 @@ def load_datasets(self) -> None: else: test_config = self.config["test_dataset"] - self.test_dataset = registry.get_dataset_class( - test_config.get("format", "lmdb") - )(test_config) + self.test_dataset = create_dataset( + convert_settings_to_split_settings(test_config, "test"), "test" + ) self.test_sampler = self.get_sampler( self.test_dataset, self.config["optim"].get( @@ -368,20 +405,68 @@ def load_datasets(self) -> None: self.relax_sampler, ) - def load_task(self): - # Normalizer for the dataset. - + def load_references_and_normalizers(self): + """Load or create element references and normalizers from config""" # Is it troublesome that we assume any normalizer info is in train? What if there is no # training dataset? What happens if we just specify a test - normalizer = self.config["dataset"].get("transforms", {}).get("normalizer", {}) - self.normalizers = {} - if normalizer: - for target in normalizer: - self.normalizers[target] = Normalizer( - mean=normalizer[target].get("mean", 0), - std=normalizer[target].get("stdev", 1), + + elementref_config = ( + self.config["dataset"].get("transforms", {}).get("element_references") + ) + norms_config = self.config["dataset"].get("transforms", {}).get("normalizer") + elementrefs, normalizers = {}, {} + if distutils.is_master(): + if elementref_config is not None: + # put them in a list to allow broadcasting python objects + elementrefs = load_references_from_config( + elementref_config, + dataset=self.train_dataset, + seed=self.config["cmd"]["seed"], + checkpoint_dir=self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None, ) + if norms_config is not None: + normalizers = load_normalizers_from_config( + norms_config, + dataset=self.train_dataset, + seed=self.config["cmd"]["seed"], + checkpoint_dir=self.config["cmd"]["checkpoint_dir"] + if not self.is_debug + else None, + element_references=elementrefs, + ) + + # log out the values that will be used. + for output, normalizer in normalizers.items(): + logging.info( + f"Normalization values for output {output}: mean={normalizer.mean.item()}, rmsd={normalizer.rmsd.item()}." + ) + + # put them in a list to broadcast them + elementrefs, normalizers = [elementrefs], [normalizers] + distutils.broadcast_object_list( + object_list=elementrefs, src=0, device=self.device + ) + distutils.broadcast_object_list( + object_list=normalizers, src=0, device=self.device + ) + # make sure element refs and normalizers are on this device + self.elementrefs.update( + { + output: elementref.to(self.device) + for output, elementref in elementrefs[0].items() + } + ) + self.normalizers.update( + { + output: normalizer.to(self.device) + for output, normalizer in normalizers[0].items() + } + ) + + def load_task(self): self.output_targets = {} for target_name in self.config["outputs"]: self.output_targets[target_name] = self.config["outputs"][target_name] @@ -423,19 +508,7 @@ def load_model(self) -> None: if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") - # TODO: depreicated, remove. - bond_feat_dim = None - bond_feat_dim = self.config["model_attributes"].get("num_gaussians", 50) - - loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( - loader.dataset[0].x.shape[-1] - if loader - and hasattr(loader.dataset[0], "x") - and loader.dataset[0].x is not None - else None, - bond_feat_dim, - 1, **self.config["model_attributes"], ).to(self.device) @@ -455,7 +528,9 @@ def load_model(self) -> None: self.logger.log_summary({"num_params": self.model.num_params}) if distutils.initialized() and not self.config["noddp"]: - self.model = DistributedDataParallel(self.model, device_ids=[self.device]) + self.model = DistributedDataParallel( + self.model, device_ids=None if self.cpu else [self.device] + ) @property def _unwrapped_model(self): @@ -533,9 +608,20 @@ def load_checkpoint( target_key = key if target_key in self.normalizers: - self.normalizers[target_key].load_state_dict( + mkeys = self.normalizers[target_key].load_state_dict( checkpoint["normalizers"][key] ) + assert len(mkeys.missing_keys) == 0 + assert len(mkeys.unexpected_keys) == 0 + + for key, state_dict in checkpoint.get("elementrefs", {}).items(): + elementrefs = LinearReferences( + max_num_elements=len(state_dict["element_references"]) - 1 + ) + mkeys = elementrefs.load_state_dict(state_dict) + self.elementrefs[key] = elementrefs + assert len(mkeys.missing_keys) == 0 + assert len(mkeys.unexpected_keys) == 0 if self.scaler and checkpoint["amp"]: self.scaler.load_state_dict(checkpoint["amp"]) @@ -632,30 +718,40 @@ def save( training_state: bool = True, ) -> str | None: if not self.is_debug and distutils.is_master(): + state = { + "state_dict": self.model.state_dict(), + "normalizers": { + key: value.state_dict() for key, value in self.normalizers.items() + }, + "elementrefs": { + key: value.state_dict() for key, value in self.elementrefs.items() + }, + "config": self.config, + "val_metrics": metrics, + "amp": self.scaler.state_dict() if self.scaler else None, + } if training_state: - return save_checkpoint( + state.update( { "epoch": self.epoch, "step": self.step, - "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict() - if self.scheduler.scheduler_type != "Null" - else None, - "normalizers": { - key: value.state_dict() - for key, value in self.normalizers.items() - }, + "scheduler": ( + self.scheduler.scheduler.state_dict() + if self.scheduler.scheduler_type != "Null" + else None + ), "config": self.config, - "val_metrics": metrics, "ema": self.ema.state_dict() if self.ema else None, - "amp": self.scaler.state_dict() if self.scaler else None, "best_val_metric": self.best_val_metric, "primary_metric": self.evaluation_metrics.get( "primary_metric", self.evaluator.task_primary_metric[self.name], ), }, + ) + ckpt_path = save_checkpoint( + state, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) @@ -664,22 +760,13 @@ def save( self.ema.store() self.ema.copy_to() ckpt_path = save_checkpoint( - { - "state_dict": self.model.state_dict(), - "normalizers": { - key: value.state_dict() - for key, value in self.normalizers.items() - }, - "config": self.config, - "val_metrics": metrics, - "amp": self.scaler.state_dict() if self.scaler else None, - }, + state, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) if self.ema: self.ema.restore() - return ckpt_path + return ckpt_path return None def update_best( diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 95ca9e94ef..12524aec54 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -11,6 +11,7 @@ import os from collections import defaultdict from itertools import chain +from typing import TYPE_CHECKING import numpy as np import torch @@ -25,6 +26,9 @@ from fairchem.core.modules.scaling.util import ensure_fitted from fairchem.core.trainers.base_trainer import BaseTrainer +if TYPE_CHECKING: + from torch_geometric.data import Batch + @registry.register_trainer("ocp") @registry.register_trainer("energy") @@ -148,7 +152,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Get a batch. batch = next(train_loader_iter) - # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) @@ -227,16 +230,21 @@ def train(self, disable_eval_tqdm: bool = False) -> None: if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) - self.train_dataset.close_db() - if self.config.get("val_dataset", False): - self.val_dataset.close_db() - if self.config.get("test_dataset", False): - self.test_dataset.close_db() + def _denorm_preds(self, target_key: str, prediction: torch.Tensor, batch: Batch): + """Convert model output from a batch into raw prediction by denormalizing and adding references""" + # denorm the outputs + if target_key in self.normalizers: + prediction = self.normalizers[target_key](prediction) + + # add element references + if target_key in self.elementrefs: + prediction = self.elementrefs[target_key](prediction, batch) + + return prediction def _forward(self, batch): out = self.model(batch.to(self.device)) - ### TODO: Move into BaseModel in OCP 2.0 outputs = {} batch_size = batch.natoms.numel() num_atoms_in_batch = batch.natoms.sum() @@ -260,10 +268,7 @@ def _forward(self, batch): for subtarget_key in self.output_targets[target_key]["decomposition"]: irreps = self.output_targets[subtarget_key]["irrep_dim"] - _pred = out[subtarget_key] - - if self.normalizers.get(subtarget_key, False): - _pred = self.normalizers[subtarget_key].denorm(_pred) + _pred = self._denorm_preds(subtarget_key, out[subtarget_key], batch) ## Fill in the corresponding irreps prediction ## Reshape irrep prediction to (batch_size, irrep_dim) @@ -284,7 +289,6 @@ def _forward(self, batch): pred = pred.view(num_atoms_in_batch, -1) else: pred = pred.view(batch_size, -1) - outputs[target_key] = pred return outputs @@ -313,8 +317,6 @@ def _compute_loss(self, out, batch): natoms = natoms[mask] num_atoms_in_batch = natoms.numel() - if self.normalizers.get(target_name, False): - target = self.normalizers[target_name].norm(target) ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 if self.output_targets[target_name]["level"] == "atom": @@ -322,6 +324,14 @@ def _compute_loss(self, out, batch): else: target = target.view(batch_size, -1) + # to keep the loss coefficient weights balanced we remove linear references + # subtract element references from target data + if target_name in self.elementrefs: + target = self.elementrefs[target_name].dereference(target, batch) + # normalize the targets data + if target_name in self.normalizers: + target = self.normalizers[target_name].norm(target) + mult = loss_info["coefficient"] loss.append( mult @@ -379,11 +389,8 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): else: target = target.view(batch_size, -1) + out[target_name] = self._denorm_preds(target_name, out[target_name], batch) targets[target_name] = target - if self.normalizers.get(target_name, False): - out[target_name] = self.normalizers[target_name].denorm( - out[target_name] - ) targets["natoms"] = natoms out["natoms"] = natoms @@ -391,7 +398,7 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None): return evaluator.eval(out, targets, prev_metrics=metrics) # Takes in a new data source and generates predictions on it. - @torch.no_grad() + @torch.no_grad def predict( self, data_loader, @@ -425,7 +432,7 @@ def predict( predictions = defaultdict(list) - for _i, batch in tqdm( + for _, batch in tqdm( enumerate(data_loader), total=len(data_loader), position=rank, @@ -436,9 +443,7 @@ def predict( out = self._forward(batch) for target_key in self.config["outputs"]: - pred = out[target_key] - if self.normalizers.get(target_key, False): - pred = self.normalizers[target_key].denorm(pred) + pred = self._denorm_preds(target_key, out[target_key], batch) if per_image: ### Save outputs in desired precision, default float16 @@ -455,7 +460,8 @@ def predict( else: dtype = torch.float16 - pred = pred.cpu().detach().to(dtype) + pred = pred.detach().cpu().to(dtype) + ### Split predictions into per-image predictions if self.config["outputs"][target_key]["level"] == "atom": batch_natoms = batch.natoms @@ -516,7 +522,7 @@ def predict( return predictions - @torch.no_grad() + @torch.no_grad def run_relaxations(self, split="val"): ensure_fitted(self._unwrapped_model) diff --git a/tests/core/common/test_data_parallel_batch_sampler.py b/tests/core/common/test_data_parallel_batch_sampler.py index 6205042652..6bd8effe26 100644 --- a/tests/core/common/test_data_parallel_batch_sampler.py +++ b/tests/core/common/test_data_parallel_batch_sampler.py @@ -1,9 +1,16 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from __future__ import annotations -import functools -import tempfile from contextlib import contextmanager from pathlib import Path +import functools +import tempfile from typing import TypeVar import numpy as np @@ -13,11 +20,13 @@ from fairchem.core.common.data_parallel import ( BalancedBatchSampler, StatefulDistributedSampler, + UnsupportedDatasetError, + _balanced_partition, ) +from fairchem.core.datasets.base_dataset import BaseDataset, DatasetMetadata DATA = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -SIZE_ATOMS = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] -SIZE_NEIGHBORS = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4] +SIZE_ATOMS = [2, 20, 3, 51, 10, 11, 41, 31, 13, 14] T_co = TypeVar("T_co", covariant=True) @@ -28,23 +37,57 @@ def _temp_file(name: str): yield Path(tmpdir) / name +@pytest.fixture() +def valid_dataset(): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return DatasetMetadata(natoms=np.array(SIZE_ATOMS)) + + def __init__(self, data) -> None: + super().__init__(config={}) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def get_metadata(self, attr, idx): + assert attr == "natoms" + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + + return _Dataset(DATA) + + @pytest.fixture() def valid_path_dataset(): - class _Dataset(Dataset[T_co]): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return self.metadata + def __init__(self, data, fpath: Path) -> None: + super().__init__(config={}) self.data = data - self.metadata_path = fpath + self.metadata = DatasetMetadata(natoms=np.load(fpath)["natoms"]) def __len__(self): return len(self.data) def __getitem__(self, idx): - return self.data[idx] + metadata_attr = getattr(self._metadata, "natoms") + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] with _temp_file("metadata.npz") as file: np.savez( natoms=np.array(SIZE_ATOMS), - neighbors=np.array(SIZE_NEIGHBORS), file=file, ) yield _Dataset(DATA, file) @@ -52,8 +95,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_path_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data self.metadata_path = Path("/tmp/does/not/exist.np") @@ -68,8 +113,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data def __len__(self): @@ -81,99 +128,68 @@ def __getitem__(self, idx): return _Dataset(DATA) -def test_lowercase(invalid_dataset) -> None: - sampler = BalancedBatchSampler( - dataset=invalid_dataset, +def test_lowercase(valid_dataset) -> None: + _ = BalancedBatchSampler( + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="ATOMS", - throw_on_error=False, - seed=0 - ) - assert sampler.mode == "atoms" - - sampler = BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="NEIGHBORS", - throw_on_error=False, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.mode == "neighbors" def test_invalid_mode(invalid_dataset) -> None: with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='natoms'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="natoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='neighbors'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, - mode="nneighbors", - throw_on_error=True, - seed=0 + mode="neighbors", + on_error="raise", + seed=0, ) def test_invalid_dataset(invalid_dataset) -> None: - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", - ): - BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 - ) - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. Batches will not be balanced, which can incur significant overhead!", - ): - BalancedBatchSampler( + with pytest.raises(UnsupportedDatasetError): + sampler = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) def test_invalid_path_dataset(invalid_path_dataset) -> None: with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -182,13 +198,11 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. Batches will not be balanced, which can incur significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -197,70 +211,59 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) -def test_valid_dataset(valid_path_dataset) -> None: +def test_valid_dataset(valid_dataset, valid_path_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - seed=0 - ) - assert (sampler.sizes == np.array(SIZE_ATOMS)).all() - - sampler = BalancedBatchSampler( - dataset=valid_path_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="neighbors", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert (sampler.sizes == np.array(SIZE_NEIGHBORS)).all() + assert ( + sampler._get_natoms(list(range(len(SIZE_ATOMS)))) == np.array(SIZE_ATOMS) + ).all() -def test_disabled(valid_path_dataset) -> None: +def test_disabled(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode=False, - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_single_node(valid_path_dataset) -> None: +def test_single_node(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=1, device=None, mode="atoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: +def test_stateful_distributed_sampler_noshuffle(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -272,12 +275,12 @@ def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: def test_stateful_distributed_sampler_vs_distributed_sampler( - valid_path_dataset, + valid_dataset, ) -> None: for seed in [0, 100, 200]: for batch_size in range(1, 4): stateful_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=2, @@ -286,7 +289,7 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( drop_last=True, ) sampler = DistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, rank=0, num_replicas=2, seed=seed, @@ -296,10 +299,10 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( assert list(stateful_sampler) == list(sampler) -def test_stateful_distributed_sampler(valid_path_dataset) -> None: +def test_stateful_distributed_sampler(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -309,7 +312,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: offset_step = 2 loaded_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, seed=0, @@ -319,7 +322,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(loaded_sampler) == original_order[offset_step * batch_size :] diff_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -328,14 +331,14 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(diff_sampler) != original_order -def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: - fullset = set(range(len(valid_path_dataset))) +def test_stateful_distributed_sampler_numreplicas(valid_dataset) -> None: + fullset = set(range(len(valid_dataset))) for drop_last in [True, False]: for num_replicas in range(1, 4): for batch_size in [1]: samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -360,14 +363,14 @@ def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: def test_stateful_distributed_sampler_numreplicas_drop_last( - valid_path_dataset, + valid_dataset, ) -> None: - fullset = set(range(len(valid_path_dataset))) + fullset = set(range(len(valid_dataset))) for num_replicas in range(1, 4): for batch_size in range(1, 4): samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -387,3 +390,15 @@ def test_stateful_distributed_sampler_numreplicas_drop_last( ) assert len(concat_idxs) == len(np.unique(concat_idxs)) assert len(concat_idxs) == (len(fullset) // num_replicas) * num_replicas + + +def test_balancedbatchsampler_partition(valid_dataset) -> None: + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS), 4) + == [[1, 9, 5, 0], [7, 8, 2], [3], [6, 4]] + ) + # test case with local batch size = 1, GPU0(rank0) always gets smallest + # we cant say anything about the remaining elements because it is a heap + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS)[[3, 6, 7, 1]], 4)[0] == [3] + ) diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index 9743d35a2f..05c7475d2c 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -7,42 +7,112 @@ gather_from_model_parallel_region, scatter_to_model_parallel_region, ) -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) def _dummy_call(x): return x -@pytest.mark.parametrize("world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])]) # noqa: PT006 + +@pytest.mark.parametrize( + "world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])] +) # noqa: PT006 def test_basic_setup(world_size: int, input: torch.Tensor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True) - output = spawn_multi_process(config, _dummy_call, input) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True + ) + output = spawn_multi_process( + config, _dummy_call, init_pg_and_rank_and_launch_test, input + ) assert output == expected_output -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1]), torch.Tensor([2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1]), torch.Tensor([2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1]), torch.Tensor([2, 3])], + ), + (2, 2, torch.Tensor([0, 1, 2]), [torch.Tensor([0, 1]), torch.Tensor([2])]), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])], + ), + ], ) -def test_scatter_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_to_model_parallel_region, input) +def test_scatter_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, + scatter_to_model_parallel_region, + init_pg_and_rank_and_launch_test, + input, + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) + def scatter_gather_fn(input: torch.Tensor, dim: int = 0): x = scatter_to_model_parallel_region(input, dim) return gather_from_model_parallel_region(x, dim) -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2]), torch.Tensor([0,1,2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ], ) -def test_gather_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_gather_fn, input) +def test_gather_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, scatter_gather_fn, init_pg_and_rank_and_launch_test, input + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) diff --git a/tests/core/datasets/conftest.py b/tests/core/datasets/conftest.py new file mode 100644 index 0000000000..eb7be94994 --- /dev/null +++ b/tests/core/datasets/conftest.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +from ase import build +from ase.calculators.singlepoint import SinglePointCalculator + + +@pytest.fixture(scope="module") +def structures(): + structures = [ + build.molecule("H2O", vacuum=4), + build.bulk("Cu"), + build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), + ] + for atoms in structures: + calc = SinglePointCalculator( + atoms, + energy=1, + forces=atoms.positions, + # there is an issue with ASE db when writing a db with 3x3 stress if is flattened to (9,) and then + # errors when trying to read it + stress=np.random.random((6,)), + ) + atoms.calc = calc + atoms.info["extensive_property"] = 3 * len(atoms) + atoms.info["tensor_property"] = np.random.random((6, 6)) + + structures[2].set_pbc(True) + return structures diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 3e401f5f4b..676805c653 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -15,26 +15,6 @@ ) from fairchem.core.datasets.lmdb_database import LMDBDatabase -structures = [ - build.molecule("H2O", vacuum=4), - build.bulk("Cu"), - build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), -] -for atoms in structures: - calc = SinglePointCalculator( - atoms, - energy=1, - forces=atoms.positions, - # there is an issue with ASE db when writing a db with 3x3 stress it is flattened to (9,) and then - # errors when trying to read it - stress=np.random.random((6,)), - ) - atoms.calc = calc - atoms.info["extensive_property"] = 3 * len(atoms) - atoms.info["tensor_property"] = np.random.random((6, 6)) - -structures[2].set_pbc(True) - @pytest.fixture( params=[ @@ -46,7 +26,7 @@ "aselmdb_dataset", ], ) -def ase_dataset(request, tmp_path_factory): +def ase_dataset(request, structures, tmp_path_factory): tmp_path = tmp_path_factory.mktemp("dataset") mult = 1 a2g_args = { @@ -110,7 +90,7 @@ def ase_dataset(request, tmp_path_factory): return dataset, mult -def test_ase_dataset(ase_dataset): +def test_ase_dataset(ase_dataset, structures): dataset, mult = ase_dataset assert len(dataset) == mult * len(structures) for data in dataset: @@ -121,7 +101,7 @@ def test_ase_dataset(ase_dataset): assert isinstance(data.extensive_property, int) -def test_ase_read_dataset(tmp_path) -> None: +def test_ase_read_dataset(tmp_path, structures): # unfortunately there is currently no clean (already implemented) way to save atoms.info when saving # individual structures - so test separately for i, structure in enumerate(structures): @@ -137,13 +117,16 @@ def test_ase_read_dataset(tmp_path) -> None: assert len(dataset) == len(structures) data = dataset[0] del data - dataset.close_db() -def test_ase_metadata_guesser(ase_dataset) -> None: +def test_ase_get_metadata(ase_dataset): + assert ase_dataset[0].get_metadata("natoms", [0])[0] == 3 + + +def test_ase_metadata_guesser(ase_dataset): dataset, _ = ase_dataset - metadata = dataset.get_metadata() + metadata = dataset.sample_property_metadata() # Confirm energy metadata guessed properly assert metadata["targets"]["energy"]["extensive"] is False @@ -171,7 +154,7 @@ def test_ase_metadata_guesser(ase_dataset) -> None: assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" -def test_db_add_delete(tmp_path) -> None: +def test_db_add_delete(tmp_path, structures): database = db.connect(tmp_path / "asedb.db") for _i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) @@ -192,10 +175,9 @@ def test_db_add_delete(tmp_path) -> None: dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) assert len(dataset) == orig_len + len(new_structures) - 1 - dataset.close_db() -def test_ase_multiread_dataset(tmp_path) -> None: +def test_ase_multiread_dataset(tmp_path): atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] energies = np.linspace(1, 0, len(atoms_objects)) @@ -224,13 +206,17 @@ def test_ase_multiread_dataset(tmp_path) -> None: f.write(f"{tmp_path / 'test.traj'} {len(atoms_objects)}") dataset = AseReadMultiStructureDataset( - config={"index_file": str(tmp_path / "test_index_file")}, + config={ + "src": str(tmp_path), + "index_file": str(tmp_path / "test_index_file"), + }, ) assert len(dataset) == len(atoms_objects) dataset = AseReadMultiStructureDataset( config={ + "src": str(tmp_path), "index_file": str(tmp_path / "test_index_file"), "a2g_args": { "r_energy": True, diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py new file mode 100644 index 0000000000..d90271c53d --- /dev/null +++ b/tests/core/datasets/test_create_dataset.py @@ -0,0 +1,180 @@ +import os +import numpy as np +import pytest + +from fairchem.core.datasets import LMDBDatabase, create_dataset +from fairchem.core.datasets.base_dataset import BaseDataset +import tempfile +from fairchem.core.trainers.base_trainer import BaseTrainer + + +@pytest.fixture() +def lmdb_database(structures): + with tempfile.TemporaryDirectory() as tmpdirname: + num_atoms = [] + asedb_fn = f"{tmpdirname}/asedb.lmdb" + with LMDBDatabase(asedb_fn) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + num_atoms.append(len(atoms)) + np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms) + yield asedb_fn + + +def test_real_dataset_config(lmdb_database): + class TestTrainer(BaseTrainer): + def __init__(self, config): + self.config = config + + def train(self, x): + return None + + def get_sampler(self, *args, **kwargs): + return None + + def get_dataloader(self, *args, **kwargs): + return None + + config = { + "model_attributes": {}, + "optim": {"batch_size": 0}, + "dataset": { + "format": "ase_db", + "src": str(lmdb_database), + "first_n": 2, + "key_mapping": { + "y": "energy", + "force": "forces", + }, + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0, "stdev": 2.887317180633545}, + } + }, + }, + "val_dataset": {"src": str(lmdb_database)}, + "test_dataset": {}, + "relax_dataset": None, + } + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 2 + + # modify the config for split and see if it works as expected + config["dataset"].pop("first_n") + config["dataset"]["train_split_settings"] = {"first_n": 2} + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 3 + + +@pytest.mark.parametrize("max_atoms", [3, None]) +@pytest.mark.parametrize( + "key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)] +) +def test_create_dataset(key, value, max_atoms, structures, lmdb_database): + # now create a config + config = { + "format": "ase_db", + "src": str(lmdb_database), + key: value, + "max_atoms": max_atoms, + } + + dataset = create_dataset(config, split="train") + if max_atoms is not None: + structures = [s for s in structures if len(s) <= max_atoms] + assert all( + natoms <= max_atoms + for natoms in dataset.metadata.natoms[range(len(dataset))] + ) + if key == "first_n": # this assumes first_n are not shuffled + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures[:value], dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures[:value], dataset) + ) + elif key == "sample_n": + assert len(dataset) == value + else: # no shuffle all of them are in there + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures, dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures, dataset) + ) + + +# make sure we cant sample more than the number of elements in the dataset with sample_n +def test_sample_n_dataset(lmdb_database): + with pytest.raises(ValueError): + _ = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 100, + }, + split="train", + ) + + +def test_diff_seed_sample_dataset(lmdb_database): + dataset_a = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + assert (dataset_a.indices == dataset_b.indices).all() + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 1, + }, + split="train", + ) + assert not (dataset_a.indices == dataset_b.indices).all() + + +def test_del_dataset(): + class _Dataset(BaseDataset): + def __init__(self, fn) -> None: + super().__init__(config={}) + self.fn = fn + open(self.fn, "a").close() + + def __del__(self): + os.remove(self.fn) + + with tempfile.TemporaryDirectory() as tmpdirname: + fn = tmpdirname + "/test" + d = _Dataset(fn) + del d + assert not os.path.exists(fn) diff --git a/tests/core/datasets/test_lmdb_dataset.py b/tests/core/datasets/test_lmdb_dataset.py new file mode 100644 index 0000000000..f922e32ce3 --- /dev/null +++ b/tests/core/datasets/test_lmdb_dataset.py @@ -0,0 +1,29 @@ +from fairchem.core.datasets.base_dataset import create_dataset + +import numpy as np + +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes + + +def test_load_lmdb_dataset(tutorial_dataset_path): + + lmdb_path = str(tutorial_dataset_path / "s2ef/val_20") + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args(["--data-path", lmdb_path]) + make_lmdb_sizes(args) + + config = { + "format": "lmdb", + "src": lmdb_path, + } + + dataset = create_dataset(config, split="val") + + assert dataset.get_metadata("natoms", 0) == dataset[0].natoms + + all_metadata_natoms = np.array(dataset.get_metadata("natoms", range(len(dataset)))) + all_natoms = np.array([datapoint.natoms for datapoint in dataset]) + + assert (all_natoms == all_metadata_natoms).all() diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 8387d6e053..9a68c4771c 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -7,13 +7,20 @@ from pathlib import Path import numpy as np +import numpy.testing as npt import pytest import yaml from tensorboard.backend.event_processing.event_accumulator import EventAccumulator from fairchem.core._cli import Runner from fairchem.core.common.flags import flags +from fairchem.core.common.test_utils import ( + PGConfig, + init_env_rank_and_launch_test, + spawn_multi_process, +) from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes setup_logging() @@ -21,9 +28,32 @@ @pytest.fixture() def configs(): return { + "scn": Path("tests/core/models/test_configs/test_scn.yml"), "escn": Path("tests/core/models/test_configs/test_escn.yml"), - "gemnet": Path("tests/core/models/test_configs/test_gemnet.yml"), + "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), + "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), + "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), + "gemnet_dt_hydra": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" + ), + "gemnet_dt_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" + ), + "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), + "gemnet_oc_hydra": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" + ), + "gemnet_oc_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" + ), + "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), + "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), + "painn": Path("tests/core/models/test_configs/test_painn.yml"), + "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), + "equiformer_v2_hydra": Path( + "tests/core/models/test_configs/test_equiformerv2_hydra.yml" + ), } @@ -37,21 +67,56 @@ def tutorial_val_src(tutorial_dataset_path): return tutorial_dataset_path / "s2ef/val_20" -def oc20_lmdb_train_and_val_from_paths(train_src, val_src, test_src=None): +def oc20_lmdb_train_and_val_from_paths( + train_src, val_src, test_src=None, otf_norms=False +): datasets = {} if train_src is not None: datasets["train"] = { "src": train_src, - "normalize_labels": True, - "target_mean": -0.7554450631141663, - "target_std": 2.887317180633545, - "grad_target_mean": 0.0, - "grad_target_std": 2.887317180633545, + "format": "lmdb", + "key_mapping": {"y": "energy", "force": "forces"}, } + if otf_norms is True: + datasets["train"].update( + { + "transforms": { + "element_references": { + "fit": { + "targets": ["energy"], + "batch_size": 4, + "num_batches": 10, + "driver": "gelsd", + } + }, + "normalizer": { + "fit": { + "targets": {"energy": None, "forces": {"mean": 0.0}}, + "batch_size": 4, + "num_batches": 10, + } + }, + } + } + ) + else: + datasets["train"].update( + { + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0.0, "stdev": 2.887317180633545}, + } + } + } + ) if val_src is not None: - datasets["val"] = {"src": val_src} + datasets["val"] = {"src": val_src, "format": "lmdb"} if test_src is not None: - datasets["test"] = {"src": test_src} + datasets["test"] = {"src": test_src, "format": "lmdb"} return datasets @@ -84,6 +149,7 @@ def _run_main( update_run_args_with=None, save_checkpoint_to=None, save_predictions_to=None, + world_size=0, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" @@ -91,9 +157,9 @@ def _run_main( yaml_config = yaml.safe_load(yaml_file) if update_dict_with is not None: yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) - run_args = { "run_dir": rundir, "logdir": f"{rundir}/logs", @@ -110,7 +176,19 @@ def _run_main( for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) config = build_config(args, override_args) - Runner()(config) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + spawn_multi_process( + pg_config, + Runner(distributed=True), + init_env_rank_and_launch_test, + config, + ) + else: + Runner()(config) if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") @@ -125,11 +203,6 @@ def _run_main( ) -@pytest.fixture(scope="class") -def torch_tempdir(tmpdir_factory): - return tmpdir_factory.mktemp("torch_tempdir") - - """ These tests are intended to be as quick as possible and test only that the network is runnable and outputs training+validation to tensorboard output These should catch errors such as shape mismatches or otherways to code wise break a network @@ -137,12 +210,7 @@ def torch_tempdir(tmpdir_factory): class TestSmoke: - def smoke_test_train( - self, - model_name, - input_yaml, - tutorial_val_src, - ): + def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False): with tempfile.TemporaryDirectory() as tempdirname: # first train a very simple model, checkpoint train_rundir = Path(tempdirname) / "train" @@ -153,11 +221,12 @@ def smoke_test_train( rundir=str(train_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), + otf_norms=otf_norms, ), }, save_checkpoint_to=checkpoint_path, @@ -174,11 +243,12 @@ def smoke_test_train( rundir=str(predictions_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), test_src=str(tutorial_val_src), + otf_norms=otf_norms, ), }, update_run_args_with={ @@ -188,30 +258,150 @@ def smoke_test_train( save_predictions_to=predictions_filename, ) + if otf_norms is True: + norm_path = glob.glob( + str(train_rundir / "checkpoints" / "*" / "normalizers.pt") + ) + assert len(norm_path) == 1 + assert os.path.isfile(norm_path[0]) + ref_path = glob.glob( + str(train_rundir / "checkpoints" / "*" / "element_references.pt") + ) + assert len(ref_path) == 1 + assert os.path.isfile(ref_path[0]) + # verify predictions from train and predict are identical energy_from_train = np.load(training_predictions_filename)["energy"] energy_from_checkpoint = np.load(predictions_filename)["energy"] - assert np.isclose(energy_from_train, energy_from_checkpoint).all() + npt.assert_allclose( + energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6 + ) + # not all models are tested with otf normalization estimation + # only gemnet_oc, escn, equiformer, and their hydra versions @pytest.mark.parametrize( - "model_name", + ("model_name", "otf_norms"), [ - pytest.param("gemnet", id="gemnet"), - pytest.param("escn", id="escn"), - pytest.param("equiformer_v2", id="equiformer_v2"), + ("schnet", False), + ("scn", False), + ("gemnet_dt", False), + ("gemnet_dt_hydra", False), + ("gemnet_dt_hydra_grad", False), + ("gemnet_oc", False), + ("gemnet_oc", True), + ("gemnet_oc_hydra", False), + ("gemnet_oc_hydra", True), + ("gemnet_oc_hydra_grad", False), + ("dimenet++", False), + ("dimenet++_hydra", False), + ("painn", False), + ("painn_hydra", False), + ("escn", False), + ("escn", True), + ("escn_hydra", False), + ("escn_hydra", True), + ("equiformer_v2", False), + ("equiformer_v2", True), + ("equiformer_v2_hydra", False), + ("equiformer_v2_hydra", True), ], ) def test_train_and_predict( self, model_name, + otf_norms, configs, tutorial_val_src, ): self.smoke_test_train( - model_name=model_name, input_yaml=configs[model_name], tutorial_val_src=tutorial_val_src, + otf_norms=otf_norms, + ) + + def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"use_pbc_single": True}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + ) + + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_balanced_batch_sampler_ddp( + self, world_size, ddp, configs, tutorial_val_src, torch_deterministic + ): + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args( + ["--data-path", str(tutorial_val_src)] ) + make_lmdb_sizes(args) + + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1, "load_balancing": "atoms"}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) # train for a few steps and confirm same seeds get same results def test_different_seeds( @@ -290,9 +480,9 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet", 0.4, 0.06, id="gemnet"), - pytest.param("escn", 0.4, 0.06, id="escn"), - pytest.param("equiformer_v2", 0.4, 0.06, id="equiformer_v2"), + pytest.param("gemnet_oc", 0.41, 0.06, id="gemnet_oc"), + pytest.param("escn", 0.41, 0.06, id="escn"), + pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], ) def test_train_optimization( diff --git a/tests/core/models/test_configs/test_dpp.yml b/tests/core/models/test_configs/test_dpp.yml new file mode 100755 index 0000000000..a79294bd15 --- /dev/null +++ b/tests/core/models/test_configs/test_dpp.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: dimenetplusplus #_bbwheads + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_dpp_hydra.yml b/tests/core/models/test_configs/test_dpp_hydra.yml new file mode 100755 index 0000000000..1120cc905f --- /dev/null +++ b/tests/core/models/test_configs/test_dpp_hydra.yml @@ -0,0 +1,55 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: dimenetplusplus_backbone + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 1 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + heads: + energy: + module: dimenetplusplus_energy_and_force_head + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_equiformerv2.yml b/tests/core/models/test_configs/test_equiformerv2.yml index 54d5e61c95..8c5c200fdf 100644 --- a/tests/core/models/test_configs/test_equiformerv2.yml +++ b/tests/core/models/test_configs/test_equiformerv2.yml @@ -1,6 +1,53 @@ +trainer: forces + +logger: + name: tensorboard +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold + primary_metric: forces_mae -trainer: forces +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae model: name: equiformer_v2 @@ -45,47 +92,3 @@ model: proj_drop: 0.0 weight_init: 'normal' # ['uniform', 'normal'] - -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - -logger: - name: tensorboard - -task: - dataset: lmdb - type: regression - metric: mae - primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 - - -optim: - batch_size: 5 - eval_batch_size: 2 - num_workers: 0 - lr_initial: 0.0025 - optimizer: AdamW - optimizer_params: {"amsgrad": True,weight_decay: 0.0} - eval_every: 190 - max_epochs: 50 - force_coefficient: 20 - scheduler: "Null" - energy_coefficient: 1 - clip_grad_norm: 20 - loss_energy: mae - loss_force: l2mae diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml new file mode 100644 index 0000000000..4c00fe6a2e --- /dev/null +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -0,0 +1,98 @@ + + +trainer: forces + +model: + name: hydra + backbone: + model: equiformer_v2_backbone + use_pbc: True + regress_forces: True + otf_graph: True + + enforce_max_neighbors_strictly: False + + max_neighbors: 1 + max_radius: 12.0 + max_num_elements: 90 + + num_layers: 1 + sphere_channels: 4 + attn_hidden_channels: 4 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. + num_heads: 1 + attn_alpha_channels: 4 # Not used when `use_s2_act_attn` is True. + attn_value_channels: 4 + ffn_hidden_channels: 8 + norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] + + lmax_list: [1] + mmax_list: [1] + grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. + + num_sphere_samples: 128 + + edge_channels: 32 + use_atom_edge_embedding: True + distance_function: 'gaussian' + num_distance_basis: 16 # not used + + attn_activation: 'silu' + use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. + ffn_activation: 'silu' # ['silu', 'swiglu'] + use_gate_act: False # [True, False] Switch between gate activation and S2 activation + use_grid_mlp: False # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. + + alpha_drop: 0.0 # [0.0, 0.1] + drop_path_rate: 0.0 # [0.0, 0.05] + proj_drop: 0.0 + + weight_init: 'normal' # ['uniform', 'normal'] + heads: + energy: + module: equiformer_v2_energy_head + forces: + module: equiformer_v2_force_head + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_escn.yml b/tests/core/models/test_configs/test_escn.yml index 5148e409e5..5848587cdd 100644 --- a/tests/core/models/test_configs/test_escn.yml +++ b/tests/core/models/test_configs/test_escn.yml @@ -1,31 +1,37 @@ trainer: forces -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - logger: name: tensorboard -task: - dataset: lmdb - type: regression - metric: mae +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 model: name: escn diff --git a/tests/core/models/test_configs/test_escn_hydra.yml b/tests/core/models/test_configs/test_escn_hydra.yml new file mode 100644 index 0000000000..ba5db1f53e --- /dev/null +++ b/tests/core/models/test_configs/test_escn_hydra.yml @@ -0,0 +1,67 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: escn_backbone + num_layers: 2 + max_neighbors: 10 + cutoff: 12.0 + sphere_channels: 8 + hidden_channels: 8 + lmax_list: [2] + mmax_list: [2] + num_sphere_samples: 64 + distance_function: "gaussian" + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + heads: + energy: + module: escn_energy_head + forces: + module: escn_force_head + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_dt.yml b/tests/core/models/test_configs/test_gemnet_dt.yml new file mode 100644 index 0000000000..b04b6dfda0 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt.yml @@ -0,0 +1,79 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: gemnet_t + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml new file mode 100644 index 0000000000..a612741470 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml @@ -0,0 +1,86 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + heads: + energy: + module: gemnet_t_energy_and_grad_force_head + forces: + module: gemnet_t_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml new file mode 100644 index 0000000000..83d46bdd4d --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml @@ -0,0 +1,84 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: False + heads: + energy_and_forces: + module: gemnet_t_energy_and_grad_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet.yml b/tests/core/models/test_configs/test_gemnet_oc.yml similarity index 75% rename from tests/core/models/test_configs/test_gemnet.yml rename to tests/core/models/test_configs/test_gemnet_oc.yml index a720583608..f1c0d01c3a 100644 --- a/tests/core/models/test_configs/test_gemnet.yml +++ b/tests/core/models/test_configs/test_gemnet_oc.yml @@ -1,34 +1,37 @@ - - - trainer: forces -dataset: - train: - src: tutorial_dset/s2ef/train_100/ - normalize_labels: True - target_mean: -0.7554450631141663 - target_std: 2.887317180633545 - grad_target_mean: 0.0 - grad_target_std: 2.887317180633545 - val: - format: lmdb - src: tutorial_dset/s2ef/val_20/ - logger: name: tensorboard -task: - dataset: lmdb - type: regression - metric: mae +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold primary_metric: forces_mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - prediction_dtype: float32 model: name: gemnet_oc diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml new file mode 100644 index 0000000000..97343e90e6 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml @@ -0,0 +1,112 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + forces: + module: gemnet_oc_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml new file mode 100644 index 0000000000..334c3cb4db --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml @@ -0,0 +1,109 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: False + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_painn.yml b/tests/core/models/test_configs/test_painn.yml new file mode 100644 index 0000000000..c1f24d0bb5 --- /dev/null +++ b/tests/core/models/test_configs/test_painn.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: painn #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_painn_hydra.yml b/tests/core/models/test_configs/test_painn_hydra.yml new file mode 100644 index 0000000000..0b39aa1731 --- /dev/null +++ b/tests/core/models/test_configs/test_painn_hydra.yml @@ -0,0 +1,58 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: painn_backbone #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + heads: + energy: + module: painn_energy_head + forces: + module: painn_force_head + + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_schnet.yml b/tests/core/models/test_configs/test_schnet.yml new file mode 100755 index 0000000000..97faf3962a --- /dev/null +++ b/tests/core/models/test_configs/test_schnet.yml @@ -0,0 +1,45 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: schnet + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + cutoff: 6.0 + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 64. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 20 + eval_batch_size: 20 + eval_every: 10000 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + warmup_factor: 0.2 + max_epochs: 15 diff --git a/tests/core/models/test_configs/test_scn.yml b/tests/core/models/test_configs/test_scn.yml new file mode 100755 index 0000000000..c080c48557 --- /dev/null +++ b/tests/core/models/test_configs/test_scn.yml @@ -0,0 +1,59 @@ +# A total of 64 32GB GPUs were used for training. +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: scn + num_interactions: 2 + hidden_channels: 16 + sphere_channels: 8 + sphere_channels_reduce: 8 + num_sphere_samples: 8 + num_basis_functions: 8 + distance_function: "gaussian" + show_timing_info: False + max_num_neighbors: 40 + cutoff: 8.0 + lmax: 4 + num_bands: 2 + use_grid: True + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + +optim: + batch_size: 2 + eval_batch_size: 1 + num_workers: 2 + lr_initial: 0.0004 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + eval_every: 5000 + lr_gamma: 0.3 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 260000 + - 340000 + - 420000 + - 500000 + - 800000 + - 1000000 + warmup_steps: 100 + warmup_factor: 0.2 + max_epochs: 12 + clip_grad_norm: 100 + ema_decay: 0.999 diff --git a/tests/core/models/test_dimenetpp.py b/tests/core/models/test_dimenetpp.py index 76a546037b..d1daec728b 100644 --- a/tests/core/models/test_dimenetpp.py +++ b/tests/core/models/test_dimenetpp.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("dimenetplusplus")( - None, - 32, - 1, cutoff=6.0, regress_forces=True, use_pbc=False, diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 34ed79ba2b..3194dd2df7 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -18,7 +18,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.models.equiformer_v2.so3 import ( @@ -59,9 +63,6 @@ def _load_model(): checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("equiformer_v2")( - None, - -1, - 1, use_pbc=True, regress_forces=True, otf_graph=True, @@ -140,7 +141,9 @@ def test_energy_force_shape(self, snapshot): def test_ddp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 1 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -151,7 +154,9 @@ def test_ddp(self, snapshot): def test_gp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 2 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -225,4 +230,3 @@ def sign(x): embedding._l_primary(c) lp = embedding.embedding.clone() (test_matrix_lp == lp).all() - diff --git a/tests/core/models/test_gemnet.py b/tests/core/models/test_gemnet.py index 3fa0c6babc..b4c5414cc4 100644 --- a/tests/core/models/test_gemnet.py +++ b/tests/core/models/test_gemnet.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("gemnet_t")( - None, - -1, - 1, cutoff=6.0, num_spherical=7, num_radial=128, diff --git a/tests/core/models/test_gemnet_oc.py b/tests/core/models/test_gemnet_oc.py index d84669750f..7729c14483 100644 --- a/tests/core/models/test_gemnet_oc.py +++ b/tests/core/models/test_gemnet_oc.py @@ -58,9 +58,6 @@ def load_model(request) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_gemnet_oc_scaling_mismatch.py b/tests/core/models/test_gemnet_oc_scaling_mismatch.py index 8f1c36d277..29ea40c0fa 100644 --- a/tests/core/models/test_gemnet_oc_scaling_mismatch.py +++ b/tests/core/models/test_gemnet_oc_scaling_mismatch.py @@ -35,9 +35,6 @@ def test_no_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -111,9 +108,6 @@ def test_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -189,9 +183,6 @@ def test_no_file_exists(self) -> None: with pytest.raises(ValueError): registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -245,9 +236,6 @@ def test_not_fitted(self) -> None: setup_imports() model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_schnet.py b/tests/core/models/test_schnet.py index aa704604f7..3dd21be4e1 100644 --- a/tests/core/models/test_schnet.py +++ b/tests/core/models/test_schnet.py @@ -46,7 +46,7 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("schnet")( - None, 32, 1, cutoff=6.0, regress_forces=True, use_pbc=True + cutoff=6.0, regress_forces=True, use_pbc=True ) request.cls.model = model diff --git a/tests/core/modules/conftest.py b/tests/core/modules/conftest.py new file mode 100644 index 0000000000..1b1e4ab7e6 --- /dev/null +++ b/tests/core/modules/conftest.py @@ -0,0 +1,48 @@ +from itertools import product +from random import choice +import pytest +import numpy as np +from pymatgen.core.periodic_table import Element +from pymatgen.core import Structure + +from fairchem.core.datasets import LMDBDatabase, AseDBDataset + + +@pytest.fixture(scope="session") +def dummy_element_refs(): + # create some dummy elemental energies from ionic radii (ignore deuterium and tritium included in pmg) + return np.concatenate( + [[0], [e.average_ionic_radius for e in Element if e.name not in ("D", "T")]] + ) + + +@pytest.fixture(scope="session") +def max_num_elements(dummy_element_refs): + return len(dummy_element_refs) - 1 + + +@pytest.fixture(scope="session") +def dummy_binary_dataset(tmpdir_factory, dummy_element_refs): + # a dummy dataset with binaries with energy that depends on composition only plus noise + all_binaries = list(product(list(Element), repeat=2)) + rng = np.random.default_rng(seed=0) + + tmpdir = tmpdir_factory.mktemp("dataset") + with LMDBDatabase(tmpdir / "dummy.aselmdb") as db: + for _ in range(1000): + elements = choice(all_binaries) + structure = Structure.from_prototype("cscl", species=elements, a=2.0) + energy = ( + sum(e.average_ionic_radius for e in elements) + + 0.05 * rng.random() * dummy_element_refs.mean() + ) + atoms = structure.to_ase_atoms() + db.write(atoms, data={"energy": energy, "forces": rng.random((2, 3))}) + + dataset = AseDBDataset( + config={ + "src": str(tmpdir / "dummy.aselmdb"), + "a2g_args": {"r_data_keys": ["energy", "forces"]}, + } + ) + return dataset diff --git a/tests/core/modules/test_element_references.py b/tests/core/modules/test_element_references.py new file mode 100644 index 0000000000..62928b623c --- /dev/null +++ b/tests/core/modules/test_element_references.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import numpy as np +import numpy.testing as npt +import pytest +import torch + +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.normalization.element_references import ( + LinearReferences, + create_element_references, + fit_linear_references, +) + + +@pytest.fixture(scope="session", params=(True, False)) +def element_refs(dummy_binary_dataset, max_num_elements, request): + return fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + shuffle=False, + max_num_elements=max_num_elements, + seed=0, + use_numpy=request.param, + ) + + +def test_apply_linear_references( + element_refs, dummy_binary_dataset, dummy_element_refs +): + max_noise = 0.05 * dummy_element_refs.mean() + + # check that removing element refs keeps only values within max noise + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + energy = batch.energy.clone().view(len(batch), -1) + deref_energy = element_refs["energy"].dereference(energy, batch) + assert all(deref_energy <= max_noise) + + # and check that we recover the total energy from applying references + ref_energy = element_refs["energy"](deref_energy, batch) + assert torch.allclose(ref_energy, energy) + + +def test_create_element_references(element_refs, tmp_path): + # test from state dict + sdict = element_refs["energy"].state_dict() + + refs = create_element_references(state_dict=sdict) + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # test from saved stated dict + torch.save(sdict, tmp_path / "linref.pt") + refs = create_element_references(file=tmp_path / "linref.pt") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # from a legacy numpy npz file + np.savez( + tmp_path / "linref.npz", coeff=element_refs["energy"].element_references.numpy() + ) + refs = create_element_references(file=tmp_path / "linref.npz") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + # from a numpy npz file + np.savez( + tmp_path / "linref.npz", + element_references=element_refs["energy"].element_references.numpy(), + ) + + refs = create_element_references(file=tmp_path / "linref.npz") + assert isinstance(refs, LinearReferences) + assert torch.allclose( + element_refs["energy"].element_references, refs.element_references + ) + + +def test_fit_linear_references( + element_refs, dummy_binary_dataset, max_num_elements, dummy_element_refs +): + # create the composition matrix + energy = np.array([d.energy for d in dummy_binary_dataset]) + cmatrix = np.vstack( + [ + np.bincount(d.atomic_numbers.int().numpy(), minlength=max_num_elements + 1) + for d in dummy_binary_dataset + ] + ) + mask = cmatrix.sum(axis=0) != 0.0 + + # fit using numpy + element_refs_np = np.zeros(max_num_elements + 1) + element_refs_np[mask] = np.linalg.lstsq(cmatrix[:, mask], energy, rcond=None)[0] + + # length is max_num_elements + 1, since H starts at 1 + assert len(element_refs["energy"].element_references) == max_num_elements + 1 + # first element is dummy, should always be zero + assert element_refs["energy"].element_references[0] == 0.0 + # elements not present should be zero + npt.assert_allclose(element_refs["energy"].element_references.numpy()[~mask], 0.0) + # torch fit vs numpy fit + npt.assert_allclose( + element_refs_np, element_refs["energy"].element_references.numpy(), atol=1e-5 + ) + # close enough to ground truth w/out noise + npt.assert_allclose( + dummy_element_refs[mask], + element_refs["energy"].element_references.numpy()[mask], + atol=5e-2, + ) + + +def test_fit_seed_no_seed(dummy_binary_dataset, max_num_elements): + refs_seed = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=0, + ) + refs_seed1 = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=0, + ) + refs_noseed = fit_linear_references( + ["energy"], + dataset=dummy_binary_dataset, + batch_size=16, + num_batches=len(dummy_binary_dataset) // 16 - 2, + shuffle=True, + max_num_elements=max_num_elements, + seed=1, + ) + + assert torch.allclose( + refs_seed["energy"].element_references, + refs_seed1["energy"].element_references, + atol=1e-6, + ) + assert not torch.allclose( + refs_seed["energy"].element_references, + refs_noseed["energy"].element_references, + atol=1e-6, + ) diff --git a/tests/core/modules/test_normalizer.py b/tests/core/modules/test_normalizer.py new file mode 100644 index 0000000000..b0d4a44040 --- /dev/null +++ b/tests/core/modules/test_normalizer.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.normalization.normalizer import ( + Normalizer, + create_normalizer, + fit_normalizers, +) + + +@pytest.fixture(scope="session") +def normalizers(dummy_binary_dataset): + return fit_normalizers( + ["energy", "forces"], + override_values={"forces": {"mean": 0.0}}, + dataset=dummy_binary_dataset, + batch_size=16, + shuffle=False, + ) + + +def test_norm_denorm(normalizers, dummy_binary_dataset, dummy_element_refs): + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + # test norm and denorm + for target, normalizer in normalizers.items(): + normed = normalizer.norm(batch[target]) + assert torch.allclose( + (batch[target] - normalizer.mean) / normalizer.rmsd, normed + ) + assert torch.allclose( + normalizer.rmsd * normed + normalizer.mean, normalizer(normed) + ) + + +def test_create_normalizers(normalizers, dummy_binary_dataset, tmp_path): + # test that forces mean was overriden + assert normalizers["forces"].mean.item() == 0.0 + + # test from state dict + sdict = normalizers["energy"].state_dict() + + norm = create_normalizer(state_dict=sdict) + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # test from saved stated dict + torch.save(sdict, tmp_path / "norm.pt") + norm = create_normalizer(file=tmp_path / "norm.pt") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from a legacy numpy npz file + np.savez( + tmp_path / "norm.npz", + mean=normalizers["energy"].mean.numpy(), + std=normalizers["energy"].rmsd.numpy(), + ) + norm = create_normalizer(file=tmp_path / "norm.npz") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from a new npz file + np.savez( + tmp_path / "norm.npz", + mean=normalizers["energy"].mean.numpy(), + rmsd=normalizers["energy"].rmsd.numpy(), + ) + norm = create_normalizer(file=tmp_path / "norm.npz") + assert isinstance(norm, Normalizer) + assert norm.state_dict() == sdict + + # from tensor directly + batch = data_list_collater(list(dummy_binary_dataset), otf_graph=True) + norm = create_normalizer(tensor=batch.energy) + assert isinstance(norm, Normalizer) + # assert norm.state_dict() == sdict + # not sure why the above fails + new_sdict = norm.state_dict() + for key in sdict: + assert torch.allclose(new_sdict[key], sdict[key]) + + # passing values directly + norm = create_normalizer( + mean=batch.energy.mean().item(), rmsd=batch.energy.std().item() + ) + assert isinstance(norm, Normalizer) + # assert norm.state_dict() == sdict + new_sdict = norm.state_dict() + for key in sdict: + assert torch.allclose(new_sdict[key], sdict[key]) + + # bad construction + with pytest.raises(ValueError): + create_normalizer(mean=1.0) diff --git a/tests/core/preprocessing/test_atoms_to_graphs.py b/tests/core/preprocessing/test_atoms_to_graphs.py index ec1c34ab20..5c07a45243 100644 --- a/tests/core/preprocessing/test_atoms_to_graphs.py +++ b/tests/core/preprocessing/test_atoms_to_graphs.py @@ -15,7 +15,7 @@ from ase.neighborlist import NeighborList, NewPrimitiveNeighborList from fairchem.core.preprocessing import AtomsToGraphs - +from fairchem.core.modules.evaluator import min_diff @pytest.fixture(scope="class") def atoms_to_graphs_internals(request) -> None: @@ -110,7 +110,8 @@ def test_convert(self) -> None: # positions act_positions = self.atoms.get_positions() positions = data.pos.numpy() - np.testing.assert_allclose(act_positions, positions) + mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc) + np.testing.assert_allclose(mindiff, 0, atol=1e-6) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) test_energy = data.energy @@ -142,7 +143,8 @@ def test_convert_all(self) -> None: # positions act_positions = self.atoms.get_positions() positions = data_list[0].pos.numpy() - np.testing.assert_allclose(act_positions, positions) + mindiff = min_diff(act_positions, positions, self.atoms.get_cell(), self.atoms.pbc) + np.testing.assert_allclose(mindiff, 0, atol=1e-6) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) test_energy = data_list[0].energy