-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hypno/add bias #5
base: main
Are you sure you want to change the base?
Conversation
7e3049e
to
4cfda6c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contribution.
But when I tested the commit I got some errors. Some were about shape of tensors and others were about conventions on taking tensor as branching condition. So make sure to pass the tests. We would setup CI soon.
Also, I have a concern about the shape of the bias. Since bias and the gradient of bias both have a shape broadcastable to (B, T, Tq, Tk)
, the memory footprint and traffic is square to the sequence length. While flash attention's design is to avoid a memory footprint & traffic that is square to the sequence length. Several test cases causes OOM even with 40GB memory.
So I suggest adding a new kernel.
Sorry for the additional burden; all tests are green now (at least on my 2060 with 12gbs). With respect to the new kernel and O(N^2) of bias: the main improvement of flash attention is not only the reduced memory consumption, which was already proposed in 2021 but the acceleration due to the hardware-aware memory transfer limitation. In this sense, it is beneficial to use flash attention even if a bias is used, purely for speed. Also the bias addition is a feature of flashv1 from Tri Dao's repo |
Yes. I see the point of speeding up with flash attention. If bias is already in I tested the new commit, most of the test cases has passed. But I got
I would take some time to investigate it~ |
There is illegal memory access, at least at this case. import torch
import flag_attn
import triton
B, H, M, N, D = 2, 16, 4000, 4096, 64
causal = True
q = torch.randn(B, H, M, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
k = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
v = torch.randn(B, H, N, D, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
b = torch.randn(B, H, M, N, dtype=torch.bfloat16, device="cuda:0", requires_grad=True)
o = flag_attn.flash_attention(q, k, v, b, causal=causal)
go = torch.randn_like(o)
gq, gk, gv, gb = torch.autograd.grad(o, (q, k, v, b), go)
print(gb.isnan().any()) Causal masking, large sequence length and headdim=64 could reveal the illegal memory access. Other configurations may also trigger illegal memory access, but I start from the case above. The error message is
The problem is not about loading the kernel code of I check memory error with cuda compute-sanitizer. The log indicates that the error is in the forward kernel. That's really strange~ BTW, the commit id of triton compiler I use in this test case is |
Oh, thanks for letting me know about this. I also checked that training a model with this resulted in nan gradients; but im not sure if it's related (it was with head_dim=32). Will try to dive deeper one of these days. Will convert back to draft PR |
Actually head dim=32 and 128 would also result in the illegal memory access, but only with some M and N. (I have not figured out the exact conditions for this error) I seems to a bug in the triton compiler(may be a bug in |
I have found the bug in my implementation. The bug is not in triton compiler, but in our implementation. The upper bound of N-dimension when using causal masking is incorrect. See #8 for more details. Thanks for your contribution, without which we may not find that bug in the code. With this bugfix, this PR for adding bias is okay to continue. |
Adds bias to attention.
Many tests fail for me (that's why i'm adding draft PR), especially the BTHD and longer sequence ones (my GPU is 12Gb) but manual pytorch tests seem to match.
It can be tested with :