From f2963f1aa581e51159105ebec46775263c09ee3c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 13:35:46 -0500 Subject: [PATCH] Disable gemm on persistent --- test/test_gpu/skip_tests_h100_pytorch.yaml | 4 ++++ tritonbench/operators/gemm/operator.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/test/test_gpu/skip_tests_h100_pytorch.yaml b/test/test_gpu/skip_tests_h100_pytorch.yaml index aa6847d6..a75eff52 100644 --- a/test/test_gpu/skip_tests_h100_pytorch.yaml +++ b/test/test_gpu/skip_tests_h100_pytorch.yaml @@ -25,6 +25,10 @@ fp8_gemm: - triton_persistent_fp8_gemm - triton_tma_persistent_fp8_gemm fp8_gemm_rowwise: +gemm: + - triton_persistent_matmul + - triton_tma_persistent_matmul + - triton_tma_persistent_cached_matmul jagged_layer_norm: jagged_mean: jagged_softmax: diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 7bbae383..d32f0a7c 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -24,11 +24,16 @@ from .kernels import matmul as kernels from .partition_k import matmul_partition_k -from .persistent_matmul import ( - matmul_persistent, - matmul_tma_persistent, - matmul_tma_persistent_cached, -) +try: + from .persistent_matmul import ( + matmul_persistent, + matmul_tma_persistent, + matmul_tma_persistent_cached, + ) + HAS_PRESISTENT = True +except ModuleNotFoundError: + HAS_PRESISTENT = False + from .triton_matmul import ( matmul as triton_tutorial_matmul, matmul_kernel as triton_tutorial_matmul_kernel, @@ -158,14 +163,14 @@ def matmul_partition_k(self, a, b, bias) -> Callable: else: return lambda: matmul_partition_k(a, bt) - @register_benchmark() + @register_benchmark(enabled=HAS_PRESISTENT) def triton_persistent_matmul(self, a, b, bias) -> Callable: if not bias == None: return lambda: matmul_persistent(a, b) + bias else: return lambda: matmul_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE) + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: @@ -173,7 +178,7 @@ def triton_tma_persistent_matmul(self, a, b, bias) -> Callable: else: return lambda: matmul_tma_persistent(a, b) - @register_benchmark(enabled=not IS_FBCODE) + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable: b = b.T.contiguous() if not bias == None: