-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] How to match Flash Attention 2 performance? #98
Comments
|
CuDNN version: nvidia-smi:
|
I've attached both the benchmarking and CuDNN wrapper code to this post. I suspect the benchmarking code is off, so I'll switch to something simpler (like the Pytorch profiler), and see what the results are. |
You can try improvising on this Install FAV2 pip inside the container and go from there. Try out the latest container (24.07, just in case). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I wrote a helper that allows someone to use CuDNN attention within Pytorch seamlessly.
However, while this gets better forward pass performance. It gets far worse backwards pass performance. Any thoughts on why this might be the case? I'm hoping there might be some obvious deficiency in my code.
(Unit is ms).
The text was updated successfully, but these errors were encountered: