Skip to content

Commit

Permalink
Make logger configurable (#1063)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and rasbt committed Mar 18, 2024
1 parent abc2f71 commit 8bc8534
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 53 deletions.
10 changes: 6 additions & 4 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import time
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, Literal

import lightning as L
import torch
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import FSDPStrategy
from torch.utils.data import DataLoader
from torchmetrics import RunningMean
Expand All @@ -31,6 +30,7 @@
parse_devices,
copy_config_files,
save_hyperparameters,
choose_logger,
)


Expand All @@ -42,6 +42,7 @@ def setup(
data: Optional[LitDataModule] = None,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/finetune/full"),
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
Expand All @@ -58,8 +59,10 @@ def setup(
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
config = Config.from_name(name=checkpoint_dir.name)

precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", resume=resume, log_interval=train.log_interval)

if devices > 1:
strategy = FSDPStrategy(
Expand All @@ -72,9 +75,8 @@ def setup(
else:
strategy = "auto"

logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=train.log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger)
fabric.launch(main, devices, resume, seed, Config.from_name(name=checkpoint_dir.name), data, checkpoint_dir, out_dir, train, eval)
fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval)


def main(
Expand Down
41 changes: 16 additions & 25 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import lightning as L
import torch
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
Expand All @@ -33,6 +32,7 @@
parse_devices,
copy_config_files,
save_hyperparameters,
choose_logger,
)


Expand All @@ -53,6 +53,7 @@ def setup(
data: Optional[LitDataModule] = None,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/lora"),
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
Expand All @@ -69,8 +70,21 @@ def setup(
pprint(locals())
data = Alpaca() if data is None else data
devices = parse_devices(devices)
config = Config.from_name(
name=checkpoint_dir.name,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_query=lora_query,
lora_key=lora_key,
lora_value=lora_value,
lora_projection=lora_projection,
lora_mlp=lora_mlp,
lora_head=lora_head,
)

precision = precision or get_default_supported_precision(training=True)
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)

plugins = None
if quantize is not None and quantize.startswith("bnb."):
Expand All @@ -96,31 +110,8 @@ def setup(
else:
strategy = "auto"

logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=train.log_interval)
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)

fabric.launch(
main,
devices,
seed,
Config.from_name(
name=checkpoint_dir.name,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_query=lora_query,
lora_key=lora_key,
lora_value=lora_value,
lora_projection=lora_projection,
lora_mlp=lora_mlp,
lora_head=lora_head,
),
data,
checkpoint_dir,
out_dir,
train,
eval,
)
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval)


def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
Expand Down
23 changes: 3 additions & 20 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import lightning as L
import torch
import torch.nn as nn
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from torchmetrics.aggregation import RunningMean
from typing_extensions import Literal
Expand All @@ -33,19 +31,20 @@
copy_config_files,
save_hyperparameters,
save_config,
choose_logger,
)


def setup(
model_name: Optional[str] = None,
model_config: Optional[Config] = None,
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
resume: Union[bool, Path] = False,
devices: Union[int, str] = "auto",
seed: int = 42,
data: Optional[LitDataModule] = None,
out_dir: Path = Path("out/pretrain"),
tokenizer_dir: Optional[Path] = None,
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
Expand Down Expand Up @@ -75,13 +74,7 @@ def setup(
# in case the dataset requires the Tokenizer
tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None

logger = choose_logger(
out_dir,
logger_name,
name=f"pretrain-{config.name}",
resume=resume,
log_interval=train.log_interval
)
logger = choose_logger(logger_name, out_dir, name=f"pretrain-{config.name}", resume=resume, log_interval=train.log_interval)

if devices > 1:
strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD")
Expand Down Expand Up @@ -352,16 +345,6 @@ def init_weights(module: nn.Module, n_layer: int, n_embd: int):
nn.init.normal_(param, mean=0.0, std=(1 / math.sqrt(n_embd) / n_layer))


def choose_logger(out_dir: Path, logger_name: str, name: str, resume: Union[bool, Path], log_interval: int, *args, **kwargs):
if logger_name == "csv":
return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, *args, **kwargs)
if logger_name == "tensorboard":
return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", *args, **kwargs)
if logger_name == "wandb":
return WandbLogger(project="pretrain", name=name, resume=(resume is not False), *args, **kwargs)
raise ValueError(f"`logger={logger_name}` is not a valid option.")


def init_out_dir(out_dir: Path) -> Path:
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
Expand Down
21 changes: 20 additions & 1 deletion litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from dataclasses import asdict
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, TypeVar, Union, Literal

import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.loggers import WandbLogger
from torch.serialization import normalize_storage_type
from typing_extensions import Self

Expand Down Expand Up @@ -417,3 +419,20 @@ def parse_devices(devices: Union[str, int]) -> int:
if isinstance(devices, int) and devices > 0:
return devices
raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")


def choose_logger(
logger_name: Literal["csv", "tensorboard", "wandb"],
out_dir: Path,
name: str,
log_interval: int = 1,
resume: Optional[bool] = None,
**kwargs: Any,
):
if logger_name == "csv":
return CSVLogger(root_dir=(out_dir / "logs"), name="csv", flush_logs_every_n_steps=log_interval, **kwargs)
if logger_name == "tensorboard":
return TensorBoardLogger(root_dir=(out_dir / "logs"), name="tensorboard", **kwargs)
if logger_name == "wandb":
return WandbLogger(project=name, resume=resume, **kwargs)
raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")
2 changes: 1 addition & 1 deletion tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
"hyperparameters.yaml",
"prompt_style.json",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()

logs = stdout.getvalue()
assert logs.count("optimizer.step") == 6
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
"hyperparameters.yaml",
"prompt_style.json",
}
assert (out_dir / "version_0" / "metrics.csv").is_file()
assert (out_dir / "logs" / "csv" / "version_0" / "metrics.csv").is_file()

logs = stdout.getvalue()
assert logs.count("optimizer.step") == 6
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def test_pretrain(_, tmp_path):
# the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here
assert set(os.listdir(out_dir / checkpoint_dir)) == {"lit_model.pth", "lit_config.json"}

assert (out_dir / "logs" / "tensorboard" / "version_0").is_dir()

# logs only appear on rank 0
logs = stdout.getvalue()
assert logs.count("optimizer.step") == 4
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
import torch
import torch.nn.functional as F
import yaml
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger

from conftest import RunIf
from lightning import Fabric

from lightning_utilities.core.imports import RequirementCache


def test_find_multiple():
from litgpt.utils import find_multiple
Expand Down Expand Up @@ -245,3 +249,16 @@ def test_save_hyperparameters(tmp_path):
assert hparams["out_dir"] == str(tmp_path)
assert hparams["foo"] is True
assert hparams["bar"] == 1


def test_choose_logger(tmp_path):
from litgpt.utils import choose_logger

assert isinstance(choose_logger("csv", out_dir=tmp_path, name="csv"), CSVLogger)
if RequirementCache("tensorboard"):
assert isinstance(choose_logger("tensorboard", out_dir=tmp_path, name="tb"), TensorBoardLogger)
if RequirementCache("wandb"):
assert isinstance(choose_logger("wandb", out_dir=tmp_path, name="wandb"), WandbLogger)

with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."):
choose_logger("foo", out_dir=tmp_path, name="foo")
2 changes: 1 addition & 1 deletion tutorials/pretrain_tinyllama.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ GPU memory. For more tips to avoid out-of-memory issues, please also see the mor
[Dealing with out-of-memory (OOM) errors](oom.md) guide.

Last, logging is kept minimal in the script, but for long-running experiments we recommend switching to a proper experiment tracker.
As an example, we included WandB (set `use_wandb=True`) to show how you can integrate any experiment tracking framework.
As an example, we included WandB (set `--logger_name=wandb`) to show how you can integrate any experiment tracking framework.
For reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy).

## Resume training
Expand Down

0 comments on commit 8bc8534

Please sign in to comment.