From 361dfa7105ecc8ded76803623d4e066fac1297e8 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sun, 29 Sep 2024 13:21:01 +0200 Subject: [PATCH] Don't use `fast_flush=False` as it seems to be deprecated (#2323) Closes #2324 Note: PyTorch remove it as well: https://github.com/pytorch/pytorch/pull/135387 CI: * ~https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11015650093~ * https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11078059091 Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 8 +++----- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 8 +++----- .../gemm_postop_addmatrix_benchmark.py | 3 +-- .../gemm_postop_gelu_benchmark.py | 3 +-- .../triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 3 +-- .../triton_kernels_benchmark/gemm_splitk_benchmark.py | 5 ++--- .../triton_kernels_benchmark/gemm_streamk_benchmark.py | 5 ++--- benchmarks/triton_kernels_benchmark/prefix_sums.py | 2 +- 8 files changed, 14 insertions(+), 23 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 8c491327e6..ae49da2d0c 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -238,7 +238,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= CAUSAL, scale=sm_scale), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': # FIXME: remove below if condition when extend attention support for Causal = True done @@ -257,8 +257,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): ), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) elif provider == 'xetla': module_name = f'flash_attn_causal_{CAUSAL}'.lower() @@ -273,8 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider): l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index ecb6d8ca71..b647de0d23 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -250,7 +250,7 @@ def benchmark(B, M, N, K, provider): if provider == 'onednn': _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': assert len(a.shape) == len(b.shape), 'Incompatible sizes' if len(a.shape) == 3: @@ -262,8 +262,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) elif provider == 'xetla': if B == 1: c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -278,8 +277,7 @@ def benchmark(B, M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index fa3f9d5ae2..08080abf96 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -273,8 +273,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index ad2efc7e6d..788354e978 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -275,8 +275,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 7456f2a2f1..e273cf4366 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -263,8 +263,7 @@ def benchmark(B, M, N, K, provider): torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 2146c0434e..468393be88 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -150,15 +150,14 @@ def benchmark(M, N, K, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device='xpu', dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 50a1f2ea32..65d5070212 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -272,14 +272,13 @@ def benchmark(M, N, K, provider): if provider == 'onednn': _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + quantiles=quantiles) elif provider == 'triton': c = torch.empty((M, N), device=a.device, dtype=torch.float32) triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/prefix_sums.py b/benchmarks/triton_kernels_benchmark/prefix_sums.py index f3beb0707d..bb3d2069f0 100644 --- a/benchmarks/triton_kernels_benchmark/prefix_sums.py +++ b/benchmarks/triton_kernels_benchmark/prefix_sums.py @@ -44,7 +44,7 @@ def benchmark(M, N, AXIS, provider): if provider == 'triton': triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS) - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles) else: raise NotImplementedError(f'Unsupported provider {provider}')