diff --git a/finetune/adapter.py b/finetune/adapter.py index 0cc46bbe7e..98952959f0 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -33,6 +33,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -145,6 +146,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat save_adapter_checkpoint(fabric, model, save_path) # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -224,6 +226,7 @@ def fit( checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_adapter_checkpoint(fabric, model, checkpoint_file) copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # the adapter "kv cache" cannot be initialized under `inference_mode` diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 0340013dd6..53eade6fe5 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -33,6 +33,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -145,6 +146,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat save_adapter_v2_checkpoint(fabric, model, save_path) # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -224,6 +226,7 @@ def fit( checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_adapter_v2_checkpoint(fabric, model, checkpoint_file) copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # the adapter "kv cache" cannot be initialized under `inference_mode` diff --git a/finetune/lora.py b/finetune/lora.py index 3f9d612844..8c81ca6064 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -33,6 +33,7 @@ CycleIterator, parse_devices, copy_config_files, + save_hyperparameters, ) @@ -176,6 +177,7 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat save_lora_checkpoint(fabric, model, save_path) # Copy checkpoint files from original checkpoint dir copy_config_files(checkpoint_dir, save_path.parent) + save_hyperparameters(setup, save_path.parent) def fit( @@ -255,6 +257,7 @@ def fit( checkpoint_file.parent.mkdir(parents=True, exist_ok=True) save_lora_checkpoint(fabric, model, checkpoint_file) copy_config_files(checkpoint_dir, checkpoint_file.parent) + save_hyperparameters(setup, checkpoint_file.parent) # FSDP has issues with `inference_mode` diff --git a/lit_gpt/pretrain.py b/lit_gpt/pretrain.py index 62ad969b2d..4037759bbc 100644 --- a/lit_gpt/pretrain.py +++ b/lit_gpt/pretrain.py @@ -23,7 +23,7 @@ from lit_gpt.args import EvalArgs, TrainArgs from lit_gpt.data import LitDataModule, TinyLlama from lit_gpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP -from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files +from lit_gpt.utils import CLI, CycleIterator, chunked_cross_entropy, num_parameters, parse_devices, copy_config_files, save_hyperparameters def setup( @@ -269,6 +269,7 @@ def fit( checkpoint_file.parent.mkdir(parents=True, exist_ok=True) fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}") fabric.save(checkpoint_file, state) + save_hyperparameters(setup, checkpoint_file.parent) if tokenizer_dir is not None: copy_config_files(tokenizer_dir, checkpoint_file.parent) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 0b20105e8b..d9fe9eae61 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -68,7 +68,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path) out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, @@ -93,6 +93,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path) "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index b78b1e91e1..92b8a35a2b 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -91,7 +91,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["adapter_v2.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, @@ -116,6 +116,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() diff --git a/tests/test_lora.py b/tests/test_lora.py index ed7ae8c129..62925f229c 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -198,7 +198,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["lora.py"]): module.setup( data=Alpaca( download_dir=alpaca_path.parent, @@ -223,6 +223,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path): "lit_config.json", "tokenizer_config.json", "tokenizer.json", + "hyperparameters.yaml", } assert (out_dir / "version_0" / "metrics.csv").is_file() diff --git a/tests/test_pretrain.py b/tests/test_pretrain.py index ba3418bffe..4655de4885 100644 --- a/tests/test_pretrain.py +++ b/tests/test_pretrain.py @@ -27,7 +27,7 @@ def test_pretrain(tmp_path, monkeypatch): out_dir = tmp_path / "out" stdout = StringIO() - with redirect_stdout(stdout): + with redirect_stdout(stdout), mock.patch("sys.argv", ["pretrain.py"]): pretrain.setup( devices=2, model=model_config, @@ -44,7 +44,7 @@ def test_pretrain(tmp_path, monkeypatch): assert all((out_dir / p).is_dir() for p in checkpoint_dirs) for checkpoint_dir in checkpoint_dirs: # the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here - assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {"lit_model.pth"} + assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "hyperparameters.yaml"} # logs only appear on rank 0 logs = stdout.getvalue()