From 864b9bc8a9abc65357aab1e26ec0d0068efa1734 Mon Sep 17 00:00:00 2001 From: Niket Kumar Bhumihar Date: Wed, 11 Dec 2024 21:40:51 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 705359803 --- checkpoint/orbax/checkpoint/_src/metadata/tree.py | 1 + checkpoint/orbax/checkpoint/_src/metadata/value.py | 2 ++ .../_src/metadata/value_metadata_entry.py | 14 +++++++++++++- .../checkpoint/_src/serialization/type_handlers.py | 4 +++- .../orbax/checkpoint/_src/serialization/types.py | 8 ++++++++ 5 files changed, 27 insertions(+), 2 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index e10efcf8..23445f1b 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -460,6 +460,7 @@ def as_user_metadata( is_ocdbt_checkpoint=use_ocdbt, use_zarr3=self.use_zarr3, ts_context=ts_context, + write_shape=value_meta.write_shape, ) flat_restore_types[keypath] = value_meta.value_type diff --git a/checkpoint/orbax/checkpoint/_src/metadata/value.py b/checkpoint/orbax/checkpoint/_src/metadata/value.py index 69464273..1e71acd1 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/value.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/value.py @@ -22,6 +22,7 @@ from etils import epath import jax from jax import numpy as jnp +from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.metadata import sharding as sharding_metadata @@ -55,6 +56,7 @@ class StorageMetadata: """Metadata describing how arrays are stored in a checkpoint.""" chunk_shape: Optional[tuple[int, ...]] + write_shape: arrays_types.Shape | None = None @dataclasses.dataclass diff --git a/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py b/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py index 7b8411f8..954e049f 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py @@ -19,6 +19,7 @@ import dataclasses from typing import Any, Dict +from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.serialization import types @@ -26,6 +27,7 @@ _VALUE_TYPE = 'value_type' _SKIP_DESERIALIZE = 'skip_deserialize' +_WRITE_SHAPE = 'write_shape' @dataclasses.dataclass @@ -41,12 +43,16 @@ class ValueMetadataEntry: value_type: str skip_deserialize: bool = False + write_shape: arrays_types.Shape | None = None def to_json(self) -> Dict[str, Any]: - return { + json_dict = { _VALUE_TYPE: self.value_type, _SKIP_DESERIALIZE: self.skip_deserialize, } + if self.write_shape is not None: + json_dict[_WRITE_SHAPE] = self.write_shape + return json_dict @classmethod def from_json( @@ -60,6 +66,11 @@ def from_json( pytree_metadata_options, ), skip_deserialize=json_dict[_SKIP_DESERIALIZE], + write_shape=( + tuple(json_dict[_WRITE_SHAPE]) + if _WRITE_SHAPE in json_dict + else None + ), ) @classmethod @@ -69,6 +80,7 @@ def build( save_arg: types.SaveArgs, ) -> ValueMetadataEntry: """Builds a ValueMetadataEntry.""" + # TODO(niket): Add support for `write_shape`. del save_arg if info.value_typestr is None: raise AssertionError( diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index c6125608..d892e927 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -228,6 +228,7 @@ def _build_array_tspec_write( dtype=dtype, target_dtype=(arg.dtype if arg is not None else None), chunk_byte_size=(arg.chunk_byte_size if arg is not None else None), + shard_axes=(arg.shard_axes if arg is not None else None), use_zarr3=info.use_zarr3, use_ocdbt=use_ocdbt, process_id=process_index, @@ -503,7 +504,8 @@ def _array_metadata_from_tensorstore( dtype=jnp.dtype(t.dtype.name), sharding=sharding, storage=value_metadata.StorageMetadata( - chunk_shape=t.chunk_layout.read_chunk_template.shape + chunk_shape=t.chunk_layout.read_chunk_template.shape, + write_shape=info.write_shape, ), ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/types.py b/checkpoint/orbax/checkpoint/_src/serialization/types.py index 518514e7..fbe7128b 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/types.py @@ -26,6 +26,7 @@ import jax.numpy as jnp import numpy as np from orbax.checkpoint import future +from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.metadata import value as value_metadata @@ -116,6 +117,8 @@ class ParamInfo: raise_array_data_missing_error: Only used for restoring. See documentation in `tensorstore_utils.py`. Comes from tree metadata and should be the same across all parameters. + write_shape: + Shape of the array shard. Used in the subchunking context. """ name: Optional[str] = None @@ -130,6 +133,7 @@ class ParamInfo: value_typestr: Optional[str] = None enable_pinned_host_transfer: bool = True raise_array_data_missing_error: bool = True + write_shape: arrays_types.Shape | None = None @dataclasses.dataclass @@ -153,11 +157,15 @@ class SaveArgs: specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape are automatically set to the chosen shape. This uses a greedy algorithm that prioritizes splitting the largest dimensions first. + shard_axes: An optional list of axes that should be prioritized when + sharding array for storage. If empty, storage sharding implementation will + prioritize axes which are already sharded. """ aggregate: bool = False dtype: Optional[jnp.dtype] = None chunk_byte_size: Optional[int] = None + shard_axes: tuple[int, ...] = tuple() def __post_init__(self): if self.aggregate: