Skip to content

Commit

Permalink
Save the hyperparameters to the checkpoint (#1012)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent 03161f3 commit 9ffb47f
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 26 deletions.
11 changes: 8 additions & 3 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -143,8 +144,10 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
save_path = out_dir / "final" / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_adapter_checkpoint(fabric, model, save_path)
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -223,7 +226,9 @@ def fit(
checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_checkpoint(fabric, model, checkpoint_file)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# the adapter "kv cache" cannot be initialized under `inference_mode`
Expand Down
11 changes: 8 additions & 3 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -143,8 +144,10 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
save_path = out_dir / "final" / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_adapter_v2_checkpoint(fabric, model, save_path)
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -223,7 +226,9 @@ def fit(
checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_v2_checkpoint(fabric, model, checkpoint_file)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# the adapter "kv cache" cannot be initialized under `inference_mode`
Expand Down
11 changes: 8 additions & 3 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -137,8 +138,10 @@ def main(
save_path = out_dir / "final" / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
fabric.save(save_path, {"model": state["model"]})
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -242,7 +245,9 @@ def fit(
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file.parent)!r}")
fabric.save(checkpoint_file, state)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# FSDP has issues with `inference_mode`
Expand Down
11 changes: 8 additions & 3 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CycleIterator,
parse_devices,
copy_config_files,
save_hyperparameters,
)


Expand Down Expand Up @@ -174,8 +175,10 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat
save_path = out_dir / "final" / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, save_path)
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
if fabric.global_rank == 0:
# Copy checkpoint files from original checkpoint dir
copy_config_files(checkpoint_dir, save_path.parent)
save_hyperparameters(setup, save_path.parent)


def fit(
Expand Down Expand Up @@ -254,7 +257,9 @@ def fit(
checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
copy_config_files(checkpoint_dir, checkpoint_file.parent)
if fabric.global_rank == 0:
copy_config_files(checkpoint_dir, checkpoint_file.parent)
save_hyperparameters(setup, checkpoint_file.parent)


# FSDP has issues with `inference_mode`
Expand Down
8 changes: 5 additions & 3 deletions lit_gpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from lit_gpt.args import EvalArgs, TrainArgs
from lit_gpt.data import LitDataModule, TinyLlama
from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files
from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files, save_hyperparameters


def setup(
Expand Down Expand Up @@ -269,8 +269,10 @@ def fit(
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)


@torch.no_grad()
Expand Down
9 changes: 9 additions & 0 deletions lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,15 @@ def CLI(*args: Any, **kwargs: Any) -> Any:
return CLI(*args, **kwargs)


def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser

parser = capture_parser(lambda: CLI(function))
config = parser.parse_args()
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)


def parse_devices(devices: Union[str, int]) -> int:
if devices in (-1, "auto"):
return torch.cuda.device_count() or 1
Expand Down
5 changes: 3 additions & 2 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand All @@ -93,6 +93,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down Expand Up @@ -166,7 +167,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca
monkeypatch.setattr(module, "fit", train_mock)

stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand All @@ -116,6 +116,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down Expand Up @@ -255,7 +256,7 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp
monkeypatch.setattr(module, "fit", train_mock)

stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),
)
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]):
module.setup(**setup_kwargs)

out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {
assert set(os.listdir(out_dir / checkpoint_dir)) == {
"lit_model.pth",
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand All @@ -65,7 +66,7 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
setup_kwargs["train"].max_steps = 8
setup_kwargs["resume"] = True
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]):
module.setup(**setup_kwargs)
logs = stdout.getvalue()
assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs
Expand Down
5 changes: 3 additions & 2 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand All @@ -223,6 +223,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
"lit_config.json",
"tokenizer_config.json",
"tokenizer.json",
"hyperparameters.yaml",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()

Expand Down Expand Up @@ -624,7 +625,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa
monkeypatch.setattr(module, "fit", train_mock)

stdout = StringIO()
with redirect_stdout(stdout):
with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]):
module.setup(
data=Alpaca(
download_dir=alpaca_path.parent,
Expand Down
8 changes: 6 additions & 2 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
@RunIf(min_cuda_gpus=2, standalone=True)
# Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
def test_pretrain(tmp_path, monkeypatch):
# If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise
# the CLI would capture pytest args, but unfortunately patching would mess with subprocess
# launching, so we need to mock `save_hyperparameters()`
@mock.patch("lit_gpt.pretrain.save_hyperparameters")
def test_pretrain(_, tmp_path):
from lit_gpt import pretrain
from lit_gpt.args import EvalArgs, TrainArgs
from lit_gpt.config import Config
Expand Down Expand Up @@ -44,7 +48,7 @@ def test_pretrain(tmp_path, monkeypatch):
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
# the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here
assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {"lit_model.pth"}
assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth"}

# logs only appear on rank 0
logs = stdout.getvalue()
Expand Down
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import os
from contextlib import redirect_stderr
from io import StringIO
from pathlib import Path
from unittest import mock

import pytest
import torch
import torch.nn.functional as F
import yaml

from conftest import RunIf
from lightning import Fabric

Expand Down Expand Up @@ -222,3 +225,23 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path):
}
contents = set(os.listdir(tmp_path))
assert expected.issubset(contents)


def _test_function(out_dir: Path, foo: bool = False, bar: int = 1):
from lit_gpt.utils import save_hyperparameters

save_hyperparameters(_test_function, out_dir)


def test_save_hyperparameters(tmp_path):
from lit_gpt.utils import CLI

with mock.patch("sys.argv", ["any.py", "--out_dir", str(tmp_path), "--foo", "True"]):
CLI(_test_function)

with open(tmp_path / "hyperparameters.yaml", "r") as file:
hparams = yaml.full_load(file)

assert hparams["out_dir"] == str(tmp_path)
assert hparams["foo"] is True
assert hparams["bar"] == 1

0 comments on commit 9ffb47f

Please sign in to comment.