Skip to content

Commit

Permalink
Don't use fast_flush=False as it seems to be deprecated (#2323)
Browse files Browse the repository at this point in the history
  • Loading branch information
anmyachev authored Sep 29, 2024
1 parent 77d819c commit 361dfa7
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
CAUSAL, scale=sm_scale), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)

elif provider == 'triton':
# FIXME: remove below if condition when extend attention support for Causal = True done
Expand All @@ -257,8 +257,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
Expand All @@ -273,8 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
8 changes: 3 additions & 5 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def benchmark(B, M, N, K, provider):

if provider == 'onednn':
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)
elif provider == 'triton':
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
Expand All @@ -262,8 +262,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
elif provider == 'xetla':
if B == 1:
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -278,8 +277,7 @@ def benchmark(B, M, N, K, provider):
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,14 @@ def benchmark(M, N, K, provider):

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
5 changes: 2 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,13 @@ def benchmark(M, N, K, provider):

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/prefix_sums.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def benchmark(M, N, AXIS, provider):

if provider == 'triton':
triton_fn = lambda: scan_kernel[(1, )](x, BLOCK_SIZE_M=M, BLOCK_SIZE_N=N, AXIS=AXIS)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down

0 comments on commit 361dfa7

Please sign in to comment.