Skip to content

Commit

Permalink
Precompile triton templates (pytorch#121998)
Browse files Browse the repository at this point in the history
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
@triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

Pull Request resolved: pytorch#121998
Approved by: https://github.com/jansel
  • Loading branch information
eellison authored and pytorchmergebot committed Mar 25, 2024
1 parent 9b095c3 commit ebde6c7
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 24 deletions.
14 changes: 8 additions & 6 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ def benchmark(
class TritonBenchmarkRequest(BenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!

def __init__(
self,
kernel_name: str,
Expand Down Expand Up @@ -545,6 +544,8 @@ def make_run_fn(
if "warmup" in inspect.signature(run_method).parameters:
warmup_arg["warmup"] = False

from torch._C import _cuda_getCurrentRawStream as get_raw_stream

if torch.version.hip and self.matrix_instr_nonkdim != 0:
return functools.partial(
run_method,
Expand All @@ -553,9 +554,7 @@ def make_run_fn(
*self.extra_args,
grid=self.grid,
**warmup_arg,
num_stages=self.num_stages,
num_warps=self.num_warps,
matrix_instr_nonkdim=self.matrix_instr_nonkdim,
stream=get_raw_stream(self.output_tensor_meta.device.index),
)
else:
return functools.partial(
Expand All @@ -565,10 +564,13 @@ def make_run_fn(
*self.extra_args,
grid=self.grid,
**warmup_arg,
num_stages=self.num_stages,
num_warps=self.num_warps,
stream=get_raw_stream(self.output_tensor_meta.device.index),
)

def precompile(self):
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
getattr(mod, self.kernel_name).precompile()

def __str__(self) -> str:
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"

Expand Down
30 changes: 27 additions & 3 deletions torch/_inductor/codegen/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ def signature_to_meta(
}


def is_unaligned_buffer(arg: TensorArg):
buf_name = arg.buffer
if buf_name in V.graph.graph_inputs:
return not config.assume_aligned_inputs

if buf_name in V.graph.constants:
# all constants are assumed to be aligned
return False

if V.graph.scheduler:
layout = V.graph.scheduler.get_buffer_layout(buf_name)
else:
buffer = V.graph.get_buffer(buf_name)
# output arg
if not buffer:
assert buf_name == V.kernel.output_node.name
layout = V.kernel.output_node.layout
else:
layout = buffer.get_layout()

if isinstance(layout, torch._inductor.ir.NonOwningLayout):
return not layout.maybe_guard_aligned()
else:
return False


def config_of(
args: List[KernelArgType],
*,
Expand All @@ -81,9 +107,7 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
)
return offset_aligned and not V.graph.scheduler.is_unaligned_buffer(
x.buffer
)
return offset_aligned and not is_unaligned_buffer(x)
else:
return False
if isinstance(x, SizeArg):
Expand Down
13 changes: 2 additions & 11 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,18 +2451,9 @@ def codegen(self):

self.flush()

def is_unaligned_buffer(self, buf_name):
if buf_name in V.graph.graph_inputs:
return not config.assume_aligned_inputs
if buf_name in V.graph.constants:
# all constants are assumed to be aligned
return False
def get_buffer_layout(self, buf_name: str) -> ir.Layout:
node = self.name_to_node[buf_name]
layout = node.node.get_layout()
if isinstance(layout, ir.NonOwningLayout):
return not layout.maybe_guard_aligned()
else:
return False
return node.node.get_layout()


class BaseScheduling:
Expand Down
13 changes: 9 additions & 4 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
grid_fn,
meta,
call_sizes,
use_jit=True,
use_jit=False,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
Expand Down Expand Up @@ -150,8 +150,8 @@ def jit_lines(self):
argdefs, _, signature = self.args.python_argdefs()
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"device": self.output_node.get_device().index,
"device_type": self.output_node.get_device().type,
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
Expand Down Expand Up @@ -498,7 +498,7 @@ def generate(
), TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
use_jit=True,
use_jit=False,
**kernel_options,
) as kernel:
try:
Expand Down Expand Up @@ -684,6 +684,10 @@ def benchmark(self, *args, out):
assert self.bmreq is not None
return self.bmreq.benchmark(*args, output_tensor=out)

def precompile(self):
assert self.bmreq is not None
self.bmreq.precompile()

def __str__(self):
return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"

Expand Down Expand Up @@ -821,6 +825,7 @@ def __call__(

# TODO(nmacchioni): remove once CI tests are fixed
choices = [choice for choice in choices if choice is not None]

if len(choices) == 0:
raise RuntimeError(
"No choices to select, please consider adding ATEN into max_autotune_gemm_backends "
Expand Down

0 comments on commit ebde6c7

Please sign in to comment.