diff --git a/orbax/__init__.py b/orbax/__init__.py index 328c2caa..138f6075 100644 --- a/orbax/__init__.py +++ b/orbax/__init__.py @@ -15,4 +15,4 @@ """Orbax API.""" # A new PyPI release will be pushed everytime `__version__` is increased. -__version__ = '0.0.18' +__version__ = '0.0.19' diff --git a/orbax/checkpoint/checkpoint_manager.py b/orbax/checkpoint/checkpoint_manager.py index 98332416..de37ac6a 100644 --- a/orbax/checkpoint/checkpoint_manager.py +++ b/orbax/checkpoint/checkpoint_manager.py @@ -78,6 +78,8 @@ class CheckpointManagerOptions: function. best_mode: one of ['max', 'min']. The best metric is determine on the basis of this value. + keep_checkpoints_without_metrics: If False, checkpoints with metrics present + are eligible for cleanup. Otherwise, they will never be deleted. step_prefix: if provided, step directories will take the form f'{step_prefix}_'. Otherwise, they will simply be an integer . @@ -88,8 +90,15 @@ class CheckpointManagerOptions: keep_period: Optional[int] = None best_fn: Optional[Callable[[PyTree], float]] = None best_mode: str = 'max' + keep_checkpoints_without_metrics: bool = True step_prefix: Optional[str] = None + def __post_init__(self): + if self.best_mode not in ('min', 'max'): + msg = ("`CheckpointManagerOptions.best_mode` must be one of None, 'min' " + "or 'max'. Got {self.dtype}.") + raise ValueError(msg) + @dataclasses.dataclass class CheckpointInfo: @@ -213,6 +222,8 @@ def best_step(self) -> Optional[int]: if not self._checkpoints: return None _, sorted_checkpoints = self._sort_checkpoints_by_metrics(self._checkpoints) + if not sorted_checkpoints: + return None return sorted_checkpoints[-1].step def should_save(self, step: int) -> bool: @@ -584,7 +595,7 @@ def get_metrics(step): for s, t, m in zip(steps, times, metrics) ] - def _add_checkpoint_info(self, step, metrics): + def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]): self._checkpoints.append( CheckpointInfo(step, datetime.datetime.now(tz=datetime.timezone.utc), metrics)) @@ -636,8 +647,12 @@ def _delete_directory(self, step: int): def _remove_old_checkpoints(self): """Keeps the `max_to_keep` most recent checkpoint steps.""" + # Must have set max_to_keep or keep_time_interval. if not self._options.max_to_keep and not self._options.keep_time_interval: return + # Not enough checkpoints accumulated to consider deletion. + if len(self._checkpoints) <= self._options.max_to_keep: + return if self._track_best: # Best steps (to keep) are at the end, after sorting. checkpoints_without_metrics, sorted_checkpoints = self._sort_checkpoints_by_metrics( @@ -647,12 +662,15 @@ def _remove_old_checkpoints(self): checkpoints_without_metrics = [] sorted_checkpoints = self._checkpoints - to_remove = len(sorted_checkpoints) - self._options.max_to_keep - if to_remove <= 0: - return - maybe_delete = sorted_checkpoints[:to_remove] - active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[ - to_remove:] + keep = int(self._options.max_to_keep) + if self._options.keep_checkpoints_without_metrics: + maybe_delete = sorted_checkpoints[:-keep] + active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[ + -keep:] + else: + all_checkpoints = checkpoints_without_metrics + sorted_checkpoints + maybe_delete = all_checkpoints[:-keep] + active_checkpoints = all_checkpoints[-keep:] kept_checkpoints = [] for info in maybe_delete: