-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA] add persistent variant #77
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Can you take a look at the CI failure? If the kernel does not work on the Triton versions we test (pytorch and Triton-main), we need to add them to the skip list: https://github.com/pytorch-labs/tritonbench/tree/main/test/test_gpu |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
ba794c5
to
de3628b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. We are changing to fixing H
(num_heads) here and only changing seq_len, but it sounds fine to me. For those who are interested in reproducing FA3 paper numbers, we can add another option for input data shapes.
@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@manman-ren merged this pull request in 0a82d3d. |
Hongtao identified the performance issue with the initial implementation and updated the assignments of tiles to each SM.
Performance with warp specialization
(Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_tma_ws_persistent-tflops triton_tutorial_flash_v2_tma_ws-tflops triton_tutorial_flash_v2-tflops