diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index cd5ee843..30594ed5 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -48,10 +48,10 @@ def parse_args(args: List[str]) -> argparse.Namespace: try: - cutlass_fp8_row = torch.ops.fbgemm.f8f8bf16_rowwise - HAS_CUTLASS = True + cutlass_or_ck_fp8_row = torch.ops.fbgemm.f8f8bf16_rowwise + HAS_CUTLASS_OR_CK = True except ImportError: - HAS_CUTLASS = False + HAS_CUTLASS_OR_CK = False try: cublas_fp8_row = torch.ops.fbgemm.f8f8bf16_cublas @@ -134,9 +134,11 @@ def _triton(self, xq, wq, x_scale, w_scale) -> Callable: no_use_persistent=self.no_use_persistent, ) - @register_benchmark(enabled=HAS_CUTLASS) - def _cutlass(self, xq, wq, x_scale, w_scale) -> Callable: - return lambda: cutlass_fp8_row( + @register_benchmark( + enabled=HAS_CUTLASS_OR_CK, label="ck" if torch.version.hip else "cutlass" + ) + def _cutlass_or_ck(self, xq, wq, x_scale, w_scale) -> Callable: + return lambda: cutlass_or_ck_fp8_row( xq, wq, x_scale, w_scale, use_fast_accum=self.fp8_fast_accum ) @@ -207,13 +209,13 @@ def plot(self): line_vals=[ "_torch", "_triton", - "_cutlass", + "_ck" if torch.version.hip else "_cutlass", "_cublas", ], # possible values for `line_arg`` line_names=[ "Torch", "Triton", - "Cutlass", + "CK" if torch.version.hip else "Cutlass", "cuBLAS", ], # label name for the lines styles=[("blue", "-"), ("green", "-"), ("yellow", "-")], # line styles diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index 43ef3655..fb2afa0b 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -222,7 +222,7 @@ def pt2_triton_matmul(self, a, b, bias) -> Callable: compiled(a, b) return lambda: compiled(a, b) - @register_benchmark() + @register_benchmark(enabled=not torch.version.hip) def pt2_cutlass_matmul(self, a, b, bias) -> Callable: torch._dynamo.reset() with inductor_config.patch( diff --git a/tritonbench/operators/gemm/persistent_matmul.py b/tritonbench/operators/gemm/persistent_matmul.py index 1ece9a73..a4394d48 100644 --- a/tritonbench/operators/gemm/persistent_matmul.py +++ b/tritonbench/operators/gemm/persistent_matmul.py @@ -133,13 +133,17 @@ def matmul_kernel_persistent( def matmul_persistent(a, b): + reduced_stages = 0 + if torch.version.hip: + # amd hits shared memory limits with current settings + reduced_stages = 1 configs = { torch.float8_e4m3fn: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, - "num_stages": 4, + "num_stages": 4 - reduced_stages, "num_warps": 8, }, torch.float16: { @@ -147,7 +151,7 @@ def matmul_persistent(a, b): "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, - "num_stages": 3, + "num_stages": 3 - reduced_stages, "num_warps": 8, }, torch.bfloat16: { @@ -155,7 +159,7 @@ def matmul_persistent(a, b): "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, - "num_stages": 3, + "num_stages": 3 - reduced_stages, "num_warps": 8, }, }