Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable bwd for flash_attention #74

Closed
wants to merge 11 commits into from
Closed

Enable bwd for flash_attention #74

wants to merge 11 commits into from

Conversation

xuzhao9
Copy link
Contributor

@xuzhao9 xuzhao9 commented Nov 22, 2024

Some backends (e.g., flash_attention/xformers_splitk) don't have backward pass. For those backends, add fwd_only=True flag, and we will skip the backward pass automatically.

If user specifies --only xformers_splitk --bwd, we will still run the backend since it is user-specified. Otherwise we will always skip this backend.

Test plan:

python -m unittest test.test_gpu.main -k flash_attention

@adamomainz
Copy link
Contributor

can you please share a screenshot of the output or confirm it ran as expected?

@xuzhao9 xuzhao9 marked this pull request as draft November 22, 2024 20:30
@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Nov 22, 2024

@adamomainz Sorry it is not ready yet and I am still debugging - I will request review again when it is ready

@adamomainz
Copy link
Contributor

@xuzhao9 no problem at all happy to review again once ready!

@xuzhao9 xuzhao9 changed the title Run bwd for flash_attention Enable bwd for flash_attention Nov 22, 2024
@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xuzhao9 xuzhao9 marked this pull request as ready for review November 22, 2024 23:49
@facebook-github-bot
Copy link
Contributor

@xuzhao9 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuzhao9 merged this pull request in 2f12e83.

@xuzhao9 xuzhao9 deleted the xz9/fix-attn-bwd branch December 4, 2024 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants