From ce74f5e7413ef6d19461383bf0aa773fa6712305 Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Tue, 3 Dec 2024 09:45:09 -0800 Subject: [PATCH] fixing small bug caused in gemms and softmax Summary: gemms are failing because of logical bug here flash_attention does not use this function and directly calls the downstream function. When testing the change for shuffled data I only tested on flash attention and didnt see this break in testing. Reviewed By: danzimm Differential Revision: D66706511 fbshipit-source-id: f89a7561a41d82b8c486bd6af097057b2413b7d2 --- tritonbench/utils/data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tritonbench/utils/data_utils.py b/tritonbench/utils/data_utils.py index 904321b2..8b852b23 100644 --- a/tritonbench/utils/data_utils.py +++ b/tritonbench/utils/data_utils.py @@ -1,7 +1,7 @@ from .triton_op import IS_FBCODE -def get_production_shapes(op_name, op_type): +def get_production_shapes(op_name, op_type, shuffle_shapes=False): """Gets a list of Softmax shapes for benchmarking""" if IS_FBCODE: from .fb.durin_data import productionDataLoader @@ -9,6 +9,6 @@ def get_production_shapes(op_name, op_type): return [ shape for shape in productionDataLoader.get_shapes_from_frozen_durin( - op_name, op_type + op_name, op_type, shuffle_shapes ) ]