Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nightly benchmarking on Triton pytorch and triton-main versions #38

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
20 changes: 20 additions & 0 deletions .ci/tritonbench/test-nightly-h100.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

# CI script to run Triton nightly on H100
# For now, only run launch-latency benchmark
set -x

if [ -z "${SETUP_SCRIPT}" ]; then
echo "ERROR: SETUP_SCRIPT is not set"
exit 1
fi

. "${SETUP_SCRIPT}"

# Run on Triton-pytorch
conda activate pytorch
python -m benchmarks.nightly.run

# Run on Triton-main
conda activate triton-main
python -m benchmarks.nightly.run
4 changes: 4 additions & 0 deletions benchmarks/nightly/postrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def postrun_analysis(output_dir: str, remove_raw_data: bool=True):
""" Aggregate all benchmark json files into a single file.
"""
pass
96 changes: 96 additions & 0 deletions benchmarks/nightly/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Tritonbench nightly run
"""
import argparse
import os
import sys
from os.path import abspath, exists
import subprocess
from datetime import datetime
from pathlib import Path
import time

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

OPERATORS = {
# Case 1: Launch latency
"launch_latency": [
"--op", "launch_latency",
"--metrics", "walltime",
],
# TODO: Add compile time
# Case 2: GEMM TFLOPS
"gemm": [
"--op", "gemm",
"--only", "triton_tutorial_matmul",
"--metrics", "tflops"
],
# TODO: Add compile time
# Case 3: Flash Attention FWD_BWD TFLOPS
"flash_attention": [
"--op", "flash_attention",
"--mode", "fwd_bwd",
"--metrics", "tflops",
],
}


def setup_tritonbench_cwd():
original_dir = abspath(os.getcwd())

for tritonbench_dir in (
".",
"../../../tritonbench",
):
if exists(tritonbench_dir):
break

if exists(tritonbench_dir):
tritonbench_dir = abspath(tritonbench_dir)
os.chdir(tritonbench_dir)
sys.path.append(tritonbench_dir)
return original_dir

def setup_output_dir(create: bool=True) -> str:
output_dir = os.path.join(CURRENT_DIR, ".data",
"run_{}_{}".format(os.environ["USER"], datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S")))
if create:
Path(output_dir).mkdir(exist_ok=True, parents=True)
return output_dir

def run_op(op: str, output_dir: str, continue_on_fail: bool=False) -> None:
from tritonbench.utils.path_utils import REPO_PATH
from tritonbench.utils.triton_op import IS_FBCODE
assert op in OPERATORS, f"Operator {op} not in {OPERATORS.keys()}."
op_task_cmd = [] if IS_FBCODE else [sys.executable, "run.py"]
op_task_cmd.extend(OPERATORS[op])
op_task_cmd.extend(["--output", os.path.join(output_dir, f"nightly_{op}.csv")])
try:
print("[tritonbench] running command: " + " ".join(op_task_cmd))
subprocess.check_call(op_task_cmd, stdout=sys.stdout, stderr=sys.stderr, cwd=REPO_PATH)
except subprocess.CalledProcessError:
if continue_on_fail:
pass
else:
raise
except KeyboardInterrupt:
print("KeyboardInterrupt received, exiting...")
sys.exit(1)


def run():
parser = argparse.ArgumentParser()
parser.add_argument("--continue-on-fail", action="store_true", help="Continue on failed operator.")
parser.add_argument("--output-dir", default=None, help="Directory to save the results.")
args = parser.parse_args()
setup_tritonbench_cwd()
if not args.output_dir:
args.output_dir = setup_output_dir()
for op in OPERATORS:
run_op(op, args.output_dir, continue_on_fail=args.continue_on_fail)
# analyze the json files post run
from .postrun import postrun_analysis
postrun_analysis(args.output_dir)

if __name__ == "__main__":
run()
10 changes: 1 addition & 9 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse

from typing import List
from typing import List, Optional

from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS
from tritonbench.utils.triton_op import DEFAULT_RUN_ITERS, DEFAULT_WARMUP, IS_FBCODE
Expand Down Expand Up @@ -80,11 +79,6 @@ def get_parser(args=None):
action="store_true",
help="Plot the result.",
)
parser.add_argument(
"--ci",
action="store_true",
help="Run in the CI mode.",
)
parser.add_argument(
"--metrics",
default=None,
Expand Down Expand Up @@ -187,8 +181,6 @@ def get_parser(args=None):
)

args, extra_args = parser.parse_known_args(args)
if args.op and args.ci:
parser.error("cannot specify operator when in CI mode")
if not args.op and not args.op_collection:
print(
"Neither operator nor operator collection is specified. Running all operators in the default collection."
Expand Down
Loading