Skip to content

Commit

Permalink
Fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 14, 2024
1 parent cfb6990 commit 1a4b96f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tritonbench/components/compile_time/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .trace import do_compile_time_in_task
from .trace import do_compile_time_in_task
3 changes: 2 additions & 1 deletion tritonbench/components/compile_time/trace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Callable

import torch
from tritonbench.utils.env_utils import fresh_triton_cache

from typing import Callable

def do_compile_time_in_task(fn: Callable) -> float:
with fresh_triton_cache():
Expand Down
13 changes: 7 additions & 6 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,8 +1205,12 @@ def _init_extra_metrics() -> Dict[str, Any]:
)
from tritonbench.components.compile_time import do_compile_time_in_task

metrics.extra_metrics["_compile_time_in_task"] = do_compile_time_in_task(fn)
self._latency_with_compile_in_task = metrics.extra_metrics["_compile_time_in_task"]
metrics.extra_metrics["_compile_time_in_task"] = (
do_compile_time_in_task(fn)
)
self._latency_with_compile_in_task = metrics.extra_metrics[
"_compile_time_in_task"
]
if "_ncu_trace_in_task" in self.required_metrics:
assert (
self.required_metrics == ["_ncu_trace_in_task"]
Expand Down Expand Up @@ -1552,9 +1556,7 @@ def compile_time(
]
)
op_task = OpTask(name=self.name)
op_task.make_operator_instance(
args = op_task_args
)
op_task.make_operator_instance(args=op_task_args)
op_task.run()
latency_with_compile = op_task.get_attribute("_latency_with_compile_in_task")
del op_task
Expand All @@ -1579,7 +1581,6 @@ def hw_roofline(self) -> float:
return rooflines[self.tb_args.precision]
return rooflines


def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
Expand Down

0 comments on commit 1a4b96f

Please sign in to comment.