diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 74dd6dea..57d526e7 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -154,6 +154,9 @@ def parse_op_args(args: List[str]): parser.add_argument( "--native-sdpa", action="store_true", help="Use SDPA native choice." ) + parser.add_argument( + "--compile-sdpa", action="store_true", help="Compile SDPA with PT2." + ) parser.add_argument( "--additional-inputs", action="store_true", help="enable additional inputs" ) @@ -177,6 +180,7 @@ def __init__( self.N_CTX = None self.causal = args.causal self.native_sdpa = args.native_sdpa + self.pt2_sdpa = args.pt2_sdpa # We always turn on causal for backward # Because Triton-Flash-V2 does not support backward with non-causal if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: @@ -217,6 +221,13 @@ def sdpa_flash_attention(q, k, v): else sdpa_kernel([SDPBackend.FLASH_ATTENTION]) ) with cxt: + if self.pt2_sdpa: + sdpa = torch.compile( + sdpa, + fullgraph=True, + backend="inductor", + mode="max-autotune", + ) return sdpa( q, k,