Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is Hydra much slower than Mamba? #9

Open
yxchng opened this issue Aug 3, 2024 · 12 comments
Open

Is Hydra much slower than Mamba? #9

yxchng opened this issue Aug 3, 2024 · 12 comments

Comments

@yxchng
Copy link

yxchng commented Aug 3, 2024

I tried unit testing a single block and it seems like Hydra is at least 4x slower. Is that expected?

@sukjunhwang
Copy link
Collaborator

It is not expected to be 4x slower. It could be possible due to GPU memory constraints, or Mamba2's autotune of Triton is not optimized for your GPU. Please try the memory efficient version using use_mem_eff_path=True.

@yxchng
Copy link
Author

yxchng commented Aug 6, 2024

This is the code I use to test:

import torch
import torch.nn as nn
import flash_attn
from mamba_ssm import Mamba
from hydra import Hydra

HH = [1,2,4,8,16]
# HH = [2,4,8,16,32]
CC = [32,64,128,256,512]
TT = [198656,70656,18432,5120,994]
num_trials = 1000



with torch.no_grad():
    for i in range(len(HH)):
        H = HH[i]
        C = CC[i]
        T = TT[i]
        print('num_tokens:', T, 'channel_size:', C, 'num_heads:', H)


        scale = (C // H) ** -0.5
        x = torch.randn(T,C).cuda()


        if T == 198656:
            cu_seqlens = torch.tensor([     
                0,   1024,   2048,   3072,   4096,   5120,   6144,   7168,   8192,
                9216,  10240,  11264,  12288,  13312,  14336,  15360,  16384,  17408,
                18432,  19456,  20480,  21504,  22528,  23552,  24576,  25600,  26624,
                27648,  28672,  29696,  30720,  31744,  32768,  33792,  34816,  35840,
                36864,  37888,  38912,  39936,  40960,  41984,  43008,  44032,  45056,
                46080,  47104,  48128,  49152,  50176,  51200,  52224,  53248,  54272,
                55296,  56320,  57344,  58368,  59392,  60416,  61440,  62464,  63488,
                64512,  65536,  66560,  67584,  68608,  69632,  70656,  71680,  72704,
                73728,  74752,  75776,  76800,  77824,  78848,  79872,  80896,  81920,
                82944,  83968,  84992,  86016,  87040,  88064,  89088,  90112,  91136,
                92160,  93184,  94208,  95232,  96256,  97280,  98304,  99328, 100352,
                101376, 102400, 103424, 104448, 105472, 106496, 107520, 108544, 109568,
                110592, 111616, 112640, 113664, 114688, 115712, 116736, 117760, 118784,
                119808, 120832, 121856, 122880, 123904, 124928, 125952, 126976, 128000,
                129024, 130048, 131072, 132096, 133120, 134144, 135168, 136192, 137216,
                138240, 139264, 140288, 141312, 142336, 143360, 144384, 145408, 146432,
                147456, 148480, 149504, 150528, 151552, 152576, 153600, 154624, 155648,
                156672, 157696, 158720, 159744, 160768, 161792, 162816, 163840, 164864,
                165888, 166912, 167936, 168960, 169984, 171008, 172032, 173056, 174080,
                175104, 176128, 177152, 178176, 179200, 180224, 181248, 182272, 183296,
                184320, 185344, 186368, 187392, 188416, 189440, 190464, 191488, 192512,
                193536, 194560, 195584, 196608, 197632, 198656], dtype=torch.int32).cuda()
        elif T == 70656:
            cu_seqlens = torch.tensor([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,  9216,
                    10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408, 18432, 19456,
                    20480, 21504, 22528, 23552, 24576, 25600, 26624, 27648, 28672, 29696,
                    30720, 31744, 32768, 33792, 34816, 35840, 36864, 37888, 38912, 39936,
                    40960, 41984, 43008, 44032, 45056, 46080, 47104, 48128, 49152, 50176,
                    51200, 52224, 53248, 54272, 55296, 56320, 57344, 58368, 59392, 60416,
                    61440, 62464, 63488, 64512, 65536, 66560, 67584, 68608, 69632, 70656],  dtype=torch.int32).cuda()

        elif T == 18432:
            cu_seqlens = torch.tensor([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,  9216,
                    10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408, 18432],  dtype=torch.int32).cuda()
        elif T ==  5120:
            cu_seqlens = torch.tensor([   0, 1024, 2048, 3072, 4096, 5120], dtype=torch.int32).cuda()
        elif T == 994:
            cu_seqlens = torch.tensor([  0, 994], dtype=torch.int32).cuda()




        qkv_proj = torch.nn.Linear(C,C*3).cuda()
        total = 0
        count = 0
        for i in range(num_trials):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            qkv = qkv_proj(x)
            feat = flash_attn.flash_attn_varlen_qkvpacked_func(
                qkv.half().reshape(-1, 3, H, C // H),
                cu_seqlens,
                max_seqlen=1024,
                dropout_p=0.0,
                softmax_scale=scale,
            ).reshape(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            if i > 100: # warmup
                total +=  elap
                count += 1
        print('flash attention', total / count)

        hydra = Hydra(
            d_model=C, # Model dimension d_model
            d_state=1,  # SSM state expansion factor
            d_conv=3,    # Local non-causal convolution width
            expand=2,    # Block expansion factor
            use_mem_eff_path=True
        ).to("cuda")

        total = 0
        count = 0
        for i in range(num_trials):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            y = hydra(x.unsqueeze(0))
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            if i > 100:
                total +=  elap
                count += 1

        print('hydra', total / count)

        mamba = Mamba(
            d_model=C, # Model dimension d_model
            d_state=1,  # SSM state expansion factor
            d_conv=4,    # Local convolution width
            expand=2,    # Block expansion factor
        ).to("cuda")

        total = 0
        count = 0
        for i in range(num_trials):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            y = mamba(x.unsqueeze(0))
            feat = feat.view(-1, C)
            end.record()
            torch.cuda.synchronize()
            elap = start.elapsed_time(end)
            if i > 100:
                total +=  elap
                count += 1

        print('mamba', total / count)

This is the results on RTX3090:

num_tokens: 198656 channel_size: 32 num_heads: 1
flash attention 1.4398738327493126
hydra 5.19853590009475
mamba 1.2737170286783255
num_tokens: 70656 channel_size: 64 num_heads: 2
flash attention 1.2244021445678523
hydra 3.9700514180243878
mamba 0.9193982243272698
num_tokens: 18432 channel_size: 128 num_heads: 4
flash attention 0.5656079320409009
hydra 2.4510892373701356
mamba 0.6655043213067782
num_tokens: 5120 channel_size: 256 num_heads: 8
flash attention 0.37106161046611585
hydra 2.1477266787421847
mamba 0.5596496043136309
num_tokens: 994 channel_size: 512 num_heads: 16
flash attention 0.2605083365328983
hydra 2.1441056672670156
mamba 0.4217387858169097

which shows Hydra to be at least 4x slower than Mamba.

Can I know what's wrong with this test?

@sukjunhwang
Copy link
Collaborator

Try with these please

  1. Turn AMP on, and convert x to bf16 or fp16
  2. Compare modules with practical settings: d_state >= 16

@yxchng
Copy link
Author

yxchng commented Aug 6, 2024

  1. Tried this, get a little bit faster, but still much slower than Mamba
  2. For my application, d_state doesn't make much difference, so I only use d_state=1. Even for d_state=16, the gap is still about 2x.

The num_tokens, channel_size and num_heads tested above are actual settings used in my network. It is for an edge device and cannot support large models.

Does that mean that hydra is only competitive in large networks?

@Hprairie
Copy link

Hprairie commented Aug 6, 2024

@sukjunhwang Do you plan on implementing a more efficient bidi scan? At least from programming my own version of Mamba2, my guess is that you could fuse most kernels together so that there is only a minor decrease in speed compared to uni directional.

@albertfgu
Copy link
Collaborator

I don't think we are planning to, but I agree that it should be much more optimizable by fusing together more kernels. Would love to see community contributions here!

@yxchng
Copy link
Author

yxchng commented Aug 11, 2024

@Hprairie How do you fuse kernels together to increase speed? How much faster do you get?

@Hprairie
Copy link

You have to rewrite the Mamba2 kernel so that you process both the forward and backward scan. If you do it right then you will load less data and perform fewer floating point operations, saving a chunk of time. I'm working on an implementation rn in my free time, I'll send a link when it's finished, however, I'm busy with other work so it might be a couple weeks.

@Sh3lock
Copy link

Sh3lock commented Aug 13, 2024

You have to rewrite the Mamba2 kernel so that you process both the forward and backward scan. If you do it right then you will load less data and perform fewer floating point operations, saving a chunk of time. I'm working on an implementation rn in my free time, I'll send a link when it's finished, however, I'm busy with other work so it might be a couple weeks.

Really looking forward to your work !!!

@Hprairie
Copy link

Hprairie commented Aug 14, 2024

For anyone interested, I have the fwd pass implemented and am working on the bwd pass rn. I have created a repo, where I will have the final kernels. A link can be found here: https://github.com/Hprairie/Bi-Mamba2

TLDR for benchmarking, the fwd pass has a slight overhead over the causal mode, which is to be expected. However, it is much better than flipping, especially since flipping in Pytorch will make a copy of the tensor.

Here is a comparison of a causal mode and a bidirectional mode:

Bidirectional Mamba Performance:
    seqlen  Causal Mamba2  Bi-Mamba2
0      2.0       0.109965   0.213840
1      4.0       0.108999   0.209290
2      8.0       0.107039   0.205770
3     16.0       0.106830   0.212640
4     32.0       0.110320   0.213240
5     64.0       0.111790   0.218340
6    128.0       0.126460   0.219845
7    256.0       0.137930   0.218290
8    512.0       0.113160   0.220190
9   1024.0       0.148601   0.228861
10  2048.0       0.215160   0.246040
11  4096.0       0.353721   0.437680
12  8192.0       0.579841   0.722961

I also want to note that these kernels aren't exactly the same as Hydra. The equation for Hydra is essentially:

$$ y = shift(SS(x)) + flip(shift(SS(flip(x)))) + Dx $$

While this kernel does this:

$$ y = SS(x) + flip(SS(flip(x))) + Dx $$

I do this as it makes it really simple to implement, however, later down the line I might add a kernel for Hydra.

EDIT: Benchmarking was done on an AMD 7900 XTX and 3970X Threadripper. I will benchmark on an A100 and H100 soon.

@albertfgu
Copy link
Collaborator

This is awesome, thanks for the contributions Hayden!

@YicongHong
Copy link

For anyone interested, I have the fwd pass implemented and am working on the bwd pass rn. I have created a repo, where I will have the final kernels. A link can be found here: https://github.com/Hprairie/Bi-Mamba2

Awesome!!! Really looking forward to your work !!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants