Skip to content

Commit

Permalink
Merge Thunder pretrain scripts (#1206)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 4, 2024
1 parent 64bd9eb commit 508027a
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 508 deletions.
40 changes: 17 additions & 23 deletions extensions/thunder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -546,18 +546,18 @@ We provide a version of the main pre-training script [that integrates Thunder](p
| Setting | Compiler/JIT | Devices | ms/iter @ step 10 | Memory (GB) |
|----------------------|--------------|---------|-------------------|-------------|
| Fully-sharded ZeRO 3 | Eager | 8 | 460.88 | 22.13 |
| Fully-sharded ZeRO 3 | Inductor | 8 | 318.71 | 17.08 |
| Fully-sharded ZeRO 3 | Thunder | 8 | 345.02 | 18.28 |
| Fully-sharded ZeRO 3 | Inductor | 8 | Error | Error |
| Fully-sharded ZeRO 3 | Thunder | 8 | 332.48 | 21.40 |
| | | | | |
| Replicated | Eager | 8 | 535.28 | 32.05 |
| Replicated | Inductor | 8 | 348.19 | 27.01 |
| Replicated | Thunder | 8 | OOM | OOM |
| Replicated | Inductor | 8 | Error | Error |
| Replicated | Thunder | 8 | 368.25 | 27.42 |
| | | | | |
| - | Eager | 1 | 449.88 | 29.85 |
| - | Inductor | 1 | 320.22 | 24.81 |
| - | Thunder | 1 | 322.83 | 26.37 |
| - | Inductor | 1 | Error | Error |
| - | Thunder | 1 | 323.78 | 27.42 |
| | | | | |
| Unsloth | Thunder | 1 | 331.93 | 25.19 |
| Unsloth | Thunder | 1 | 334.98 | 25.19 |

<details>
<summary>Reproduction details</summary>
Expand All @@ -566,12 +566,7 @@ Config:

```yaml
out_dir: out/pretrain-thunder
data:
class_path: litgpt.data.TinyStories
init_args:
path: data
num_workers: 0
seed: 42
data: TinyStories
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf
logger_name: csv
```
Expand All @@ -581,25 +576,24 @@ Commands:
```bash
python extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --train.global_batch_size 32

python extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --strategy ddp

python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --devices 1
python extensions/thunder/pretrain.py --config config.yaml --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --devices 1

python extensions/thunder/unsloth/pretrain.py --config config.yaml --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile, nvfuser, torch]' --devices 1
```

Gradient accumulation is disabled in the FSDP setting because Thunder does not support skipping the backward synchronization yet.

The CUDA devices are all NVIDIA A100-SXM4-40GB.
`torch.compile` fails to compile the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601

The Unsloth example does not support distributed yet.
The Unsloth example requires commenting out this line in Lightning Fabric: https://github.com/Lightning-AI/pytorch-lightning/blob/fadd2fc/src/lightning/fabric/wrappers.py#L233
The CUDA devices are all NVIDIA A100-SXM4-40GB.

```text
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Expand All @@ -608,9 +602,9 @@ CUDA used to build PyTorch: 12.1
CUDA runtime version: 12.3.107
Nvidia driver version: 545.23.08
pytorch-triton==3.0.0+989adb9a29
torch==2.3.0.dev20240314+cu121
lightning-thunder==0.1.0
nvfuser_cu121==0.1.7.dev20240315
torch==2.4.0.dev20240326+cu121
lightning-thunder==0.2.0.dev20240404
nvfuser_cu121==0.2.0.dev20240327
```

</details>
55 changes: 35 additions & 20 deletions extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import math
import os
import pprint
import sys
import time
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union, List

import lightning as L
import torch
Expand All @@ -21,7 +22,7 @@
from litgpt import Tokenizer
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import DataModule, TinyLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from litgpt.model import GPT, CausalSelfAttention, Config, LLaMAMLP, Block
from litgpt.utils import (
CLI,
CycleIterator,
Expand All @@ -35,6 +36,10 @@
save_hyperparameters,
)

# support running without installing as a package
wd = Path(__file__).parent.resolve()
sys.path.append(str(wd))


def setup(
model_name: Optional[str] = None,
Expand Down Expand Up @@ -64,6 +69,7 @@ def setup(
logger_name: Literal["wandb", "tensorboard", "csv"] = "tensorboard",
seed: int = 42,
compiler: Optional[Literal["thunder", "torch"]] = "thunder",
executors: Optional[List[str]] = ("sdpa", "torchcompile", "nvfuser", "torch"),
strategy: Literal["auto", "ddp", "fsdp"] = "fsdp",
):
"""Pretrain a model.
Expand All @@ -88,6 +94,7 @@ def setup(
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
compiler: If desired, the compiler/JIT to use.
executors: If using Thunder, the executors to enable.
strategy: If desired, the strategy to use.
"""
hparams = locals()
Expand All @@ -108,19 +115,16 @@ def setup(

if devices > 1:
if compiler == "thunder":
executors = ("sdpa", "torchcompile", "nvfuser", "torch")
global jit
jit = lambda model: model # the strategy will call `jit`
if strategy == "fsdp":
from extensions.thunder.strategies import ThunderFSDPStrategy

strategy = ThunderFSDPStrategy(
sharding_strategy="ZERO3", bucketing_strategy="BLOCK", executors=executors, state_dict_type="full"
sharding_strategy="ZERO3", bucketing_strategy="BLOCK", state_dict_type="full", jit=False,
)
elif strategy == "ddp":
from extensions.thunder.strategies import ThunderDDPStrategy

strategy = ThunderDDPStrategy()
strategy = ThunderDDPStrategy(jit=False)
else:
if strategy == "fsdp":
strategy = FSDPStrategy(
Expand All @@ -131,6 +135,10 @@ def setup(
fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-true", loggers=[logger])
fabric.launch()

if compiler is not None:
global forward_and_loss
forward_and_loss = jit(forward_and_loss, executors) if compiler == "thunder" else torch.compile(forward_and_loss)

fabric.print(pprint.pformat(hparams))
if logger_name in ("tensorboard", "wandb"):
fabric.logger.log_hyperparams(hparams)
Expand Down Expand Up @@ -188,9 +196,10 @@ def main(
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters: {num_parameters(model):,}")

if compiler is not None:
model = jit(model) if compiler == "thunder" else torch.compile(model)
model = fabric.setup(model)
if compiler == "thunder":
# avoid `Tensor.register_hook` which is unsupported
model._register_backward_hook = lambda *_: None
optimizer = torch.optim.AdamW(
model.parameters(),
lr=train.learning_rate,
Expand Down Expand Up @@ -267,7 +276,8 @@ def fit(
total_t0 = time.perf_counter()
val_loss = "n/a"

warmup_iters = train.lr_warmup_steps * train.gradient_accumulation_iters(devices)
warmup_iters = train.warmup_iters(devices, max_iters, train_dataloader)

for train_data in train_iterator:
if state["iter_num"] >= max_iters:
break
Expand All @@ -285,8 +295,7 @@ def fit(

is_accumulating = state["iter_num"] % train.gradient_accumulation_iters(devices) != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets)
loss = forward_and_loss(model, input_ids, targets)
fabric.backward(loss / train.gradient_accumulation_iters(devices))

running_loss.update(loss.detach())
Expand Down Expand Up @@ -359,6 +368,13 @@ def fit(
save_config(model.config, checkpoint_file.parent)


def forward_and_loss(model: nn.Module, input_ids: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
logits = model(input_ids)
# disable chunk_size to enable the unsloth cross entropy kernel
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
return loss


@torch.no_grad()
def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor:
fabric.barrier()
Expand All @@ -371,8 +387,7 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max
break
input_ids = batch[:, 0 : model.max_seq_length].contiguous().long()
targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets)
loss = forward_and_loss(model, input_ids, targets)
losses.append(loss)

val_loss = torch.stack(losses).mean()
Expand Down Expand Up @@ -454,14 +469,14 @@ def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resu
raise ValueError("\n".join(issues))


def jit(fn: Callable) -> Any:
def jit(fn: Callable, executors: List[str]) -> Any:
assert executors is not None
import thunder
from thunder.executors.sdpaex import sdpa_ex
from thunder.executors.torch_compile import torch_compile_executor
from unsloth.executor import unsloth_ex # import for registration # noqa: F401
from strategies.utils import _validate_executors

return thunder.jit(
fn, executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor]
)
executors = _validate_executors(executors)
return thunder.jit(fn, executors=executors)


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions extensions/thunder/strategies/thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ def setup_module(self, module: Module) -> Module:
ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)
# update the compile data state
cd.fn = ddp_module
assert hasattr(cd, "_processed_function") # sanity check
cd._processed_function = ddp_module
cd.process_group_for_ddp = ddp_module.process_group_for_ddp
return module
else:
Expand Down
2 changes: 0 additions & 2 deletions extensions/thunder/strategies/thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def setup_module(self, module: Module) -> Module:
)
# update the compile data state
cd.fn = fsdp_module
assert hasattr(cd, "_processed_function") # sanity check
cd._processed_function = fsdp_module
cd.process_group_for_ddp = fsdp_module.process_group_for_ddp
return module
else:
Expand Down
4 changes: 2 additions & 2 deletions extensions/thunder/strategies/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union, Sequence

if TYPE_CHECKING:
from thunder import Executor


def _validate_executors(executors: Optional[Tuple[Union["Executor", str], ...]]) -> Optional[Tuple["Executor", ...]]:
def _validate_executors(executors: Optional[Sequence[Union["Executor", str]]]) -> Optional[Tuple["Executor", ...]]:
"""Converts string executors into it's respective ``Executor`` object."""
if executors is None:
return None
Expand Down
2 changes: 1 addition & 1 deletion extensions/thunder/unsloth/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import kernels

unsloth_ex = OperatorExecutor("unsloth_ex", version="0.1")
unsloth_ex = OperatorExecutor("unsloth", version="0.1")
register_executor(unsloth_ex)


Expand Down
Loading

0 comments on commit 508027a

Please sign in to comment.