diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 6e7a3626..607d2935 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -68,7 +68,7 @@ def _run_one_operator(args: List[str]): check_ci_output(op) del op # Test backward (if applicable) - if op.has_bwd: + if op.has_bwd(): tb_args.mode = "bwd" op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() @@ -89,7 +89,7 @@ def _run_operator_in_task(op: str, args: List[str]): task.run() task.check_output() # Test backward (if applicable) - if task.get_attribute("has_bwd"): + if task.get_attribute("has_bwd", method=True): task.del_op_instance() args.extend(["--bwd"]) task.make_operator_instance(args=args) diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index adc4c1de..ff365a49 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -165,7 +165,10 @@ def run(self) -> None: @base_task.run_in_worker(scoped=True) @staticmethod def get_attribute( - attr: str, field: Optional[str] = None, classattr: bool = False + attr: str, + field: Optional[str] = None, + classattr: bool = False, + method: bool = False, ) -> Any: if classattr: op = globals()["Operator"] @@ -173,10 +176,10 @@ def get_attribute( op = globals()["op"] if hasattr(op, attr): if field: - op_attr = getattr(op, attr) - return getattr(op_attr, field) + op_attr = getattr(getattr(op, attr), field) else: - return getattr(op, attr) + op_attr = getattr(op, attr) + return op_attr() if method else op_attr else: return None