diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 4ad3d8d5e5..ac45d629d9 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -305,11 +305,9 @@ def benchmark(B, M, N, K, provider): elif provider == 'xetla': if B == 1: c = torch.zeros((M, N), device='xpu', dtype=torch.float32) - acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32) else: c = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) - acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) cnt = torch.zeros((B, M, N), device='xpu', dtype=torch.int32) name = f'gemm_shape_{B}_{M}_{K}_{N}' # FIXME: Use gemm_streamk_benchmark.py when Triton streamk can get @@ -317,7 +315,18 @@ def benchmark(B, M, N, K, provider): if (B, M, N, K) == (1, 3072, 3072, 4096): name = 'gemm_streamk_shape_3072_4096_3072' func = getattr(xetla_kernel, name) - xetla_fn = lambda: func(a, b, c, acc, cnt) + + + def xetla_func_with_acc_allocation(): + # allocating `acc` matrix on every function call, to be as similar as + # possible to the triton kernel, which also does this on every call. + if B == 1: + acc = torch.zeros((M, N), device='xpu', dtype=torch.float32) + else: + acc = torch.zeros((B, M, N), device='xpu', dtype=torch.float32) + return func(a, b, c, acc, cnt) + + xetla_fn = xetla_func_with_acc_allocation torch_fn = lambda: torch.matmul(a, b).to(torch.float32) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')