Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA] add persistent variant #77

Closed
wants to merge 8 commits into from

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Nov 25, 2024

Hongtao identified the performance issue with the initial implementation and updated the assignments of tiles to each SM.

Performance with warp specialization
(Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_tma_ws_persistent-tflops triton_tutorial_flash_v2_tma_ws-tflops triton_tutorial_flash_v2-tflops


         (8, 16, 8192, 128)                                              516.164                                   490.451                            423.905

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot
Copy link
Contributor

@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xuzhao9
Copy link
Contributor

xuzhao9 commented Nov 25, 2024

Can you take a look at the CI failure? If the kernel does not work on the Triton versions we test (pytorch and Triton-main), we need to add them to the skip list: https://github.com/pytorch-labs/tritonbench/tree/main/test/test_gpu

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren marked this pull request as ready for review November 26, 2024 22:25
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@xuzhao9 xuzhao9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We are changing to fixing H (num_heads) here and only changing seq_len, but it sounds fine to me. For those who are interested in reproducing FA3 paper numbers, we can add another option for input data shapes.

@facebook-github-bot
Copy link
Contributor

@manman-ren has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@manman-ren merged this pull request in 0a82d3d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants