Skip to content

Commit

Permalink
Use code detection to check bwd method override.
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 20, 2024
1 parent b151b84 commit 5bfe61c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
11 changes: 2 additions & 9 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,11 @@ def _run_one_operator(args: List[str]):
check_ci_output(op)
del op
# Test backward (if applicable)
try:
if op.has_bwd:
tb_args.mode = "bwd"
op = Operator(tb_args=tb_args, extra_args=extra_args)
op.run()
check_ci_output(op)
except NotImplementedError:
logger.info(
f"Operator {op.name} does not support backward, skipping backward test."
)


def _run_operator_in_task(op: str, args: List[str]):
Expand All @@ -94,14 +90,11 @@ def _run_operator_in_task(op: str, args: List[str]):
task.check_output()
task.del_op_instance()
# Test backward (if applicable)
try:
if task.get_attribute("has_bwd"):
args.extend(["--bwd"])
task.make_operator_instance(args=args)
task.run()
task.check_output()
except NotImplementedError:
# Operator does not support backward, skip the test
pass


def make_test(operator):
Expand Down
4 changes: 4 additions & 0 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,3 +1517,7 @@ def run_and_capture(self, *args, **kwargs):
ir_dir / f"{fn._name}_{kernel.name}_{input_id}.sass", "w"
) as f:
f.write(sass)

@property
def has_bwd(self) -> bool:
return self.get_bwd_fn.__code__ is BenchmarkOperator.get_bwd_fn.__code__

0 comments on commit 5bfe61c

Please sign in to comment.