From 06c28ed91cbba027606f721d41136555f87a7ff9 Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Thu, 19 Dec 2024 15:31:08 -0800 Subject: [PATCH] quick update for fp8 rowwise Summary: currently rowwise only uses production shapes and skips default shapes. we want to make this inclusive and not an either or. Reviewed By: danzimm Differential Revision: D67478772 fbshipit-source-id: edc1e903ab91fcdefcf880b8a892b41a8d3b39d5 --- tritonbench/operators/fp8_gemm_rowwise/operator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tritonbench/operators/fp8_gemm_rowwise/operator.py b/tritonbench/operators/fp8_gemm_rowwise/operator.py index 499a82a..53a63bc 100644 --- a/tritonbench/operators/fp8_gemm_rowwise/operator.py +++ b/tritonbench/operators/fp8_gemm_rowwise/operator.py @@ -118,14 +118,16 @@ def __init__( super().__init__(tb_args, extra_args) self.use_cuda_graphs = True addmm_args = parse_args(self.extra_args) - if hasattr(tb_args, "production_shapes") and tb_args.production_shapes: - self.shapes = get_production_shapes(self.name, "fp32_gemm") - elif addmm_args.m and addmm_args.n and addmm_args.k: + if addmm_args.m and addmm_args.n and addmm_args.k: self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)] elif addmm_args.llama: self.shapes = gemm_shapes() else: self.shapes = BUILDIN_SHAPES + if hasattr(tb_args, "production_shapes") and tb_args.production_shapes: + extras = get_production_shapes(self.name, "fp32_gemm") + if len(extras): + self.shapes.extend(extras) self.fp8_fast_accum = addmm_args.fp8_fast_accum self.use_tma = addmm_args.use_tma self.no_use_persistent = addmm_args.no_use_persistent