diff --git a/.gitignore b/.gitignore index ecfa253b..fc0098f4 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ __pycache__/ .idea *.egg-info/ torch_compile_debug/ +build/ diff --git a/benchmarks/flash_attention_bench/run.py b/benchmarks/flash_attention_bench/run.py new file mode 100644 index 00000000..ac1df18d --- /dev/null +++ b/benchmarks/flash_attention_bench/run.py @@ -0,0 +1,42 @@ +""" +Run flash_attention benchmark with a single input shape: +BATCH: 4 + +SEQ_LEN: 16384 + +Print tflops metrics +""" + +import tritonbench +from tritonbench.utils.parser import get_parser + + +def run(): + args = [ + "--batch", + "4", + "--seq-len", + "16384", + "--n-heads", + "32", + "--d-head", + "64", + "--precision", + "fp16", + "--bwd", + "--only", + "triton_tutorial_flash_v2", + "--causal", + "--metrics", + "tflops", + ] + flash_attn_op = tritonbench.load_opbench_by_name("flash_attention") + parser = get_parser() + args, extra_args = parser.parse_known_args(args) + flash_attn_bench = flash_attn_op(args, extra_args) + flash_attn_bench.run() + print(flash_attn_bench.output) + + +if __name__ == "__main__": + run() diff --git a/benchmarks/gemm_bench/run.py b/benchmarks/gemm_bench/run.py new file mode 100644 index 00000000..7dcc1dad --- /dev/null +++ b/benchmarks/gemm_bench/run.py @@ -0,0 +1,39 @@ +""" +Run gemm benchmark with a single input shape: + +M = 4096 +N = 4096 +K = 4096 + +Print tflops metrics +""" + +import tritonbench +from tritonbench.utils.parser import get_parser + + +def run(): + args = [ + "--m", + "4096", + "--n", + "4096", + "--k", + "4096", + "--precision", + "fp16", + "--only", + "triton_tutorial_matmul", + "--metrics", + "tflops", + ] + gemm_op = tritonbench.load_opbench_by_name("gemm") + parser = get_parser() + args, extra_args = parser.parse_known_args(args) + gemm_bench = gemm_op(args, extra_args) + gemm_bench.run() + print(gemm_bench.output) + + +if __name__ == "__main__": + run() diff --git a/tritonbench/operators/flash_attention/operator.py b/tritonbench/operators/flash_attention/operator.py index 5d3bcc46..5cc4d223 100644 --- a/tritonbench/operators/flash_attention/operator.py +++ b/tritonbench/operators/flash_attention/operator.py @@ -34,6 +34,7 @@ import argparse import math import os +from contextlib import nullcontext from itertools import chain import torch @@ -58,6 +59,7 @@ attention_opt as triton_tutorial_FA2_opt, ) + # [Optional] flash_attn v2 try: from flash_attn.flash_attn_interface import ( @@ -94,7 +96,7 @@ from .test_fmha_utils import permute_qkv HAS_XFORMERS = True -except (ImportError, IOError, AttributeError): +except (ImportError, IOError, AttributeError, TypeError): HAS_XFORMERS = False # [Optional] colfax cutlass backend @@ -141,7 +143,7 @@ def parse_op_args(args: List[str]): parser = argparse.ArgumentParser() parser.add_argument("--batch", type=int, default=4, help="Batch size") - parser.add_argument("--seq-len", type=int, default=11, help="Batch size") + parser.add_argument("--seq-len", type=int, default=None, help="Sequence length") parser.add_argument("--n-heads", type=int, default=48, help="Number of heads") parser.add_argument("--d-head", type=int, default=64, help="specify head dimension") parser.add_argument( @@ -149,6 +151,12 @@ def parse_op_args(args: List[str]): action="store_true", help="enable causal (always true on backward)", ) + parser.add_argument( + "--native-sdpa", action="store_true", help="Use SDPA native choice." + ) + parser.add_argument( + "--additional-inputs", action="store_true", help="enable additional inputs" + ) return parser.parse_args(args) @@ -168,10 +176,12 @@ def __init__( self.D_HEAD = args.d_head self.N_CTX = None self.causal = args.causal + self.native_sdpa = args.native_sdpa # We always turn on causal for backward # Because Triton-Flash-V2 does not support backward with non-causal if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD: self.causal = True + self.additional_inputs = args.additional_inputs self.sm_scale = 1.3 @register_benchmark() @@ -201,7 +211,12 @@ def sdpa( v: torch.Tensor, ) -> Callable: def sdpa_flash_attention(q, k, v): - with sdpa_kernel([SDPBackend.FLASH_ATTENTION]): + cxt = ( + nullcontext + if self.native_sdpa + else sdpa_kernel([SDPBackend.FLASH_ATTENTION]) + ) + with cxt: return sdpa( q, k, @@ -480,16 +495,23 @@ def get_input_iter(self) -> Generator: D_HEAD = self.D_HEAD BATCH = self.BATCH H = self.H + SEQ_LEN_LOG2 = 7 def get_ctx_vals(): - for i in range(self.SEQ_LEN, 15): + if self.SEQ_LEN: + yield (BATCH, self.H, self.SEQ_LEN, self.D_HEAD) + return + for i in range(SEQ_LEN_LOG2, 15): N_CTX = 2**i # BATCH = 16384 // N_CTX # H = 2048 // D_HEAD yield (BATCH, H, N_CTX, D_HEAD) ctx_vals = get_ctx_vals() - shapes = self.__additional_example_input(ctx_vals) + if self.additional_inputs: + shapes = self.__additional_example_input(ctx_vals) + else: + shapes = ctx_vals requires_grad = True for shape in shapes: BATCH, H, N_CTX, D_HEAD = shape