From 5bfe61c67b523bed82a8c2c9f5055fd3829c117d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 16:45:43 -0500 Subject: [PATCH 1/4] Use code detection to check bwd method override. --- test/test_gpu/main.py | 11 ++--------- tritonbench/utils/triton_op.py | 4 ++++ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 2738ba77..a0012ec4 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -68,15 +68,11 @@ def _run_one_operator(args: List[str]): check_ci_output(op) del op # Test backward (if applicable) - try: + if op.has_bwd: tb_args.mode = "bwd" op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() check_ci_output(op) - except NotImplementedError: - logger.info( - f"Operator {op.name} does not support backward, skipping backward test." - ) def _run_operator_in_task(op: str, args: List[str]): @@ -94,14 +90,11 @@ def _run_operator_in_task(op: str, args: List[str]): task.check_output() task.del_op_instance() # Test backward (if applicable) - try: + if task.get_attribute("has_bwd"): args.extend(["--bwd"]) task.make_operator_instance(args=args) task.run() task.check_output() - except NotImplementedError: - # Operator does not support backward, skip the test - pass def make_test(operator): diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 785583c7..cdc109e3 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1517,3 +1517,7 @@ def run_and_capture(self, *args, **kwargs): ir_dir / f"{fn._name}_{kernel.name}_{input_id}.sass", "w" ) as f: f.write(sass) + + @property + def has_bwd(self) -> bool: + return self.get_bwd_fn.__code__ is BenchmarkOperator.get_bwd_fn.__code__ From c4cb4ca2e566906788aa103790172c0636dd830a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 16:54:52 -0500 Subject: [PATCH 2/4] Bugfix --- test/test_gpu/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index a0012ec4..6e7a3626 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -88,9 +88,9 @@ def _run_operator_in_task(op: str, args: List[str]): task.make_operator_instance(args=args) task.run() task.check_output() - task.del_op_instance() # Test backward (if applicable) if task.get_attribute("has_bwd"): + task.del_op_instance() args.extend(["--bwd"]) task.make_operator_instance(args=args) task.run() From f2d52db29b1e5962cef593157b93d14816064a1a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 18:01:23 -0800 Subject: [PATCH 3/4] Use a smarter way to detect backward method. --- tritonbench/utils/triton_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index cdc109e3..ccaf88cb 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1518,6 +1518,6 @@ def run_and_capture(self, *args, **kwargs): ) as f: f.write(sass) - @property - def has_bwd(self) -> bool: - return self.get_bwd_fn.__code__ is BenchmarkOperator.get_bwd_fn.__code__ + @classmethod + def has_bwd(cls) -> bool: + return cls.get_bwd_fn is not BenchmarkOperator.get_bwd_fn From 4b7f1b006a3b0a5a858c2cac6484f48d24982a6a Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 21:06:18 -0500 Subject: [PATCH 4/4] Fix op_task --- test/test_gpu/main.py | 4 ++-- tritonbench/operators/op_task.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 6e7a3626..607d2935 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -68,7 +68,7 @@ def _run_one_operator(args: List[str]): check_ci_output(op) del op # Test backward (if applicable) - if op.has_bwd: + if op.has_bwd(): tb_args.mode = "bwd" op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() @@ -89,7 +89,7 @@ def _run_operator_in_task(op: str, args: List[str]): task.run() task.check_output() # Test backward (if applicable) - if task.get_attribute("has_bwd"): + if task.get_attribute("has_bwd", method=True): task.del_op_instance() args.extend(["--bwd"]) task.make_operator_instance(args=args) diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index adc4c1de..ff365a49 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -165,7 +165,10 @@ def run(self) -> None: @base_task.run_in_worker(scoped=True) @staticmethod def get_attribute( - attr: str, field: Optional[str] = None, classattr: bool = False + attr: str, + field: Optional[str] = None, + classattr: bool = False, + method: bool = False, ) -> Any: if classattr: op = globals()["Operator"] @@ -173,10 +176,10 @@ def get_attribute( op = globals()["op"] if hasattr(op, attr): if field: - op_attr = getattr(op, attr) - return getattr(op_attr, field) + op_attr = getattr(getattr(op, attr), field) else: - return getattr(op, attr) + op_attr = getattr(op, attr) + return op_attr() if method else op_attr else: return None