From 3516beacd3ff2ccca51fdc6e2471ff5adcc35e4f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 12 Apr 2024 16:35:17 +0100 Subject: [PATCH 1/2] Fix capturing hparams for loggers that don't support serializing non-primitives (#1281) --- extensions/thunder/pretrain.py | 3 ++- litgpt/pretrain.py | 3 ++- litgpt/utils.py | 18 +++++++++++++++++- tests/test_utils.py | 23 +++++++++++++++++++++++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index fe9c798dc0..24f140c9df 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -26,6 +26,7 @@ from litgpt.utils import ( CLI, CycleIterator, + capture_hparams, choose_logger, chunked_cross_entropy, copy_config_files, @@ -97,7 +98,7 @@ def setup( executors: If using Thunder, the executors to enable. strategy: If desired, the strategy to use. """ - hparams = locals() + hparams = capture_hparams() data = TinyLlama() if data is None else data if model_config is not None and model_name is not None: raise ValueError("Only one of `model_name` or `model_config` can be set.") diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index 4ab31b414e..f75a93d8c6 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -26,6 +26,7 @@ from litgpt.utils import ( CLI, CycleIterator, + capture_hparams, choose_logger, chunked_cross_entropy, copy_config_files, @@ -87,7 +88,7 @@ def setup( logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. """ - hparams = locals() + hparams = capture_hparams() data = TinyLlama() if data is None else data if model_config is not None and model_name is not None: raise ValueError("Only one of `model_name` or `model_config` can be set.") diff --git a/litgpt/utils.py b/litgpt/utils.py index fb6a86c107..37ebdfd6f9 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -1,11 +1,12 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Utility functions for training and inference.""" +import inspect import math import pickle import shutil import sys -from dataclasses import asdict +from dataclasses import asdict, is_dataclass from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union @@ -404,6 +405,21 @@ def CLI(*args: Any, **kwargs: Any) -> Any: return CLI(*args, **kwargs) +def capture_hparams() -> Dict[str, Any]: + """Captures the local variables ('hyperparameters') from where this function gets called.""" + caller_frame = inspect.currentframe().f_back + locals_of_caller = caller_frame.f_locals + hparams = {} + for name, value in locals_of_caller.items(): + if value is None or isinstance(value, (int, float, str, bool, Path)): + hparams[name] = value + elif is_dataclass(value): + hparams[name] = asdict(value) + else: + hparams[name] = str(value) + return hparams + + def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None: """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" from jsonargparse import capture_parser diff --git a/tests/test_utils.py b/tests/test_utils.py index 63caf4158a..d76ae98056 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from dataclasses import asdict import os from contextlib import redirect_stderr @@ -18,9 +19,11 @@ from lightning_utilities.core.imports import RequirementCache from litgpt import GPT +from litgpt.args import TrainArgs from litgpt.utils import ( CLI, CycleIterator, + capture_hparams, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, @@ -219,6 +222,26 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path): assert expected.issubset(contents) +def test_capture_hparams(): + integer = 1 + string = "string" + boolean = True + none = None + path = Path("/path") + dataclass = TrainArgs() + other = torch.nn.Linear(1, 1) + hparams = capture_hparams() + assert hparams == { + "integer": integer, + "string": string, + "boolean": boolean, + "none": none, + "path": path, + "dataclass": asdict(dataclass), + "other": str(other), + } + + def _test_function(out_dir: Path, foo: bool = False, bar: int = 1): save_hyperparameters(_test_function, out_dir) From d6e91ee802f7c6f651b55ff984ee73d4ceac99ef Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 12 Apr 2024 12:53:06 -0400 Subject: [PATCH 2/2] Add whitespace after "Instruct:" (#1277) --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 2d989be32b..df1a7150b6 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -251,7 +251,7 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: class Phi2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - return f"Instruct:{prompt}\nOutput:" + return f"Instruct: {prompt}\nOutput:" class TinyLlama(PromptStyle):