Skip to content

Commit

Permalink
Add format
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 4, 2024
1 parent a562c2f commit c3e6792
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,5 @@ __pycache__/
.ipynb_checkpoints/
.idea
*.egg-info/
<<<<<<< HEAD
torch_compile_debug/
=======
build/
>>>>>>> 1a642f1 (Add flash_attention_benchmark)
20 changes: 19 additions & 1 deletion benchmarks/flash_attention_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,31 @@


def run():
args = ["--batch", "4", "--seq-len", "16384", "--n-heads", "32", "--d-head", "64", "--precision", "fp16", "--bwd", "--only", "triton_tutorial_flash_v2", "--causal", "--metrics", "tflops"]
args = [
"--batch",
"4",
"--seq-len",
"16384",
"--n-heads",
"32",
"--d-head",
"64",
"--precision",
"fp16",
"--bwd",
"--only",
"triton_tutorial_flash_v2",
"--causal",
"--metrics",
"tflops",
]
flash_attn_op = tritonbench.load_opbench_by_name("flash_attention")
parser = get_parser()
args, extra_args = parser.parse_known_args(args)
flash_attn_bench = flash_attn_op(args, extra_args)
flash_attn_bench.run()
print(flash_attn_bench.output)


if __name__ == "__main__":
run()
16 changes: 15 additions & 1 deletion benchmarks/gemm_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,27 @@


def run():
args = ["--m", "4096", "--n", "4096", "--k", "4096", "--precision", "fp16", "--only", "triton_tutorial_matmul", "--metrics", "tflops"]
args = [
"--m",
"4096",
"--n",
"4096",
"--k",
"4096",
"--precision",
"fp16",
"--only",
"triton_tutorial_matmul",
"--metrics",
"tflops",
]
gemm_op = tritonbench.load_opbench_by_name("gemm")
parser = get_parser()
args, extra_args = parser.parse_known_args(args)
gemm_bench = gemm_op(args, extra_args)
gemm_bench.run()
print(gemm_bench.output)


if __name__ == "__main__":
run()
5 changes: 4 additions & 1 deletion tritonbench/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def parse_op_args(args: List[str]):
action="store_true",
help="enable causal (always true on backward)",
)
parser.add_argument("--additional-inputs", action="store_true", help="enable additional inputs")
parser.add_argument(
"--additional-inputs", action="store_true", help="enable additional inputs"
)
return parser.parse_args(args)


Expand Down Expand Up @@ -481,6 +483,7 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:

def get_input_iter(self) -> Generator:
import math

D_HEAD = self.D_HEAD
BATCH = self.BATCH
H = self.H
Expand Down

0 comments on commit c3e6792

Please sign in to comment.