-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is Hydra much slower than Mamba? #9
Comments
It is not expected to be 4x slower. It could be possible due to GPU memory constraints, or Mamba2's autotune of Triton is not optimized for your GPU. Please try the memory efficient version using use_mem_eff_path=True. |
This is the code I use to test:
This is the results on RTX3090:
which shows Hydra to be at least 4x slower than Mamba. Can I know what's wrong with this test? |
Try with these please
|
The num_tokens, channel_size and num_heads tested above are actual settings used in my network. It is for an edge device and cannot support large models. Does that mean that hydra is only competitive in large networks? |
@sukjunhwang Do you plan on implementing a more efficient bidi scan? At least from programming my own version of Mamba2, my guess is that you could fuse most kernels together so that there is only a minor decrease in speed compared to uni directional. |
I don't think we are planning to, but I agree that it should be much more optimizable by fusing together more kernels. Would love to see community contributions here! |
@Hprairie How do you fuse kernels together to increase speed? How much faster do you get? |
You have to rewrite the Mamba2 kernel so that you process both the forward and backward scan. If you do it right then you will load less data and perform fewer floating point operations, saving a chunk of time. I'm working on an implementation rn in my free time, I'll send a link when it's finished, however, I'm busy with other work so it might be a couple weeks. |
Really looking forward to your work !!! |
For anyone interested, I have the fwd pass implemented and am working on the bwd pass rn. I have created a repo, where I will have the final kernels. A link can be found here: https://github.com/Hprairie/Bi-Mamba2 TLDR for benchmarking, the fwd pass has a slight overhead over the causal mode, which is to be expected. However, it is much better than flipping, especially since flipping in Pytorch will make a copy of the tensor. Here is a comparison of a causal mode and a bidirectional mode: Bidirectional Mamba Performance:
seqlen Causal Mamba2 Bi-Mamba2
0 2.0 0.109965 0.213840
1 4.0 0.108999 0.209290
2 8.0 0.107039 0.205770
3 16.0 0.106830 0.212640
4 32.0 0.110320 0.213240
5 64.0 0.111790 0.218340
6 128.0 0.126460 0.219845
7 256.0 0.137930 0.218290
8 512.0 0.113160 0.220190
9 1024.0 0.148601 0.228861
10 2048.0 0.215160 0.246040
11 4096.0 0.353721 0.437680
12 8192.0 0.579841 0.722961 I also want to note that these kernels aren't exactly the same as Hydra. The equation for Hydra is essentially: While this kernel does this: I do this as it makes it really simple to implement, however, later down the line I might add a kernel for Hydra. EDIT: Benchmarking was done on an AMD 7900 XTX and 3970X Threadripper. I will benchmark on an A100 and H100 soon. |
This is awesome, thanks for the contributions Hayden! |
Awesome!!! Really looking forward to your work !!! |
I tried unit testing a single block and it seems like Hydra is at least 4x slower. Is that expected?
The text was updated successfully, but these errors were encountered: