From 2e3652858459eb5171cb324dad3c88608d970f3f Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Tue, 17 Dec 2024 22:45:09 +0000 Subject: [PATCH 1/4] add int8 to addmatrix benchmark 1/? --- .../gemm_postop_addmatrix_benchmark.py | 50 ++++++++++++------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 307100dcfe..d4b7fc535a 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -46,7 +46,8 @@ def matmul_kernel_with_block_pointers( stride_am: tl.constexpr, stride_ak: tl.constexpr, # stride_bk: tl.constexpr, stride_bn: tl.constexpr, # stride_cm: tl.constexpr, stride_cn: tl.constexpr, # - stride_dm: tl.constexpr, stride_dn: tl.constexpr, + stride_dm: tl.constexpr, stride_dn: tl.constexpr, # + ACCUMULATOR_DTYPE: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): pid = tl.program_id(axis=0) @@ -66,7 +67,7 @@ def matmul_kernel_with_block_pointers( offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) for _ in range(0, K, BLOCK_SIZE_K): a = tl.load(a_block_ptr, boundary_check=(0, 1)) b = tl.load(b_block_ptr, boundary_check=(0, 1)) @@ -121,6 +122,7 @@ def matmul_kernel_with_block_pointers_batched( stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, # stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, + ACCUMULATOR_DTYPE: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): bid = tl.program_id(axis=0) @@ -144,7 +146,7 @@ def matmul_kernel_with_block_pointers_batched( offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) for _ in range(0, K, BLOCK_SIZE_K): a = tl.load(a_block_ptr, boundary_check=(0, 1)) b = tl.load(b_block_ptr, boundary_check=(0, 1)) @@ -188,7 +190,8 @@ def matmul(a, b, d, c): a.stride(0), a.stride(1), a.stride(2), # b.stride(0), b.stride(1), b.stride(2), # c.stride(0), c.stride(1), c.stride(2), # - d.stride(0), d.stride(1), d.stride(2)) + d.stride(0), d.stride(1), d.stride(2), # + tl.float32 if a.dtype.is_floating_point else tl.int32) elif len(a.shape) == 2 and len(b.shape) == 2: assert a.shape[1] == b.shape[0], 'Incompatible dimensions' assert a.is_contiguous(), 'Matrix A must be contiguous' @@ -202,7 +205,8 @@ def matmul(a, b, d, c): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - d.stride(0), d.stride(1)) + d.stride(0), d.stride(1), # + tl.float32 if a.dtype.is_floating_point else tl.int32) else: assert False, 'Input matrixs dimensions mismatch' return c @@ -212,10 +216,10 @@ def matmul(a, b, d, c): @benchmark_suit.perf_report( benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot - x_names=['B', 'M', 'K', 'N'], + x_names=['B', 'M', 'K', 'N', 'dtype'], # different possible values for `x_name` - x_vals=[[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + # - [ # + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in [torch.bfloat16, torch.int8]] + # + [[*shape, dtype] for shape in [ [1, 1, 5120, 13824], # [1, 4, 4096, 12288], # [1, 512, 8192, 8192], # @@ -236,7 +240,7 @@ def matmul(a, b, d, c): [32, 4096, 4096, 128], # [4096, 8, 128, 16384], # [4096, 8, 16384, 128] - ], + ] for dtype in [torch.bfloat16, torch.int8]], line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -250,29 +254,37 @@ def matmul(a, b, d, c): # name for the plot. Used also as a file name for saving the plot. args={}, )) -def benchmark(B, M, N, K, provider): +def benchmark(B, M, N, K, dtype, provider): + res_dtype = torch.float32 if dtype is torch.bfloat16 else torch.int32 + if dtype.is_floating_point: + rand = lambda shape, dtype : torch.rand(shape, device='xpu', dtype=dtype) + else: + rand = lambda shape, dtype : torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) if B == 1: - a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16) - b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16) - d = torch.rand((M, N), device='xpu', dtype=torch.float32) + a = rand((M, K), dtype) + b = rand((K, N), dtype) + d = rand((M, N), res_dtype) else: - a = torch.rand((B, M, K), device='xpu', dtype=torch.bfloat16) - b = torch.rand((B, K, N), device='xpu', dtype=torch.bfloat16) - d = torch.rand((B, M, N), device='xpu', dtype=torch.float32) + a = rand((B, M, K), dtype) + b = rand((B, K, N), dtype) + d = rand((B, M, N), res_dtype) quantiles = [0.5, 0.0, 1.0] if provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: - c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + c = torch.empty((B, M, N), device='xpu', dtype=res_dtype) kernel_name = 'matmul_kernel_with_block_pointers_batched' else: assert len(a.shape) == 2, 'Expecting shape of length 2' - c = torch.empty((M, N), device='xpu', dtype=torch.float32) + c = torch.empty((M, N), device='xpu', dtype=res_dtype) kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, d, c) - torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d + # Torch does not support integer calculation in matmul + torch_device = 'xpu' if dtype.is_floating_point else 'cpu' + torch_dtype = dtype if dtype.is_floating_point else res_dtype + torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype), b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, From c7e06769bc505aeeb9d926134cb9baf338d2031c Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 18 Dec 2024 01:04:29 +0000 Subject: [PATCH 2/4] add int8 to addmatrix benchmark 2/? --- .../gemm_postop_addmatrix_benchmark.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index d4b7fc535a..8ff10ecc31 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -218,7 +218,7 @@ def matmul(a, b, d, c): # argument names to use as an x-axis for the plot x_names=['B', 'M', 'K', 'N', 'dtype'], # different possible values for `x_name` - x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in [torch.bfloat16, torch.int8]] + # + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in [torch.bfloat16]] + # [[*shape, dtype] for shape in [ [1, 1, 5120, 13824], # [1, 4, 4096, 12288], # @@ -240,7 +240,7 @@ def matmul(a, b, d, c): [32, 4096, 4096, 128], # [4096, 8, 128, 16384], # [4096, 8, 16384, 128] - ] for dtype in [torch.bfloat16, torch.int8]], + ] for dtype in [torch.bfloat16]], line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -286,7 +286,9 @@ def benchmark(B, M, N, K, dtype, provider): torch_dtype = dtype if dtype.is_floating_point else res_dtype torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype), b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], [1, 512, 8192, 32768], [4, 32768, 4096, 128]]: + # torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles, kernel_name=kernel_name) else: From 7a66ce76596a733d4c6dfcbe847ee9e1aebb59ea Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 18 Dec 2024 01:26:41 +0000 Subject: [PATCH 3/4] add int8 to addmatrix benchmark 3/3 --- .../gemm_postop_addmatrix_benchmark.py | 64 ++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 8ff10ecc31..5b1c7971e0 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -121,8 +121,7 @@ def matmul_kernel_with_block_pointers_batched( stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, # stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, # - stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, - ACCUMULATOR_DTYPE: tl.constexpr, + stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, ACCUMULATOR_DTYPE: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): bid = tl.program_id(axis=0) @@ -218,29 +217,31 @@ def matmul(a, b, d, c): # argument names to use as an x-axis for the plot x_names=['B', 'M', 'K', 'N', 'dtype'], # different possible values for `x_name` - x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in [torch.bfloat16]] + # - [[*shape, dtype] for shape in [ - [1, 1, 5120, 13824], # - [1, 4, 4096, 12288], # - [1, 512, 8192, 8192], # - [1, 512, 8192, 32768], # - [1, 512, 32768, 8192], # - [1, 1024, 16384, 8192], # - [1, 1024, 28672, 8192], # - [1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works - [1, 4096, 16384, 8192], # - [1, 8192, 16384, 1024], # - [1, 8192, 16384, 4096], # - [1, 16384, 1024, 8192], # - [1, 16384, 4096, 8192], # - [1, 16384, 8192, 1024], # - [1, 16384, 8192, 4096], # - [4, 32768, 128, 4096], # - [4, 32768, 4096, 128], # - [32, 4096, 4096, 128], # - [4096, 8, 128, 16384], # - [4096, 8, 16384, 128] - ] for dtype in [torch.bfloat16]], + x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] + for i in [1, 2, 4, 8] + for dtype in [torch.bfloat16, torch.int8]] + # + [[*shape, dtype] + for shape in [[1, 1, 5120, 13824], # + [1, 4, 4096, 12288], # + [1, 512, 8192, 8192], # + [1, 512, 8192, 32768], # + [1, 512, 32768, 8192], # + [1, 1024, 16384, 8192], # + [1, 1024, 28672, 8192], # + [1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works + [1, 4096, 16384, 8192], # + [1, 8192, 16384, 1024], # + [1, 8192, 16384, 4096], # + [1, 16384, 1024, 8192], # + [1, 16384, 4096, 8192], # + [1, 16384, 8192, 1024], # + [1, 16384, 8192, 4096], # + [4, 32768, 128, 4096], # + [4, 32768, 4096, 128], # + [32, 4096, 4096, 128], # + [4096, 8, 128, 16384], # + [4096, 8, 16384, 128]] + for dtype in [torch.bfloat16, torch.int8]], line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -257,9 +258,9 @@ def matmul(a, b, d, c): def benchmark(B, M, N, K, dtype, provider): res_dtype = torch.float32 if dtype is torch.bfloat16 else torch.int32 if dtype.is_floating_point: - rand = lambda shape, dtype : torch.rand(shape, device='xpu', dtype=dtype) + rand = lambda shape, dtype: torch.rand(shape, device='xpu', dtype=dtype) else: - rand = lambda shape, dtype : torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) + rand = lambda shape, dtype: torch.randint(low=-127, high=128, size=shape, device='xpu', dtype=dtype) if B == 1: a = rand((M, K), dtype) b = rand((K, N), dtype) @@ -281,12 +282,15 @@ def benchmark(B, M, N, K, dtype, provider): c = torch.empty((M, N), device='xpu', dtype=res_dtype) kernel_name = 'matmul_kernel_with_block_pointers' triton_fn = lambda: matmul(a, b, d, c) - # Torch does not support integer calculation in matmul + # Torch does not support integer calculation in matmul torch_device = 'xpu' if dtype.is_floating_point else 'cpu' torch_dtype = dtype if dtype.is_floating_point else res_dtype - torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype), b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype) + d + torch_fn = lambda: torch.matmul(a.to(device=torch_device, dtype=torch_dtype), + b.to(device=torch_device, dtype=torch_dtype)).to(device='xpu', dtype=res_dtype + ) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], [1, 512, 8192, 32768], [4, 32768, 4096, 128]]: + if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], + [1, 512, 8192, 32768], [4, 32768, 4096, 128]]: # torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, From 164fd0079b7bb68aee83e57297957b61f8617106 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 18 Dec 2024 01:53:29 +0000 Subject: [PATCH 4/4] Add dtype to the param cols --- .github/workflows/triton-benchmarks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 659eedee72..c9c612ab52 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -235,7 +235,7 @@ jobs: cd benchmarks/triton_kernels_benchmark python gemm_postop_addmatrix_benchmark.py --reports $REPORTS source ../../scripts/capture-hw-details.sh - python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix.csv $REPORTS/gemm-postop-addmatrix-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix.csv $REPORTS/gemm-postop-addmatrix-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N,dtype" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG - name: Run Triton FA kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py') }}