From 8bb917fa9d25087cdd927ad6c66c6d4f439f16fe Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 11 Oct 2024 22:29:51 -0400 Subject: [PATCH] [FA] Update autotune configs on default path (#2475) In this PR, we aim to allow configurations that are equivalent to the one used on advanced path. This PR gives 4% performance improvement on geomean of FA out of box. On advanced path, `BLOCK_M` is 128, `num_warps` can be `8` or `16`. Signed-off-by: Whitney Tsang --- .../flash_attention_fwd_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 0f36efb26f..2f7716a93c 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -154,10 +154,10 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # configs = [ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large'}, num_stages=s, num_warps=w) \ - for BM in [256] \ + for BM in [128, 256] \ for BN in [32, 64] \ - for s in [3] \ - for w in [32] \ + for s in [3, 4] \ + for w in [8, 16, 32] \ ] tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])