From 3a38d9a84c9d34f013503fc5feaa7468f4045434 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 31 Oct 2023 13:51:25 -0700 Subject: [PATCH] make deterministic, prioritize sharded --- olmo/util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index baadca9be..8b7da6321 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -488,7 +488,8 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: else: latest_step = 0 latest_checkpoint: Optional[Path] = None - for path in Path(dir).glob("step*"): + # Sorting here guarantees that we prioritize sharded checkpoints over unsharded checkpoints. + for path in sorted(Path(dir).glob("step*")): if path.is_dir(): try: step = int(path.name.replace("step", "").replace("-unsharded", "")) @@ -649,7 +650,8 @@ def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: assert not response["IsTruncated"] # need to handle this if it happens latest_step = 0 latest_checkpoint: Optional[str] = None - for item in response["CommonPrefixes"]: + # Sorting here guarantees that we prioritize sharded checkpoints over unsharded checkpoints. + for item in sorted(response["CommonPrefixes"], key=lambda x: x["Prefix"]): prefix = item["Prefix"].strip("/") checkpoint_name = os.path.split(prefix)[-1] if not checkpoint_name.startswith("step"):