diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index c744fe8..493968f 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -64,7 +64,9 @@ def inductor_geglu(self, input) -> Callable: # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config + from torch._functorch import config as functorch_config + + functorch_config.donated_buffer = False compiled = torch.compile(self.baseline_model) return lambda: compiled(input) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index db48a03..40228c2 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -38,9 +38,9 @@ def torch_compile_layer_norm(self, *args): # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config + from torch._functorch import config as functorch_config - torch._functorch.config.donated_buffer = False + functorch_config.donated_buffer = False import torch @torch.compile diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index b414513..48d998e 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -64,7 +64,9 @@ def inductor_swiglu(self, input) -> Callable: # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config + from torch._functorch import config as functorch_config + + functorch_config.donated_buffer = False compiled = torch.compile(self.baseline_op) return lambda: compiled(input)