From ab96b72ec9ba5ea50cef19298f23111ec14cdbf8 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Tue, 12 Nov 2024 08:54:00 -0800 Subject: [PATCH 1/2] Resubmit of PR1304. --- checkpoint/orbax/checkpoint/single_host_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/checkpoint/orbax/checkpoint/single_host_test.py b/checkpoint/orbax/checkpoint/single_host_test.py index 60503e09..949ab561 100644 --- a/checkpoint/orbax/checkpoint/single_host_test.py +++ b/checkpoint/orbax/checkpoint/single_host_test.py @@ -22,7 +22,6 @@ import ml_dtypes import numpy as np from orbax.checkpoint import test_utils -from orbax.checkpoint.args import StandardSave from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils from orbax.checkpoint._src.serialization import type_handlers @@ -70,7 +69,7 @@ def test_save_singular_array_with_standard_checkpoint_handler(self, x): handler = standard_checkpoint_handler_test_utils.StandardCheckpointHandler() with self.assertRaisesRegex(ValueError, '.*Use ArrayCheckpointHandler / ArraySave.*'): - handler.save(self.ckpt_dir, args=StandardSave(x)) + handler.save(self.ckpt_dir, args=standard_checkpoint_handler_test_utils.StandardSaveArgs(x)) def test_save_and_restore_zarrv3_jax_array_default_chunk_size(self): handler = PyTreeCheckpointHandler(use_zarr3=True) From 8989abc8acb49b12098f76eec16da2e9219ee163 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Tue, 12 Nov 2024 09:22:04 -0800 Subject: [PATCH 2/2] Resubmit PR 1304. --- .../checkpoint/_src/handlers/standard_checkpoint_handler.py | 6 ++++++ checkpoint/orbax/checkpoint/single_host_test.py | 1 + 2 files changed, 7 insertions(+) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py index acbf3640..25adba96 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +import numbers from typing import Any, List, Optional from absl import logging @@ -91,6 +92,11 @@ def _validate_save_state( ): if item is None: raise ValueError('Must provide item to save.') + if isinstance(item, jax.Array | numbers.Number): + raise ValueError( + 'StandardCheckpointHandler / StandardSave does not support single ' + 'arrays or scalars. Use ArrayCheckpointHandler / ArraySave' + ) if save_args is None: save_args = jax.tree.map(lambda x: None, item) diff --git a/checkpoint/orbax/checkpoint/single_host_test.py b/checkpoint/orbax/checkpoint/single_host_test.py index 983c8f24..949ab561 100644 --- a/checkpoint/orbax/checkpoint/single_host_test.py +++ b/checkpoint/orbax/checkpoint/single_host_test.py @@ -23,6 +23,7 @@ import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils from orbax.checkpoint._src.serialization import type_handlers import tensorstore as ts