Skip to content

Commit

Permalink
Avoid duplicate reads of the same index when loading.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702354575
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Dec 3, 2024
1 parent 9baf7db commit c7af62a
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 102 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [emergency checkpoint] Fix local restore by re-mapping device ids directly
instead of inferring them from how process indexes changed across restarts
with some false assumptions.
- Avoid duplicate reads of the same index when loading.

### Changed
- Coordination service now supports barrier reuse - eliminate some barrier name
Expand Down
32 changes: 30 additions & 2 deletions checkpoint/orbax/checkpoint/_src/arrays/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@
NdSlice = types.NdSlice
Index = types.Index

HashableSlice = types.HashableSlice
HashableIndex = types.HashableIndex


def int_tuple_from_slice(s: slice) -> tuple[int, ...]:
"""Represents a slice as a tuple of integers."""
ints = s.start, s.stop, s.step
start, stop, step = s.start, s.stop, s.step
step = step or 1
try:
return tuple(int(x) for x in ints)
return (int(start), int(stop), int(step))
except:
raise ValueError(f'Slice {s} is not concrete.') from None

Expand All @@ -52,6 +56,30 @@ def resolve_slice(xs: NdSlice, shape: Shape) -> NdSlice:
for x, n in zip(() if xs is Ellipsis else xs, shape))


def to_hashable_index(
idx: Index, *, shape: Shape | None = None
) -> HashableIndex:
"""Converts an Index into a hashable form.
Optionally resolves the slices to a concrete index if the shape is provided.
If not, conversion may fail if the slices are not concrete.
Args:
idx: The index to convert.
shape: Global array shape.
Returns:
A hashable index.
"""
idx = resolve_slice(idx, shape) if shape else idx

return tuple([int_tuple_from_slice(s) for s in idx])


def from_hashable_index(idx: HashableIndex) -> Index:
return tuple([slice(s[0], s[1], s[2]) for s in idx])


def dissolve_slice(
xs: NdSlice,
shape: Shape,
Expand Down
5 changes: 5 additions & 0 deletions checkpoint/orbax/checkpoint/_src/arrays/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ class NumpyShapeDtypeStruct:
"""Abstract representation of a Numpy array."""
shape: Shape
dtype: np.dtype


# Slice objects are not hashable before python 3.12.
HashableSlice = tuple[int, int, int]
HashableIndex = tuple[HashableSlice, ...]
11 changes: 4 additions & 7 deletions checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
Shape = types.Shape
Index = types.Index
OptionalAxisAndShape = tuple[int | None, Shape| None]
# Slice objects are not hashable before python 3.12.
HashableSlice = tuple[int | None, int | None, int | None]

HashableIndex = types.HashableIndex
HashableSlice = types.HashableSlice


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -138,10 +139,6 @@ def to_fragments(self) -> fragments.Fragments:
return result


def _hashable_slices(slices: Index) -> tuple[HashableSlice, ...]:
return tuple([(s.start, s.stop, s.step) for s in slices])


@functools.lru_cache(maxsize=4096)
def _sharding_num_replicas(
sharding: jax.sharding.Sharding, global_shape: Shape
Expand Down Expand Up @@ -174,7 +171,7 @@ def _sharding_num_replicas(
"""
counts = collections.defaultdict(int)
for index in sharding.devices_indices_map(global_shape).values():
counts[_hashable_slices(index)] += 1
counts[numpy_utils.to_hashable_index(index, shape=global_shape)] += 1
num_replicas = next(iter(counts.values()))
assert all(count == num_replicas for count in counts.values())
return num_replicas
Expand Down
192 changes: 107 additions & 85 deletions checkpoint/orbax/checkpoint/_src/serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
"""Array serialization and deserialization."""

import asyncio
from collections.abc import Awaitable, Mapping
import collections
from collections.abc import Mapping
import contextlib
import functools
import os
import re
from typing import Any, AsyncIterator, Callable, Dict, Optional, Protocol, Sequence, Union
from typing import Any, AsyncIterator, Dict, Optional, Protocol, Sequence, Union

from absl import logging
import humanize
Expand All @@ -29,6 +29,7 @@
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint._src.arrays import fragments
from orbax.checkpoint._src.arrays import numpy_utils as np_utils
from orbax.checkpoint._src.arrays import types
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.serialization import replica_slices
Expand All @@ -51,25 +52,6 @@
Shape = types.Shape


def _get_device_to_index_map(
global_shape: Shape, sharding: jax.sharding.Sharding
) -> Mapping[jax.Device, Index]:
return sharding.devices_indices_map(global_shape)


async def create_async_array_from_callback(
global_shape: Shape,
sharding: jax.sharding.Sharding,
data_callback: Callable[[Index, jax.Device], Awaitable[jax.Array]],
) -> jax.Array:
device_to_index_map = _get_device_to_index_map(global_shape, sharding)
addressable_da = sharding._addressable_device_assignment # pylint: disable=protected-access
future_arrays = [data_callback(device_to_index_map[d], d)
for d in addressable_da]
dbs = await asyncio.gather(*future_arrays)
return jax.make_array_from_single_device_arrays(global_shape, sharding, dbs)


def _get_metadata(arr: jax.Array, local_shape: Shape):
return {
'compressor': {'id': 'zstd'},
Expand Down Expand Up @@ -427,16 +409,15 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
return num_bytes


async def _read_and_device_put_shard(
device: jax.Device,
async def _read_shard(
t: ts.TensorStore,
*,
new_shard_shape: Sequence[int],
dtype: jnp.dtype,
requested_domain: ts.IndexDomain,
restricted_domain: ts.IndexDomain,
dll: Optional[layout.DeviceLocalLayout],
) -> jax.Array:
"""Reads a single shard from TensorStore and places it on device."""
) -> np.ndarray:
"""Reads a single shard from TensorStore into host memory."""
# This maybe needed because the shape the array was saved with is smaller
# than the requested shape of the array in which it will be reloaded. So
# the extra values will be filled with 0s.
Expand All @@ -448,29 +429,22 @@ async def _read_and_device_put_shard(
# Cast while reloading on process to avoid 2 copies on device if the
# casting is done on device.
out = out.astype(dtype)
# Convert to jnp array so that layouts are initialized properly for
# sub-byte dtypes.
# TODO(yashkatariya): This is a band-aid fix. Figure out a better way to
# make this work.
if out.dtype == jnp.int4:
out = jnp.asarray(out) # type: ignore
return jax.device_put(
out, Layout(dll, jax.sharding.SingleDeviceSharding(device))
)
return out


async def _read_array_index_callback(
async def _read_array_index_and_device_put(
devices: list[jax.Device],
index: Index,
device: jax.Device,
t: ts.TensorStore,
shape: Shape,
*,
global_shape: Shape,
new_shard_shape: Shape,
dtype: jnp.dtype,
byte_limiter: ByteLimiter,
strict: bool,
ddl: Optional[layout.DeviceLocalLayout],
) -> jax.Array:
"""Callback that reads an array index and places on device."""
dll: Optional[layout.DeviceLocalLayout],
) -> list[jax.Array]:
"""Callback that reads an array index and places on the devices."""
for sl in index:
if sl.step is not None and sl.step != 1:
raise ValueError(
Expand All @@ -479,45 +453,100 @@ async def _read_array_index_callback(
)

if strict:
if t.shape == shape:
domain = ts.IndexDomain(shape=shape)[ts.d[:][index]]
if t.shape == global_shape:
domain = ts.IndexDomain(shape=global_shape)[ts.d[:][index]]
requested_domain = domain
restricted_domain = domain
else:
raise ValueError(
f'Requested shape: {shape} is not compatible with the stored shape:'
f' {t.shape}. Truncating/padding is disabled by setting of'
' `strict=True`. When using standard Orbax APIs, this behavior can be'
' modified by specifying `strict=False` in `ArrayRestoreArgs` for any'
' array in which padding/truncation is desired.'
f'Requested shape: {global_shape} is not compatible with the stored'
f' shape: {t.shape}. Truncating/padding is disabled by setting of'
' `strict=True`. When using standard Orbax APIs, this behavior can'
' be modified by specifying `strict=False` in `ArrayRestoreArgs` for'
' any array in which padding/truncation is desired.'
)
else:
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
requested_domain = ts.IndexTransform(input_shape=global_shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)

requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
result = []
# Limit the bytes read for every shard.
# Perform read for index once, and place it on all relevant devices within
# the `reserved_bytes` context.
# TODO(b/381111280) This de-duplication of reads does not fully solve the
# problem of read amplification, since we can still run into problems if
# we are resharding. See b/381111280 for details.
async with reserved_bytes(byte_limiter, requested_bytes):
try:
result = await _read_and_device_put_shard(
device,
t,
new_shard_shape,
dtype,
requested_domain,
restricted_domain,
ddl,
shard = await _read_shard(
t=t,
new_shard_shape=new_shard_shape,
dtype=dtype,
requested_domain=requested_domain,
restricted_domain=restricted_domain,
)
except BaseException as e:
raise Exception( # pylint: disable=broad-exception-raised
f'Encountered error while reading array index: {index}. See full'
f' TensorStore details: {t.spec}.'
) from e
for device in devices:
result.append(
jax.device_put(
shard, Layout(dll, jax.sharding.SingleDeviceSharding(device))
)
)
return result


def _get_device_to_index_map(
global_shape: Shape, sharding: jax.sharding.Sharding
) -> Mapping[jax.Device, Index]:
return sharding.devices_indices_map(global_shape)


async def read_and_create_array(
t: ts.TensorStore,
*,
global_shape: Shape,
new_shard_shape: Shape,
sharding: jax.sharding.Sharding,
dtype: jnp.dtype,
byte_limiter: ByteLimiter,
strict: bool,
dll: Optional[layout.DeviceLocalLayout],
) -> jax.Array:
"""Read shards from TensorStore and create a jax.Array."""
local_indices_devices_map: dict[types.HashableIndex, list[jax.Device]] = (
collections.defaultdict(list)
)
for d, idx in _get_device_to_index_map(global_shape, sharding).items():
if d in sharding._addressable_device_assignment: # pylint: disable=protected-access
local_indices_devices_map[
np_utils.to_hashable_index(idx, shape=global_shape)
].append(d)

read_array_coros = [
_read_array_index_and_device_put(
devices,
np_utils.from_hashable_index(idx),
t,
global_shape=global_shape,
new_shard_shape=new_shard_shape,
dtype=dtype,
byte_limiter=byte_limiter,
strict=strict,
dll=dll,
)
for idx, devices in local_indices_devices_map.items()
]
dbs = sum(await asyncio.gather(*read_array_coros), [])
return jax.make_array_from_single_device_arrays(global_shape, sharding, dbs)


async def async_deserialize(
user_in_sharding: jax.sharding.Sharding | Layout,
user_sharding: jax.sharding.Sharding | Layout,
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
global_shape: Optional[Shape] = None,
dtype: Optional[jnp.dtype] = None,
Expand All @@ -530,39 +559,32 @@ async def async_deserialize(
"""Reads an array using TensorStore."""
byte_limiter = byte_limiter or get_byte_limiter()
context = context or ts_utils.get_ts_context(use_ocdbt=False)
in_sharding = (
user_in_sharding.sharding
if isinstance(user_in_sharding, Layout)
else user_in_sharding
sharding = (
user_sharding.sharding
if isinstance(user_sharding, Layout)
else user_sharding
)
if not isinstance(in_sharding, jax.sharding.Sharding):
if not isinstance(sharding, jax.sharding.Sharding):
raise ValueError(
'sharding passed to deserialization should be specified, concrete and'
f' an instance of `jax.sharding.Sharding`. Got {in_sharding}')
dll = (
user_in_sharding.device_local_layout
if isinstance(user_in_sharding, Layout)
else None
)
f' an instance of `jax.sharding.Sharding`. Got {sharding}'
)
dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
t = await ts.open(
tensorstore_spec,
open=True,
assume_metadata=assume_metadata,
context=context,
)
shape = t.shape if global_shape is None else global_shape
new_shard_shape = in_sharding.shard_shape(tuple(shape))
return await create_async_array_from_callback(
tuple(shape),
in_sharding,
functools.partial(
_read_array_index_callback,
t=t,
shape=shape,
new_shard_shape=new_shard_shape,
dtype=dtype,
byte_limiter=byte_limiter,
strict=strict,
ddl=dll,
),
global_shape = tuple(t.shape if global_shape is None else global_shape)
new_shard_shape = sharding.shard_shape(global_shape)
return await read_and_create_array(
t,
global_shape=global_shape,
new_shard_shape=new_shard_shape,
sharding=sharding,
dtype=dtype,
byte_limiter=byte_limiter,
strict=strict,
dll=dll,
)
Loading

0 comments on commit c7af62a

Please sign in to comment.