diff --git a/.gitignore b/.gitignore index fc0098f4..ffc5241f 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ __pycache__/ *.egg-info/ torch_compile_debug/ build/ +/*.csv +*.hatchet diff --git a/tritonbench/components/do_bench/__init__.py b/tritonbench/components/do_bench/__init__.py new file mode 100644 index 00000000..1daabb0a --- /dev/null +++ b/tritonbench/components/do_bench/__init__.py @@ -0,0 +1 @@ +from .run import do_bench_wrapper diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py new file mode 100644 index 00000000..441bef11 --- /dev/null +++ b/tritonbench/components/do_bench/run.py @@ -0,0 +1,34 @@ +import torch +import triton + + +def do_bench_wrapper( + fn, + warmup, + rep, + grad_to_none, + use_cuda_graphs: bool = False, + bypass_fail: bool = False, +): + """Wrapper to triton's do_bench to gain latency.""" + if use_cuda_graphs: + with torch.cuda.stream(torch.cuda.Stream()): + return triton.testing.do_bench_cudagraph( + fn, + rep=rep, + return_mode="median", + grad_to_none=grad_to_none, + ) + else: + try: + return triton.testing.do_bench( + fn, + warmup=warmup, + rep=rep, + return_mode="median", + grad_to_none=grad_to_none, + ) + except Exception as e: + if not bypass_fail: + raise e + return None diff --git a/tritonbench/components/proton/__init__.py b/tritonbench/components/proton/__init__.py new file mode 100644 index 00000000..dbd7f529 --- /dev/null +++ b/tritonbench/components/proton/__init__.py @@ -0,0 +1 @@ +from .trace import proton_trace diff --git a/tritonbench/components/proton/trace.py b/tritonbench/components/proton/trace.py new file mode 100644 index 00000000..7776063e --- /dev/null +++ b/tritonbench/components/proton/trace.py @@ -0,0 +1,25 @@ +from typing import Callable, Optional + +import triton.profiler as proton + + +def proton_trace( + session_id: int, + scope_name: str, + fn: Callable, + warmup: int, + flops: Optional[int] = None, + bytes: Optional[int] = None, +): + # warmup + for _ in range(warmup): + fn() + metrics_dict = {} + if flops: + metrics_dict["flops"] = flops + if bytes: + metrics_dict["bytes"] = bytes + proton.activate(session_id) + with proton.scope(scope_name, metrics=metrics_dict): + fn() + proton.deactivate(session_id) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 8979cc5f..72050ae7 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -23,6 +23,7 @@ import torch import triton +from tritonbench.components.do_bench import do_bench_wrapper from tritonbench.components.ncu import ncu_analyzer, nsys_analyzer from tritonbench.utils.env_utils import ( apply_precision, @@ -706,6 +707,12 @@ def run( """Benchmarking the operator and returning its metrics.""" metrics = [] try: + if "proton" in self.required_metrics: + import triton.profiler as proton + + self._proton_session_id = proton.start() + proton.enter_scope(f"tritonbench_run_op_{self.name}") + proton.deactivate(self._proton_session_id) input_id_range = range(self._input_id, self._input_id + self._num_inputs) if tqdm is not None: input_id_range = tqdm(input_id_range) @@ -713,8 +720,13 @@ def run( for _dryrun_input_id in range(self._input_id): self.example_inputs = self.get_example_inputs() for input_id in input_id_range: - self._cur_input_id = input_id self.example_inputs = self.get_example_inputs() + x_val = self.get_x_val(self.example_inputs) + if "proton" in self.required_metrics: + proton.activate(self._proton_session_id) + proton.enter_scope(f"x_val_{x_val}") + proton.deactivate(self._proton_session_id) + self._cur_input_id = input_id if self.example_inputs is None: logger.warn( f"The input generator get_input_iter() has depleted at id {input_id}. Available number of " @@ -733,7 +745,6 @@ def run( self.baseline_fn = None self.baseline_metrics = None self._op_flops = {} - x_val = self.get_x_val(self.example_inputs) if self._only: benchmarks = self._only else: @@ -774,8 +785,15 @@ def _reduce_benchmarks(acc, bm_name: str): _reduce_benchmarks, benchmarks, {} ) metrics.append((x_val, y_vals)) - del self.example_inputs - gc.collect() + del self.example_inputs # save some memory + if "proton" in self.required_metrics: + proton.activate(self._proton_session_id) + proton.exit_scope() + proton.deactivate(self._proton_session_id) + if "proton" in self.required_metrics: + proton.activate(self._proton_session_id) + proton.exit_scope() + proton.finalize() except (KeyboardInterrupt, Exception): logger.warning( "Caught exception, terminating early with partial results", @@ -968,27 +986,14 @@ def _init_extra_metrics() -> Dict[str, Any]: if {"latency", "tflops", "speedup", "compile_time"} & set( self.required_metrics ): - if self.use_cuda_graphs: - with torch.cuda.stream(torch.cuda.Stream()): - metrics.latency = triton.testing.do_bench_cudagraph( - fn, - rep=rep, - return_mode="median", - grad_to_none=self.get_grad_to_none(self.example_inputs), - ) - else: - try: - metrics.latency = triton.testing.do_bench( - fn, - warmup=warmup, - rep=rep, - return_mode="median", - grad_to_none=self.get_grad_to_none(self.example_inputs), - ) - except Exception as e: - if not self.tb_args.bypass_fail: - raise e - metrics.latency = None + metrics.latency = do_bench_wrapper( + fn, + warmup, + rep, + grad_to_none=self.get_grad_to_none(self.example_inputs), + use_cuda_graphs=self.use_cuda_graphs, + bypass_fail=self.tb_args.bypass_fail, + ) if { "gpu_peak_mem", "gpu_mem_footprint_compression_ratio", @@ -1118,6 +1123,20 @@ def _init_extra_metrics() -> Dict[str, Any]: ) if "kineto_trace" in self.required_metrics: metrics.kineto_trace = self.kineto_trace(input_id, fn) + if "proton" in self.required_metrics: + from tritonbench.components.proton import proton_trace + + scope_name = fn_name + flops = self.flops() if self.has_metric("flops") else None + num_bytes = self.bytes() if self.has_metric("bytes") else None + proton_trace( + self._proton_session_id, + scope_name, + fn, + warmup=warmup, + flops=flops, + bytes=num_bytes, + ) if "best_config" in self.required_metrics: metrics.best_config = self.best_config(fn) # run the hidden metric "_compile_time_in_task" @@ -1592,3 +1611,7 @@ def run_and_capture(self, *args, **kwargs): @classmethod def has_bwd(cls) -> bool: return cls.get_bwd_fn is not BenchmarkOperator.get_bwd_fn + + @classmethod + def has_metric(cls, metric_name: str) -> bool: + return bool(getattr(cls, metric_name, None))