Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 14, 2024
1 parent 37af7f5 commit cfb6990
Showing 1 changed file with 2 additions and 16 deletions.
18 changes: 2 additions & 16 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,8 @@ def _init_extra_metrics() -> Dict[str, Any]:
)
from tritonbench.components.compile_time import do_compile_time_in_task

metrics.extra_metrics["_compile_time_in_task"] = do_compile_time_in_task(fn)
metrics.extra_metrics["_compile_time_in_task"] = do_compile_time_in_task(fn)
self._latency_with_compile_in_task = metrics.extra_metrics["_compile_time_in_task"]
if "_ncu_trace_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_ncu_trace_in_task"]
Expand Down Expand Up @@ -1578,21 +1579,6 @@ def hw_roofline(self) -> float:
return rooflines[self.tb_args.precision]
return rooflines

def _compile_time_in_task(
self,
fn: Callable,
) -> float:
with fresh_triton_cache():
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
fn()
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
latency_with_compile = start_event.elapsed_time(end_event)
self._latency_with_compile_in_task = latency_with_compile
return latency_with_compile

def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down

0 comments on commit cfb6990

Please sign in to comment.