Skip to content

Commit

Permalink
Manual submit of #1323.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695816989
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Nov 12, 2024
1 parent 174d6aa commit 2a10358
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
import numbers
from typing import Any, List, Optional

from absl import logging
Expand Down Expand Up @@ -91,6 +92,11 @@ def _validate_save_state(
):
if item is None:
raise ValueError('Must provide item to save.')
if isinstance(item, jax.Array | numbers.Number):
raise ValueError(
'StandardCheckpointHandler / StandardSave does not support single '
'arrays or scalars. Use ArrayCheckpointHandler / ArraySave'
)
if save_args is None:
save_args = jax.tree.map(lambda x: None, item)

Expand Down
12 changes: 12 additions & 0 deletions checkpoint/orbax/checkpoint/single_host_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.handlers import pytree_checkpoint_handler
from orbax.checkpoint._src.handlers import standard_checkpoint_handler_test_utils
from orbax.checkpoint._src.serialization import type_handlers
import tensorstore as ts

Expand Down Expand Up @@ -147,6 +148,17 @@ def test_chunk_byte_size(self):
np.testing.assert_array_equal(x, restored_tree['x'])
assert isinstance(restored_tree['x'], jax.Array)

@parameterized.parameters({'x': jnp.array([1, 2])}, {'x': 1})
def test_save_singular_array_with_standard_checkpoint_handler(self, x):
handler = standard_checkpoint_handler_test_utils.StandardCheckpointHandler()
with self.assertRaisesRegex(
ValueError, '.*Use ArrayCheckpointHandler / ArraySave.*'
):
handler.save(
self.ckpt_dir,
args=standard_checkpoint_handler_test_utils.StandardSaveArgs(x),
)

@parameterized.product(
dtype=[
jnp.bfloat16,
Expand Down

0 comments on commit 2a10358

Please sign in to comment.