diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 56ea865..51033b7 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -155,6 +155,9 @@ def parse_op_args(args: List[str]): parser.add_argument( "--native-sdpa", action="store_true", help="Use SDPA native choice." ) + parser.add_argument( + "--pt2-sdpa", action="store_true", help="Compile SDPA with PT2." + ) parser.add_argument( "--additional-inputs", action="store_true", help="enable additional inputs" ) @@ -176,6 +179,7 @@ def __init__( self.N_CTX = None self.causal = args.causal self.native_sdpa = args.native_sdpa + self.pt2_sdpa = args.pt2_sdpa # We always turn on causal for backward # Because Triton-Flash-V2 does not support backward with non-causal if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: @@ -216,7 +220,17 @@ def sdpa_flash_attention(q, k, v): else sdpa_kernel([SDPBackend.FLASH_ATTENTION]) ) with cxt: - return sdpa( + sdpa_impl = ( + torch.compile( + sdpa, + fullgraph=True, + backend="inductor", + mode="max-autotune", + ) + if self.pt2_sdpa + else sdpa + ) + return sdpa_impl( q, k, v, @@ -467,18 +481,25 @@ def causal_mask(b, h, q_idx, kv_idx): @register_metric() def tflops( self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics + ) -> float: + analytic_flops = self.flops(fn_name, example_inputs, metrics) + return analytic_flops / metrics.latency * 1e-9 + + @register_metric(x_only=True) + def flops( + self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics ) -> float: q, k, v = example_inputs BATCH, H, N_CTX, D_HEAD = q.shape flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD - tflops = 2 * flops_per_matmul + flops = 2 * flops_per_matmul if self.causal: - tflops *= 0.5 + flops *= 0.5 if self.mode == BenchmarkMode.BWD: - tflops *= 2.5 # 2.0(bwd) + 0.5(recompute) + flops *= 2.5 # 2.0(bwd) + 0.5(recompute) elif self.mode == BenchmarkMode.FWD_BWD: - tflops *= 3.5 # 1.0(fwd) + 2.0(bwd) + 0.5(recompute) - return tflops / metrics.latency * 1e-9 + flops *= 3.5 # 1.0(fwd) + 2.0(bwd) + 0.5(recompute) + return flops def get_bwd_fn(self, fwd_fn: Callable) -> Callable: o = fwd_fn()