Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into Llama
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Nov 1, 2023
2 parents 581538c + fd2425f commit d88e02d
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 4 deletions.
3 changes: 2 additions & 1 deletion configs/mcli/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ command: |-
scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \
--run_name=v1_5-mix-mitch-ish \
--wandb.name=v1_5-mix-mitch-ish-mcli \
--global_train_batch_size=2160
--global_train_batch_size=2160 \
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}'
22 changes: 22 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,22 @@ def path_choose(*paths) -> str:
else:
return ""

# Finds the latest checkpoint in a folder.
def path_last_checkpoint(path) -> str:
from .util import find_latest_checkpoint

latest_checkpoint = find_latest_checkpoint(path)
if latest_checkpoint is None:
if validate_paths:
raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
else:
return ""
else:
return str(latest_checkpoint)

om.register_new_resolver("path.glob", path_glob, replace=True)
om.register_new_resolver("path.choose", path_choose, replace=True)
om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)

@classmethod
def new(cls: Type[C], **kwargs) -> C:
Expand Down Expand Up @@ -730,6 +744,14 @@ class TrainConfig(BaseConfig):
load_path: Optional[str] = None
"""
The path to a training checkpoint to restore/resume from.
Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
For example,
```bash
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
```
"""

load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
Expand Down
11 changes: 8 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
move_to_device,
peak_gpu_memory,
syncronize_flag,
upload,
)

__all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
Expand Down Expand Up @@ -737,9 +738,13 @@ def on_trace_ready(p):
output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=32)
log.info(f"Profile by total CPU time at step {p.step_num}:\n{output}")

p.export_chrome_trace(str(profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
p.export_stacks(str(profiler_output_dir / f"{p.step_num}.gpu.stacks"), "self_cuda_time_total")
p.export_stacks(str(profiler_output_dir / f"{p.step_num}.cpu.stacks"), "self_cpu_time_total")
p.export_chrome_trace(
str(trace_path := (profiler_output_dir / f"{p.step_num}.chrome_trace.json.gz"))
)
if self.cfg.remote_save_folder is not None:
upload_folder = f"{self.cfg.remote_save_folder.rstrip('/')}/profiler"
log.info(f"Tracing complete, uploading results to '{upload_folder}'...")
upload(trace_path, f"{upload_folder}/{trace_path.name}")

from torch.profiler import ProfilerActivity

Expand Down
52 changes: 52 additions & 0 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,35 @@ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> byte
return f.read(num_bytes)


def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
if is_url(dir):
from urllib.parse import urlparse

parsed = urlparse(str(dir))
if parsed.scheme == "gs":
raise NotImplementedError
elif parsed.scheme == "s3":
return _s3_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "file":
return find_latest_checkpoint(str(dir).replace("file://", "", 1))
else:
raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
else:
latest_step = 0
latest_checkpoint: Optional[Path] = None
for path in Path(dir).glob("step*"):
if path.is_dir():
try:
step = int(path.name.replace("step", "").replace("-unsharded", ""))
except ValueError:
continue
# We prioritize sharded checkpoints over unsharded checkpoints.
if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
latest_step = step
latest_checkpoint = path
return latest_checkpoint


def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
from google.cloud import storage as gcs

Expand Down Expand Up @@ -613,6 +642,29 @@ def _s3_get_bytes_range(
raise OlmoNetworkError("Failed to get bytes range from s3") from err


def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
if not prefix.endswith("/"):
prefix = f"{prefix}/"
response = _get_s3_client().list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
assert not response["IsTruncated"] # need to handle this if it happens
latest_step = 0
latest_checkpoint: Optional[str] = None
for item in response["CommonPrefixes"]:
prefix = item["Prefix"].strip("/")
checkpoint_name = os.path.split(prefix)[-1]
if not checkpoint_name.startswith("step"):
continue
try:
step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
except ValueError:
continue
# We prioritize sharded checkpoints over unsharded ones.
if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
latest_step = step
latest_checkpoint = f"s3://ai2-llm/{prefix}"
return latest_checkpoint


def default_thread_count() -> int:
return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))

Expand Down

0 comments on commit d88e02d

Please sign in to comment.