Skip to content

Commit

Permalink
fix thunder test
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 22, 2024
1 parent 393d516 commit 379e54e
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union, List
from typing import Any, Callable, Optional, Tuple, Union, List, Dict

import lightning as L
import torch
Expand All @@ -30,6 +30,7 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
instantiate_torch_optimizer,
num_parameters,
parse_devices,
reset_parameters,
Expand All @@ -55,16 +56,13 @@ def setup(
global_batch_size=512,
micro_batch_size=4,
max_tokens=int(3e12), # 3 trillion
learning_rate=4e-4,
weight_decay=1e-1,
beta1=0.9,
beta2=0.95,
max_norm=1.0,
min_lr=4e-5,
lr_warmup_steps=2000,
tie_embeddings=False,
),
eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),
optimizer: Union[str, Dict] = "AdamW",
devices: Union[int, str] = "auto",
tokenizer_dir: Optional[Path] = None,
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
Expand Down Expand Up @@ -93,6 +91,7 @@ def setup(
tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
module require this.
logger_name: The name of the logger to send metrics to.
optimizer: An optimizer name (such as "AdamW") or config.
seed: The random seed to use for reproducibility.
compiler: If desired, the compiler/JIT to use.
executors: If using Thunder, the executors to enable.
Expand Down Expand Up @@ -157,6 +156,7 @@ def setup(
tokenizer,
train,
eval,
optimizer,
compiler,
)

Expand All @@ -174,6 +174,7 @@ def main(
tokenizer: Optional[Tokenizer],
train: TrainArgs,
eval: EvalArgs,
optimizer: Union[str, Dict],
compiler: Optional[Literal["thunder", "torch"]],
) -> None:
validate_args(train, eval, initial_checkpoint_dir, resume)
Expand Down Expand Up @@ -201,13 +202,7 @@ def main(
if compiler == "thunder":
# avoid `Tensor.register_hook` which is unsupported
model._register_backward_hook = lambda *_: None
optimizer = torch.optim.AdamW(
model.parameters(),
lr=train.learning_rate,
weight_decay=train.weight_decay,
betas=(train.beta1, train.beta2),
fused=True,
)
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())
optimizer = fabric.setup_optimizers(optimizer)

train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)
Expand Down

0 comments on commit 379e54e

Please sign in to comment.