Skip to content

Commit

Permalink
Fix op_task
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 20, 2024
1 parent f2d52db commit 4b7f1b0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
4 changes: 2 additions & 2 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _run_one_operator(args: List[str]):
check_ci_output(op)
del op
# Test backward (if applicable)
if op.has_bwd:
if op.has_bwd():
tb_args.mode = "bwd"
op = Operator(tb_args=tb_args, extra_args=extra_args)
op.run()
Expand All @@ -89,7 +89,7 @@ def _run_operator_in_task(op: str, args: List[str]):
task.run()
task.check_output()
# Test backward (if applicable)
if task.get_attribute("has_bwd"):
if task.get_attribute("has_bwd", method=True):
task.del_op_instance()
args.extend(["--bwd"])
task.make_operator_instance(args=args)
Expand Down
11 changes: 7 additions & 4 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,21 @@ def run(self) -> None:
@base_task.run_in_worker(scoped=True)
@staticmethod
def get_attribute(
attr: str, field: Optional[str] = None, classattr: bool = False
attr: str,
field: Optional[str] = None,
classattr: bool = False,
method: bool = False,
) -> Any:
if classattr:
op = globals()["Operator"]
else:
op = globals()["op"]
if hasattr(op, attr):
if field:
op_attr = getattr(op, attr)
return getattr(op_attr, field)
op_attr = getattr(getattr(op, attr), field)
else:
return getattr(op, attr)
op_attr = getattr(op, attr)
return op_attr() if method else op_attr
else:
return None

Expand Down

0 comments on commit 4b7f1b0

Please sign in to comment.