Skip to content

Commit

Permalink
Add flash attention and gemm benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 4, 2024
1 parent d044227 commit a562c2f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 307 deletions.
2 changes: 1 addition & 1 deletion benchmarks/flash_attention_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def run():
args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "bf16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops", "--precision", "fp16"]
args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "fp16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops"]
flash_attn_op = tritonbench.load_opbench_by_name("flash_attention")
parser = get_parser()
args, extra_args = parser.parse_known_args(args)
Expand Down
25 changes: 25 additions & 0 deletions benchmarks/gemm_bench/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Run gemm benchmark with a single input shape:
M = 4096
N = 4096
K = 4096
Print tflops metrics
"""

import tritonbench
from tritonbench.utils.parser import get_parser


def run():
args = ["--m", "4096", "--n", "4096", "--k", "4096", "--precision", "fp16", "--only", "triton_tutorial_matmul", "--metrics", "tflops"]
gemm_op = tritonbench.load_opbench_by_name("gemm")
parser = get_parser()
args, extra_args = parser.parse_known_args(args)
gemm_bench = gemm_op(args, extra_args)
gemm_bench.run()
print(gemm_bench.output)

if __name__ == "__main__":
run()
301 changes: 0 additions & 301 deletions tritonbench/kernels/triton_fused_attention_vanilla.py

This file was deleted.

7 changes: 2 additions & 5 deletions tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@
from tritonbench.kernels.triton_fused_attention import (
attention_opt as triton_tutorial_FA2_opt,
)
from tritonbench.kernels.triton_fused_attention_vanilla import (
attention as triton_tutorial_FA2,
)


# [Optional] flash_attn v2
Expand Down Expand Up @@ -257,8 +254,8 @@ def triton_tutorial_flash_v2(
v: torch.Tensor,
) -> Callable:
# base: do not enable TMA/WarpSpec/CompPipe
return lambda: triton_tutorial_FA2(
q, k, v, self.causal, self.sm_scale
return lambda: triton_tutorial_FA2_opt(
q, k, v, self.causal, self.sm_scale, "base"
)

@register_benchmark(enabled=HAS_CUDA_124)
Expand Down

0 comments on commit a562c2f

Please sign in to comment.