From 9ffb47f344596733e0f77efcb0b4d18a3543dea1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 6 Mar 2024 15:52:04 +0100 Subject: [PATCH] Save the hyperparameters to the checkpoint (#1012) --- finetune/adapter.py | 11 ++++++++--- finetune/adapter_v2.py | 11 ++++++++--- finetune/full.py | 11 ++++++++--- finetune/lora.py | 11 ++++++++--- lit_gpt/pretrain.py | 8 +++++--- lit_gpt/utils.py | 9 +++++++++ tests/test_adapter.py | 5 +++-- tests/test_adapter_v2.py | 5 +++-- tests/test_full.py | 7 ++++--- tests/test_lora.py | 5 +++-- tests/test_pretrain.py | 8 ++++++-- tests/test_utils.py | 23 +++++++++++++++++++++++ 12 files changed, 88 insertions(+), 26 deletions(-) diff --git a/finetune/adapter.py b/finetune/adapter.py index 0cc46bbe7e..7e9ae5a814 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -33,6 +33,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -143,8 +144,10 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat save_path = out_dir / "final" / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) save_adapter_checkpoint(fabric, model, save_path) - # Copy checkpoint files from original checkpoint dir - copy_config_files(checkpoint_dir, save_path.parent) + if fabric.global_rank == 0: + # Copy checkpoint files from original checkpoint dir + copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -223,7 +226,9 @@ def fit( checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_adapter_checkpoint(fabric, model, checkpoint_file) - copy_config_files(checkpoint_dir, checkpoint_file.parent) + if fabric.global_rank == 0: + copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # the adapter "kv cache" cannot be initialized under `inference_mode` diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 0340013dd6..0c9bd235c6 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -33,6 +33,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -143,8 +144,10 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat save_path = out_dir / "final" / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) save_adapter_v2_checkpoint(fabric, model, save_path) - # Copy checkpoint files from original checkpoint dir - copy_config_files(checkpoint_dir, save_path.parent) + if fabric.global_rank == 0: + # Copy checkpoint files from original checkpoint dir + copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -223,7 +226,9 @@ def fit( checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_adapter_v2_checkpoint(fabric, model, checkpoint_file) - copy_config_files(checkpoint_dir, checkpoint_file.parent) + if fabric.global_rank == 0: + copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # the adapter "kv cache" cannot be initialized under `inference_mode` diff --git a/finetune/full.py b/finetune/full.py index 276abeb622..f93f783e3f 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -34,6 +34,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -137,8 +138,10 @@ def main( save_path = out_dir / "final" / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) fabric.save(save_path, {"model": state["model"]}) - # Copy checkpoint files from original checkpoint dir - copy_config_files(checkpoint_dir, save_path.parent) + if fabric.global_rank == 0: + # Copy checkpoint files from original checkpoint dir + copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -242,7 +245,9 @@ def fit( checkpoint_file.parent.mkdir(parents=True, exist_ok=True) fabric.print(f"Saving checkpoint to {str(checkpoint_file.parent)!r}") fabric.save(checkpoint_file, state) - copy_config_files(checkpoint_dir, checkpoint_file.parent) + if fabric.global_rank == 0: + copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # FSDP has issues with `inference_mode` diff --git a/finetune/lora.py b/finetune/lora.py index 3f9d612844..f399e9beb3 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -33,6 +33,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -174,8 +175,10 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat save_path = out_dir / "final" / "lit_model.pth" save_path.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, save_path) - # Copy checkpoint files from original checkpoint dir - copy_config_files(checkpoint_dir, save_path.parent) + if fabric.global_rank == 0: + # Copy checkpoint files from original checkpoint dir + copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -254,7 +257,9 @@ def fit( checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth" checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, checkpoint_file) - copy_config_files(checkpoint_dir, checkpoint_file.parent) + if fabric.global_rank == 0: + copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # FSDP has issues with `inference_mode` diff --git a/lit_gpt/pretrain.py b/lit_gpt/pretrain.py index 62ad969b2d..468a46784a 100644 --- a/lit_gpt/pretrain.py +++ b/lit_gpt/pretrain.py @@ -23,7 +23,7 @@ from lit_gpt.args import EvalArgs, TrainArgs from lit_gpt.data import LitDataModule, TinyLlama from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP -from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files +from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files, save_hyperparameters def setup( @@ -269,8 +269,10 @@ def fit( checkpoint_file.parent.mkdir(parents=True, exist_ok=True) fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}") fabric.save(checkpoint_file, state) - if tokenizer_dir is not None: - copy_config_files(tokenizer_dir, checkpoint_file.parent) + if fabric.global_rank == 0: + save_hyperparameters(setup, checkpoint_file.parent) + if tokenizer_dir is not None: + copy_config_files(tokenizer_dir, checkpoint_file.parent) @torch.no_grad() diff --git a/lit_gpt/utils.py b/lit_gpt/utils.py index b21da4ca80..8b0ff82973 100644 --- a/lit_gpt/utils.py +++ b/lit_gpt/utils.py @@ -395,6 +395,15 @@ def CLI(*args: Any, **kwargs: Any) -> Any: return CLI(*args, **kwargs) +def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None: + """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" + from jsonargparse import capture_parser + + parser = capture_parser(lambda: CLI(function)) + config = parser.parse_args() + parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True) + + def parse_devices(devices: Union[str, int]) -> int: if devices in (-1, "auto"): return torch.cuda.device_count() or 1 diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 0b20105e8b..7415cdc5c2 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -68,7 +68,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path) out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, @@ -93,6 +93,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path) "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() @@ -166,7 +167,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca monkeypatch.setattr(module, "fit", train_mock) stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index b78b1e91e1..283a97f90d 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -91,7 +91,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, @@ -116,6 +116,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() @@ -255,7 +256,7 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp monkeypatch.setattr(module, "fit", train_mock) stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, diff --git a/tests/test_full.py b/tests/test_full.py index 4db91adbb9..f09862e6ea 100644 --- a/tests/test_full.py +++ b/tests/test_full.py @@ -40,7 +40,7 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1), ) stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]): module.setup(**setup_kwargs) out_dir_contents = set(os.listdir(out_dir)) @@ -48,11 +48,12 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): assert checkpoint_dirs.issubset(out_dir_contents) assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: - assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == { + assert set(os.listdir(out_dir / checkpoint_dir)) == { "lit_model.pth", "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() @@ -65,7 +66,7 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): setup_kwargs["train"].max_steps = 8 setup_kwargs["resume"] = True stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]): module.setup(**setup_kwargs) logs = stdout.getvalue() assert f"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}" in logs diff --git a/tests/test_lora.py b/tests/test_lora.py index ed7ae8c129..aae1152168 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -198,7 +198,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, @@ -223,6 +223,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() @@ -624,7 +625,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa monkeypatch.setattr(module, "fit", train_mock) stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["full.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, diff --git a/tests/test_pretrain.py b/tests/test_pretrain.py index ba3418bffe..6628618fc8 100644 --- a/tests/test_pretrain.py +++ b/tests/test_pretrain.py @@ -14,7 +14,11 @@ @RunIf(min_cuda_gpus=2, standalone=True) # Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) -def test_pretrain(tmp_path, monkeypatch): +# If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise +# the CLI would capture pytest args, but unfortunately patching would mess with subprocess +# launching, so we need to mock `save_hyperparameters()` +@mock.patch("lit_gpt.pretrain.save_hyperparameters") +def test_pretrain(_, tmp_path): from lit_gpt import pretrain from lit_gpt.args import EvalArgs, TrainArgs from lit_gpt.config import Config @@ -44,7 +48,7 @@ def test_pretrain(tmp_path, monkeypatch): assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: # the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here - assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {"lit_model.pth"} + assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth"} # logs only appear on rank 0 logs = stdout.getvalue() diff --git a/tests/test_utils.py b/tests/test_utils.py index cf4e390e85..f4b8ae51a0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,11 +3,14 @@ import os from contextlib import redirect_stderr from io import StringIO +from pathlib import Path from unittest import mock import pytest import torch import torch.nn.functional as F +import yaml + from conftest import RunIf from lightning import Fabric @@ -222,3 +225,23 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path): } contents = set(os.listdir(tmp_path)) assert expected.issubset(contents) + + +def _test_function(out_dir: Path, foo: bool = False, bar: int = 1): + from lit_gpt.utils import save_hyperparameters + + save_hyperparameters(_test_function, out_dir) + + +def test_save_hyperparameters(tmp_path): + from lit_gpt.utils import CLI + + with mock.patch("sys.argv", ["any.py", "--out_dir", str(tmp_path), "--foo", "True"]): + CLI(_test_function) + + with open(tmp_path / "hyperparameters.yaml", "r") as file: + hparams = yaml.full_load(file) + + assert hparams["out_dir"] == str(tmp_path) + assert hparams["foo"] is True + assert hparams["bar"] == 1