Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix capturing hparams for loggers that don't support serializing non-primitives #1281

Merged
merged 4 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion litgpt/pretrain.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does finetuning also need it?

Also please apply the change to the extensions/thunder/pretrain.py copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finetuning doesn't capture the hparams for the logger. It's only in the training script because it is my personal preference and so far I was the only user of this script that used wandb until today when Guangtao from TinyLlama noticed this bug.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
Expand Down Expand Up @@ -87,7 +88,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""
hparams = locals()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
Expand Down
18 changes: 17 additions & 1 deletion litgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Utility functions for training and inference."""
import inspect
import math
import pickle
import shutil
import sys
from dataclasses import asdict
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -404,6 +405,21 @@ def CLI(*args: Any, **kwargs: Any) -> Any:
return CLI(*args, **kwargs)


def capture_hparams() -> Dict[str, Any]:
"""Captures the local variables ('hyperparameters') from where this function gets called."""
caller_frame = inspect.currentframe().f_back
locals_of_caller = caller_frame.f_locals
hparams = {}
for name, value in locals_of_caller.items():
if value is None or isinstance(value, (int, float, str, bool, Path)):
hparams[name] = value
elif is_dataclass(value):
hparams[name] = asdict(value)
else:
hparams[name] = str(value)
return hparams


def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser
Expand Down
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import asdict

import os
from contextlib import redirect_stderr
Expand All @@ -18,9 +19,11 @@
from lightning_utilities.core.imports import RequirementCache

from litgpt import GPT
from litgpt.args import TrainArgs
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
check_valid_checkpoint_dir,
choose_logger,
chunked_cross_entropy,
Expand Down Expand Up @@ -219,6 +222,26 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path):
assert expected.issubset(contents)


def test_capture_hparams():
integer = 1
string = "string"
boolean = True
none = None
path = Path("/path")
dataclass = TrainArgs()
other = torch.nn.Linear(1, 1)
hparams = capture_hparams()
assert hparams == {
"integer": integer,
"string": string,
"boolean": boolean,
"none": none,
"path": path,
"dataclass": asdict(dataclass),
"other": str(other),
}


def _test_function(out_dir: Path, foo: bool = False, bar: int = 1):
save_hyperparameters(_test_function, out_dir)

Expand Down
Loading