Skip to content

Commit

Permalink
Merge branch 'main' into galore
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored May 6, 2024
2 parents bd73193 + 90a16e4 commit b5b472f
Show file tree
Hide file tree
Showing 26 changed files with 133 additions and 218 deletions.
4 changes: 2 additions & 2 deletions .github/azure-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pr:

jobs:
- job: testing
timeoutInMinutes: "20"
timeoutInMinutes: "30"
cancelTimeoutInMinutes: "2"
pool: "lit-rtx-3090"
variables:
Expand Down Expand Up @@ -67,4 +67,4 @@ jobs:
env:
PL_RUN_CUDA_TESTS: "1"
displayName: "Standalone tests"
timeoutInMinutes: "5"
timeoutInMinutes: "10"
75 changes: 39 additions & 36 deletions extensions/thunder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,12 @@ from extensions.thunder.strategies import ThunderFSDPStrategy, ThunderDDPStrateg
strategy = ThunderFSDPStrategy(
sharding_strategy="ZERO3",
bucketing_strategy="BLOCK",
executors=("sdpa", "torchcompile", "nvfuser", "torch"),
executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"),
state_dict_type="full",
)

# replicated data parallel
strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile", "nvfuser", "torch"))
strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"))

fabric = L.Fabric(devices=devices, strategy=strategy)
fabric.launch()
Expand All @@ -482,12 +482,10 @@ Thunder allows you to define a priority list of executors that can map operators

```python
import thunder
from thunder.executors.sdpaex import sdpa_ex
from thunder.executors.torch_compile import torch_compile_executor

model = thunder.jit(
model,
executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor]
executors=["sdpa", "torchcompile_cat", "nvfuser", "torch"]
)
```

Expand All @@ -507,11 +505,11 @@ We can enable this executor by passing it to the list of executors available. Th
`NvFuser` creates its fusion regions.

```python
from unsloth.executor import unsloth_ex
import thunder

model = thunder.jit(
model,
executors=[sdpa_ex, unsloth_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor]
executors=["sdpa", "unsloth", "torchcompile_cat", "nvfuser", "torch"]
)
```

Expand Down Expand Up @@ -543,21 +541,24 @@ Given the Unsloth results below, these hand-written kernels do not seem to be wo
We provide a version of the main pre-training script [that integrates Thunder](pretrain.py) that uses TinyLlama, a 1.1B parameter LLM.

| Setting | Compiler/JIT | Devices | ms/iter @ step 10 | Memory (GB) |
|----------------------|--------------|---------|-------------------|---------------|
| Fully-sharded ZeRO 3 | Eager | 8 | 460.88 | 22.13 |
| Fully-sharded ZeRO 3 | Inductor | 8 | Not supported | Not supported |
| Fully-sharded ZeRO 3 | Thunder | 8 | 332.48 | 21.40 |
| | | | | |
| Replicated | Eager | 8 | 535.28 | 32.05 |
| Replicated | Inductor | 8 | Not supported | Not supported |
| Replicated | Thunder | 8 | 368.25 | 27.42 |
| | | | | |
| - | Eager | 1 | 449.88 | 29.85 |
| - | Inductor | 1 | Not supported | Not supported |
| - | Thunder | 1 | 323.78 | 27.42 |
| | | | | |
| Unsloth | Thunder | 1 | 334.98 | 25.19 |
| Setting | Compiler | Executors | Devices | ms/iter @ step 10 | Memory (GB) |
|----------------------|----------|----------------------------------------|---------|-------------------|---------------|
| Fully-sharded ZeRO 3 | Eager | - | 8 | 456.57 | 22.13 |
| Fully-sharded ZeRO 3 | torch | - | 8 | Not supported | Not supported |
| Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile | 8 | Not supported | Not supported |
| Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 333.56 | 21.40 |
| | | | | | |
| Replicated | Eager | - | 8 | 569.46 | 32.04 |
| Replicated | torch | - | 8 | Not supported | Not supported |
| Replicated | Thunder | sdpa, torchcompile | 8 | 426.44 | 22.19 |
| Replicated | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 356.01 | 27.42 |
| | | | | | |
| - | Eager | - | 1 | 447.65 | 29.84 |
| - | torch | - | 1 | Not supported | Not supported |
| - | Thunder | sdpa, torchcompile | 1 | 373.37 | 22.19 |
| - | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 322.25 | 27.42 |
| | | | | | |
| Unsloth | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 331.92 | 25.19 |

<details>
<summary>Reproduction details</summary>
Expand All @@ -567,45 +568,47 @@ Config:
```yaml
out_dir: out/pretrain-thunder
data: TinyStories
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf
tokenizer_dir: checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0
logger_name: csv
```
Commands:
```bash
litgpt download --repo_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tokenizer_only true

python extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --train.global_batch_size 32
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --train.global_batch_size 32

python extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --strategy ddp
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --strategy ddp

python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1
python extensions/thunder/pretrain.py --config config.yaml --compiler torch --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --devices 1

python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile, nvfuser, torch]' --devices 1
python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1
```

Gradient accumulation is disabled in the FSDP setting because Thunder does not support skipping the backward synchronization yet.

`torch.compile` does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601
`--compiler torch` (`torch.compile` without `thunder`) is not include because it does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601

The CUDA devices are all NVIDIA A100-SXM4-40GB.

```text
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python version: 3.10.12 [GCC 11.4.0] (64-bit runtime)
Is debug build: False
CUDA used to build PyTorch: 12.1
CUDA runtime version: 12.3.107
Nvidia driver version: 545.23.08
pytorch-triton==3.0.0+989adb9a29
torch==2.4.0.dev20240326+cu121
pytorch-triton==3.0.0+45fff310c8
torch==2.4.0.dev20240427+cu121
lightning==2.3.0.dev20240328
lightning-thunder==0.2.0.dev20240404
nvfuser_cu121==0.2.0.dev20240327
lightning-thunder==0.2.0.dev20240505
nvfuser_cu121==0.2.3.dev20240428
```

</details>
2 changes: 0 additions & 2 deletions extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,7 @@ def jit(fn: Callable, executors: List[str]) -> Any:
assert executors is not None
import thunder
from unsloth.executor import unsloth_ex # import for registration # noqa: F401
from strategies.utils import _validate_executors

executors = _validate_executors(executors)
return thunder.jit(fn, executors=executors)


Expand Down
4 changes: 1 addition & 3 deletions extensions/thunder/strategies/thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from torch.nn import Module
from typing_extensions import override

from .utils import _validate_executors

if TYPE_CHECKING:
from thunder import Executor

Expand Down Expand Up @@ -74,7 +72,7 @@ def __init__(
if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = _validate_executors(executors)
self.executors = executors
self._num_nodes = 1
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
Expand Down
4 changes: 1 addition & 3 deletions extensions/thunder/strategies/thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
from torch.optim import Optimizer
from typing_extensions import override

from .utils import _validate_executors

if TYPE_CHECKING:
from thunder import Executor
from thunder.distributed import FSDPBucketingStrategy, FSDPType
Expand Down Expand Up @@ -122,7 +120,7 @@ def __init__(
if not jit and executors is not None:
raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`")
self.jit = jit
self.executors = _validate_executors(executors)
self.executors = executors
self._state_dict_type = state_dict_type
self._fsdp_kwargs = kwargs

Expand Down
28 changes: 0 additions & 28 deletions extensions/thunder/strategies/utils.py

This file was deleted.

63 changes: 27 additions & 36 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
"""
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.head_size = head_size
self.n_head = n_head
self.n_query_groups = n_query_groups
if isinstance(enable_lora, bool):
Expand Down Expand Up @@ -258,30 +259,34 @@ def __init__(
# https://github.com/cloneofsimo/lora
self.scaling = self.lora_alpha / self.r

# Compute the indices
# Indices are needed to properly pad weight updates with zeros in `zero_pad` method.
q_per_kv = self.n_head // self.n_query_groups
total_qkv = q_per_kv + 2
head_size = out_features // (self.n_query_groups * total_qkv)
ind = range(out_features)
self.reset_parameters()

@property
def lora_ind(self) -> torch.Tensor:
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
# Indices are needed to properly pad weight updates with zeros.
if not hasattr(self, "_lora_ind"):
enable_q, enable_k, enable_v = self.enable_lora
qkv_group_size = self.n_head // self.n_query_groups + 2
candidate_indices = range(self.linear.out_features)
lora_ind = []
if enable_q:
q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2]
q_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size < qkv_group_size - 2]
lora_ind.extend(q_ind)
if enable_k:
k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2]
k_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 2]
lora_ind.extend(k_ind)
if enable_v:
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1]
v_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 1]
lora_ind.extend(v_ind)
self._lora_ind = torch.tensor(lora_ind)
self._lora_ind_cache = {self._lora_ind.device: self._lora_ind}
self.reset_parameters()

self.register_buffer(
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
)

return self._lora_ind

def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
"""Properly pad weight updates with zeros.
"""Properly pad the last dimension of weight updates with zeros.
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
then the weights update should be:
Expand Down Expand Up @@ -332,20 +337,9 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
# ⚬ enable_lora: [True, False, True]
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where lora_ind comes in handy)
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
# for example when we want to merge/unmerge LoRA weights and pretrained weights
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)

# `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize)
# every time this method is called. So instead we simply cache a copy on each device that needs it.
if (lora_ind := self._lora_ind_cache.get(result.device)) is None:
self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device)

result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)
# only for key updates (this is where self.lora_ind comes in handy)
result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384)
return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384)

def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
Expand Down Expand Up @@ -379,7 +373,8 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
return torch.cat(
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)],
dim=1, # (B, C_output', T)
) # (B, C_output, T)

def get_lora_AB(self) -> torch.Tensor:
Expand All @@ -391,10 +386,8 @@ def get_lora_AB(self) -> torch.Tensor:
lora = self.conv1d(
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).squeeze(
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
return self.zero_pad(lora * self.scaling) # (256, 128) after zero_pad (384, 128)
).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128)

def merge(self) -> None:
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
Expand Down Expand Up @@ -432,9 +425,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
after_B = self.conv1d(
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).transpose(
-2, -1
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
return pretrained + lora

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ test = [
"transformers>=4.38.0", # numerical comparisons
"einops>=0.7.0",
"protobuf>=4.23.4",
"lightning-thunder==0.2.0.dev20240404; python_version >= '3.10'",
"lightning-thunder==0.2.0.dev20240505; python_version >= '3.10'",
]
all = [
"bitsandbytes==0.42.0", # quantization
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import warnings

import pytest

warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")
13 changes: 8 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import shutil
import warnings
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -50,6 +49,14 @@ def restore_default_dtype():
torch.set_default_dtype(torch.float32)


@pytest.fixture(autouse=True)
def destroy_process_group():
import torch.distributed

if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


class MockTokenizer:
"""A dummy tokenizer that encodes each character as its ASCII code."""

Expand Down Expand Up @@ -149,7 +156,3 @@ def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.C
bold=True,
purple=True, # oh yeah, branded pytest messages
)


# Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel)
warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")
Loading

0 comments on commit b5b472f

Please sign in to comment.