From afa8e31dd971a3febaafa186245eaf9df90204aa Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 2 Dec 2024 14:55:56 -0500 Subject: [PATCH 1/3] Disable donated buffer when benchmarking --- tritonbench/operators/layer_norm/operator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index 6627697c..a87e8b03 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,6 +34,11 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + torch._functorch.config.donated_buffer = False @torch.compile def inner(*args): return F.layer_norm(*args) From 3f75a34f4258ab0aa17059abb00b45e06cbfe100 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 2 Dec 2024 12:06:12 -0800 Subject: [PATCH 2/3] Lint --- tritonbench/operators/layer_norm/operator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index a87e8b03..cfde8064 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -38,7 +38,9 @@ def torch_compile_layer_norm(self, *args): # so donated buffer have to be disabled if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: import torch._functorch.config + torch._functorch.config.donated_buffer = False + @torch.compile def inner(*args): return F.layer_norm(*args) From f0e2a047167c139368bc85b4bf3c9f8478a39708 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 2 Dec 2024 12:30:19 -0800 Subject: [PATCH 3/3] Fix unit test --- tritonbench/operators/layer_norm/operator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index cfde8064..ecfc4444 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -40,6 +40,7 @@ def torch_compile_layer_norm(self, *args): import torch._functorch.config torch._functorch.config.donated_buffer = False + import torch @torch.compile def inner(*args):