diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index acbf3640..25adba96 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +import numbers from typing import Any, List, Optional from absl import logging @@ -91,6 +92,11 @@ def _validate_save_state( ): if item is None: raise ValueError('Must provide item to save.') + if isinstance(item, jax.Array | numbers.Number): + raise ValueError( + 'StandardCheckpointHandler / StandardSave does not support single ' + 'arrays or scalars. Use ArrayCheckpointHandler / ArraySave' + ) if save_args is None: save_args = jax.tree.map(lambda x: None, item) diff --git a/checkpoint/orbax/checkpoint/single_host_test.py b/checkpoint/orbax/checkpoint/single_host_test.py index 2eeea371..ab98484f 100644 --- a/checkpoint/orbax/checkpoint/single_host_test.py +++ b/checkpoint/orbax/checkpoint/single_host_test.py @@ -23,6 +23,7 @@ import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils from orbax.checkpoint._src.serialization import type_handlers import tensorstore as ts @@ -147,6 +148,17 @@ def test_chunk_byte_size(self): np.testing.assert_array_equal(x, restored_tree['x']) assert isinstance(restored_tree['x'], jax.Array) + @parameterized.parameters({'x': jnp.array([1, 2])}, {'x': 1}) + def test_save_singular_array_with_standard_checkpoint_handler(self, x): + handler = standard_checkpoint_handler_test_utils.StandardCheckpointHandler() + with self.assertRaisesRegex( + ValueError, '.*Use ArrayCheckpointHandler / ArraySave.*' + ): + handler.save( + self.ckpt_dir, + args=standard_checkpoint_handler_test_utils.StandardSaveArgs(x), + ) + @parameterized.product( dtype=[ jnp.bfloat16,