Skip to content

Commit

Permalink
Add proton profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 6, 2024
1 parent df6cae6 commit b419bb5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,9 +719,10 @@ def run(
for _dryrun_input_id in range(self._input_id):
self.example_inputs = self.get_example_inputs()
for input_id in input_id_range:
x_val = self.get_x_val(self.example_inputs)
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.enter_scope(f"input_id_{input_id}")
proton.enter_scope(f"x_val_{x_val}")
proton.deactivate(self._proton_session_id)
self._cur_input_id = input_id
self.example_inputs = self.get_example_inputs()
Expand All @@ -743,7 +744,6 @@ def run(
self.baseline_fn = None
self.baseline_metrics = None
self._op_flops = {}
x_val = self.get_x_val(self.example_inputs)
if self._only:
benchmarks = self._only
else:
Expand Down Expand Up @@ -786,11 +786,11 @@ def _reduce_benchmarks(acc, bm_name: str):
del self.example_inputs # save some memory
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.exit_scope(f"input_{input_id}")
proton.exit_scope(f"x_val_{x_val}")
proton.deactivate(self._proton_session_id)
if "proton" in self.required_metrics:
proton.activate(self._proton_session_id)
proton.exit_scope("tritonbench_run")
proton.exit_scope(f"tritonbench_run_op_{self.name}")
proton.finalize()
except (KeyboardInterrupt, Exception):
logger.warning(
Expand Down

0 comments on commit b419bb5

Please sign in to comment.