Skip to content

Commit

Permalink
safety switch for production data
Browse files Browse the repository at this point in the history
Summary:
Currently our input durin datasource was dropped so adding a justknob to control the pattern where we use the prod data or not. Will run all benchmarks against this with diff tag to make sure they dont error out too

justknob lives here https://www.internalfb.com/intern/justknobs/?name=tritonbench%2Fkillswitches

Reviewed By: xuzhao9

Differential Revision: D67358140

fbshipit-source-id: 6b9b52a38b8e87aa882cede4c3c2e095dd1b4efa
  • Loading branch information
adamomainz authored and facebook-github-bot committed Dec 18, 2024
1 parent 65badab commit 0497c5d
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,7 @@ def __init__(
super().__init__(tb_args, extra_args)
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if IS_FBCODE and tb_args.production_shapes:
self.shapes = get_production_shapes(
self.name, f"{tb_args.precision}_gemm", self.tb_args.shuffle_shapes
)
elif gemm_args.input:
if gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
elif gemm_args.splitk:
self.shapes = SPLIT_K_SHAPES
Expand All @@ -158,6 +154,19 @@ def __init__(
else:
self.shapes = BUILDIN_SHAPES

if IS_FBCODE and tb_args.production_shapes:
additional_shapes = get_production_shapes(
self.name, f"{tb_args.precision}_gemm", self.tb_args.shuffle_shapes
)
if len(additional_shapes): # only append if not empty
self.shapes.append(
get_production_shapes(
self.name,
f"{tb_args.precision}_gemm",
self.tb_args.shuffle_shapes,
)
)

@register_benchmark()
def triton_tutorial_matmul(self, a, b, bias) -> Callable:
if not bias == None:
Expand Down

0 comments on commit 0497c5d

Please sign in to comment.