Skip to content

Commit

Permalink
thunder pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 9, 2024
1 parent cd8db2b commit f07b76b
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import Literal

from litgpt import Tokenizer
from litgpt.args import EvalArgs, TrainArgs
from litgpt.args import EvalArgs, TrainArgs, OptimizerArgs
from litgpt.data import DataModule, TinyLlama
from litgpt.model import GPT, CausalSelfAttention, Config, LLaMAMLP, Block
from litgpt.utils import (
Expand All @@ -32,6 +32,7 @@
copy_config_files,
num_parameters,
parse_devices,
parse_kwargs_from_string,
reset_parameters,
save_config,
save_hyperparameters,
Expand All @@ -55,16 +56,20 @@ def setup(
global_batch_size=512,
micro_batch_size=4,
max_tokens=int(3e12), # 3 trillion
learning_rate=4e-4,
weight_decay=1e-1,
beta1=0.9,
beta2=0.95,
max_norm=1.0,
min_lr=4e-5,
lr_warmup_steps=2000,
tie_embeddings=False,
),
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
optim: OptimizerArgs = OptimizerArgs(
optimizer="adamw",
learning_rate=4e-4,
weight_decay=1e-1,
beta1=0.9,
beta2=0.95,
extra_kwargs=None
),
devices: Union[int, str] = "auto",
tokenizer_dir: Optional[Path] = None,
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
Expand All @@ -89,6 +94,7 @@ def setup(
data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
optim: Optimizer-related arguments. See ``litgpt.args.OptimizerArgs`` for details.
devices: How many devices/GPUs to use. Uses all GPUs by default.
tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
module require this.
Expand Down Expand Up @@ -133,6 +139,20 @@ def setup(
)
else:
strategy = "auto"

if "galore" in optim.optimizer:
default_values = {
"rank": 8,
"update_proj_gap": 200,
"scale": 0.25,
"proj_type": "std"
}
elif optim.extra_kwargs is None:
optim.extra_kwargs = ""
default_values = {}
else:
default_values = {}
optim.extra_kwargs = parse_kwargs_from_string(optim.extra_kwargs, defaults=default_values)
fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-true", loggers=[logger])
fabric.launch()

Expand All @@ -157,6 +177,7 @@ def setup(
tokenizer,
train,
eval,
optim,
compiler,
)

Expand All @@ -174,6 +195,7 @@ def main(
tokenizer: Optional[Tokenizer],
train: TrainArgs,
eval: EvalArgs,
optim: OptimizerArgs,
compiler: Optional[Literal["thunder", "torch"]],
) -> None:
validate_args(train, eval, initial_checkpoint_dir, resume)
Expand All @@ -198,6 +220,11 @@ def main(
fabric.print(f"Total parameters: {num_parameters(model):,}")

model = fabric.setup(model)

if optim.extra_kwargs:
raise ValueError("Additional optimizer arguments are currently not supported by Thunder.")
if optim.optimizer in ("galore_adamw", "galore_adamw_8bit"):
raise ValueError("GaLore is currently not supported by Thunder.")
if compiler == "thunder":
# avoid `Tensor.register_hook` which is unsupported
model._register_backward_hook = lambda *_: None
Expand Down Expand Up @@ -231,7 +258,7 @@ def main(
fabric.load(resume, state)

train_time = time.perf_counter()
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval, optim)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")

# Save final checkpoint
Expand Down

0 comments on commit f07b76b

Please sign in to comment.