diff --git a/checkpoint/orbax/checkpoint/path/step.py b/checkpoint/orbax/checkpoint/path/step.py index 2385c74f..96f2e708 100644 --- a/checkpoint/orbax/checkpoint/path/step.py +++ b/checkpoint/orbax/checkpoint/path/step.py @@ -185,9 +185,12 @@ def latest_step_metadata( root_path: epath.PathLike, name_format: NameFormat[MetadataT] ) -> Optional[MetadataT]: """Returns step.MetadataT of the latest step in `root_path`.""" - root_path = epath.Path(root_path) return max( - name_format.find_all(root_path), + sorted( + name_format.find_all(root_path), + key=lambda metadata: metadata.path.name, + reverse=True, + ), default=None, key=lambda metadata: metadata.step, )