-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Balancing computation with zigzag blocking #2
Comments
Do you mean Striped Attention when you say zigzag blocking? Or is it something more simple which still gives you a much better utilization (lower latency) of the gpus? |
@andreaskoepf Oh... I haven't read the striped attention before... (thought that was some sparse attention mask version of ring attention like window attention from the name, my bad...) but from a quick look, it seems that stripe attention is doing some thing like:
I was thinking about doing something like:
which may be able to fold the causal mask from:
into
I'm not sure which could give better performace... |
@zhuzilin this is great work! Your zig zag pattern looks to be the optimal sharding for Ring Attention: it homogeneously spreads the computation across all ranks. Wanted to share my analysis in case it's helpful for others. Setup: the queries, keys, and values ( Ring attention divides the sequence length into
EfficiencyThe first iteration of ring attention is the same for sharding strategies. Because the queries and keys come from the same rank in this step, the The strategies all differ in how long the subsequent steps take, in which the queries attend to the keys from different ranks. We use the maximum number of positions that any rank attends to on a given iteration as a rough proxy for how much time that iteration takes. NaiveFor naive ring attention, the computation is always bottlenecked by (at least) rank The iteration time is determined by the slowest rank, so the time is set by Striped AttentionFor striped ring attention, once again rank Rank The iteration time is now set by Zig Zag AttentionFor zig zag ring attention, every rank performs the same amount of attention operations on every iteration: There are two possible scenarios for every rank:
The compute is thus homogeneous on every rank. This improves upon striped attention by reducing maximum operations per iteration by an approximate factor of Minimal Example:
|
Rank | Naive | Striped | Zig Zag |
---|---|---|---|
0 | {0, 1} | {0, 2} | {0, 4} |
1 | {2, 3} | {1, 3} | {2, 3} |
On the second iteration, rank-0
's queries are attending to rank-1
's keys and vice versa for rank-1
's queries. The aggregate number of positions each rank attends to in all cases is:
Rank | Naive | Striped | Zig Zag |
---|---|---|---|
0 | 0 | 1 | 2 |
1 | 4 | 3 | 2 |
The naive strategy is maximally imbalanced, the striped strategy is somewhat imbalanced, and zig zag is perfectly balanced.
EDIT: I noticed LLaMA3 uses zig-zag chunking as well, page 11 of the paper in the section on context parallelism. Though they don't actually use the form of ring attention implemented here, for GQA and attention masking reasons.
Currently the implementation will split the input sequence into n blocks, e.g. 4 gpu will split into:
however, this will result in uneven calculation, where the gpu that has
b3
will do around 4 times more calculation than the gpu that hasb0
, due to causal attention mask.If we split the input sequence into 2n blocks, e.g. 4 gpu will split into:
then all gpu will have the same amount of calculation, and theoratically the latency should be decrease by half.
The text was updated successfully, but these errors were encountered: