Skip to content

Commit

Permalink
PathLike pytype cleanup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 487305127
  • Loading branch information
Orbax Authors authored and copybara-github committed Nov 9, 2022
1 parent 7bba53b commit 4d83dfb
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
8 changes: 4 additions & 4 deletions orbax/checkpoint/abstract_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""AbstractCheckpointer."""

import abc
from typing import Any, Optional, Union
from typing import Any, Optional
from etils import epath


Expand All @@ -30,7 +30,7 @@ class AbstractCheckpointer(abc.ABC):
"""

@abc.abstractmethod
def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs):
def save(self, directory: epath.PathLike, item: Any, *args, **kwargs):
"""Saves the given item to the provided directory.
Args:
Expand All @@ -44,7 +44,7 @@ def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs):

@abc.abstractmethod
def restore(self,
directory: Union[str, epath.Path],
directory: epath.PathLike,
*args,
item: Optional[Any] = None,
**kwargs) -> Any:
Expand All @@ -66,7 +66,7 @@ def restore(self,
pass

@abc.abstractmethod
def structure(self, directory: Union[str, epath.Path]) -> Optional[Any]:
def structure(self, directory: epath.PathLike) -> Optional[Any]:
"""The structure of the saved object at `directory`.
Delegates to underlying handler.
Expand Down
6 changes: 3 additions & 3 deletions orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import asyncio
import functools
from typing import Any, Union, Optional
from typing import Any, Optional

from absl import logging
from etils import epath
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self, handler: AsyncCheckpointHandler, timeout_secs: int = 300):
self._handler = handler
AsyncManager.__init__(self, timeout_secs=timeout_secs)

def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs):
def save(self, directory: epath.PathLike, item: Any, *args, **kwargs):
"""Saves the given item to the provided directory.
Delegates to the underlying CheckpointHandler. Ensures save operation
Expand Down Expand Up @@ -82,7 +82,7 @@ def save(self, directory: Union[str, epath.Path], item: Any, *args, **kwargs):
functools.partial(utils.ensure_atomic_save, tmpdir, directory))

def restore(self,
directory: Union[str, epath.Path],
directory: epath.PathLike,
*args,
item: Optional[Any] = None,
**kwargs) -> Any:
Expand Down
8 changes: 4 additions & 4 deletions orbax/checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Synchronous Checkpointer implementation."""

from typing import Any, Optional, Union
from typing import Any, Optional

from absl import logging
from etils import epath
Expand All @@ -36,7 +36,7 @@ def __init__(self, handler: CheckpointHandler):
self._handler = handler

def save(self,
directory: Union[str, epath.Path],
directory: epath.PathLike,
item: Any,
*args,
force: bool = False,
Expand Down Expand Up @@ -76,7 +76,7 @@ def save(self,
multihost_utils.sync_global_devices('Checkpointer:save')

def restore(self,
directory: Union[str, epath.Path],
directory: epath.PathLike,
*args,
item: Optional[Any] = None,
**kwargs) -> Any:
Expand All @@ -89,7 +89,7 @@ def restore(self,
logging.info('Restoring item from %s.', directory)
return self._handler.restore(directory, *args, item=item, **kwargs)

def structure(self, directory: Union[str, epath.Path]) -> Optional[Any]:
def structure(self, directory: epath.PathLike) -> Optional[Any]:
"""See superclass documentation."""
directory = epath.Path(directory)
try:
Expand Down
19 changes: 9 additions & 10 deletions orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import os
import time
from typing import Iterator, List, Optional, Tuple, Union
from typing import Iterator, List, Optional, Tuple

from absl import logging
from etils import epath
Expand All @@ -32,7 +32,6 @@
_GCS_PATH_PREFIX = 'gs://'
CheckpointDirs = Tuple[str, str]
PyTree = type(jax.tree_util.tree_structure(None))
Path = Union[str, epath.Path]


def _wrap(func):
Expand Down Expand Up @@ -86,7 +85,7 @@ def rmtree(path: epath.Path):
Leaf = str


def pytree_structure(directory: Path) -> PyTree:
def pytree_structure(directory: epath.PathLike) -> PyTree:
"""Reconstruct state dict from saved model format in `directory`."""
directory = epath.Path(directory)

Expand Down Expand Up @@ -138,7 +137,7 @@ def to_state_dict(pytree):
return _rebuild_ts_specs(state_dict)


def cleanup_tmp_directories(directory: Path):
def cleanup_tmp_directories(directory: epath.PathLike):
"""Cleanup steps in `directory` with tmp files, as these are not finalized."""
directory = epath.Path(directory)
if jax.process_index() == 0:
Expand All @@ -154,7 +153,7 @@ def is_gcs_path(path: epath.Path):


def get_save_directory(step: int,
directory: Path,
directory: epath.PathLike,
name: Optional[str] = None,
step_prefix: Optional[str] = None) -> epath.Path:
"""Returns the standardized path to a save directory for a single item."""
Expand All @@ -168,7 +167,7 @@ def get_save_directory(step: int,
return result


def create_tmp_directory(final_dir: Path) -> epath.Path:
def create_tmp_directory(final_dir: epath.PathLike) -> epath.Path:
"""Creates a temporary directory for saving at the given path."""
# Share a timestamp across devices.
final_dir = epath.Path(final_dir)
Expand Down Expand Up @@ -208,7 +207,7 @@ def is_scalar(x):
return isinstance(x, (int, float, np.number))


def is_checkpoint_finalized(path: Path) -> bool:
def is_checkpoint_finalized(path: epath.PathLike) -> bool:
"""Determines if the checkpoint path is finalized.
Path takes the form:
Expand Down Expand Up @@ -239,7 +238,7 @@ def is_checkpoint_finalized(path: Path) -> bool:
return True


def checkpoint_steps(checkpoint_dir: Path) -> List[int]:
def checkpoint_steps(checkpoint_dir: epath.PathLike) -> List[int]:
checkpoint_dir = epath.Path(checkpoint_dir)
return [
int(os.fspath(s.name))
Expand All @@ -248,7 +247,7 @@ def checkpoint_steps(checkpoint_dir: Path) -> List[int]:
]


def tmp_checkpoints(checkpoint_dir: Path) -> List[str]:
def tmp_checkpoints(checkpoint_dir: epath.PathLike) -> List[str]:
checkpoint_dir = epath.Path(checkpoint_dir)
return [
s.name for s in checkpoint_dir.iterdir() if not is_checkpoint_finalized(s)
Expand Down Expand Up @@ -287,7 +286,7 @@ def _wait_for_new_checkpoint(checkpoint_dir: epath.Path,
return checkpoint_step


def checkpoints_iterator(checkpoint_dir: Path,
def checkpoints_iterator(checkpoint_dir: epath.PathLike,
min_interval_secs=0,
timeout=None,
timeout_fn=None) -> Iterator[int]:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
'importlib_resources',
'etils',
'flax',
'importlib_resources',
'jax',
'jaxlib',
'numpy',
Expand Down

0 comments on commit 4d83dfb

Please sign in to comment.