Skip to content

Commit

Permalink
Disable gemm on persistent
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 19, 2024
1 parent 07000ee commit f2963f1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
4 changes: 4 additions & 0 deletions test/test_gpu/skip_tests_h100_pytorch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ fp8_gemm:
- triton_persistent_fp8_gemm
- triton_tma_persistent_fp8_gemm
fp8_gemm_rowwise:
gemm:
- triton_persistent_matmul
- triton_tma_persistent_matmul
- triton_tma_persistent_cached_matmul
jagged_layer_norm:
jagged_mean:
jagged_softmax:
Expand Down
21 changes: 13 additions & 8 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@

from .kernels import matmul as kernels
from .partition_k import matmul_partition_k
from .persistent_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_tma_persistent_cached,
)
try:
from .persistent_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_tma_persistent_cached,
)
HAS_PRESISTENT = True
except ModuleNotFoundError:
HAS_PRESISTENT = False

from .triton_matmul import (
matmul as triton_tutorial_matmul,
matmul_kernel as triton_tutorial_matmul_kernel,
Expand Down Expand Up @@ -158,22 +163,22 @@ def matmul_partition_k(self, a, b, bias) -> Callable:
else:
return lambda: matmul_partition_k(a, bt)

@register_benchmark()
@register_benchmark(enabled=HAS_PRESISTENT)
def triton_persistent_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: matmul_persistent(a, b) + bias
else:
return lambda: matmul_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE)
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
return lambda: matmul_tma_persistent(a, b) + bias
else:
return lambda: matmul_tma_persistent(a, b)

@register_benchmark(enabled=not IS_FBCODE)
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
b = b.T.contiguous()
if not bias == None:
Expand Down

0 comments on commit f2963f1

Please sign in to comment.