diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index e10efcf8..23445f1b 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -460,6 +460,7 @@ def as_user_metadata( is_ocdbt_checkpoint=use_ocdbt, use_zarr3=self.use_zarr3, ts_context=ts_context, + write_shape=value_meta.write_shape, ) flat_restore_types[keypath] = value_meta.value_type diff --git a/checkpoint/orbax/checkpoint/_src/metadata/value.py b/checkpoint/orbax/checkpoint/_src/metadata/value.py index 69464273..1e71acd1 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/value.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/value.py @@ -22,6 +22,7 @@ from etils import epath import jax from jax import numpy as jnp +from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.metadata import sharding as sharding_metadata @@ -55,6 +56,7 @@ class StorageMetadata: """Metadata describing how arrays are stored in a checkpoint.""" chunk_shape: Optional[tuple[int, ...]] + write_shape: arrays_types.Shape | None = None @dataclasses.dataclass diff --git a/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py b/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py index 7b8411f8..954e049f 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/value_metadata_entry.py @@ -19,6 +19,7 @@ import dataclasses from typing import Any, Dict +from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.serialization import types @@ -26,6 +27,7 @@ _VALUE_TYPE = 'value_type' _SKIP_DESERIALIZE = 'skip_deserialize' +_WRITE_SHAPE = 'write_shape' @dataclasses.dataclass @@ -41,12 +43,16 @@ class ValueMetadataEntry: value_type: str skip_deserialize: bool = False + write_shape: arrays_types.Shape | None = None def to_json(self) -> Dict[str, Any]: - return { + json_dict = { _VALUE_TYPE: self.value_type, _SKIP_DESERIALIZE: self.skip_deserialize, } + if self.write_shape is not None: + json_dict[_WRITE_SHAPE] = self.write_shape + return json_dict @classmethod def from_json( @@ -60,6 +66,11 @@ def from_json( pytree_metadata_options, ), skip_deserialize=json_dict[_SKIP_DESERIALIZE], + write_shape=( + tuple(json_dict[_WRITE_SHAPE]) + if _WRITE_SHAPE in json_dict + else None + ), ) @classmethod @@ -69,6 +80,7 @@ def build( save_arg: types.SaveArgs, ) -> ValueMetadataEntry: """Builds a ValueMetadataEntry.""" + # TODO(niket): Add support for `write_shape`. del save_arg if info.value_typestr is None: raise AssertionError( diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index c6125608..0bae33d0 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -503,7 +503,8 @@ def _array_metadata_from_tensorstore( dtype=jnp.dtype(t.dtype.name), sharding=sharding, storage=value_metadata.StorageMetadata( - chunk_shape=t.chunk_layout.read_chunk_template.shape + chunk_shape=t.chunk_layout.read_chunk_template.shape, + write_shape=info.write_shape, ), ) diff --git a/checkpoint/orbax/checkpoint/_src/serialization/types.py b/checkpoint/orbax/checkpoint/_src/serialization/types.py index 518514e7..77beff4e 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/types.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/types.py @@ -26,6 +26,7 @@ import jax.numpy as jnp import numpy as np from orbax.checkpoint import future +from orbax.checkpoint._src.arrays import types as arrays_types from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.metadata import value as value_metadata @@ -116,6 +117,8 @@ class ParamInfo: raise_array_data_missing_error: Only used for restoring. See documentation in `tensorstore_utils.py`. Comes from tree metadata and should be the same across all parameters. + write_shape: + Shape of the array shard. Used in the subchunking context. """ name: Optional[str] = None @@ -130,6 +133,7 @@ class ParamInfo: value_typestr: Optional[str] = None enable_pinned_host_transfer: bool = True raise_array_data_missing_error: bool = True + write_shape: arrays_types.Shape | None = None @dataclasses.dataclass