Skip to content

Commit

Permalink
Merge pull request #488 from allenai/shanea/optimize-unsharding-2
Browse files Browse the repository at this point in the history
[Storage Cleaner] Speed up unsharding of some legacy checkpoints
  • Loading branch information
2015aroras authored Mar 7, 2024
2 parents a737306 + 158da6c commit 752353b
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 113 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Refactor torch.load monkey patching for legacy checkpoint unsharding in anticipation of unsharding implementation change.
- Changed legacy checkpoint unsharding to use processes and shared memory instead of threads

## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02

Expand Down
324 changes: 212 additions & 112 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
import shutil
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from functools import reduce
from multiprocessing import shared_memory
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast

import numpy as np
import torch
import torch.distributed.checkpoint as dist_cp
from numpy import ndarray
import torch.multiprocessing as mp
from packaging import version
from torch.distributed import _remote_device
from torch.distributed._shard._utils import narrow_tensor_by_index
Expand All @@ -41,6 +42,8 @@
except ModuleNotFoundError:
from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore

from olmo import util

from .aliases import PathOrStr
from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
from .optim import Optimizer, fix_optim_state_dict
Expand Down Expand Up @@ -913,6 +916,162 @@ def unshard_checkpoint(
full_state_dict if load_trainer_state else None,
)

def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
key = tuple() if key is None else key
if isinstance(state, (list, tuple, set)):
for i, sub_state in enumerate(state):
self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
elif isinstance(state, dict):
for name in state.keys():
self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
elif isinstance(state, ShardedTensor):
self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
return
else:
return

def _get_shard_placement_and_rank_sizes(
self, shards_metadata: List[ShardMetadata], world_size: int
) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
def shard_size(shard_md):
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]

rank_sizes = [0 for _ in range(world_size)]
shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
for shard_md in shards_metadata:
shard_rank = cast(_remote_device, shard_md.placement).rank()
assert shard_rank is not None
if shard_rank >= world_size:
raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")

shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
rank_sizes[shard_rank] += shard_size(shard_md)

return shard_placement, rank_sizes

def _copy_sharded_tensor_to_shared_mem(
self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
) -> Any:
shard0_md = sharded_tensor.metadata()
shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
shard0_md.shards_metadata, world_size
)

rank_size = rank_sizes[rank]
assert rank_size >= 0
if rank_size == 0:
return

assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
numpy_type = np.float32

sharded_memory_name = "-".join(key + (str(rank),))

shm = shared_memory.SharedMemory(
create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
)
np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)

for local_shard in sharded_tensor.local_shards():
shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
assert shard_rank == rank

src = local_shard.tensor.flatten()
shard_offset = shard_placement[local_shard.metadata][1]

np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()

shm.close()

def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
shard_number = int(shard_filepath.name[4:-3])
log.info("Starting unsharding shard number %d to shared memory", shard_number)

with self._patch_sharded_tensor_load():
shard = torch.load(shard_filepath, map_location="cpu")
log.debug("Done loading shard number %d", shard_number)

self._copy_sharded_tensors_to_shared_mem(
shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
)
log.info("Done unsharding shard number %d to shared memory", shard_number)

def _unshard_using_sharded_mem(
self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
) -> Any:
return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))

def _unshard_state_using_shared_mem(
self, state: Any, world_size: int, device: torch.device, key: Tuple
) -> Any:
if isinstance(state, (list, tuple, set)):
return state.__class__(
self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
for i, sub_state in enumerate(state)
)
elif isinstance(state, dict):
return {
name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
for name in state.keys()
}
elif isinstance(state, ShardedTensor):
return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
elif isinstance(state, torch.Tensor):
return state.to(device=device)
else:
return state

def _unshard_tensor_using_shared_mem(
self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
) -> torch.Tensor:
shard0_md = sharded_tensor.metadata()

def shard_size(shard_md):
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]

shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
shard0_md.shards_metadata, world_size
)

assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
numpy_type = np.float32

out = torch.empty(
*sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
)
dims = len(sharded_tensor.metadata().size)
for shard_md, (rank, rank_offset) in shard_placement.items():
if rank >= world_size:
raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")

sharded_memory_name = "-".join(key + (str(rank),))
shm = shared_memory.SharedMemory(name=sharded_memory_name)

rank_size = rank_sizes[rank]
assert rank_size >= 0
if rank_size == 0:
continue

np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)

tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
tensor = tensor.view(shard_md.shard_sizes)

out_narrow_view = out
for dim in range(dims):
out_narrow_view = out_narrow_view.narrow(
dim,
shard_md.shard_offsets[dim],
shard_md.shard_sizes[dim],
)

out_narrow_view.copy_(tensor)

shm.close()
shm.unlink()

return out

@contextmanager
def _patch_sharded_tensor_load(self):
"""
Expand Down Expand Up @@ -948,127 +1107,68 @@ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2

def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
"""
The current unsharding implementation consists of:
1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
2. Loading 1 shard on the main process as a base unsharded object.
3. Using the sharded tensors in shared memory to populate the base unsharded object.
This implementation replaced a prior implementation that instead loaded
all shards using threads, because that implementation turned out to
be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
The current implementation is slower than the old one in many scenarios,
but is significantly faster in the above mentioned case (e.g. 30 minutes)
if there are enough CPUs.
"""

input_dir = Path(input_dir)
skip_keys = skip_keys or set()

with self._patch_sharded_tensor_load():
# We load in threads because it's faster.
executor = ThreadPoolExecutor()
shards_dict = {}
for shard_name in input_dir.glob("rank*.pt"):
log.info("Loading %s ...", shard_name)
shard_number = int(shard_name.name[4:-3]) # shard names look like "rankXX.pt"
shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu")
shards = [None] * len(shards_dict)
for rank, shard_future in shards_dict.items():
shard = shard_future.result()
for key in skip_keys:
if key in shard:
del shard[key]
shards[rank] = shard
assert all(shard is not None for shard in shards)
executor.shutdown()
del shards_dict

log.info("Unsharding from %d shards ...", len(shards))

unsharded_state_dict = self._unshard_object(shards, device=device)
# At this point in time we need 2x memory :-(
del shards

return unsharded_state_dict

def _unshard_object(self, os: List[Any], device: torch.device) -> Any:
rank0_item = os[0]
assert all(type(o) is type(rank0_item) for o in os)
if isinstance(rank0_item, str):
assert all(o == rank0_item for o in os)
return rank0_item
elif isinstance(rank0_item, (list, tuple, set)):
assert all(len(o) == len(rank0_item) for o in os)
return rank0_item.__class__(self._unshard_object(o, device=device) for o in zip(*os))
elif isinstance(rank0_item, dict):
assert all(o.keys() == rank0_item.keys() for o in os)
return {key: self._unshard_object([o[key] for o in os], device=device) for key in rank0_item.keys()}
elif isinstance(rank0_item, ShardedTensor):
return self._gather(os, device=device)
else:
assert all(self._objects_are_equal(o, rank0_item) for o in os)
return rank0_item

def _gather(self, shards: List[ShardedTensor], device: torch.device) -> torch.Tensor:
world_size = len(shards)
shard0_md = shards[0].metadata()
# Make sure all shards agree on the metadata
assert all(shard.metadata() == shard0_md for shard in shards)
# Make sure the nth shard expects to be the nth shard.
assert all(
shard_md.placement.rank() == rank # type: ignore
for rank, shard_md in enumerate(shard0_md.shards_metadata)
)

def shard_size(shard_md):
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
shard_filepaths = list(input_dir.glob("rank*.pt"))
world_size = len(shard_filepaths)
if world_size == 0:
raise RuntimeError("No shards found for unsharding")

rank_sizes = [0 for _ in range(world_size)]
max_rank_size = 0
shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
for shard_md in shard0_md.shards_metadata:
shard_rank = cast(_remote_device, shard_md.placement).rank()
assert shard_rank is not None

shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
rank_sizes[shard_rank] += shard_size(shard_md)
max_rank_size = max(max_rank_size, rank_sizes[shard_rank])

gather_list: List[torch.Tensor] = [torch.empty((max_rank_size,)) for _ in range(world_size)]

datas = []
with torch.no_grad():
for shard in shards:
data = torch.empty(max_rank_size)

for local_shard in shard.local_shards():
src = local_shard.tensor.flatten()
shard_offset = shard_placement[local_shard.metadata][1]
data[shard_offset : shard_offset + src.numel()].copy_(src)
log.info("Number of shards: %d", world_size)
shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
min_ram_required_estimate_gb = shard_size_gb * world_size
log.info(
"Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
)

datas.append(data)
log.info("Copying sharded tensors to shared memory using multiple processes")
# Copy sharded data to shared memory using multiple processes, so this process can load
# from memory rather than disk. We spawn a new process instead of forking since shared memory
# appears to get deleted when forked processes end for some reason.
executor = ProcessPoolExecutor(
mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
)
futures = []
for shard_filepath in shard_filepaths:
shard_rank = int(shard_filepath.name[4:-3])

# torch.gather in a nutshell
for rank, data in enumerate(datas):
gather_list[rank].copy_(data)
if shard_rank >= world_size:
raise RuntimeError(
f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
)

full_size = shard0_md.size
out = torch.empty(*full_size, dtype=shard0_md.tensor_properties.dtype, device=device)
dims = len(full_size)
for shard_md in shard0_md.shards_metadata:
rank, rank_offset = shard_placement[shard_md]
tensor = gather_list[rank]
tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
tensor = tensor.view(shard_md.shard_sizes)
futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))

out_narrow_view = out
for dim in range(dims):
out_narrow_view = out_narrow_view.narrow(
dim,
shard_md.shard_offsets[dim],
shard_md.shard_sizes[dim],
)
for f in as_completed(futures):
f.result()
executor.shutdown()

out_narrow_view.copy_(tensor)
log.info("Loading a shard on the main process to be unsharded state")
with self._patch_sharded_tensor_load():
state = torch.load(shard_filepaths[0], map_location="cpu")

return out
for key in skip_keys:
if key in state:
del state[key]

def _objects_are_equal(self, a: Any, b: Any) -> bool:
if type(a) is not type(b):
return False
if isinstance(a, ndarray):
return np.array_equal(a, b)
elif isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return a == b
log.info("Unsharding from %d shards ...", world_size)
return self._unshard_using_sharded_mem(state, world_size, device, input_dir)


@dataclass
Expand Down

0 comments on commit 752353b

Please sign in to comment.