Skip to content

Commit

Permalink
Improve GEMM performance of shape 4096x8x128x16384 (#2646)
Browse files Browse the repository at this point in the history
This change (`grid` order adjustment to improve cache hit) originating
from #2600.
Batched gemm only.
~99% of XeTLA for `4096x8x128x16384`.

![image](https://github.com/user-attachments/assets/ef7e9750-b3f7-4adc-aa66-5be704383e40)
  • Loading branch information
ESI-SYD authored Nov 11, 2024
1 parent 85682e4 commit ca95a70
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched(
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
bid = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
bid = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
Expand Down Expand Up @@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
B = a.shape[0]
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
B,
)
matmul_kernel_with_block_pointers_batched[grid](
a, b, c, #
Expand Down

0 comments on commit ca95a70

Please sign in to comment.