From 2474f1ed1a31a12f5eb7af3c838734a51d172da9 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 2 Dec 2024 14:53:47 -0800 Subject: [PATCH] Disable donated buffer when benchmarking layer_norm with backwards (#88) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Torch sets `donated_buffer = True ` by default but it does not support running backward multiple times, so we have to disable it in benchmarking. Fixes https://github.com/pytorch-labs/tritonbench/issues/40 Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/88 Test Plan: ``` $ python run.py --op layer_norm --bwd 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:26<00:00, 1.13it/s] x_val torch_layer_norm-latency triton_layer_norm-latency torch_compile_layer_norm-latency liger_layer_norm-latency ------- -------------------------- --------------------------- ---------------------------------- -------------------------- 1024 0.06768 0.0888 0.068736 0.068256 1536 0.09872 0.090592 0.090368 0.08032 2048 0.121568 0.100352 0.104608 0.088224 2560 0.149536 0.107424 0.122656 0.097472 3072 0.184768 0.116544 0.143456 0.124288 3584 0.213216 0.127264 0.176576 0.117312 4096 0.240576 0.1384 0.195168 0.123936 4608 0.271232 0.180928 0.218176 0.179744 5120 0.294272 0.191328 0.240352 0.185056 5632 0.31952 0.199616 0.26704 0.197792 6144 0.344064 0.208448 0.297792 0.21168 6656 0.36864 0.219232 0.339936 0.219552 7168 0.393792 0.226816 0.365152 0.22592 7680 0.419456 0.240736 0.390432 0.236992 8192 0.44576 0.251936 0.419776 0.25088 8704 0.480256 0.264032 0.448672 2.67574 9216 0.502624 0.274272 0.477312 2.72173 9728 0.527168 0.293152 0.522656 2.7551 10240 0.554528 0.30736 0.549216 2.78102 10752 0.576192 0.325824 0.573888 2.8047 11264 0.601088 0.339392 0.598272 2.84749 11776 0.635232 0.351808 0.631072 2.8816 12288 0.653088 0.36336 0.655776 2.92502 12800 0.684352 0.381472 0.696512 2.95024 13312 0.708384 0.391296 0.720288 2.97734 13824 0.73264 0.406944 0.743584 3.00829 14336 0.756224 0.417472 0.771136 3.04874 14848 0.781728 0.434144 0.79568 3.07536 15360 0.806656 0.432192 0.82064 3.10083 15872 0.833216 0.459456 0.858624 3.10598 ``` Reviewed By: FindHao Differential Revision: D66667947 Pulled By: xuzhao9 fbshipit-source-id: 14f9304fb3684881b5d0f91635f1cde58b6fcc8e --- tritonbench/operators/layer_norm/operator.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tritonbench/operators/layer_norm/operator.py b/tritonbench/operators/layer_norm/operator.py index 6627697c..ecfc4444 100644 --- a/tritonbench/operators/layer_norm/operator.py +++ b/tritonbench/operators/layer_norm/operator.py @@ -34,6 +34,14 @@ def torch_layer_norm(self, *args): @register_benchmark() def torch_compile_layer_norm(self, *args): + # We need to run backward multiple times for proper benchmarking + # so donated buffer have to be disabled + if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD: + import torch._functorch.config + + torch._functorch.config.donated_buffer = False + import torch + @torch.compile def inner(*args): return F.layer_norm(*args)