Skip to content

Commit

Permalink
Balanced batch sampler+base dataset (#753)
Browse files Browse the repository at this point in the history
* Update BalancedBatchSampler to use datasets' `data_sizes` method
Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error`

* Remove python 3.10 syntax

* Documentation

* Added set_epoch method

* Format

* Changed "resolved dataset" message to be a debug log to reduce log spam

* clean up batchsampler and tests

* base dataset class

* move lin_ref to base dataset

* inherit basedataset for ase dataset

* filter indices prop

* added create_dataset fn

* yaml load fix

* create dataset function instead of filtering in base

* remove filtered_indices

* make create_dataset and LMDBDatabase importable from datasets

* create_dataset cleanup

* test create_dataset

* use metadata.natoms directly and add it to subset

* use self.indices to handle shard

* rename _data_sizes

* fix Subset of metadata

* minor change to metadata, added full path option

* import updates

* implement get_metadata for datasets; add tests for max_atoms and balanced partitioning

* a[:len(a)+1] does not throw error, change to check for this

* off by one fix

* fixing tests

* plug create_dataset into trainer

* remove datasetwithsizes; fix base dataset integration; replace close_db with __del__

* lint

* add/fix test;

* adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764)

* adding new notebook for using fairchem models with NEBs

* adding md tutorials

* blocking code cells that arent needed or take too long

* Add extra test case for local batch size = 1

* fix example

* fix test case

* reorg changes

* remove metadata_has_sizes in favor of basedataset function metadata_hasattr

* fix data_parallel typo

* fix up some tests

* rename get_metadata to sample_property_metadata

* add slow get_metadata for ase; add tests for get_metadata (ase+lmdb); add test for make lmdb metadata sizes

* add support for different backends and ddp in pytest

* fix tests and balanced batch sampler

* make default dataset lmdb

* lint

* fix tests

* test with world_size=0 by default

* fix tests

* fix tests..

* remove subsample from oc22 dataset

* remove old datasets; add test for noddp

* remove load balancing from docs

* fix docs; add train_split_settings and test for this

---------

Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: lbluque <[email protected]>
Co-authored-by: Brandon <[email protected]>
Co-authored-by: Brook Wander <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
  • Loading branch information
8 people authored Aug 2, 2024
1 parent 7ec7397 commit 04a69b0
Show file tree
Hide file tree
Showing 21 changed files with 1,095 additions and 449 deletions.
1 change: 1 addition & 0 deletions docs/core/fine-tuning/fine-tuning-oxides.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ from fairchem.core.common.tutorial_utils import generate_yml_config
yml = generate_yml_config(checkpoint_path, 'config.yml',
delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes',
'optim.loss_force', # the checkpoint setting causes an error
'optim.load_balancing',
'dataset', 'test_dataset', 'val_dataset'],
update={'gpus': 1,
'optim.eval_every': 10,
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/advanced/fine-tuning-in-python.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ We start by making the config.yml. We build this from the calculator checkpoint.
from fairchem.core.common.tutorial_utils import generate_yml_config
yml = generate_yml_config(checkpoint_path, 'config.yml',
delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes',
delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes','optim.load_balancing',
'optim.loss_force', # the checkpoint setting causes an error
'dataset', 'test_dataset', 'val_dataset'],
update={'gpus': 1,
Expand Down
236 changes: 119 additions & 117 deletions src/fairchem/core/common/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@

import heapq
import logging
from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Literal

import numba
import numpy as np
import numpy.typing as npt
import torch
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
import torch.distributed
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from typing_extensions import override

from fairchem.core.common import distutils, gp_utils
from fairchem.core.datasets import data_list_collater
from fairchem.core.datasets.base_dataset import (
UnsupportedDatasetError,
)

if TYPE_CHECKING:
from pathlib import Path

from numpy.typing import NDArray
from torch_geometric.data import Batch, Data


Expand All @@ -35,30 +38,24 @@ def __call__(self, data_list: list[Data]) -> Batch:


@numba.njit
def balanced_partition(sizes: npt.NDArray[np.int_], num_parts: int):
def _balanced_partition(sizes: NDArray[np.int_], num_parts: int):
"""
Greedily partition the given set by always inserting
the largest element into the smallest partition.
"""
sort_idx = np.argsort(-sizes) # Sort in descending order
heap: list[tuple[list[int], list[int]]] = [
(sizes[idx], [idx]) for idx in sort_idx[:num_parts]
]
heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]]
heapq.heapify(heap)
for idx in sort_idx[num_parts:]:
smallest_part = heapq.heappop(heap)
new_size = smallest_part[0] + sizes[idx]
new_idx = smallest_part[1] + [idx]
new_idx = smallest_part[1] + [
idx
] # TODO should this be append to save time/space
heapq.heappush(heap, (new_size, new_idx))
return [part[1] for part in heap]


@runtime_checkable
class _HasMetadata(Protocol):
@property
def metadata_path(self) -> Path: ...


class StatefulDistributedSampler(DistributedSampler):
"""
More fine-grained state DataSampler that uses training iteration and epoch
Expand Down Expand Up @@ -105,149 +102,154 @@ def set_epoch_and_start_iteration(self, epoch, start_iter):
self.start_iter = start_iter


class BalancedBatchSampler(Sampler):
def _load_dataset(self, dataset, mode: Literal["atoms", "neighbors"]):
errors: list[str] = []
if not isinstance(dataset, _HasMetadata):
errors.append(f"Dataset {dataset} does not have a metadata_path attribute.")
return None, errors
if not dataset.metadata_path.exists():
errors.append(f"Metadata file {dataset.metadata_path} does not exist.")
return None, errors
def _ensure_supported(dataset: Any):
if not isinstance(dataset, Dataset):
raise UnsupportedDatasetError("BalancedBatchSampler requires a dataset.")

if not dataset.metadata_hasattr("natoms"):
raise UnsupportedDatasetError(
"BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms."
)

key = {"atoms": "natoms", "neighbors": "neighbors"}[mode]
sizes = np.load(dataset.metadata_path)[key]
logging.debug(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
return dataset

return sizes, errors

class BalancedBatchSampler(BatchSampler):
def __init__(
self,
dataset,
dataset: Dataset,
*,
batch_size: int,
num_replicas: int,
rank: int,
device: torch.device,
seed: int,
mode: str | bool = "atoms",
mode: bool | Literal["atoms"] = "atoms",
shuffle: bool = True,
on_error: Literal["warn_and_balance", "warn_and_no_balance", "raise"] = "raise",
drop_last: bool = False,
force_balancing: bool = False,
throw_on_error: bool = False,
) -> None:
if mode is True:
mode = "atoms"

if isinstance(mode, str):
mode = mode.lower()
if mode not in ("atoms", "neighbors"):
raise ValueError(
f"Invalid mode {mode}. Must be one of 'atoms', 'neighbors', or a boolean."
)
):
"""
Initializes a BalancedBatchSampler object.
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.device = device
self.mode = mode
self.shuffle = shuffle
self.drop_last = drop_last
Args:
dataset (Dataset): The dataset to sample from.
batch_size (int): The size of each batch.
num_replicas (int): The number of processes participating in distributed training.
rank (int): The rank of the current process in distributed training.
device (torch.device): The device to use for the batches.
mode (str or bool, optional): The mode to use for balancing the batches. Defaults to "atoms".
shuffle (bool, optional): Whether to shuffle the samples. Defaults to True.
on_error (Literal["warn_and_balance", "warn_and_no_balance", "raise"], optional): The action to take when an error occurs (i.e., when we have an invalid dataset). Defaults to "raise".
- "warn_and_balance": Raise a warning and balance the batch by manually loading the data samples and counting the number of nodes (this is slow).
- "warn_and_no_balance": Raise a warning and do not do any balancing.
- "raise": Raise an error.
drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False.
"""
self.disabled = False
self.on_error = on_error

if mode is False:
logging.warning(f"Disabled BalancedBatchSampler because {mode=}.")
self.disabled = True
elif mode.lower() != "atoms":
raise ValueError(
f"Only mode='atoms' or mode=True is supported, got {mode=}."
)
elif num_replicas == 1:
logging.warning(f"Disabled BalancedBatchSampler because {num_replicas=}.")
self.disabled = True

try:
dataset = _ensure_supported(dataset)
except UnsupportedDatasetError as error:
if self.on_error == "raise":
raise error
if self.on_error == "warn_and_balance":
logging.warning(
f"Failed to get data sizes from metadata, loading data to get sizes (THIS IS SLOW). {error}"
)
elif self.on_error == "warn_and_no_balance":
logging.warning(
f"Failed to get data sizes, falling back to uniform partitioning. {error}"
)
else:
raise ValueError(f"Unknown on_error={self.on_error}") from error

self.single_sampler = StatefulDistributedSampler(
self.dataset,
sampler = StatefulDistributedSampler(
dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
drop_last=drop_last,
batch_size=batch_size,
seed=seed,
)
self.batch_sampler = BatchSampler(
self.single_sampler,
batch_size,
drop_last=drop_last,
)

self.sizes = None
self.balance_batches = False

if self.num_replicas <= 1:
logging.info("Batch balancing is disabled for single GPU training.")
return

if self.mode is False:
logging.info(
"Batch balancing is disabled because `optim.load_balancing` is `False`"
)
return

self.sizes, errors = self._load_dataset(dataset, self.mode)
if self.sizes is None:
self.balance_batches = force_balancing
if force_balancing:
errors.append(
"BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead! "
"You can disable balancing by setting `optim.load_balancing` to `False`."
)
else:
errors.append(
"Batches will not be balanced, which can incur significant overhead!"
)
else:
self.balance_batches = True

if errors:
msg = "BalancedBatchSampler: " + " ".join(errors)
if throw_on_error:
raise RuntimeError(msg)
super().__init__(sampler, batch_size=batch_size, drop_last=drop_last)
self.device = device

logging.warning(msg)
logging.info(
f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
)

def __len__(self) -> int:
return len(self.batch_sampler)
def _get_natoms(self, batch_idx: list[int]):
if self.sampler.dataset.metadata_hasattr("natoms"):
return np.array(
self.sampler.dataset.get_metadata("natoms", batch_idx)
).reshape(-1)
if self.on_error == "warn_and_balance":
return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx])
return None

def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None:
if not hasattr(self.single_sampler, "set_epoch_and_start_iteration"):
if not isinstance(self.sampler, StatefulDistributedSampler):
if start_iteration != 0:
raise NotImplementedError(
f"{type(self.single_sampler)} does not support resuming from a nonzero step."
)
self.single_sampler.set_epoch(epoch)
self.sampler.set_epoch(epoch)
else:
self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration)
self.sampler.set_epoch_and_start_iteration(epoch, start_iteration)

def set_epoch(self, epoch: int) -> None:
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)

@staticmethod
def _dist_enabled():
return torch.distributed.is_available() and torch.distributed.is_initialized()

@override
def __iter__(self):
if not self.balance_batches:
yield from self.batch_sampler
if self.disabled or not self._dist_enabled():
yield from super().__iter__()
return

for batch_idx in self.batch_sampler:
if self.sizes is None:
# Unfortunately, we need to load the data to know the image sizes
data_list = [self.dataset[idx] for idx in batch_idx]

if self.mode == "atoms":
sizes = [data.num_nodes for data in data_list]
elif self.mode == "neighbors":
sizes = [data.edge_index.shape[1] for data in data_list]
else:
raise NotImplementedError(
f"Unknown load balancing mode: {self.mode}"
)
else:
sizes = [self.sizes[idx] for idx in batch_idx]

idx_sizes = torch.stack([torch.tensor(batch_idx), torch.tensor(sizes)])
for batch_idx in super().__iter__():
sizes = self._get_natoms(batch_idx)
if sizes is None: # on_error == "warn_and_no_balance" is set
yield batch_idx
continue

idx_sizes = torch.stack(
[
torch.tensor(batch_idx, device=self.device),
torch.tensor(sizes, device=self.device),
]
)
idx_sizes_all = distutils.all_gather(idx_sizes, device=self.device)
idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu()
if gp_utils.initialized():
idx_sizes_all = torch.unique(input=idx_sizes_all, dim=1)
idx_all = idx_sizes_all[0]
sizes_all = idx_sizes_all[1]

local_idx_balanced = balanced_partition(
sizes_all.numpy(), num_parts=self.num_replicas
local_idx_balanced = _balanced_partition(
sizes_all.numpy(),
num_parts=self.sampler.num_replicas,
)
# Since DistributedSampler pads the last batch
# this should always have an entry for each replica.
yield idx_all[local_idx_balanced[self.rank]]
yield idx_all[local_idx_balanced[self.sampler.rank]]
6 changes: 3 additions & 3 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def setup(config) -> None:
)
else:
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend="nccl")
dist.init_process_group(backend=config.get("backend", "nccl"))


def cleanup() -> None:
Expand Down Expand Up @@ -144,7 +144,7 @@ def all_reduce(
if not isinstance(data, torch.Tensor):
tensor = torch.tensor(data)
if device is not None:
tensor = tensor.cuda(device)
tensor = tensor.to(device)
dist.all_reduce(tensor, group=group)
if average:
tensor /= get_world_size()
Expand All @@ -162,7 +162,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]:
if not isinstance(data, torch.Tensor):
tensor = torch.tensor(data)
if device is not None:
tensor = tensor.cuda(device)
tensor = tensor.to(device)
tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())]
dist.all_gather(tensor_list, tensor, group=group)
if not isinstance(data, torch.Tensor):
Expand Down
Loading

0 comments on commit 04a69b0

Please sign in to comment.