Skip to content

Commit

Permalink
Add pt2_sdpa (#105)
Browse files Browse the repository at this point in the history
Summary:
Add `--pt2-sdpa` option to test compiled sdpa.

Pull Request resolved: #105

Test Plan:
```
$ python run.py --op flash_attention --batch 1 --n-heads 24 --seq-len 4608 --d-head 128 --only cudnn,sdpa,flash_v3 --metrics proton --native-sdpa --pt2-sdpa
```

Time/ns:
<img width="1920" alt="image" src="https://github.com/user-attachments/assets/d8ce093c-4aaf-4aff-b36e-3ea97f5c2c3d">

Tflops:
<img width="1920" alt="image" src="https://github.com/user-attachments/assets/c7e36b6f-04ae-4dd8-b6b8-1546702e1134">

Reviewed By: FindHao

Differential Revision: D66969136

Pulled By: xuzhao9

fbshipit-source-id: 40ed0223d99a10724e495cacaaeba7770adf78a6
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Dec 10, 2024
1 parent e4f5305 commit d94fae4
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def parse_op_args(args: List[str]):
parser.add_argument(
"--native-sdpa", action="store_true", help="Use SDPA native choice."
)
parser.add_argument(
"--pt2-sdpa", action="store_true", help="Compile SDPA with PT2."
)
parser.add_argument(
"--additional-inputs", action="store_true", help="enable additional inputs"
)
Expand All @@ -176,6 +179,7 @@ def __init__(
self.N_CTX = None
self.causal = args.causal
self.native_sdpa = args.native_sdpa
self.pt2_sdpa = args.pt2_sdpa
# We always turn on causal for backward
# Because Triton-Flash-V2 does not support backward with non-causal
if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD:
Expand Down Expand Up @@ -216,7 +220,17 @@ def sdpa_flash_attention(q, k, v):
else sdpa_kernel([SDPBackend.FLASH_ATTENTION])
)
with cxt:
return sdpa(
sdpa_impl = (
torch.compile(
sdpa,
fullgraph=True,
backend="inductor",
mode="max-autotune",
)
if self.pt2_sdpa
else sdpa
)
return sdpa_impl(
q,
k,
v,
Expand Down Expand Up @@ -467,18 +481,25 @@ def causal_mask(b, h, q_idx, kv_idx):
@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
analytic_flops = self.flops(fn_name, example_inputs, metrics)
return analytic_flops / metrics.latency * 1e-9

@register_metric(x_only=True)
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
q, k, v = example_inputs
BATCH, H, N_CTX, D_HEAD = q.shape
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
tflops = 2 * flops_per_matmul
flops = 2 * flops_per_matmul
if self.causal:
tflops *= 0.5
flops *= 0.5
if self.mode == BenchmarkMode.BWD:
tflops *= 2.5 # 2.0(bwd) + 0.5(recompute)
flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
elif self.mode == BenchmarkMode.FWD_BWD:
tflops *= 3.5 # 1.0(fwd) + 2.0(bwd) + 0.5(recompute)
return tflops / metrics.latency * 1e-9
flops *= 3.5 # 1.0(fwd) + 2.0(bwd) + 0.5(recompute)
return flops

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
o = fwd_fn()
Expand Down

0 comments on commit d94fae4

Please sign in to comment.