Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hypno/add bias #5

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft

Hypno/add bias #5

wants to merge 10 commits into from

Conversation

hypnopump
Copy link

Adds bias to attention.

bias

Many tests fail for me (that's why i'm adding draft PR), especially the BTHD and longer sequence ones (my GPU is 12Gb) but manual pytorch tests seem to match.

It can be tested with :

# Check same output as torch

import torch as th
from flag_attn import flash_attention

# B, H, T, D = 2, 16, 8192, 128
B, H, T, D = 1, 1, 2048, 16


th.manual_seed(17)
th.cuda.manual_seed(17)
q = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
k = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
v = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
bias = th.randn((B, H, T, T), dtype=th.float16, device="cuda:0").requires_grad_()
go = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0")
onobias = flash_attention(q, k, v, causal=False)
o_nobias = (q @ k.transpose(-1, -2) / q.shape[-1]**0.5).softmax(dim=-1) @ v
o_th_nobias = th.nn.functional.scaled_dot_product_attention(q, k, v)

o = flash_attention(q, k, v, bias, causal=False)
o_ = (q @ k.transpose(-1, -2) / q.shape[-1]**0.5 + bias).softmax(dim=-1) @ v
o_th = th.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias)
assert (o - o_th).amax() < 1e-3

gq_th, gk_th, gv_th, gbias_th = th.autograd.grad(
    o_th, (q, k, v, bias), go
)

gq, gk, gv, gbias = th.autograd.grad(
    o, (q, k, v, bias), go
)
assert (gbias - gbias_th).amax() < 1e-3

@hypnopump hypnopump marked this pull request as ready for review December 16, 2023 19:02
Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contribution.

But when I tested the commit I got some errors. Some were about shape of tensors and others were about conventions on taking tensor as branching condition. So make sure to pass the tests. We would setup CI soon.

Also, I have a concern about the shape of the bias. Since bias and the gradient of bias both have a shape broadcastable to (B, T, Tq, Tk), the memory footprint and traffic is square to the sequence length. While flash attention's design is to avoid a memory footprint & traffic that is square to the sequence length. Several test cases causes OOM even with 40GB memory.

So I suggest adding a new kernel.

@hypnopump
Copy link
Author

hypnopump commented Dec 18, 2023

Sorry for the additional burden; all tests are green now (at least on my 2060 with 12gbs).

With respect to the new kernel and O(N^2) of bias: the main improvement of flash attention is not only the reduced memory consumption, which was already proposed in 2021 but the acceleration due to the hardware-aware memory transfer limitation. In this sense, it is beneficial to use flash attention even if a bias is used, purely for speed.

Also the bias addition is a feature of flashv1 from Tri Dao's repo

@iclementine
Copy link
Collaborator

Yes. I see the point of speeding up with flash attention. If bias is already in (B, H, T, T), the memory is already taken, , and gradient of bias requires the same amount of memory as bias. So flash attention does not add extra memory consumptions because of bias.

I tested the new commit, most of the test cases has passed. But I got nan in some cases (on A100, 40GB).

        if use_bias:
>           assert gbias_triton_max_diff < 2 * gbias_torch_max_diff + 1e-5
E           assert nan < ((2 * 0.5625) + 1e-05)

test_flash_attention.py:153: AssertionError
============================================ short test summary info ============================================
FAILED test_flash_attention.py::test_attention_fwd_bwd[BHTD-dtype1-True-True-2-4-512-128-100-1.0-0] - assert nan < ((2 * 0.078125) + 1e-05)
FAILED test_flash_attention.py::test_attention_fwd_bwd[BHTD-dtype1-True-True-2-4-512-128-100-2.0-0] - assert nan < ((2 * 0.5625) + 1e-05)

I would take some time to investigate it~

@iclementine
Copy link
Collaborator

iclementine commented Dec 27, 2023

There is illegal memory access, at least at this case.

import torch
import flag_attn
import triton


B, H, M, N, D = 2, 16, 4000, 4096, 64
causal = True

q = torch.randn(B, H, M, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
b = torch.randn(B, H, M, N, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
o = flag_attn.flash_attention(q, k, v, b, causal=causal)
go = torch.randn_like(o)
gq, gk, gv, gb = torch.autograd.grad(o, (q, k, v, b), go)

print(gb.isnan().any())

Causal masking, large sequence length and headdim=64 could reveal the illegal memory access. Other configurations may also trigger illegal memory access, but I start from the case above.

The error message is

  File "/home/clement/projects/FlagAttention/src/flag_attn/flash.py", line 199, in backward
    _bwd_preprocess[grid](
  File "/home/clement/.virtualenvs/dev/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run
    bin.c_wrapper(
  File "/home/clement/.virtualenvs/dev/lib/python3.10/site-packages/triton/compiler/compiler.py", line 692, in __getattribute__
    self._init_handles()
  File "/home/clement/.virtualenvs/dev/lib/python3.10/site-packages/triton/compiler/compiler.py", line 683, in _in
it_handles
    mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
RuntimeError: Triton Error [CUDA]: unspecified launch failure

The problem is not about loading the kernel code of _bwd_preprocess but something else. You can also verify it via cuda memcheck.

I check memory error with cuda compute-sanitizer. The log indicates that the error is in the forward kernel. That's really strange~
memory_error.txt
图片

图片

BTW, the commit id of triton compiler I use in this test case is e28a256d71.

@hypnopump
Copy link
Author

hypnopump commented Dec 27, 2023

Oh, thanks for letting me know about this. I also checked that training a model with this resulted in nan gradients; but im not sure if it's related (it was with head_dim=32). Will try to dive deeper one of these days.

Will convert back to draft PR

@hypnopump hypnopump marked this pull request as draft December 27, 2023 15:33
@iclementine
Copy link
Collaborator

Actually head dim=32 and 128 would also result in the illegal memory access, but only with some M and N. (I have not figured out the exact conditions for this error)

I seems to a bug in the triton compiler(may be a bug in MaterializeLoadStorePass) . I am still working on it.

@iclementine
Copy link
Collaborator

iclementine commented Dec 28, 2023

I have found the bug in my implementation. The bug is not in triton compiler, but in our implementation. The upper bound of N-dimension when using causal masking is incorrect. See #8 for more details.

Thanks for your contribution, without which we may not find that bug in the code.

With this bugfix, this PR for adding bias is okay to continue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants