Skip to content

Commit

Permalink
Add support for tracking checkpoint metrics with Orbax in T5X.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 492213818
  • Loading branch information
cpgaffney1 authored and copybara-github committed Dec 1, 2022
1 parent 42a9dcc commit 1f0d7b0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""Orbax API."""

# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = '0.0.18'
__version__ = '0.0.19'
32 changes: 25 additions & 7 deletions orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class CheckpointManagerOptions:
function.
best_mode: one of ['max', 'min']. The best metric is determine on the basis of
this value.
keep_checkpoints_without_metrics: If False, checkpoints with metrics present
are eligible for cleanup. Otherwise, they will never be deleted.
step_prefix: if provided, step directories will take the form
f'{step_prefix}_<step>'. Otherwise, they will simply be an integer <step>.
Expand All @@ -88,8 +90,15 @@ class CheckpointManagerOptions:
keep_period: Optional[int] = None
best_fn: Optional[Callable[[PyTree], float]] = None
best_mode: str = 'max'
keep_checkpoints_without_metrics: bool = True
step_prefix: Optional[str] = None

def __post_init__(self):
if self.best_mode not in ('min', 'max'):
msg = ("`CheckpointManagerOptions.best_mode` must be one of None, 'min' "
"or 'max'. Got {self.dtype}.")
raise ValueError(msg)


@dataclasses.dataclass
class CheckpointInfo:
Expand Down Expand Up @@ -213,6 +222,8 @@ def best_step(self) -> Optional[int]:
if not self._checkpoints:
return None
_, sorted_checkpoints = self._sort_checkpoints_by_metrics(self._checkpoints)
if not sorted_checkpoints:
return None
return sorted_checkpoints[-1].step

def should_save(self, step: int) -> bool:
Expand Down Expand Up @@ -584,7 +595,7 @@ def get_metrics(step):
for s, t, m in zip(steps, times, metrics)
]

def _add_checkpoint_info(self, step, metrics):
def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
self._checkpoints.append(
CheckpointInfo(step, datetime.datetime.now(tz=datetime.timezone.utc),
metrics))
Expand Down Expand Up @@ -636,8 +647,12 @@ def _delete_directory(self, step: int):

def _remove_old_checkpoints(self):
"""Keeps the `max_to_keep` most recent checkpoint steps."""
# Must have set max_to_keep or keep_time_interval.
if not self._options.max_to_keep and not self._options.keep_time_interval:
return
# Not enough checkpoints accumulated to consider deletion.
if len(self._checkpoints) <= self._options.max_to_keep:
return
if self._track_best:
# Best steps (to keep) are at the end, after sorting.
checkpoints_without_metrics, sorted_checkpoints = self._sort_checkpoints_by_metrics(
Expand All @@ -647,12 +662,15 @@ def _remove_old_checkpoints(self):
checkpoints_without_metrics = []
sorted_checkpoints = self._checkpoints

to_remove = len(sorted_checkpoints) - self._options.max_to_keep
if to_remove <= 0:
return
maybe_delete = sorted_checkpoints[:to_remove]
active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[
to_remove:]
keep = int(self._options.max_to_keep)
if self._options.keep_checkpoints_without_metrics:
maybe_delete = sorted_checkpoints[:-keep]
active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[
-keep:]
else:
all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
maybe_delete = all_checkpoints[:-keep]
active_checkpoints = all_checkpoints[-keep:]

kept_checkpoints = []
for info in maybe_delete:
Expand Down

0 comments on commit 1f0d7b0

Please sign in to comment.