Skip to content

Commit

Permalink
[Tutorials] Rename reference library name (#2452)
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Oct 9, 2024
1 parent 76426a7 commit b3ddfca
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def matmul(a, b, activation=""):
# We can now compare the performance of our kernel against that of cuBLAS or rocBLAS. Here we focus on square matrices,
# but feel free to arrange this script as you wish to benchmark any other matrix shape.

ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'
ref_lib = 'cuBLAS' if is_cuda() else 'oneDNN' if is_xpu() else 'rocBLAS'

configs = []
for fp8_inputs in [False, True]:
Expand Down
13 changes: 10 additions & 3 deletions python/tutorials/08-grouped-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
import triton.language as tl


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"


@triton.autotune(
configs=[
triton.Config({
Expand Down Expand Up @@ -228,6 +232,9 @@ def torch_perf_fn(group_A, group_B):
torch.matmul(a, b)


ref_lib = 'cuBLAS' if is_cuda() else 'oneDNN'


@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
Expand All @@ -236,9 +243,9 @@ def torch_perf_fn(group_A, group_B):
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['cublas', 'triton'],
line_vals=[ref_lib.lower(), 'triton'],
# label name for the lines
line_names=["cuBLAS", "Triton"],
line_names=[ref_lib, "Triton"],
# line styles
styles=[('green', '-'), ('blue', '-')],
ylabel="runtime(ms)", # label name for the y-axis
Expand Down Expand Up @@ -276,7 +283,7 @@ def benchmark(N, provider):
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu")

quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
if provider == ref_lib.lower():
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
Expand Down

0 comments on commit b3ddfca

Please sign in to comment.