Skip to content

Commit

Permalink
Add pt2 sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 7, 2024
1 parent edc923f commit 133e727
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def parse_op_args(args: List[str]):
parser.add_argument(
"--native-sdpa", action="store_true", help="Use SDPA native choice."
)
parser.add_argument(
"--compile-sdpa", action="store_true", help="Compile SDPA with PT2."
)
parser.add_argument(
"--additional-inputs", action="store_true", help="enable additional inputs"
)
Expand All @@ -177,6 +180,7 @@ def __init__(
self.N_CTX = None
self.causal = args.causal
self.native_sdpa = args.native_sdpa
self.pt2_sdpa = args.pt2_sdpa
# We always turn on causal for backward
# Because Triton-Flash-V2 does not support backward with non-causal
if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD:
Expand Down Expand Up @@ -217,6 +221,13 @@ def sdpa_flash_attention(q, k, v):
else sdpa_kernel([SDPBackend.FLASH_ATTENTION])
)
with cxt:
if self.pt2_sdpa:
sdpa = torch.compile(
sdpa,
fullgraph=True,
backend="inductor",
mode="max-autotune",
)
return sdpa(
q,
k,
Expand Down

0 comments on commit 133e727

Please sign in to comment.