-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zigzag attention support? #20
Comments
hey David, no problem could you link me to the paper? did you see the rotation trick from Chris Fifty yet? |
Hi, could you link me to the paper? -> It's used in the Llama3 paper (https://arxiv.org/abs/2407.21783). Page 11 of the paper in the section on context parallelism. Though they don't actually use the form of ring attention implemented here, for GQA and attention masking reasons. did you see the rotation trick from Chris Fifty yet? -> I have not. What is it about? |
check out the vq repo nice! didn't even know Meta was using ring attention 🤣 I'll read the paper tomorrow |
guess all the big players will be using some form of sequence parallel attention soon (google, meta, and you at nvidia) |
@dwromero could i prompt you for a summary of what zigzag is? is it just another way to permute the sequence for better balancing? |
That's right |
@dwromero ok, should be an easy add! |
🤟🤟🤟 |
@dwromero oh, there is nothing to zigzag (did you coin that term?) it is just an all gather for keys and values, with GQA as justification |
@dwromero let me break this project into two, where i first handle the permuting they do, then offer the all gather for the key / values, both configurable. |
@dwromero actually, maybe it should just be a separate self contained file given how different it is |
I actually tried this with TransformerEngine and it works simply by splitting differently. Ran some tests and all seems to match. Do you think that would be sufficient here too? Basically, using a splitting like:
|
@dwromero yea that works for sharding the sequence but you'll need to handle the masking (maybe flex attention can come in handy here). and it seems like they project the key values on each rank separately then do the all gather? |
yea i don't know if i completely buy this. sure GQA can be enough savings that an all gather at 128k is fine, but how about 10 million? yea, this is definitely sequence parallelism in its crudest form, imo |
@dwromero made a bit of progress in the linked PR but out of steam will resume tomorrow morning feel free to leave any comments for anything that doesn't look right |
@dwromero alright, think i can knock out the remaining this morning you still there? |
@dwromero think it is all there in 0.5.19, you can play around with it by running the |
Wow cool! Thank you so much @lucidrains ! 💪 |
@dwromero no problem. if you can get me some nvidia cloud compute, i can throw in the flex attention logic. but not a big priority for now |
Hi @lucidrains ,
I hope you are doing well. And thank you for yet another useful repo! :)
I was wondering if you have any plans to support the zigzag version of ring attention. It seems to distributed compute better in autoregressive settings and is quite hot at the moment (zhuzilin/ring-flash-attention#2). I could help if you need help with that.
David
The text was updated successfully, but these errors were encountered: