Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resubmit single array error message #1323

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
import numbers
from typing import Any, List, Optional

from absl import logging
Expand Down Expand Up @@ -91,6 +92,11 @@ def _validate_save_state(
):
if item is None:
raise ValueError('Must provide item to save.')
if isinstance(item, jax.Array | numbers.Number):
raise ValueError(
'StandardCheckpointHandler / StandardSave does not support single '
'arrays or scalars. Use ArrayCheckpointHandler / ArraySave'
)
if save_args is None:
save_args = jax.tree.map(lambda x: None, item)

Expand Down
8 changes: 8 additions & 0 deletions checkpoint/orbax/checkpoint/single_host_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils
from orbax.checkpoint._src.serialization import type_handlers
import tensorstore as ts

Expand Down Expand Up @@ -63,6 +64,13 @@ def test_save_and_restore_jax_array(self, use_zarr3):
np.testing.assert_array_equal(x, restored_tree['x'])
assert isinstance(restored_tree['x'], jax.Array)

@parameterized.parameters({'x': jnp.array([1, 2])}, {'x': 1})
def test_save_singular_array_with_standard_checkpoint_handler(self, x):
handler = standard_checkpoint_handler_test_utils.StandardCheckpointHandler()
with self.assertRaisesRegex(ValueError,
'.*Use ArrayCheckpointHandler / ArraySave.*'):
handler.save(self.ckpt_dir, args=standard_checkpoint_handler_test_utils.StandardSaveArgs(x))

def test_save_and_restore_zarrv3_jax_array_default_chunk_size(self):
handler = PyTreeCheckpointHandler(use_zarr3=True)
key = jax.random.PRNGKey(0)
Expand Down
Loading