From ebde6c72cb1de53797dcc7e29897c98f7bdce3a3 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 25 Mar 2024 10:30:09 -0700 Subject: [PATCH] Precompile triton templates (#121998) Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` @triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` @triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) @triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121998 Approved by: https://github.com/jansel --- torch/_inductor/autotune_process.py | 14 +++++++----- torch/_inductor/codegen/triton_utils.py | 30 ++++++++++++++++++++++--- torch/_inductor/scheduler.py | 13 ++--------- torch/_inductor/select_algorithm.py | 13 +++++++---- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 790ec9d60ec0fe..8c44167bc32a2b 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -502,7 +502,6 @@ def benchmark( class TritonBenchmarkRequest(BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! - def __init__( self, kernel_name: str, @@ -545,6 +544,8 @@ def make_run_fn( if "warmup" in inspect.signature(run_method).parameters: warmup_arg["warmup"] = False + from torch._C import _cuda_getCurrentRawStream as get_raw_stream + if torch.version.hip and self.matrix_instr_nonkdim != 0: return functools.partial( run_method, @@ -553,9 +554,7 @@ def make_run_fn( *self.extra_args, grid=self.grid, **warmup_arg, - num_stages=self.num_stages, - num_warps=self.num_warps, - matrix_instr_nonkdim=self.matrix_instr_nonkdim, + stream=get_raw_stream(self.output_tensor_meta.device.index), ) else: return functools.partial( @@ -565,10 +564,13 @@ def make_run_fn( *self.extra_args, grid=self.grid, **warmup_arg, - num_stages=self.num_stages, - num_warps=self.num_warps, + stream=get_raw_stream(self.output_tensor_meta.device.index), ) + def precompile(self): + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + getattr(mod, self.kernel_name).precompile() + def __str__(self) -> str: return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index d1f58187ca4a4e..ba514552409d24 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -63,6 +63,32 @@ def signature_to_meta( } +def is_unaligned_buffer(arg: TensorArg): + buf_name = arg.buffer + if buf_name in V.graph.graph_inputs: + return not config.assume_aligned_inputs + + if buf_name in V.graph.constants: + # all constants are assumed to be aligned + return False + + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + + if isinstance(layout, torch._inductor.ir.NonOwningLayout): + return not layout.maybe_guard_aligned() + else: + return False + + def config_of( args: List[KernelArgType], *, @@ -81,9 +107,7 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: offset_aligned = V.graph.sizevars.statically_known_multiple_of( x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] ) - return offset_aligned and not V.graph.scheduler.is_unaligned_buffer( - x.buffer - ) + return offset_aligned and not is_unaligned_buffer(x) else: return False if isinstance(x, SizeArg): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 435efdf57b8c37..18f249321fd81c 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2451,18 +2451,9 @@ def codegen(self): self.flush() - def is_unaligned_buffer(self, buf_name): - if buf_name in V.graph.graph_inputs: - return not config.assume_aligned_inputs - if buf_name in V.graph.constants: - # all constants are assumed to be aligned - return False + def get_buffer_layout(self, buf_name: str) -> ir.Layout: node = self.name_to_node[buf_name] - layout = node.node.get_layout() - if isinstance(layout, ir.NonOwningLayout): - return not layout.maybe_guard_aligned() - else: - return False + return node.node.get_layout() class BaseScheduling: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 50820a72ca9b5d..2b41e93a2140b7 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -94,7 +94,7 @@ def __init__( grid_fn, meta, call_sizes, - use_jit=True, + use_jit=False, prefix_args=0, suffix_args=0, epilogue_fn=identity, @@ -150,8 +150,8 @@ def jit_lines(self): argdefs, _, signature = self.args.python_argdefs() triton_meta = { "signature": signature_to_meta(signature, size_dtype=self.index_dtype), - "device": V.graph.scheduler.current_device.index, - "device_type": V.graph.scheduler.current_device.type, + "device": self.output_node.get_device().index, + "device_type": self.output_node.get_device().type, "constants": {}, } triton_meta["configs"] = [config_of(signature)] @@ -498,7 +498,7 @@ def generate( ), TritonTemplateKernel( kernel_name=kernel_name, output_node=fake_out, - use_jit=True, + use_jit=False, **kernel_options, ) as kernel: try: @@ -684,6 +684,10 @@ def benchmark(self, *args, out): assert self.bmreq is not None return self.bmreq.benchmark(*args, output_tensor=out) + def precompile(self): + assert self.bmreq is not None + self.bmreq.precompile() + def __str__(self): return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" @@ -821,6 +825,7 @@ def __call__( # TODO(nmacchioni): remove once CI tests are fixed choices = [choice for choice in choices if choice is not None] + if len(choices) == 0: raise RuntimeError( "No choices to select, please consider adding ATEN into max_autotune_gemm_backends "