diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 3c713b44..3178ffae 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [emergency checkpoint] Fix local restore by re-mapping device ids directly instead of inferring them from how process indexes changed across restarts with some false assumptions. +- Avoid duplicate reads of the same index when loading. ### Changed - Coordination service now supports barrier reuse - eliminate some barrier name diff --git a/checkpoint/orbax/checkpoint/_src/arrays/numpy_utils.py b/checkpoint/orbax/checkpoint/_src/arrays/numpy_utils.py index 6c2b511d..be12ad31 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/numpy_utils.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/numpy_utils.py @@ -22,12 +22,16 @@ NdSlice = types.NdSlice Index = types.Index +HashableSlice = types.HashableSlice +HashableIndex = types.HashableIndex + def int_tuple_from_slice(s: slice) -> tuple[int, ...]: """Represents a slice as a tuple of integers.""" - ints = s.start, s.stop, s.step + start, stop, step = s.start, s.stop, s.step + step = step or 1 try: - return tuple(int(x) for x in ints) + return (int(start), int(stop), int(step)) except: raise ValueError(f'Slice {s} is not concrete.') from None @@ -52,6 +56,30 @@ def resolve_slice(xs: NdSlice, shape: Shape) -> NdSlice: for x, n in zip(() if xs is Ellipsis else xs, shape)) +def to_hashable_index( + idx: Index, *, shape: Shape | None = None +) -> HashableIndex: + """Converts an Index into a hashable form. + + Optionally resolves the slices to a concrete index if the shape is provided. + If not, conversion may fail if the slices are not concrete. + + Args: + idx: The index to convert. + shape: Global array shape. + + Returns: + A hashable index. + """ + idx = resolve_slice(idx, shape) if shape else idx + + return tuple([int_tuple_from_slice(s) for s in idx]) + + +def from_hashable_index(idx: HashableIndex) -> Index: + return tuple([slice(s[0], s[1], s[2]) for s in idx]) + + def dissolve_slice( xs: NdSlice, shape: Shape, diff --git a/checkpoint/orbax/checkpoint/_src/arrays/types.py b/checkpoint/orbax/checkpoint/_src/arrays/types.py index 93d047c8..55651259 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/types.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/types.py @@ -36,3 +36,8 @@ class NumpyShapeDtypeStruct: """Abstract representation of a Numpy array.""" shape: Shape dtype: np.dtype + + +# Slice objects are not hashable before python 3.12. +HashableSlice = tuple[int, int, int] +HashableIndex = tuple[HashableSlice, ...] diff --git a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py index 8be6cfa0..7a026c2b 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py @@ -32,8 +32,9 @@ Shape = types.Shape Index = types.Index OptionalAxisAndShape = tuple[int | None, Shape| None] -# Slice objects are not hashable before python 3.12. -HashableSlice = tuple[int | None, int | None, int | None] + +HashableIndex = types.HashableIndex +HashableSlice = types.HashableSlice @dataclasses.dataclass(frozen=True) @@ -138,10 +139,6 @@ def to_fragments(self) -> fragments.Fragments: return result -def _hashable_slices(slices: Index) -> tuple[HashableSlice, ...]: - return tuple([(s.start, s.stop, s.step) for s in slices]) - - @functools.lru_cache(maxsize=4096) def _sharding_num_replicas( sharding: jax.sharding.Sharding, global_shape: Shape @@ -174,7 +171,7 @@ def _sharding_num_replicas( """ counts = collections.defaultdict(int) for index in sharding.devices_indices_map(global_shape).values(): - counts[_hashable_slices(index)] += 1 + counts[numpy_utils.to_hashable_index(index, shape=global_shape)] += 1 num_replicas = next(iter(counts.values())) assert all(count == num_replicas for count in counts.values()) return num_replicas diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py index 9d5379d8..4e571235 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization.py @@ -15,12 +15,12 @@ """Array serialization and deserialization.""" import asyncio -from collections.abc import Awaitable, Mapping +import collections +from collections.abc import Mapping import contextlib -import functools import os import re -from typing import Any, AsyncIterator, Callable, Dict, Optional, Protocol, Sequence, Union +from typing import Any, AsyncIterator, Dict, Optional, Protocol, Sequence, Union from absl import logging import humanize @@ -29,6 +29,7 @@ import jax.numpy as jnp import numpy as np from orbax.checkpoint._src.arrays import fragments +from orbax.checkpoint._src.arrays import numpy_utils as np_utils from orbax.checkpoint._src.arrays import types from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.serialization import replica_slices @@ -51,25 +52,6 @@ Shape = types.Shape -def _get_device_to_index_map( - global_shape: Shape, sharding: jax.sharding.Sharding -) -> Mapping[jax.Device, Index]: - return sharding.devices_indices_map(global_shape) - - -async def create_async_array_from_callback( - global_shape: Shape, - sharding: jax.sharding.Sharding, - data_callback: Callable[[Index, jax.Device], Awaitable[jax.Array]], -) -> jax.Array: - device_to_index_map = _get_device_to_index_map(global_shape, sharding) - addressable_da = sharding._addressable_device_assignment # pylint: disable=protected-access - future_arrays = [data_callback(device_to_index_map[d], d) - for d in addressable_da] - dbs = await asyncio.gather(*future_arrays) - return jax.make_array_from_single_device_arrays(global_shape, sharding, dbs) - - def _get_metadata(arr: jax.Array, local_shape: Shape): return { 'compressor': {'id': 'zstd'}, @@ -427,16 +409,15 @@ def estimate_read_memory_footprint(t: ts.TensorStore, return num_bytes -async def _read_and_device_put_shard( - device: jax.Device, +async def _read_shard( t: ts.TensorStore, + *, new_shard_shape: Sequence[int], dtype: jnp.dtype, requested_domain: ts.IndexDomain, restricted_domain: ts.IndexDomain, - dll: Optional[layout.DeviceLocalLayout], -) -> jax.Array: - """Reads a single shard from TensorStore and places it on device.""" +) -> np.ndarray: + """Reads a single shard from TensorStore into host memory.""" # This maybe needed because the shape the array was saved with is smaller # than the requested shape of the array in which it will be reloaded. So # the extra values will be filled with 0s. @@ -448,29 +429,22 @@ async def _read_and_device_put_shard( # Cast while reloading on process to avoid 2 copies on device if the # casting is done on device. out = out.astype(dtype) - # Convert to jnp array so that layouts are initialized properly for - # sub-byte dtypes. - # TODO(yashkatariya): This is a band-aid fix. Figure out a better way to - # make this work. - if out.dtype == jnp.int4: - out = jnp.asarray(out) # type: ignore - return jax.device_put( - out, Layout(dll, jax.sharding.SingleDeviceSharding(device)) - ) + return out -async def _read_array_index_callback( +async def _read_array_index_and_device_put( + devices: list[jax.Device], index: Index, - device: jax.Device, t: ts.TensorStore, - shape: Shape, + *, + global_shape: Shape, new_shard_shape: Shape, dtype: jnp.dtype, byte_limiter: ByteLimiter, strict: bool, - ddl: Optional[layout.DeviceLocalLayout], -) -> jax.Array: - """Callback that reads an array index and places on device.""" + dll: Optional[layout.DeviceLocalLayout], +) -> list[jax.Array]: + """Callback that reads an array index and places on the devices.""" for sl in index: if sl.step is not None and sl.step != 1: raise ValueError( @@ -479,45 +453,100 @@ async def _read_array_index_callback( ) if strict: - if t.shape == shape: - domain = ts.IndexDomain(shape=shape)[ts.d[:][index]] + if t.shape == global_shape: + domain = ts.IndexDomain(shape=global_shape)[ts.d[:][index]] requested_domain = domain restricted_domain = domain else: raise ValueError( - f'Requested shape: {shape} is not compatible with the stored shape:' - f' {t.shape}. Truncating/padding is disabled by setting of' - ' `strict=True`. When using standard Orbax APIs, this behavior can be' - ' modified by specifying `strict=False` in `ArrayRestoreArgs` for any' - ' array in which padding/truncation is desired.' + f'Requested shape: {global_shape} is not compatible with the stored' + f' shape: {t.shape}. Truncating/padding is disabled by setting of' + ' `strict=True`. When using standard Orbax APIs, this behavior can' + ' be modified by specifying `strict=False` in `ArrayRestoreArgs` for' + ' any array in which padding/truncation is desired.' ) else: - requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + requested_domain = ts.IndexTransform(input_shape=global_shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) requested_bytes = estimate_read_memory_footprint(t, restricted_domain) + result = [] # Limit the bytes read for every shard. + # Perform read for index once, and place it on all relevant devices within + # the `reserved_bytes` context. + # TODO(b/381111280) This de-duplication of reads does not fully solve the + # problem of read amplification, since we can still run into problems if + # we are resharding. See b/381111280 for details. async with reserved_bytes(byte_limiter, requested_bytes): try: - result = await _read_and_device_put_shard( - device, - t, - new_shard_shape, - dtype, - requested_domain, - restricted_domain, - ddl, + shard = await _read_shard( + t=t, + new_shard_shape=new_shard_shape, + dtype=dtype, + requested_domain=requested_domain, + restricted_domain=restricted_domain, ) except BaseException as e: raise Exception( # pylint: disable=broad-exception-raised f'Encountered error while reading array index: {index}. See full' f' TensorStore details: {t.spec}.' ) from e + for device in devices: + result.append( + jax.device_put( + shard, Layout(dll, jax.sharding.SingleDeviceSharding(device)) + ) + ) return result +def _get_device_to_index_map( + global_shape: Shape, sharding: jax.sharding.Sharding +) -> Mapping[jax.Device, Index]: + return sharding.devices_indices_map(global_shape) + + +async def read_and_create_array( + t: ts.TensorStore, + *, + global_shape: Shape, + new_shard_shape: Shape, + sharding: jax.sharding.Sharding, + dtype: jnp.dtype, + byte_limiter: ByteLimiter, + strict: bool, + dll: Optional[layout.DeviceLocalLayout], +) -> jax.Array: + """Read shards from TensorStore and create a jax.Array.""" + local_indices_devices_map: dict[types.HashableIndex, list[jax.Device]] = ( + collections.defaultdict(list) + ) + for d, idx in _get_device_to_index_map(global_shape, sharding).items(): + if d in sharding._addressable_device_assignment: # pylint: disable=protected-access + local_indices_devices_map[ + np_utils.to_hashable_index(idx, shape=global_shape) + ].append(d) + + read_array_coros = [ + _read_array_index_and_device_put( + devices, + np_utils.from_hashable_index(idx), + t, + global_shape=global_shape, + new_shard_shape=new_shard_shape, + dtype=dtype, + byte_limiter=byte_limiter, + strict=strict, + dll=dll, + ) + for idx, devices in local_indices_devices_map.items() + ] + dbs = sum(await asyncio.gather(*read_array_coros), []) + return jax.make_array_from_single_device_arrays(global_shape, sharding, dbs) + + async def async_deserialize( - user_in_sharding: jax.sharding.Sharding | Layout, + user_sharding: jax.sharding.Sharding | Layout, tensorstore_spec: Union[ts.Spec, Dict[str, Any]], global_shape: Optional[Shape] = None, dtype: Optional[jnp.dtype] = None, @@ -530,39 +559,32 @@ async def async_deserialize( """Reads an array using TensorStore.""" byte_limiter = byte_limiter or get_byte_limiter() context = context or ts_utils.get_ts_context(use_ocdbt=False) - in_sharding = ( - user_in_sharding.sharding - if isinstance(user_in_sharding, Layout) - else user_in_sharding + sharding = ( + user_sharding.sharding + if isinstance(user_sharding, Layout) + else user_sharding ) - if not isinstance(in_sharding, jax.sharding.Sharding): + if not isinstance(sharding, jax.sharding.Sharding): raise ValueError( 'sharding passed to deserialization should be specified, concrete and' - f' an instance of `jax.sharding.Sharding`. Got {in_sharding}') - dll = ( - user_in_sharding.device_local_layout - if isinstance(user_in_sharding, Layout) - else None - ) + f' an instance of `jax.sharding.Sharding`. Got {sharding}' + ) + dll = sharding.device_local_layout if isinstance(sharding, Layout) else None t = await ts.open( tensorstore_spec, open=True, assume_metadata=assume_metadata, context=context, ) - shape = t.shape if global_shape is None else global_shape - new_shard_shape = in_sharding.shard_shape(tuple(shape)) - return await create_async_array_from_callback( - tuple(shape), - in_sharding, - functools.partial( - _read_array_index_callback, - t=t, - shape=shape, - new_shard_shape=new_shard_shape, - dtype=dtype, - byte_limiter=byte_limiter, - strict=strict, - ddl=dll, - ), + global_shape = tuple(t.shape if global_shape is None else global_shape) + new_shard_shape = sharding.shard_shape(global_shape) + return await read_and_create_array( + t, + global_shape=global_shape, + new_shard_shape=new_shard_shape, + sharding=sharding, + dtype=dtype, + byte_limiter=byte_limiter, + strict=strict, + dll=dll, ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py index 6fecfcff..0c6aa21c 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -import logging import math import os import pathlib @@ -187,24 +186,26 @@ async def deserialize_with_byte_limit(): r.block_until_ready() tm.start() + _, start_memory_usage = tm.get_traced_memory() asyncio_utils.run_sync(deserialize_with_byte_limit()) - unused_current, peak = tm.get_traced_memory() + _, peak_memory_usage = tm.get_traced_memory() # NB: some padding + tensorstore overhead. It should always be # less than array size (2048 * 4096 * 4 = 32M) - self.assertLess(peak, 10_000_000) + self.assertLess(peak_memory_usage - start_memory_usage, 10_000_000) deserialize_wo_limit = serialization.async_deserialize( sharding, tspec, inp_shape ) tm.clear_traces() + _, start_memory_usage = tm.get_traced_memory() # NB: call block_until_ready() is important here and above # because otherwise this leads to racing condition and segfault with # tensorstore attempting to dealloc using tracemalloc which is already # destroyed. asyncio_utils.run_sync(deserialize_wo_limit).block_until_ready() - unused_current, peak = tm.get_traced_memory() + _, peak_memory_usage = tm.get_traced_memory() # We load entire array in memory here. - self.assertGreater(peak, 30_000_000) + self.assertGreater(peak_memory_usage - start_memory_usage, 30_000_000) tm.stop() def test_checkpointing_jax_array(self): @@ -259,9 +260,6 @@ def cb3(_): tspecs, ) - logging.info(m1.addressable_shards) - logging.info(m2.addressable_shards) - logging.info(m3.addressable_shards) self.assertIsInstance(m1, jax.Array) self.assertArraysEqual( np.asarray(m1.addressable_shards[0].data), @@ -662,6 +660,32 @@ def test_incomplete_write(self): ): deserialize([sharding], [tspec]) + @parameterized.named_parameters( + dict(testcase_name='fully_replicated', pspec=(None, None)), + dict(testcase_name='partially_replicated', pspec=('x', None)), + dict(testcase_name='fully_sharded', pspec=('x', 'y')), + ) + def test_dedup_loading(self, pspec): + data = np.arange(2_048 * 4_096, dtype=np.float32).reshape(2_048, 4_096) + global_shape = data.shape + global_mesh = create_global_mesh((2, 2), ('x', 'y')) + sharding = NamedSharding(global_mesh, P(*pspec)) + array = jax.make_array_from_callback( + global_shape, sharding, lambda idx: data[idx] + ) + ckpt_paths = [str(self.ckpt_dir)] + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) + serialize([array], tspecs) + + tm.start() + _, start_memory_usage = tm.get_traced_memory() + deserialize([sharding], tspecs, [global_shape]) + _, peak_memory_usage = tm.get_traced_memory() + tm.clear_traces() + # Array size (2048 * 4096 * 4 = 32M) + delta = 2_000_000 # Empirically chosen wiggle room. + self.assertLess(peak_memory_usage - start_memory_usage, 32_000_000 + delta) + if __name__ == '__main__': absltest.main()