-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flex attention with dropout #77
Comments
You are correct, we dont currently have post-softmax dropout implemented. We have this is a feature but we have seen decreasing adoption of this throughout the industry and don't have it high pri. |
Thank you for the reply. In this case, I just would like to know if it is possible to implement a pre-softmax dropout under the current framework. The main question here is whether I can use rand function within mask_mod or score_mod? Will the forward and backward process compute the same mask? Another question is, can I avoid the need to call the create block mask for different forward pass? |
So the naive way to implement this is import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from functools import partial
B, H, S, D = 1, 4, 256, 64
dropout_prob = 0.1
full_dropout = bool_mask = (torch.rand((B, H, S, D), device="cuda") > dropout_prob)
def dropout(score, b, h, q_idz, kv_idx):
return torch.where(full_dropout[b, h, q_idz, kv_idx], -float("inf"), score)
if __name__ == "__main__":
make_tensor = partial(torch.randn, (B, H, S, D), device="cuda", dtype=torch.float16, requires_grad=True)
query, key, value = make_tensor(), make_tensor(), make_tensor()
compiled_flex = torch.compile(flex_attention, fullgraph=True)
out = compiled_flex(query, key, value, score_mod=dropout)
print(out) There is probs some of other fun things you can do to try and reduce the extra memory to store the mask but this is the most straightforward |
Hi,
I found the flex attention package really useful and flexible. However, it seems that flex attention does not support dropout, which is quite widely adopted. I would like to know if this would be supported in future?
Besides, I also considered implementing dropout in the mask, although it is not equivalent to applying dropout after softmax. However, even in this setting, I am not sure how to make the implementation correct, as the dropout mask cannot be generated on the fly (it must be the same in both forward and backward propagation).
Can anyone elaborate on this? Thank you so much!
The text was updated successfully, but these errors were encountered: