Skip to content

Commit

Permalink
Backport D64976612 and D65060122
Browse files Browse the repository at this point in the history
Summary:
We are migrating from pytorch/benchmark/torchbenchmark/operators to pytorch/tritonbench/tritonbench/operators.

Backport D64976612 and D65060122 fixing gemm for amd in pytorch/tritonbench

Reviewed By: danzimm, adamomainz

Differential Revision: D65098976

fbshipit-source-id: a60d63075d3f2719a7d91ccb8fa9c75f01b1b05c
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 29, 2024
1 parent 67ad80b commit 7d471dc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
18 changes: 10 additions & 8 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def parse_args(args: List[str]) -> argparse.Namespace:


try:
cutlass_fp8_row = torch.ops.fbgemm.f8f8bf16_rowwise
HAS_CUTLASS = True
cutlass_or_ck_fp8_row = torch.ops.fbgemm.f8f8bf16_rowwise
HAS_CUTLASS_OR_CK = True
except ImportError:
HAS_CUTLASS = False
HAS_CUTLASS_OR_CK = False

try:
cublas_fp8_row = torch.ops.fbgemm.f8f8bf16_cublas
Expand Down Expand Up @@ -134,9 +134,11 @@ def _triton(self, xq, wq, x_scale, w_scale) -> Callable:
no_use_persistent=self.no_use_persistent,
)

@register_benchmark(enabled=HAS_CUTLASS)
def _cutlass(self, xq, wq, x_scale, w_scale) -> Callable:
return lambda: cutlass_fp8_row(
@register_benchmark(
enabled=HAS_CUTLASS_OR_CK, label="ck" if torch.version.hip else "cutlass"
)
def _cutlass_or_ck(self, xq, wq, x_scale, w_scale) -> Callable:
return lambda: cutlass_or_ck_fp8_row(
xq, wq, x_scale, w_scale, use_fast_accum=self.fp8_fast_accum
)

Expand Down Expand Up @@ -207,13 +209,13 @@ def plot(self):
line_vals=[
"_torch",
"_triton",
"_cutlass",
"_ck" if torch.version.hip else "_cutlass",
"_cublas",
], # possible values for `line_arg``
line_names=[
"Torch",
"Triton",
"Cutlass",
"CK" if torch.version.hip else "Cutlass",
"cuBLAS",
], # label name for the lines
styles=[("blue", "-"), ("green", "-"), ("yellow", "-")], # line styles
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def pt2_triton_matmul(self, a, b, bias) -> Callable:
compiled(a, b)
return lambda: compiled(a, b)

@register_benchmark()
@register_benchmark(enabled=not torch.version.hip)
def pt2_cutlass_matmul(self, a, b, bias) -> Callable:
torch._dynamo.reset()
with inductor_config.patch(
Expand Down
10 changes: 7 additions & 3 deletions tritonbench/operators/gemm/persistent_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,33 @@ def matmul_kernel_persistent(


def matmul_persistent(a, b):
reduced_stages = 0
if torch.version.hip:
# amd hits shared memory limits with current settings
reduced_stages = 1
configs = {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 4,
"num_stages": 4 - reduced_stages,
"num_warps": 8,
},
torch.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_stages": 3 - reduced_stages,
"num_warps": 8,
},
torch.bfloat16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 3,
"num_stages": 3 - reduced_stages,
"num_warps": 8,
},
}
Expand Down

0 comments on commit 7d471dc

Please sign in to comment.