From 4d83dfb4a759d488c878fa6cf05cd86184ff86ad Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Wed, 9 Nov 2022 11:43:01 -0800 Subject: [PATCH] PathLike pytype cleanup PiperOrigin-RevId: 487305127 --- orbax/checkpoint/abstract_checkpointer.py | 8 ++++---- orbax/checkpoint/async_checkpointer.py | 6 +++--- orbax/checkpoint/checkpointer.py | 8 ++++---- orbax/checkpoint/utils.py | 19 +++++++++---------- setup.py | 1 + 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/orbax/checkpoint/abstract_checkpointer.py b/orbax/checkpoint/abstract_checkpointer.py index 64d1a312..542f3e9f 100644 --- a/orbax/checkpoint/abstract_checkpointer.py +++ b/orbax/checkpoint/abstract_checkpointer.py @@ -15,7 +15,7 @@ """AbstractCheckpointer.""" import abc -from typing import Any, Optional, Union +from typing import Any, Optional from etils import epath @@ -30,7 +30,7 @@ class AbstractCheckpointer(abc.ABC): """ @abc.abstractmethod - def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs): + def save(self, directory: epath.PathLike, item: Any, *args, **kwargs): """Saves the given item to the provided directory. Args: @@ -44,7 +44,7 @@ def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs): @abc.abstractmethod def restore(self, - directory: Union[str, epath.Path], + directory: epath.PathLike, *args, item: Optional[Any] = None, **kwargs) -> Any: @@ -66,7 +66,7 @@ def restore(self, pass @abc.abstractmethod - def structure(self, directory: Union[str, epath.Path]) -> Optional[Any]: + def structure(self, directory: epath.PathLike) -> Optional[Any]: """The structure of the saved object at `directory`. Delegates to underlying handler. diff --git a/orbax/checkpoint/async_checkpointer.py b/orbax/checkpoint/async_checkpointer.py index c2957bf6..56b914e4 100644 --- a/orbax/checkpoint/async_checkpointer.py +++ b/orbax/checkpoint/async_checkpointer.py @@ -16,7 +16,7 @@ import asyncio import functools -from typing import Any, Union, Optional +from typing import Any, Optional from absl import logging from etils import epath @@ -44,7 +44,7 @@ def __init__(self, handler: AsyncCheckpointHandler, timeout_secs: int = 300): self._handler = handler AsyncManager.__init__(self, timeout_secs=timeout_secs) - def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs): + def save(self, directory: epath.PathLike, item: Any, *args, **kwargs): """Saves the given item to the provided directory. Delegates to the underlying CheckpointHandler. Ensures save operation @@ -82,7 +82,7 @@ def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs): functools.partial(utils.ensure_atomic_save, tmpdir, directory)) def restore(self, - directory: Union[str, epath.Path], + directory: epath.PathLike, *args, item: Optional[Any] = None, **kwargs) -> Any: diff --git a/orbax/checkpoint/checkpointer.py b/orbax/checkpoint/checkpointer.py index f3abab1b..cfa16f0a 100644 --- a/orbax/checkpoint/checkpointer.py +++ b/orbax/checkpoint/checkpointer.py @@ -14,7 +14,7 @@ """Synchronous Checkpointer implementation.""" -from typing import Any, Optional, Union +from typing import Any, Optional from absl import logging from etils import epath @@ -36,7 +36,7 @@ def __init__(self, handler: CheckpointHandler): self._handler = handler def save(self, - directory: Union[str, epath.Path], + directory: epath.PathLike, item: Any, *args, force: bool = False, @@ -76,7 +76,7 @@ def save(self, multihost_utils.sync_global_devices('Checkpointer:save') def restore(self, - directory: Union[str, epath.Path], + directory: epath.PathLike, *args, item: Optional[Any] = None, **kwargs) -> Any: @@ -89,7 +89,7 @@ def restore(self, logging.info('Restoring item from %s.', directory) return self._handler.restore(directory, *args, item=item, **kwargs) - def structure(self, directory: Union[str, epath.Path]) -> Optional[Any]: + def structure(self, directory: epath.PathLike) -> Optional[Any]: """See superclass documentation.""" directory = epath.Path(directory) try: diff --git a/orbax/checkpoint/utils.py b/orbax/checkpoint/utils.py index 432655c6..aec97813 100644 --- a/orbax/checkpoint/utils.py +++ b/orbax/checkpoint/utils.py @@ -17,7 +17,7 @@ import functools import os import time -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple from absl import logging from etils import epath @@ -32,7 +32,6 @@ _GCS_PATH_PREFIX = 'gs://' CheckpointDirs = Tuple[str, str] PyTree = type(jax.tree_util.tree_structure(None)) -Path = Union[str, epath.Path] def _wrap(func): @@ -86,7 +85,7 @@ def rmtree(path: epath.Path): Leaf = str -def pytree_structure(directory: Path) -> PyTree: +def pytree_structure(directory: epath.PathLike) -> PyTree: """Reconstruct state dict from saved model format in `directory`.""" directory = epath.Path(directory) @@ -138,7 +137,7 @@ def to_state_dict(pytree): return _rebuild_ts_specs(state_dict) -def cleanup_tmp_directories(directory: Path): +def cleanup_tmp_directories(directory: epath.PathLike): """Cleanup steps in `directory` with tmp files, as these are not finalized.""" directory = epath.Path(directory) if jax.process_index() == 0: @@ -154,7 +153,7 @@ def is_gcs_path(path: epath.Path): def get_save_directory(step: int, - directory: Path, + directory: epath.PathLike, name: Optional[str] = None, step_prefix: Optional[str] = None) -> epath.Path: """Returns the standardized path to a save directory for a single item.""" @@ -168,7 +167,7 @@ def get_save_directory(step: int, return result -def create_tmp_directory(final_dir: Path) -> epath.Path: +def create_tmp_directory(final_dir: epath.PathLike) -> epath.Path: """Creates a temporary directory for saving at the given path.""" # Share a timestamp across devices. final_dir = epath.Path(final_dir) @@ -208,7 +207,7 @@ def is_scalar(x): return isinstance(x, (int, float, np.number)) -def is_checkpoint_finalized(path: Path) -> bool: +def is_checkpoint_finalized(path: epath.PathLike) -> bool: """Determines if the checkpoint path is finalized. Path takes the form: @@ -239,7 +238,7 @@ def is_checkpoint_finalized(path: Path) -> bool: return True -def checkpoint_steps(checkpoint_dir: Path) -> List[int]: +def checkpoint_steps(checkpoint_dir: epath.PathLike) -> List[int]: checkpoint_dir = epath.Path(checkpoint_dir) return [ int(os.fspath(s.name)) @@ -248,7 +247,7 @@ def checkpoint_steps(checkpoint_dir: Path) -> List[int]: ] -def tmp_checkpoints(checkpoint_dir: Path) -> List[str]: +def tmp_checkpoints(checkpoint_dir: epath.PathLike) -> List[str]: checkpoint_dir = epath.Path(checkpoint_dir) return [ s.name for s in checkpoint_dir.iterdir() if not is_checkpoint_finalized(s) @@ -287,7 +286,7 @@ def _wait_for_new_checkpoint(checkpoint_dir: epath.Path, return checkpoint_step -def checkpoints_iterator(checkpoint_dir: Path, +def checkpoints_iterator(checkpoint_dir: epath.PathLike, min_interval_secs=0, timeout=None, timeout_fn=None) -> Iterator[int]: diff --git a/setup.py b/setup.py index 9baaf381..4b01345d 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ 'importlib_resources', 'etils', 'flax', + 'importlib_resources', 'jax', 'jaxlib', 'numpy',