Skip to content

Commit

Permalink
Fix a bug in Orbax checkpointing where None values PyTree are not han…
Browse files Browse the repository at this point in the history
…dled correctly. This is caused by a recent update in jax where None values are no longer considered as a leaf node: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-34-october-4-2024.

PiperOrigin-RevId: 691828986
  • Loading branch information
Orbax Authors committed Oct 31, 2024
1 parent 9a28c7a commit 3803f95
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- Fix a bug in Orbax checkpointing where None values PyTree are not handled
correctly. This is caused by a recent update in jax where None values are no
longer considered as a leaf node:
https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-34-october-4-2024.


### Added
- Local type handler registries.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def _group_value(
):
nonlocal grouped
tuple_key = tree_utils.tuple_path_from_keypath(keypath)

if info is None or arg is None or value is None:
return

if info.skip_deserialize:
return

Expand Down Expand Up @@ -226,6 +230,7 @@ def _group_value(
param_infos,
tree,
args,
is_leaf=lambda x: x is None,
)
return list(grouped.values())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,10 @@ def _process_aggregated_value(meta_or_value, args):

flat_aggregate = tree_utils.to_flat_dict(
jax.tree_util.tree_map(
_process_aggregated_value, metadata, restore_args
lambda x, y: None if x is None else _process_aggregated_value(x, y),
metadata,
restore_args,
is_leaf=lambda x: x is None,
),
)

Expand Down Expand Up @@ -814,7 +817,12 @@ def _maybe_set_default_restore_types(value_meta: Any, arg: RestoreArgs):
# If metadata file was missing in the checkpoint, we need to decide
# restore_type based on RestoreArgs.
structure = jax.tree.map(
_maybe_set_default_restore_types, structure, checkpoint_restore_args
lambda x, y: None
if x is None
else _maybe_set_default_restore_types(x, y),
structure,
checkpoint_restore_args,
is_leaf=lambda x: x is None,
)

restored_item = asyncio_utils.run_sync(
Expand Down

0 comments on commit 3803f95

Please sign in to comment.