Skip to content

Commit

Permalink
Adds a YAML validator to automatically find the last checkpoint (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Oct 31, 2023
1 parent 1099942 commit fd2425f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
3 changes: 2 additions & 1 deletion configs/mcli/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ command: |-
scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \
--run_name=v1_5-mix-mitch-ish \
--wandb.name=v1_5-mix-mitch-ish-mcli \
--global_train_batch_size=2160
--global_train_batch_size=2160 \
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}'
22 changes: 22 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,22 @@ def path_choose(*paths) -> str:
else:
return ""

# Finds the latest checkpoint in a folder.
def path_last_checkpoint(path) -> str:
from .util import find_latest_checkpoint

latest_checkpoint = find_latest_checkpoint(path)
if latest_checkpoint is None:
if validate_paths:
raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
else:
return ""
else:
return str(latest_checkpoint)

om.register_new_resolver("path.glob", path_glob, replace=True)
om.register_new_resolver("path.choose", path_choose, replace=True)
om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)

@classmethod
def new(cls: Type[C], **kwargs) -> C:
Expand Down Expand Up @@ -720,6 +734,14 @@ class TrainConfig(BaseConfig):
load_path: Optional[str] = None
"""
The path to a training checkpoint to restore/resume from.
Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
For example,
```bash
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
```
"""

load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
Expand Down
52 changes: 52 additions & 0 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,35 @@ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> byte
return f.read(num_bytes)


def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
if is_url(dir):
from urllib.parse import urlparse

parsed = urlparse(str(dir))
if parsed.scheme == "gs":
raise NotImplementedError
elif parsed.scheme == "s3":
return _s3_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "file":
return find_latest_checkpoint(str(dir).replace("file://", "", 1))
else:
raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
else:
latest_step = 0
latest_checkpoint: Optional[Path] = None
for path in Path(dir).glob("step*"):
if path.is_dir():
try:
step = int(path.name.replace("step", "").replace("-unsharded", ""))
except ValueError:
continue
# We prioritize sharded checkpoints over unsharded checkpoints.
if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
latest_step = step
latest_checkpoint = path
return latest_checkpoint


def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
from google.cloud import storage as gcs

Expand Down Expand Up @@ -614,6 +643,29 @@ def _s3_get_bytes_range(
raise OlmoNetworkError("Failed to get bytes range from s3") from err


def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
if not prefix.endswith("/"):
prefix = f"{prefix}/"
response = _get_s3_client().list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
assert not response["IsTruncated"] # need to handle this if it happens
latest_step = 0
latest_checkpoint: Optional[str] = None
for item in response["CommonPrefixes"]:
prefix = item["Prefix"].strip("/")
checkpoint_name = os.path.split(prefix)[-1]
if not checkpoint_name.startswith("step"):
continue
try:
step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
except ValueError:
continue
# We prioritize sharded checkpoints over unsharded ones.
if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
latest_step = step
latest_checkpoint = f"s3://ai2-llm/{prefix}"
return latest_checkpoint


def is_weight_decay_module(module: nn.Module) -> bool:
"""Returns true if the module should use weight decay."""
from .model import LayerNormBase
Expand Down

0 comments on commit fd2425f

Please sign in to comment.