Skip to content

Commit

Permalink
Add user-facing TreeMetadata object returned by `PyTreeCheckpointHa…
Browse files Browse the repository at this point in the history
…ndler.metadata`. This object mimics an ordinary PyTree to make the change unnoticeable to most users, but also has additional accessible properties not included in any tree mapping operations.

PiperOrigin-RevId: 690728276
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Nov 25, 2024
1 parent 3a32a23 commit 1bdfd6e
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 31 deletions.
6 changes: 6 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix namedtuple empty value typestr when experimental support_rich_types is
disabled again after enabling it.

### Changed
- Add user-facing `TreeMetadata` object returned by
`PyTreeCheckpointHandler.metadata`. This object mimics an ordinary PyTree to
make the change unnoticeable to most users, but also has additional accessible
properties not included in any tree mapping operations.


## [0.10.1] - 2024-11-22

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint._src.metadata import tree as tree_metadata
import tensorstore as ts


Expand Down Expand Up @@ -770,7 +770,8 @@ def _read_metadata_file(
json.loads(path.read_text())
)

def metadata(self, directory: epath.Path) -> Optional[PyTree]:

def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
"""Returns tree metadata.
The result will be a PyTree matching the structure of the saved checkpoint.
Expand All @@ -797,8 +798,11 @@ def metadata(self, directory: epath.Path) -> Optional[PyTree]:
tree containing metadata.
"""
is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory)
return self._read_metadata_file(directory).as_user_metadata(
directory, self._type_handler_registry, use_ocdbt=is_ocdbt_checkpoint
return tree_metadata.TreeMetadata.build(
self._read_metadata_file(directory),
directory=directory,
type_handler_registry=self._type_handler_registry,
use_ocdbt=is_ocdbt_checkpoint,
)

def finalize(self, directory: epath.Path) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def test_metadata(self):
)
metadata = handler.metadata(self.directory)
self.assertDictEqual(
metadata.state,
metadata.state.tree,
{
'a': value_metadata.ScalarMetadata(
name='a', directory=self.directory / 'state', dtype=jnp.int64
Expand Down Expand Up @@ -709,7 +709,7 @@ def test_metadata_handler_registry(self):
)
metadata = handler.metadata(self.directory)
self.assertDictEqual(
metadata.state,
metadata.state.tree,
{
'a': value_metadata.ScalarMetadata(
name='a', directory=self.directory / 'state', dtype=jnp.int64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.serialization import serialization
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint._src.metadata import tree as tree_metadata
import tensorstore as ts


Expand Down Expand Up @@ -982,7 +982,7 @@ def _process_metadata_and_aggregate_leaves(value_meta, value):
use_zarr3,
)

def metadata(self, directory: epath.Path) -> Optional[PyTree]:
def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
"""Returns tree metadata.
The result will be a PyTree matching the structure of the saved checkpoint.
Expand Down Expand Up @@ -1053,6 +1053,12 @@ class PyTreeSaveArgs(CheckpointArgs):
ocdbt_target_data_file_size: Optional[int] = None
enable_pinned_host_transfer: bool = True

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
self.item = self.item.tree
if isinstance(self.save_args, tree_metadata.TreeMetadata):
self.save_args = self.save_args.tree


@register_with_handler(PyTreeCheckpointHandler, for_restore=True)
@dataclasses.dataclass
Expand Down Expand Up @@ -1090,3 +1096,11 @@ class PyTreeRestoreArgs(CheckpointArgs):
transforms: Optional[PyTree] = None
transforms_default_to_original: bool = True
legacy_transform_fn: Optional[LegacyTransformFn] = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
self.item = self.item.tree
if isinstance(self.restore_args, tree_metadata.TreeMetadata):
self.restore_args = self.restore_args.tree
if isinstance(self.transforms, tree_metadata.TreeMetadata):
self.transforms = self.transforms.tree
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint._src.metadata import tree as tree_metadata


PyTree = Any
Expand Down Expand Up @@ -230,7 +231,7 @@ def _replace_strict(
),
)

def metadata(self, directory: epath.Path) -> PyTree:
def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata:
"""Returns metadata about the saved item."""
return self._impl.metadata(directory)

Expand Down Expand Up @@ -258,6 +259,12 @@ class StandardSaveArgs(CheckpointArgs):
item: PyTree
save_args: Optional[PyTree] = None

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
self.item = self.item.tree
if isinstance(self.save_args, tree_metadata.TreeMetadata):
self.save_args = self.save_args.tree


@register_with_handler(StandardCheckpointHandler, for_restore=True)
@dataclasses.dataclass
Expand All @@ -282,3 +289,7 @@ class StandardRestoreArgs(CheckpointArgs):

item: Optional[PyTree] = None
strict: bool = True

def __post_init__(self):
if isinstance(self.item, tree_metadata.TreeMetadata):
self.item = self.item.tree
210 changes: 210 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dataclasses
import enum
import functools
import inspect
import operator
from typing import Any, Dict, Hashable, List, Optional, Tuple, TypeAlias, TypeVar, Union

Expand Down Expand Up @@ -515,3 +516,212 @@ def serialize_tree(
)

return tree_utils.serialize_tree(tree, keep_empty_nodes=True)


class _SupportedTreeType(enum.Enum):
"""Enumerates allowed serialized tree types."""

DICT = 'dict'
LIST = 'list'
TUPLE = 'tuple'

@classmethod
def create(cls, tree: PyTree):
if isinstance(tree, dict):
return cls.DICT
elif isinstance(tree, list):
return cls.LIST
elif isinstance(tree, tuple):
return cls.TUPLE
else:
raise ValueError(f'Unsupported tree type: {type(tree)}')


@jax.tree_util.register_pytree_with_keys_class
class TreeMetadata:
"""User-facing metadata representation of a PyTree.
The object should be treated as a regular PyTree that can be mapped over, with
values of `ocp.metadata.value.Metadata`. Additional properties can be accessed
as public attributes, but are not included in tree mapping functions. To
directly access the underlying PyTree, which matches the checkpoint structure,
use the `tree` property.
Here is a typical example usage::
with ocp.StandardCheckpointer() as ckptr:
# `metadata` is a `TreeMetadata` object, but can be treated as a regular
# PyTree. In this case, it corresponds to a "serialized" representation of
# the checkpoint tree. This means that all custom nodes are converted to
# standardized containers like list, tuple, and dict.
metadata = ckptr.metadata('/path/to/existing/checkpoint')
# Access array properties.
metadata['step'].shape
metadata['step'].dtype
# Access a list element.
metadata['opt_state'][0]
# Get all the shapes of the tree elements.
shapes = jax.tree.map(lambda x: x.shape, metadata)
If the checkpoint structure is standardized as a list or a tuple, the metadata
object can be indexed like a regular sequence::
with ocp.StandardCheckpointer() as ckptr:
metadata = ckptr.metadata('/path/to/existing/checkpoint')
metadata[0].shape
shapes = jax.tree.map(lambda x: x.shape, metadata)
Note that if we manually construct a target tree with the same structure as
the checkpoint, we will run into an error if we try to tree map over it at the
same time as the metadata object. To do this, instead access the `tree`
property.
Properties of the `TreeMetadata` object, such as `custom` and `tree`, can be
accessed directly::
with ocp.StandardCheckpointer() as ckptr:
metadata = ckptr.metadata('/path/to/existing/checkpoint')
metadata.custom
metadata.tree
"""

def __init__(
self,
*,
tree: PyTree,
custom: PyTree | None = None,
):
self._tree = tree
self._custom = custom
self._tree_type = _SupportedTreeType.create(tree)

def __repr__(self):
properties_repr = ''.join(
[f' {k}={v}\n' for k, v in self.properties().items()]
)
return f'TreeMetadata(\n{properties_repr})'

@property
def tree(self) -> PyTree:
return self._tree

@property
def custom(self) -> PyTree | None:
return self._custom

def tree_flatten(self):
flat_with_keys, aux_data = self.tree_flatten_with_keys()
tree_keys, tree_values = zip(*flat_with_keys)
return (
tree_values,
dict(
tree_keys=tree_keys,
**aux_data,
),
)

def tree_flatten_with_keys(self):
if isinstance(self._tree, dict):
tree_keys = [jax.tree_util.DictKey(k) for k in self._tree.keys()]
tree_values = self._tree.values()
else:
tree_keys = [jax.tree_util.SequenceKey(i) for i in range(len(self._tree))]
tree_values = self._tree
return (
list(zip(tree_keys, tree_values)),
dict(tree_type=self._tree_type, **self.properties(include_tree=False)),
)

@classmethod
def tree_unflatten(cls, aux_data, flat_tree):
tree_type = aux_data.pop('tree_type')
tree_keys = aux_data.pop('tree_keys')
match tree_type:
case _SupportedTreeType.DICT:
return cls(
tree={
tree_utils.get_key_name(k): v
for k, v in zip(tree_keys, flat_tree)
},
**aux_data,
)
case _SupportedTreeType.LIST:
return cls(tree=list(flat_tree), **aux_data)
case _SupportedTreeType.TUPLE:
return cls(tree=tuple(flat_tree), **aux_data)
case _:
raise ValueError(f'Unsupported tree type: {tree_type}')

def __getitem__(self, key: str | int) -> Any:
"""Retrieves the value associated with the given key in the metadata tree.
If the container is a dict, the key should be a dict key. If the container
is a list or tuple, the key should be an integer index.
Args:
key: The key to retrieve.
Returns:
The value associated with the given key.
"""
return self.tree[key]

def __contains__(self, key: str | int) -> bool:
"""Checks if the given key is present in the metadata tree.
If the container is a dict, the key should be a dict key. If the container
is a list or tuple, the key should be an integer index.
Args:
key: The key to check.
Returns:
True if the key is present in the tree, False otherwise.
"""
return key in self.tree

def __len__(self) -> int:
return len(self.tree)

def __iter__(self):
return iter(self.tree)

def get(self, key: str, default=None):
try:
return self.__getitem__(key)
except KeyError:
return default
except IndexError:
return default

def keys(self):
return self.tree.keys()

def values(self):
return self.tree.values()

def items(self):
return self.tree.items()

def properties(self, *, include_tree: bool = True) -> dict[str, Any]:
result = {
name: getattr(self, name)
for name, member in inspect.getmembers(type(self))
if isinstance(member, property)
}
if not include_tree:
result.pop('tree')
return result

@classmethod
def build(
cls,
internal_tree_metadata: InternalTreeMetadata,
*,
directory: epath.Path,
type_handler_registry: types.TypeHandlerRegistry,
use_ocdbt: bool,
) -> TreeMetadata:
"""Builds the TreeMetadata."""
return cls(
tree=internal_tree_metadata.as_user_metadata(
directory, type_handler_registry, use_ocdbt=use_ocdbt
),
)
Loading

0 comments on commit 1bdfd6e

Please sign in to comment.