Skip to content

Commit

Permalink
Add option to do native sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 4, 2024
1 parent c3e6792 commit c26a51f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import argparse
import math
import os
from contextlib import nullcontext
from itertools import chain

import torch
Expand Down Expand Up @@ -150,6 +151,7 @@ def parse_op_args(args: List[str]):
action="store_true",
help="enable causal (always true on backward)",
)
parser.add_argument("--native-sdpa", action="store_true", help="Use SDPA native choice.")
parser.add_argument(
"--additional-inputs", action="store_true", help="enable additional inputs"
)
Expand All @@ -172,6 +174,7 @@ def __init__(
self.D_HEAD = args.d_head
self.N_CTX = None
self.causal = args.causal
self.native_sdpa = args.native_sdpa
# We always turn on causal for backward
# Because Triton-Flash-V2 does not support backward with non-causal
if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD:
Expand Down Expand Up @@ -206,7 +209,9 @@ def sdpa(
v: torch.Tensor,
) -> Callable:
def sdpa_flash_attention(q, k, v):
with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
cxt = nullcontext if self.native_sdpa else \
sdpa_kernel([SDPBackend.FLASH_ATTENTION])
with cxt:
return sdpa(
q,
k,
Expand Down

0 comments on commit c26a51f

Please sign in to comment.