Skip to content

Commit

Permalink
Revert "try changes from #3036"
Browse files Browse the repository at this point in the history
This reverts commit 2a4b818.
  • Loading branch information
anmyachev committed Dec 18, 2024
1 parent 2a4b818 commit 0d66c8e
Show file tree
Hide file tree
Showing 11 changed files with 3 additions and 23 deletions.
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401

if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
from triton.runtime import driver
Expand Down
8 changes: 0 additions & 8 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,6 @@ def extract_kernels(funcs):
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")


def make_do_bench_for_autotune():

def autotuner_do_bench(*args, **kwargs):
return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs)

return autotuner_do_bench


def assert_close(x, y, atol=None, rtol=None, err_msg=""):
import numpy as np
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
for w in [8, 16, 32] \
]

tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune())
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
tune_attn_fwd = tuner(_attn_fwd)


Expand Down
1 change: 0 additions & 1 deletion benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def naive_softmax(x):
triton.Config({"threads_per_warp": 16}, num_warps=4),
],
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
num_stages=s, num_warps=32) for s in [2, 3]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -117,7 +116,6 @@ def matmul_kernel_with_block_pointers(
num_stages=s, num_warps=4) for s in [2]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -110,7 +109,6 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def gelu(x):
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -123,7 +122,6 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -108,7 +107,6 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
num_stages=4, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def _kernel(A, B, C, #
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def mac_loop(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def first_wave(
Expand Down Expand Up @@ -144,7 +143,6 @@ def first_wave(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def full_tiles(
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def kernel(x_ptr, x_size, **META):
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
use_cuda_graph=use_cuda_graph, do_bench=do_bench)
use_cuda_graph=use_cuda_graph)

return decorator

Expand Down

0 comments on commit 0d66c8e

Please sign in to comment.