diff --git a/tritonbench/operators/gemm/operator.py b/tritonbench/operators/gemm/operator.py index e413e4a..270f133 100644 --- a/tritonbench/operators/gemm/operator.py +++ b/tritonbench/operators/gemm/operator.py @@ -143,11 +143,7 @@ def __init__( super().__init__(tb_args, extra_args) gemm_args = parse_args(self.extra_args) self.layout = gemm_args.layout - if IS_FBCODE and tb_args.production_shapes: - self.shapes = get_production_shapes( - self.name, f"{tb_args.precision}_gemm", self.tb_args.shuffle_shapes - ) - elif gemm_args.input: + if gemm_args.input: self.shapes = read_shapes_from_csv(gemm_args.input) elif gemm_args.splitk: self.shapes = SPLIT_K_SHAPES @@ -158,6 +154,19 @@ def __init__( else: self.shapes = BUILDIN_SHAPES + if IS_FBCODE and tb_args.production_shapes: + additional_shapes = get_production_shapes( + self.name, f"{tb_args.precision}_gemm", self.tb_args.shuffle_shapes + ) + if len(additional_shapes): # only append if not empty + self.shapes.append( + get_production_shapes( + self.name, + f"{tb_args.precision}_gemm", + self.tb_args.shuffle_shapes, + ) + ) + @register_benchmark() def triton_tutorial_matmul(self, a, b, bias) -> Callable: if not bias == None: