From 04a69b0353360fe9616047662fe9de4c2168b742 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 2 Aug 2024 10:19:10 -0700 Subject: [PATCH 1/2] Balanced batch sampler+base dataset (#753) * Update BalancedBatchSampler to use datasets' `data_sizes` method Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error` * Remove python 3.10 syntax * Documentation * Added set_epoch method * Format * Changed "resolved dataset" message to be a debug log to reduce log spam * clean up batchsampler and tests * base dataset class * move lin_ref to base dataset * inherit basedataset for ase dataset * filter indices prop * added create_dataset fn * yaml load fix * create dataset function instead of filtering in base * remove filtered_indices * make create_dataset and LMDBDatabase importable from datasets * create_dataset cleanup * test create_dataset * use metadata.natoms directly and add it to subset * use self.indices to handle shard * rename _data_sizes * fix Subset of metadata * minor change to metadata, added full path option * import updates * implement get_metadata for datasets; add tests for max_atoms and balanced partitioning * a[:len(a)+1] does not throw error, change to check for this * off by one fix * fixing tests * plug create_dataset into trainer * remove datasetwithsizes; fix base dataset integration; replace close_db with __del__ * lint * add/fix test; * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * Add extra test case for local batch size = 1 * fix example * fix test case * reorg changes * remove metadata_has_sizes in favor of basedataset function metadata_hasattr * fix data_parallel typo * fix up some tests * rename get_metadata to sample_property_metadata * add slow get_metadata for ase; add tests for get_metadata (ase+lmdb); add test for make lmdb metadata sizes * add support for different backends and ddp in pytest * fix tests and balanced batch sampler * make default dataset lmdb * lint * fix tests * test with world_size=0 by default * fix tests * fix tests.. * remove subsample from oc22 dataset * remove old datasets; add test for noddp * remove load balancing from docs * fix docs; add train_split_settings and test for this --------- Co-authored-by: Nima Shoghi Co-authored-by: Nima Shoghi Co-authored-by: lbluque Co-authored-by: Brandon Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi --- docs/core/fine-tuning/fine-tuning-oxides.md | 1 + .../advanced/fine-tuning-in-python.md | 2 +- src/fairchem/core/common/data_parallel.py | 236 ++++++++--------- src/fairchem/core/common/distutils.py | 6 +- src/fairchem/core/common/test_utils.py | 77 +++--- src/fairchem/core/datasets/__init__.py | 10 +- src/fairchem/core/datasets/ase_datasets.py | 44 ++-- src/fairchem/core/datasets/base_dataset.py | 227 +++++++++++++++++ src/fairchem/core/datasets/lmdb_dataset.py | 61 ++--- .../core/datasets/oc22_lmdb_dataset.py | 22 +- src/fairchem/core/scripts/make_lmdb_sizes.py | 16 +- src/fairchem/core/trainers/base_trainer.py | 91 ++++--- src/fairchem/core/trainers/ocp_trainer.py | 10 +- .../test_data_parallel_batch_sampler.py | 239 ++++++++++-------- tests/core/common/test_gp_utils.py | 110 ++++++-- tests/core/datasets/conftest.py | 28 ++ tests/core/datasets/test_ase_datasets.py | 46 ++-- tests/core/datasets/test_create_dataset.py | 180 +++++++++++++ tests/core/datasets/test_lmdb_dataset.py | 29 +++ tests/core/e2e/test_s2ef.py | 94 ++++++- tests/core/models/test_equiformer_v2.py | 15 +- 21 files changed, 1095 insertions(+), 449 deletions(-) create mode 100644 src/fairchem/core/datasets/base_dataset.py create mode 100644 tests/core/datasets/conftest.py create mode 100644 tests/core/datasets/test_create_dataset.py create mode 100644 tests/core/datasets/test_lmdb_dataset.py diff --git a/docs/core/fine-tuning/fine-tuning-oxides.md b/docs/core/fine-tuning/fine-tuning-oxides.md index 77a9350d3b..39c39cad40 100644 --- a/docs/core/fine-tuning/fine-tuning-oxides.md +++ b/docs/core/fine-tuning/fine-tuning-oxides.md @@ -205,6 +205,7 @@ from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', 'optim.loss_force', # the checkpoint setting causes an error + 'optim.load_balancing', 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, 'optim.eval_every': 10, diff --git a/docs/tutorials/advanced/fine-tuning-in-python.md b/docs/tutorials/advanced/fine-tuning-in-python.md index 1d14219c88..0eeb8e5485 100644 --- a/docs/tutorials/advanced/fine-tuning-in-python.md +++ b/docs/tutorials/advanced/fine-tuning-in-python.md @@ -75,7 +75,7 @@ We start by making the config.yml. We build this from the calculator checkpoint. from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', - delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', + delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes','optim.load_balancing', 'optim.loss_force', # the checkpoint setting causes an error 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index 4d5836b786..89c3b67445 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -9,20 +9,23 @@ import heapq import logging -from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal import numba import numpy as np -import numpy.typing as npt import torch -from torch.utils.data import BatchSampler, DistributedSampler, Sampler +import torch.distributed +from torch.utils.data import BatchSampler, Dataset, DistributedSampler +from typing_extensions import override from fairchem.core.common import distutils, gp_utils from fairchem.core.datasets import data_list_collater +from fairchem.core.datasets.base_dataset import ( + UnsupportedDatasetError, +) if TYPE_CHECKING: - from pathlib import Path - + from numpy.typing import NDArray from torch_geometric.data import Batch, Data @@ -35,30 +38,24 @@ def __call__(self, data_list: list[Data]) -> Batch: @numba.njit -def balanced_partition(sizes: npt.NDArray[np.int_], num_parts: int): +def _balanced_partition(sizes: NDArray[np.int_], num_parts: int): """ Greedily partition the given set by always inserting the largest element into the smallest partition. """ sort_idx = np.argsort(-sizes) # Sort in descending order - heap: list[tuple[list[int], list[int]]] = [ - (sizes[idx], [idx]) for idx in sort_idx[:num_parts] - ] + heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]] heapq.heapify(heap) for idx in sort_idx[num_parts:]: smallest_part = heapq.heappop(heap) new_size = smallest_part[0] + sizes[idx] - new_idx = smallest_part[1] + [idx] + new_idx = smallest_part[1] + [ + idx + ] # TODO should this be append to save time/space heapq.heappush(heap, (new_size, new_idx)) return [part[1] for part in heap] -@runtime_checkable -class _HasMetadata(Protocol): - @property - def metadata_path(self) -> Path: ... - - class StatefulDistributedSampler(DistributedSampler): """ More fine-grained state DataSampler that uses training iteration and epoch @@ -105,56 +102,83 @@ def set_epoch_and_start_iteration(self, epoch, start_iter): self.start_iter = start_iter -class BalancedBatchSampler(Sampler): - def _load_dataset(self, dataset, mode: Literal["atoms", "neighbors"]): - errors: list[str] = [] - if not isinstance(dataset, _HasMetadata): - errors.append(f"Dataset {dataset} does not have a metadata_path attribute.") - return None, errors - if not dataset.metadata_path.exists(): - errors.append(f"Metadata file {dataset.metadata_path} does not exist.") - return None, errors +def _ensure_supported(dataset: Any): + if not isinstance(dataset, Dataset): + raise UnsupportedDatasetError("BalancedBatchSampler requires a dataset.") + + if not dataset.metadata_hasattr("natoms"): + raise UnsupportedDatasetError( + "BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms." + ) - key = {"atoms": "natoms", "neighbors": "neighbors"}[mode] - sizes = np.load(dataset.metadata_path)[key] + logging.debug(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}") + return dataset - return sizes, errors +class BalancedBatchSampler(BatchSampler): def __init__( self, - dataset, + dataset: Dataset, + *, batch_size: int, num_replicas: int, rank: int, device: torch.device, seed: int, - mode: str | bool = "atoms", + mode: bool | Literal["atoms"] = "atoms", shuffle: bool = True, + on_error: Literal["warn_and_balance", "warn_and_no_balance", "raise"] = "raise", drop_last: bool = False, - force_balancing: bool = False, - throw_on_error: bool = False, - ) -> None: - if mode is True: - mode = "atoms" - - if isinstance(mode, str): - mode = mode.lower() - if mode not in ("atoms", "neighbors"): - raise ValueError( - f"Invalid mode {mode}. Must be one of 'atoms', 'neighbors', or a boolean." - ) + ): + """ + Initializes a BalancedBatchSampler object. - self.dataset = dataset - self.batch_size = batch_size - self.num_replicas = num_replicas - self.rank = rank - self.device = device - self.mode = mode - self.shuffle = shuffle - self.drop_last = drop_last + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): The size of each batch. + num_replicas (int): The number of processes participating in distributed training. + rank (int): The rank of the current process in distributed training. + device (torch.device): The device to use for the batches. + mode (str or bool, optional): The mode to use for balancing the batches. Defaults to "atoms". + shuffle (bool, optional): Whether to shuffle the samples. Defaults to True. + on_error (Literal["warn_and_balance", "warn_and_no_balance", "raise"], optional): The action to take when an error occurs (i.e., when we have an invalid dataset). Defaults to "raise". + - "warn_and_balance": Raise a warning and balance the batch by manually loading the data samples and counting the number of nodes (this is slow). + - "warn_and_no_balance": Raise a warning and do not do any balancing. + - "raise": Raise an error. + drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False. + """ + self.disabled = False + self.on_error = on_error + + if mode is False: + logging.warning(f"Disabled BalancedBatchSampler because {mode=}.") + self.disabled = True + elif mode.lower() != "atoms": + raise ValueError( + f"Only mode='atoms' or mode=True is supported, got {mode=}." + ) + elif num_replicas == 1: + logging.warning(f"Disabled BalancedBatchSampler because {num_replicas=}.") + self.disabled = True + + try: + dataset = _ensure_supported(dataset) + except UnsupportedDatasetError as error: + if self.on_error == "raise": + raise error + if self.on_error == "warn_and_balance": + logging.warning( + f"Failed to get data sizes from metadata, loading data to get sizes (THIS IS SLOW). {error}" + ) + elif self.on_error == "warn_and_no_balance": + logging.warning( + f"Failed to get data sizes, falling back to uniform partitioning. {error}" + ) + else: + raise ValueError(f"Unknown on_error={self.on_error}") from error - self.single_sampler = StatefulDistributedSampler( - self.dataset, + sampler = StatefulDistributedSampler( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, @@ -162,82 +186,59 @@ def __init__( batch_size=batch_size, seed=seed, ) - self.batch_sampler = BatchSampler( - self.single_sampler, - batch_size, - drop_last=drop_last, - ) - - self.sizes = None - self.balance_batches = False - if self.num_replicas <= 1: - logging.info("Batch balancing is disabled for single GPU training.") - return - - if self.mode is False: - logging.info( - "Batch balancing is disabled because `optim.load_balancing` is `False`" - ) - return - - self.sizes, errors = self._load_dataset(dataset, self.mode) - if self.sizes is None: - self.balance_batches = force_balancing - if force_balancing: - errors.append( - "BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead! " - "You can disable balancing by setting `optim.load_balancing` to `False`." - ) - else: - errors.append( - "Batches will not be balanced, which can incur significant overhead!" - ) - else: - self.balance_batches = True - - if errors: - msg = "BalancedBatchSampler: " + " ".join(errors) - if throw_on_error: - raise RuntimeError(msg) + super().__init__(sampler, batch_size=batch_size, drop_last=drop_last) + self.device = device - logging.warning(msg) + logging.info( + f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}" + ) - def __len__(self) -> int: - return len(self.batch_sampler) + def _get_natoms(self, batch_idx: list[int]): + if self.sampler.dataset.metadata_hasattr("natoms"): + return np.array( + self.sampler.dataset.get_metadata("natoms", batch_idx) + ).reshape(-1) + if self.on_error == "warn_and_balance": + return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx]) + return None def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None: - if not hasattr(self.single_sampler, "set_epoch_and_start_iteration"): + if not isinstance(self.sampler, StatefulDistributedSampler): if start_iteration != 0: raise NotImplementedError( f"{type(self.single_sampler)} does not support resuming from a nonzero step." ) - self.single_sampler.set_epoch(epoch) + self.sampler.set_epoch(epoch) else: - self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration) + self.sampler.set_epoch_and_start_iteration(epoch, start_iteration) + + def set_epoch(self, epoch: int) -> None: + if isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + @staticmethod + def _dist_enabled(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @override def __iter__(self): - if not self.balance_batches: - yield from self.batch_sampler + if self.disabled or not self._dist_enabled(): + yield from super().__iter__() return - for batch_idx in self.batch_sampler: - if self.sizes is None: - # Unfortunately, we need to load the data to know the image sizes - data_list = [self.dataset[idx] for idx in batch_idx] - - if self.mode == "atoms": - sizes = [data.num_nodes for data in data_list] - elif self.mode == "neighbors": - sizes = [data.edge_index.shape[1] for data in data_list] - else: - raise NotImplementedError( - f"Unknown load balancing mode: {self.mode}" - ) - else: - sizes = [self.sizes[idx] for idx in batch_idx] - - idx_sizes = torch.stack([torch.tensor(batch_idx), torch.tensor(sizes)]) + for batch_idx in super().__iter__(): + sizes = self._get_natoms(batch_idx) + if sizes is None: # on_error == "warn_and_no_balance" is set + yield batch_idx + continue + + idx_sizes = torch.stack( + [ + torch.tensor(batch_idx, device=self.device), + torch.tensor(sizes, device=self.device), + ] + ) idx_sizes_all = distutils.all_gather(idx_sizes, device=self.device) idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu() if gp_utils.initialized(): @@ -245,9 +246,10 @@ def __iter__(self): idx_all = idx_sizes_all[0] sizes_all = idx_sizes_all[1] - local_idx_balanced = balanced_partition( - sizes_all.numpy(), num_parts=self.num_replicas + local_idx_balanced = _balanced_partition( + sizes_all.numpy(), + num_parts=self.sampler.num_replicas, ) # Since DistributedSampler pads the last batch # this should always have an entry for each replica. - yield idx_all[local_idx_balanced[self.rank]] + yield idx_all[local_idx_balanced[self.sampler.rank]] diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 919f7ba66d..8989840641 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -98,7 +98,7 @@ def setup(config) -> None: ) else: config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend="nccl") + dist.init_process_group(backend=config.get("backend", "nccl")) def cleanup() -> None: @@ -144,7 +144,7 @@ def all_reduce( if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) dist.all_reduce(tensor, group=group) if average: tensor /= get_world_size() @@ -162,7 +162,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]: if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())] dist.all_gather(tensor_list, tensor, group=group) if not isinstance(data, torch.Tensor): diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index 8aaf822105..130daba2d5 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -44,9 +44,55 @@ class PGConfig: use_gp: bool = True +def init_env_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["RANK"] = str(rank) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + +def init_pg_and_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + # setup default process group + dist.init_process_group( + rank=rank, + world_size=pg_setup_params.world_size, + backend=pg_setup_params.backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) + # setup gp + if pg_setup_params.use_gp: + config = { + "gp_gpus": pg_setup_params.gp_group_size, + "distributed_backend": pg_setup_params.backend, + } + setup_gp(config) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + def spawn_multi_process( config: PGConfig, test_method: callable, + init_and_launch: callable, *test_method_args: Any, **test_method_kwargs: Any, ) -> list[Any]: @@ -72,7 +118,7 @@ def spawn_multi_process( torch.multiprocessing.spawn( # torch.multiprocessing.spawn sends rank as the first param # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn - _init_pg_and_rank_and_launch_test, + init_and_launch, args=( config, mp_output_dict, @@ -84,32 +130,3 @@ def spawn_multi_process( ) return [mp_output_dict[i] for i in range(config.world_size)] - - -def _init_pg_and_rank_and_launch_test( - rank: int, - pg_setup_params: PGConfig, - mp_output_dict: dict[int, object], - test_method: callable, - args: list[object], - kwargs: dict[str, object], -) -> None: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = pg_setup_params.port - os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) - os.environ["LOCAL_RANK"] = str(rank) - # setup default process group - dist.init_process_group( - rank=rank, - world_size=pg_setup_params.world_size, - backend=pg_setup_params.backend, - timeout=timedelta(seconds=10), # setting up timeout for distributed collectives - ) - # setup gp - if pg_setup_params.use_gp: - config = { - "gp_gpus": pg_setup_params.gp_group_size, - "distributed_backend": pg_setup_params.backend, - } - setup_gp(config) - mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme diff --git a/src/fairchem/core/datasets/__init__.py b/src/fairchem/core/datasets/__init__.py index 1fd4b51cd5..dc3f7d0e4d 100644 --- a/src/fairchem/core/datasets/__init__.py +++ b/src/fairchem/core/datasets/__init__.py @@ -5,23 +5,19 @@ from __future__ import annotations from .ase_datasets import AseDBDataset, AseReadDataset, AseReadMultiStructureDataset +from .base_dataset import create_dataset from .lmdb_database import LMDBDatabase from .lmdb_dataset import ( LmdbDataset, - SinglePointLmdbDataset, - TrajectoryLmdbDataset, data_list_collater, ) -from .oc22_lmdb_dataset import OC22LmdbDataset __all__ = [ "AseDBDataset", "AseReadDataset", "AseReadMultiStructureDataset", "LmdbDataset", - "SinglePointLmdbDataset", - "TrajectoryLmdbDataset", - "data_list_collater", - "OC22LmdbDataset", "LMDBDatabase", + "create_dataset", + "data_list_collater", ] diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index a258f42832..15c22322db 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -20,13 +20,12 @@ import ase import numpy as np -import torch.nn from torch import tensor -from torch.utils.data import Dataset from tqdm import tqdm from fairchem.core.common.registry import registry from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.lmdb_database import LMDBDatabase from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms @@ -60,7 +59,7 @@ def apply_one_tags( return atoms -class AseAtomsDataset(Dataset, ABC): +class AseAtomsDataset(BaseDataset, ABC): """ This is an abstract Dataset that includes helpful utilities for turning ASE atoms objects into OCP-usable data objects. This should not be instantiated directly @@ -81,7 +80,7 @@ def __init__( config: dict, atoms_transform: Callable[[ase.Atoms, Any, ...], ase.Atoms] = apply_one_tags, ) -> None: - self.config = config + super().__init__(config) a2g_args = config.get("a2g_args", {}) or {} @@ -96,19 +95,13 @@ def __init__( self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - self.lin_ref = None - if self.config.get("lin_ref", False): - lin_ref = torch.tensor( - np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - ) - self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) - self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): self.__getitem__ = cache(self.__getitem__) self.ids = self._load_dataset_get_ids(config) + self.num_samples = len(self.ids) if len(self.ids) == 0: raise ValueError( @@ -116,9 +109,6 @@ def __init__( f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" ) - def __len__(self) -> int: - return len(self.ids) - def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): @@ -174,11 +164,7 @@ def _load_dataset_get_ids(self, config): def get_relaxed_energy(self, identifier): raise NotImplementedError("IS2RE-Direct is not implemented with this dataset.") - def close_db(self) -> None: - # This method is sometimes called by a trainer - pass - - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: metadata = {} if num_samples < len(self): @@ -197,6 +183,18 @@ def get_metadata(self, num_samples: int = 100) -> dict: return metadata + def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here + if attr != "natoms": + return None + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) + return len(self.get_atoms(idx)) + @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): @@ -399,7 +397,7 @@ def get_atoms(self, idx: str) -> ase.Atoms: return atoms - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: return {} def get_relaxed_energy(self, identifier) -> float: @@ -556,17 +554,17 @@ def connect_db( return ase.db.connect(address, **connect_args) - def close_db(self) -> None: + def __del__(self): for db in self.dbs: if hasattr(db, "close"): db.close() - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: logging.warning( "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" ) if self.dbs[0].metadata == {}: - return super().get_metadata(num_samples) + return super().sample_property_metadata(num_samples) return copy.deepcopy(self.dbs[0].metadata) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py new file mode 100644 index 0000000000..2ca26596c3 --- /dev/null +++ b/src/fairchem/core/datasets/base_dataset.py @@ -0,0 +1,227 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from abc import ABCMeta +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + NamedTuple, + TypeVar, +) + +import numpy as np +import torch +from torch import randperm +from torch.utils.data import Dataset +from torch.utils.data import Subset as Subset_ + +from fairchem.core.common.registry import registry + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import ArrayLike + + +T_co = TypeVar("T_co", covariant=True) + + +class DatasetMetadata(NamedTuple): + natoms: ArrayLike | None = None + + +class UnsupportedDatasetError(ValueError): + pass + + +class BaseDataset(Dataset[T_co], metaclass=ABCMeta): + """Base Dataset class for all OCP datasets.""" + + def __init__(self, config: dict): + """Initialize + + Args: + config (dict): dataset configuration + """ + self.config = config + self.paths = [] + + if "src" in self.config: + if isinstance(config["src"], str): + self.paths = [Path(self.config["src"])] + else: + self.paths = tuple(Path(path) for path in config["src"]) + + self.lin_ref = None + if self.config.get("lin_ref", False): + lin_ref = torch.tensor( + np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] + ) + self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) + + def __len__(self) -> int: + return self.num_samples + + def metadata_hasattr(self, attr) -> bool: + if self._metadata is None: + return False + return hasattr(self._metadata, attr) + + @cached_property + def indices(self): + return np.arange(self.num_samples, dtype=int) + + @cached_property + def _metadata(self) -> DatasetMetadata: + # logic to read metadata file here + metadata_npzs = [] + if self.config.get("metadata_path", None) is not None: + metadata_npzs.append( + np.load(self.config["metadata_path"], allow_pickle=True) + ) + + else: + for path in self.paths: + if path.is_file(): + metadata_file = path.parent / "metadata.npz" + else: + metadata_file = path / "metadata.npz" + if metadata_file.is_file(): + metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) + + if len(metadata_npzs) == 0: + logging.warning( + f"Could not find dataset metadata.npz files in '{self.paths}'" + ) + return None + + metadata = DatasetMetadata( + **{ + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in DatasetMetadata._fields + } + ) + + assert metadata.natoms.shape[0] == len( + self + ), "Loaded metadata and dataset size mismatch." + + return metadata + + def get_metadata(self, attr, idx): + if self._metadata is not None: + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + return None + + +class Subset(Subset_, BaseDataset): + """A pytorch subset that also takes metadata if given.""" + + def __init__( + self, + dataset: BaseDataset, + indices: Sequence[int], + metadata: DatasetMetadata | None = None, + ) -> None: + super().__init__(dataset, indices) + self.metadata = metadata + self.indices = indices + self.num_samples = len(indices) + self.config = dataset.config + + @cached_property + def _metadata(self) -> DatasetMetadata: + return self.dataset._metadata + + def get_metadata(self, attr, idx): + if isinstance(idx, list): + return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) + return self.dataset.get_metadata(attr, self.indices[idx]) + + +def create_dataset(config: dict[str, Any], split: str) -> Subset: + """Create a dataset from a config dictionary + + Args: + config (dict): dataset config dictionary + split (str): name of split + + Returns: + Subset: dataset subset class + """ + # Initialize the dataset + dataset_cls = registry.get_dataset_class(config.get("format", "lmdb")) + assert issubclass(dataset_cls, Dataset), f"{dataset_cls} is not a Dataset" + + # remove information about other splits, only keep specified split + # this may only work with the mt config not main config + current_split_config = config.copy() + if "splits" in current_split_config: + current_split_config.pop("splits") + current_split_config.update(config["splits"][split]) + + seed = current_split_config.get("seed", 0) + if split != "train": + seed += ( + 1 # if we use same dataset for train / val , make sure its diff sampling + ) + + g = torch.Generator() + g.manual_seed(seed) + + dataset = dataset_cls(current_split_config) + # Get indices of the dataset + indices = dataset.indices + max_atoms = current_split_config.get("max_atoms", None) + if max_atoms is not None: + if not dataset.metadata_hasattr("natoms"): + raise ValueError("Cannot use max_atoms without dataset metadata") + indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms] + + # Apply dataset level transforms + # TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle? + first_n = current_split_config.get("first_n") + sample_n = current_split_config.get("sample_n") + no_shuffle = current_split_config.get("no_shuffle") + # this is true if at most one of the mutually exclusive arguments are set + if sum(arg is not None for arg in (first_n, sample_n, no_shuffle)) > 1: + raise ValueError( + "sample_n, first_n, no_shuffle are mutually exclusive arguments. Only one can be provided." + ) + if first_n is not None: + max_index = first_n + elif sample_n is not None: + # shuffle by default, user can disable to optimize if they have confidence in dataset + # shuffle all datasets by default to avoid biasing the sampling in concat dataset + # TODO only shuffle if split is train + max_index = sample_n + indices = indices[randperm(len(indices), generator=g)] + else: + max_index = len(indices) + indices = ( + indices if no_shuffle else indices[randperm(len(indices), generator=g)] + ) + + if max_index > len(indices): + msg = ( + f"Cannot take {max_index} data points from a dataset of only length {len(indices)}.\n" + f"Make sure to set first_n or sample_n to a number =< the total samples in dataset." + ) + if max_atoms is not None: + msg = msg[:-1] + f"that are smaller than the given max_atoms {max_atoms}." + raise ValueError(msg) + + indices = indices[:max_index] + + return Subset(dataset, indices, metadata=dataset._metadata) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index 91ced220ea..ca1fcc2b77 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -9,32 +9,33 @@ import bisect import logging import pickle -import warnings -from pathlib import Path -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import lmdb import numpy as np import torch -from torch.utils.data import Dataset from torch_geometric.data import Batch -from torch_geometric.data.data import BaseData from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms +if TYPE_CHECKING: + from pathlib import Path + + from torch_geometric.data.data import BaseData + T_co = TypeVar("T_co", covariant=True) @registry.register_dataset("lmdb") @registry.register_dataset("single_point_lmdb") @registry.register_dataset("trajectory_lmdb") -class LmdbDataset(Dataset[T_co]): - metadata_path: Path +class LmdbDataset(BaseDataset): sharded: bool r"""Dataset class to load from LMDB files containing relaxation @@ -50,20 +51,21 @@ class LmdbDataset(Dataset[T_co]): """ def __init__(self, config) -> None: - super().__init__() - self.config = config + super().__init__(config) assert not self.config.get( "train_on_oc20_total_energies", False ), "For training on total energies set dataset=oc22_lmdb" - self.path = Path(self.config["src"]) + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] + if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" - self.metadata_path = self.path / "metadata.npz" - self._keys = [] self.envs = [] for db_path in db_paths: @@ -86,7 +88,6 @@ def __init__(self, config) -> None: self._keylen_cumulative = np.cumsum(keylens).tolist() self.num_samples = sum(keylens) else: - self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) # If "length" encoded as ascii is present, use that @@ -113,19 +114,15 @@ def __init__(self, config) -> None: self.indices, self.config.get("total_shards", 1) ) # limit each process to see a subset of data based off defined shard - self.available_indices = self.shards[self.config.get("shard", 0)] - self.num_samples = len(self.available_indices) + self.indices = self.shards[self.config.get("shard", 0)] + self.num_samples = len(self.indices) self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - def __len__(self) -> int: - return self.num_samples - def __getitem__(self, idx: int) -> T_co: # if sharding, remap idx to appropriate idx of the sharded set - if self.sharded: - idx = self.available_indices[idx] + idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) @@ -165,14 +162,14 @@ def connect_db(self, lmdb_path: Path | None = None) -> lmdb.Environment: max_readers=1, ) - def close_db(self) -> None: + def __del__(self): if not self.path.is_file(): for env in self.envs: env.close() else: self.env.close() - def get_metadata(self, num_samples: int = 100): + def sample_property_metadata(self, num_samples: int = 100): # This will interogate the classic OCP LMDB format to determine # which properties are present and attempt to guess their shapes # and whether they are intensive or extensive. @@ -214,26 +211,6 @@ def get_metadata(self, num_samples: int = 100): } -class SinglePointLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "SinglePointLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - -class TrajectoryLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "TrajectoryLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData: batch = Batch.from_data_list(data_list) diff --git a/src/fairchem/core/datasets/oc22_lmdb_dataset.py b/src/fairchem/core/datasets/oc22_lmdb_dataset.py index 0c6d4e8bfb..867a72726f 100644 --- a/src/fairchem/core/datasets/oc22_lmdb_dataset.py +++ b/src/fairchem/core/datasets/oc22_lmdb_dataset.py @@ -9,22 +9,21 @@ import bisect import pickle -from pathlib import Path import lmdb import numpy as np import torch -from torch.utils.data import Dataset from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance as aii from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.modules.transforms import DataTransforms @registry.register_dataset("oc22_lmdb") -class OC22LmdbDataset(Dataset): +class OC22LmdbDataset(BaseDataset): r"""Dataset class to load from LMDB files containing relaxation trajectories or single point computations. @@ -43,10 +42,13 @@ class OC22LmdbDataset(Dataset): """ def __init__(self, config, transform=None) -> None: - super().__init__() - self.config = config + super().__init__(config) + + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] - self.path = Path(self.config["src"]) self.data2train = self.config.get("data2train", "all") if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) @@ -114,19 +116,11 @@ def __init__(self, config, transform=None) -> None: if self.train_on_oc20_total_energies: with open(config["oc20_ref"], "rb") as fp: self.oc20_ref = pickle.load(fp) - if self.config.get("lin_ref", False): - coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - self.lin_ref = torch.nn.Parameter(torch.tensor(coeff), requires_grad=False) - self.subsample = aii(self.config.get("subsample", False), bool) def __len__(self) -> int: - if self.subsample: - return min(self.subsample, self.num_samples) return self.num_samples def __getitem__(self, idx): - if self.data2train != "all": - idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) diff --git a/src/fairchem/core/scripts/make_lmdb_sizes.py b/src/fairchem/core/scripts/make_lmdb_sizes.py index 682fb58e65..ebf2122aeb 100644 --- a/src/fairchem/core/scripts/make_lmdb_sizes.py +++ b/src/fairchem/core/scripts/make_lmdb_sizes.py @@ -15,7 +15,7 @@ from tqdm import tqdm from fairchem.core.common.typing import assert_is_instance -from fairchem.core.datasets import SinglePointLmdbDataset, TrajectoryLmdbDataset +from fairchem.core.datasets.lmdb_dataset import LmdbDataset def get_data(index): @@ -28,14 +28,13 @@ def get_data(index): return index, natoms, neighbors -def main(args) -> None: +def make_lmdb_sizes(args) -> None: path = assert_is_instance(args.data_path, str) global dataset + dataset = LmdbDataset({"src": path}) if os.path.isdir(path): - dataset = TrajectoryLmdbDataset({"src": path}) outpath = os.path.join(path, "metadata.npz") elif os.path.isfile(path): - dataset = SinglePointLmdbDataset({"src": path}) outpath = os.path.join(os.path.dirname(path), "metadata.npz") output_indices = range(len(dataset)) @@ -63,7 +62,7 @@ def main(args) -> None: np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors) -if __name__ == "__main__": +def get_lmdb_sizes_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--data-path", @@ -77,5 +76,10 @@ def main(args) -> None: type=int, help="Num of workers to parallelize across", ) + return parser + + +if __name__ == "__main__": + parser = get_lmdb_sizes_parser() args: argparse.Namespace = parser.parse_args() - main(args) + make_lmdb_sizes(args) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index dce5099452..1c0c975f8a 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -7,6 +7,7 @@ from __future__ import annotations +import copy import datetime import errno import logging @@ -38,6 +39,7 @@ save_checkpoint, update_config, ) +from fairchem.core.datasets.base_dataset import create_dataset from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss @@ -241,12 +243,16 @@ def load_logger(self) -> None: def get_sampler( self, dataset, batch_size: int, shuffle: bool ) -> BalancedBatchSampler: - if "load_balancing" in self.config["optim"]: - balancing_mode = self.config["optim"]["load_balancing"] - force_balancing = True + balancing_mode = self.config["optim"].get("load_balancing", None) + on_error = self.config["optim"].get("load_balancing_on_error", None) + if balancing_mode is not None: + if on_error is None: + on_error = "raise" else: balancing_mode = "atoms" - force_balancing = False + + if on_error is None: + on_error = "warn_and_no_balance" if gp_utils.initialized(): num_replicas = gp_utils.get_dp_world_size() @@ -262,7 +268,7 @@ def get_sampler( device=self.device, mode=balancing_mode, shuffle=shuffle, - force_balancing=force_balancing, + on_error=on_error, seed=self.config["cmd"]["seed"], ) @@ -283,15 +289,26 @@ def load_datasets(self) -> None: self.val_loader = None self.test_loader = None + # This is hacky and scheduled to be removed next BE week + # move ['X_split_settings'] to ['splits'][X] + def convert_settings_to_split_settings(config, split_name): + config = copy.deepcopy(config) # make sure we dont modify the original + if f"{split_name}_split_settings" in config: + config["splits"] = { + split_name: config.pop(f"{split_name}_split_settings") + } + return config + # load train, val, test datasets if "src" in self.config["dataset"]: logging.info( f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}" ) - self.train_dataset = registry.get_dataset_class( - self.config["dataset"].get("format", "lmdb") - )(self.config["dataset"]) + self.train_dataset = create_dataset( + convert_settings_to_split_settings(self.config["dataset"], "train"), + "train", + ) self.train_sampler = self.get_sampler( self.train_dataset, self.config["optim"]["batch_size"], @@ -302,6 +319,16 @@ def load_datasets(self) -> None: self.train_sampler, ) + if ( + "first_n" in self.config["dataset"] + or "sample_n" in self.config["dataset"] + or "max_atom" in self.config["dataset"] + ): + logging.warn( + "Dataset attributes (first_n/sample_n/max_atom) passed to all datasets! Please don't do this, its dangerous!\n" + + "Add them under each dataset 'train_split_settings'/'val_split_settings'/'test_split_settings'" + ) + if "src" in self.config["val_dataset"]: if self.config["val_dataset"].get("use_train_settings", True): val_config = self.config["dataset"].copy() @@ -309,9 +336,9 @@ def load_datasets(self) -> None: else: val_config = self.config["val_dataset"] - self.val_dataset = registry.get_dataset_class( - val_config.get("format", "lmdb") - )(val_config) + self.val_dataset = create_dataset( + convert_settings_to_split_settings(val_config, "val"), "val" + ) self.val_sampler = self.get_sampler( self.val_dataset, self.config["optim"].get( @@ -331,9 +358,9 @@ def load_datasets(self) -> None: else: test_config = self.config["test_dataset"] - self.test_dataset = registry.get_dataset_class( - test_config.get("format", "lmdb") - )(test_config) + self.test_dataset = create_dataset( + convert_settings_to_split_settings(test_config, "test"), "test" + ) self.test_sampler = self.get_sampler( self.test_dataset, self.config["optim"].get( @@ -398,15 +425,15 @@ def load_task(self): ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["train_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("train_on_free_atoms", True) + self.config["outputs"][target_name].get( + "train_on_free_atoms", True + ) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: self.output_targets[subtarget]["eval_on_free_atoms"] = ( - self.config[ - "outputs" - ][target_name].get("eval_on_free_atoms", True) + self.config["outputs"][target_name].get( + "eval_on_free_atoms", True + ) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent @@ -429,11 +456,13 @@ def load_model(self) -> None: loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( - loader.dataset[0].x.shape[-1] - if loader - and hasattr(loader.dataset[0], "x") - and loader.dataset[0].x is not None - else None, + ( + loader.dataset[0].x.shape[-1] + if loader + and hasattr(loader.dataset[0], "x") + and loader.dataset[0].x is not None + else None + ), bond_feat_dim, 1, **self.config["model_attributes"], @@ -455,7 +484,9 @@ def load_model(self) -> None: self.logger.log_summary({"num_params": self.model.num_params}) if distutils.initialized() and not self.config["noddp"]: - self.model = DistributedDataParallel(self.model, device_ids=[self.device]) + self.model = DistributedDataParallel( + self.model, device_ids=None if self.cpu else [self.device] + ) @property def _unwrapped_model(self): @@ -639,9 +670,11 @@ def save( "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict() - if self.scheduler.scheduler_type != "Null" - else None, + "scheduler": ( + self.scheduler.scheduler.state_dict() + if self.scheduler.scheduler_type != "Null" + else None + ), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 9055d2d625..72c005893d 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -227,12 +227,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) - self.train_dataset.close_db() - if self.config.get("val_dataset", False): - self.val_dataset.close_db() - if self.config.get("test_dataset", False): - self.test_dataset.close_db() - def _forward(self, batch): out = self.model(batch.to(self.device)) @@ -648,7 +642,9 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[:-1] # np.split does not need last idx, assumes n-1:end + )[ + :-1 + ] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/common/test_data_parallel_batch_sampler.py b/tests/core/common/test_data_parallel_batch_sampler.py index 6205042652..6bd8effe26 100644 --- a/tests/core/common/test_data_parallel_batch_sampler.py +++ b/tests/core/common/test_data_parallel_batch_sampler.py @@ -1,9 +1,16 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from __future__ import annotations -import functools -import tempfile from contextlib import contextmanager from pathlib import Path +import functools +import tempfile from typing import TypeVar import numpy as np @@ -13,11 +20,13 @@ from fairchem.core.common.data_parallel import ( BalancedBatchSampler, StatefulDistributedSampler, + UnsupportedDatasetError, + _balanced_partition, ) +from fairchem.core.datasets.base_dataset import BaseDataset, DatasetMetadata DATA = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -SIZE_ATOMS = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] -SIZE_NEIGHBORS = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4] +SIZE_ATOMS = [2, 20, 3, 51, 10, 11, 41, 31, 13, 14] T_co = TypeVar("T_co", covariant=True) @@ -28,23 +37,57 @@ def _temp_file(name: str): yield Path(tmpdir) / name +@pytest.fixture() +def valid_dataset(): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return DatasetMetadata(natoms=np.array(SIZE_ATOMS)) + + def __init__(self, data) -> None: + super().__init__(config={}) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def get_metadata(self, attr, idx): + assert attr == "natoms" + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + + return _Dataset(DATA) + + @pytest.fixture() def valid_path_dataset(): - class _Dataset(Dataset[T_co]): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return self.metadata + def __init__(self, data, fpath: Path) -> None: + super().__init__(config={}) self.data = data - self.metadata_path = fpath + self.metadata = DatasetMetadata(natoms=np.load(fpath)["natoms"]) def __len__(self): return len(self.data) def __getitem__(self, idx): - return self.data[idx] + metadata_attr = getattr(self._metadata, "natoms") + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] with _temp_file("metadata.npz") as file: np.savez( natoms=np.array(SIZE_ATOMS), - neighbors=np.array(SIZE_NEIGHBORS), file=file, ) yield _Dataset(DATA, file) @@ -52,8 +95,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_path_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data self.metadata_path = Path("/tmp/does/not/exist.np") @@ -68,8 +113,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data def __len__(self): @@ -81,99 +128,68 @@ def __getitem__(self, idx): return _Dataset(DATA) -def test_lowercase(invalid_dataset) -> None: - sampler = BalancedBatchSampler( - dataset=invalid_dataset, +def test_lowercase(valid_dataset) -> None: + _ = BalancedBatchSampler( + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="ATOMS", - throw_on_error=False, - seed=0 - ) - assert sampler.mode == "atoms" - - sampler = BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="NEIGHBORS", - throw_on_error=False, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.mode == "neighbors" def test_invalid_mode(invalid_dataset) -> None: with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='natoms'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="natoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='neighbors'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, - mode="nneighbors", - throw_on_error=True, - seed=0 + mode="neighbors", + on_error="raise", + seed=0, ) def test_invalid_dataset(invalid_dataset) -> None: - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", - ): - BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 - ) - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. Batches will not be balanced, which can incur significant overhead!", - ): - BalancedBatchSampler( + with pytest.raises(UnsupportedDatasetError): + sampler = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) def test_invalid_path_dataset(invalid_path_dataset) -> None: with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -182,13 +198,11 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. Batches will not be balanced, which can incur significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -197,70 +211,59 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) -def test_valid_dataset(valid_path_dataset) -> None: +def test_valid_dataset(valid_dataset, valid_path_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - seed=0 - ) - assert (sampler.sizes == np.array(SIZE_ATOMS)).all() - - sampler = BalancedBatchSampler( - dataset=valid_path_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="neighbors", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert (sampler.sizes == np.array(SIZE_NEIGHBORS)).all() + assert ( + sampler._get_natoms(list(range(len(SIZE_ATOMS)))) == np.array(SIZE_ATOMS) + ).all() -def test_disabled(valid_path_dataset) -> None: +def test_disabled(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode=False, - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_single_node(valid_path_dataset) -> None: +def test_single_node(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=1, device=None, mode="atoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: +def test_stateful_distributed_sampler_noshuffle(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -272,12 +275,12 @@ def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: def test_stateful_distributed_sampler_vs_distributed_sampler( - valid_path_dataset, + valid_dataset, ) -> None: for seed in [0, 100, 200]: for batch_size in range(1, 4): stateful_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=2, @@ -286,7 +289,7 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( drop_last=True, ) sampler = DistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, rank=0, num_replicas=2, seed=seed, @@ -296,10 +299,10 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( assert list(stateful_sampler) == list(sampler) -def test_stateful_distributed_sampler(valid_path_dataset) -> None: +def test_stateful_distributed_sampler(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -309,7 +312,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: offset_step = 2 loaded_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, seed=0, @@ -319,7 +322,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(loaded_sampler) == original_order[offset_step * batch_size :] diff_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -328,14 +331,14 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(diff_sampler) != original_order -def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: - fullset = set(range(len(valid_path_dataset))) +def test_stateful_distributed_sampler_numreplicas(valid_dataset) -> None: + fullset = set(range(len(valid_dataset))) for drop_last in [True, False]: for num_replicas in range(1, 4): for batch_size in [1]: samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -360,14 +363,14 @@ def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: def test_stateful_distributed_sampler_numreplicas_drop_last( - valid_path_dataset, + valid_dataset, ) -> None: - fullset = set(range(len(valid_path_dataset))) + fullset = set(range(len(valid_dataset))) for num_replicas in range(1, 4): for batch_size in range(1, 4): samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -387,3 +390,15 @@ def test_stateful_distributed_sampler_numreplicas_drop_last( ) assert len(concat_idxs) == len(np.unique(concat_idxs)) assert len(concat_idxs) == (len(fullset) // num_replicas) * num_replicas + + +def test_balancedbatchsampler_partition(valid_dataset) -> None: + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS), 4) + == [[1, 9, 5, 0], [7, 8, 2], [3], [6, 4]] + ) + # test case with local batch size = 1, GPU0(rank0) always gets smallest + # we cant say anything about the remaining elements because it is a heap + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS)[[3, 6, 7, 1]], 4)[0] == [3] + ) diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index 9743d35a2f..05c7475d2c 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -7,42 +7,112 @@ gather_from_model_parallel_region, scatter_to_model_parallel_region, ) -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) def _dummy_call(x): return x -@pytest.mark.parametrize("world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])]) # noqa: PT006 + +@pytest.mark.parametrize( + "world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])] +) # noqa: PT006 def test_basic_setup(world_size: int, input: torch.Tensor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True) - output = spawn_multi_process(config, _dummy_call, input) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True + ) + output = spawn_multi_process( + config, _dummy_call, init_pg_and_rank_and_launch_test, input + ) assert output == expected_output -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1]), torch.Tensor([2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1]), torch.Tensor([2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1]), torch.Tensor([2, 3])], + ), + (2, 2, torch.Tensor([0, 1, 2]), [torch.Tensor([0, 1]), torch.Tensor([2])]), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])], + ), + ], ) -def test_scatter_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_to_model_parallel_region, input) +def test_scatter_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, + scatter_to_model_parallel_region, + init_pg_and_rank_and_launch_test, + input, + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) + def scatter_gather_fn(input: torch.Tensor, dim: int = 0): x = scatter_to_model_parallel_region(input, dim) return gather_from_model_parallel_region(x, dim) -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2]), torch.Tensor([0,1,2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ], ) -def test_gather_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_gather_fn, input) +def test_gather_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, scatter_gather_fn, init_pg_and_rank_and_launch_test, input + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) diff --git a/tests/core/datasets/conftest.py b/tests/core/datasets/conftest.py new file mode 100644 index 0000000000..eb7be94994 --- /dev/null +++ b/tests/core/datasets/conftest.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +from ase import build +from ase.calculators.singlepoint import SinglePointCalculator + + +@pytest.fixture(scope="module") +def structures(): + structures = [ + build.molecule("H2O", vacuum=4), + build.bulk("Cu"), + build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), + ] + for atoms in structures: + calc = SinglePointCalculator( + atoms, + energy=1, + forces=atoms.positions, + # there is an issue with ASE db when writing a db with 3x3 stress if is flattened to (9,) and then + # errors when trying to read it + stress=np.random.random((6,)), + ) + atoms.calc = calc + atoms.info["extensive_property"] = 3 * len(atoms) + atoms.info["tensor_property"] = np.random.random((6, 6)) + + structures[2].set_pbc(True) + return structures diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 01bd4ea2fc..7b114d877f 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -15,26 +15,6 @@ ) from fairchem.core.datasets.lmdb_database import LMDBDatabase -structures = [ - build.molecule("H2O", vacuum=4), - build.bulk("Cu"), - build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), -] -for atoms in structures: - calc = SinglePointCalculator( - atoms, - energy=1, - forces=atoms.positions, - # there is an issue with ASE db when writing a db with 3x3 stress it is flattened to (9,) and then - # errors when trying to read it - stress=np.random.random((6,)), - ) - atoms.calc = calc - atoms.info["extensive_property"] = 3 * len(atoms) - atoms.info["tensor_property"] = np.random.random((6, 6)) - -structures[2].set_pbc(True) - @pytest.fixture( params=[ @@ -46,7 +26,7 @@ "aselmdb_dataset", ], ) -def ase_dataset(request, tmp_path_factory): +def ase_dataset(request, structures, tmp_path_factory): tmp_path = tmp_path_factory.mktemp("dataset") mult = 1 a2g_args = { @@ -110,7 +90,7 @@ def ase_dataset(request, tmp_path_factory): return dataset, mult -def test_ase_dataset(ase_dataset): +def test_ase_dataset(ase_dataset, structures): dataset, mult = ase_dataset assert len(dataset) == mult * len(structures) for data in dataset: @@ -121,7 +101,7 @@ def test_ase_dataset(ase_dataset): assert isinstance(data.extensive_property, int) -def test_ase_read_dataset(tmp_path) -> None: +def test_ase_read_dataset(tmp_path, structures): # unfortunately there is currently no clean (already implemented) way to save atoms.info when saving # individual structures - so test separately for i, structure in enumerate(structures): @@ -137,13 +117,16 @@ def test_ase_read_dataset(tmp_path) -> None: assert len(dataset) == len(structures) data = dataset[0] del data - dataset.close_db() -def test_ase_metadata_guesser(ase_dataset) -> None: +def test_ase_get_metadata(ase_dataset): + assert ase_dataset[0].get_metadata("natoms", [0])[0] == 3 + + +def test_ase_metadata_guesser(ase_dataset): dataset, _ = ase_dataset - metadata = dataset.get_metadata() + metadata = dataset.sample_property_metadata() # Confirm energy metadata guessed properly assert metadata["targets"]["energy"]["extensive"] is False @@ -171,7 +154,7 @@ def test_ase_metadata_guesser(ase_dataset) -> None: assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" -def test_db_add_delete(tmp_path) -> None: +def test_db_add_delete(tmp_path, structures): database = db.connect(tmp_path / "asedb.db") for _i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) @@ -192,10 +175,9 @@ def test_db_add_delete(tmp_path) -> None: dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) assert len(dataset) == orig_len + len(new_structures) - 1 - dataset.close_db() -def test_ase_multiread_dataset(tmp_path) -> None: +def test_ase_multiread_dataset(tmp_path): atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] energies = np.linspace(1, 0, len(atoms_objects)) @@ -224,13 +206,17 @@ def test_ase_multiread_dataset(tmp_path) -> None: f.write(f"{tmp_path / 'test.traj'} {len(atoms_objects)}") dataset = AseReadMultiStructureDataset( - config={"index_file": str(tmp_path / "test_index_file")}, + config={ + "src": str(tmp_path), + "index_file": str(tmp_path / "test_index_file"), + }, ) assert len(dataset) == len(atoms_objects) dataset = AseReadMultiStructureDataset( config={ + "src": str(tmp_path), "index_file": str(tmp_path / "test_index_file"), "a2g_args": { "r_energy": True, diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py new file mode 100644 index 0000000000..d90271c53d --- /dev/null +++ b/tests/core/datasets/test_create_dataset.py @@ -0,0 +1,180 @@ +import os +import numpy as np +import pytest + +from fairchem.core.datasets import LMDBDatabase, create_dataset +from fairchem.core.datasets.base_dataset import BaseDataset +import tempfile +from fairchem.core.trainers.base_trainer import BaseTrainer + + +@pytest.fixture() +def lmdb_database(structures): + with tempfile.TemporaryDirectory() as tmpdirname: + num_atoms = [] + asedb_fn = f"{tmpdirname}/asedb.lmdb" + with LMDBDatabase(asedb_fn) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + num_atoms.append(len(atoms)) + np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms) + yield asedb_fn + + +def test_real_dataset_config(lmdb_database): + class TestTrainer(BaseTrainer): + def __init__(self, config): + self.config = config + + def train(self, x): + return None + + def get_sampler(self, *args, **kwargs): + return None + + def get_dataloader(self, *args, **kwargs): + return None + + config = { + "model_attributes": {}, + "optim": {"batch_size": 0}, + "dataset": { + "format": "ase_db", + "src": str(lmdb_database), + "first_n": 2, + "key_mapping": { + "y": "energy", + "force": "forces", + }, + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0, "stdev": 2.887317180633545}, + } + }, + }, + "val_dataset": {"src": str(lmdb_database)}, + "test_dataset": {}, + "relax_dataset": None, + } + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 2 + + # modify the config for split and see if it works as expected + config["dataset"].pop("first_n") + config["dataset"]["train_split_settings"] = {"first_n": 2} + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 3 + + +@pytest.mark.parametrize("max_atoms", [3, None]) +@pytest.mark.parametrize( + "key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)] +) +def test_create_dataset(key, value, max_atoms, structures, lmdb_database): + # now create a config + config = { + "format": "ase_db", + "src": str(lmdb_database), + key: value, + "max_atoms": max_atoms, + } + + dataset = create_dataset(config, split="train") + if max_atoms is not None: + structures = [s for s in structures if len(s) <= max_atoms] + assert all( + natoms <= max_atoms + for natoms in dataset.metadata.natoms[range(len(dataset))] + ) + if key == "first_n": # this assumes first_n are not shuffled + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures[:value], dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures[:value], dataset) + ) + elif key == "sample_n": + assert len(dataset) == value + else: # no shuffle all of them are in there + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures, dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures, dataset) + ) + + +# make sure we cant sample more than the number of elements in the dataset with sample_n +def test_sample_n_dataset(lmdb_database): + with pytest.raises(ValueError): + _ = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 100, + }, + split="train", + ) + + +def test_diff_seed_sample_dataset(lmdb_database): + dataset_a = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + assert (dataset_a.indices == dataset_b.indices).all() + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 1, + }, + split="train", + ) + assert not (dataset_a.indices == dataset_b.indices).all() + + +def test_del_dataset(): + class _Dataset(BaseDataset): + def __init__(self, fn) -> None: + super().__init__(config={}) + self.fn = fn + open(self.fn, "a").close() + + def __del__(self): + os.remove(self.fn) + + with tempfile.TemporaryDirectory() as tmpdirname: + fn = tmpdirname + "/test" + d = _Dataset(fn) + del d + assert not os.path.exists(fn) diff --git a/tests/core/datasets/test_lmdb_dataset.py b/tests/core/datasets/test_lmdb_dataset.py new file mode 100644 index 0000000000..f922e32ce3 --- /dev/null +++ b/tests/core/datasets/test_lmdb_dataset.py @@ -0,0 +1,29 @@ +from fairchem.core.datasets.base_dataset import create_dataset + +import numpy as np + +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes + + +def test_load_lmdb_dataset(tutorial_dataset_path): + + lmdb_path = str(tutorial_dataset_path / "s2ef/val_20") + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args(["--data-path", lmdb_path]) + make_lmdb_sizes(args) + + config = { + "format": "lmdb", + "src": lmdb_path, + } + + dataset = create_dataset(config, split="val") + + assert dataset.get_metadata("natoms", 0) == dataset[0].natoms + + all_metadata_natoms = np.array(dataset.get_metadata("natoms", range(len(dataset)))) + all_natoms = np.array([datapoint.natoms for datapoint in dataset]) + + assert (all_natoms == all_metadata_natoms).all() diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 8387d6e053..aea07201bd 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -9,6 +9,12 @@ import numpy as np import pytest import yaml +from fairchem.core.common.test_utils import ( + PGConfig, + init_env_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes from tensorboard.backend.event_processing.event_accumulator import EventAccumulator from fairchem.core._cli import Runner @@ -84,6 +90,7 @@ def _run_main( update_run_args_with=None, save_checkpoint_to=None, save_predictions_to=None, + world_size=0, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" @@ -91,6 +98,7 @@ def _run_main( yaml_config = yaml.safe_load(yaml_file) if update_dict_with is not None: yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) @@ -110,7 +118,19 @@ def _run_main( for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) config = build_config(args, override_args) - Runner()(config) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + spawn_multi_process( + pg_config, + Runner(distributed=True), + init_env_rank_and_launch_test, + config, + ) + else: + Runner()(config) if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") @@ -213,6 +233,72 @@ def test_train_and_predict( tutorial_val_src=tutorial_val_src, ) + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_balanced_batch_sampler_ddp( + self, world_size, ddp, configs, tutorial_val_src, torch_deterministic + ): + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args( + ["--data-path", str(tutorial_val_src)] + ) + make_lmdb_sizes(args) + + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1, "load_balancing": "atoms"}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + # train for a few steps and confirm same seeds get same results def test_different_seeds( self, @@ -290,9 +376,9 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet", 0.4, 0.06, id="gemnet"), - pytest.param("escn", 0.4, 0.06, id="escn"), - pytest.param("equiformer_v2", 0.4, 0.06, id="equiformer_v2"), + pytest.param("gemnet", 0.41, 0.06, id="gemnet"), + pytest.param("escn", 0.41, 0.06, id="escn"), + pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], ) def test_train_optimization( diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 34ed79ba2b..0034232cd2 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -18,7 +18,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.models.equiformer_v2.so3 import ( @@ -140,7 +144,9 @@ def test_energy_force_shape(self, snapshot): def test_ddp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 1 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -151,7 +157,9 @@ def test_ddp(self, snapshot): def test_gp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 2 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -225,4 +233,3 @@ def sign(x): embedding._l_primary(c) lp = embedding.embedding.clone() (test_matrix_lp == lp).all() - From 08b8c1ea9f1858d7f8f14df1718f26997c1ca799 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 2 Aug 2024 13:50:24 -0700 Subject: [PATCH 2/2] Move select models to backbone + heads format and add support for hydra (#782) * convert escn to bb + heads * convert dimenet to bb + heads * gemnet_oc to backbone and heads * add additional parameter backbone config to heads * gemnet to bb and heads * pain to bb and heads * add eqv2 bb+heads; move to canonical naming * fix calculator loading by leaving original class in code * fix issues with calculator loading * lint fixes * move dimenet++ heads to one * add test for dimenet * add painn test * hydra and tests for gemnetH dppH painnH * add escnH and equiformerv2H * add gemnetdt gemnetdtH * add smoke test for schnet and scn * remove old examples * typo * fix gemnet with grad forces; add test for this * remove unused params; add backbone and head interface; add typing * remove unused second order output heads * remove OC20 suffix from equiformer * remove comment * rename and lint * fix dimenet test * fix tests * refactor generate graph * refactor generate graph * fix a messy cherry pick * final messy fix * graph data interface in eqv2 * refactor * no bbconfigs * no more headconfigs in inits * rename hydra * fix eqV2 * update test configs * final fixes * fix tutorial * rm comments * fix test --------- Co-authored-by: lbluque Co-authored-by: Luis Barroso-Luque --- docs/legacy_tutorials/OCP_Tutorial.md | 2 +- src/fairchem/core/models/base.py | 137 +++++++- src/fairchem/core/models/dimenet_plus_plus.py | 147 +++++++-- .../core/models/equiformer_v2/__init__.py | 2 +- ...equiformer_v2_oc20.py => equiformer_v2.py} | 297 +++++++++++++++--- src/fairchem/core/models/escn/escn.py | 177 +++++++++-- src/fairchem/core/models/gemnet/gemnet.py | 195 +++++++++--- src/fairchem/core/models/gemnet_gp/gemnet.py | 64 ++-- .../core/models/gemnet_oc/gemnet_oc.py | 287 ++++++++++++++--- src/fairchem/core/models/painn/painn.py | 130 ++++++-- src/fairchem/core/models/schnet.py | 26 +- src/fairchem/core/models/scn/scn.py | 31 +- src/fairchem/core/trainers/base_trainer.py | 14 - tests/core/e2e/test_s2ef.py | 46 ++- tests/core/models/test_configs/test_dpp.yml | 50 +++ .../models/test_configs/test_dpp_hydra.yml | 55 ++++ .../test_configs/test_equiformerv2_hydra.yml | 98 ++++++ .../models/test_configs/test_escn_hydra.yml | 67 ++++ .../models/test_configs/test_gemnet_dt.yml | 79 +++++ .../test_configs/test_gemnet_dt_hydra.yml | 86 +++++ .../test_gemnet_dt_hydra_grad.yml | 84 +++++ .../{test_gemnet.yml => test_gemnet_oc.yml} | 0 .../test_configs/test_gemnet_oc_hydra.yml | 112 +++++++ .../test_gemnet_oc_hydra_grad.yml | 109 +++++++ tests/core/models/test_configs/test_painn.yml | 50 +++ .../models/test_configs/test_painn_hydra.yml | 58 ++++ .../core/models/test_configs/test_schnet.yml | 45 +++ tests/core/models/test_configs/test_scn.yml | 59 ++++ tests/core/models/test_dimenetpp.py | 3 - tests/core/models/test_equiformer_v2.py | 3 - tests/core/models/test_gemnet.py | 3 - tests/core/models/test_gemnet_oc.py | 3 - .../models/test_gemnet_oc_scaling_mismatch.py | 12 - tests/core/models/test_schnet.py | 2 +- 34 files changed, 2182 insertions(+), 351 deletions(-) rename src/fairchem/core/models/equiformer_v2/{equiformer_v2_oc20.py => equiformer_v2.py} (72%) create mode 100755 tests/core/models/test_configs/test_dpp.yml create mode 100755 tests/core/models/test_configs/test_dpp_hydra.yml create mode 100644 tests/core/models/test_configs/test_equiformerv2_hydra.yml create mode 100644 tests/core/models/test_configs/test_escn_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml rename tests/core/models/test_configs/{test_gemnet.yml => test_gemnet_oc.yml} (100%) create mode 100644 tests/core/models/test_configs/test_gemnet_oc_hydra.yml create mode 100644 tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml create mode 100644 tests/core/models/test_configs/test_painn.yml create mode 100644 tests/core/models/test_configs/test_painn_hydra.yml create mode 100755 tests/core/models/test_configs/test_schnet.yml create mode 100755 tests/core/models/test_configs/test_scn.yml diff --git a/docs/legacy_tutorials/OCP_Tutorial.md b/docs/legacy_tutorials/OCP_Tutorial.md index 8b5d4d522a..19fd93f6bc 100644 --- a/docs/legacy_tutorials/OCP_Tutorial.md +++ b/docs/legacy_tutorials/OCP_Tutorial.md @@ -1807,7 +1807,7 @@ Similarly, to predict forces, we pass edge features through a fully-connected la @registry.register_model("simple") class SimpleAtomEdgeModel(torch.nn.Module): - def __init__(self, num_atoms, bond_feat_dim, num_targets, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): + def __init__(self, emb_size=64, num_radial=64, cutoff=6.0, env_exponent=5): super().__init__() self.radial_basis = RadialBasis( diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 42790643a9..eb8c9d543c 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -8,27 +8,42 @@ from __future__ import annotations import logging +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch -import torch.nn as nn +from torch import nn from torch_geometric.nn import radius_graph +from fairchem.core.common.registry import registry from fairchem.core.common.utils import ( compute_neighbors, get_pbc_distances, radius_graph_pbc, ) +if TYPE_CHECKING: + from torch_geometric.data import Batch -class BaseModel(nn.Module): - def __init__(self, num_atoms=None, bond_feat_dim=None, num_targets=None) -> None: - super().__init__() - self.num_atoms = num_atoms - self.bond_feat_dim = bond_feat_dim - self.num_targets = num_targets - def forward(self, data): - raise NotImplementedError +@dataclass +class GraphData: + """Class to keep graph attributes nicely packaged.""" + + edge_index: torch.Tensor + edge_distance: torch.Tensor + edge_distance_vec: torch.Tensor + cell_offsets: torch.Tensor + offset_distances: torch.Tensor + neighbors: torch.Tensor + batch_full: torch.Tensor # used for GP functionality + atomic_numbers_full: torch.Tensor # used for GP functionality + node_offset: int = 0 # used for GP functionality + + +class GraphModelMixin: + """Mixin Model class implementing some general convenience properties and methods.""" def generate_graph( self, @@ -109,13 +124,16 @@ def generate_graph( ) neighbors = compute_neighbors(data, edge_index) - return ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - cell_offset_distances, - neighbors, + return GraphData( + edge_index=edge_index, + edge_distance=edge_dist, + edge_distance_vec=distance_vec, + cell_offsets=cell_offsets, + offset_distances=cell_offset_distances, + neighbors=neighbors, + node_offset=0, + batch_full=data.batch, + atomic_numbers_full=data.atomic_numbers.long(), ) @property @@ -130,3 +148,90 @@ def no_weight_decay(self) -> list: if "embedding" in name or "frequencies" in name or "bias" in name: no_wd_list.append(name) return no_wd_list + + +class HeadInterface(metaclass=ABCMeta): + @abstractmethod + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Head forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + emb: dict[str->torch.Tensor] + Embeddings of the input as generated by the backbone + + Returns + ------- + outputs: dict[str->torch.Tensor] + Return one or more targets generated by this head + """ + return + + +class BackboneInterface(metaclass=ABCMeta): + @abstractmethod + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + """Backbone forward. + + Arguments + --------- + data: DataBatch + Atomic systems as input + + Returns + ------- + embedding: dict[str->torch.Tensor] + Return backbone embeddings for the given input + """ + return + + +@registry.register_model("hydra") +class HydraModel(nn.Module, GraphModelMixin): + def __init__( + self, + backbone: dict, + heads: dict, + otf_graph: bool = True, + ): + super().__init__() + self.otf_graph = otf_graph + + backbone_model_name = backbone.pop("model") + self.backbone: BackboneInterface = registry.get_model_class( + backbone_model_name + )( + **backbone, + ) + + # Iterate through outputs_cfg and create heads + self.output_heads: dict[str, HeadInterface] = {} + + head_names_sorted = sorted(heads.keys()) + for head_name in head_names_sorted: + head_config = heads[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) + + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) + + self.output_heads = torch.nn.ModuleDict(self.output_heads) + + def forward(self, data: Batch): + emb = self.backbone(data) + # Predict all output properties for all structures in the batch for now. + out = {} + for k in self.output_heads: + out.update(self.output_heads[k](data, emb)) + + return out diff --git a/src/fairchem/core/models/dimenet_plus_plus.py b/src/fairchem/core/models/dimenet_plus_plus.py index 296a77bbba..aa08ea0672 100644 --- a/src/fairchem/core/models/dimenet_plus_plus.py +++ b/src/fairchem/core/models/dimenet_plus_plus.py @@ -34,6 +34,8 @@ from __future__ import annotations +import typing + import torch from torch import nn from torch_geometric.nn.inits import glorot_orthogonal @@ -49,7 +51,10 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch try: import sympy as sym @@ -57,7 +62,7 @@ sym = None -class InteractionPPBlock(torch.nn.Module): +class InteractionPPBlock(nn.Module): def __init__( self, hidden_channels: int, @@ -90,11 +95,11 @@ def __init__( self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) # Residual layers before and after skip connection. - self.layers_before_skip = torch.nn.ModuleList( + self.layers_before_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)] ) self.lin = nn.Linear(hidden_channels, hidden_channels) - self.layers_after_skip = torch.nn.ModuleList( + self.layers_after_skip = nn.ModuleList( [ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)] ) @@ -153,7 +158,7 @@ def forward(self, x, rbf, sbf, idx_kj, idx_ji): return h -class OutputPPBlock(torch.nn.Module): +class OutputPPBlock(nn.Module): def __init__( self, num_radial: int, @@ -169,7 +174,7 @@ def __init__( self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) - self.lins = torch.nn.ModuleList() + self.lins = nn.ModuleList() for _ in range(num_layers): self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) @@ -193,7 +198,7 @@ def forward(self, x, rbf, i, num_nodes: int | None = None): return self.lin(x) -class DimeNetPlusPlus(torch.nn.Module): +class DimeNetPlusPlus(nn.Module): r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. Args: @@ -241,7 +246,6 @@ def __init__( act = activation_resolver(act) super().__init__() - self.cutoff = cutoff if sym is None: @@ -256,7 +260,7 @@ def __init__( self.emb = EmbeddingBlock(num_radial, hidden_channels, act) - self.output_blocks = torch.nn.ModuleList( + self.output_blocks = nn.ModuleList( [ OutputPPBlock( num_radial, @@ -270,7 +274,7 @@ def __init__( ] ) - self.interaction_blocks = torch.nn.ModuleList( + self.interaction_blocks = nn.ModuleList( [ InteractionPPBlock( hidden_channels, @@ -330,13 +334,42 @@ def forward(self, z, pos, batch=None): raise NotImplementedError +@registry.register_model("dimenetplusplus_energy_and_force_head") +class DimeNetPlusPlusWrapEnergyAndForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.regress_forces = backbone.regress_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + outputs = { + "energy": ( + emb["P"].sum(dim=0) + if data.batch is None + else scatter(emb["P"], data.batch, dim=0) + ) + } + if self.regress_forces: + outputs["forces"] = ( + -1 + * ( + torch.autograd.grad( + outputs["energy"], + data.pos, + grad_outputs=torch.ones_like(outputs["energy"]), + create_graph=True, + )[0] + ) + ) + return outputs + + @registry.register_model("dimenetplusplus") -class DimeNetPlusPlusWrap(DimeNetPlusPlus, BaseModel): +class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin): def __init__( self, - num_atoms: int, - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, regress_forces: bool = True, hidden_channels: int = 128, @@ -353,7 +386,6 @@ def __init__( num_after_skip: int = 2, num_output_layers: int = 3, ) -> None: - self.num_targets = num_targets self.regress_forces = regress_forces self.use_pbc = use_pbc self.cutoff = cutoff @@ -362,7 +394,7 @@ def __init__( super().__init__( hidden_channels=hidden_channels, - out_channels=num_targets, + out_channels=1, num_blocks=num_blocks, int_emb_size=int_emb_size, basis_emb_size=basis_emb_size, @@ -380,22 +412,15 @@ def __init__( def _forward(self, data): pos = data.pos batch = data.batch - ( - edge_index, - dist, - _, - cell_offsets, - offsets, - neighbors, - ) = self.generate_graph(data) - - data.edge_index = edge_index - data.cell_offsets = cell_offsets - data.neighbors = neighbors - j, i = edge_index + graph = self.generate_graph(data) + + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( - edge_index, + graph.edge_index, data.cell_offsets, num_nodes=data.atomic_numbers.size(0), ) @@ -405,8 +430,8 @@ def _forward(self, data): pos_j = pos[idx_j].detach() if self.use_pbc: pos_ji, pos_kj = ( - pos[idx_j].detach() - pos_i + offsets[idx_ji], - pos[idx_k].detach() - pos_j + offsets[idx_kj], + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], ) else: pos_ji, pos_kj = ( @@ -418,8 +443,8 @@ def _forward(self, data): b = torch.cross(pos_ji, pos_kj).norm(dim=-1) angle = torch.atan2(b, a) - rbf = self.rbf(dist) - sbf = self.sbf(dist, angle, idx_kj) + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) # Embedding block. x = self.emb(data.atomic_numbers.long(), rbf, i, j) @@ -459,3 +484,57 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("dimenetplusplus_backbone") +class DimeNetPlusPlusWrapBackbone(DimeNetPlusPlusWrap, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + if self.regress_forces: + data.pos.requires_grad_(True) + pos = data.pos + graph = self.generate_graph(data) + data.edge_index = graph.edge_index + data.cell_offsets = graph.cell_offsets + data.neighbors = graph.neighbors + j, i = graph.edge_index + + _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( + graph.edge_index, + data.cell_offsets, + num_nodes=data.atomic_numbers.size(0), + ) + + # Calculate angles. + pos_i = pos[idx_i].detach() + pos_j = pos[idx_j].detach() + if self.use_pbc: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i + graph.offset_distances[idx_ji], + pos[idx_k].detach() - pos_j + graph.offset_distances[idx_kj], + ) + else: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i, + pos[idx_k].detach() - pos_j, + ) + + a = (pos_ji * pos_kj).sum(dim=-1) + b = torch.cross(pos_ji, pos_kj).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(graph.edge_distance) + sbf = self.sbf(graph.edge_distance, angle, idx_kj) + + # Embedding block. + x = self.emb(data.atomic_numbers.long(), rbf, i, j) + P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) + + # Interaction blocks. + for interaction_block, output_block in zip( + self.interaction_blocks, self.output_blocks[1:] + ): + x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) + P += output_block(x, rbf, i, num_nodes=pos.size(0)) + + return {"P": P, "edge_embedding": x, "edge_idx": i} diff --git a/src/fairchem/core/models/equiformer_v2/__init__.py b/src/fairchem/core/models/equiformer_v2/__init__.py index 424b64f9ed..720f890f65 100644 --- a/src/fairchem/core/models/equiformer_v2/__init__.py +++ b/src/fairchem/core/models/equiformer_v2/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from .equiformer_v2_oc20 import EquiformerV2_OC20 as EquiformerV2 +from .equiformer_v2 import EquiformerV2 __all__ = ["EquiformerV2"] diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py similarity index 72% rename from src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py rename to src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 8edf81319c..e2625eadaf 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -10,13 +10,15 @@ from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): pass +import typing + from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding @@ -42,13 +44,18 @@ TransBlockV2, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + + from fairchem.core.models.base import GraphData + # Statistics of IS2RE 100K _AVG_NUM_NODES = 77.81317 _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100 @registry.register_model("equiformer_v2") -class EquiformerV2_OC20(BaseModel): +class EquiformerV2(nn.Module, GraphModelMixin): """ Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation @@ -108,9 +115,6 @@ class EquiformerV2_OC20(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = True, @@ -436,23 +440,12 @@ def forward(self, data): self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, ) - data_batch_full = data.batch data_batch = data.batch - atomic_numbers_full = atomic_numbers - node_offset = 0 if gp_utils.initialized(): ( atomic_numbers, @@ -462,12 +455,17 @@ def forward(self, data): edge_distance, edge_distance_vec, ) = self._init_gp_partitions( - atomic_numbers_full, - data_batch_full, - edge_index, - edge_distance, - edge_distance_vec, + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + ############################################################### # Entering Graph Parallel Region # after this point, if using gp, then node, edge tensors are split @@ -485,7 +483,9 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -496,7 +496,6 @@ def forward(self, data): ############################################################### # Init per node representations using an atomic number based embedding - offset = 0 x = SO3_Embedding( len(atomic_numbers), self.lmax_list, @@ -519,27 +518,27 @@ def forward(self, data): offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) # Edge encoding (distance and atom edge) - edge_distance = self.distance_expansion(edge_distance) + graph.edge_distance = self.distance_expansion(graph.edge_distance) if self.share_atom_edge_embedding and self.use_atom_edge_embedding: - source_element = atomic_numbers_full[ - edge_index[0] + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] ] # Source atom atomic number - target_element = atomic_numbers_full[ - edge_index[1] + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] ] # Target atom atomic number source_embedding = self.source_embedding(source_element) target_embedding = self.target_embedding(target_element) - edge_distance = torch.cat( - (edge_distance, source_embedding, target_embedding), dim=1 + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 ) # Edge-degree embedding edge_degree = self.edge_degree_embedding( - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, len(atomic_numbers), - node_offset, + graph.node_offset, ) x.embedding = x.embedding + edge_degree.embedding @@ -550,11 +549,11 @@ def forward(self, data): for i in range(self.num_layers): x = self.blocks[i]( x, # SO3_Embedding - atomic_numbers_full, - edge_distance, - edge_index, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, batch=data_batch, # for GraphDropPath - node_offset=node_offset, + node_offset=graph.node_offset, ) # Final layer norm @@ -572,7 +571,7 @@ def forward(self, data): device=node_energy.device, dtype=node_energy.dtype, ) - energy.index_add_(0, data_batch_full, node_energy.view(-1)) + energy.index_add_(0, graph.batch_full, node_energy.view(-1)) energy = energy / self.avg_num_nodes # Add the per-atom linear references to the energy. @@ -594,8 +593,8 @@ def forward(self, data): with torch.cuda.amp.autocast(False): energy = energy.to(self.energy_lin_ref.dtype).index_add( 0, - data_batch_full, - self.energy_lin_ref[atomic_numbers_full], + graph.batch_full, + self.energy_lin_ref[graph.atomic_numbers_full], ) outputs = {"energy": energy} @@ -605,10 +604,10 @@ def forward(self, data): if self.regress_forces: forces = self.force_block( x, - atomic_numbers_full, - edge_distance, - edge_index, - node_offset=node_offset, + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + node_offset=graph.node_offset, ) forces = forces.embedding.narrow(1, 1, 3) forces = forces.view(-1, 3).contiguous() @@ -678,3 +677,209 @@ def no_weight_decay(self) -> set: no_wd_list.append(global_parameter_name) return set(no_wd_list) + + +@registry.register_model("equiformer_v2_backbone") +class EquiformerV2Backbone(EquiformerV2, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + return {"node_embedding": x, "graph": graph} + + +@registry.register_model("equiformer_v2_energy_head") +class EquiformerV2EnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.avg_num_nodes = backbone.avg_num_nodes + self.energy_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): + node_energy = self.energy_block(emb["node_embedding"]) + node_energy = node_energy.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) + energy = torch.zeros( + len(data.natoms), + device=node_energy.device, + dtype=node_energy.dtype, + ) + energy.index_add_(0, data.batch, node_energy.view(-1)) + return {"energy": energy / self.avg_num_nodes} + + +@registry.register_model("equiformer_v2_force_head") +class EquiformerV2ForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + + def forward(self, data: Batch, emb: dict[str, torch.Tensor]): + forces = self.force_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + forces = forces.embedding.narrow(1, 1, 3) + forces = forces.view(-1, 3).contiguous() + if gp_utils.initialized(): + forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) + return {"forces": forces} diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0ec66b9dba..dfa872c398 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -10,13 +10,17 @@ import contextlib import logging import time +import typing import torch import torch.nn as nn +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.escn.so3 import ( CoefficientMapping, SO3_Embedding, @@ -36,7 +40,7 @@ @registry.register_model("escn") -class eSCN(BaseModel): +class eSCN(nn.Module, GraphModelMixin): """Equivariant Spherical Channel Network Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs @@ -64,9 +68,6 @@ class eSCN(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -79,7 +80,6 @@ def __init__( sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, - use_grid: bool = True, num_sphere_samples: int = 128, distance_function: str = "gaussian", basis_width_scalar: float = 1.0, @@ -232,22 +232,16 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat(data, edge_index, edge_distance_vec) + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() @@ -290,8 +284,8 @@ def forward(self, data): x_message = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -304,8 +298,8 @@ def forward(self, data): x = self.layer_blocks[i]( x, atomic_numbers, - edge_distance, - edge_index, + graph.edge_distance, + graph.edge_index, self.SO3_edge_rot, mappingReduced, ) @@ -421,6 +415,149 @@ def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) +@registry.register_model("escn_backbone") +class eSCNBackbone(eSCN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + device = data.pos.device + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + + atomic_numbers = data.atomic_numbers.long() + num_atoms = len(atomic_numbers) + + graph = self.generate_graph(data) + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + self.SO3_edge_rot = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + offset = 0 + x = SO3_Embedding( + num_atoms, + self.lmax_list, + self.sphere_channels, + device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l=0,m=0 coefficients for each resolution + for i in range(self.num_resolutions): + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer + mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + if i > 0: + x_message = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Residual layer for all layers past the first + x.embedding = x.embedding + x_message.embedding + + else: + # No residual for the first layer + x = self.layer_blocks[i]( + x, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + mappingReduced, + ) + + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. + # These values are fed into the output blocks. + x_pt = torch.tensor([], device=device) + offset = 0 + # Compute the embedding values at every sampled point on the sphere + for i in range(self.num_resolutions): + num_coefficients = int((x.lmax_list[i] + 1) ** 2) + x_pt = torch.cat( + [ + x_pt, + torch.einsum( + "abc, pb->apc", + x.embedding[:, offset : offset + num_coefficients], + self.sphharm_weights[i], + ).contiguous(), + ], + dim=2, + ) + offset = offset + num_coefficients + + x_pt = x_pt.view(-1, self.sphere_channels_all) + + return {"sphere_values": x_pt, "sphere_points": self.sphere_points} + + +@registry.register_model("escn_energy_head") +class eSCNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + # Output blocks for energy and forces + self.energy_block = EnergyBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + node_energy = self.energy_block(emb["sphere_values"]) + energy = torch.zeros(len(data.natoms), device=data.pos.device) + energy.index_add_(0, data.batch, node_energy.view(-1)) + # Scale energy to help balance numerical precision w.r.t. forces + return {"energy": energy * 0.001} + + +@registry.register_model("escn_force_head") +class eSCNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.force_block = ForceBlock( + backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act + ) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + return {"forces": self.force_block(emb["sphere_values"], emb["sphere_points"])} + + class LayerBlock(torch.nn.Module): """ Layer block: Perform one layer (message passing and aggregation) of the GNN diff --git a/src/fairchem/core/models/gemnet/gemnet.py b/src/fairchem/core/models/gemnet/gemnet.py index e719c219b8..59b3eda08f 100644 --- a/src/fairchem/core/models/gemnet/gemnet.py +++ b/src/fairchem/core/models/gemnet/gemnet.py @@ -7,14 +7,20 @@ from __future__ import annotations +import typing + import numpy as np import torch +import torch.nn as nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -28,17 +34,12 @@ @registry.register_model("gemnet_t") -class GemNetT(BaseModel): +class GemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -94,9 +95,6 @@ class GemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -132,7 +130,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -235,7 +232,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -421,18 +418,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.edge_distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -447,10 +436,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -530,7 +519,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -549,7 +538,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -557,11 +546,11 @@ def forward(self, data): if self.extensive: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} @@ -569,30 +558,18 @@ def forward(self, data): if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t @@ -601,3 +578,129 @@ def forward(self, data): @property def num_params(self): return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_t_backbone") +class GemNetTBackbone(GemNetT, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + edge_index, + neighbors, + D_st, + V_st, + id_swap, + id3_ba, + id3_ca, + id3_ragged_idx, + ) = self.generate_interaction_graph(data) + idx_s, idx_t = edge_index + + # Calculate triplet angles + cosφ_cab = inner_product_normalized(V_st[id3_ca], V_st[id3_ba]) + rad_cbf3, cbf3 = self.cbf_basis3(D_st, cosφ_cab, id3_ca) + + rbf = self.radial_basis(D_st) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, rbf, idx_s, idx_t) # (nEdges, emb_size_edge) + + rbf3 = self.mlp_rbf3(rbf) + cbf3 = self.mlp_cbf3(rad_cbf3, cbf3, id3_ca, id3_ragged_idx) + + rbf_h = self.mlp_rbf_h(rbf) + rbf_out = self.mlp_rbf_out(rbf) + + E_t, F_st = self.out_blocks[0](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + rbf3=rbf3, + cbf3=cbf3, + id3_ragged_idx=id3_ragged_idx, + id_swap=id_swap, + id3_ba=id3_ba, + id3_ca=id3_ca, + rbf_h=rbf_h, + idx_s=idx_s, + idx_t=idx_t, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + E, F = self.out_blocks[i + 1](h, m, rbf_out, idx_t) + # (nAtoms, 1), (nEdges, 1) + F_st += F + E_t += E + return { + "F_st": F_st, + "E_t": E_t, + "edge_vec": V_st, + "edge_idx": idx_t, + "node_embedding": h, + "edge_embedding": m, + } + + +@registry.register_model("gemnet_t_energy_and_grad_force_head") +class GemNetTEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.extensive = backbone.extensive + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter( + emb["E_t"], data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t} + + if self.regress_forces and not self.direct_forces: + outputs["forces"] = -torch.autograd.grad( + E_t.sum(), data.pos, create_graph=True + )[0] + # (nAtoms, 3) + return outputs + + +@registry.register_model("gemnet_t_force_head") +class GemNetTForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # map forces in edge directions + F_st_vec = emb["F_st"][:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.size(0), + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (nAtoms, 3) diff --git a/src/fairchem/core/models/gemnet_gp/gemnet.py b/src/fairchem/core/models/gemnet_gp/gemnet.py index 81fbd40694..a75756dcc1 100644 --- a/src/fairchem/core/models/gemnet_gp/gemnet.py +++ b/src/fairchem/core/models/gemnet_gp/gemnet.py @@ -9,13 +9,14 @@ import numpy as np import torch +from torch import nn from torch_scatter import scatter from torch_sparse import SparseTensor from fairchem.core.common import gp_utils from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.modules.scaling.compat import load_scales_compat from .layers.atom_update_block import OutputBlock @@ -29,17 +30,12 @@ @registry.register_model("gp_gemnet_t") -class GraphParallelGemNetT(BaseModel): +class GraphParallelGemNetT(nn.Module, GraphModelMixin): """ GemNet-T, triplets-only variant of GemNet Parameters ---------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -95,9 +91,6 @@ class GraphParallelGemNetT(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -134,7 +127,6 @@ def __init__( if rbf is None: rbf = {"name": "gaussian"} super().__init__() - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive @@ -239,7 +231,7 @@ def __init__( emb_size_edge=emb_size_edge, emb_size_rbf=emb_size_rbf, nHidden=num_atom, - num_targets=num_targets, + num_targets=1, activation=activation, output_init=output_init, direct_forces=direct_forces, @@ -415,18 +407,10 @@ def select_edges( def generate_interaction_graph(self, data): num_atoms = data.atomic_numbers.size(0) - - ( - edge_index, - D_st, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - V_st = -distance_vec / D_st[:, None] + V_st = -graph.distance_vec / graph.edge_distance[:, None] # Mask interaction edges if required if self.otf_graph or np.isclose(self.cutoff, 6): @@ -441,10 +425,10 @@ def generate_interaction_graph(self, data): V_st, ) = self.select_edges( data=data, - edge_index=edge_index, - cell_offsets=cell_offsets, - neighbors=neighbors, - edge_dist=D_st, + edge_index=graph.edge_index, + cell_offsets=graph.cell_offsets, + neighbors=graph.neighbors, + edge_dist=graph.edge_distance, edge_vector=V_st, cutoff=select_cutoff, ) @@ -563,7 +547,7 @@ def forward(self, data): rbf_out = self.mlp_rbf_out(rbf) E_t, F_st = self.out_blocks[0](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) for i in range(self.num_blocks): # Interaction block @@ -585,7 +569,7 @@ def forward(self, data): ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) E, F = self.out_blocks[i + 1](nAtoms, m, rbf_out, idx_t) - # (nAtoms, num_targets), (nEdges, num_targets) + # (nAtoms, 1), (nEdges, 1) F_st += F E_t += E @@ -601,41 +585,29 @@ def forward(self, data): E_t = gp_utils.gather_from_model_parallel_region(E_t, dim=0) E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) outputs = {"energy": E_t} if self.regress_forces: if self.direct_forces: # map forces in edge directions F_st_vec = F_st[:, :, None] * V_st[:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter( F_st_vec, idx_t_full, dim=0, dim_size=data.atomic_numbers.size(0), reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) F_t = F_t.squeeze(1) # (nAtoms, 3) else: - if self.num_targets > 1: - forces = [] - for i in range(self.num_targets): - # maybe this can be solved differently - forces += [ - -torch.autograd.grad( - E_t[:, i].sum(), pos, create_graph=True - )[0] - ] - F_t = torch.stack(forces, dim=1) - # (nAtoms, num_targets, 3) - else: - F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] - # (nAtoms, 3) + F_t = -torch.autograd.grad(E_t.sum(), pos, create_graph=True)[0] + # (nAtoms, 3) outputs["forces"] = F_t diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index e1176d00c9..0aea3d81ba 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -7,9 +7,11 @@ from __future__ import annotations import logging +import typing import numpy as np import torch +import torch.nn as nn from torch_scatter import segment_coo from fairchem.core.common.registry import registry @@ -18,7 +20,7 @@ get_max_neighbors_mask, scatter_det, ) -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.modules.scaling.compat import load_scales_compat from .initializers import get_initializer @@ -40,17 +42,15 @@ repeat_blocks, ) +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + @registry.register_model("gemnet_oc") -class GemNetOC(BaseModel): +class GemNetOC(nn.Module, GraphModelMixin): """ Arguments --------- - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets: int - Number of prediction targets. - num_spherical: int Controls maximum frequency. num_radial: int @@ -179,9 +179,6 @@ class GemNetOC(BaseModel): def __init__( self, - num_atoms: int | None, - bond_feat_dim: int, - num_targets: int, num_spherical: int, num_radial: int, num_blocks: int, @@ -249,11 +246,11 @@ def __init__( super().__init__() if len(kwargs) > 0: logging.warning(f"Unrecognized arguments: {list(kwargs.keys())}") - self.num_targets = num_targets assert num_blocks > 0 self.num_blocks = num_blocks self.extensive = extensive + self.activation = activation self.atom_edge_interaction = atom_edge_interaction self.edge_atom_interaction = edge_atom_interaction self.atom_interaction = atom_interaction @@ -357,7 +354,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) - self.out_energy = Dense(emb_size_atom, num_targets, bias=False, activation=None) + self.out_energy = Dense(emb_size_atom, 1, bias=False, activation=None) if direct_forces: out_mlp_F = [ Dense( @@ -373,9 +370,7 @@ def __init__( for _ in range(num_global_out_layers) ] self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) - self.out_forces = Dense( - emb_size_edge, num_targets, bias=False, activation=None - ) + self.out_forces = Dense(emb_size_edge, 1, bias=False, activation=None) out_initializer = get_initializer(output_init) self.out_energy.reset_parameters(out_initializer) @@ -870,15 +865,7 @@ def subselect_edges( def generate_graph_dict(self, data, cutoff, max_neighbors): """Generate a radius/nearest neighbor graph.""" otf_graph = cutoff > 6 or max_neighbors > 50 or self.otf_graph - - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - num_neighbors, - ) = self.generate_graph( + graph = self.generate_graph( data, cutoff=cutoff, max_neighbors=max_neighbors, @@ -886,15 +873,15 @@ def generate_graph_dict(self, data, cutoff, max_neighbors): ) # These vectors actually point in the opposite direction. # But we want to use col as idx_t for efficient aggregation. - edge_vector = -distance_vec / edge_dist[:, None] - cell_offsets = -cell_offsets # a - c + offset + edge_vector = -graph.edge_distance_vec / graph.edge_distance[:, None] + cell_offsets = -graph.cell_offsets # a - c + offset graph = { - "edge_index": edge_index, - "distance": edge_dist, + "edge_index": graph.edge_index, + "distance": graph.edge_distance, "vector": edge_vector, "cell_offset": cell_offsets, - "num_neighbors": num_neighbors, + "num_neighbors": graph.neighbors, } # Mask interaction edges if required @@ -1285,11 +1272,11 @@ def forward(self, data): if self.extensive: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="add" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) else: E_t = scatter_det( E_t, batch, dim=0, dim_size=nMolecules, reduce="mean" - ) # (nMolecules, num_targets) + ) # (nMolecules, 1) E_t = E_t.squeeze(1) # (num_molecules) outputs = {"energy": E_t} @@ -1308,19 +1295,19 @@ def forward(self, data): dim=0, dim_size=int(nEdges / 2), reduce="mean", - ) # (nEdges/2, num_targets) - F_st = F_st[id_undir] # (nEdges, num_targets) + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) # map forces in edge directions F_st_vec = F_st[:, :, None] * main_graph["vector"][:, None, :] - # (nEdges, num_targets, 3) + # (nEdges, 1, 3) F_t = scatter_det( F_st_vec, idx_t, dim=0, dim_size=num_atoms, reduce="add", - ) # (nAtoms, num_targets, 3) + ) # (nAtoms, 1, 3) else: F_t = self.force_scaler.calc_forces_and_update(E_t, pos) @@ -1333,3 +1320,233 @@ def forward(self, data): @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) + + +@registry.register_model("gemnet_oc_backbone") +class GemNetOCBackbone(GemNetOC, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data: Batch) -> dict[str, torch.Tensor]: + pos = data.pos + atomic_numbers = data.atomic_numbers.long() + num_atoms = atomic_numbers.shape[0] + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + main_graph, + a2a_graph, + a2ee2a_graph, + qint_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ) = self.get_graphs_and_indices(data) + _, idx_t = main_graph["edge_index"] + + ( + basis_rad_raw, + basis_atom_update, + basis_output, + bases_qint, + bases_e2e, + bases_a2e, + bases_e2a, + basis_a2a_rad, + ) = self.get_bases( + main_graph=main_graph, + a2a_graph=a2a_graph, + a2ee2a_graph=a2ee2a_graph, + qint_graph=qint_graph, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + num_atoms=num_atoms, + ) + + # Embedding block + h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"]) + # (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[0](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E, xs_F = [x_E], [x_F] + + for i in range(self.num_blocks): + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + bases_qint=bases_qint, + bases_e2e=bases_e2e, + bases_a2e=bases_a2e, + bases_e2a=bases_e2a, + basis_a2a_rad=basis_a2a_rad, + basis_atom_update=basis_atom_update, + edge_index_main=main_graph["edge_index"], + a2ee2a_graph=a2ee2a_graph, + a2a_graph=a2a_graph, + id_swap=id_swap, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[i + 1](h, m, basis_output, idx_t) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E.append(x_E) + xs_F.append(x_F) + + return { + "xs_E": xs_E, + "xs_F": xs_F, + "edge_vec": main_graph["vector"], + "edge_idx": idx_t, + "num_neighbors": main_graph["num_neighbors"], + } + + +@registry.register_model("gemnet_oc_energy_and_grad_force_head") +class GemNetOCEnergyAndGradForceHead(nn.Module, HeadInterface): + def __init__( + self, + backbone: BackboneInterface, + num_global_out_layers: int, + output_init: str = "HeOrthogonal", + ): + super().__init__() + self.extensive = backbone.extensive + + self.regress_forces = backbone.regress_forces + self.direct_forces = backbone.direct_forces + self.force_scaler = backbone.force_scaler + + out_mlp_E = [ + Dense( + backbone.atom_emb.emb_size * (len(backbone.int_blocks) + 1), + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + backbone.atom_emb.emb_size, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_E = torch.nn.Sequential(*out_mlp_E) + + self.out_energy = Dense( + backbone.atom_emb.emb_size, + 1, + bias=False, + activation=None, + ) + + out_initializer = get_initializer(output_init) + self.out_energy.reset_parameters(out_initializer) + + @conditional_grad(torch.enable_grad()) + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + # Global output block for final predictions + x_E = self.out_mlp_E(torch.cat(emb["xs_E"], dim=-1)) + with torch.cuda.amp.autocast(False): + E_t = self.out_energy(x_E.float()) + + nMolecules = torch.max(data.batch) + 1 + if self.extensive: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="add" + ) # (nMolecules, 1) + else: + E_t = scatter_det( + E_t, data.batch, dim=0, dim_size=nMolecules, reduce="mean" + ) # (nMolecules, 1) + + outputs = {"energy": E_t.squeeze(1)} # (num_molecules) + + if self.regress_forces and not self.direct_forces: + F_t = self.force_scaler.calc_forces_and_update(outputs["energy"], data.pos) + outputs["forces"] = F_t.squeeze(1) + return outputs + + +@registry.register_model("gemnet_oc_force_head") +class GemNetOCForceHead(nn.Module, HeadInterface): + def __init__( + self, backbone, num_global_out_layers: int, output_init: str = "HeOrthogonal" + ): + super().__init__() + + self.direct_forces = backbone.direct_forces + self.forces_coupled = backbone.forces_coupled + + emb_size_edge = backbone.edge_emb.dense.linear.out_features + if self.direct_forces: + out_mlp_F = [ + Dense( + emb_size_edge * (len(backbone.int_blocks) + 1), + emb_size_edge, + activation=backbone.activation, + ) + ] + [ + ResidualLayer( + emb_size_edge, + activation=backbone.activation, + ) + for _ in range(num_global_out_layers) + ] + self.out_mlp_F = torch.nn.Sequential(*out_mlp_F) + self.out_forces = Dense( + emb_size_edge, + 1, + bias=False, + activation=None, + ) + out_initializer = get_initializer(output_init) + self.out_forces.reset_parameters(out_initializer) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) + with torch.cuda.amp.autocast(False): + F_st = self.out_forces(x_F.float()) + + if self.forces_coupled: # enforce F_st = F_ts + nEdges = emb["edge_idx"].shape[0] + id_undir = repeat_blocks( + emb["num_neighbors"] // 2, + repeats=2, + continuous_indexing=True, + ) + F_st = scatter_det( + F_st, + id_undir, + dim=0, + dim_size=int(nEdges / 2), + reduce="mean", + ) # (nEdges/2, 1) + F_st = F_st[id_undir] # (nEdges, 1) + + # map forces in edge directions + F_st_vec = F_st[:, :, None] * emb["edge_vec"][:, None, :] + # (nEdges, 1, 3) + F_t = scatter_det( + F_st_vec, + emb["edge_idx"], + dim=0, + dim_size=data.atomic_numbers.long().shape[0], + reduce="add", + ) # (nAtoms, 1, 3) + return {"forces": F_t.squeeze(1)} # (num_atoms, 3) + return {} diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index 8843f02b2e..ec9e9f465c 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -32,15 +32,19 @@ from __future__ import annotations import math +import typing import torch from torch import nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch from torch_geometric.nn import MessagePassing from torch_scatter import scatter, segment_coo from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface from fairchem.core.models.gemnet.layers.base_layers import ScaledSiLU from fairchem.core.models.gemnet.layers.embedding_block import AtomEmbedding from fairchem.core.models.gemnet.layers.radial_basis import RadialBasis @@ -51,7 +55,7 @@ @registry.register_model("painn") -class PaiNN(BaseModel): +class PaiNN(nn.Module, GraphModelMixin): r"""PaiNN model based on the description in Schütt et al. (2021): Equivariant message passing for the prediction of tensorial properties and molecular spectra, https://arxiv.org/abs/2102.03150. @@ -59,9 +63,6 @@ class PaiNN(BaseModel): def __init__( self, - num_atoms: int, - bond_feat_dim: int, - num_targets: int, hidden_channels: int = 512, num_layers: int = 6, num_rbf: int = 128, @@ -310,23 +311,16 @@ def symmetrize_edges( ) def generate_graph_values(self, data): - ( - edge_index, - edge_dist, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) # Unit vectors pointing from edge_index[1] to edge_index[0], # i.e., edge_index[0] - edge_index[1] divided by the norm. # make sure that the distances are not close to zero before dividing - mask_zero = torch.isclose(edge_dist, torch.tensor(0.0), atol=1e-6) - edge_dist[mask_zero] = 1.0e-6 - edge_vector = distance_vec / edge_dist[:, None] + mask_zero = torch.isclose(graph.edge_distance, torch.tensor(0.0), atol=1e-6) + graph.edge_distance[mask_zero] = 1.0e-6 + edge_vector = graph.edge_distance_vec / graph.edge_distance[:, None] - empty_image = neighbors == 0 + empty_image = graph.neighbors == 0 if torch.any(empty_image): raise ValueError( f"An image has no neighbors: id={data.id[empty_image]}, " @@ -342,11 +336,11 @@ def generate_graph_values(self, data): [edge_vector], id_swap, ) = self.symmetrize_edges( - edge_index, - cell_offsets, - neighbors, + graph.edge_index, + graph.cell_offsets, + graph.neighbors, data.batch, - [edge_dist], + [graph.edge_distance], [edge_vector], ) @@ -436,6 +430,50 @@ def __repr__(self) -> str: ) +@registry.register_model("painn_backbone") +class PaiNNBackbone(PaiNN, BackboneInterface): + @conditional_grad(torch.enable_grad()) + def forward(self, data) -> dict[str, torch.Tensor]: + pos = data.pos + z = data.atomic_numbers.long() + + if self.regress_forces and not self.direct_forces: + pos = pos.requires_grad_(True) + + ( + edge_index, + neighbors, + edge_dist, + edge_vector, + id_swap, + ) = self.generate_graph_values(data) + + assert z.dim() == 1 + assert z.dtype == torch.long + + edge_rbf = self.radial_basis(edge_dist) # rbf * envelope + + x = self.atom_emb(z) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + + #### Interaction blocks ############################################### + + for i in range(self.num_layers): + dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) + + x = x + dx + vec = vec + dvec + x = x * self.inv_sqrt_2 + + dx, dvec = self.update_layers[i](x, vec) + + x = x + dx + vec = vec + dvec + x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) + + return {"node_embedding": x, "node_vec": vec} + + class PaiNNMessage(MessagePassing): def __init__( self, @@ -625,3 +663,53 @@ def forward(self, x, v): x = self.act(x) return x, v + + +@registry.register_model("painn_energy_head") +class PaiNNEnergyHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + + self.out_energy = nn.Sequential( + nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2), + ScaledSiLU(), + nn.Linear(backbone.hidden_channels // 2, 1), + ) + + nn.init.xavier_uniform_(self.out_energy[0].weight) + self.out_energy[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.out_energy[2].weight) + self.out_energy[2].bias.data.fill_(0) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + per_atom_energy = self.out_energy(emb["node_embedding"]).squeeze(1) + return {"energy": scatter(per_atom_energy, data.batch, dim=0)} + + +@registry.register_model("painn_force_head") +class PaiNNForceHead(nn.Module, HeadInterface): + def __init__(self, backbone): + super().__init__() + self.direct_forces = backbone.direct_forces + + if self.direct_forces: + self.out_forces = PaiNNOutput(backbone.hidden_channels) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.direct_forces: + forces = self.out_forces(emb["node_embedding"], emb["node_vec"]) + else: + forces = ( + -1 + * torch.autograd.grad( + emb["node_embedding"], + data.pos, + grad_outputs=torch.ones_like(emb["node_embedding"]), + create_graph=True, + )[0] + ) + return {"forces": forces} diff --git a/src/fairchem/core/models/schnet.py b/src/fairchem/core/models/schnet.py index 2f89c17e1f..5ca70a354e 100644 --- a/src/fairchem/core/models/schnet.py +++ b/src/fairchem/core/models/schnet.py @@ -13,11 +13,11 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin @registry.register_model("schnet") -class SchNetWrap(SchNet, BaseModel): +class SchNetWrap(SchNet, GraphModelMixin): r"""Wrapper around the continuous-filter convolutional neural network SchNet from the `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions" `_. Each layer uses interaction @@ -28,9 +28,6 @@ class SchNetWrap(SchNet, BaseModel): h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))), Args: - num_atoms (int): Unused argument - bond_feat_dim (int): Unused argument - num_targets (int): Number of targets to predict. use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions. (default: :obj:`True`) regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating @@ -54,9 +51,6 @@ class SchNetWrap(SchNet, BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -67,7 +61,7 @@ def __init__( cutoff: float = 10.0, readout: str = "add", ) -> None: - self.num_targets = num_targets + self.num_targets = 1 self.regress_forces = regress_forces self.use_pbc = use_pbc self.cutoff = cutoff @@ -88,25 +82,17 @@ def _forward(self, data): z = data.atomic_numbers.long() pos = data.pos batch = data.batch - - ( - edge_index, - edge_weight, - distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) if self.use_pbc: assert z.dim() == 1 assert z.dtype == torch.long - edge_attr = self.distance_expansion(edge_weight) + edge_attr = self.distance_expansion(graph.edge_distance) h = self.embedding(z) for interaction in self.interactions: - h = h + interaction(h, edge_index, edge_weight, edge_attr) + h = h + interaction(h, graph.edge_index, graph.edge_distance, edge_attr) h = self.lin1(h) h = self.act(h) diff --git a/src/fairchem/core/models/scn/scn.py b/src/fairchem/core/models/scn/scn.py index bf8454f212..84806e19e8 100644 --- a/src/fairchem/core/models/scn/scn.py +++ b/src/fairchem/core/models/scn/scn.py @@ -18,7 +18,7 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BaseModel +from fairchem.core.models.base import GraphModelMixin from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -33,7 +33,7 @@ @registry.register_model("scn") -class SphericalChannelNetwork(BaseModel): +class SphericalChannelNetwork(nn.Module, GraphModelMixin): """Spherical Channel Network Paper: Spherical Channels for Modeling Atomic Interactions @@ -75,9 +75,6 @@ class SphericalChannelNetwork(BaseModel): def __init__( self, - num_atoms: int, # not used - bond_feat_dim: int, # not used - num_targets: int, # not used use_pbc: bool = True, regress_forces: bool = True, otf_graph: bool = False, @@ -262,15 +259,7 @@ def _forward_helper(self, data): atomic_numbers = data.atomic_numbers.long() num_atoms = len(atomic_numbers) pos = data.pos - - ( - edge_index, - edge_distance, - edge_distance_vec, - cell_offsets, - _, # cell offset distances - neighbors, - ) = self.generate_graph(data) + graph = self.generate_graph(data) ############################################################### # Initialize data structures @@ -278,12 +267,12 @@ def _forward_helper(self, data): # Calculate which message block each edge should use. Based on edge distance rank. edge_rank = self._rank_edge_distances( - edge_distance, edge_index, self.max_num_neighbors + graph.edge_distance, graph.edge_index, self.max_num_neighbors ) # Reorder edges so that they are grouped by distance rank (lowest to highest) last_cutoff = -0.1 - message_block_idx = torch.zeros(len(edge_distance), device=pos.device) + message_block_idx = torch.zeros(len(graph.edge_distance), device=pos.device) edge_distance_reorder = torch.tensor([], device=self.device) edge_index_reorder = torch.tensor([], device=self.device) edge_distance_vec_reorder = torch.tensor([], device=self.device) @@ -297,21 +286,21 @@ def _forward_helper(self, data): edge_distance_reorder = torch.cat( [ edge_distance_reorder, - torch.masked_select(edge_distance, mask), + torch.masked_select(graph.edge_distance, mask), ], dim=0, ) edge_index_reorder = torch.cat( [ edge_index_reorder, - torch.masked_select(edge_index, mask.view(1, -1).repeat(2, 1)).view( - 2, -1 - ), + torch.masked_select( + graph.edge_index, mask.view(1, -1).repeat(2, 1) + ).view(2, -1), ], dim=1, ) edge_distance_vec_mask = torch.masked_select( - edge_distance_vec, mask.view(-1, 1).repeat(1, 3) + graph.edge_distance_vec, mask.view(-1, 1).repeat(1, 3) ).view(-1, 3) edge_distance_vec_reorder = torch.cat( [edge_distance_vec_reorder, edge_distance_vec_mask], dim=0 diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 1c0c975f8a..c21409863e 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -450,21 +450,7 @@ def load_model(self) -> None: if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") - # TODO: depreicated, remove. - bond_feat_dim = None - bond_feat_dim = self.config["model_attributes"].get("num_gaussians", 50) - - loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( - ( - loader.dataset[0].x.shape[-1] - if loader - and hasattr(loader.dataset[0], "x") - and loader.dataset[0].x is not None - else None - ), - bond_feat_dim, - 1, **self.config["model_attributes"], ).to(self.device) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index aea07201bd..1584becd45 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -27,9 +27,32 @@ @pytest.fixture() def configs(): return { + "scn": Path("tests/core/models/test_configs/test_scn.yml"), "escn": Path("tests/core/models/test_configs/test_escn.yml"), - "gemnet": Path("tests/core/models/test_configs/test_gemnet.yml"), + "escn_hydra": Path("tests/core/models/test_configs/test_escn_hydra.yml"), + "schnet": Path("tests/core/models/test_configs/test_schnet.yml"), + "gemnet_dt": Path("tests/core/models/test_configs/test_gemnet_dt.yml"), + "gemnet_dt_hydra": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra.yml" + ), + "gemnet_dt_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml" + ), + "gemnet_oc": Path("tests/core/models/test_configs/test_gemnet_oc.yml"), + "gemnet_oc_hydra": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra.yml" + ), + "gemnet_oc_hydra_grad": Path( + "tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml" + ), + "dimenet++": Path("tests/core/models/test_configs/test_dpp.yml"), + "dimenet++_hydra": Path("tests/core/models/test_configs/test_dpp_hydra.yml"), + "painn": Path("tests/core/models/test_configs/test_painn.yml"), + "painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"), "equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"), + "equiformer_v2_hydra": Path( + "tests/core/models/test_configs/test_equiformerv2_hydra.yml" + ), } @@ -173,7 +196,7 @@ def smoke_test_train( rundir=str(train_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), @@ -194,7 +217,7 @@ def smoke_test_train( rundir=str(predictions_rundir), input_yaml=input_yaml, update_dict_with={ - "optim": {"max_epochs": 2, "eval_every": 8}, + "optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5}, "dataset": oc20_lmdb_train_and_val_from_paths( train_src=str(tutorial_val_src), val_src=str(tutorial_val_src), @@ -216,9 +239,22 @@ def smoke_test_train( @pytest.mark.parametrize( "model_name", [ - pytest.param("gemnet", id="gemnet"), + pytest.param("schnet", id="schnet"), + pytest.param("scn", id="scn"), + pytest.param("gemnet_dt", id="gemnet_dt"), + pytest.param("gemnet_dt_hydra", id="gemnet_dt_hydra"), + pytest.param("gemnet_dt_hydra_grad", id="gemnet_dt_hydra_grad"), + pytest.param("gemnet_oc", id="gemnet_oc"), + pytest.param("gemnet_oc_hydra", id="gemnet_oc_hydra"), + pytest.param("gemnet_oc_hydra_grad", id="gemnet_oc_hydra_grad"), + pytest.param("dimenet++", id="dimenet++"), + pytest.param("dimenet++_hydra", id="dimenet++_hydra"), + pytest.param("painn", id="painn"), + pytest.param("painn_hydra", id="painn_hydra"), pytest.param("escn", id="escn"), + pytest.param("escn_hydra", id="escn_hydra"), pytest.param("equiformer_v2", id="equiformer_v2"), + pytest.param("equiformer_v2_hydra", id="equiformer_v2_hydra"), ], ) def test_train_and_predict( @@ -376,7 +412,7 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet", 0.41, 0.06, id="gemnet"), + pytest.param("gemnet_oc", 0.41, 0.06, id="gemnet_oc"), pytest.param("escn", 0.41, 0.06, id="escn"), pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], diff --git a/tests/core/models/test_configs/test_dpp.yml b/tests/core/models/test_configs/test_dpp.yml new file mode 100755 index 0000000000..a79294bd15 --- /dev/null +++ b/tests/core/models/test_configs/test_dpp.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: dimenetplusplus #_bbwheads + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 3 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_dpp_hydra.yml b/tests/core/models/test_configs/test_dpp_hydra.yml new file mode 100755 index 0000000000..1120cc905f --- /dev/null +++ b/tests/core/models/test_configs/test_dpp_hydra.yml @@ -0,0 +1,55 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: dimenetplusplus_backbone + hidden_channels: 4 + out_emb_channels: 4 + num_blocks: 1 + cutoff: 6.0 + num_radial: 6 + num_spherical: 7 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + regress_forces: True + use_pbc: True + heads: + energy: + module: dimenetplusplus_energy_and_force_head + +# *** Important note *** +# The total number of gpus used for this run was 256. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 5 + eval_batch_size: 2 + eval_every: 1000 + num_workers: 8 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 130794 + - 196192 + - 261589 + warmup_steps: 130794 + warmup_factor: 0.2 + max_epochs: 7 diff --git a/tests/core/models/test_configs/test_equiformerv2_hydra.yml b/tests/core/models/test_configs/test_equiformerv2_hydra.yml new file mode 100644 index 0000000000..4c00fe6a2e --- /dev/null +++ b/tests/core/models/test_configs/test_equiformerv2_hydra.yml @@ -0,0 +1,98 @@ + + +trainer: forces + +model: + name: hydra + backbone: + model: equiformer_v2_backbone + use_pbc: True + regress_forces: True + otf_graph: True + + enforce_max_neighbors_strictly: False + + max_neighbors: 1 + max_radius: 12.0 + max_num_elements: 90 + + num_layers: 1 + sphere_channels: 4 + attn_hidden_channels: 4 # [64, 96] This determines the hidden size of message passing. Do not necessarily use 96. + num_heads: 1 + attn_alpha_channels: 4 # Not used when `use_s2_act_attn` is True. + attn_value_channels: 4 + ffn_hidden_channels: 8 + norm_type: 'layer_norm_sh' # ['rms_norm_sh', 'layer_norm', 'layer_norm_sh'] + + lmax_list: [1] + mmax_list: [1] + grid_resolution: 18 # [18, 16, 14, None] For `None`, simply comment this line. + + num_sphere_samples: 128 + + edge_channels: 32 + use_atom_edge_embedding: True + distance_function: 'gaussian' + num_distance_basis: 16 # not used + + attn_activation: 'silu' + use_s2_act_attn: False # [False, True] Switch between attention after S2 activation or the original EquiformerV1 attention. + ffn_activation: 'silu' # ['silu', 'swiglu'] + use_gate_act: False # [True, False] Switch between gate activation and S2 activation + use_grid_mlp: False # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs. + + alpha_drop: 0.0 # [0.0, 0.1] + drop_path_rate: 0.0 # [0.0, 0.05] + proj_drop: 0.0 + + weight_init: 'normal' # ['uniform', 'normal'] + heads: + energy: + module: equiformer_v2_energy_head + forces: + module: equiformer_v2_force_head + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_escn_hydra.yml b/tests/core/models/test_configs/test_escn_hydra.yml new file mode 100644 index 0000000000..ba5db1f53e --- /dev/null +++ b/tests/core/models/test_configs/test_escn_hydra.yml @@ -0,0 +1,67 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: escn_backbone + num_layers: 2 + max_neighbors: 10 + cutoff: 12.0 + sphere_channels: 8 + hidden_channels: 8 + lmax_list: [2] + mmax_list: [2] + num_sphere_samples: 64 + distance_function: "gaussian" + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + heads: + energy: + module: escn_energy_head + forces: + module: escn_force_head + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_dt.yml b/tests/core/models/test_configs/test_gemnet_dt.yml new file mode 100644 index 0000000000..b04b6dfda0 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt.yml @@ -0,0 +1,79 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: gemnet_t + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml new file mode 100644 index 0000000000..a612741470 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra.yml @@ -0,0 +1,86 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: True + heads: + energy: + module: gemnet_t_energy_and_grad_force_head + forces: + module: gemnet_t_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml new file mode 100644 index 0000000000..83d46bdd4d --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_dt_hydra_grad.yml @@ -0,0 +1,84 @@ +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_t_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip: 4 + emb_size_rbf: 4 + emb_size_cbf: 4 + emb_size_bil_trip: 4 + num_before_skip: 1 + num_after_skip: 2 + num_concat: 1 + num_atom: 3 + cutoff: 6.0 + max_neighbors: 50 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + extensive: True + otf_graph: False + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-dT.json + + regress_forces: True + direct_forces: False + heads: + energy_and_forces: + module: gemnet_t_energy_and_grad_force_head + +optim: + batch_size: 8 + eval_batch_size: 8 + eval_every: 5000 + num_workers: 2 + lr_initial: 5.e-4 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_gemnet.yml b/tests/core/models/test_configs/test_gemnet_oc.yml similarity index 100% rename from tests/core/models/test_configs/test_gemnet.yml rename to tests/core/models/test_configs/test_gemnet_oc.yml diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml new file mode 100644 index 0000000000..97343e90e6 --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra.yml @@ -0,0 +1,112 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: True + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + forces: + module: gemnet_oc_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml new file mode 100644 index 0000000000..334c3cb4db --- /dev/null +++ b/tests/core/models/test_configs/test_gemnet_oc_hydra_grad.yml @@ -0,0 +1,109 @@ + + + +trainer: forces + +dataset: + train: + src: tutorial_dset/s2ef/train_100/ + normalize_labels: True + target_mean: -0.7554450631141663 + target_std: 2.887317180633545 + grad_target_mean: 0.0 + grad_target_std: 2.887317180633545 + val: + format: lmdb + src: tutorial_dset/s2ef/val_20/ + +logger: + name: tensorboard + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +model: + name: hydra + backbone: + model: gemnet_oc_backbone + num_spherical: 3 + num_radial: 8 + num_blocks: 2 + emb_size_atom: 8 + emb_size_edge: 16 + emb_size_trip_in: 4 + emb_size_trip_out: 4 + emb_size_quad_in: 2 + emb_size_quad_out: 2 + emb_size_aint_in: 4 + emb_size_aint_out: 4 + emb_size_rbf: 2 + emb_size_cbf: 2 + emb_size_sbf: 4 + num_before_skip: 1 + num_after_skip: 1 + num_concat: 1 + num_atom: 3 + num_output_afteratom: 3 + cutoff: 12.0 + cutoff_qint: 12.0 + cutoff_aeaint: 12.0 + cutoff_aint: 12.0 + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + rbf: + name: gaussian + envelope: + name: polynomial + exponent: 5 + cbf: + name: spherical_harmonics + sbf: + name: legendre_outer + extensive: True + output_init: HeOrthogonal + activation: silu + scale_file: configs/s2ef/all/gemnet/scaling_factors/gemnet-oc.pt + + regress_forces: True + direct_forces: False + forces_coupled: False + + quad_interaction: True + atom_edge_interaction: True + edge_atom_interaction: True + atom_interaction: True + + num_atom_emb_layers: 2 + num_global_out_layers: 2 + qint_tags: [1, 2] + heads: + energy: + module: gemnet_oc_energy_and_grad_force_head + num_global_out_layers: 2 + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 10 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae diff --git a/tests/core/models/test_configs/test_painn.yml b/tests/core/models/test_configs/test_painn.yml new file mode 100644 index 0000000000..c1f24d0bb5 --- /dev/null +++ b/tests/core/models/test_configs/test_painn.yml @@ -0,0 +1,50 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: painn #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_painn_hydra.yml b/tests/core/models/test_configs/test_painn_hydra.yml new file mode 100644 index 0000000000..0b39aa1731 --- /dev/null +++ b/tests/core/models/test_configs/test_painn_hydra.yml @@ -0,0 +1,58 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: hydra + backbone: + model: painn_backbone #_bbwheads + hidden_channels: 32 + num_layers: 6 + num_rbf: 32 + cutoff: 12.0 + max_neighbors: 5 + scale_file: configs/s2ef/all/painn/painn_nb6_scaling_factors.pt + regress_forces: True + direct_forces: True + use_pbc: True + heads: + energy: + module: painn_energy_head + forces: + module: painn_force_head + + +optim: + batch_size: 32 + eval_batch_size: 32 + load_balancing: atoms + eval_every: 5000 + num_workers: 2 + optimizer: AdamW + optimizer_params: + amsgrad: True + weight_decay: 0. # 2e-6 (TF weight decay) / 1e-4 (lr) = 2e-2 + lr_initial: 1.e-4 + lr_gamma: 0.8 + scheduler: ReduceLROnPlateau + mode: min + factor: 0.8 + patience: 3 + max_epochs: 80 + force_coefficient: 100 + energy_coefficient: 1 + ema_decay: 0.999 + clip_grad_norm: 10 diff --git a/tests/core/models/test_configs/test_schnet.yml b/tests/core/models/test_configs/test_schnet.yml new file mode 100755 index 0000000000..97faf3962a --- /dev/null +++ b/tests/core/models/test_configs/test_schnet.yml @@ -0,0 +1,45 @@ +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: schnet + hidden_channels: 1024 + num_filters: 256 + num_interactions: 5 + num_gaussians: 200 + cutoff: 6.0 + use_pbc: True + +# *** Important note *** +# The total number of gpus used for this run was 64. +# If the global batch size (num_gpus * batch_size) is modified +# the lr_milestones and warmup_steps need to be adjusted accordingly. + +optim: + batch_size: 20 + eval_batch_size: 20 + eval_every: 10000 + num_workers: 16 + lr_initial: 0.0001 + lr_gamma: 0.1 + lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma + - 313907 + - 523179 + - 732451 + warmup_steps: 209271 + warmup_factor: 0.2 + max_epochs: 15 diff --git a/tests/core/models/test_configs/test_scn.yml b/tests/core/models/test_configs/test_scn.yml new file mode 100755 index 0000000000..c080c48557 --- /dev/null +++ b/tests/core/models/test_configs/test_scn.yml @@ -0,0 +1,59 @@ +# A total of 64 32GB GPUs were used for training. +trainer: forces + +task: + dataset: lmdb + type: regression + metric: mae + primary_metric: forces_mae + labels: + - potential energy + grad_input: atomic forces + train_on_free_atoms: True + eval_on_free_atoms: True + prediction_dtype: float32 + +logger: + name: tensorboard + +model: + name: scn + num_interactions: 2 + hidden_channels: 16 + sphere_channels: 8 + sphere_channels_reduce: 8 + num_sphere_samples: 8 + num_basis_functions: 8 + distance_function: "gaussian" + show_timing_info: False + max_num_neighbors: 40 + cutoff: 8.0 + lmax: 4 + num_bands: 2 + use_grid: True + regress_forces: True + use_pbc: True + basis_width_scalar: 2.0 + otf_graph: True + +optim: + batch_size: 2 + eval_batch_size: 1 + num_workers: 2 + lr_initial: 0.0004 + optimizer: AdamW + optimizer_params: {"amsgrad": True} + eval_every: 5000 + lr_gamma: 0.3 + lr_milestones: # epochs at which lr_initial <- lr_initial * lr_gamma + - 260000 + - 340000 + - 420000 + - 500000 + - 800000 + - 1000000 + warmup_steps: 100 + warmup_factor: 0.2 + max_epochs: 12 + clip_grad_norm: 100 + ema_decay: 0.999 diff --git a/tests/core/models/test_dimenetpp.py b/tests/core/models/test_dimenetpp.py index 76a546037b..d1daec728b 100644 --- a/tests/core/models/test_dimenetpp.py +++ b/tests/core/models/test_dimenetpp.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("dimenetplusplus")( - None, - 32, - 1, cutoff=6.0, regress_forces=True, use_pbc=False, diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 0034232cd2..3194dd2df7 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -63,9 +63,6 @@ def _load_model(): checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("equiformer_v2")( - None, - -1, - 1, use_pbc=True, regress_forces=True, otf_graph=True, diff --git a/tests/core/models/test_gemnet.py b/tests/core/models/test_gemnet.py index 3fa0c6babc..b4c5414cc4 100644 --- a/tests/core/models/test_gemnet.py +++ b/tests/core/models/test_gemnet.py @@ -47,9 +47,6 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("gemnet_t")( - None, - -1, - 1, cutoff=6.0, num_spherical=7, num_radial=128, diff --git a/tests/core/models/test_gemnet_oc.py b/tests/core/models/test_gemnet_oc.py index d84669750f..7729c14483 100644 --- a/tests/core/models/test_gemnet_oc.py +++ b/tests/core/models/test_gemnet_oc.py @@ -58,9 +58,6 @@ def load_model(request) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_gemnet_oc_scaling_mismatch.py b/tests/core/models/test_gemnet_oc_scaling_mismatch.py index 8f1c36d277..29ea40c0fa 100644 --- a/tests/core/models/test_gemnet_oc_scaling_mismatch.py +++ b/tests/core/models/test_gemnet_oc_scaling_mismatch.py @@ -35,9 +35,6 @@ def test_no_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -111,9 +108,6 @@ def test_scaling_mismatch(self) -> None: checkpoint = torch.load(io.BytesIO(r.content), map_location=torch.device("cpu")) model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -189,9 +183,6 @@ def test_no_file_exists(self) -> None: with pytest.raises(ValueError): registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, @@ -245,9 +236,6 @@ def test_not_fitted(self) -> None: setup_imports() model = registry.get_model_class("gemnet_oc")( - None, - -1, - 1, num_spherical=7, num_radial=128, num_blocks=4, diff --git a/tests/core/models/test_schnet.py b/tests/core/models/test_schnet.py index aa704604f7..3dd21be4e1 100644 --- a/tests/core/models/test_schnet.py +++ b/tests/core/models/test_schnet.py @@ -46,7 +46,7 @@ def load_model(request) -> None: setup_imports() model = registry.get_model_class("schnet")( - None, 32, 1, cutoff=6.0, regress_forces=True, use_pbc=True + cutoff=6.0, regress_forces=True, use_pbc=True ) request.cls.model = model