Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 16, 2024
1 parent 13b585c commit 5dae101
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 13 deletions.
4 changes: 4 additions & 0 deletions benchmarks/nightly/postrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def postrun_analysis(output_dir: str, remove_raw_data: bool=True):
""" Aggregate all benchmark json files into a single file.
"""
pass
71 changes: 63 additions & 8 deletions benchmarks/nightly/run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
"""
Tritonbench nightly run
"""
import argparse
import os
import sys
from os.path import abspath, exists
import subprocess
from datetime import datetime
from pathlib import Path
import time

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

OPERATORS = {
# Case 1: Launch latency
"launch_latency": [
"--op", "launch_latency",
"--metrics", "walltime",
],
# TODO: Add compile time
# Case 2: GEMM TFLOPS
"gemm": [
"--op", "gemm",
"--only", "triton_tutorial_matmul",
"--metrics", "tflops"
],
# TODO: Add compile time
# Case 3: Flash Attention FWD_BWD TFLOPS
"flash_attention": [
"--op", "flash_attention",
"--mode", "fwd_bwd",
"--metrics", "tflops",
],
}


def setup_tritonbench_cwd():
original_dir = abspath(os.getcwd())
Expand All @@ -21,21 +51,46 @@ def setup_tritonbench_cwd():
sys.path.append(tritonbench_dir)
return original_dir

def setup_output_dir(create: bool=True) -> str:
output_dir = os.path.join(CURRENT_DIR, ".data",
"run_{}_{}".format(os.environ["USER"], datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S")))
if create:
Path(output_dir).mkdir(exist_ok=True, parents=True)
return output_dir

OPERATORS = [
"launch_latency",
"addmm",
"gemm",
"flash_attention",
]
def run_op(op: str, output_dir: str, continue_on_fail: bool=False) -> None:
from tritonbench.utils.path_utils import REPO_PATH
from tritonbench.utils.triton_op import IS_FBCODE
assert op in OPERATORS, f"Operator {op} not in {OPERATORS.keys()}."
op_task_cmd = [] if IS_FBCODE else [sys.executable, "run.py"]
op_task_cmd.extend(OPERATORS[op])
op_task_cmd.extend(["--output", os.path.join(output_dir, f"nightly_{op}.csv")])
try:
print("[tritonbench] running command: " + " ".join(op_task_cmd))
subprocess.check_call(op_task_cmd, stdout=sys.stdout, stderr=sys.stderr, cwd=REPO_PATH)
except subprocess.CalledProcessError:
if continue_on_fail:
pass
else:
raise
except KeyboardInterrupt:
print("KeyboardInterrupt received, exiting...")
sys.exit(1)


def run():
parser = argparse.ArgumentParser()
parser.add_argument("--continue-on-fail", action="store_true", help="Continue on failed operator.")
parser.add_argument("--output-dir", default=None, help="Directory to save the results.")
args = parser.parse_args()
setup_tritonbench_cwd()
from tritonbench.utils.runner import tritonbench_run_in_subprocess
if not args.output_dir:
args.output_dir = setup_output_dir()
for op in OPERATORS:
tritonbench_run_in_subprocess(op)
run_op(op, args.output_dir, continue_on_fail=args.continue_on_fail)
# analyze the json files post run
from .postrun import postrun_analysis
postrun_analysis(args.output_dir)

if __name__ == "__main__":
run()
5 changes: 0 additions & 5 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,6 @@ def get_parser(args=None):
action="store_true",
help="when true randomly shuffles the inputs before running benchmarks where possible.",
)
parser.add_argument(
"--child",
action="store_true",
help="Flag option that it is running in the child process.",
)

if IS_FBCODE:
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
Expand Down

0 comments on commit 5dae101

Please sign in to comment.