diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index b201a05a..6fbbf549 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -16,6 +16,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix namedtuple empty value typestr when experimental support_rich_types is disabled again after enabling it. +### Changed +- Add user-facing `TreeMetadata` object returned by +`PyTreeCheckpointHandler.metadata`. This object mimics an ordinary PyTree to +make the change unnoticeable to most users, but also has additional accessible +properties not included in any tree mapping operations. + ## [0.10.1] - 2024-11-22 diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index b61c5d93..0a5772f2 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -42,7 +42,6 @@ from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.handlers import async_checkpoint_handler from orbax.checkpoint._src.metadata import empty_values -from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.serialization import serialization @@ -50,6 +49,7 @@ from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.serialization import types from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint._src.metadata import tree as tree_metadata import tensorstore as ts @@ -770,7 +770,8 @@ def _read_metadata_file( json.loads(path.read_text()) ) - def metadata(self, directory: epath.Path) -> Optional[PyTree]: + + def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: """Returns tree metadata. The result will be a PyTree matching the structure of the saved checkpoint. @@ -797,8 +798,11 @@ def metadata(self, directory: epath.Path) -> Optional[PyTree]: tree containing metadata. """ is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory) - return self._read_metadata_file(directory).as_user_metadata( - directory, self._type_handler_registry, use_ocdbt=is_ocdbt_checkpoint + return tree_metadata.TreeMetadata.build( + self._read_metadata_file(directory), + directory=directory, + type_handler_registry=self._type_handler_registry, + use_ocdbt=is_ocdbt_checkpoint, ) def finalize(self, directory: epath.Path) -> None: diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py index a7ec0563..2b43262d 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py @@ -669,7 +669,7 @@ def test_metadata(self): ) metadata = handler.metadata(self.directory) self.assertDictEqual( - metadata.state, + metadata.state.tree, { 'a': value_metadata.ScalarMetadata( name='a', directory=self.directory / 'state', dtype=jnp.int64 @@ -709,7 +709,7 @@ def test_metadata_handler_registry(self): ) metadata = handler.metadata(self.directory) self.assertDictEqual( - metadata.state, + metadata.state.tree, { 'a': value_metadata.ScalarMetadata( name='a', directory=self.directory / 'state', dtype=jnp.int64 diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index 716885f3..fcfaf0c3 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -42,11 +42,11 @@ from orbax.checkpoint._src.handlers import async_checkpoint_handler from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler from orbax.checkpoint._src.metadata import empty_values -from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint._src.metadata import tree as tree_metadata import tensorstore as ts @@ -982,7 +982,7 @@ def _process_metadata_and_aggregate_leaves(value_meta, value): use_zarr3, ) - def metadata(self, directory: epath.Path) -> Optional[PyTree]: + def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: """Returns tree metadata. The result will be a PyTree matching the structure of the saved checkpoint. @@ -1053,6 +1053,12 @@ class PyTreeSaveArgs(CheckpointArgs): ocdbt_target_data_file_size: Optional[int] = None enable_pinned_host_transfer: bool = True + def __post_init__(self): + if isinstance(self.item, tree_metadata.TreeMetadata): + self.item = self.item.tree + if isinstance(self.save_args, tree_metadata.TreeMetadata): + self.save_args = self.save_args.tree + @register_with_handler(PyTreeCheckpointHandler, for_restore=True) @dataclasses.dataclass @@ -1090,3 +1096,11 @@ class PyTreeRestoreArgs(CheckpointArgs): transforms: Optional[PyTree] = None transforms_default_to_original: bool = True legacy_transform_fn: Optional[LegacyTransformFn] = None + + def __post_init__(self): + if isinstance(self.item, tree_metadata.TreeMetadata): + self.item = self.item.tree + if isinstance(self.restore_args, tree_metadata.TreeMetadata): + self.restore_args = self.restore_args.tree + if isinstance(self.transforms, tree_metadata.TreeMetadata): + self.transforms = self.transforms.tree diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index 25adba96..e3d64ff1 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -31,6 +31,7 @@ from orbax.checkpoint._src.handlers import async_checkpoint_handler from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint._src.metadata import tree as tree_metadata PyTree = Any @@ -230,7 +231,7 @@ def _replace_strict( ), ) - def metadata(self, directory: epath.Path) -> PyTree: + def metadata(self, directory: epath.Path) -> tree_metadata.TreeMetadata: """Returns metadata about the saved item.""" return self._impl.metadata(directory) @@ -258,6 +259,12 @@ class StandardSaveArgs(CheckpointArgs): item: PyTree save_args: Optional[PyTree] = None + def __post_init__(self): + if isinstance(self.item, tree_metadata.TreeMetadata): + self.item = self.item.tree + if isinstance(self.save_args, tree_metadata.TreeMetadata): + self.save_args = self.save_args.tree + @register_with_handler(StandardCheckpointHandler, for_restore=True) @dataclasses.dataclass @@ -282,3 +289,7 @@ class StandardRestoreArgs(CheckpointArgs): item: Optional[PyTree] = None strict: bool = True + + def __post_init__(self): + if isinstance(self.item, tree_metadata.TreeMetadata): + self.item = self.item.tree diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index 5a7d9c64..f9cd5e3c 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -21,6 +21,7 @@ import dataclasses import enum import functools +import inspect import operator from typing import Any, Dict, Hashable, List, Optional, Tuple, TypeAlias, TypeVar, Union @@ -515,3 +516,212 @@ def serialize_tree( ) return tree_utils.serialize_tree(tree, keep_empty_nodes=True) + + +class _SupportedTreeType(enum.Enum): + """Enumerates allowed serialized tree types.""" + + DICT = 'dict' + LIST = 'list' + TUPLE = 'tuple' + + @classmethod + def create(cls, tree: PyTree): + if isinstance(tree, dict): + return cls.DICT + elif isinstance(tree, list): + return cls.LIST + elif isinstance(tree, tuple): + return cls.TUPLE + else: + raise ValueError(f'Unsupported tree type: {type(tree)}') + + +@jax.tree_util.register_pytree_with_keys_class +class TreeMetadata: + """User-facing metadata representation of a PyTree. + + The object should be treated as a regular PyTree that can be mapped over, with + values of `ocp.metadata.value.Metadata`. Additional properties can be accessed + as public attributes, but are not included in tree mapping functions. To + directly access the underlying PyTree, which matches the checkpoint structure, + use the `tree` property. + + Here is a typical example usage:: + with ocp.StandardCheckpointer() as ckptr: + # `metadata` is a `TreeMetadata` object, but can be treated as a regular + # PyTree. In this case, it corresponds to a "serialized" representation of + # the checkpoint tree. This means that all custom nodes are converted to + # standardized containers like list, tuple, and dict. + metadata = ckptr.metadata('/path/to/existing/checkpoint') + # Access array properties. + metadata['step'].shape + metadata['step'].dtype + # Access a list element. + metadata['opt_state'][0] + # Get all the shapes of the tree elements. + shapes = jax.tree.map(lambda x: x.shape, metadata) + + If the checkpoint structure is standardized as a list or a tuple, the metadata + object can be indexed like a regular sequence:: + with ocp.StandardCheckpointer() as ckptr: + metadata = ckptr.metadata('/path/to/existing/checkpoint') + metadata[0].shape + shapes = jax.tree.map(lambda x: x.shape, metadata) + + Note that if we manually construct a target tree with the same structure as + the checkpoint, we will run into an error if we try to tree map over it at the + same time as the metadata object. To do this, instead access the `tree` + property. + + Properties of the `TreeMetadata` object, such as `custom` and `tree`, can be + accessed directly:: + with ocp.StandardCheckpointer() as ckptr: + metadata = ckptr.metadata('/path/to/existing/checkpoint') + metadata.custom + metadata.tree + """ + + def __init__( + self, + *, + tree: PyTree, + custom: PyTree | None = None, + ): + self._tree = tree + self._custom = custom + self._tree_type = _SupportedTreeType.create(tree) + + def __repr__(self): + properties_repr = ''.join( + [f' {k}={v}\n' for k, v in self.properties().items()] + ) + return f'TreeMetadata(\n{properties_repr})' + + @property + def tree(self) -> PyTree: + return self._tree + + @property + def custom(self) -> PyTree | None: + return self._custom + + def tree_flatten(self): + flat_with_keys, aux_data = self.tree_flatten_with_keys() + tree_keys, tree_values = zip(*flat_with_keys) + return ( + tree_values, + dict( + tree_keys=tree_keys, + **aux_data, + ), + ) + + def tree_flatten_with_keys(self): + if isinstance(self._tree, dict): + tree_keys = [jax.tree_util.DictKey(k) for k in self._tree.keys()] + tree_values = self._tree.values() + else: + tree_keys = [jax.tree_util.SequenceKey(i) for i in range(len(self._tree))] + tree_values = self._tree + return ( + list(zip(tree_keys, tree_values)), + dict(tree_type=self._tree_type, **self.properties(include_tree=False)), + ) + + @classmethod + def tree_unflatten(cls, aux_data, flat_tree): + tree_type = aux_data.pop('tree_type') + tree_keys = aux_data.pop('tree_keys') + match tree_type: + case _SupportedTreeType.DICT: + return cls( + tree={ + tree_utils.get_key_name(k): v + for k, v in zip(tree_keys, flat_tree) + }, + **aux_data, + ) + case _SupportedTreeType.LIST: + return cls(tree=list(flat_tree), **aux_data) + case _SupportedTreeType.TUPLE: + return cls(tree=tuple(flat_tree), **aux_data) + case _: + raise ValueError(f'Unsupported tree type: {tree_type}') + + def __getitem__(self, key: str | int) -> Any: + """Retrieves the value associated with the given key in the metadata tree. + + If the container is a dict, the key should be a dict key. If the container + is a list or tuple, the key should be an integer index. + + Args: + key: The key to retrieve. + + Returns: + The value associated with the given key. + """ + return self.tree[key] + + def __contains__(self, key: str | int) -> bool: + """Checks if the given key is present in the metadata tree. + + If the container is a dict, the key should be a dict key. If the container + is a list or tuple, the key should be an integer index. + + Args: + key: The key to check. + + Returns: + True if the key is present in the tree, False otherwise. + """ + return key in self.tree + + def __len__(self) -> int: + return len(self.tree) + + def __iter__(self): + return iter(self.tree) + + def get(self, key: str, default=None): + try: + return self.__getitem__(key) + except KeyError: + return default + except IndexError: + return default + + def keys(self): + return self.tree.keys() + + def values(self): + return self.tree.values() + + def items(self): + return self.tree.items() + + def properties(self, *, include_tree: bool = True) -> dict[str, Any]: + result = { + name: getattr(self, name) + for name, member in inspect.getmembers(type(self)) + if isinstance(member, property) + } + if not include_tree: + result.pop('tree') + return result + + @classmethod + def build( + cls, + internal_tree_metadata: InternalTreeMetadata, + *, + directory: epath.Path, + type_handler_registry: types.TypeHandlerRegistry, + use_ocdbt: bool, + ) -> TreeMetadata: + """Builds the TreeMetadata.""" + return cls( + tree=internal_tree_metadata.as_user_metadata( + directory, type_handler_registry, use_ocdbt=use_ocdbt + ), + ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py index 2081719e..11f43ea8 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py @@ -18,16 +18,17 @@ from absl.testing import parameterized import chex import jax -from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.metadata import tree as tree_metadata_lib from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.serialization import types from orbax.checkpoint._src.testing import test_tree_utils from orbax.checkpoint._src.tree import utils as tree_utils +from orbax.checkpoint._src.metadata import tree as tree_metadata_lib def _to_param_infos( tree: Any, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, + pytree_metadata_options: tree_metadata_lib.PyTreeMetadataOptions, ): return jax.tree.map( # Other properties are not relevant. @@ -48,23 +49,25 @@ class InternalTreeMetadataEntryTest(parameterized.TestCase): @parameterized.product( test_pytree=test_tree_utils.TEST_PYTREES, pytree_metadata_options=[ - tree_metadata.PyTreeMetadataOptions(support_rich_types=False), - tree_metadata.PyTreeMetadataOptions(support_rich_types=True), + tree_metadata_lib.PyTreeMetadataOptions(support_rich_types=False), + tree_metadata_lib.PyTreeMetadataOptions(support_rich_types=True), ], ) def test_as_nested_tree( self, test_pytree: test_tree_utils.TestPyTree, - pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, + pytree_metadata_options: tree_metadata_lib.PyTreeMetadataOptions, ): tree = test_pytree.provide_tree() - original_internal_tree_metadata = tree_metadata.InternalTreeMetadata.build( - param_infos=_to_param_infos(tree, pytree_metadata_options), - pytree_metadata_options=pytree_metadata_options, + original_internal_tree_metadata = ( + tree_metadata_lib.InternalTreeMetadata.build( + param_infos=_to_param_infos(tree, pytree_metadata_options), + pytree_metadata_options=pytree_metadata_options, + ) ) json_object = original_internal_tree_metadata.to_json() restored_internal_tree_metadata = ( - tree_metadata.InternalTreeMetadata.from_json( + tree_metadata_lib.InternalTreeMetadata.from_json( json_object, pytree_metadata_options ) ) @@ -82,12 +85,12 @@ def test_as_nested_tree( test_pytree=test_tree_utils.TEST_PYTREES, pytree_metadata_options_switch=[ ( - tree_metadata.PyTreeMetadataOptions(support_rich_types=False), - tree_metadata.PyTreeMetadataOptions(support_rich_types=True), + tree_metadata_lib.PyTreeMetadataOptions(support_rich_types=False), + tree_metadata_lib.PyTreeMetadataOptions(support_rich_types=True), ), ( - tree_metadata.PyTreeMetadataOptions(support_rich_types=True), - tree_metadata.PyTreeMetadataOptions(support_rich_types=False), + tree_metadata_lib.PyTreeMetadataOptions(support_rich_types=True), + tree_metadata_lib.PyTreeMetadataOptions(support_rich_types=False), ), ], ) @@ -95,8 +98,8 @@ def test_switching_between_support_rich_types( self, test_pytree: test_tree_utils.TestPyTree, pytree_metadata_options_switch: tuple[ - tree_metadata.PyTreeMetadataOptions, - tree_metadata.PyTreeMetadataOptions, + tree_metadata_lib.PyTreeMetadataOptions, + tree_metadata_lib.PyTreeMetadataOptions, ], ): write_pytree_metadata_options, read_pytree_metadata_options = ( @@ -114,20 +117,121 @@ def test_switching_between_support_rich_types( expected_tree_metadata = test_pytree.expected_nested_tree_metadata tree = test_pytree.provide_tree() - original_internal_tree_metadata = tree_metadata.InternalTreeMetadata.build( - param_infos=_to_param_infos(tree, write_pytree_metadata_options), - pytree_metadata_options=write_pytree_metadata_options, + original_internal_tree_metadata = ( + tree_metadata_lib.InternalTreeMetadata.build( + param_infos=_to_param_infos(tree, write_pytree_metadata_options), + pytree_metadata_options=write_pytree_metadata_options, + ) ) json_object = original_internal_tree_metadata.to_json() restored_internal_tree_metadata = ( - tree_metadata.InternalTreeMetadata.from_json( + tree_metadata_lib.InternalTreeMetadata.from_json( json_object, read_pytree_metadata_options - ) - ) + )) restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree() chex.assert_trees_all_equal(restored_tree_metadata, expected_tree_metadata) +class TreeMetadataTest(parameterized.TestCase): + + def _check_tree_property( + self, expected_tree: Any, metadata: tree_metadata_lib.TreeMetadata + ): + if isinstance(expected_tree, dict): + self.assertDictEqual(metadata.tree, expected_tree) + elif isinstance(expected_tree, list): + self.assertListEqual(metadata.tree, expected_tree) + elif isinstance(expected_tree, tuple): + self.assertTupleEqual(metadata.tree, expected_tree) + else: + raise ValueError(f'Unsupported tree type: {type(expected_tree)}') + + @parameterized.parameters(({'a': 1, 'b': 2},), ([1, 2],), ((1, 2),)) + def test_properties(self, tree): + custom = {'foo': 1} + metadata = tree_metadata_lib.TreeMetadata(tree=tree, custom=custom) + self.assertDictEqual(metadata.custom, custom) + self._check_tree_property(tree, metadata) + + + @parameterized.parameters(({'a': 1, 'b': 2},), ([1, 2],), ((1, 2),)) + def test_tree_map(self, tree): + custom = {'foo': 1} + metadata = tree_metadata_lib.TreeMetadata(tree=tree, custom=custom) + metadata = jax.tree.map(lambda x: x + 1, metadata) + self.assertDictEqual(metadata.custom, custom) + self._check_tree_property(jax.tree.map(lambda x: x + 1, tree), metadata) + + @parameterized.parameters(({'a': 1, 'b': 2},), ([1, 2],), ((1, 2),)) + def test_multiple_tree_map(self, tree): + metadata = tree_metadata_lib.TreeMetadata(tree=tree) + with self.assertRaises(ValueError): + _ = jax.tree.map(lambda x, y: x + y, metadata, tree) + + @parameterized.parameters(({'a': 1, 'b': 2},), ([1, 2],), ((1, 2),)) + def test_accessors(self, tree): + metadata = tree_metadata_lib.TreeMetadata(tree=tree) + self.assertLen(metadata, 2) + if isinstance(tree, dict): + self.assertIn('a', metadata) + self.assertIn('b', metadata) + self.assertNotIn('c', metadata) + self.assertEqual(metadata['a'], 1) + self.assertEqual(metadata['b'], 2) + self.assertEqual(metadata.get('a'), 1) + self.assertEqual(metadata.get('b'), 2) + self.assertIsNone(metadata.get('c')) + with self.assertRaises(KeyError): + _ = metadata['c'] + else: + self.assertNotIn(0, metadata) + self.assertIn(1, metadata) + self.assertIn(2, metadata) + self.assertEqual(metadata[0], 1) + self.assertEqual(metadata[1], 2) + self.assertEqual(metadata[-1], 2) + self.assertEqual(metadata.get(0), 1) + self.assertEqual(metadata.get(1), 2) + self.assertEqual(metadata.get(-1), 2) + self.assertIsNone(metadata.get(2)) + with self.assertRaises(IndexError): + _ = metadata[2] + + @parameterized.parameters(({'a': 1, 'b': 2},), ([1, 2],), ((1, 2),)) + def test_tree_flatten(self, tree): + metadata = tree_metadata_lib.TreeMetadata(tree=tree) + flat, treedef = jax.tree.flatten(metadata) + self.assertSequenceEqual(flat, [1, 2]) + unflat = jax.tree.unflatten(treedef, flat) + self.assertIsInstance(unflat, tree_metadata_lib.TreeMetadata) + self._check_tree_property(tree, unflat) + + @parameterized.parameters(({'a': 1, 'c': {'b': 2}},), ([1, 2],), ((1, 2),)) + def test_with_path(self, tree): + metadata = tree_metadata_lib.TreeMetadata(tree=tree) + metadata = jax.tree_util.tree_map_with_path(lambda _, x: x + 1, metadata) + self._check_tree_property( + jax.tree_util.tree_map_with_path(lambda _, x: x + 1, tree), metadata + ) + + flat_with_keys, treedef = jax.tree_util.tree_flatten_with_path(metadata) + keys, values = zip(*flat_with_keys) + expected_keys = ( + list(tree.keys()) + if isinstance(tree, dict) + else [str(i) for i in range(len(tree))] + ) + self.assertSequenceEqual( + expected_keys, [tree_utils.tuple_path_from_keypath(k)[0] for k in keys] + ) + self.assertSequenceEqual(values, [2, 3]) + + flat, _ = jax.tree.flatten(metadata) + unflat = jax.tree.unflatten(treedef, flat) + self.assertIsInstance(unflat, tree_metadata_lib.TreeMetadata) + self._check_tree_property(jax.tree.map(lambda x: x + 1, tree), unflat) + + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/checkpoint_utils.py b/checkpoint/orbax/checkpoint/checkpoint_utils.py index c6a6f30b..baa83fbf 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_utils.py +++ b/checkpoint/orbax/checkpoint/checkpoint_utils.py @@ -28,6 +28,7 @@ from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.path.snapshot import snapshot as snapshot_lib from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.metadata import tree as tree_metadata PyTree = Any @@ -465,4 +466,10 @@ def _restore_args( sharding_tree = jax.tree.map( lambda x: x.sharding if hasattr(x, 'sharding') else None, target ) - return jax.tree.map(_restore_args, target, sharding_tree) + if isinstance(target, tree_metadata.TreeMetadata): + return tree_metadata.TreeMetadata( + tree=jax.tree.map(_restore_args, target.tree, sharding_tree.tree), + **target.properties(include_tree=False), + ) + else: + return jax.tree.map(_restore_args, target, sharding_tree)