From ae6fadd1834bd72ca3ed50b7899af9e0b6bc9da4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 31 Oct 2023 12:06:08 -0700 Subject: [PATCH] Add a YAML validator to automatically find the last checkpoint --- configs/mcli/v1_5-mix-medium-mitch-ish.yaml | 3 +- olmo/config.py | 14 ++++++ olmo/util.py | 50 +++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/configs/mcli/v1_5-mix-medium-mitch-ish.yaml b/configs/mcli/v1_5-mix-medium-mitch-ish.yaml index e15c494eb..8f8e5c493 100644 --- a/configs/mcli/v1_5-mix-medium-mitch-ish.yaml +++ b/configs/mcli/v1_5-mix-medium-mitch-ish.yaml @@ -29,4 +29,5 @@ command: |- scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \ --run_name=v1_5-mix-mitch-ish \ --wandb.name=v1_5-mix-mitch-ish-mcli \ - --global_train_batch_size=2160 + --global_train_batch_size=2160 \ + --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}' diff --git a/olmo/config.py b/olmo/config.py index af3327dac..a5964ec0c 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -95,8 +95,22 @@ def path_choose(*paths) -> str: else: return "" + # Finds the latest checkpoint in a folder. + def path_last_checkpoint(path) -> str: + from .util import find_latest_checkpoint + + latest_checkpoint = find_latest_checkpoint(path) + if latest_checkpoint is None: + if validate_paths: + raise FileNotFoundError(f"Could not find a latest checkpoint at {path}") + else: + return "" + else: + return str(latest_checkpoint) + om.register_new_resolver("path.glob", path_glob, replace=True) om.register_new_resolver("path.choose", path_choose, replace=True) + om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True) @classmethod def new(cls: Type[C], **kwargs) -> C: diff --git a/olmo/util.py b/olmo/util.py index 2b0649021..baadca9be 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -472,6 +472,34 @@ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> byte return f.read(num_bytes) +def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: + if is_url(dir): + from urllib.parse import urlparse + + parsed = urlparse(str(dir)) + if parsed.scheme == "gs": + raise NotImplementedError + elif parsed.scheme == "s3": + return _s3_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/")) + elif parsed.scheme == "file": + return find_latest_checkpoint(str(dir).replace("file://", "", 1)) + else: + raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files") + else: + latest_step = 0 + latest_checkpoint: Optional[Path] = None + for path in Path(dir).glob("step*"): + if path.is_dir(): + try: + step = int(path.name.replace("step", "").replace("-unsharded", "")) + except ValueError: + continue + if step > latest_step: + latest_step = step + latest_checkpoint = path + return latest_checkpoint + + def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): from google.cloud import storage as gcs @@ -614,6 +642,28 @@ def _s3_get_bytes_range( raise OlmoNetworkError("Failed to get bytes range from s3") from err +def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: + if not prefix.endswith("/"): + prefix = f"{prefix}/" + response = _get_s3_client().list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/") + assert not response["IsTruncated"] # need to handle this if it happens + latest_step = 0 + latest_checkpoint: Optional[str] = None + for item in response["CommonPrefixes"]: + prefix = item["Prefix"].strip("/") + checkpoint_name = os.path.split(prefix)[-1] + if not checkpoint_name.startswith("step"): + continue + try: + step = int(checkpoint_name.replace("step", "").replace("-unsharded", "")) + except ValueError: + continue + if step > latest_step: + latest_step = step + latest_checkpoint = f"s3://ai2-llm/{prefix}" + return latest_checkpoint + + def is_weight_decay_module(module: nn.Module) -> bool: """Returns true if the module should use weight decay.""" from .model import LayerNormBase