diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 9ec2e846..4bdb25f8 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -474,7 +474,7 @@ def get_input_iter(self) -> Generator: H = self.H def get_ctx_vals(): - for i in range(self.SEQ_LEN, self.SEQ_LEN + 1): + for i in range(self.SEQ_LEN, 15): N_CTX = 2**i # BATCH = 16384 // N_CTX # H = 2048 // D_HEAD