diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index 6627697c..ecfc4444 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,6 +34,14 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + + torch._functorch.config.donated_buffer = False + import torch + @torch.compile def inner(*args): return F.layer_norm(*args)