diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index c7855cf..89994e4 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1490,14 +1490,16 @@ def service_exists(service_name): ).resolve() ncu_args = [ "ncu", - "--nvtx", - "--nvtx-include", - f"{_RANGE_NAME}/", "--target-processes", "all", "--import-source", "yes", ] + # NCU does not recognize NVTX range on backward + # So we have to trace the entire process over backward + # See: https://github.com/pytorch-labs/tritonbench/issues/87 + if not (self.mode == Mode.BWD or self.mode == Mode.FWD_BWD): + ncu_args.extend(["--nvtx", "--nvtx-include", f"{_RANGE_NAME}/"]) ncu_args.extend(extend_ncu_args) if replay: ncu_args.extend(