Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix uneven head sequence parallelism bug (#6774) #6797

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Eugene29
Copy link

Here gather_idx < 2 represents is_first_all2all. During the first all2all, uneven_head_all2all will be called if either num_heads % seq_world_size != 0 or get_num_kv_heads() is None.

During the second all2all, it'll return return uneven_head_all2all if and only if get_num_kv_heads() is None which is always set during the first uneven all2all. This means that there will no longer be issue where uneven_head_all2all is returned for the second all2all because of num_heads % seq_world_size != 0.

@Eugene29
Copy link
Author

Fix #6774

@inkcherry
Copy link
Contributor

LGTM, Many thanks!

@@ -155,7 +155,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
# we only need num_heads once
num_heads = input.shape[2]

if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and gather_idx < 2):
Copy link
Contributor

@inkcherry inkcherry Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, can we use scatter_idx for the judgment here? This way, we can stay consistent with other parts by uniformly using scatter_idx to determine whether the current state is total_head or total_seq_len

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto @inkcherry! Future refactor will have a better naming to differentiate between N/p (pre/post-attention) and h/p (attention block)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I agree, scatter_idx booleans are confusing. Maybe I can share my opinion on the refactoring on Wed if thats ok.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants