From 7472742f8b21207c718e3806306bdc546ce0234a Mon Sep 17 00:00:00 2001 From: FindHao Date: Wed, 11 Dec 2024 10:26:01 -0800 Subject: [PATCH] change imports --- tritonbench/operators/geglu/operator.py | 4 +++- tritonbench/operators/layer_norm/operator.py | 4 ++-- tritonbench/operators/swiglu/operator.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tritonbench/operators/geglu/operator.py b/tritonbench/operators/geglu/operator.py index c744fe8..493968f 100644 --- a/tritonbench/operators/geglu/operator.py +++ b/tritonbench/operators/geglu/operator.py @@ -64,7 +64,9 @@ def inductor_geglu(self, input) -> Callable: # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config + from torch._functorch import config as functorch_config + + functorch_config.donated_buffer = False compiled = torch.compile(self.baseline_model) return lambda: compiled(input) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index db48a03..40228c2 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -38,9 +38,9 @@ def torch_compile_layer_norm(self, *args): # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config + from torch._functorch import config as functorch_config - torch._functorch.config.donated_buffer = False + functorch_config.donated_buffer = False import torch @torch.compile diff --git a/tritonbench/operators/swiglu/operator.py b/tritonbench/operators/swiglu/operator.py index b414513..48d998e 100644 --- a/tritonbench/operators/swiglu/operator.py +++ b/tritonbench/operators/swiglu/operator.py @@ -64,7 +64,9 @@ def inductor_swiglu(self, input) -> Callable: # We need to run backward multiple times for proper benchmarking # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: - import torch._functorch.config + from torch._functorch import config as functorch_config + + functorch_config.donated_buffer = False compiled = torch.compile(self.baseline_op) return lambda: compiled(input)