diff --git a/configs/mcli/v1_5-mix-medium-mitch-ish.yaml b/configs/mcli/v1_5-mix-medium-mitch-ish.yaml index e15c494eb..8f8e5c493 100644 --- a/configs/mcli/v1_5-mix-medium-mitch-ish.yaml +++ b/configs/mcli/v1_5-mix-medium-mitch-ish.yaml @@ -29,4 +29,5 @@ command: |- scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \ --run_name=v1_5-mix-mitch-ish \ --wandb.name=v1_5-mix-mitch-ish-mcli \ - --global_train_batch_size=2160 + --global_train_batch_size=2160 \ + --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}' diff --git a/olmo/config.py b/olmo/config.py index af3327dac..6a8ae2001 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -95,8 +95,22 @@ def path_choose(*paths) -> str: else: return "" + # Finds the latest checkpoint in a folder. + def path_last_checkpoint(path) -> str: + from .util import find_latest_checkpoint + + latest_checkpoint = find_latest_checkpoint(path) + if latest_checkpoint is None: + if validate_paths: + raise FileNotFoundError(f"Could not find a latest checkpoint at {path}") + else: + return "" + else: + return str(latest_checkpoint) + om.register_new_resolver("path.glob", path_glob, replace=True) om.register_new_resolver("path.choose", path_choose, replace=True) + om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True) @classmethod def new(cls: Type[C], **kwargs) -> C: @@ -720,6 +734,14 @@ class TrainConfig(BaseConfig): load_path: Optional[str] = None """ The path to a training checkpoint to restore/resume from. + + Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes + a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory. + For example, + + ```bash + --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}' + ``` """ load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None diff --git a/olmo/util.py b/olmo/util.py index 2b0649021..331274b0b 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -472,6 +472,35 @@ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> byte return f.read(num_bytes) +def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]: + if is_url(dir): + from urllib.parse import urlparse + + parsed = urlparse(str(dir)) + if parsed.scheme == "gs": + raise NotImplementedError + elif parsed.scheme == "s3": + return _s3_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/")) + elif parsed.scheme == "file": + return find_latest_checkpoint(str(dir).replace("file://", "", 1)) + else: + raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files") + else: + latest_step = 0 + latest_checkpoint: Optional[Path] = None + for path in Path(dir).glob("step*"): + if path.is_dir(): + try: + step = int(path.name.replace("step", "").replace("-unsharded", "")) + except ValueError: + continue + # We prioritize sharded checkpoints over unsharded checkpoints. + if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")): + latest_step = step + latest_checkpoint = path + return latest_checkpoint + + def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): from google.cloud import storage as gcs @@ -614,6 +643,29 @@ def _s3_get_bytes_range( raise OlmoNetworkError("Failed to get bytes range from s3") from err +def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]: + if not prefix.endswith("/"): + prefix = f"{prefix}/" + response = _get_s3_client().list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/") + assert not response["IsTruncated"] # need to handle this if it happens + latest_step = 0 + latest_checkpoint: Optional[str] = None + for item in response["CommonPrefixes"]: + prefix = item["Prefix"].strip("/") + checkpoint_name = os.path.split(prefix)[-1] + if not checkpoint_name.startswith("step"): + continue + try: + step = int(checkpoint_name.replace("step", "").replace("-unsharded", "")) + except ValueError: + continue + # We prioritize sharded checkpoints over unsharded ones. + if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")): + latest_step = step + latest_checkpoint = f"s3://ai2-llm/{prefix}" + return latest_checkpoint + + def is_weight_decay_module(module: nn.Module) -> bool: """Returns true if the module should use weight decay.""" from .model import LayerNormBase