From 0fe5493a11c1c0c28cd71d9d139cd93e9a94e542 Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Tue, 10 Dec 2024 16:50:17 -0800 Subject: [PATCH] only running config checks for triton kernels Summary: TSIA Reviewed By: xuzhao9 Differential Revision: D67052354 fbshipit-source-id: fddd6482493051be3a0c69cc1bb0f4508310388c --- tritonbench/utils/triton_op.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index fa70499..9ffa185 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -872,18 +872,29 @@ def all_configs(self, fn): from triton.runtime import Autotuner + from triton.runtime.jit import JITFunction + original_run = Autotuner.run + original_run_jit = JITFunction.run autotuner = None + compiled_kernels = [] def run_and_capture(self, *args, **kwargs): nonlocal autotuner autotuner = self original_run(self, *args, **kwargs) - with mock.patch.object(Autotuner, "run", run_and_capture): - fn() + # There isn't really a great way to get the compiled kernels without monkeypatching + def run_and_capture_jit(self, *args, **kwargs): + compiled_kernel = original_run_jit(self, *args, **kwargs) + compiled_kernels.append(compiled_kernel) + return compiled_kernel - if autotuner is not None: + with mock.patch.object(JITFunction, "run", run_and_capture_jit): + with mock.patch.object(Autotuner, "run", run_and_capture): + fn() + + if autotuner is not None and len(compiled_kernels): configs = [] for config in autotuner.configs: configs.append(str(config))