diff --git a/docs/core/fine-tuning/fine-tuning-oxides.md b/docs/core/fine-tuning/fine-tuning-oxides.md index 77a9350d3b..39c39cad40 100644 --- a/docs/core/fine-tuning/fine-tuning-oxides.md +++ b/docs/core/fine-tuning/fine-tuning-oxides.md @@ -205,6 +205,7 @@ from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', 'optim.loss_force', # the checkpoint setting causes an error + 'optim.load_balancing', 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, 'optim.eval_every': 10, diff --git a/docs/tutorials/advanced/fine-tuning-in-python.md b/docs/tutorials/advanced/fine-tuning-in-python.md index 1d14219c88..0eeb8e5485 100644 --- a/docs/tutorials/advanced/fine-tuning-in-python.md +++ b/docs/tutorials/advanced/fine-tuning-in-python.md @@ -75,7 +75,7 @@ We start by making the config.yml. We build this from the calculator checkpoint. from fairchem.core.common.tutorial_utils import generate_yml_config yml = generate_yml_config(checkpoint_path, 'config.yml', - delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes', + delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes','optim.load_balancing', 'optim.loss_force', # the checkpoint setting causes an error 'dataset', 'test_dataset', 'val_dataset'], update={'gpus': 1, diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index 4d5836b786..89c3b67445 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -9,20 +9,23 @@ import heapq import logging -from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal import numba import numpy as np -import numpy.typing as npt import torch -from torch.utils.data import BatchSampler, DistributedSampler, Sampler +import torch.distributed +from torch.utils.data import BatchSampler, Dataset, DistributedSampler +from typing_extensions import override from fairchem.core.common import distutils, gp_utils from fairchem.core.datasets import data_list_collater +from fairchem.core.datasets.base_dataset import ( + UnsupportedDatasetError, +) if TYPE_CHECKING: - from pathlib import Path - + from numpy.typing import NDArray from torch_geometric.data import Batch, Data @@ -35,30 +38,24 @@ def __call__(self, data_list: list[Data]) -> Batch: @numba.njit -def balanced_partition(sizes: npt.NDArray[np.int_], num_parts: int): +def _balanced_partition(sizes: NDArray[np.int_], num_parts: int): """ Greedily partition the given set by always inserting the largest element into the smallest partition. """ sort_idx = np.argsort(-sizes) # Sort in descending order - heap: list[tuple[list[int], list[int]]] = [ - (sizes[idx], [idx]) for idx in sort_idx[:num_parts] - ] + heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]] heapq.heapify(heap) for idx in sort_idx[num_parts:]: smallest_part = heapq.heappop(heap) new_size = smallest_part[0] + sizes[idx] - new_idx = smallest_part[1] + [idx] + new_idx = smallest_part[1] + [ + idx + ] # TODO should this be append to save time/space heapq.heappush(heap, (new_size, new_idx)) return [part[1] for part in heap] -@runtime_checkable -class _HasMetadata(Protocol): - @property - def metadata_path(self) -> Path: ... - - class StatefulDistributedSampler(DistributedSampler): """ More fine-grained state DataSampler that uses training iteration and epoch @@ -105,56 +102,83 @@ def set_epoch_and_start_iteration(self, epoch, start_iter): self.start_iter = start_iter -class BalancedBatchSampler(Sampler): - def _load_dataset(self, dataset, mode: Literal["atoms", "neighbors"]): - errors: list[str] = [] - if not isinstance(dataset, _HasMetadata): - errors.append(f"Dataset {dataset} does not have a metadata_path attribute.") - return None, errors - if not dataset.metadata_path.exists(): - errors.append(f"Metadata file {dataset.metadata_path} does not exist.") - return None, errors +def _ensure_supported(dataset: Any): + if not isinstance(dataset, Dataset): + raise UnsupportedDatasetError("BalancedBatchSampler requires a dataset.") + + if not dataset.metadata_hasattr("natoms"): + raise UnsupportedDatasetError( + "BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms." + ) - key = {"atoms": "natoms", "neighbors": "neighbors"}[mode] - sizes = np.load(dataset.metadata_path)[key] + logging.debug(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}") + return dataset - return sizes, errors +class BalancedBatchSampler(BatchSampler): def __init__( self, - dataset, + dataset: Dataset, + *, batch_size: int, num_replicas: int, rank: int, device: torch.device, seed: int, - mode: str | bool = "atoms", + mode: bool | Literal["atoms"] = "atoms", shuffle: bool = True, + on_error: Literal["warn_and_balance", "warn_and_no_balance", "raise"] = "raise", drop_last: bool = False, - force_balancing: bool = False, - throw_on_error: bool = False, - ) -> None: - if mode is True: - mode = "atoms" - - if isinstance(mode, str): - mode = mode.lower() - if mode not in ("atoms", "neighbors"): - raise ValueError( - f"Invalid mode {mode}. Must be one of 'atoms', 'neighbors', or a boolean." - ) + ): + """ + Initializes a BalancedBatchSampler object. - self.dataset = dataset - self.batch_size = batch_size - self.num_replicas = num_replicas - self.rank = rank - self.device = device - self.mode = mode - self.shuffle = shuffle - self.drop_last = drop_last + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): The size of each batch. + num_replicas (int): The number of processes participating in distributed training. + rank (int): The rank of the current process in distributed training. + device (torch.device): The device to use for the batches. + mode (str or bool, optional): The mode to use for balancing the batches. Defaults to "atoms". + shuffle (bool, optional): Whether to shuffle the samples. Defaults to True. + on_error (Literal["warn_and_balance", "warn_and_no_balance", "raise"], optional): The action to take when an error occurs (i.e., when we have an invalid dataset). Defaults to "raise". + - "warn_and_balance": Raise a warning and balance the batch by manually loading the data samples and counting the number of nodes (this is slow). + - "warn_and_no_balance": Raise a warning and do not do any balancing. + - "raise": Raise an error. + drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False. + """ + self.disabled = False + self.on_error = on_error + + if mode is False: + logging.warning(f"Disabled BalancedBatchSampler because {mode=}.") + self.disabled = True + elif mode.lower() != "atoms": + raise ValueError( + f"Only mode='atoms' or mode=True is supported, got {mode=}." + ) + elif num_replicas == 1: + logging.warning(f"Disabled BalancedBatchSampler because {num_replicas=}.") + self.disabled = True + + try: + dataset = _ensure_supported(dataset) + except UnsupportedDatasetError as error: + if self.on_error == "raise": + raise error + if self.on_error == "warn_and_balance": + logging.warning( + f"Failed to get data sizes from metadata, loading data to get sizes (THIS IS SLOW). {error}" + ) + elif self.on_error == "warn_and_no_balance": + logging.warning( + f"Failed to get data sizes, falling back to uniform partitioning. {error}" + ) + else: + raise ValueError(f"Unknown on_error={self.on_error}") from error - self.single_sampler = StatefulDistributedSampler( - self.dataset, + sampler = StatefulDistributedSampler( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, @@ -162,82 +186,59 @@ def __init__( batch_size=batch_size, seed=seed, ) - self.batch_sampler = BatchSampler( - self.single_sampler, - batch_size, - drop_last=drop_last, - ) - - self.sizes = None - self.balance_batches = False - if self.num_replicas <= 1: - logging.info("Batch balancing is disabled for single GPU training.") - return - - if self.mode is False: - logging.info( - "Batch balancing is disabled because `optim.load_balancing` is `False`" - ) - return - - self.sizes, errors = self._load_dataset(dataset, self.mode) - if self.sizes is None: - self.balance_batches = force_balancing - if force_balancing: - errors.append( - "BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead! " - "You can disable balancing by setting `optim.load_balancing` to `False`." - ) - else: - errors.append( - "Batches will not be balanced, which can incur significant overhead!" - ) - else: - self.balance_batches = True - - if errors: - msg = "BalancedBatchSampler: " + " ".join(errors) - if throw_on_error: - raise RuntimeError(msg) + super().__init__(sampler, batch_size=batch_size, drop_last=drop_last) + self.device = device - logging.warning(msg) + logging.info( + f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}" + ) - def __len__(self) -> int: - return len(self.batch_sampler) + def _get_natoms(self, batch_idx: list[int]): + if self.sampler.dataset.metadata_hasattr("natoms"): + return np.array( + self.sampler.dataset.get_metadata("natoms", batch_idx) + ).reshape(-1) + if self.on_error == "warn_and_balance": + return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx]) + return None def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None: - if not hasattr(self.single_sampler, "set_epoch_and_start_iteration"): + if not isinstance(self.sampler, StatefulDistributedSampler): if start_iteration != 0: raise NotImplementedError( f"{type(self.single_sampler)} does not support resuming from a nonzero step." ) - self.single_sampler.set_epoch(epoch) + self.sampler.set_epoch(epoch) else: - self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration) + self.sampler.set_epoch_and_start_iteration(epoch, start_iteration) + + def set_epoch(self, epoch: int) -> None: + if isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + @staticmethod + def _dist_enabled(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @override def __iter__(self): - if not self.balance_batches: - yield from self.batch_sampler + if self.disabled or not self._dist_enabled(): + yield from super().__iter__() return - for batch_idx in self.batch_sampler: - if self.sizes is None: - # Unfortunately, we need to load the data to know the image sizes - data_list = [self.dataset[idx] for idx in batch_idx] - - if self.mode == "atoms": - sizes = [data.num_nodes for data in data_list] - elif self.mode == "neighbors": - sizes = [data.edge_index.shape[1] for data in data_list] - else: - raise NotImplementedError( - f"Unknown load balancing mode: {self.mode}" - ) - else: - sizes = [self.sizes[idx] for idx in batch_idx] - - idx_sizes = torch.stack([torch.tensor(batch_idx), torch.tensor(sizes)]) + for batch_idx in super().__iter__(): + sizes = self._get_natoms(batch_idx) + if sizes is None: # on_error == "warn_and_no_balance" is set + yield batch_idx + continue + + idx_sizes = torch.stack( + [ + torch.tensor(batch_idx, device=self.device), + torch.tensor(sizes, device=self.device), + ] + ) idx_sizes_all = distutils.all_gather(idx_sizes, device=self.device) idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu() if gp_utils.initialized(): @@ -245,9 +246,10 @@ def __iter__(self): idx_all = idx_sizes_all[0] sizes_all = idx_sizes_all[1] - local_idx_balanced = balanced_partition( - sizes_all.numpy(), num_parts=self.num_replicas + local_idx_balanced = _balanced_partition( + sizes_all.numpy(), + num_parts=self.sampler.num_replicas, ) # Since DistributedSampler pads the last batch # this should always have an entry for each replica. - yield idx_all[local_idx_balanced[self.rank]] + yield idx_all[local_idx_balanced[self.sampler.rank]] diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 919f7ba66d..8989840641 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -98,7 +98,7 @@ def setup(config) -> None: ) else: config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend="nccl") + dist.init_process_group(backend=config.get("backend", "nccl")) def cleanup() -> None: @@ -144,7 +144,7 @@ def all_reduce( if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) dist.all_reduce(tensor, group=group) if average: tensor /= get_world_size() @@ -162,7 +162,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]: if not isinstance(data, torch.Tensor): tensor = torch.tensor(data) if device is not None: - tensor = tensor.cuda(device) + tensor = tensor.to(device) tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())] dist.all_gather(tensor_list, tensor, group=group) if not isinstance(data, torch.Tensor): diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index 8aaf822105..130daba2d5 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -44,9 +44,55 @@ class PGConfig: use_gp: bool = True +def init_env_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["RANK"] = str(rank) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + +def init_pg_and_rank_and_launch_test( + rank: int, + pg_setup_params: PGConfig, + mp_output_dict: dict[int, object], + test_method: callable, + args: list[object], + kwargs: dict[str, object], +) -> None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = pg_setup_params.port + os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) + os.environ["LOCAL_RANK"] = str(rank) + # setup default process group + dist.init_process_group( + rank=rank, + world_size=pg_setup_params.world_size, + backend=pg_setup_params.backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) + # setup gp + if pg_setup_params.use_gp: + config = { + "gp_gpus": pg_setup_params.gp_group_size, + "distributed_backend": pg_setup_params.backend, + } + setup_gp(config) + mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme + + def spawn_multi_process( config: PGConfig, test_method: callable, + init_and_launch: callable, *test_method_args: Any, **test_method_kwargs: Any, ) -> list[Any]: @@ -72,7 +118,7 @@ def spawn_multi_process( torch.multiprocessing.spawn( # torch.multiprocessing.spawn sends rank as the first param # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn - _init_pg_and_rank_and_launch_test, + init_and_launch, args=( config, mp_output_dict, @@ -84,32 +130,3 @@ def spawn_multi_process( ) return [mp_output_dict[i] for i in range(config.world_size)] - - -def _init_pg_and_rank_and_launch_test( - rank: int, - pg_setup_params: PGConfig, - mp_output_dict: dict[int, object], - test_method: callable, - args: list[object], - kwargs: dict[str, object], -) -> None: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = pg_setup_params.port - os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size) - os.environ["LOCAL_RANK"] = str(rank) - # setup default process group - dist.init_process_group( - rank=rank, - world_size=pg_setup_params.world_size, - backend=pg_setup_params.backend, - timeout=timedelta(seconds=10), # setting up timeout for distributed collectives - ) - # setup gp - if pg_setup_params.use_gp: - config = { - "gp_gpus": pg_setup_params.gp_group_size, - "distributed_backend": pg_setup_params.backend, - } - setup_gp(config) - mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme diff --git a/src/fairchem/core/datasets/__init__.py b/src/fairchem/core/datasets/__init__.py index 1fd4b51cd5..dc3f7d0e4d 100644 --- a/src/fairchem/core/datasets/__init__.py +++ b/src/fairchem/core/datasets/__init__.py @@ -5,23 +5,19 @@ from __future__ import annotations from .ase_datasets import AseDBDataset, AseReadDataset, AseReadMultiStructureDataset +from .base_dataset import create_dataset from .lmdb_database import LMDBDatabase from .lmdb_dataset import ( LmdbDataset, - SinglePointLmdbDataset, - TrajectoryLmdbDataset, data_list_collater, ) -from .oc22_lmdb_dataset import OC22LmdbDataset __all__ = [ "AseDBDataset", "AseReadDataset", "AseReadMultiStructureDataset", "LmdbDataset", - "SinglePointLmdbDataset", - "TrajectoryLmdbDataset", - "data_list_collater", - "OC22LmdbDataset", "LMDBDatabase", + "create_dataset", + "data_list_collater", ] diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index a258f42832..15c22322db 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -20,13 +20,12 @@ import ase import numpy as np -import torch.nn from torch import tensor -from torch.utils.data import Dataset from tqdm import tqdm from fairchem.core.common.registry import registry from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.lmdb_database import LMDBDatabase from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms @@ -60,7 +59,7 @@ def apply_one_tags( return atoms -class AseAtomsDataset(Dataset, ABC): +class AseAtomsDataset(BaseDataset, ABC): """ This is an abstract Dataset that includes helpful utilities for turning ASE atoms objects into OCP-usable data objects. This should not be instantiated directly @@ -81,7 +80,7 @@ def __init__( config: dict, atoms_transform: Callable[[ase.Atoms, Any, ...], ase.Atoms] = apply_one_tags, ) -> None: - self.config = config + super().__init__(config) a2g_args = config.get("a2g_args", {}) or {} @@ -96,19 +95,13 @@ def __init__( self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - self.lin_ref = None - if self.config.get("lin_ref", False): - lin_ref = torch.tensor( - np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - ) - self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) - self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): self.__getitem__ = cache(self.__getitem__) self.ids = self._load_dataset_get_ids(config) + self.num_samples = len(self.ids) if len(self.ids) == 0: raise ValueError( @@ -116,9 +109,6 @@ def __init__( f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" ) - def __len__(self) -> int: - return len(self.ids) - def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): @@ -174,11 +164,7 @@ def _load_dataset_get_ids(self, config): def get_relaxed_energy(self, identifier): raise NotImplementedError("IS2RE-Direct is not implemented with this dataset.") - def close_db(self) -> None: - # This method is sometimes called by a trainer - pass - - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: metadata = {} if num_samples < len(self): @@ -197,6 +183,18 @@ def get_metadata(self, num_samples: int = 100) -> dict: return metadata + def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here + if attr != "natoms": + return None + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) + return len(self.get_atoms(idx)) + @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): @@ -399,7 +397,7 @@ def get_atoms(self, idx: str) -> ase.Atoms: return atoms - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: return {} def get_relaxed_energy(self, identifier) -> float: @@ -556,17 +554,17 @@ def connect_db( return ase.db.connect(address, **connect_args) - def close_db(self) -> None: + def __del__(self): for db in self.dbs: if hasattr(db, "close"): db.close() - def get_metadata(self, num_samples: int = 100) -> dict: + def sample_property_metadata(self, num_samples: int = 100) -> dict: logging.warning( "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" ) if self.dbs[0].metadata == {}: - return super().get_metadata(num_samples) + return super().sample_property_metadata(num_samples) return copy.deepcopy(self.dbs[0].metadata) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py new file mode 100644 index 0000000000..2ca26596c3 --- /dev/null +++ b/src/fairchem/core/datasets/base_dataset.py @@ -0,0 +1,227 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from abc import ABCMeta +from functools import cached_property +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + NamedTuple, + TypeVar, +) + +import numpy as np +import torch +from torch import randperm +from torch.utils.data import Dataset +from torch.utils.data import Subset as Subset_ + +from fairchem.core.common.registry import registry + +if TYPE_CHECKING: + from collections.abc import Sequence + + from numpy.typing import ArrayLike + + +T_co = TypeVar("T_co", covariant=True) + + +class DatasetMetadata(NamedTuple): + natoms: ArrayLike | None = None + + +class UnsupportedDatasetError(ValueError): + pass + + +class BaseDataset(Dataset[T_co], metaclass=ABCMeta): + """Base Dataset class for all OCP datasets.""" + + def __init__(self, config: dict): + """Initialize + + Args: + config (dict): dataset configuration + """ + self.config = config + self.paths = [] + + if "src" in self.config: + if isinstance(config["src"], str): + self.paths = [Path(self.config["src"])] + else: + self.paths = tuple(Path(path) for path in config["src"]) + + self.lin_ref = None + if self.config.get("lin_ref", False): + lin_ref = torch.tensor( + np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] + ) + self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) + + def __len__(self) -> int: + return self.num_samples + + def metadata_hasattr(self, attr) -> bool: + if self._metadata is None: + return False + return hasattr(self._metadata, attr) + + @cached_property + def indices(self): + return np.arange(self.num_samples, dtype=int) + + @cached_property + def _metadata(self) -> DatasetMetadata: + # logic to read metadata file here + metadata_npzs = [] + if self.config.get("metadata_path", None) is not None: + metadata_npzs.append( + np.load(self.config["metadata_path"], allow_pickle=True) + ) + + else: + for path in self.paths: + if path.is_file(): + metadata_file = path.parent / "metadata.npz" + else: + metadata_file = path / "metadata.npz" + if metadata_file.is_file(): + metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) + + if len(metadata_npzs) == 0: + logging.warning( + f"Could not find dataset metadata.npz files in '{self.paths}'" + ) + return None + + metadata = DatasetMetadata( + **{ + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in DatasetMetadata._fields + } + ) + + assert metadata.natoms.shape[0] == len( + self + ), "Loaded metadata and dataset size mismatch." + + return metadata + + def get_metadata(self, attr, idx): + if self._metadata is not None: + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + return None + + +class Subset(Subset_, BaseDataset): + """A pytorch subset that also takes metadata if given.""" + + def __init__( + self, + dataset: BaseDataset, + indices: Sequence[int], + metadata: DatasetMetadata | None = None, + ) -> None: + super().__init__(dataset, indices) + self.metadata = metadata + self.indices = indices + self.num_samples = len(indices) + self.config = dataset.config + + @cached_property + def _metadata(self) -> DatasetMetadata: + return self.dataset._metadata + + def get_metadata(self, attr, idx): + if isinstance(idx, list): + return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) + return self.dataset.get_metadata(attr, self.indices[idx]) + + +def create_dataset(config: dict[str, Any], split: str) -> Subset: + """Create a dataset from a config dictionary + + Args: + config (dict): dataset config dictionary + split (str): name of split + + Returns: + Subset: dataset subset class + """ + # Initialize the dataset + dataset_cls = registry.get_dataset_class(config.get("format", "lmdb")) + assert issubclass(dataset_cls, Dataset), f"{dataset_cls} is not a Dataset" + + # remove information about other splits, only keep specified split + # this may only work with the mt config not main config + current_split_config = config.copy() + if "splits" in current_split_config: + current_split_config.pop("splits") + current_split_config.update(config["splits"][split]) + + seed = current_split_config.get("seed", 0) + if split != "train": + seed += ( + 1 # if we use same dataset for train / val , make sure its diff sampling + ) + + g = torch.Generator() + g.manual_seed(seed) + + dataset = dataset_cls(current_split_config) + # Get indices of the dataset + indices = dataset.indices + max_atoms = current_split_config.get("max_atoms", None) + if max_atoms is not None: + if not dataset.metadata_hasattr("natoms"): + raise ValueError("Cannot use max_atoms without dataset metadata") + indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms] + + # Apply dataset level transforms + # TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle? + first_n = current_split_config.get("first_n") + sample_n = current_split_config.get("sample_n") + no_shuffle = current_split_config.get("no_shuffle") + # this is true if at most one of the mutually exclusive arguments are set + if sum(arg is not None for arg in (first_n, sample_n, no_shuffle)) > 1: + raise ValueError( + "sample_n, first_n, no_shuffle are mutually exclusive arguments. Only one can be provided." + ) + if first_n is not None: + max_index = first_n + elif sample_n is not None: + # shuffle by default, user can disable to optimize if they have confidence in dataset + # shuffle all datasets by default to avoid biasing the sampling in concat dataset + # TODO only shuffle if split is train + max_index = sample_n + indices = indices[randperm(len(indices), generator=g)] + else: + max_index = len(indices) + indices = ( + indices if no_shuffle else indices[randperm(len(indices), generator=g)] + ) + + if max_index > len(indices): + msg = ( + f"Cannot take {max_index} data points from a dataset of only length {len(indices)}.\n" + f"Make sure to set first_n or sample_n to a number =< the total samples in dataset." + ) + if max_atoms is not None: + msg = msg[:-1] + f"that are smaller than the given max_atoms {max_atoms}." + raise ValueError(msg) + + indices = indices[:max_index] + + return Subset(dataset, indices, metadata=dataset._metadata) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index 91ced220ea..ca1fcc2b77 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -9,32 +9,33 @@ import bisect import logging import pickle -import warnings -from pathlib import Path -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar import lmdb import numpy as np import torch -from torch.utils.data import Dataset from torch_geometric.data import Batch -from torch_geometric.data.data import BaseData from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.datasets.target_metadata_guesser import guess_property_metadata from fairchem.core.modules.transforms import DataTransforms +if TYPE_CHECKING: + from pathlib import Path + + from torch_geometric.data.data import BaseData + T_co = TypeVar("T_co", covariant=True) @registry.register_dataset("lmdb") @registry.register_dataset("single_point_lmdb") @registry.register_dataset("trajectory_lmdb") -class LmdbDataset(Dataset[T_co]): - metadata_path: Path +class LmdbDataset(BaseDataset): sharded: bool r"""Dataset class to load from LMDB files containing relaxation @@ -50,20 +51,21 @@ class LmdbDataset(Dataset[T_co]): """ def __init__(self, config) -> None: - super().__init__() - self.config = config + super().__init__(config) assert not self.config.get( "train_on_oc20_total_energies", False ), "For training on total energies set dataset=oc22_lmdb" - self.path = Path(self.config["src"]) + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] + if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" - self.metadata_path = self.path / "metadata.npz" - self._keys = [] self.envs = [] for db_path in db_paths: @@ -86,7 +88,6 @@ def __init__(self, config) -> None: self._keylen_cumulative = np.cumsum(keylens).tolist() self.num_samples = sum(keylens) else: - self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) # If "length" encoded as ascii is present, use that @@ -113,19 +114,15 @@ def __init__(self, config) -> None: self.indices, self.config.get("total_shards", 1) ) # limit each process to see a subset of data based off defined shard - self.available_indices = self.shards[self.config.get("shard", 0)] - self.num_samples = len(self.available_indices) + self.indices = self.shards[self.config.get("shard", 0)] + self.num_samples = len(self.indices) self.key_mapping = self.config.get("key_mapping", None) self.transforms = DataTransforms(self.config.get("transforms", {})) - def __len__(self) -> int: - return self.num_samples - def __getitem__(self, idx: int) -> T_co: # if sharding, remap idx to appropriate idx of the sharded set - if self.sharded: - idx = self.available_indices[idx] + idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) @@ -165,14 +162,14 @@ def connect_db(self, lmdb_path: Path | None = None) -> lmdb.Environment: max_readers=1, ) - def close_db(self) -> None: + def __del__(self): if not self.path.is_file(): for env in self.envs: env.close() else: self.env.close() - def get_metadata(self, num_samples: int = 100): + def sample_property_metadata(self, num_samples: int = 100): # This will interogate the classic OCP LMDB format to determine # which properties are present and attempt to guess their shapes # and whether they are intensive or extensive. @@ -214,26 +211,6 @@ def get_metadata(self, num_samples: int = 100): } -class SinglePointLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "SinglePointLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - -class TrajectoryLmdbDataset(LmdbDataset[BaseData]): - def __init__(self, config, transform=None) -> None: - super().__init__(config) - warnings.warn( - "TrajectoryLmdbDataset is deprecated and will be removed in the future." - "Please use 'LmdbDataset' instead.", - stacklevel=3, - ) - - def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData: batch = Batch.from_data_list(data_list) diff --git a/src/fairchem/core/datasets/oc22_lmdb_dataset.py b/src/fairchem/core/datasets/oc22_lmdb_dataset.py index 0c6d4e8bfb..867a72726f 100644 --- a/src/fairchem/core/datasets/oc22_lmdb_dataset.py +++ b/src/fairchem/core/datasets/oc22_lmdb_dataset.py @@ -9,22 +9,21 @@ import bisect import pickle -from pathlib import Path import lmdb import numpy as np import torch -from torch.utils.data import Dataset from fairchem.core.common.registry import registry from fairchem.core.common.typing import assert_is_instance as aii from fairchem.core.common.utils import pyg2_data_transform from fairchem.core.datasets._utils import rename_data_object_keys +from fairchem.core.datasets.base_dataset import BaseDataset from fairchem.core.modules.transforms import DataTransforms @registry.register_dataset("oc22_lmdb") -class OC22LmdbDataset(Dataset): +class OC22LmdbDataset(BaseDataset): r"""Dataset class to load from LMDB files containing relaxation trajectories or single point computations. @@ -43,10 +42,13 @@ class OC22LmdbDataset(Dataset): """ def __init__(self, config, transform=None) -> None: - super().__init__() - self.config = config + super().__init__(config) + + assert ( + len(self.paths) == 1 + ), f"{type(self)} does not support a list of src paths." + self.path = self.paths[0] - self.path = Path(self.config["src"]) self.data2train = self.config.get("data2train", "all") if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) @@ -114,19 +116,11 @@ def __init__(self, config, transform=None) -> None: if self.train_on_oc20_total_energies: with open(config["oc20_ref"], "rb") as fp: self.oc20_ref = pickle.load(fp) - if self.config.get("lin_ref", False): - coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] - self.lin_ref = torch.nn.Parameter(torch.tensor(coeff), requires_grad=False) - self.subsample = aii(self.config.get("subsample", False), bool) def __len__(self) -> int: - if self.subsample: - return min(self.subsample, self.num_samples) return self.num_samples def __getitem__(self, idx): - if self.data2train != "all": - idx = self.indices[idx] if not self.path.is_file(): # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._keylen_cumulative, idx) diff --git a/src/fairchem/core/scripts/make_lmdb_sizes.py b/src/fairchem/core/scripts/make_lmdb_sizes.py index 682fb58e65..ebf2122aeb 100644 --- a/src/fairchem/core/scripts/make_lmdb_sizes.py +++ b/src/fairchem/core/scripts/make_lmdb_sizes.py @@ -15,7 +15,7 @@ from tqdm import tqdm from fairchem.core.common.typing import assert_is_instance -from fairchem.core.datasets import SinglePointLmdbDataset, TrajectoryLmdbDataset +from fairchem.core.datasets.lmdb_dataset import LmdbDataset def get_data(index): @@ -28,14 +28,13 @@ def get_data(index): return index, natoms, neighbors -def main(args) -> None: +def make_lmdb_sizes(args) -> None: path = assert_is_instance(args.data_path, str) global dataset + dataset = LmdbDataset({"src": path}) if os.path.isdir(path): - dataset = TrajectoryLmdbDataset({"src": path}) outpath = os.path.join(path, "metadata.npz") elif os.path.isfile(path): - dataset = SinglePointLmdbDataset({"src": path}) outpath = os.path.join(os.path.dirname(path), "metadata.npz") output_indices = range(len(dataset)) @@ -63,7 +62,7 @@ def main(args) -> None: np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors) -if __name__ == "__main__": +def get_lmdb_sizes_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--data-path", @@ -77,5 +76,10 @@ def main(args) -> None: type=int, help="Num of workers to parallelize across", ) + return parser + + +if __name__ == "__main__": + parser = get_lmdb_sizes_parser() args: argparse.Namespace = parser.parse_args() - main(args) + make_lmdb_sizes(args) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index fcac13502b..c21409863e 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -7,6 +7,7 @@ from __future__ import annotations +import copy import datetime import errno import logging @@ -38,6 +39,7 @@ save_checkpoint, update_config, ) +from fairchem.core.datasets.base_dataset import create_dataset from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss @@ -241,12 +243,16 @@ def load_logger(self) -> None: def get_sampler( self, dataset, batch_size: int, shuffle: bool ) -> BalancedBatchSampler: - if "load_balancing" in self.config["optim"]: - balancing_mode = self.config["optim"]["load_balancing"] - force_balancing = True + balancing_mode = self.config["optim"].get("load_balancing", None) + on_error = self.config["optim"].get("load_balancing_on_error", None) + if balancing_mode is not None: + if on_error is None: + on_error = "raise" else: balancing_mode = "atoms" - force_balancing = False + + if on_error is None: + on_error = "warn_and_no_balance" if gp_utils.initialized(): num_replicas = gp_utils.get_dp_world_size() @@ -262,7 +268,7 @@ def get_sampler( device=self.device, mode=balancing_mode, shuffle=shuffle, - force_balancing=force_balancing, + on_error=on_error, seed=self.config["cmd"]["seed"], ) @@ -283,15 +289,26 @@ def load_datasets(self) -> None: self.val_loader = None self.test_loader = None + # This is hacky and scheduled to be removed next BE week + # move ['X_split_settings'] to ['splits'][X] + def convert_settings_to_split_settings(config, split_name): + config = copy.deepcopy(config) # make sure we dont modify the original + if f"{split_name}_split_settings" in config: + config["splits"] = { + split_name: config.pop(f"{split_name}_split_settings") + } + return config + # load train, val, test datasets if "src" in self.config["dataset"]: logging.info( f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}" ) - self.train_dataset = registry.get_dataset_class( - self.config["dataset"].get("format", "lmdb") - )(self.config["dataset"]) + self.train_dataset = create_dataset( + convert_settings_to_split_settings(self.config["dataset"], "train"), + "train", + ) self.train_sampler = self.get_sampler( self.train_dataset, self.config["optim"]["batch_size"], @@ -302,6 +319,16 @@ def load_datasets(self) -> None: self.train_sampler, ) + if ( + "first_n" in self.config["dataset"] + or "sample_n" in self.config["dataset"] + or "max_atom" in self.config["dataset"] + ): + logging.warn( + "Dataset attributes (first_n/sample_n/max_atom) passed to all datasets! Please don't do this, its dangerous!\n" + + "Add them under each dataset 'train_split_settings'/'val_split_settings'/'test_split_settings'" + ) + if "src" in self.config["val_dataset"]: if self.config["val_dataset"].get("use_train_settings", True): val_config = self.config["dataset"].copy() @@ -309,9 +336,9 @@ def load_datasets(self) -> None: else: val_config = self.config["val_dataset"] - self.val_dataset = registry.get_dataset_class( - val_config.get("format", "lmdb") - )(val_config) + self.val_dataset = create_dataset( + convert_settings_to_split_settings(val_config, "val"), "val" + ) self.val_sampler = self.get_sampler( self.val_dataset, self.config["optim"].get( @@ -331,9 +358,9 @@ def load_datasets(self) -> None: else: test_config = self.config["test_dataset"] - self.test_dataset = registry.get_dataset_class( - test_config.get("format", "lmdb") - )(test_config) + self.test_dataset = create_dataset( + convert_settings_to_split_settings(test_config, "test"), "test" + ) self.test_sampler = self.get_sampler( self.test_dataset, self.config["optim"].get( @@ -443,7 +470,9 @@ def load_model(self) -> None: self.logger.log_summary({"num_params": self.model.num_params}) if distutils.initialized() and not self.config["noddp"]: - self.model = DistributedDataParallel(self.model, device_ids=[self.device]) + self.model = DistributedDataParallel( + self.model, device_ids=None if self.cpu else [self.device] + ) @property def _unwrapped_model(self): diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 9055d2d625..72c005893d 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -227,12 +227,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) - self.train_dataset.close_db() - if self.config.get("val_dataset", False): - self.val_dataset.close_db() - if self.config.get("test_dataset", False): - self.test_dataset.close_db() - def _forward(self, batch): out = self.model(batch.to(self.device)) @@ -648,7 +642,9 @@ def run_relaxations(self, split="val"): ) gather_results["chunk_idx"] = np.cumsum( [gather_results["chunk_idx"][i] for i in idx] - )[:-1] # np.split does not need last idx, assumes n-1:end + )[ + :-1 + ] # np.split does not need last idx, assumes n-1:end full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz" diff --git a/tests/core/common/test_data_parallel_batch_sampler.py b/tests/core/common/test_data_parallel_batch_sampler.py index 6205042652..6bd8effe26 100644 --- a/tests/core/common/test_data_parallel_batch_sampler.py +++ b/tests/core/common/test_data_parallel_batch_sampler.py @@ -1,9 +1,16 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from __future__ import annotations -import functools -import tempfile from contextlib import contextmanager from pathlib import Path +import functools +import tempfile from typing import TypeVar import numpy as np @@ -13,11 +20,13 @@ from fairchem.core.common.data_parallel import ( BalancedBatchSampler, StatefulDistributedSampler, + UnsupportedDatasetError, + _balanced_partition, ) +from fairchem.core.datasets.base_dataset import BaseDataset, DatasetMetadata DATA = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -SIZE_ATOMS = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] -SIZE_NEIGHBORS = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4] +SIZE_ATOMS = [2, 20, 3, 51, 10, 11, 41, 31, 13, 14] T_co = TypeVar("T_co", covariant=True) @@ -28,23 +37,57 @@ def _temp_file(name: str): yield Path(tmpdir) / name +@pytest.fixture() +def valid_dataset(): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return DatasetMetadata(natoms=np.array(SIZE_ATOMS)) + + def __init__(self, data) -> None: + super().__init__(config={}) + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def get_metadata(self, attr, idx): + assert attr == "natoms" + metadata_attr = getattr(self._metadata, attr) + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + + return _Dataset(DATA) + + @pytest.fixture() def valid_path_dataset(): - class _Dataset(Dataset[T_co]): + class _Dataset(BaseDataset): + @functools.cached_property + def _metadata(self) -> DatasetMetadata: + return self.metadata + def __init__(self, data, fpath: Path) -> None: + super().__init__(config={}) self.data = data - self.metadata_path = fpath + self.metadata = DatasetMetadata(natoms=np.load(fpath)["natoms"]) def __len__(self): return len(self.data) def __getitem__(self, idx): - return self.data[idx] + metadata_attr = getattr(self._metadata, "natoms") + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] with _temp_file("metadata.npz") as file: np.savez( natoms=np.array(SIZE_ATOMS), - neighbors=np.array(SIZE_NEIGHBORS), file=file, ) yield _Dataset(DATA, file) @@ -52,8 +95,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_path_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data self.metadata_path = Path("/tmp/does/not/exist.np") @@ -68,8 +113,10 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_dataset(): - class _Dataset(Dataset): + class _Dataset(BaseDataset): + def __init__(self, data) -> None: + super().__init__(config={}) self.data = data def __len__(self): @@ -81,99 +128,68 @@ def __getitem__(self, idx): return _Dataset(DATA) -def test_lowercase(invalid_dataset) -> None: - sampler = BalancedBatchSampler( - dataset=invalid_dataset, +def test_lowercase(valid_dataset) -> None: + _ = BalancedBatchSampler( + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="ATOMS", - throw_on_error=False, - seed=0 - ) - assert sampler.mode == "atoms" - - sampler = BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="NEIGHBORS", - throw_on_error=False, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.mode == "neighbors" def test_invalid_mode(invalid_dataset) -> None: with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='natoms'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="natoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - ValueError, match="Must be one of 'atoms', 'neighbors', or a boolean." + ValueError, + match="Only mode='atoms' or mode=True is supported, got mode='neighbors'.", ): - BalancedBatchSampler( + _ = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, - mode="nneighbors", - throw_on_error=True, - seed=0 + mode="neighbors", + on_error="raise", + seed=0, ) def test_invalid_dataset(invalid_dataset) -> None: - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", - ): - BalancedBatchSampler( - dataset=invalid_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 - ) - with pytest.raises( - RuntimeError, - match="does not have a metadata_path attribute. Batches will not be balanced, which can incur significant overhead!", - ): - BalancedBatchSampler( + with pytest.raises(UnsupportedDatasetError): + sampler = BalancedBatchSampler( dataset=invalid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) def test_invalid_path_dataset(invalid_path_dataset) -> None: with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -182,13 +198,11 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=True, - seed=0 + on_error="raise", + seed=0, ) with pytest.raises( - RuntimeError, - match="Metadata file .+ does not exist. Batches will not be balanced, which can incur significant overhead!", + UnsupportedDatasetError, ): BalancedBatchSampler( dataset=invalid_path_dataset, @@ -197,70 +211,59 @@ def test_invalid_path_dataset(invalid_path_dataset) -> None: num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - force_balancing=False, - seed=0 + on_error="raise", + seed=0, ) -def test_valid_dataset(valid_path_dataset) -> None: +def test_valid_dataset(valid_dataset, valid_path_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode="atoms", - throw_on_error=True, - seed=0 - ) - assert (sampler.sizes == np.array(SIZE_ATOMS)).all() - - sampler = BalancedBatchSampler( - dataset=valid_path_dataset, - batch_size=1, - rank=0, - num_replicas=2, - device=None, - mode="neighbors", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert (sampler.sizes == np.array(SIZE_NEIGHBORS)).all() + assert ( + sampler._get_natoms(list(range(len(SIZE_ATOMS)))) == np.array(SIZE_ATOMS) + ).all() -def test_disabled(valid_path_dataset) -> None: +def test_disabled(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=2, device=None, mode=False, - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_single_node(valid_path_dataset) -> None: +def test_single_node(valid_dataset) -> None: sampler = BalancedBatchSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=1, rank=0, num_replicas=1, device=None, mode="atoms", - throw_on_error=True, - seed=0 + on_error="raise", + seed=0, ) - assert sampler.balance_batches is False + assert sampler.disabled or not sampler._dist_enabled() -def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: +def test_stateful_distributed_sampler_noshuffle(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -272,12 +275,12 @@ def test_stateful_distributed_sampler_noshuffle(valid_path_dataset) -> None: def test_stateful_distributed_sampler_vs_distributed_sampler( - valid_path_dataset, + valid_dataset, ) -> None: for seed in [0, 100, 200]: for batch_size in range(1, 4): stateful_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=2, @@ -286,7 +289,7 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( drop_last=True, ) sampler = DistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, rank=0, num_replicas=2, seed=seed, @@ -296,10 +299,10 @@ def test_stateful_distributed_sampler_vs_distributed_sampler( assert list(stateful_sampler) == list(sampler) -def test_stateful_distributed_sampler(valid_path_dataset) -> None: +def test_stateful_distributed_sampler(valid_dataset) -> None: for batch_size in range(1, 4): sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -309,7 +312,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: offset_step = 2 loaded_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, seed=0, @@ -319,7 +322,7 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(loaded_sampler) == original_order[offset_step * batch_size :] diff_sampler = StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=0, num_replicas=1, @@ -328,14 +331,14 @@ def test_stateful_distributed_sampler(valid_path_dataset) -> None: assert list(diff_sampler) != original_order -def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: - fullset = set(range(len(valid_path_dataset))) +def test_stateful_distributed_sampler_numreplicas(valid_dataset) -> None: + fullset = set(range(len(valid_dataset))) for drop_last in [True, False]: for num_replicas in range(1, 4): for batch_size in [1]: samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -360,14 +363,14 @@ def test_stateful_distributed_sampler_numreplicas(valid_path_dataset) -> None: def test_stateful_distributed_sampler_numreplicas_drop_last( - valid_path_dataset, + valid_dataset, ) -> None: - fullset = set(range(len(valid_path_dataset))) + fullset = set(range(len(valid_dataset))) for num_replicas in range(1, 4): for batch_size in range(1, 4): samplers = [ StatefulDistributedSampler( - dataset=valid_path_dataset, + dataset=valid_dataset, batch_size=batch_size, rank=rank, seed=0, @@ -387,3 +390,15 @@ def test_stateful_distributed_sampler_numreplicas_drop_last( ) assert len(concat_idxs) == len(np.unique(concat_idxs)) assert len(concat_idxs) == (len(fullset) // num_replicas) * num_replicas + + +def test_balancedbatchsampler_partition(valid_dataset) -> None: + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS), 4) + == [[1, 9, 5, 0], [7, 8, 2], [3], [6, 4]] + ) + # test case with local batch size = 1, GPU0(rank0) always gets smallest + # we cant say anything about the remaining elements because it is a heap + assert np.array( + _balanced_partition(np.array(SIZE_ATOMS)[[3, 6, 7, 1]], 4)[0] == [3] + ) diff --git a/tests/core/common/test_gp_utils.py b/tests/core/common/test_gp_utils.py index 9743d35a2f..05c7475d2c 100644 --- a/tests/core/common/test_gp_utils.py +++ b/tests/core/common/test_gp_utils.py @@ -7,42 +7,112 @@ gather_from_model_parallel_region, scatter_to_model_parallel_region, ) -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) def _dummy_call(x): return x -@pytest.mark.parametrize("world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])]) # noqa: PT006 + +@pytest.mark.parametrize( + "world_size, input, expected_output", [(1, 5, [5]), (3, 0, [0, 0, 0])] +) # noqa: PT006 def test_basic_setup(world_size: int, input: torch.Tensor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True) - output = spawn_multi_process(config, _dummy_call, input) + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=True + ) + output = spawn_multi_process( + config, _dummy_call, init_pg_and_rank_and_launch_test, input + ) assert output == expected_output -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1]), torch.Tensor([2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1]), torch.Tensor([2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1]), torch.Tensor([2, 3])], + ), + (2, 2, torch.Tensor([0, 1, 2]), [torch.Tensor([0, 1]), torch.Tensor([2])]), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0]), torch.Tensor([1]), torch.Tensor([2])], + ), + ], ) -def test_scatter_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_to_model_parallel_region, input) +def test_scatter_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, + scatter_to_model_parallel_region, + init_pg_and_rank_and_launch_test, + input, + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) + def scatter_gather_fn(input: torch.Tensor, dim: int = 0): x = scatter_to_model_parallel_region(input, dim) return gather_from_model_parallel_region(x, dim) -@pytest.mark.parametrize("world_size, gp_size, input, expected_output", # noqa: PT006 - [(2, 1, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2,3]), [torch.Tensor([0,1,2,3]), torch.Tensor([0,1,2,3])]), - (2, 2, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2])]), - (3, 3, torch.Tensor([0,1,2]), [torch.Tensor([0,1,2]), torch.Tensor([0,1,2]), torch.Tensor([0,1,2])])] + +@pytest.mark.parametrize( + "world_size, gp_size, input, expected_output", # noqa: PT006 + [ + ( + 2, + 1, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2, 3]), + [torch.Tensor([0, 1, 2, 3]), torch.Tensor([0, 1, 2, 3])], + ), + ( + 2, + 2, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ( + 3, + 3, + torch.Tensor([0, 1, 2]), + [torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2]), torch.Tensor([0, 1, 2])], + ), + ], ) -def test_gather_tensors(world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list): - config = PGConfig(backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True) - output = spawn_multi_process(config, scatter_gather_fn, input) +def test_gather_tensors( + world_size: int, gp_size: int, input: torch.Tesnor, expected_output: list +): + config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=gp_size, use_gp=True + ) + output = spawn_multi_process( + config, scatter_gather_fn, init_pg_and_rank_and_launch_test, input + ) for out, expected_out in zip(output, expected_output): assert torch.equal(out, expected_out) diff --git a/tests/core/datasets/conftest.py b/tests/core/datasets/conftest.py new file mode 100644 index 0000000000..eb7be94994 --- /dev/null +++ b/tests/core/datasets/conftest.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +from ase import build +from ase.calculators.singlepoint import SinglePointCalculator + + +@pytest.fixture(scope="module") +def structures(): + structures = [ + build.molecule("H2O", vacuum=4), + build.bulk("Cu"), + build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), + ] + for atoms in structures: + calc = SinglePointCalculator( + atoms, + energy=1, + forces=atoms.positions, + # there is an issue with ASE db when writing a db with 3x3 stress if is flattened to (9,) and then + # errors when trying to read it + stress=np.random.random((6,)), + ) + atoms.calc = calc + atoms.info["extensive_property"] = 3 * len(atoms) + atoms.info["tensor_property"] = np.random.random((6, 6)) + + structures[2].set_pbc(True) + return structures diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 01bd4ea2fc..7b114d877f 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -15,26 +15,6 @@ ) from fairchem.core.datasets.lmdb_database import LMDBDatabase -structures = [ - build.molecule("H2O", vacuum=4), - build.bulk("Cu"), - build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), -] -for atoms in structures: - calc = SinglePointCalculator( - atoms, - energy=1, - forces=atoms.positions, - # there is an issue with ASE db when writing a db with 3x3 stress it is flattened to (9,) and then - # errors when trying to read it - stress=np.random.random((6,)), - ) - atoms.calc = calc - atoms.info["extensive_property"] = 3 * len(atoms) - atoms.info["tensor_property"] = np.random.random((6, 6)) - -structures[2].set_pbc(True) - @pytest.fixture( params=[ @@ -46,7 +26,7 @@ "aselmdb_dataset", ], ) -def ase_dataset(request, tmp_path_factory): +def ase_dataset(request, structures, tmp_path_factory): tmp_path = tmp_path_factory.mktemp("dataset") mult = 1 a2g_args = { @@ -110,7 +90,7 @@ def ase_dataset(request, tmp_path_factory): return dataset, mult -def test_ase_dataset(ase_dataset): +def test_ase_dataset(ase_dataset, structures): dataset, mult = ase_dataset assert len(dataset) == mult * len(structures) for data in dataset: @@ -121,7 +101,7 @@ def test_ase_dataset(ase_dataset): assert isinstance(data.extensive_property, int) -def test_ase_read_dataset(tmp_path) -> None: +def test_ase_read_dataset(tmp_path, structures): # unfortunately there is currently no clean (already implemented) way to save atoms.info when saving # individual structures - so test separately for i, structure in enumerate(structures): @@ -137,13 +117,16 @@ def test_ase_read_dataset(tmp_path) -> None: assert len(dataset) == len(structures) data = dataset[0] del data - dataset.close_db() -def test_ase_metadata_guesser(ase_dataset) -> None: +def test_ase_get_metadata(ase_dataset): + assert ase_dataset[0].get_metadata("natoms", [0])[0] == 3 + + +def test_ase_metadata_guesser(ase_dataset): dataset, _ = ase_dataset - metadata = dataset.get_metadata() + metadata = dataset.sample_property_metadata() # Confirm energy metadata guessed properly assert metadata["targets"]["energy"]["extensive"] is False @@ -171,7 +154,7 @@ def test_ase_metadata_guesser(ase_dataset) -> None: assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" -def test_db_add_delete(tmp_path) -> None: +def test_db_add_delete(tmp_path, structures): database = db.connect(tmp_path / "asedb.db") for _i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) @@ -192,10 +175,9 @@ def test_db_add_delete(tmp_path) -> None: dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) assert len(dataset) == orig_len + len(new_structures) - 1 - dataset.close_db() -def test_ase_multiread_dataset(tmp_path) -> None: +def test_ase_multiread_dataset(tmp_path): atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] energies = np.linspace(1, 0, len(atoms_objects)) @@ -224,13 +206,17 @@ def test_ase_multiread_dataset(tmp_path) -> None: f.write(f"{tmp_path / 'test.traj'} {len(atoms_objects)}") dataset = AseReadMultiStructureDataset( - config={"index_file": str(tmp_path / "test_index_file")}, + config={ + "src": str(tmp_path), + "index_file": str(tmp_path / "test_index_file"), + }, ) assert len(dataset) == len(atoms_objects) dataset = AseReadMultiStructureDataset( config={ + "src": str(tmp_path), "index_file": str(tmp_path / "test_index_file"), "a2g_args": { "r_energy": True, diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py new file mode 100644 index 0000000000..d90271c53d --- /dev/null +++ b/tests/core/datasets/test_create_dataset.py @@ -0,0 +1,180 @@ +import os +import numpy as np +import pytest + +from fairchem.core.datasets import LMDBDatabase, create_dataset +from fairchem.core.datasets.base_dataset import BaseDataset +import tempfile +from fairchem.core.trainers.base_trainer import BaseTrainer + + +@pytest.fixture() +def lmdb_database(structures): + with tempfile.TemporaryDirectory() as tmpdirname: + num_atoms = [] + asedb_fn = f"{tmpdirname}/asedb.lmdb" + with LMDBDatabase(asedb_fn) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + num_atoms.append(len(atoms)) + np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms) + yield asedb_fn + + +def test_real_dataset_config(lmdb_database): + class TestTrainer(BaseTrainer): + def __init__(self, config): + self.config = config + + def train(self, x): + return None + + def get_sampler(self, *args, **kwargs): + return None + + def get_dataloader(self, *args, **kwargs): + return None + + config = { + "model_attributes": {}, + "optim": {"batch_size": 0}, + "dataset": { + "format": "ase_db", + "src": str(lmdb_database), + "first_n": 2, + "key_mapping": { + "y": "energy", + "force": "forces", + }, + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0, "stdev": 2.887317180633545}, + } + }, + }, + "val_dataset": {"src": str(lmdb_database)}, + "test_dataset": {}, + "relax_dataset": None, + } + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 2 + + # modify the config for split and see if it works as expected + config["dataset"].pop("first_n") + config["dataset"]["train_split_settings"] = {"first_n": 2} + + t = TestTrainer(config) + t.load_datasets() + assert len(t.train_dataset) == 2 + assert len(t.val_dataset) == 3 + + +@pytest.mark.parametrize("max_atoms", [3, None]) +@pytest.mark.parametrize( + "key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)] +) +def test_create_dataset(key, value, max_atoms, structures, lmdb_database): + # now create a config + config = { + "format": "ase_db", + "src": str(lmdb_database), + key: value, + "max_atoms": max_atoms, + } + + dataset = create_dataset(config, split="train") + if max_atoms is not None: + structures = [s for s in structures if len(s) <= max_atoms] + assert all( + natoms <= max_atoms + for natoms in dataset.metadata.natoms[range(len(dataset))] + ) + if key == "first_n": # this assumes first_n are not shuffled + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures[:value], dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures[:value], dataset) + ) + elif key == "sample_n": + assert len(dataset) == value + else: # no shuffle all of them are in there + assert all( + np.allclose(a1.cell.array, a2.cell.numpy()) + for a1, a2 in zip(structures, dataset) + ) + assert all( + np.allclose(a1.numbers, a2.atomic_numbers) + for a1, a2 in zip(structures, dataset) + ) + + +# make sure we cant sample more than the number of elements in the dataset with sample_n +def test_sample_n_dataset(lmdb_database): + with pytest.raises(ValueError): + _ = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 100, + }, + split="train", + ) + + +def test_diff_seed_sample_dataset(lmdb_database): + dataset_a = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 0, + }, + split="train", + ) + assert (dataset_a.indices == dataset_b.indices).all() + dataset_b = create_dataset( + config={ + "format": "ase_db", + "src": str(lmdb_database), + "sample_n": 3, + "seed": 1, + }, + split="train", + ) + assert not (dataset_a.indices == dataset_b.indices).all() + + +def test_del_dataset(): + class _Dataset(BaseDataset): + def __init__(self, fn) -> None: + super().__init__(config={}) + self.fn = fn + open(self.fn, "a").close() + + def __del__(self): + os.remove(self.fn) + + with tempfile.TemporaryDirectory() as tmpdirname: + fn = tmpdirname + "/test" + d = _Dataset(fn) + del d + assert not os.path.exists(fn) diff --git a/tests/core/datasets/test_lmdb_dataset.py b/tests/core/datasets/test_lmdb_dataset.py new file mode 100644 index 0000000000..f922e32ce3 --- /dev/null +++ b/tests/core/datasets/test_lmdb_dataset.py @@ -0,0 +1,29 @@ +from fairchem.core.datasets.base_dataset import create_dataset + +import numpy as np + +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes + + +def test_load_lmdb_dataset(tutorial_dataset_path): + + lmdb_path = str(tutorial_dataset_path / "s2ef/val_20") + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args(["--data-path", lmdb_path]) + make_lmdb_sizes(args) + + config = { + "format": "lmdb", + "src": lmdb_path, + } + + dataset = create_dataset(config, split="val") + + assert dataset.get_metadata("natoms", 0) == dataset[0].natoms + + all_metadata_natoms = np.array(dataset.get_metadata("natoms", range(len(dataset)))) + all_natoms = np.array([datapoint.natoms for datapoint in dataset]) + + assert (all_natoms == all_metadata_natoms).all() diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 2f9f80efbf..3af1bc0da7 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -13,7 +13,13 @@ from fairchem.core._cli import Runner from fairchem.core.common.flags import flags +from fairchem.core.common.test_utils import ( + PGConfig, + init_env_rank_and_launch_test, + spawn_multi_process, +) from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes setup_logging() @@ -107,6 +113,7 @@ def _run_main( update_run_args_with=None, save_checkpoint_to=None, save_predictions_to=None, + world_size=0, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" @@ -114,6 +121,7 @@ def _run_main( yaml_config = yaml.safe_load(yaml_file) if update_dict_with is not None: yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) @@ -133,7 +141,19 @@ def _run_main( for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) config = build_config(args, override_args) - Runner()(config) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + spawn_multi_process( + pg_config, + Runner(distributed=True), + init_env_rank_and_launch_test, + config, + ) + else: + Runner()(config) if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") @@ -249,6 +269,71 @@ def test_train_and_predict( tutorial_val_src=tutorial_val_src, ) + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + + @pytest.mark.parametrize( + ("world_size", "ddp"), + [ + pytest.param(2, True), + pytest.param(0, False), + ], + ) + def test_balanced_batch_sampler_ddp( + self, world_size, ddp, configs, tutorial_val_src, torch_deterministic + ): + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args( + ["--data-path", str(tutorial_val_src)] + ) + make_lmdb_sizes(args) + + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + if not ddp: + extra_args["no_ddp"] = True + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1, "load_balancing": "atoms"}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + world_size=world_size, + ) + # train for a few steps and confirm same seeds get same results def test_different_seeds( self, @@ -326,9 +411,9 @@ class TestSmallDatasetOptim: @pytest.mark.parametrize( ("model_name", "expected_energy_mae", "expected_force_mae"), [ - pytest.param("gemnet_oc", 0.4, 0.06, id="gemnet_oc"), - pytest.param("escn", 0.4, 0.06, id="escn"), - pytest.param("equiformer_v2", 0.4, 0.06, id="equiformer_v2"), + pytest.param("gemnet_oc", 0.41, 0.06, id="gemnet_oc"), + pytest.param("escn", 0.41, 0.06, id="escn"), + pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"), ], ) def test_train_optimization( diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index 500b53d628..3194dd2df7 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -18,7 +18,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry -from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.models.equiformer_v2.so3 import ( @@ -137,7 +141,9 @@ def test_energy_force_shape(self, snapshot): def test_ddp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=1, gp_group_size=1, use_gp=False) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 1 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape @@ -148,7 +154,9 @@ def test_ddp(self, snapshot): def test_gp(self, snapshot): data_dist = self.data.clone().detach() config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - output = spawn_multi_process(config, _runner, data_dist) + output = spawn_multi_process( + config, _runner, init_pg_and_rank_and_launch_test, data_dist + ) assert len(output) == 2 energy, forces = output[0]["energy"], output[0]["forces"] assert snapshot == energy.shape