From c26a51fc226cfbb051a3a230dc2b28e9f5c6be62 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 4 Dec 2024 18:30:59 -0500 Subject: [PATCH] Add option to do native sdpa --- tritonbench/operators/flash_attention/operator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index dcc2daca..ade5d515 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -34,6 +34,7 @@ import argparse import math import os +from contextlib import nullcontext from itertools import chain import torch @@ -150,6 +151,7 @@ def parse_op_args(args: List[str]): action="store_true", help="enable causal (always true on backward)", ) + parser.add_argument("--native-sdpa", action="store_true", help="Use SDPA native choice.") parser.add_argument( "--additional-inputs", action="store_true", help="enable additional inputs" ) @@ -172,6 +174,7 @@ def __init__( self.D_HEAD = args.d_head self.N_CTX = None self.causal = args.causal + self.native_sdpa = args.native_sdpa # We always turn on causal for backward # Because Triton-Flash-V2 does not support backward with non-causal if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: @@ -206,7 +209,9 @@ def sdpa( v: torch.Tensor, ) -> Callable: def sdpa_flash_attention(q, k, v): - with sdpa_kernel([SDPBackend.FLASH_ATTENTION]): + cxt = nullcontext if self.native_sdpa else \ + sdpa_kernel([SDPBackend.FLASH_ATTENTION]) + with cxt: return sdpa( q, k,