包含以下内容:
- softmax_f32_kernel (grid level memory fence)
- softmax_f32x4_kernel(grid level memory fence)
- softmax_f32_per_token_kernel(per token)
- softmax_f32x4_per_token_kernel(per token)
- safe_softmax_f32_per_token_kernel(per token)
- safe_softmax_f32x4_per_token_kernel(per token)
- safe_softmax_f16_f32_per_token_kernel(per token)
- safe_softmax_f16x2_f32_per_token_kernel(per token)
- safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
- online_safe_softmax_f32_per_token_kernel(per token, online softmax)
- online_safe_softmax_f32x4_pack_per_token_kernel(per token, online softmax)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 softmax.py
输出:
----------------------------------------------------------------------------------------------------
N=16384
----------------------------------------------------------------------------------------------------
out_f32(fence): ['0.00011554 ', '1.172e-05 ', '3.789e-05 '], time:0.00707126ms
out_f32x4(fence): ['0.00011554 ', '1.172e-05 ', '3.789e-05 '], time:0.00714874ms
out_f32_th: ['0.00011554 ', '1.172e-05 ', '3.789e-05 '], time:0.00871110ms
----------------------------------------------------------------------------------------------------
S=4096, H=256
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01259184ms
out_f32x4(per): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01004362ms
out_f32(safe): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01583433ms
out_f32(safe+online): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01357031ms
out_f32x4(safe+online): ['0.00489145 ', '0.00030952 ', '0.00112878 '], time:0.01050377ms
out_f32x4(safe): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01027584ms
out_f32_th(per): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01042914ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.01418757ms
out_f16x2f32(safe): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.00781608ms
out_f16x8packf32(safe): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.00523329ms
out_f16_th(per): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.00563836ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=512
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02372313ms
out_f32x4(per): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02219534ms
out_f32(safe): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.03100491ms
out_f32(safe+online): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02549100ms
out_f32x4(safe+online): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02228165ms
out_f32x4(safe): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02230835ms
out_f32_th(per): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02294350ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.02967048ms
out_f16x2f32(safe): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.01563406ms
out_f16x8packf32(safe): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.01033092ms
out_f16_th(per): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.01410413ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=1024
----------------------------------------------------------------------------------------------------
out_f32(per): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.06144118ms
out_f32x4(per): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04208207ms
out_f32(safe): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.08846235ms
out_f32(safe+online): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.06275535ms
out_f32x4(safe+online): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04195666ms
out_f32x4(safe): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04199767ms
out_f32_th(per): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04214501ms
----------------------------------------------------------------------------------------------------
out_f16f32(safe): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.07461023ms
out_f16x2f32(safe): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.02805471ms
out_f16x8packf32(safe): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.02210021ms
out_f16_th(per): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.02429175ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=2048
----------------------------------------------------------------------------------------------------
out_f32x4(per): ['0.00014777 ', '0.00018938 ', '9.769e-05 '], time:0.08160353ms
out_f32x4(safe): ['0.00014777 ', '0.00018938 ', '9.769e-05 '], time:0.08181977ms
out_f32_th(per): ['0.00014777 ', '0.00018938 ', '9.769e-05 '], time:0.10212374ms
----------------------------------------------------------------------------------------------------
out_f16x2f32(safe): ['0.0001477 ', '0.00018942 ', '9.769e-05 '], time:0.07831120ms
out_f16x8packf32(safe): ['0.0001477 ', '0.00018942 ', '9.769e-05 '], time:0.04206920ms
out_f16_th(per): ['0.0001477 ', '0.00018942 ', '9.769e-05 '], time:0.05331278ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=4096
----------------------------------------------------------------------------------------------------
out_f32x4(per): ['4.063e-05 ', '0.00038625 ', '0.00019391 '], time:0.16202784ms
out_f32x4(safe): ['4.063e-05 ', '0.00038625 ', '0.00019391 '], time:0.16271973ms
out_f32_th(per): ['4.063e-05 ', '0.00038625 ', '0.00019391 '], time:0.19028711ms
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['4.065e-05 ', '0.00038624 ', '0.00019383 '], time:0.08193207ms
out_f16_th(per): ['4.065e-05 ', '0.00038624 ', '0.00019383 '], time:0.10132599ms
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
S=4096, H=8192
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['0.00044656 ', '1.872e-05 ', '0.00054884 '], time:0.16337919ms
out_f16_th(per): ['0.00044656 ', '1.872e-05 ', '0.00054884 '], time:0.18709970ms
----------------------------------------------------------------------------------------------------
S=8192, H=8192
----------------------------------------------------------------------------------------------------
out_f16x8packf32(safe): ['4.601e-05 ', '9.853e-05 ', '1.711e-05 '], time:0.32324409ms
out_f16_th(per): ['4.601e-05 ', '9.853e-05 ', '1.711e-05 '], time:0.36632204ms
----------------------------------------------------------------------------------------------------