Skip to content

Commit

Permalink
try changes from #3036
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored and whitneywhtsang committed Dec 18, 2024
1 parent d82e3ea commit 231b07a
Show file tree
Hide file tree
Showing 11 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401
from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401

if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
from triton.runtime import driver
Expand Down
8 changes: 8 additions & 0 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def extract_kernels(funcs):
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")


def make_do_bench_for_autotune():

def autotuner_do_bench(*args, **kwargs):
return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs)

return autotuner_do_bench


def assert_close(x, y, atol=None, rtol=None, err_msg=""):
import numpy as np
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
for w in [8, 16, 32] \
]

tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune())
tune_attn_fwd = tuner(_attn_fwd)


Expand Down
1 change: 1 addition & 0 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def naive_softmax(x):
triton.Config({"threads_per_warp": 16}, num_warps=4),
],
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
num_stages=s, num_warps=32) for s in [2, 3]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -116,6 +117,7 @@ def matmul_kernel_with_block_pointers(
num_stages=s, num_warps=4) for s in [2]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -109,6 +110,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def gelu(x):
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -122,6 +123,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -107,6 +108,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
num_stages=4, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def _kernel(A, B, C, #
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def mac_loop(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def first_wave(
Expand Down Expand Up @@ -143,6 +144,7 @@ def first_wave(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def full_tiles(
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def kernel(x_ptr, x_size, **META):
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
use_cuda_graph=use_cuda_graph)
use_cuda_graph=use_cuda_graph, do_bench=do_bench)

return decorator

Expand Down

0 comments on commit 231b07a

Please sign in to comment.