From d1958329c9c531d33b4933e9bfcdbc4b15ab4ed0 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 17 Oct 2024 19:29:53 +0200 Subject: [PATCH 1/8] [NFC] Add `#include ` into `TritonToTritonGPUPass.h` (#4943) This makes the dependence more obvious since `const std::string &target` is used in this file. --- .../triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h index d3da1394e4..78917fdfdd 100644 --- a/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h +++ b/include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h @@ -3,6 +3,7 @@ #include #include +#include namespace mlir { From 188370325395c79a4ba3de0bc47e39a19fc83224 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 17 Oct 2024 21:07:59 +0200 Subject: [PATCH 2/8] [NFC] Make some tests platform independent (#4946) Minor changes that reuse Python's capabilities for writing platform-independent code. Signed-off-by: Anatoly Myachev --- python/test/unit/language/test_compile_errors.py | 3 ++- python/test/unit/language/test_subprocess.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 8756ac24e1..12c3997ec7 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -150,7 +150,8 @@ def kernel(): try: inner = e.value.__cause__ outer = e.value - assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py" + assert f"{os.sep}core.py" in '\n'.join(traceback.format_tb( + inner.__traceback__)), "error should point inside core.py" assert "at 2:4:" in str(outer), "error should point to expand_dims call" assert "" not in str(outer) diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 2ad97e8a68..193895757d 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -49,7 +49,7 @@ def test_print(func_type: str, data_type: str, device: str): assert proc.stderr == b'' return - outs = [line for line in proc.stdout.decode("UTF-8").split("\n") if line] + outs = [line for line in proc.stdout.decode("UTF-8").splitlines() if line] # The total number of elements in the 1-D tensor to print. N = 128 From 692143cd869f2fa1501edb1db76fb452b85ac914 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 17 Oct 2024 21:09:30 +0100 Subject: [PATCH 3/8] [AMD] Add a tt.pointer_range_32 specialization (#4910) This is a PR adding an attribute in the HIP backend to test for a tensor storage to be within 2GB. This will enable support of buffer operations. --- python/test/unit/runtime/test_cache.py | 27 +++++++++++++ python/triton/backends/compiler.py | 52 +++++++++++++++++++------- python/triton/runtime/jit.py | 4 ++ third_party/amd/backend/compiler.py | 47 ++++++++++++++++++++++- 4 files changed, 115 insertions(+), 15 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 48e4eeebd3..d83cb67c3e 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -9,6 +9,7 @@ import triton import triton.language as tl from triton.runtime.jit import JITFunction +from triton._internal_testing import is_hip @triton.jit @@ -572,3 +573,29 @@ def compiled_hook(*args, **kwargs): assert specialization_data is not None and specialization_data_compiled == specialization_data assert is_warmup is True assert key in kernel_add.cache[getattr(torch, device).current_device()] + + +@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) +def test_within_2gb(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a): + tl.load(a) + + # This is the attribute we want to test + pointer_range_32 = None + + def cache_hook(*args, **kwargs): + nonlocal pointer_range_32 + pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32 + + JITFunction.cache_hook = cache_hook + # In warmup we assume that the pointer range is 32 bits + kernel_add.warmup(torch.float32, grid=(1, )) + assert pointer_range_32 == [0] + # Torch tensor > 2GB + kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) + assert len(pointer_range_32) == 0 + # Torch tensor <= 2GB + kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) + assert pointer_range_32 == [0] diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 037cd1b597..f2ba8eac80 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -8,7 +8,21 @@ from typing import Dict, List, Tuple, Union from types import ModuleType +# Table that associates strings to AttrsDescriptor (sub)classes. +# In this way we can dynamically select the correct class +# constructor +_descriptor_table = {} + +def register_descriptor(cls): + """ + Register a descriptor into the descriptor table + """ + _descriptor_table[cls.__name__] = cls + return cls + + +@register_descriptor class AttrsDescriptor: """ This class handles compile-time properties for specific function parameters. @@ -135,18 +149,28 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() def to_dict(self): - return self.arg_properties + """ + Store the fields of this class in a serializable dictionary + """ + # We need to only store the `arg_properties` field. To initialize the + # other fields we relay on the class type. We store it as a string in + # the dictionary so that we can use it to invoke the appropriate + # (sub)class constructor in the `from_dict` method. + return {"arg_properties": self.arg_properties, "cls": type(self).__name__} @staticmethod def from_dict(data): - attrsDescriptor = AttrsDescriptor() - for prop_name, param_ids in data.items(): - attrsDescriptor.arg_properties[prop_name] = param_ids - attrsDescriptor._init_slots() - return attrsDescriptor - - @staticmethod - def from_hints(hints: List[Tuple[int, int]]): + """ + Create the object from a serializable dictionary + """ + attrs_descriptor = _descriptor_table[data["cls"]]() + for prop_name, param_ids in data["arg_properties"].items(): + attrs_descriptor.arg_properties[prop_name] = param_ids + attrs_descriptor._init_slots() + return attrs_descriptor + + @classmethod + def from_hints(cls, hints: List[Tuple[int, int]]): """ Create the class from a set of hints that are passed in. @@ -156,11 +180,11 @@ def from_hints(hints: List[Tuple[int, int]]): then we insert `param_index` into the correct list (e.g., in `arg_properties[prop0]`) """ - attrsDescriptor = AttrsDescriptor() - for prop_name, prop_val in attrsDescriptor.property_values.items(): - attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] - attrsDescriptor._init_slots() - return attrsDescriptor + attrs_descriptor = cls() + for prop_name, prop_val in attrs_descriptor.property_values.items(): + attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] + attrs_descriptor._init_slots() + return attrs_descriptor @staticmethod def is_divisible_by_16(x): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 0842849ad9..45178a40bb 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -879,6 +879,10 @@ def __init__(self, dtype): def data_ptr(): return 0 # optimistically assumes multiple of 16 + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + class TensorWrapper: diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 640ccead59..66e05ce034 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,4 +1,4 @@ -from triton.backends.compiler import BaseBackend, GPUTarget +from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass from typing import Any, Dict, Tuple @@ -72,6 +72,44 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() +@register_descriptor +class HIPAttrsDescriptor(AttrsDescriptor): + # This property asserts if the underlying storage area of a given pointer + # can be resepresented as a 32 bit integer. When this is true, we can be + # sure that all indices into the tensor behind that pointer can use 32-bit + # indexing. That opens the door for the AMD backend to use buffer load/store + # instrinsics, which requires this property. Buffer load/store intrinsics + # gives direct out-of-bound support and simplifies index calculation for + # lower register pressure. + __slots__ = ("pointer_range_32") + + def _add_backend_properties(self, params=None, values=None): + self.property_values["tt.pointer_range"] = 32 + if params is None or values is None: + return + + self.arg_properties["tt.pointer_range"] = [ + param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] + + @staticmethod + def is_within2gb(arg): + if hasattr(arg, "ptr_range"): + return arg.ptr_range() <= 2**31 - 1 + if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"): + # Please note that 2**31-1 is the max int32 positive limit + return arg.untyped_storage().size() <= 2**31 - 1 + return False + + @staticmethod + def get_property_key(val, align): + generic_key = AttrsDescriptor.get_property_key(val, align) + hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N" + key = (generic_key + hip_key).replace("N", "") + return key if key else "N" + + class HIPBackend(BaseBackend): @staticmethod @@ -118,6 +156,13 @@ def get_module_map(self) -> Dict[str, ModuleType]: def load_dialects(self, ctx): amd.load_dialects(ctx) + def get_attrs_descriptor(self, params, args): + return HIPAttrsDescriptor(params, args) + + @staticmethod + def compute_spec_key(arg, align): + return HIPAttrsDescriptor.get_property_key(arg, align) + @staticmethod def path_to_rocm_lld(): # Check env path for ld.lld From 538c237078926a6dd1b4ebb9140d538e1c958d29 Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Fri, 18 Oct 2024 15:18:36 +0200 Subject: [PATCH 4/8] [TEST] Use device fixture for unit tests (#4948) --- python/test/unit/language/test_pipeliner.py | 2 +- python/test/unit/runtime/test_cache.py | 2 +- python/test/unit/test_debug.py | 25 ++++++++++----------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index e01ed2fc27..cc02cf0e33 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -203,7 +203,7 @@ def kernel_up(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows width = ROW_COUNT depth = 78 - x = torch.zeros(width, depth, device='cuda') + x = torch.zeros(width, depth, device=device) y0 = torch.rand_like(x) n_rows, n_cols = x.shape BLOCK_SIZE = triton.next_power_of_2(n_cols) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index d83cb67c3e..a45cb3f888 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -432,7 +432,7 @@ def kernel(tmp): tl.device_assert(tl.load(tmp) == 1, "tmp == 1") device = getattr(torch, device).current_device() - tmp = torch.tensor([1], dtype=torch.int32, device="cuda") + tmp = torch.tensor([1], dtype=torch.int32, device=device) assert len(kernel.cache[device]) == 0 kernel[(1, )](tmp, debug=False) assert len(kernel.cache[device]) == 1 diff --git a/python/test/unit/test_debug.py b/python/test/unit/test_debug.py index d370623966..05bf1fe494 100644 --- a/python/test/unit/test_debug.py +++ b/python/test/unit/test_debug.py @@ -10,7 +10,7 @@ for env_var in [True, False]\ ]) @pytest.mark.forked -def test_device_assert(cond, opt_flag, env_var, device="cuda"): +def test_device_assert(cond, opt_flag, env_var, device): os.environ['TRITON_DEBUG'] = str(int(env_var)) torch.zeros([1], dtype=torch.int32, device=device) @@ -21,11 +21,11 @@ def _kernel(COND: tl.constexpr): if not cond and (opt_flag or env_var): with pytest.raises(RuntimeError): _kernel[(1, )](cond, debug=opt_flag) - torch.cuda.synchronize() + getattr(torch, device).synchronize() return _kernel[(1, )](cond, debug=opt_flag) - torch.cuda.synchronize() + getattr(torch, device).synchronize() @pytest.mark.parametrize("cond", [False, True]) @@ -43,19 +43,18 @@ def _kernel(COND: tl.constexpr): _kernel[(1, )](cond) -def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func): - device = "cuda" +def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func, device): x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device) y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device) z = torch.empty_like(x) if should_overflow and debug: with pytest.raises(RuntimeError) as exc_info: tri_func[(1, )](x, y, z, debug=debug) - torch.cuda.synchronize() + getattr(torch, device).synchronize() assert "device-side assert" in str(exc_info.value) else: tri_func[(1, )](x, y, z, debug=debug) - torch.cuda.synchronize() + getattr(torch, device).synchronize() assert int(z) == int(ref_func(x, y)) @@ -74,13 +73,13 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref (2**15 - 1, 1, 'int16', 'int16', True, True), ]) @pytest.mark.forked -def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): +def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): @triton.jit def _kernel_add(X, Y, Z): tl.store(Z, tl.load(X) + tl.load(Y)) - _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y) + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y, device) # mul overflow @@ -95,13 +94,13 @@ def _kernel_add(X, Y, Z): (-2**30, 2, 'int32', 'int32', True, False), ]) @pytest.mark.forked -def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): +def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): @triton.jit def _kernel_mul(X, Y, Z): tl.store(Z, tl.load(X) * tl.load(Y)) - _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y) + _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y, device) # sub overflow @@ -115,10 +114,10 @@ def _kernel_mul(X, Y, Z): (-2**31, -1, 'int32', 'int32', True, False), ]) @pytest.mark.forked -def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow): +def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device): @triton.jit def _kernel_sub(X, Y, Z): tl.store(Z, tl.load(X) - tl.load(Y)) - _test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y) + _test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y, device) From bce48c82f3a8586e2f35b51d1afbcfe72a3fbbcd Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Fri, 18 Oct 2024 18:05:15 +0200 Subject: [PATCH 5/8] [NFC] Make cuda links parameterizable by `system` parameter (#4945) This will make it easier to support other platforms downstream. I hope that such code should not complicate the support of Triton itself. --------- Signed-off-by: Anatoly Myachev --- python/setup.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/python/setup.py b/python/setup.py index 5810b18075..714668462f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -284,7 +284,8 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func): arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] except KeyError: arch = platform.machine() - url = url_func(arch, version) + supported = {"Linux": "linux", "Darwin": "linux"} + url = url_func(supported[system], arch, version) tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux" @@ -500,11 +501,11 @@ def get_platform_dependent_src_path(subdir): download_and_copy( name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH", - version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda arch, version: + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/linux-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2" + f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2" if int(version_major) >= 12 and int(version_minor1) >= 5 else - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") (*version.split('.')))) download_and_copy( name="cuobjdump", @@ -512,8 +513,8 @@ def get_platform_dependent_src_path(subdir): dst_path="bin/cuobjdump", variable="TRITON_CUOBJDUMP_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"], - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", ) download_and_copy( name="nvdisasm", @@ -521,40 +522,41 @@ def get_platform_dependent_src_path(subdir): dst_path="bin/nvdisasm", variable="TRITON_NVDISASM_PATH", version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"], - url_func=lambda arch, version: - f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", + url_func=lambda system, arch, version: + f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", ) download_and_copy( name="cudacrt", src_path=get_platform_dependent_src_path("include"), dst_path="include", - variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda arch, version: + variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-crt-dev_linux-{arch}/{version}/download/noarch/cuda-crt-dev_linux-{arch}-{version}-0.tar.bz2" + f"https://anaconda.org/nvidia/cuda-crt-dev_{system}-{arch}/{version}/download/noarch/cuda-crt-dev_{system}-{arch}-{version}-0.tar.bz2" if int(version_major) >= 12 and int(version_minor1) >= 5 else - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") + f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") (*version.split('.')))) download_and_copy( name="cudart", src_path=get_platform_dependent_src_path("include"), dst_path="include", - variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda arch, version: + variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-cudart-dev_linux-{arch}/{version}/download/noarch/cuda-cudart-dev_linux-{arch}-{version}-0.tar.bz2" + f"https://anaconda.org/nvidia/cuda-cudart-dev_{system}-{arch}/{version}/download/noarch/cuda-cudart-dev_{system}-{arch}-{version}-0.tar.bz2" if int(version_major) >= 12 and int(version_minor1) >= 5 else - f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/linux-{arch}/cuda-cudart-dev-{version}-0.tar.bz2" + f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/{system}-{arch}/cuda-cudart-dev-{version}-0.tar.bz2" )(*version.split('.')))) download_and_copy( name="cupti", src_path=get_platform_dependent_src_path("include"), dst_path="include", - variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version: + variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], + url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2" + f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2" if int(version_major) >= 12 and int(version_minor1) >= 5 else - f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2") + f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2") (*version.split('.')))) download_and_copy( name="cupti", src_path=get_platform_dependent_src_path("lib"), dst_path="lib/cupti", - variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version: + variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda system, arch, version: ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2" + f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2" if int(version_major) >= 12 and int(version_minor1) >= 5 else - f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2") + f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2") (*version.split('.')))) backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] From d4e5a7873323107ec16ff03c0727256d060908eb Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Fri, 18 Oct 2024 18:32:35 +0200 Subject: [PATCH 6/8] [Triton] Use `UnitAttr` in `tt.reshape` definition (#4947) Make `allow_reorder` and `efficient_layout` `UnitAttr` for a cleaner interface. This way, the operation exposes a `bool getEfficientLayout()` member to check for that attribute and a constructor receiving `bool` arguments for both of these attributes (defaulted to `false`). The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) Signed-off-by: victor-eds --- include/triton/Dialect/Triton/IR/TritonOps.td | 9 ++------- lib/Dialect/Triton/IR/Ops.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 ++--- .../TritonGPU/Transforms/OptimizeThreadLocality.cpp | 4 ++-- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 3 +-- test/Conversion/tritongpu_to_llvm.mlir | 2 +- test/Triton/combine.mlir | 12 ++++++------ test/Triton/invalid.mlir | 2 +- test/Triton/ops.mlir | 10 ++++++++-- test/TritonGPU/canonicalize.mlir | 8 ++++---- test/TritonGPU/combine.mlir | 10 +++++----- test/TritonGPU/loop-pipeline-hip.mlir | 2 +- test/TritonGPU/loop-pipeline.mlir | 2 +- test/TritonGPU/optimize-locality.mlir | 12 ++++++------ 14 files changed, 41 insertions(+), 42 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 66946c20cc..a8358d968c 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure, If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. The compiler is still free to change it for better performance. }]; - let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr:$efficient_layout); + let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout); let results = (outs TT_Tensor:$result); - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; + let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)"; let hasCanonicalizeMethod = 1; let hasFolder = 1; let hasVerifier = 1; - let builders = [ - OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder), - [{ - build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr()); - }]>]; } def TT_BroadcastOp : TT_Op<"broadcast", [Pure, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 1240caebe2..c2c057f42c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op, } LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { - if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + if (!op.getAllowReorder() || op.getEfficientLayout()) return failure(); return canonicalizeViewOrBroadcast(op, rewriter); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 8179c1cda1..70eaf5d3b6 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2748,7 +2748,7 @@ struct CanonicalizeConvertFromReshape return failure(); if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); - if (!op.getAllowReorder() || op.getEfficientLayout().has_value()) + if (!op.getAllowReorder() || op.getEfficientLayout()) return failure(); rewriter.replaceOpWithNewOp( @@ -2868,8 +2868,7 @@ struct CanonicalizeConvertFromConvert // cvt(reshape) -> reshape if (auto reshape = dyn_cast(arg)) { - if (!reshape.getAllowReorder() || - reshape.getEfficientLayout().has_value() || + if (!reshape.getAllowReorder() || reshape.getEfficientLayout() || isExpensiveView(reshape.getSrc().getType(), op.getType())) return failure(); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp index dbb41c4098..b0e5095ac8 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -314,8 +314,8 @@ class TritonGPUOptimizeThreadLocalityPass IRMapping mapping; for (auto operand : reduce.getOperands()) { auto viewOp = builder.create( - reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true); - viewOp.setEfficientLayout(true); + reduce.getLoc(), viewOpTensorType, operand, + /*allowReorder=*/true, /*efficientLayout=*/true); mapping.map(operand, viewOp); } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 91acba38bf..4ef9d1cd1d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -556,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) { RankedTensorType newDstType = RankedTensorType::get(reshapeDstType.getShape(), reshapeDstType.getElementType(), targetEncoding); - return reshape.getAllowReorder() && - !reshape.getEfficientLayout().has_value() && + return reshape.getAllowReorder() && !reshape.getEfficientLayout() && !triton::gpu::isExpensiveView(reshape.getSrc().getType(), newDstType); } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 76e85e3c7a..e2f43f4ba6 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.undef // CHECK: %[[T0:.*]] = llvm.extractvalue // CHECK: %[[T1:.*]] = llvm.extractvalue - %0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> + %0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2> // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue %[[T0]] // CHECK: llvm.insertvalue %[[T1]] diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 2197b18738..41a3ba15a8 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -292,15 +292,15 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor, %arg1: tensor<1xf32>) // CHECK-LABEL: @test_canonicalize_view tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) { - %view0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<2x4xf32> - // CHECK: %{{.*}} = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<4x2xf32> - %view1 = tt.reshape %view0 {allow_reorder = true} : tensor<2x4xf32> -> tensor<4x2xf32> + %view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32> + // CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32> + %view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32> %splat = tt.splat %arg1 : tensor -> tensor<8xf32> // CHECK: %{{.*}} = tt.splat %arg1 : tensor -> tensor<2x2x2xf32> - %view2 = tt.reshape %splat {allow_reorder = true} : tensor<8xf32> -> tensor<2x2x2xf32> + %view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32> - %view3 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<8xf32> + %view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32> // CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32> %add = arith.addf %view3, %arg0 : tensor<8xf32> @@ -329,7 +329,7 @@ tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x %a = arith.constant dense<1.0> : tensor<1x128xf32> // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32> - %b = tt.reshape %a {allow_reorder = true} : tensor<1x128xf32> -> tensor<16x8xf32> + %b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32> // CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32> %c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32> diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 6eee82fec7..a3826dded0 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) { tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) { // expected-error @+1 {{number of src and dst elements of reshape must be the same}} - %a = tt.reshape %arg0 {allow_reorder = false} : tensor<32x128xf16> -> tensor<64x32xf16> + %a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16> tt.return } diff --git a/test/Triton/ops.mlir b/test/Triton/ops.mlir index c5d7ec8b65..c3b92b7ee4 100644 --- a/test/Triton/ops.mlir +++ b/test/Triton/ops.mlir @@ -225,8 +225,14 @@ tt.func @inline_asm_scalar(%0: i32) { // CHECK-LABEL: reshape tt.func @reshape(%0: tensor<512xi32>) { - // CHECK: tt.reshape %{{.+}} {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32> - %1 = tt.reshape %0 {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32> + %1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + %2 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %3 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + // CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32> + %4 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32> tt.return } diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ecee359cb1..9422bb0f85 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: @test_canonicalize_convert_view // CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32 // CHECK-NOT: triton_gpu.convert_layout -// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] {allow_reorder = true} +// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder // CHECK: tt.return %[[V]] #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -13,7 +13,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { %c = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2> - %r = tt.reshape %c {allow_reorder = true} : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } } // end module @@ -25,7 +25,7 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> // CHECK-LABEL: @test_canonicalize_convert_expensive_view // CHECK-SAME: (%[[ARG:.+]]: tensor<256x16xf32 // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[ARG]] -// CHECK: %[[V:.+]] = tt.reshape %[[C]] {allow_reorder = true} +// CHECK: %[[V:.+]] = tt.reshape %[[C]] allow_reorder // CHECK: tt.return %[[V]] #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -33,7 +33,7 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.target" = "cuda:80"} { tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blocked0>) -> tensor<4096xf32, #blocked1> { %c = triton_gpu.convert_layout %arg0 : tensor<256x16xf32, #blocked0> -> tensor<256x16xf32, #blocked2> - %r = tt.reshape %c {allow_reorder = true} : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> + %r = tt.reshape %c allow_reorder : tensor<256x16xf32, #blocked2> -> tensor<4096xf32, #blocked1> tt.return %r : tensor<4096xf32, #blocked1> } } // end module diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 601b0cc44b..78c6f68bf6 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2097,7 +2097,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.func public @reshape_propagate(%arg0: tensor<16x2xf32, #blocked>) -> tensor<32xf32, #blocked3> { // CHECK-NOT: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = triton_gpu.convert_layout %b : tensor<32xf32, #blocked2> -> tensor<32xf32, #blocked3> tt.return %c : tensor<32xf32, #blocked3> } @@ -2116,7 +2116,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: tt.reshape // CHECK: triton_gpu.convert_layout %a = triton_gpu.convert_layout %arg0 : tensor<16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> - %b = tt.reshape %a {allow_reorder = false} : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> + %b = tt.reshape %a : tensor<16x2xf32, #blocked1> -> tensor<32xf32, #blocked2> tt.return %b : tensor<32xf32, #blocked2> } } @@ -2133,7 +2133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-NOT: triton_gpu.convert_layout // CHECK: arith.truncf // CHECK: triton_gpu.convert_layout - %a = tt.reshape %arg0 {allow_reorder = true, efficient_layout} : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> + %a = tt.reshape %arg0 allow_reorder efficient_layout : tensor<16x2xf32, #blocked> -> tensor<32xf32, #blocked1> %b = triton_gpu.convert_layout %a : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked2> %c = arith.truncf %b : tensor<32xf32, #blocked2> to tensor<32xf16, #blocked2> tt.return %c : tensor<32xf16, #blocked2> @@ -2536,9 +2536,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32, #triton_gpu.slice<{dim = 0, parent = #triton_gpu.slice<{dim = 2, parent = #blocked2}>}>> -> tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> %2 = tt.expand_dims %1 {axis = 2 : i32} : tensor<1x2xi32, #triton_gpu.slice<{dim = 2, parent = #blocked2}>> -> tensor<1x2x1xi32, #blocked2> %3 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<1x2x128xi32, #blocked2> - %4 = tt.reshape %3 {allow_reorder = false} : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %4 = tt.reshape %3 : tensor<1x2x128xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %5 = tt.broadcast %2 : tensor<1x2x1xi32, #blocked2> -> tensor<2x2x64xi32, #blocked2> - %6 = tt.reshape %5 {allow_reorder = false} : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> + %6 = tt.reshape %5 : tensor<2x2x64xi32, #blocked2> -> tensor<1x256xi32, #blocked1> %7 = arith.cmpi ne, %4, %cst : tensor<1x256xi32, #blocked1> %8 = arith.select %7, %6, %cst : tensor<1x256xi1, #blocked1>, tensor<1x256xi32, #blocked1> %9 = triton_gpu.convert_layout %8 : tensor<1x256xi32, #blocked1> -> tensor<1x256xi32, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 28c815febb..7fa7812c5a 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -221,7 +221,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %33:3 = scf.for %arg7 = %c0_i32 to %c5_i32 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %arg0, %arg10 = %arg1) -> (tensor<64x1x32xf32, #blocked>, tensor<1x512x!tt.ptr, #blocked2>, tensor<64x8x32x!tt.ptr, #blocked1>) : i32 { %39 = tt.load %arg9 : tensor<1x512x!tt.ptr, #blocked2> %40 = tt.load %arg10 : tensor<64x8x32x!tt.ptr, #blocked1> - %41 = tt.reshape %39 {allow_reorder = true} : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> + %41 = tt.reshape %39 allow_reorder : tensor<1x512xf32, #blocked2> -> tensor<64x1x8xf32, #blocked5> %43 = triton_gpu.convert_layout %41 : tensor<64x1x8xf32, #blocked5> -> tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %44 = triton_gpu.convert_layout %40 : tensor<64x8x32xf32, #blocked1> -> tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> %45 = tt.dot %43, %44, %arg8 : tensor<64x1x8xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x8x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x1x32xf32, #blocked> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index d22224b3e1..3d215a635d 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -1338,7 +1338,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked> %85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3> %86 = tt.trans %85 {order = array} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4> - %87 = tt.reshape %86 {allow_reorder = false} : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> + %87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5> %88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma> diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir index 5073f997d4..5442998671 100644 --- a/test/TritonGPU/optimize-locality.mlir +++ b/test/TritonGPU/optimize-locality.mlir @@ -4,7 +4,7 @@ // CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0.000000e+00> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.addf // CHECK: arith.addf %[[FOR_ARG]], %[[REDUCE]] @@ -207,7 +207,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[INIT_ARG:.*]] = arith.constant dense<0xFF800000> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[INIT_ARG]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.maximumf // CHECK: arith.maximumf %[[FOR_ARG]], %[[REDUCE]] @@ -314,7 +314,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[CST:.*]] = arith.constant dense<0x7F800000> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.minimumf // CHECK: arith.minimumf %[[FOR_ARG]], %[[REDUCE]] @@ -421,7 +421,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> // CHECK: %[[LOOP_OUTPUT:.*]] = scf.for {{.*}} iter_args(%[[FOR_ARG:.*]] = %[[CST]]) -> {{.*}} // CHECK: %[[LOAD:.*]] = tt.load -// CHECK: tt.reshape %[[LOAD]] {allow_reorder = true, efficient_layout} : {{.*}} -> tensor<{{32x32x4xf32.*}} +// CHECK: tt.reshape %[[LOAD]] allow_reorder efficient_layout : {{.*}} -> tensor<{{32x32x4xf32.*}} // CHECK-NEXT: %[[REDUCE:.*]] = "tt.reduce"({{%.*}}) <{axis = 2 : i32}> // CHECK: arith.mulf // CHECK: arith.mulf %[[FOR_ARG]], %[[REDUCE]] @@ -579,14 +579,14 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK-DAG: #[[$BLOCK1:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> // CHECK-DAG: #[[$BLOCK2:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}> // CHECK-LABEL: optimize_view_layout -// CHECK: %[[R:.+]] = tt.reshape {{.*}} {allow_reorder = true, efficient_layout} : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> +// CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<8x128xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK2]]> // CHECK: %[[C:.+]] = triton_gpu.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK2]]> -> tensor<64x16xf32, #[[$BLOCK1]]> // CHECK: "tt.reduce"(%[[C]]) #blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [2, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @optimize_view_layout(%arg0: tensor<8x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> { - %0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> + %0 = tt.reshape %arg0 allow_reorder : tensor<8x128xf32, #blocked> -> tensor<64x16xf32, #blocked1> %1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = arith.maximumf %arg1, %arg2 : f32 From 76ed94df1924b2262be9b37d778b6e0ccccb1180 Mon Sep 17 00:00:00 2001 From: SJW <48454132+sjw36@users.noreply.github.com> Date: Fri, 18 Oct 2024 16:48:14 -0500 Subject: [PATCH 7/8] [AMD] Remove stream pipeliner v1 (#4845) We have flipped stream pipeliner v2 on as default for quite sometime. All known issues has been fixed. So now remove old v1 pipeliner. Note that this changes know `num_stages` are handled: previously we used to enable pipelining if `num_stages` is `0`, which really is not a good behavior. Now switched to follow common practice where `0`/`1` won't trigger pipelining anymore; need `2` or more to trigger. Given downstream users might be using `0` in the codebase, right now we `assert` to give developers a clear indication the switch of behavior instead of silently drop the perf. The `assert` is expected to be dropped sometime down the line. --------- Co-authored-by: Lei Zhang --- bin/RegisterTritonDialects.h | 1 - test/TritonGPU/amd/amd-loop-pipeline-v1.mlir | 31 - third_party/amd/backend/compiler.py | 20 +- .../include/TritonAMDGPUTransforms/Passes.h | 2 - .../include/TritonAMDGPUTransforms/Passes.td | 13 - .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 - .../TritonAMDGPUTransforms/StreamPipeline.cpp | 920 ------------------ third_party/amd/python/triton_amd.cc | 2 - 8 files changed, 8 insertions(+), 982 deletions(-) delete mode 100644 test/TritonGPU/amd/amd-loop-pipeline-v1.mlir delete mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 28dc3befd0..25a891c2f7 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -60,7 +60,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUReorderInstructions(); - mlir::registerTritonAMDGPUStreamPipeline(); mlir::registerTritonAMDGPUStreamPipelineV2(); mlir::registerTritonAMDGPUCanonicalizePointers(); diff --git a/test/TritonGPU/amd/amd-loop-pipeline-v1.mlir b/test/TritonGPU/amd/amd-loop-pipeline-v1.mlir deleted file mode 100644 index 45eae93880..0000000000 --- a/test/TritonGPU/amd/amd-loop-pipeline-v1.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline | FileCheck %s - -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#loc = loc("/data/users/dberard/triton-env/scripts/matmul.py":6:0) -#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> -module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @use_dep_args - tt.func @use_dep_args(%a_ptrs: tensor<64x32x!tt.ptr, #blocked>, %b_ptrs: tensor<32x64x!tt.ptr, #blocked1>, %loop_range: i32) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) { - %cst = arith.constant dense<32> : tensor<64x32xi32, #blocked> - %cst2 = arith.constant dense<2048> : tensor<32x64xi32, #blocked1> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> - %c0_i32 = arith.constant 0 : i32 - %c8_i32 = arith.constant 8 : i32 - %c32_i32 = arith.constant 32 : i32 - // CHECK: tt.load - // CHECK: [[FOR_OUT:%[a-z0-9_]+]]:{{[0-9]+}} = scf.for - %for:3 = scf.for %arg6 = %c0_i32 to %loop_range step %c32_i32 iter_args(%arg7 = %cst_0, %arg8 = %a_ptrs, %arg9 = %b_ptrs) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) : i32 { - %63 = tt.load %arg8 : tensor<64x32x!tt.ptr, #blocked> - %64 = tt.load %arg9 : tensor<32x64x!tt.ptr, #blocked1> - %65 = triton_gpu.convert_layout %63 : tensor<64x32xbf16, #blocked> -> tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> - %66 = triton_gpu.convert_layout %64 : tensor<32x64xbf16, #blocked1> -> tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> - %67 = tt.dot %65, %66, %arg7 : tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> - %68 = tt.addptr %arg8, %cst : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> - %69 = tt.addptr %arg9, %cst2 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> - scf.yield %67, %68, %69 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1> - } - // CHECK: tt.return {{[^,]+}}, [[FOR_OUT]]#3, [[FOR_OUT]]#4 - tt.return %for#0, %for#1, %for#2 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1> - } -} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 66e05ce034..a53a06dd42 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -29,7 +29,7 @@ def min_dot_size(target: GPUTarget): class HIPOptions: num_warps: int = 4 waves_per_eu: int = 1 - num_stages: int = 0 + num_stages: int = 2 num_ctas: int = 1 extern_libs: dict = None cluster_dims: tuple = (1, 1, 1) @@ -215,23 +215,19 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) - use_new_pipeliner = os.getenv("TRITON_HIP_USE_NEW_STREAM_PIPELINE", "1") == "1" if amd.has_matrix_core_feature(options.arch): - if use_new_pipeliner: - # In the old pipeliner we only support num_stages = 0/1, which means something - # different than the NVIDIA side. In the new pipeliner we unify the num_stages - # interpretation. Default to use 2 stages if not explicitly set. - num_stages = options.num_stages if options.num_stages != 0 else 2 - amd.passes.ttgpuir.add_stream_pipelinev2(pm, num_stages) - else: - if options.num_stages == 0: - amd.passes.ttgpuir.add_stream_pipeline(pm) + assert options.num_stages != 0, ("Triton AMD backend pipeliner has been updated. " + "We used to trigger software pipelining with " + "num_stages == 0. Now it will not happen anymore; " + "please update to use num_stages == 2 for " + "equivalent behavior in the past.") + amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) - if use_new_pipeliner or options.num_stages != 0: + if amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_reorder_instructions(pm) amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index 027e7f0cc5..841137887b 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -6,8 +6,6 @@ namespace mlir { -std::unique_ptr createTritonAMDGPUStreamPipelinePass(); - std::unique_ptr createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2); std::unique_ptr diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 7f4c15c3d2..d59935e796 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -3,19 +3,6 @@ include "mlir/Pass/PassBase.td" -def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> { - let summary = "pipeline"; - - let description = [{ - Pipeline global loads through registers to shared memory while computing on previous - tile - }]; - - let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()"; - - let dependentDialects = []; -} - def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> { let summary = "pipeline"; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 129d7780c0..414e4a329f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_triton_library(TritonAMDGPUTransforms CanonicalizePointers.cpp OptimizeEpilogue.cpp ReorderInstructions.cpp - StreamPipeline.cpp StreamPipelineV2.cpp MfmaGroup.cpp diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp deleted file mode 100644 index 784ce52e1b..0000000000 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ /dev/null @@ -1,920 +0,0 @@ -#include "TritonAMDGPUTransforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "llvm/ADT/MapVector.h" - -//===----------------------------------------------------------------------===// -// This file implements stream software pipelining for loops. The implementation -// here is inspired by the pipeline pass in Triton and the rocMLIR pipeliner. -// -// We divide the loop body into the following phases: -// a. Pre-load operations: for instance, index computation. -// b. Load operations: loading from global memory to shared memory. -// c. Compute operations: for instance, Triton dot. -// d. Post-load operations: for instance, index computation. -// -// To pipeline the loop, we need to: -// - Find all the dependencies of the load operations. -// - Prologue: Hoist the pipelinable load operations and shared memory store -// for the ramp up stage -// - Pipelined Loop: Assemble the loop body minus last iteration -// - Prefetch next tile from global into regs (while computing from previous) -// - Non-load loop body -// - Store next tile into shared mem -// - Epilogue: Peeled non-load loop body for last iteration -// -//===----------------------------------------------------------------------===// - -using llvm::MapVector; -using namespace mlir; -namespace ttg = triton::gpu; - -#define GEN_PASS_CLASSES -#include "TritonAMDGPUTransforms/Passes.h.inc" - -namespace { - -class LoopPipeliner { - /// Cache of ForOp and YieldOp related to this pipeliner. - scf::ForOp forOp; - scf::YieldOp yieldOp; - - bool peelLastIter = true; - - /// The new pipelined ForOp. - scf::ForOp pplForOp; - - /// Loads to be pipelined - SetVector validLoads; - /// The value that each load will be mapped to (after layout conversion) - DenseMap convertMapping; - /// load => buffer - DenseMap loadsBuffer; - /// load => buffer type (with shared layout after swizzling) - DenseMap loadsBufferType; - - /// Iterator values - Value nextLoopCond; - - /// Yield values - SmallVector yieldValues; - - /// The number of stages in the pipeline is fixed to '2' for - /// analysis since there will be a current buffer stored in - /// shared mem and a next buffer stored in regs. - int numStages = 2; - - /// Arg indicies in in pplForOp - size_t depArgsBeginIdx; - DenseMap depArgsIdx; - - /// value (in loop) => value at stage N - DenseMap> valueMapping; - /// loop iter arg => value - DenseMap depArgsMapping; - - /// forOp value => pplForOp value - IRMapping curMapping; - /// forOp value => prefetch value - IRMapping nextMapping; - - /// Dependency ops by program order - SmallVector orderedDeps; - - SetVector currentDeps; - - /// block arguments that loads depend on - SetVector depArgs; - - /// operation => source operand defined stages - DenseMap> immediateOpStages; - - /// operations that loads depend on - SetVector depOps; - - /// Collect values that `v` depends on and are defined inside the loop - void collectValueDep(Value v, int stage, SetVector &deps, - SetVector &args); - - /// Collect all op dependencies - void collectDeps(SetVector &ops, - MapVector> &opDeps); - - void collectDepChain(Operation *op, SetVector &ops); - - /// Check if none of the for-ops has valid uses - LogicalResult checkOpUses(); - - /// Check if ops have dependencies that are not pipelinable - LogicalResult checkOpDeps(); - - void createBufferTypes(); - - void createOrderedDeps(); - - void createCurrentDeps(); - - /// Return the stage at which `v` is defined prior to `stage` - int getValueDefStage(Value v, int stage); - - /// Map `origin` to `newValue` at `stage` - void setValueMapping(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at `stage` according to the association between - /// yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue, int stage); - - /// Map `origin` to `newValue` at the next stage according to the association - /// between yieldOp and forOp - void setValueMappingYield(Value origin, Value newValue); - - /// Return the value mapped to `origin` at `stage`, if it exists. - Value lookupOrDefault(Value origin, int stage); - - Value getLoadMask(triton::LoadOp loadOp, Value mappedMask, Value loopCond, - OpBuilder &builder); - /// Collect all args of the new loop - SmallVector collectNewLoopArgs(); - - /// Clone the forOp and return the new forOp - scf::ForOp cloneForOp(ArrayRef newLoopArgs, OpBuilder &builder); - - void updateLoadMask(triton::LoadOp loadOp, Value newMask); - /// Prefetch the next iteration for `pplForOp` - void prefetchNextBuffer(OpBuilder &builder); - void cloneCurrentBody(OpBuilder &builder); - void storeNextBuffer(OpBuilder &builder); - - bool isLoadChain(Operation *op) const; - - /// Assemble `pplForOp`'s yield op - void finalizeYield(OpBuilder &builder); - -public: - LoopPipeliner(scf::ForOp forOp) : forOp(forOp) { - yieldOp = cast(forOp.getBody()->getTerminator()); - } - - /// Collect loads to pipeline. Return success if we can pipeline this loop - LogicalResult initialize(); - - // Update mapping from old forOp results to new pplForOp results - void setResultMapping(DenseMap &newResults); - - /// Emit pipelined loads (before loop body) - void emitPrologue(); - - /// emit pipelined loads (after loop body) - void emitEpilogue(DenseMap &newResults); - - /// create the new ForOp (add new args & insert prefetched ops) - scf::ForOp createNewForOp(); - - friend struct PipelinePass; -}; - -void LoopPipeliner::collectValueDep(Value v, int stage, - SetVector &deps, - SetVector &args) { - // Since we only need to peel the loop numStages-1 times, don't worry - // about depends that are too far away - if (stage < 0) - return; - - // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getRegion()) - return; - - if (Operation *op = v.getDefiningOp()) { - if (!deps.contains(op)) { - deps.insert(op); - for (Value opr : op->getOperands()) - collectValueDep(opr, stage, deps, args); - } - } else if (auto arg = dyn_cast(v)) { - if (arg.getArgNumber() > 0) { - args.insert(arg); - collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1, - deps, args); - } - } -} - -void LoopPipeliner::collectDeps( - SetVector &ops, - MapVector> &valueDeps) { - for (auto op : ops) { - for (Value v : op->getOperands()) { - SetVector deps; - SetVector args; - collectValueDep(v, numStages - 1, deps, args); - valueDeps[op] = deps; - } - } -} - -LogicalResult LoopPipeliner::checkOpUses() { - SetVector ops; - // We cannot use forOp.walk(...) here because we only want to visit the - // operations in the loop body block. Nested blocks are handled separately. - for (Operation &op : forOp) { - if (auto loadOp = dyn_cast(&op)) - ops.insert(&op); - } - - // Collect all ops' dependencies - MapVector> opDeps; - collectDeps(ops, opDeps); - - for (Operation *op : ops) { - auto loadOp = dyn_cast(op); - // Don't pipeline valid loads that depend on other valid loads - // (Because if a valid load depends on another valid load, this load needs - // to wait on the other load in the prologue, which is against the point - // of the pipeline pass) - bool isCandidate = true; - for (Operation *other : ops) - if (isa(other)) - if (opDeps[op].contains(other)) { - isCandidate = false; - break; - } - // We only pipeline loads that have one covert_layout (to dot_op) use - // TODO: lift this constraint in the future - if (isCandidate && loadOp.getResult().hasOneUse()) { - isCandidate = false; - Operation *use = *loadOp.getResult().getUsers().begin(); - - // Advance to the first conversion as long as the use resides in shared - // memory and it has a single use itself - while (use) { - if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse()) - break; - auto tensorType = - dyn_cast(use->getResult(0).getType()); - if (!tensorType || - !isa(tensorType.getEncoding())) - break; - use = *use->getResult(0).getUsers().begin(); - } - - // TODO: handle fp_to_fp conversions in between - if (auto convertLayout = llvm::dyn_cast(use)) - if (auto tensorType = - dyn_cast(convertLayout.getResult().getType())) - if (auto dotOpEnc = dyn_cast( - tensorType.getEncoding())) { - isCandidate = true; - convertMapping[loadOp] = convertLayout; - } - } else - isCandidate = false; - - if (isCandidate) - validLoads.insert(op); - } - - return validLoads.empty() ? failure() : success(); -} - -LogicalResult LoopPipeliner::checkOpDeps() { - /// arg => source operand defined stages - DenseMap> immediateArgStages; - SetVector nonImmediateDepArgs; - SetVector nonImmediateOps; - for (Operation *op : validLoads) { - for (Value v : op->getOperands()) { - SetVector deps; - SetVector args; - collectValueDep(v, numStages - 1, deps, args); - int defStage = getValueDefStage(v, numStages - 1); - if (defStage < 0) { - // assert(defStage >= 0 && - // "newLoopArgs has null args without a define op. Consider - // either " "rewrite the loop to reduce cross iteration - // dependencies or " "increase the num_stages value."); - return failure(); - } - bool immediate = args.size() > 0; - for (auto *dep : deps) { - depOps.insert(dep); - if (immediate) - immediateOpStages[dep].insert(defStage); - else - nonImmediateOps.insert(dep); - } - for (auto arg : args) { - depArgs.insert(arg); - if (immediate) - immediateArgStages[arg].insert(defStage); - else - nonImmediateDepArgs.insert(arg); - } - } - } - - // XXX: We could remove the following constraints if we can rematerialize in - // the loop. - // Check if immediateDepArgs and nonImmediateDepArgs are disjoint. - for (auto &[arg, stages] : immediateArgStages) { - assert(stages.size() == 1 && - "Triton doesn't support an argument provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateDepArgs.contains(arg) && - stages.contains(numStages - 2)) && - "Loop-carried arguments provide values for both immediate and " - "non-immediate operands of loads. Please consider removing " - "pre/post load instructions dependency on this argument."); - } - - // Check if immediateOps and nonImmediateOps are disjoint. - for (auto &[op, stages] : immediateOpStages) { - assert(stages.size() == 1 && - "Triton doesn't support an operation provides values for " - "immediate operands of loads from multiple stages. Consider " - "removing post load instructions dependency on this argument."); - assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) && - "Operations provide values for both immediate and " - "non-immediate operands of loads. Please consider " - "removing pre/post load instructions dependency on this " - "operation."); - } - return success(); -} - -// helpers -void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - valueMapping[origin] = SmallVector(numStages); - valueMapping[origin][stage] = newValue; -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue, - int stage) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto value = forOp.getRegionIterArgs()[yieldIdx]; - setValueMapping(value, newValue, stage); - } - } -} - -void LoopPipeliner::setValueMappingYield(Value origin, Value newValue) { - for (OpOperand &operand : origin.getUses()) { - if (operand.getOwner() == yieldOp) { - auto yieldIdx = operand.getOperandNumber(); - auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]]; - auto originArg = forOp.getRegionIterArgs()[yieldIdx]; - nextMapping.map(originArg, newValue); - auto newArg = pplForOp.getRegionIterArgs()[depYieldIdx]; - if (!depArgsMapping.contains(newArg)) - depArgsMapping[newArg] = newValue; - } - } -} - -Value LoopPipeliner::lookupOrDefault(Value origin, int stage) { - if (valueMapping.find(origin) == valueMapping.end()) - return origin; - return valueMapping[origin][stage]; -} - -void LoopPipeliner::createBufferTypes() { - for (auto loadCvt : convertMapping) { - auto loadOp = loadCvt.first; - Value cvt = loadCvt.second; - auto dotOpEnc = cast( - cast(cvt.getType()).getEncoding()); - auto ty = cast(loadOp.getType()); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - Type eType = ty.getElementType(); - auto blockedEnc = cast(ty.getEncoding()); - auto CTALayout = ttg::getCTALayout(ty.getEncoding()); - // unsigned bitWidth = dotOpEnc.getMMAv2kWidth() - // ? 32 / dotOpEnc.getMMAv2kWidth() - // : ty.getElementType().getIntOrFloatBitWidth(); - auto srcOrder = ttg::getOrder(ty.getEncoding()); - SmallVector sharedOrder; - int rank = srcOrder.size(); - // TODO rework this when shared -> dotOp conversions support arbitrary - // shared memory ordering - if (rank == 3) { - // Move the batch dimension (dim #0) to be the last so that it will be the - // slowest varying dimension. - for (unsigned i = 0; i < rank; ++i) - if (srcOrder[i] != 0) - sharedOrder.emplace_back(srcOrder[i]); - sharedOrder.emplace_back(0); - } else { - sharedOrder = srcOrder; - } - auto sharedEnc = - ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(), - sharedOrder, CTALayout, eType); - loadsBufferType[loadOp] = triton::MemDescType::get( - bufferShape, eType, sharedEnc, - triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()), - /*mutableMemory=*/true); - } -} - -void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (depOps.contains(&op)) - orderedDeps.push_back(&op); - else if (op.getNumResults() > 0 && validLoads.contains(&op)) - orderedDeps.push_back(&op); - } - assert(depOps.size() + validLoads.size() == orderedDeps.size() && - "depOps contains invalid values"); -} - -void LoopPipeliner::collectDepChain(Operation *op, - SetVector &ops) { - if (op->getNumResults() == 1 && validLoads.contains(op)) - return; - if (!ops.contains(op)) { - ops.insert(op); - for (Value opr : op->getOperands()) - if (Operation *oprOp = opr.getDefiningOp()) - collectDepChain(oprOp, ops); - } -} - -void LoopPipeliner::createCurrentDeps() { - for (Operation &op : forOp.getBody()->without_terminator()) { - if (!llvm::is_contained(orderedDeps, &op)) - collectDepChain(&op, currentDeps); - } -} - -int LoopPipeliner::getValueDefStage(Value v, int stage) { - if (stage < 0) - return -1; - if (auto arg = dyn_cast(v)) { - if (arg.getArgNumber() > 0) - return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1), - stage - 1); - llvm_unreachable("Loop induction variable should not be a dependency"); - } else - return stage; -} - -LogicalResult LoopPipeliner::initialize() { - if (checkOpUses().failed()) - return failure(); - - if (checkOpDeps().failed()) - return failure(); - - createBufferTypes(); - - createOrderedDeps(); - - createCurrentDeps(); - - return success(); -} - -Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask, - Value loopCond, OpBuilder &builder) { - if (!peelLastIter) { - // add mask for last iteration when not peeled to epilogue - Value mask = loadOp.getMask(); - Type maskType = triton::getI1SameShape(loadOp.getType()); - Value newMask; - if (mask) { - Value cond = loopCond; - if (isa(maskType)) { - cond = - builder.create(mask.getLoc(), maskType, loopCond); - } - newMask = builder.create(mask.getLoc(), mappedMask, cond); - } else { - if (isa(maskType)) { - newMask = builder.create(loopCond.getLoc(), maskType, - loopCond); - } else { - newMask = loopCond; - } - } - return newMask; - } - // use original mask when peeling last iteration bc the loop will not do - // extra loads for the tail of the pipeline - return mappedMask; -} - -bool LoopPipeliner::isLoadChain(Operation *op) const { - if (auto cvtOp = dyn_cast(op)) { - Value loadVal = cvtOp.getSrc(); - if (auto f2fOp = dyn_cast(op)) - loadVal = f2fOp.getSrc(); - if (validLoads.contains(loadVal.getDefiningOp())) { - if (isa(cvtOp.getType().getEncoding())) - return true; - } - } - return false; -} - -void LoopPipeliner::emitPrologue() { - /// forOp block args => forOp operands - /// forOp iterator => lower bound - IRMapping prologueMap; - OpBuilder builder(forOp); - // Get init operands for loop carried values - for (BlockArgument &arg : forOp.getRegionIterArgs()) { - OpOperand &operand = *forOp.getTiedLoopInit(arg); - prologueMap.map(arg, operand.get()); - } - - // Emit prologue - // Map IV to lower bound - prologueMap.map(forOp.getInductionVar(), forOp.getLowerBound()); - - // Emit Iteration 0 loads, etc - for (Operation *op : orderedDeps) { - Operation *newOp = nullptr; - if (validLoads.contains(op)) { - auto loadOp = cast(op); - // Load from global -> regs - auto newLoadOp = cloneWithInferType(builder, op, prologueMap); - Value loadVal = newLoadOp->getResult(0); - // Convert from regs to shared mem - newOp = builder.create( - loadOp.getLoc(), loadsBufferType[loadOp], loadVal); - Value cvtVal = newOp->getResult(0); - prologueMap.map(loadOp->getResult(0), cvtVal); - loadsBuffer[op] = cvtVal; - } else { - newOp = cloneWithInferType(builder, op, prologueMap); - } - // Capture loop carried results for pipelined for input - for (unsigned idx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(idx), newOp->getResult(idx), 1); - } // for (Operation *op : orderedDeps) -} - -void LoopPipeliner::setResultMapping(DenseMap &newResults) { - // After pipelining, some of the depArgs have beem mapped to new args. - // We need to remap these. - // - // For example, if we have - // - // ptr = ... - // c = [zeros] - // ret = scf.for iter_args(a_ptr=ptr, c=c) - // a = load(a_ptr) - // c += dot(a, ...) - // a_ptr_new = a_ptr + N - // scf.yield %a_ptr_new, %c - // - // then the ptr arg should be mapped to a new arg in the for loop. - // - // ptr = ... - // c = [zeros] - // load_pre = load(ptr) - // ptr_new = ptr + N - // ret = scf.for iter_args(a_ptr=ptr, c=c, ld=load_pre, A_ptr_1=ptr_new) - // a_next = load(A_ptr_1) - // c += dot(ld, ...) - // A_ptr_new = A_ptr_1 + N - // scf.yield a_ptr, c, a_next, A_ptr_new - // - // After this, if there are downstream users of a_ptr, they should reference - // ret#3 instead of ret#0 - for (const auto &origArg : llvm::enumerate(forOp.getRegionIterArgs())) { - if (depArgs.contains(origArg.value())) { - auto oldIdx = origArg.index(); - auto newIdx = depArgsIdx[origArg.value()]; - auto oldResult = forOp->getResult(oldIdx); - auto newResult = pplForOp->getResult(newIdx); - newResults[oldResult] = newResult; - } - } -} - -void LoopPipeliner::emitEpilogue(DenseMap &newResults) { - if (!peelLastIter) - return; - OpBuilder builder(pplForOp); - builder.setInsertionPointAfter(pplForOp); - - IRMapping epilogueMap; - // Map 'for' iteration args to pipelined-for results - auto args = forOp.getRegionIterArgs(); - for (uint32_t i = 0; i < args.size(); ++i) - epilogueMap.map(args[i], pplForOp.getResult(i)); - for (auto *loadOp : validLoads) - epilogueMap.map(loadOp->getResult(0), loadsBuffer[loadOp]); - - // This is computing the upper bound of the pipelined loop as: - // pplUpperBound = lb+((ub-1-lb)/step)*step - Location loc = forOp.getLoc(); - Value ub = forOp.getUpperBound(); - Value lb = forOp.getLowerBound(); - Value step = forOp.getStep(); - Value one = builder.create(loc, 1, 32); - - // pplRange = ub-1-lb - Value pplRange = builder.create( - loc, builder.create(loc, ub, one), lb); - - // pplIters = (pplrRange/step)*step - Value pplIters = builder.create( - loc, builder.create(loc, pplRange, step), step); - - // pplUpperBound = lb+pplIters - Value pplUpperBound = builder.create(loc, lb, pplIters); - epilogueMap.map(forOp.getInductionVar(), pplUpperBound); - - const auto &yieldOprs = yieldOp.getOperands(); - // Clone the loop body after the new ForOp - // , replace original args with results of the new ForOp. - for (Operation &op : forOp.getBody()->without_terminator()) { - if (currentDeps.contains(&op)) { - Operation *newOp = nullptr; - if (isLoadChain(&op)) { - if (auto cvt = dyn_cast(&op)) { - Value mappedValue = epilogueMap.lookup(cvt.getSrc()); - if (isa(mappedValue.getType())) { - auto newCvt = builder.create( - cvt.getLoc(), cvt.getType(), mappedValue); - epilogueMap.map(cvt.getResult(), newCvt); - newOp = newCvt; - } - } - if (!newOp) - newOp = builder.clone(op, epilogueMap); - } else { - newOp = cloneWithInferType(builder, &op, epilogueMap); - } - // substitute for these results for the results of the new for loop - for (const auto &pair : llvm::zip(op.getResults(), newOp->getResults())) { - auto val = std::get<0>(pair); - auto it = llvm::find(yieldOprs, val); - if (it != yieldOprs.end()) { - uint32_t idx = std::distance(yieldOprs.begin(), it); - newResults[forOp->getResult(idx)] = std::get<1>(pair); - } - } - } - } -} - -SmallVector LoopPipeliner::collectNewLoopArgs() { - // Order of new args: - // (original args) - // (shared mem buffers for each load) - // (depArgs at stage numStages - 1) - - // We need this to update operands for yield - // original block arg => new arg's idx - SmallVector newLoopArgs; - for (auto v : forOp.getInitArgs()) { - newLoopArgs.push_back(lookupOrDefault(v, numStages - 1)); /*1*/ - } - - // Loop carried vals - depArgsBeginIdx = newLoopArgs.size(); - for (auto depArg : depArgs) { - depArgsIdx[depArg] = newLoopArgs.size(); - newLoopArgs.push_back(valueMapping[depArg][numStages - 1]); /*1*/ - } - - return newLoopArgs; -} - -scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, - OpBuilder &builder) { - auto loc = forOp.getLoc(); - // Peel off the last iteration - auto pplUpperBound = forOp.getUpperBound(); - if (peelLastIter) - pplUpperBound = - builder.create(loc, pplUpperBound, forOp.getStep()); - - // Clone the original ForOp - pplForOp = builder.create( - loc, forOp.getLowerBound(), pplUpperBound, forOp.getStep(), newLoopArgs); - - // Set mapping on body of the new ForOp - builder.setInsertionPointToStart(pplForOp.getBody()); - for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) - curMapping.map(arg.value(), pplForOp.getRegionIterArgs()[arg.index()]); - for (auto *loadOp : validLoads) - curMapping.map(loadOp->getResult(0), loadsBuffer[loadOp]); - curMapping.map(forOp.getInductionVar(), pplForOp.getInductionVar()); - - nextMapping = curMapping; - // Map the dep args of the next iteration to the dep args of the current - auto iterArgs = pplForOp.getRegionIterArgs(); - size_t argIdx = 0; - for (auto depArg : depArgs) { - BlockArgument nextArg = iterArgs[argIdx + depArgsBeginIdx]; - nextMapping.map(depArg, nextArg); - ++argIdx; - } - - // Compute next IV for pre-loads - Value iv = pplForOp.getInductionVar(); - curMapping.map(forOp.getInductionVar(), iv); - Value nextIV = - builder.create(iv.getLoc(), iv, pplForOp.getStep()); - nextMapping.map(forOp.getInductionVar(), nextIV); - nextLoopCond = - builder.create(nextIV.getLoc(), arith::CmpIPredicate::slt, - nextIV, pplForOp.getUpperBound()); - - return pplForOp; -} - -void LoopPipeliner::updateLoadMask(triton::LoadOp loadOp, Value newMask) { - if (newMask) { - if (loadOp->getNumOperands() > 1) - loadOp->setOperand(1, newMask); - else { - auto mask = loadOp.getMaskMutable(); - mask.assign(newMask); - } - } -} - -void LoopPipeliner::prefetchNextBuffer(OpBuilder &builder) { - // Emit prefetch loads of next buffer before compute of current buffer - for (Operation *op : orderedDeps) { - Operation *nextOp = nullptr; - if (validLoads.contains(op)) { - // Update loading mask - auto loadOp = llvm::cast(op); - auto mask = loadOp.getMask(); - // pre-load global -> regs - Value newMask = getLoadMask(loadOp, nextMapping.lookupOrDefault(mask), - nextLoopCond, builder); - if (mask) { - // If mask is defined outside the loop, don't update the map more than - // once - if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask))) - nextMapping.map(loadOp.getMask(), newMask); - newMask = nextMapping.lookupOrDefault(mask); - } - auto newOp = builder.clone(*op, nextMapping); - updateLoadMask(cast(newOp), newMask); - } else if (!immediateOpStages[op].contains(numStages - 2)) { - Operation *nextOp = builder.clone(*op, nextMapping); - if (auto loadOp = dyn_cast(op)) { - if (auto newMask = getLoadMask( - loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder)) { - updateLoadMask(cast(nextOp), newMask); - } - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(dstIdx), nextOp->getResult(dstIdx)); - } - } -} - -void LoopPipeliner::cloneCurrentBody(OpBuilder &builder) { - auto loc = forOp.getLoc(); - // only add instructions that are not part of the restructuring - for (Operation &op : forOp.getBody()->without_terminator()) { - if (currentDeps.contains(&op)) { - Operation *newOp = nullptr; - if (isLoadChain(&op)) { - if (auto cvt = dyn_cast(&op)) { - Value mappedValue = curMapping.lookup(cvt.getSrc()); - if (isa(mappedValue.getType())) { - auto newCvt = builder.create( - cvt.getLoc(), cvt.getType(), mappedValue); - curMapping.map(cvt.getResult(), newCvt); - newOp = newCvt; - } - } - if (!newOp) - newOp = builder.clone(op, curMapping); - } else { - newOp = cloneWithInferType(builder, &op, curMapping); - } - } - } -} - -void LoopPipeliner::storeNextBuffer(OpBuilder &builder) { - // Store the next buffer at the end of the loop body for the next iteration - for (Operation *op : orderedDeps) { - if (!validLoads.contains(op)) { - if (immediateOpStages[op].contains(numStages - 2)) { - Operation *nextOp = builder.clone(*op, nextMapping); - if (auto loadOp = dyn_cast(op)) { - auto newMask = - getLoadMask(loadOp, nextMapping.lookupOrDefault(loadOp.getMask()), - nextLoopCond, builder); - updateLoadMask(cast(nextOp), newMask); - } - - for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) - setValueMappingYield(op->getResult(dstIdx), - nextOp->getResult(dstIdx)); - } - } - } - - // PL loads -> store next to shared - for (auto *loadOp : validLoads) { - Value loadVal = nextMapping.lookup(loadOp->getResult(0)); - // then store regs -> shared - Value storeBuf = loadsBuffer[loadOp]; - builder.create(loadOp->getLoc(), loadVal, storeBuf); - } - - // Some values have not been used by any ops in the loop body - for (BlockArgument arg : forOp.getRegionIterArgs()) - setValueMappingYield(arg, pplForOp.getRegionIterArgs()[depArgsIdx[arg]]); -} - -void LoopPipeliner::finalizeYield(OpBuilder &builder) { - SmallVector yieldValues; - for (const auto &opr : llvm::enumerate(yieldOp->getOperands())) { - if (curMapping.contains(opr.value())) - yieldValues.push_back(curMapping.lookup(opr.value())); - else - yieldValues.push_back(pplForOp.getRegionIterArgs()[opr.index()]); - } - for (size_t i = 0; i < depArgsMapping.size(); ++i) { - auto arg = pplForOp.getRegionIterArgs()[depArgsBeginIdx + i]; - assert(depArgsMapping.count(arg) && "Missing loop-carried value"); - yieldValues.push_back(depArgsMapping[arg]); - } - - builder.setInsertionPointToEnd(pplForOp.getBody()); - builder.create(yieldOp->getLoc(), yieldValues); -} - -scf::ForOp LoopPipeliner::createNewForOp() { - OpBuilder builder(forOp); - auto newLoopArgs = collectNewLoopArgs(); - cloneForOp(newLoopArgs, builder); - prefetchNextBuffer(builder); - cloneCurrentBody(builder); - storeNextBuffer(builder); - finalizeYield(builder); - return pplForOp; -} - -// Stream Pipeline -struct PipelinePass : public TritonAMDGPUStreamPipelineBase { - PipelinePass() = default; - - void runOnOperation() override { - // Pre-processing - // we make sure element-wise ops are done *after* the conversion - // to dot operands - // we can achieve this with simple recursive pattern matching - // MLIRContext *context = &getContext(); - // mlir::RewritePatternSet patterns(context); - // patterns.add(context); - // auto didPreprocess = - // applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - - // Do the pipelining - getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp); - - if (pipeliner.initialize().failed()) - return; - - pipeliner.emitPrologue(); - scf::ForOp pplForOp = pipeliner.createNewForOp(); - DenseMap newResults; - for (unsigned i = 0; i < forOp->getNumResults(); ++i) - newResults[forOp->getResult(i)] = pplForOp->getResult(i); - pipeliner.setResultMapping(newResults); - pipeliner.emitEpilogue(newResults); - - // Replace the original loop - for (auto &pair : newResults) - std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair)); - forOp->erase(); - }); - } -}; -} // anonymous namespace - -std::unique_ptr mlir::createTritonAMDGPUStreamPipelinePass() { - return std::make_unique(); -} diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 5b5cca5b05..a9f3a8ee2f 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -68,8 +68,6 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUCanonicalizePointersPass); ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructionsPass); - ADD_PASS_WRAPPER_0("add_stream_pipeline", - mlir::createTritonAMDGPUStreamPipelinePass); ADD_PASS_WRAPPER_1("add_stream_pipelinev2", mlir::createTritonAMDGPUStreamPipelineV2Pass, int); } From 7f80413d5b497e4171beb79c0d97edd2198c85c3 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 23 Oct 2024 14:04:47 +0000 Subject: [PATCH 8/8] Revert "[AMD] Add a tt.pointer_range_32 specialization (#4910)" This reverts commit 692143cd869f2fa1501edb1db76fb452b85ac914. --- python/test/unit/runtime/test_cache.py | 27 ------------- python/triton/backends/compiler.py | 52 +++++++------------------- python/triton/runtime/jit.py | 4 -- third_party/amd/backend/compiler.py | 47 +---------------------- 4 files changed, 15 insertions(+), 115 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a45cb3f888..a0084e0be9 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -9,7 +9,6 @@ import triton import triton.language as tl from triton.runtime.jit import JITFunction -from triton._internal_testing import is_hip @triton.jit @@ -573,29 +572,3 @@ def compiled_hook(*args, **kwargs): assert specialization_data is not None and specialization_data_compiled == specialization_data assert is_warmup is True assert key in kernel_add.cache[getattr(torch, device).current_device()] - - -@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) -def test_within_2gb(device, fresh_triton_cache) -> None: - - @triton.jit - def kernel_add(a): - tl.load(a) - - # This is the attribute we want to test - pointer_range_32 = None - - def cache_hook(*args, **kwargs): - nonlocal pointer_range_32 - pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32 - - JITFunction.cache_hook = cache_hook - # In warmup we assume that the pointer range is 32 bits - kernel_add.warmup(torch.float32, grid=(1, )) - assert pointer_range_32 == [0] - # Torch tensor > 2GB - kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) - assert len(pointer_range_32) == 0 - # Torch tensor <= 2GB - kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) - assert pointer_range_32 == [0] diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 8c3eda90f5..11e1dc4cef 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -8,21 +8,7 @@ from typing import Dict, Union from types import ModuleType -# Table that associates strings to AttrsDescriptor (sub)classes. -# In this way we can dynamically select the correct class -# constructor -_descriptor_table = {} - -def register_descriptor(cls): - """ - Register a descriptor into the descriptor table - """ - _descriptor_table[cls.__name__] = cls - return cls - - -@register_descriptor class AttrsDescriptor: """ This class handles compile-time properties for specific function parameters. @@ -149,28 +135,18 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() def to_dict(self): - """ - Store the fields of this class in a serializable dictionary - """ - # We need to only store the `arg_properties` field. To initialize the - # other fields we relay on the class type. We store it as a string in - # the dictionary so that we can use it to invoke the appropriate - # (sub)class constructor in the `from_dict` method. - return {"arg_properties": self.arg_properties, "cls": type(self).__name__} + return self.arg_properties @staticmethod def from_dict(data): - """ - Create the object from a serializable dictionary - """ - attrs_descriptor = _descriptor_table[data["cls"]]() - for prop_name, param_ids in data["arg_properties"].items(): - attrs_descriptor.arg_properties[prop_name] = param_ids - attrs_descriptor._init_slots() - return attrs_descriptor - - @classmethod - def from_hints(cls, hints: list[tuple[int, int]]): + attrsDescriptor = AttrsDescriptor() + for prop_name, param_ids in data.items(): + attrsDescriptor.arg_properties[prop_name] = param_ids + attrsDescriptor._init_slots() + return attrsDescriptor + + @staticmethod + def from_hints(hints: list[tuple[int, int]]): """ Create the class from a set of hints that are passed in. @@ -180,11 +156,11 @@ def from_hints(cls, hints: list[tuple[int, int]]): then we insert `param_index` into the correct list (e.g., in `arg_properties[prop0]`) """ - attrs_descriptor = cls() - for prop_name, prop_val in attrs_descriptor.property_values.items(): - attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] - attrs_descriptor._init_slots() - return attrs_descriptor + attrsDescriptor = AttrsDescriptor() + for prop_name, prop_val in attrsDescriptor.property_values.items(): + attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] + attrsDescriptor._init_slots() + return attrsDescriptor @staticmethod def is_divisible_by_16(x): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 45178a40bb..0842849ad9 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -879,10 +879,6 @@ def __init__(self, dtype): def data_ptr(): return 0 # optimistically assumes multiple of 16 - @staticmethod - def ptr_range(): - return 0 # optimistically assumes 32 bit pointer range - class TensorWrapper: diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 66e05ce034..640ccead59 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,4 +1,4 @@ -from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor +from triton.backends.compiler import BaseBackend, GPUTarget from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass from typing import Any, Dict, Tuple @@ -72,44 +72,6 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() -@register_descriptor -class HIPAttrsDescriptor(AttrsDescriptor): - # This property asserts if the underlying storage area of a given pointer - # can be resepresented as a 32 bit integer. When this is true, we can be - # sure that all indices into the tensor behind that pointer can use 32-bit - # indexing. That opens the door for the AMD backend to use buffer load/store - # instrinsics, which requires this property. Buffer load/store intrinsics - # gives direct out-of-bound support and simplifies index calculation for - # lower register pressure. - __slots__ = ("pointer_range_32") - - def _add_backend_properties(self, params=None, values=None): - self.property_values["tt.pointer_range"] = 32 - if params is None or values is None: - return - - self.arg_properties["tt.pointer_range"] = [ - param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] - - @staticmethod - def is_within2gb(arg): - if hasattr(arg, "ptr_range"): - return arg.ptr_range() <= 2**31 - 1 - if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"): - # Please note that 2**31-1 is the max int32 positive limit - return arg.untyped_storage().size() <= 2**31 - 1 - return False - - @staticmethod - def get_property_key(val, align): - generic_key = AttrsDescriptor.get_property_key(val, align) - hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N" - key = (generic_key + hip_key).replace("N", "") - return key if key else "N" - - class HIPBackend(BaseBackend): @staticmethod @@ -156,13 +118,6 @@ def get_module_map(self) -> Dict[str, ModuleType]: def load_dialects(self, ctx): amd.load_dialects(ctx) - def get_attrs_descriptor(self, params, args): - return HIPAttrsDescriptor(params, args) - - @staticmethod - def compute_spec_key(arg, align): - return HIPAttrsDescriptor.get_property_key(arg, align) - @staticmethod def path_to_rocm_lld(): # Check env path for ld.lld