Skip to content

Latest commit

 

History

History

softmax

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Softmax

0x00 说明

包含以下内容:

  • softmax_f32_kernel (grid level memory fence)
  • softmax_f32x4_kernel(grid level memory fence)
  • softmax_f32_per_token_kernel(per token)
  • softmax_f32x4_per_token_kernel(per token)
  • safe_softmax_f32_per_token_kernel(per token)
  • safe_softmax_f32x4_per_token_kernel(per token)
  • safe_softmax_f16_f32_per_token_kernel(per token)
  • safe_softmax_f16x2_f32_per_token_kernel(per token)
  • safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
  • online_safe_softmax_f32_per_token_kernel(per token, online softmax)
  • online_safe_softmax_f32x4_pack_per_token_kernel(per token, online softmax)
  • PyTorch bindings

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 softmax.py

输出:

----------------------------------------------------------------------------------------------------
                                             N=16384
----------------------------------------------------------------------------------------------------
          out_f32(fence): ['0.00011554  ', '1.172e-05   ', '3.789e-05   '], time:0.00707126ms
        out_f32x4(fence): ['0.00011554  ', '1.172e-05   ', '3.789e-05   '], time:0.00714874ms
              out_f32_th: ['0.00011554  ', '1.172e-05   ', '3.789e-05   '], time:0.00871110ms
----------------------------------------------------------------------------------------------------
                                             S=4096, H=256
----------------------------------------------------------------------------------------------------
            out_f32(per): ['0.00489144  ', '0.00030952  ', '0.00112878  '], time:0.01259184ms
          out_f32x4(per): ['0.00489144  ', '0.00030952  ', '0.00112878  '], time:0.01004362ms
           out_f32(safe): ['0.00489144  ', '0.00030952  ', '0.00112878  '], time:0.01583433ms
    out_f32(safe+online): ['0.00489144  ', '0.00030952  ', '0.00112878  '], time:0.01357031ms
  out_f32x4(safe+online): ['0.00489145  ', '0.00030952  ', '0.00112878  '], time:0.01050377ms
         out_f32x4(safe): ['0.00489144  ', '0.00030952  ', '0.00112878  '], time:0.01027584ms
         out_f32_th(per): ['0.00489144  ', '0.00030952  ', '0.00112878  '], time:0.01042914ms
----------------------------------------------------------------------------------------------------
        out_f16f32(safe): ['0.00489044  ', '0.00030971  ', '0.00112915  '], time:0.01418757ms
      out_f16x2f32(safe): ['0.00489044  ', '0.00030971  ', '0.00112915  '], time:0.00781608ms
  out_f16x8packf32(safe): ['0.00489044  ', '0.00030971  ', '0.00112915  '], time:0.00523329ms
         out_f16_th(per): ['0.00489044  ', '0.00030971  ', '0.00112915  '], time:0.00563836ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=512
----------------------------------------------------------------------------------------------------
            out_f32(per): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.02372313ms
          out_f32x4(per): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.02219534ms
           out_f32(safe): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.03100491ms
    out_f32(safe+online): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.02549100ms
  out_f32x4(safe+online): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.02228165ms
         out_f32x4(safe): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.02230835ms
         out_f32_th(per): ['0.00042486  ', '0.00308358  ', '0.00113099  '], time:0.02294350ms
----------------------------------------------------------------------------------------------------
        out_f16f32(safe): ['0.00042486  ', '0.00308418  ', '0.00113106  '], time:0.02967048ms
      out_f16x2f32(safe): ['0.00042486  ', '0.00308418  ', '0.00113106  '], time:0.01563406ms
  out_f16x8packf32(safe): ['0.00042486  ', '0.00308418  ', '0.00113106  '], time:0.01033092ms
         out_f16_th(per): ['0.00042486  ', '0.00308418  ', '0.00113106  '], time:0.01410413ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=1024
----------------------------------------------------------------------------------------------------
            out_f32(per): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.06144118ms
          out_f32x4(per): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.04208207ms
           out_f32(safe): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.08846235ms
    out_f32(safe+online): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.06275535ms
  out_f32x4(safe+online): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.04195666ms
         out_f32x4(safe): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.04199767ms
         out_f32_th(per): ['0.00015042  ', '0.00127817  ', '0.00087939  '], time:0.04214501ms
----------------------------------------------------------------------------------------------------
        out_f16f32(safe): ['0.00015044  ', '0.00127792  ', '0.00087929  '], time:0.07461023ms
      out_f16x2f32(safe): ['0.00015044  ', '0.00127792  ', '0.00087929  '], time:0.02805471ms
  out_f16x8packf32(safe): ['0.00015044  ', '0.00127792  ', '0.00087929  '], time:0.02210021ms
         out_f16_th(per): ['0.00015044  ', '0.00127792  ', '0.00087929  '], time:0.02429175ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=2048
----------------------------------------------------------------------------------------------------
          out_f32x4(per): ['0.00014777  ', '0.00018938  ', '9.769e-05   '], time:0.08160353ms
         out_f32x4(safe): ['0.00014777  ', '0.00018938  ', '9.769e-05   '], time:0.08181977ms
         out_f32_th(per): ['0.00014777  ', '0.00018938  ', '9.769e-05   '], time:0.10212374ms
----------------------------------------------------------------------------------------------------
      out_f16x2f32(safe): ['0.0001477   ', '0.00018942  ', '9.769e-05   '], time:0.07831120ms
  out_f16x8packf32(safe): ['0.0001477   ', '0.00018942  ', '9.769e-05   '], time:0.04206920ms
         out_f16_th(per): ['0.0001477   ', '0.00018942  ', '9.769e-05   '], time:0.05331278ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=4096
----------------------------------------------------------------------------------------------------
          out_f32x4(per): ['4.063e-05   ', '0.00038625  ', '0.00019391  '], time:0.16202784ms
         out_f32x4(safe): ['4.063e-05   ', '0.00038625  ', '0.00019391  '], time:0.16271973ms
         out_f32_th(per): ['4.063e-05   ', '0.00038625  ', '0.00019391  '], time:0.19028711ms
----------------------------------------------------------------------------------------------------
  out_f16x8packf32(safe): ['4.065e-05   ', '0.00038624  ', '0.00019383  '], time:0.08193207ms
         out_f16_th(per): ['4.065e-05   ', '0.00038624  ', '0.00019383  '], time:0.10132599ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
                                             S=4096, H=8192
----------------------------------------------------------------------------------------------------
  out_f16x8packf32(safe): ['0.00044656  ', '1.872e-05   ', '0.00054884  '], time:0.16337919ms
         out_f16_th(per): ['0.00044656  ', '1.872e-05   ', '0.00054884  '], time:0.18709970ms
----------------------------------------------------------------------------------------------------
                                             S=8192, H=8192
----------------------------------------------------------------------------------------------------
  out_f16x8packf32(safe): ['4.601e-05   ', '9.853e-05   ', '1.711e-05   '], time:0.32324409ms
         out_f16_th(per): ['4.601e-05   ', '9.853e-05   ', '1.711e-05   '], time:0.36632204ms
----------------------------------------------------------------------------------------------------