diff --git a/.github/azure-gpu-test.yml b/.github/azure-gpu-test.yml index b9b5b32c35..b5a5c9889e 100644 --- a/.github/azure-gpu-test.yml +++ b/.github/azure-gpu-test.yml @@ -13,7 +13,7 @@ pr: jobs: - job: testing - timeoutInMinutes: "20" + timeoutInMinutes: "30" cancelTimeoutInMinutes: "2" pool: "lit-rtx-3090" variables: @@ -67,4 +67,4 @@ jobs: env: PL_RUN_CUDA_TESTS: "1" displayName: "Standalone tests" - timeoutInMinutes: "5" + timeoutInMinutes: "10" diff --git a/extensions/thunder/README.md b/extensions/thunder/README.md index 715e7745fa..a494248907 100644 --- a/extensions/thunder/README.md +++ b/extensions/thunder/README.md @@ -461,12 +461,12 @@ from extensions.thunder.strategies import ThunderFSDPStrategy, ThunderDDPStrateg strategy = ThunderFSDPStrategy( sharding_strategy="ZERO3", bucketing_strategy="BLOCK", - executors=("sdpa", "torchcompile", "nvfuser", "torch"), + executors=("sdpa", "torchcompile_cat", "nvfuser", "torch"), state_dict_type="full", ) # replicated data parallel -strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile", "nvfuser", "torch")) +strategy = ThunderDDPStrategy(executors=("sdpa", "torchcompile_cat", "nvfuser", "torch")) fabric = L.Fabric(devices=devices, strategy=strategy) fabric.launch() @@ -482,12 +482,10 @@ Thunder allows you to define a priority list of executors that can map operators ```python import thunder -from thunder.executors.sdpaex import sdpa_ex -from thunder.executors.torch_compile import torch_compile_executor model = thunder.jit( model, - executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor] + executors=["sdpa", "torchcompile_cat", "nvfuser", "torch"] ) ``` @@ -507,11 +505,11 @@ We can enable this executor by passing it to the list of executors available. Th `NvFuser` creates its fusion regions. ```python -from unsloth.executor import unsloth_ex +import thunder model = thunder.jit( model, - executors=[sdpa_ex, unsloth_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor] + executors=["sdpa", "unsloth", "torchcompile_cat", "nvfuser", "torch"] ) ``` @@ -543,21 +541,24 @@ Given the Unsloth results below, these hand-written kernels do not seem to be wo We provide a version of the main pre-training script [that integrates Thunder](pretrain.py) that uses TinyLlama, a 1.1B parameter LLM. -| Setting | Compiler/JIT | Devices | ms/iter @ step 10 | Memory (GB) | -|----------------------|--------------|---------|-------------------|---------------| -| Fully-sharded ZeRO 3 | Eager | 8 | 460.88 | 22.13 | -| Fully-sharded ZeRO 3 | Inductor | 8 | Not supported | Not supported | -| Fully-sharded ZeRO 3 | Thunder | 8 | 332.48 | 21.40 | -| | | | | | -| Replicated | Eager | 8 | 535.28 | 32.05 | -| Replicated | Inductor | 8 | Not supported | Not supported | -| Replicated | Thunder | 8 | 368.25 | 27.42 | -| | | | | | -| - | Eager | 1 | 449.88 | 29.85 | -| - | Inductor | 1 | Not supported | Not supported | -| - | Thunder | 1 | 323.78 | 27.42 | -| | | | | | -| Unsloth | Thunder | 1 | 334.98 | 25.19 | +| Setting | Compiler | Executors | Devices | ms/iter @ step 10 | Memory (GB) | +|----------------------|----------|----------------------------------------|---------|-------------------|---------------| +| Fully-sharded ZeRO 3 | Eager | - | 8 | 456.57 | 22.13 | +| Fully-sharded ZeRO 3 | torch | - | 8 | Not supported | Not supported | +| Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile | 8 | Not supported | Not supported | +| Fully-sharded ZeRO 3 | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 333.56 | 21.40 | +| | | | | | | +| Replicated | Eager | - | 8 | 569.46 | 32.04 | +| Replicated | torch | - | 8 | Not supported | Not supported | +| Replicated | Thunder | sdpa, torchcompile | 8 | 426.44 | 22.19 | +| Replicated | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 8 | 356.01 | 27.42 | +| | | | | | | +| - | Eager | - | 1 | 447.65 | 29.84 | +| - | torch | - | 1 | Not supported | Not supported | +| - | Thunder | sdpa, torchcompile | 1 | 373.37 | 22.19 | +| - | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 322.25 | 27.42 | +| | | | | | | +| Unsloth | Thunder | sdpa, torchcompile_cat, nvfuser, torch | 1 | 331.92 | 25.19 |
Reproduction details @@ -567,45 +568,47 @@ Config: ```yaml out_dir: out/pretrain-thunder data: TinyStories -tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf +tokenizer_dir: checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0 logger_name: csv ``` Commands: ```bash +litgpt download --repo_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tokenizer_only true + python extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32 -python extensions/thunder/pretrain.py --config config.yaml --compiler torch --train.global_batch_size 32 -python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --train.global_batch_size 32 +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --train.global_batch_size 32 +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --train.global_batch_size 32 python extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp -python extensions/thunder/pretrain.py --config config.yaml --compiler torch --strategy ddp -python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --strategy ddp +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --strategy ddp +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --strategy ddp python extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1 -python extensions/thunder/pretrain.py --config config.yaml --compiler torch --devices 1 -python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile, nvfuser, torch]' --devices 1 +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --devices 1 +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --devices 1 -python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile, nvfuser, torch]' --devices 1 +python extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1 ``` Gradient accumulation is disabled in the FSDP setting because Thunder does not support skipping the backward synchronization yet. -`torch.compile` does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601 +`--compiler torch` (`torch.compile` without `thunder`) is not include because it does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601 The CUDA devices are all NVIDIA A100-SXM4-40GB. ```text -Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) +Python version: 3.10.12 [GCC 11.4.0] (64-bit runtime) Is debug build: False CUDA used to build PyTorch: 12.1 CUDA runtime version: 12.3.107 Nvidia driver version: 545.23.08 -pytorch-triton==3.0.0+989adb9a29 -torch==2.4.0.dev20240326+cu121 +pytorch-triton==3.0.0+45fff310c8 +torch==2.4.0.dev20240427+cu121 lightning==2.3.0.dev20240328 -lightning-thunder==0.2.0.dev20240404 -nvfuser_cu121==0.2.0.dev20240327 +lightning-thunder==0.2.0.dev20240505 +nvfuser_cu121==0.2.3.dev20240428 ```
diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index 24f140c9df..6aa77a745f 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -482,9 +482,7 @@ def jit(fn: Callable, executors: List[str]) -> Any: assert executors is not None import thunder from unsloth.executor import unsloth_ex # import for registration # noqa: F401 - from strategies.utils import _validate_executors - executors = _validate_executors(executors) return thunder.jit(fn, executors=executors) diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index 29b2af8980..9717e9b3a1 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -28,8 +28,6 @@ from torch.nn import Module from typing_extensions import override -from .utils import _validate_executors - if TYPE_CHECKING: from thunder import Executor @@ -74,7 +72,7 @@ def __init__( if not jit and executors is not None: raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") self.jit = jit - self.executors = _validate_executors(executors) + self.executors = executors self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend self._timeout: Optional[timedelta] = timeout diff --git a/extensions/thunder/strategies/thunder_fsdp.py b/extensions/thunder/strategies/thunder_fsdp.py index 55b30bdf66..fe1719b29c 100644 --- a/extensions/thunder/strategies/thunder_fsdp.py +++ b/extensions/thunder/strategies/thunder_fsdp.py @@ -32,8 +32,6 @@ from torch.optim import Optimizer from typing_extensions import override -from .utils import _validate_executors - if TYPE_CHECKING: from thunder import Executor from thunder.distributed import FSDPBucketingStrategy, FSDPType @@ -122,7 +120,7 @@ def __init__( if not jit and executors is not None: raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") self.jit = jit - self.executors = _validate_executors(executors) + self.executors = executors self._state_dict_type = state_dict_type self._fsdp_kwargs = kwargs diff --git a/extensions/thunder/strategies/utils.py b/extensions/thunder/strategies/utils.py deleted file mode 100644 index b7132cdbf1..0000000000 --- a/extensions/thunder/strategies/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import TYPE_CHECKING, Optional, Tuple, Union, Sequence - -if TYPE_CHECKING: - from thunder import Executor - - -def _validate_executors(executors: Optional[Sequence[Union["Executor", str]]]) -> Optional[Tuple["Executor", ...]]: - """Converts string executors into it's respective ``Executor`` object.""" - if executors is None: - return None - from thunder import get_all_executors - - final = [] - issues = [] - all = get_all_executors() - for executor in executors: - if isinstance(executor, str): - for existing in all: - if executor == existing.name: - final.append(existing) - break - else: - issues.append(executor) - else: - final.append(executor) - if issues: - raise ValueError(f"Did not find the executors {issues} in {all}") - return tuple(final) diff --git a/litgpt/lora.py b/litgpt/lora.py index 8fee63cbb6..7c4ae423e0 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -215,6 +215,7 @@ def __init__( """ super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.head_size = head_size self.n_head = n_head self.n_query_groups = n_query_groups if isinstance(enable_lora, bool): @@ -258,30 +259,34 @@ def __init__( # https://github.com/cloneofsimo/lora self.scaling = self.lora_alpha / self.r - # Compute the indices - # Indices are needed to properly pad weight updates with zeros in `zero_pad` method. - q_per_kv = self.n_head // self.n_query_groups - total_qkv = q_per_kv + 2 - head_size = out_features // (self.n_query_groups * total_qkv) - ind = range(out_features) + self.reset_parameters() + + @property + def lora_ind(self) -> torch.Tensor: + """Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used.""" + # Indices are needed to properly pad weight updates with zeros. + if not hasattr(self, "_lora_ind"): + enable_q, enable_k, enable_v = self.enable_lora + qkv_group_size = self.n_head // self.n_query_groups + 2 + candidate_indices = range(self.linear.out_features) lora_ind = [] if enable_q: - q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2] + q_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size < qkv_group_size - 2] lora_ind.extend(q_ind) if enable_k: - k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2] + k_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 2] lora_ind.extend(k_ind) if enable_v: - v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1] + v_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 1] lora_ind.extend(v_ind) - self._lora_ind = torch.tensor(lora_ind) - self._lora_ind_cache = {self._lora_ind.device: self._lora_ind} - self.reset_parameters() - + self.register_buffer( + "_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False + ) + return self._lora_ind def zero_pad(self, x: torch.Tensor) -> torch.Tensor: - """Properly pad weight updates with zeros. + """Properly pad the last dimension of weight updates with zeros. If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, then the weights update should be: @@ -332,20 +337,9 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: # ⚬ enable_lora: [True, False, True] # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but - # only for key updates (this is where lora_ind comes in handy) - # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors - # for example when we want to merge/unmerge LoRA weights and pretrained weights - x = x.transpose(0, 1) - result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) - result = result.view(-1, self.linear.out_features) # (4096, 384) - - # `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize) - # every time this method is called. So instead we simply cache a copy on each device that needs it. - if (lora_ind := self._lora_ind_cache.get(result.device)) is None: - self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device) - - result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256) - return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) + # only for key updates (this is where self.lora_ind comes in handy) + result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384) + return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384) def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. @@ -379,7 +373,8 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T) weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1) return torch.cat( - [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T) + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], + dim=1, # (B, C_output', T) ) # (B, C_output, T) def get_lora_AB(self) -> torch.Tensor: @@ -391,10 +386,8 @@ def get_lora_AB(self) -> torch.Tensor: lora = self.conv1d( self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).squeeze( - 0 - ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) - return self.zero_pad(lora * self.scaling) # (256, 128) after zero_pad (384, 128) + ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128) def merge(self) -> None: """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" @@ -432,9 +425,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: after_B = self.conv1d( after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).transpose( - -2, -1 - ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) return pretrained + lora diff --git a/pyproject.toml b/pyproject.toml index ab7b9b26f3..40029ae13b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ test = [ "transformers>=4.38.0", # numerical comparisons "einops>=0.7.0", "protobuf>=4.23.4", - "lightning-thunder==0.2.0.dev20240404; python_version >= '3.10'", + "lightning-thunder==0.2.0.dev20240505; python_version >= '3.10'", ] all = [ "bitsandbytes==0.42.0", # quantization diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..2f22d66b14 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,7 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import warnings + +import pytest + +warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*") diff --git a/tests/conftest.py b/tests/conftest.py index fdfe2295eb..fa22e514c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ import os import shutil -import warnings from pathlib import Path from typing import List, Optional @@ -50,6 +49,14 @@ def restore_default_dtype(): torch.set_default_dtype(torch.float32) +@pytest.fixture(autouse=True) +def destroy_process_group(): + import torch.distributed + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + class MockTokenizer: """A dummy tokenizer that encodes each character as its ASCII code.""" @@ -149,7 +156,3 @@ def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.C bold=True, purple=True, # oh yeah, branded pytest messages ) - - -# Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel) -warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*") diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 2028a78b83..9e724ab0e8 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -9,7 +9,7 @@ import pytest import torch import yaml -from conftest import RunIf +from tests.conftest import RunIf from lightning import Fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 33f00a3166..5f63697e9e 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -8,7 +8,7 @@ import pytest import torch import yaml -from conftest import RunIf +from tests.conftest import RunIf from lightning import Fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer diff --git a/tests/test_ci.py b/tests/test_ci.py index d553b53e16..e1db31aeaf 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -1,6 +1,6 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. -from conftest import RunIf +from tests.conftest import RunIf from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index ca4ee9881e..e5b1b889c0 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -2,14 +2,11 @@ import os from dataclasses import asdict -from pathlib import Path from unittest.mock import ANY -from urllib.request import urlretrieve import pytest import torch import yaml -from conftest import RunIf from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.falcon import FalconConfig, FalconForCausalLM from transformers.models.gemma import GemmaConfig, GemmaForCausalLM @@ -27,6 +24,7 @@ copy_weights_phi, qkv_split, ) +from tests.conftest import RunIf def test_convert_lit_checkpoint(tmp_path): diff --git a/tests/test_generate.py b/tests/test_generate.py index 5f950ddcfa..7cd0dca9db 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -121,7 +121,9 @@ def test_generate_different_results_with_different_top_p(): torch.manual_seed(123) input_idx = torch.randint(10, size=(1,)) + torch.manual_seed(123) output1 = generate.generate(model, input_idx, 20, top_p=1.0) + torch.manual_seed(123) output2 = generate.generate(model, input_idx, 20, top_p=0.1) assert not torch.equal(output1, output2) diff --git a/tests/test_generate_sequentially.py b/tests/test_generate_sequentially.py index 4bc3665f97..b0bed4797e 100644 --- a/tests/test_generate_sequentially.py +++ b/tests/test_generate_sequentially.py @@ -11,7 +11,7 @@ import pytest import torch import yaml -from conftest import RunIf +from tests.conftest import RunIf from lightning import Fabric from litgpt import Config diff --git a/tests/test_generate_tp.py b/tests/test_generate_tp.py index eb0505219c..039dd0ea4b 100644 --- a/tests/test_generate_tp.py +++ b/tests/test_generate_tp.py @@ -7,12 +7,12 @@ import pytest import torch import yaml -from conftest import RunIf -from test_generate_sequentially import find_forward_hooks from litgpt import GPT, Config from litgpt.generate.tp import tensor_parallel, tensor_parallel_linear from litgpt.scripts.download import download_from_hub +from tests.conftest import RunIf +from tests.test_generate_sequentially import find_forward_hooks def test_tensor_parallel_linear(): diff --git a/tests/test_lora.py b/tests/test_lora.py index d131411d9c..d283f1cf44 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -9,7 +9,6 @@ import pytest import torch import yaml -from conftest import RunIf from lightning import Fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer @@ -22,11 +21,12 @@ import litgpt.finetune.lora as module from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca -from litgpt.lora import GPT as LoRAGPT from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention from litgpt.lora import Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights +from litgpt.lora import GPT as LoRAGPT from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_hf_llama +from tests.conftest import RunIf def test_lora_layer_replacement(): @@ -107,7 +107,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) - assert torch.equal(attn._lora_ind, torch.tensor(lora_ind)) + assert torch.equal(attn.lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 24) bsz, ctx_len, in_dim = 2, 30, 8 @@ -128,7 +128,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (12, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (10, 2) - assert torch.equal(attn._lora_ind, torch.tensor(lora_ind)) + assert torch.equal(attn.lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 12) bsz, ctx_len, in_dim = 2, 30, 8 @@ -149,7 +149,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) - assert torch.equal(attn._lora_ind, torch.tensor(lora_ind)) + assert torch.equal(attn.lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 16) bsz, ctx_len, in_dim = 2, 30, 8 @@ -733,3 +733,28 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa logs = stdout.getvalue() assert "of trainable parameters: 512" in logs assert "of non-trainable parameters: 1,888" in logs + + +@RunIf(standalone=True, min_cuda_gpus=2) +def test_lora_model_fsdp_init(): + config = Config( + n_layer=1, + n_head=2, + n_embd=8, + block_size=8, + vocab_size=8, + lora_r=8, + lora_alpha=8, + lora_dropout=0.1, + lora_query=True, + lora_value=False, + lora_projection=True, + ) + fabric = Fabric(devices=2, strategy="fsdp", precision="16-true") + fabric.launch() + with fabric.init_module(empty_init=True): + model = LoRAGPT(config) + x = torch.randint(0, config.padded_vocab_size, size=(2, config.block_size), dtype=torch.int64, device=fabric.device) + model = fabric.setup(model) + y = model(x) + assert y.shape == torch.Size([2, 8, 512]) diff --git a/tests/test_model.py b/tests/test_model.py index 49584aeb87..1cad36a8db 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,12 +2,9 @@ from copy import deepcopy from functools import partial -from pathlib import Path -from urllib.request import urlretrieve import pytest import torch -from conftest import RunIf from lightning import Fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.init import _materialize_meta_tensors @@ -37,6 +34,7 @@ copy_weights_hf_llama, copy_weights_phi, ) +from tests.conftest import RunIf @torch.inference_mode() @@ -417,64 +415,6 @@ def test_against_hf_mixtral(): torch.testing.assert_close(ours_y, theirs_y) -@torch.inference_mode() -@pytest.mark.parametrize( - ("device", "dtype"), - [ - (torch.device("cpu"), torch.float32), - pytest.param( - torch.device("cuda"), - torch.float16, - marks=[ - # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input - # is slightly different - pytest.mark.xfail(raises=AssertionError, strict=False), - RunIf(min_cuda_gpus=1), - ], - ), - ], -) -def test_against_hf_h2o_danube(device, dtype): - torch.set_default_dtype(dtype) - - ours_config = Config.from_name( - "Danube2-1.8b-chat", - padded_vocab_size=10000, - n_layer=2, - n_embd=16, - n_head=8, - n_query_groups=2, - intermediate_size=43, - ) - T = 5 - theirs_config = MistralConfig( - vocab_size=ours_config.padded_vocab_size, - hidden_size=ours_config.n_embd, - num_attention_heads=ours_config.n_head, - num_hidden_layers=ours_config.n_layer, - intermediate_size=ours_config.intermediate_size, - max_position_embeddings=T, - rms_norm_eps=ours_config.norm_eps, - num_key_value_heads=ours_config.n_query_groups, - rope_theta=ours_config.rope_base, - ) - assert ours_config.intermediate_size == theirs_config.intermediate_size - - theirs_model = MistralForCausalLM(theirs_config).to(device) - theirs_state_dict = theirs_model.state_dict() - state_dict = {} - copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) - ours_model = GPT(ours_config).to(device) - ours_model.load_state_dict(state_dict) - - # test end to end - x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) - assert x.size(1) == T - ours_y = ours_model(x) - theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float - torch.testing.assert_close(ours_y, theirs_y) - - @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), diff --git a/tests/test_pretrain.py b/tests/test_pretrain.py index d252524e87..9a67b85917 100644 --- a/tests/test_pretrain.py +++ b/tests/test_pretrain.py @@ -3,21 +3,19 @@ import os from contextlib import redirect_stdout from io import StringIO -from pathlib import Path from unittest import mock from unittest.mock import ANY, Mock import pytest import torch -from conftest import RunIf from lightning.fabric.strategies import FSDPStrategy, SingleDeviceStrategy from torch.utils.data import DataLoader -from test_utils import test_init_out_dir from litgpt import pretrain from litgpt.args import EvalArgs, TrainArgs from litgpt.config import Config from litgpt.pretrain import initialize_weights +from tests.conftest import RunIf @RunIf(min_cuda_gpus=2, standalone=True) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 20f2c84e0c..206db11c28 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -112,6 +112,6 @@ def test_save_load_prompt_style(tmp_path): save_prompt_style(CustomPromptStyle(), checkpoint_dir) with open(checkpoint_dir / "prompt_style.yaml", "r", encoding="utf-8") as file: contents = yaml.safe_load(file) - assert contents == {"class_path": "test_prompts.CustomPromptStyle"} + assert contents == {"class_path": "tests.test_prompts.CustomPromptStyle"} loaded = load_prompt_style(checkpoint_dir) assert isinstance(loaded, CustomPromptStyle) diff --git a/tests/test_thunder_ddp.py b/tests/test_thunder_ddp.py index 566e883ac3..2dbc208889 100644 --- a/tests/test_thunder_ddp.py +++ b/tests/test_thunder_ddp.py @@ -3,7 +3,7 @@ import pytest import torch -from conftest import RunIf +from tests.conftest import RunIf from lightning import Fabric # support running without installing as a package diff --git a/tests/test_thunder_fsdp.py b/tests/test_thunder_fsdp.py index 321cdac7a6..84de117574 100644 --- a/tests/test_thunder_fsdp.py +++ b/tests/test_thunder_fsdp.py @@ -1,12 +1,11 @@ import os -import re import sys from pathlib import Path from typing import Optional, Tuple, Union import pytest import torch -from conftest import RunIf +from tests.conftest import RunIf from lightning.fabric import Fabric from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 @@ -15,35 +14,20 @@ sys.path.append(str(wd)) from extensions.thunder.strategies.thunder_fsdp import ThunderFSDPStrategy -from extensions.thunder.strategies.utils import _validate_executors @RunIf(thunder=True) def test_thunder_strategy_input_parsing(): - from thunder import pythonex from thunder.distributed import FSDPBucketingStrategy, FSDPType strategy = ThunderFSDPStrategy(bucketing_strategy="BlOcK", executors=("python",), sharding_strategy="zero3") assert strategy.bucketing_strategy is FSDPBucketingStrategy.BLOCK - assert strategy.executors == (pythonex,) assert strategy.sharding_strategy is FSDPType.ZERO3 with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): ThunderFSDPStrategy(jit=False, executors=("python",)) -@RunIf(thunder=True) -def test_validate_executors(): - from thunder import pythonex, pytorch_executor - - assert _validate_executors(None) is None - assert _validate_executors((pythonex, pytorch_executor)) == (pythonex, pytorch_executor) - assert _validate_executors(("python", "torch")) == (pythonex, pytorch_executor) - assert _validate_executors(("python", pytorch_executor)) == (pythonex, pytorch_executor) - with pytest.raises(ValueError, match=re.escape("not find the executors ['foo', 'bar'] in")): - assert _validate_executors(("python", "foo", pytorch_executor, "bar")) - - @RunIf(thunder=True) def test_save_checkpoint_invalid_settings_raise(tmp_path): strategy = ThunderFSDPStrategy(state_dict_type="full") @@ -263,8 +247,6 @@ def set_up_planner(self, state_dict, metadata, is_coordinator): @RunIf(min_cuda_gpus=2, thunder=True, standalone=True) def test_save_load_sharded_checkpoint(tmp_path): - pytest.skip("Temporarily disabled, often exceeds 5 min timeout") - strategy = ThunderFSDPStrategy(state_dict_type="sharded", broadcast_from=0) fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) fabric.launch() diff --git a/tests/test_thunder_pretrain.py b/tests/test_thunder_pretrain.py index 358c0d5c59..30f9d71afb 100644 --- a/tests/test_thunder_pretrain.py +++ b/tests/test_thunder_pretrain.py @@ -6,7 +6,7 @@ from unittest.mock import Mock import torch -from conftest import RunIf +from tests.conftest import RunIf from torch.utils.data import DataLoader from litgpt import Config diff --git a/tests/test_unsloth_executor.py b/tests/test_unsloth_executor.py index 797d1f6f53..15b1c7c673 100644 --- a/tests/test_unsloth_executor.py +++ b/tests/test_unsloth_executor.py @@ -1,10 +1,10 @@ import pytest import torch -from conftest import RunIf from litgpt import GPT, Config from litgpt.model import apply_rope, build_rope_cache from litgpt.utils import chunked_cross_entropy +from tests.conftest import RunIf @RunIf(min_cuda_gpus=1, thunder=True) diff --git a/tests/test_utils.py b/tests/test_utils.py index cbb5230621..9770bf98e7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,7 +11,7 @@ import torch import torch.nn.functional as F import yaml -from conftest import RunIf +from tests.conftest import RunIf from lightning import Fabric from lightning.fabric.loggers import CSVLogger, TensorBoardLogger from lightning.fabric.plugins import BitsandbytesPrecision