Skip to content

Commit

Permalink
Cherry-pick from #3026
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 18, 2024
1 parent 352de9d commit d82e3ea
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,28 @@ def benchmark(B, M, N, K, provider):
elif provider == 'xetla':
if B == 1:
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)
else:
c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32)
name = f'gemm_shape_{B}_{M}_{K}_{N}'
# FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get
# better performance.
if (B, M, N, K) == (1, 3072, 3072, 4096):
name = 'gemm_streamk_shape_3072_4096_3072'
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(a, b, c, acc, cnt)


def xetla_func_with_acc_allocation():
# allocating `acc` matrix on every function call, to be as similar as
# possible to the triton kernel, which also does this on every call.
if B == 1:
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
else:
acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32)
return func(a, b, c, acc, cnt)

xetla_fn = xetla_func_with_acc_allocation
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
Expand Down

0 comments on commit d82e3ea

Please sign in to comment.