diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index c8a1413..4004b59 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -621,6 +621,7 @@ def __init__( self.mode = Mode.BWD if self.mode in [Mode.FWD_BWD, Mode.BWD]: # TODO: remove this once we have a better way to handle backward benchmarking + import torch._functorch.config torch._functorch.config.donated_buffer = False self.device = tb_args.device self.required_metrics = (