Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metrics] Add proton profiling #102

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ __pycache__/
*.egg-info/
torch_compile_debug/
build/
/*.csv
*.hatchet
1 change: 1 addition & 0 deletions tritonbench/components/do_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .run import do_bench_wrapper
34 changes: 34 additions & 0 deletions tritonbench/components/do_bench/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import triton


def do_bench_wrapper(
fn,
warmup,
rep,
grad_to_none,
use_cuda_graphs: bool = False,
bypass_fail: bool = False,
):
"""Wrapper to triton's do_bench to gain latency."""
if use_cuda_graphs:
with torch.cuda.stream(torch.cuda.Stream()):
return triton.testing.do_bench_cudagraph(
fn,
rep=rep,
return_mode="median",
grad_to_none=grad_to_none,
)
else:
try:
return triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
return_mode="median",
grad_to_none=grad_to_none,
)
except Exception as e:
if not bypass_fail:
raise e
return None
1 change: 1 addition & 0 deletions tritonbench/components/proton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .trace import proton_trace
25 changes: 25 additions & 0 deletions tritonbench/components/proton/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Callable, Optional

import triton.profiler as proton


def proton_trace(
session_id: int,
scope_name: str,
fn: Callable,
warmup: int,
flops: Optional[int] = None,
bytes: Optional[int] = None,
):
# warmup
for _ in range(warmup):
fn()
metrics_dict = {}
if flops:
metrics_dict["flops"] = flops
if bytes:
metrics_dict["bytes"] = bytes
proton.activate(session_id)
with proton.scope(scope_name, metrics=metrics_dict):
fn()
proton.deactivate(session_id)
73 changes: 48 additions & 25 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
import triton

from tritonbench.components.do_bench import do_bench_wrapper
from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer
from tritonbench.utils.env_utils import (
apply_precision,
Expand Down Expand Up @@ -706,15 +707,26 @@ def run(
"""Benchmarking the operator and returning its metrics."""
metrics = []
try:
if "proton" in self.required_metrics:
import triton.profiler as proton

self._proton_session_id = proton.start()
proton.enter_scope(f"tritonbench_run_op_{self.name}")
proton.deactivate(self._proton_session_id)
input_id_range = range(self._input_id, self._input_id + self._num_inputs)
if tqdm is not None:
input_id_range = tqdm(input_id_range)
if self._input_id:
for _dryrun_input_id in range(self._input_id):
self.example_inputs = self.get_example_inputs()
for input_id in input_id_range:
self._cur_input_id = input_id
self.example_inputs = self.get_example_inputs()
x_val = self.get_x_val(self.example_inputs)
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.enter_scope(f"x_val_{x_val}")
proton.deactivate(self._proton_session_id)
self._cur_input_id = input_id
if self.example_inputs is None:
logger.warn(
f"The input generator get_input_iter() has depleted at id {input_id}. Available number of "
Expand All @@ -733,7 +745,6 @@ def run(
self.baseline_fn = None
self.baseline_metrics = None
self._op_flops = {}
x_val = self.get_x_val(self.example_inputs)
if self._only:
benchmarks = self._only
else:
Expand Down Expand Up @@ -774,8 +785,15 @@ def _reduce_benchmarks(acc, bm_name: str):
_reduce_benchmarks, benchmarks, {}
)
metrics.append((x_val, y_vals))
del self.example_inputs
gc.collect()
del self.example_inputs # save some memory
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.exit_scope()
proton.deactivate(self._proton_session_id)
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.exit_scope()
proton.finalize()
except (KeyboardInterrupt, Exception):
logger.warning(
"Caught exception, terminating early with partial results",
Expand Down Expand Up @@ -968,27 +986,14 @@ def _init_extra_metrics() -> Dict[str, Any]:
if {"latency", "tflops", "speedup", "compile_time"} & set(
self.required_metrics
):
if self.use_cuda_graphs:
with torch.cuda.stream(torch.cuda.Stream()):
metrics.latency = triton.testing.do_bench_cudagraph(
fn,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
else:
try:
metrics.latency = triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
except Exception as e:
if not self.tb_args.bypass_fail:
raise e
metrics.latency = None
metrics.latency = do_bench_wrapper(
fn,
warmup,
rep,
grad_to_none=self.get_grad_to_none(self.example_inputs),
use_cuda_graphs=self.use_cuda_graphs,
bypass_fail=self.tb_args.bypass_fail,
)
if {
"gpu_peak_mem",
"gpu_mem_footprint_compression_ratio",
Expand Down Expand Up @@ -1118,6 +1123,20 @@ def _init_extra_metrics() -> Dict[str, Any]:
)
if "kineto_trace" in self.required_metrics:
metrics.kineto_trace = self.kineto_trace(input_id, fn)
if "proton" in self.required_metrics:
from tritonbench.components.proton import proton_trace

scope_name = fn_name
flops = self.flops() if self.has_metric("flops") else None
num_bytes = self.bytes() if self.has_metric("bytes") else None
proton_trace(
self._proton_session_id,
scope_name,
fn,
warmup=warmup,
flops=flops,
bytes=num_bytes,
)
if "best_config" in self.required_metrics:
metrics.best_config = self.best_config(fn)
# run the hidden metric "_compile_time_in_task"
Expand Down Expand Up @@ -1592,3 +1611,7 @@ def run_and_capture(self, *args, **kwargs):
@classmethod
def has_bwd(cls) -> bool:
return cls.get_bwd_fn is not BenchmarkOperator.get_bwd_fn

@classmethod
def has_metric(cls, metric_name: str) -> bool:
return bool(getattr(cls, metric_name, None))
Loading