Skip to content

Commit

Permalink
adding param to shuffle production data shapes
Browse files Browse the repository at this point in the history
Summary:
TSIA right now just doing this for production shapes

Noticed that we sometimes error out at times and do not run all the shapes. since we run multiple times a day randomly shuffling the shapes and aggregating over the day will produce a more stable output

Reviewed By: danzimm, xuzhao9

Differential Revision: D66519495

fbshipit-source-id: 56993a8bb196174e05c4224bd386190c18883603
  • Loading branch information
adamomainz authored and facebook-github-bot committed Nov 26, 2024
1 parent 0ca9f40 commit c666f87
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def __additional_example_input(self, standard_shapes: Generator) -> Generator:
shapes = chain(
shapes,
productionDataLoader.get_shapes_from_frozen_durin(
self.name, "attention"
self.name, "attention", shuffle_shapes=self.tb_args.shuffle_shapes
),
)
return shapes
Expand Down
4 changes: 3 additions & 1 deletion tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def __init__(
gemm_args = parse_args(self.extra_args)
self.layout = gemm_args.layout
if IS_FBCODE and tb_args.production_shapes:
self.shapes = get_production_shapes(self.name, f"{tb_args.precision}_gemm")
self.shapes = get_production_shapes(
self.name, f"{tb_args.precision}_gemm", self.tb_args.shuffle_shapes
)
elif gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
elif gemm_args.splitk:
Expand Down
4 changes: 3 additions & 1 deletion tritonbench/operators/softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def get_input_iter(self):
M = 4096
shapes = [(M, 128 * i) for i in range(2, 100)]
if IS_FBCODE and self.tb_args.production_shapes:
shapes = get_production_shapes(self.name, "softmax")
shapes = get_production_shapes(
self.name, "softmax", self.tb_args.shuffle_shapes
)
for M, N in shapes:
yield (torch.randn([M, N], dtype=self.dtype, device=self.device),)

Expand Down
5 changes: 5 additions & 0 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def get_parser(args=None):
action="store_true",
help="bypass and continue on operator failure.",
)
parser.add_argument(
"--shuffle-shapes",
action="store_true",
help="when true randomly shuffles the inputs before running benchmarks where possible.",
)

if IS_FBCODE:
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
Expand Down

0 comments on commit c666f87

Please sign in to comment.