From 1a612c52752bad62008a10feacada688972bcde5 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Tue, 26 Nov 2024 15:16:10 -0800 Subject: [PATCH] Fix performance issue with persistent Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/kernels/triton_fused_attention.py | 62 +++++++++++++++++-- .../operators/flash_attention/operator.py | 1 - 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index 2cb663e4..62e2017f 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -502,6 +502,52 @@ def _attn_fwd_inner_ws( (24, 240) ] # , (40, 232)] #32,240 hangs, 24, 240 works 40, 232 works ] +configsTmaWSPersistent = [ + ( + triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + "GRID_MULTIPLE": mult, + "GRID_GROUP": gr, + }, + num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0, + num_warps=w, + num_buffers_warp_spec=buf, + num_consumer_groups=grp, + reg_dec_producer=dec, + reg_inc_consumer=inc, + ) + if has_warp_spec + else triton.Config( + { + "BLOCK_M": BM, + "BLOCK_N": BN, + "ENABLE_TMA": enable_tma, + "LOOP_SCHEDULE": sched, + "GRID_MULTIPLE": mult, + "GRID_GROUP": gr, + }, + num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0, + num_warps=w, + ) + ) + for BM in [128] + for BN in [128] + for mult in [1, 2, 4, 8, 16] + for gr in [0, 1] + for sched in schedList + for enable_tma in [True] + for enable_ws in [True] + for w in [4] + for buf in [2] + for grp in [2] # 2 + for dec, inc in [ + (24, 240) + ] # , (40, 232)] #32,240 hangs, 24, 240 works 40, 232 works +] def keep(conf): @@ -1298,7 +1344,7 @@ def _attn_fwd_tma_ws( # Q, V, desc_k, desc_v, sm_scale, M, Out, # ) -@triton.autotune(list(filter(keep, configsTmaWS)), key=["N_CTX"]) +@triton.autotune(list(filter(keep, configsTmaWSPersistent)), key=["N_CTX"]) @triton.jit def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, # Q, @@ -1337,6 +1383,8 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, # ENABLE_TMA: tl.constexpr, LOOP_SCHEDULE: tl.constexpr, ENABLE_WS: tl.constexpr, + GRID_MULTIPLE: tl.constexpr, + GRID_GROUP: tl.constexpr, ): tl.static_assert(BLOCK_N <= HEAD_DIM) # original grid @@ -1353,8 +1401,12 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, # tile_idx = prog_id for _ in range(0, tiles_per_sm): - pid = tile_idx // (Z * H) - off_hz = tile_idx % (Z * H) # tl.program_id(1) + if GRID_GROUP == 0: + pid = tile_idx // (Z * H) + off_hz = tile_idx % (Z * H) # tl.program_id(1) + else: + pid = tile_idx % n_tile_num + off_hz = tile_idx // n_tile_num _attn_fwd_compute_ws( Q, K, @@ -1830,7 +1882,7 @@ def grid_tma_persistent(META): if META["ENABLE_TMA"] == False: return ( min( - NUM_SMS, + NUM_SMS * META["GRID_MULTIPLE"], triton.cdiv(q.shape[2], META["BLOCK_M"]) * q.shape[0] * q.shape[1], @@ -1890,7 +1942,7 @@ def grid_tma_persistent(META): ) return ( min( - NUM_SMS, + NUM_SMS * META["GRID_MULTIPLE"], triton.cdiv(q.shape[2], META["BLOCK_M"]) * q.shape[0] * q.shape[1], ), 1, diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 9e2cc5d6..9ec2e846 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -450,7 +450,6 @@ def tflops( BATCH, H, N_CTX, D_HEAD = q.shape flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD tflops = 2 * flops_per_matmul - print("tflops, latency: ", tflops, metrics.latency) if self.causal: tflops *= 0.5 if self.mode == BenchmarkMode.BWD: