Skip to content

Commit

Permalink
Fix performance issue with persistent
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Nov 26, 2024
1 parent 7747e0c commit 1a612c5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
62 changes: 57 additions & 5 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,52 @@ def _attn_fwd_inner_ws(
(24, 240)
] # , (40, 232)] #32,240 hangs, 24, 240 works 40, 232 works
]
configsTmaWSPersistent = [
(
triton.Config(
{
"BLOCK_M": BM,
"BLOCK_N": BN,
"ENABLE_TMA": enable_tma,
"LOOP_SCHEDULE": sched,
"GRID_MULTIPLE": mult,
"GRID_GROUP": gr,
},
num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0,
num_warps=w,
num_buffers_warp_spec=buf,
num_consumer_groups=grp,
reg_dec_producer=dec,
reg_inc_consumer=inc,
)
if has_warp_spec
else triton.Config(
{
"BLOCK_M": BM,
"BLOCK_N": BN,
"ENABLE_TMA": enable_tma,
"LOOP_SCHEDULE": sched,
"GRID_MULTIPLE": mult,
"GRID_GROUP": gr,
},
num_stages=2 if sched == "FA_firstDot" or sched == "FA_secondDot" else 0,
num_warps=w,
)
)
for BM in [128]
for BN in [128]
for mult in [1, 2, 4, 8, 16]
for gr in [0, 1]
for sched in schedList
for enable_tma in [True]
for enable_ws in [True]
for w in [4]
for buf in [2]
for grp in [2] # 2
for dec, inc in [
(24, 240)
] # , (40, 232)] #32,240 hangs, 24, 240 works 40, 232 works
]


def keep(conf):
Expand Down Expand Up @@ -1298,7 +1344,7 @@ def _attn_fwd_tma_ws( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
)


@triton.autotune(list(filter(keep, configsTmaWS)), key=["N_CTX"])
@triton.autotune(list(filter(keep, configsTmaWSPersistent)), key=["N_CTX"])
@triton.jit
def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
Q,
Expand Down Expand Up @@ -1337,6 +1383,8 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
ENABLE_TMA: tl.constexpr,
LOOP_SCHEDULE: tl.constexpr,
ENABLE_WS: tl.constexpr,
GRID_MULTIPLE: tl.constexpr,
GRID_GROUP: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
# original grid
Expand All @@ -1353,8 +1401,12 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #

tile_idx = prog_id
for _ in range(0, tiles_per_sm):
pid = tile_idx // (Z * H)
off_hz = tile_idx % (Z * H) # tl.program_id(1)
if GRID_GROUP == 0:
pid = tile_idx // (Z * H)
off_hz = tile_idx % (Z * H) # tl.program_id(1)
else:
pid = tile_idx % n_tile_num
off_hz = tile_idx // n_tile_num
_attn_fwd_compute_ws(
Q,
K,
Expand Down Expand Up @@ -1830,7 +1882,7 @@ def grid_tma_persistent(META):
if META["ENABLE_TMA"] == False:
return (
min(
NUM_SMS,
NUM_SMS * META["GRID_MULTIPLE"],
triton.cdiv(q.shape[2], META["BLOCK_M"])
* q.shape[0]
* q.shape[1],
Expand Down Expand Up @@ -1890,7 +1942,7 @@ def grid_tma_persistent(META):
)
return (
min(
NUM_SMS,
NUM_SMS * META["GRID_MULTIPLE"],
triton.cdiv(q.shape[2], META["BLOCK_M"]) * q.shape[0] * q.shape[1],
),
1,
Expand Down
1 change: 0 additions & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,6 @@ def tflops(
BATCH, H, N_CTX, D_HEAD = q.shape
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
tflops = 2 * flops_per_matmul
print("tflops, latency: ", tflops, metrics.latency)
if self.causal:
tflops *= 0.5
if self.mode == BenchmarkMode.BWD:
Expand Down

0 comments on commit 1a612c5

Please sign in to comment.