Skip to content

Commit

Permalink
Add SDPA backend test (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 24, 2023
1 parent 60de2d0 commit 50b1974
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 63 deletions.
2 changes: 0 additions & 2 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path:


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
2 changes: 0 additions & 2 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_pa


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
2 changes: 0 additions & 2 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ def save_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
2 changes: 0 additions & 2 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,6 @@ def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Pa


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
13 changes: 5 additions & 8 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ def forward(
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

# repeat k and v if necessary
if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!)
# for MHA this is a no-op
# maybe repeat k and v if for the non multi-head attention cases
# training: flash attention requires it
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)

Expand Down Expand Up @@ -289,11 +290,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def build_rope_cache(
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Expand Down
2 changes: 0 additions & 2 deletions pretrain/openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ def get_lr(it: int) -> float:


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
2 changes: 0 additions & 2 deletions pretrain/openwebtext_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def get_lr(it: int) -> float:


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
2 changes: 0 additions & 2 deletions pretrain/redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,6 @@ def get_lr(it: int) -> float:


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ def __eq__(self, other):
@pytest.fixture()
def tensor_like():
return TensorLike()


@pytest.fixture(autouse=True)
def restore_default_dtype():
# just in case
torch.set_default_dtype(torch.float32)
2 changes: 2 additions & 0 deletions tests/test_gptq.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import lightning as L
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2


@pytest.mark.skipif(_TORCH_GREATER_EQUAL_2_2, reason="Core dumped")
def test_gptq_blockwise_quantization():
from quantize.gptq import _TRITON_AVAILABLE

Expand Down
114 changes: 90 additions & 24 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import operator
import sys
from functools import partial
from pathlib import Path
from urllib.request import urlretrieve

import pytest
import torch
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2
from lightning_utilities.core.imports import compare_version

wd = Path(__file__).parent.parent.absolute()
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))


@pytest.fixture(autouse=True)
def restore_default_dtype():
# just in case
torch.set_default_dtype(torch.float32)
import lit_gpt.config as config_module


@torch.inference_mode()
Expand All @@ -26,12 +26,14 @@ def restore_default_dtype():
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"), torch.float16, marks=[
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
),
],
)
Expand Down Expand Up @@ -97,12 +99,14 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"), torch.float16, marks=[
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
),
],
)
Expand Down Expand Up @@ -146,12 +150,14 @@ def test_against_hf_falcon(kwargs, device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"), torch.float16, marks=[
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
),
],
)
Expand Down Expand Up @@ -210,12 +216,14 @@ def test_against_original_open_llama_3b(device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"), torch.float16, marks=[
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
),
],
)
Expand Down Expand Up @@ -267,10 +275,12 @@ def test_against_hf_llama2(ours_kwargs, device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"), torch.float16, marks=[
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
),
],
)
Expand Down Expand Up @@ -325,12 +335,14 @@ def test_against_hf_phi(device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"), torch.float16, marks=[
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
),
],
)
Expand Down Expand Up @@ -453,3 +465,57 @@ def test_model_kv_cache_amp():
with torch.autocast("cpu", torch.bfloat16):
output = model(encoded.unsqueeze(0), encoded)
assert output.dtype is torch.bfloat16


# https://github.com/pytorch/pytorch/blob/ad3572a5d/torch/testing/_internal/common_cuda.py#L31-L34
SUPPORTS_FLASH_ATTENTION = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) and not _IS_WINDOWS
)
SUPPORTS_MEM_EFF_ATTENTION = torch.cuda.is_available()
SUPPORTS_FUSED_ATTENTION = SUPPORTS_FLASH_ATTENTION or SUPPORTS_MEM_EFF_ATTENTION


@pytest.mark.skipif(not SUPPORTS_FUSED_ATTENTION, reason="Unsupported")
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice(config):
from torch.backends.cuda import SDPBackend

from lit_gpt import GPT

torch.set_default_dtype(torch.float16)

def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
choice = torch._fused_sdp_choice(q, k, v, mask, is_causal=True)
assert choice == expected
return original_fn(q, k, v, mask)

config["n_layer"] = 1
config = config_module.Config(**config)

try:
with torch.device("cuda"):
model = GPT(config)
x = torch.randint(0, 10, (2, 16), dtype=torch.int32)
except torch.cuda.OutOfMemoryError:
# best effort, if the GPU can load it
pytest.xfail()

for h in model.transformer.h:
h.attn.scaled_dot_product_attention = partial(assert_sdpa_uses_flash, h.attn.scaled_dot_product_attention)

if SUPPORTS_FLASH_ATTENTION:
# flash attention 1 requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or
# equal to 128
expected = (
SDPBackend.FLASH_ATTENTION
if _TORCH_GREATER_EQUAL_2_2 or (config.head_size <= 128 and config.head_size % 8 == 0)
else SDPBackend.MATH
)
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
model(x)

if SUPPORTS_MEM_EFF_ATTENTION:
expected = SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 else SDPBackend.MATH
with torch.backends.cuda.sdp_kernel(enable_flash=False):
model(x)
5 changes: 0 additions & 5 deletions tutorials/finetune_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,3 @@ With only a few modifications, you can prepare and train on your own instruction
--checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b \
--out_dir data/mydata-finetuned
```
## Troubleshooting
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the finetune script (see <https://github.com/Lightning-AI/lit-llama/issues/101>).
5 changes: 0 additions & 5 deletions tutorials/finetune_full.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,3 @@ With only a few modifications, you can prepare and train on your own instruction
--checkpoint_dir checkpoints/tiiuae/falcon-7b \
--out_dir data/mydata-finetuned
```
## Troubleshooting
If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the finetune script (see <https://github.com/Lightning-AI/lit-llama/issues/101>).
7 changes: 0 additions & 7 deletions tutorials/finetune_lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,3 @@ python eval/lm_eval_harness.py \
--batch_size 4 \
--save_filepath "results.json"
```

&nbsp;

## Troubleshooting

If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see <https://github.com/Lightning-AI/lit-llama/issues/101>).

0 comments on commit 50b1974

Please sign in to comment.