From 66e8629ff64da193e8c35f58eeea4a9c1e739e40 Mon Sep 17 00:00:00 2001 From: Ognjen Plavsic <130548569+oplavsic@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:37:10 +0100 Subject: [PATCH 01/11] [AMD] Implement RepOrder for AMD MMA layouts (#5126) Implement RepOrder methods for MFMA and WMMA layouts. Both layouts have row major rep layout. Also, isTranspose flag in MFMA layout does not affect RepOrder, meaning RepOrder is row major in both cases. Co-authored-by: Ognjen Plavsic --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 5 +++++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 21 ++++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2c1f7da609..09330e7eda 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -777,6 +777,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "getSizePerThreadForOperand", (ins "int":$opIdx, "int":$kWidth)>, + + InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first", + "SmallVector", + "getRepOrderForOperand", + (ins "int":$opIdx)>, ]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1e63c4b390..0237d9815c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1658,7 +1658,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { } SmallVector AMDMfmaEncodingAttr::getRepOrder() const { - llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder"); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); +} + +SmallVector +AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); } SmallVector @@ -1745,8 +1752,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { return shapePerCTATile; } SmallVector AMDWmmaEncodingAttr::getRepOrder() const { - llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder"); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); } + +SmallVector +AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { + auto rank = getWarpsPerCTA().size(); + return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true); +} + SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -2016,7 +2031,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getRepOrder() const { - if (auto mma = mlir::dyn_cast(getParent())) { + if (auto mma = mlir::dyn_cast(getParent())) { return mma.getRepOrderForOperand(getOpIdx()); } llvm::report_fatal_error( From 220e51c4fc4bd5eefec21eef50ee3599ddb7b4b3 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 18 Nov 2024 06:37:52 -0800 Subject: [PATCH 02/11] [BACKEND] Fix ProgramPoint passing in AxisInfoAnalysis (#5181) Fixes #5122. The `ProgramPoint` [here](https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/lib/Analysis/AxisInfo.cpp#L1087) is created on the stack. Then its address is [passed](https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/lib/Analysis/AxisInfo.cpp#L1088-L1089) to the MLIR `SparseAnalysis` code, where it is [added as a dependency](https://github.com/llvm/llvm-project/blob/33ff9e43b4c5bdc3da31c6b11ad51d35a69bec5f/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L311) and later [dereferenced](https://github.com/llvm/llvm-project/blob/33ff9e43b4c5bdc3da31c6b11ad51d35a69bec5f/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp#L90). By the time the `ProramPoint` is dereferenced in the `AbstractSparseForwardDataFlowAnalysis::visit`, the `AxisInfoAnalysis::visitForOpInductionVar` will have finished and the `ProgramPoint` stack variable destroyed. This leads to a segfault (which can be reproed on the base rev with the lit test added in this PR). The code modified in this PR was originally added in #4927, in conjunction with updating the `llvm-project` hash to `b5cc222d7429`. However, as noted in https://github.com/llvm/llvm-project/pull/110344 (the `llvm-project` PR that has made the refactoring prompting the `AxisInfo.cpp` change in #4927): > For dense forward data-flow analysis and other analysis (except dense backward data-flow analysis), the program point corresponding to the original operation can be obtained by `getProgramPointAfter(op)` As the `AxisInfoAnalysis` (in Triton) inherits from `SparseForwardDataFlowAnalysis` (in MLIR), in this PR we follow the above which resolves the segfault issue (as the `ProgramPoint` is now stored in the instance-level state of the pass). P.S. The lit test added in this PR is not exactly minimal. However, I did my best to minimize it starting from the 400-line repro TTGIR in #5122. Further minimization does not seem to expose the segfault. --- lib/Analysis/AxisInfo.cpp | 6 +++--- test/TritonGPU/coalesce.mlir | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index f0c5ae3167..717df8d1bd 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1084,9 +1084,9 @@ LogicalResult AxisInfoAnalysis::visitOperation( void AxisInfoAnalysis::visitForOpInductionVar( scf::ForOp op, ArrayRef *> argLattices) { - ProgramPoint programPoint(op); - auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue(); - auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue(); + ProgramPoint *programPoint = getProgramPointAfter(op); + auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + auto step = getLatticeElementFor(programPoint, op.getStep())->getValue(); AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index cf93c37b84..5d35f43e9e 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -131,3 +131,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + +// ----- + +// COM: Reproducer for issue #5122 +// CHECK-LABEL: @test_5122 +module { + tt.func public @test_5122(%arg0: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %0 = arith.cmpi sgt, %arg0, %c1_i32 : i32 + scf.if %0 { + %1 = scf.if %0 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %2 = arith.cmpi sgt, %1, %c1_i32 : i32 + %3 = scf.if %2 -> (i32) { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 { + %5 = arith.addi %arg2, %c1_i32 : i32 + scf.yield %5 : i32 + } + } + tt.return + } +} From 689dcfee09ed8436bcc4871bad2c4f35c3d54e03 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 18 Nov 2024 08:00:12 -0800 Subject: [PATCH 03/11] [INTERPRETER] Fix argument passing for internal parameters in function declarations (#5169) --- python/test/unit/language/test_core.py | 27 +++++++++++++------------- python/triton/runtime/interpreter.py | 10 +++++----- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4a22ee1cda..a499dc2321 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5581,7 +5581,7 @@ def matmul_kernel( # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # low_precision_acc: tl.constexpr, # - num_pipeline_stages: tl.constexpr = 3 # + num_stages: tl.constexpr = 3 # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -5593,7 +5593,7 @@ def matmul_kernel( # a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages): + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): a = tl.load(a_ptrs) b = tl.load(b_ptrs) accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) @@ -5632,7 +5632,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, - num_pipeline_stages=num_stages) + num_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) @@ -5824,7 +5824,7 @@ def test_tl_range(device): pgm = matmul_kernel[ 1, ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K, 0, num_pipeline_stages=5) + BLOCK_K, 0, num_stages=5) ref_out = torch.matmul(a, b).to(torch.float32) if is_interpreter(): # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. @@ -5850,8 +5850,8 @@ def maxnreg_noinline2(X): tl.store(X, 0) +@pytest.mark.interpreter def test_maxnreg(device): - assert not is_interpreter(), "this test won't work with the interpreter" if not is_cuda(): pytest.skip('maxnreg only works on CUDA') @@ -5865,14 +5865,15 @@ def kernel(X): X = torch.empty(1, dtype=torch.int32, device=device) k = kernel[(1, )](X, maxnreg=42) - # Ensure that .maxnreg is set on the kernel function (marked with .entry) - # and not on either of the noinline functions (marked with .func). - try: - assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) - assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) - except AssertionError: - print("Failing ptx:\n", k.asm["ptx"]) - raise + if not is_interpreter(): + # Ensure that .maxnreg is set on the kernel function (marked with .entry) + # and not on either of the noinline functions (marked with .func). + try: + assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"]) + assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"]) + except AssertionError: + print("Failing ptx:\n", k.asm["ptx"]) + raise @pytest.mark.interpreter diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 7c53697429..8d14107188 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -1034,9 +1034,6 @@ def _implicit_cvt(arg): interpreter_builder = InterpreterBuilder() -# These keywords are not supported by the interpreter -RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"] - class GridExecutor: @@ -1077,10 +1074,13 @@ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) def __call__(self, *args_dev, **kwargs): - # removes reserved keywords from kwargs - kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} if kwargs.pop("warmup", False): return + # Removes not used reserved keywords from kwargs + # Triton doesn't support keyword-only, variable positional or variable keyword arguments + # It's safe to inspect only positional or keyword arguments (i.e., argspec.args) + argspec = inspect.getfullargspec(self.fn) + kwargs = {k: v for k, v in kwargs.items() if k in argspec.args} # copy arguments to the host args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) # remaps core language functions to interpreted ones From 1fc326912b6fc7e080ca7dc6e131103d9f3aa8d7 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 18 Nov 2024 20:28:20 +0100 Subject: [PATCH 04/11] [NFC] Use reference instead of copies in few places (#5118) Apply fixes suggested by coverity static analysis. Signed-off-by: Anatoly Myachev --- include/triton/Analysis/AxisInfo.h | 7 ++++--- lib/Analysis/Allocation.cpp | 8 ++++---- lib/Analysis/AxisInfo.cpp | 6 ++++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index aad4503b48..1bf9c8a690 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -27,11 +27,12 @@ class AxisInfo { public: AxisInfo() : AxisInfo({}, {}, {}) {} - AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy) : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} - AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, - std::optional constantValue) + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy, std::optional constantValue) : contiguity(contiguity), divisibility(divisibility), constancy(constancy), constantValue(constantValue) { assert(divisibility.size() == contiguity.size()); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 74c64e65c3..02269c9aac 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -84,8 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, assert(cvtNeedsSharedMemory(srcTy, dstTy)); - auto inOrd = gpu::getOrder(srcLayout); - auto outOrd = gpu::getOrder(dstLayout); + const auto &inOrd = gpu::getOrder(srcLayout); + const auto &outOrd = gpu::getOrder(dstLayout); scratchConfig.order = outOrd; unsigned srcContigPerThread = @@ -303,7 +303,7 @@ class AllocationAnalysis { /// arguments are involved. void resolveAliasBufferLiveness( function_ref(Value value)> getLiveness) { - for (auto aliasBufferIter : allocation->aliasBuffer) { + for (const auto &aliasBufferIter : allocation->aliasBuffer) { auto value = aliasBufferIter.first; auto buffers = aliasBufferIter.second; auto range = getLiveness(value); @@ -443,7 +443,7 @@ class AllocationAnalysis { std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { auto xRange = bufferRange[buffer]; bool res = xRange.intersects(range); - for (auto val : tripleMap) + for (const auto &val : tripleMap) res = res && !val.second.intersects(xRange); // only one buffer intersect return res; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 717df8d1bd..fc6a2c73be 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1085,8 +1085,10 @@ LogicalResult AxisInfoAnalysis::visitOperation( void AxisInfoAnalysis::visitForOpInductionVar( scf::ForOp op, ArrayRef *> argLattices) { ProgramPoint *programPoint = getProgramPointAfter(op); - auto lb = getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); - auto step = getLatticeElementFor(programPoint, op.getStep())->getValue(); + const auto &lb = + getLatticeElementFor(programPoint, op.getLowerBound())->getValue(); + const auto &step = + getLatticeElementFor(programPoint, op.getStep())->getValue(); AxisInfo::DimVectorT knownContiguity(1, 1); AxisInfo::DimVectorT knownDivisibility(1, 1); From c76b342a2d704b6552c1224a4e7706bb85a4b888 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 18 Nov 2024 14:14:16 -0800 Subject: [PATCH 05/11] [BACKEND] Add missing precondition in optimize acc init (#5184) We need scalar select to be able to do this optimization. --- .../Transforms/OptimizeAccumulatorInit.cpp | 2 ++ test/TritonGPU/accumulator-init.mlir | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp index dd9b4ad139..2111d6241c 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -65,6 +65,8 @@ std::optional> findZeroInitOp(Value accUse, return std::nullopt; } if (auto selOp = dyn_cast(defOp)) { + if (!selOp.getCondition().getType().isInteger(1)) + return std::nullopt; if (isConstantZeroTensor(selOp.getTrueValue()) || isConstantZeroTensor(selOp.getFalseValue())) { return std::make_pair(selOp, 0); diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir index 72ef11dcaf..2f56d0e975 100644 --- a/test/TritonGPU/accumulator-init.mlir +++ b/test/TritonGPU/accumulator-init.mlir @@ -348,4 +348,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : } tt.return %17 : tensor<128x16xf32, #mma1> } + +// If the condition is a tensor skip the optimization. +// CHECK-LABEL: @negative_sel_tensor +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @negative_sel_tensor(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1> + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } } From 29d27d746eb6c00ccaf31a45b04d02baec91d845 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 19 Nov 2024 16:02:35 +0100 Subject: [PATCH 06/11] Make sure `ext_oneapi_get_default_context` doesn't broke runtime on windows (#2742) Part of #2478 (to reduce diff) These are quite stable changes, we can merge it without CI on Windows. @gshimansky if you don't mind. --------- Signed-off-by: Anatoly Myachev --- third_party/intel/backend/driver.c | 23 +++++++++++++++++++++-- utils/SPIRVRunner/SPIRVRunner.cpp | 23 ++++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/third_party/intel/backend/driver.c b/third_party/intel/backend/driver.c index 56e9ac4149..295af3e3fe 100644 --- a/third_party/intel/backend/driver.c +++ b/third_party/intel/backend/driver.c @@ -166,6 +166,26 @@ struct BuildFlags { } }; +sycl::context get_default_context(const sycl::device &sycl_device) { + const auto &platform = sycl_device.get_platform(); +#ifdef WIN32 + sycl::context ctx; + try { + ctx = platform.ext_oneapi_get_default_context(); + } catch (const std::runtime_error &ex) { + // This exception is thrown on Windows because + // ext_oneapi_get_default_context is not implemented. But it can be safely + // ignored it seems. +#if _DEBUG + std::cout << "ERROR: " << ex.what() << std::endl; +#endif + } + return ctx; +#else + return platform.ext_oneapi_get_default_context(); +#endif +} + static PyObject *loadBinary(PyObject *self, PyObject *args) { const char *name, *build_flags_ptr; int shared; @@ -194,8 +214,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { const size_t binary_size = PyBytes_Size(py_bytes); uint8_t *binary_ptr = (uint8_t *)PyBytes_AsString(py_bytes); - const auto ctx = - sycl_device.get_platform().ext_oneapi_get_default_context(); + const auto &ctx = get_default_context(sycl_device); const auto l0_device = sycl::get_native(sycl_device); const auto l0_context = diff --git a/utils/SPIRVRunner/SPIRVRunner.cpp b/utils/SPIRVRunner/SPIRVRunner.cpp index 2a2bd54577..0342e39b28 100644 --- a/utils/SPIRVRunner/SPIRVRunner.cpp +++ b/utils/SPIRVRunner/SPIRVRunner.cpp @@ -122,6 +122,26 @@ static inline T checkSyclErrors(const std::tuple tuple) { return std::get<0>(tuple); } +sycl::context get_default_context(const sycl::device &sycl_device) { + const auto &platform = sycl_device.get_platform(); +#ifdef WIN32 + sycl::context ctx; + try { + ctx = platform.ext_oneapi_get_default_context(); + } catch (const std::runtime_error &ex) { + // This exception is thrown on Windows because + // ext_oneapi_get_default_context is not implemented. But it can be safely + // ignored it seems. +#if _DEBUG + std::cout << "ERROR: " << ex.what() << std::endl; +#endif + } + return ctx; +#else + return platform.ext_oneapi_get_default_context(); +#endif +} + /** SYCL Functions **/ std::tuple, sycl::kernel, int32_t, int32_t> @@ -138,7 +158,8 @@ loadBinary(const std::string &kernel_name, const std::string &build_flags, const auto &sycl_l0_device_pair = g_sycl_l0_device_list[deviceId]; const sycl::device sycl_device = sycl_l0_device_pair.first; - const auto ctx = sycl_device.get_platform().ext_oneapi_get_default_context(); + const auto &ctx = get_default_context(sycl_device); + const auto l0_device = sycl::get_native(sycl_device); const auto l0_context = From 31bfb67090e7dc5bfed5214a772a14c2c3b68e62 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 18 Nov 2024 13:46:29 +0000 Subject: [PATCH 07/11] Revert "Revert "[Triton][Allocation] Enable `getScratchValueSize` specialization (#5070)"" This reverts commit c17a0fb139ec2ae216e435598e8dcabdb2a376b6. --- include/triton/Analysis/Allocation.h | 20 ++-- lib/Analysis/Allocation.cpp | 140 ++++++++++++++------------- test/Analysis/test-allocation.mlir | 7 ++ test/lib/Analysis/TestAllocation.cpp | 32 +++++- 4 files changed, 123 insertions(+), 76 deletions(-) diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 5d5f1a5709..1bdcae3d71 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -18,6 +18,12 @@ namespace mlir { namespace triton { class AllocationAnalysis; +/// Callback to allow backends to specify target-specific scratch sizes for +/// some operations. +using AllocationAnalysisScratchSizeFn = std::function; + +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op); + // To convert a tensor from one layout to another, we need to allocate a // temporary buffer (i.e., scratch buffer) in shared memory. The conversion may // require multiple iterations, with each iteration involving multiple @@ -141,7 +147,8 @@ class Allocation { explicit Allocation(Operation *operation) : operation(operation) {} /// Runs allocation analysis on the given top-level operation. - template void run(FuncAllocMapT &funcAllocMap); + void run(FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter); /// Returns the operation this analysis was constructed from. Operation *getOperation() const { return operation; } @@ -255,17 +262,18 @@ class ModuleAllocation : public CallGraph { public: using FuncOffsetMapT = DenseMap; - template - static ModuleAllocation get(ModuleOp moduleOp) { - ModuleAllocation res(moduleOp); - res.walk( + ModuleAllocation(ModuleOp moduleOp, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter = + triton::defaultAllocationAnalysisScratchSizeFn) + : CallGraph(moduleOp) { + walk( // Pre-order edge walk callback [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, // Post-order node walk callback [&](FunctionOpInterface funcOp) { auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp); if (inserted) - iter->second.template run(res.funcMap); + iter->second.run(funcMap, scratchSizeGetter); }); return res; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 6bd32b4746..6972344c87 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -118,13 +118,70 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, return scratchConfig; } +unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { + if (auto reduceOp = dyn_cast(op)) { + ReduceOpHelper helper(reduceOp); + return helper.getScratchSizeInBytes(); + } + if (auto scanOp = dyn_cast(op)) { + ScanLoweringHelper helper(scanOp); + return helper.getScratchSizeInBytes(); + } + if (auto histogram = dyn_cast(op)) { + auto dstTy = histogram.getType(); + int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + return std::max(dstTy.getNumElements(), threadsPerWarp) * + std::max(8, dstTy.getElementTypeBitWidth()) / 8; + } + if (auto cvtLayout = dyn_cast(op)) { + auto srcTy = cvtLayout.getSrc().getType(); + auto dstTy = cvtLayout.getType(); + auto srcEncoding = srcTy.getEncoding(); + auto dstEncoding = dstTy.getEncoding(); + if (mlir::isa(srcEncoding) || + mlir::isa(dstEncoding)) { + // Conversions from/to shared memory do not need scratch memory. + return 0; + } + // ConvertLayoutOp with both input/output non-shared_layout + // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's + // also possible to realize it with other approaches in restricted + // conditions, such as warp-shuffle + auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); + auto elems = getNumScratchElements(scratchConfig.paddedRepShape); + return isa(srcTy.getElementType()) + ? elems * kPtrBitWidth / 8 + : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + } + if (isa(op)) { + auto value = op->getOperand(0); + // only scalar requires scratch memory + // make it explicit for readability + if (dyn_cast(value.getType())) { + return 0; + } + auto smemShape = getRepShapeForAtomic(op->getResult(0)); + auto elems = getNumScratchElements(smemShape); + auto elemTy = cast(value.getType()).getPointeeType(); + assert(!isa(elemTy) && "unexpected pointer type"); + return elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + } + if (auto createTensormap = dyn_cast(op)) { + constexpr int32_t kTMASize = 128; + return kTMASize; + } + return 0; +} + class AllocationAnalysis { public: AllocationAnalysis(Operation *operation, Allocation::FuncAllocMapT *funcAllocMap, - Allocation *allocation) + Allocation *allocation, + AllocationAnalysisScratchSizeFn scratchSizeGetter) : operation(operation), funcAllocMap(funcAllocMap), - allocation(allocation) { + allocation(allocation), scratchSizeGetter(scratchSizeGetter) { run(); } @@ -177,77 +234,19 @@ class AllocationAnalysis { /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { - const size_t scratchAlignment = 128; - if (auto reduceOp = dyn_cast(op)) { - ReduceOpHelper helper(reduceOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto scanOp = dyn_cast(op)) { - ScanLoweringHelper helper(scanOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto histogram = dyn_cast(op)) { - auto dstTy = histogram.getType(); - int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp( - op->getParentOfType()); - auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * - std::max(8, dstTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto cvtLayout = dyn_cast(op)) { - auto srcTy = cvtLayout.getSrc().getType(); - auto dstTy = cvtLayout.getType(); - auto srcEncoding = srcTy.getEncoding(); - auto dstEncoding = dstTy.getEncoding(); - if (mlir::isa(srcEncoding) || - mlir::isa(dstEncoding)) { - // Conversions from/to shared memory do not need scratch memory. - return; - } - // ConvertLayoutOp with both input/output non-shared_layout - // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's - // also possible to realize it with other approaches in restricted - // conditions, such as warp-shuffle - auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); - auto elems = getNumScratchElements(scratchConfig.paddedRepShape); - auto bytes = - isa(srcTy.getElementType()) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (isa(op)) { - auto value = op->getOperand(0); - // only scalar requires scratch memory - // make it explicit for readability - if (dyn_cast(value.getType())) { - // nothing to do - } else { - auto smemShape = getRepShapeForAtomic(op->getResult(0)); - auto elems = getNumScratchElements(smemShape); - auto elemTy = cast(value.getType()).getPointeeType(); - assert(!isa(elemTy) && "unexpected pointer type"); - auto bytes = - elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } - } else if (auto callOp = dyn_cast(op)) { + constexpr size_t scratchAlignment = 128; + if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); auto *funcAlloc = &(*funcAllocMap)[funcOp]; auto bytes = funcAlloc->getSharedMemorySize(); maybeAddScratchBuffer(op, bytes, scratchAlignment); - } else if (auto createTensormap = - dyn_cast(op)) { - constexpr int32_t kTMASize = 128; - constexpr int32_t kTMAAlign = 128; - maybeAddScratchBuffer(op, kTMASize, - kTMAAlign); + return; } + unsigned bytes = scratchSizeGetter(op); + maybeAddScratchBuffer(op, bytes, + scratchAlignment); } void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { @@ -547,13 +546,16 @@ class AllocationAnalysis { Allocation::FuncAllocMapT *funcAllocMap; Allocation *allocation; BufferRangeMapT bufferRange; + AllocationAnalysisScratchSizeFn scratchSizeGetter; }; } // namespace triton -template <> -void Allocation::run(FuncAllocMapT &funcAllocMap) { - triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +void Allocation::run( + FuncAllocMapT &funcAllocMap, + triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this, + scratchSizeGetter); } std::map> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 208f6b80bf..a12e3a0260 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -1,4 +1,11 @@ // RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128 + +// Check there are no lines with a size different to 128 and we have at least a line with size 128. + +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} +// CHECK-128: scratch offset = {{.*}}, size = 128 +// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}} #AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}> diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 97adcfdf96..e7245e75cb 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -5,21 +5,42 @@ using namespace mlir; namespace { +unsigned getScratchSize128(Operation *) { return 128; } + +enum class GetScratchSizeFunction { + None, + ValidConstant, +}; + struct TestAllocationPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + TestAllocationPass() = default; + TestAllocationPass(const TestAllocationPass &other) + : PassWrapper>(other) {} + StringRef getArgument() const final { return "test-print-allocation"; } StringRef getDescription() const final { return "print the result of the allocation pass"; } + ModuleAllocation getModuleAllocation() { + switch (getScratchSizeFunction) { + case GetScratchSizeFunction::None: + return {getOperation()}; + case GetScratchSizeFunction::ValidConstant: + return {getOperation(), getScratchSize128}; + } + llvm_unreachable("Unhandled case"); + } + void runOnOperation() override { auto &os = llvm::errs(); ModuleOp moduleOp = getOperation(); // Convert to std::string can remove quotes from opName - ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp); + ModuleAllocation moduleAllocation = getModuleAllocation(); moduleOp.walk([&](triton::FuncOp funcOp) { auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); os << opName << "\n"; @@ -48,6 +69,15 @@ struct TestAllocationPass os << "size = " << allocation->getSharedMemorySize() << "\n"; }); } + + Option getScratchSizeFunction{ + *this, "get-scratch-size-function", + llvm::cl::desc("Custom scratch size function to use"), + llvm::cl::init(GetScratchSizeFunction::None), + llvm::cl::values( + clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"), + clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant", + "ValidConstant"))}; }; } // namespace From 3f682b16293f6f4192c9e9ce347de6134a90024a Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 18 Nov 2024 14:22:50 +0000 Subject: [PATCH 08/11] [XPU][Alloc] Use upstream interface to specialize Allocation analysis Defie custom scratch memory size getter to specialize Allocation analysis. Signed-off-by: victor-eds --- include/triton/Analysis/Allocation.h | 9 +- .../TritonGPUToLLVM/AllocateSharedMemory.cpp | 2 +- python/src/passes.cc | 3 +- test/lib/Analysis/TestMembar.cpp | 2 +- .../TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp | 2 +- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- .../intel/include/Analysis/Allocation.h | 11 +- third_party/intel/lib/Analysis/Allocation.cpp | 642 +----------------- .../AllocateSharedMemory.cpp | 4 +- .../TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- 11 files changed, 44 insertions(+), 639 deletions(-) diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 1bdcae3d71..f4de4c6b53 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -249,9 +249,6 @@ class Allocation { size_t sharedMemorySize = 0; }; -template <> -void Allocation::run(FuncAllocMapT &funcAllocMap); - /// Static analysis that computes the allocation of shared memory buffers /// of the entire call graph. /// The allocation is performed in a post-order walk of the call graph. @@ -271,11 +268,10 @@ class ModuleAllocation : public CallGraph { [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, // Post-order node walk callback [&](FunctionOpInterface funcOp) { - auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp); + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); if (inserted) iter->second.run(funcMap, scratchSizeGetter); }); - return res; } size_t getSharedMemorySize() { @@ -300,9 +296,6 @@ class ModuleAllocation : public CallGraph { } private: - explicit ModuleAllocation(ModuleOp moduleOp) - : CallGraph(moduleOp) {} - FuncOffsetMapT sharedMemoryValue; }; diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index a85abe7c7f..aae9faf0ee 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -23,7 +23,7 @@ struct AllocateSharedMemory void runOnOperation() override { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - ModuleAllocation allocation = ModuleAllocation::get(mod); + ModuleAllocation allocation(mod); mod.walk([&](FunctionOpInterface funcOp) { funcOp.walk([&](Operation *op) { diff --git a/python/src/passes.cc b/python/src/passes.cc index 989c2dae5a..d6612387b2 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -17,8 +17,7 @@ namespace py = pybind11; void init_triton_analysis(py::module &&m) { py::class_(m, "allocation", py::module_local()) - .def(py::init( - &mlir::ModuleAllocation::get)); + .def(py::init()); py::class_(m, "membar", py::module_local()) .def(py::init()) .def("run", &mlir::ModuleMembarAnalysis::run); diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 32546808bb..25e8e2d198 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -25,7 +25,7 @@ struct TestMembarPass Operation *operation = getOperation(); ModuleOp moduleOp = cast(operation); // Print all ops after membar pass - ModuleAllocation allocation = ModuleAllocation::get(moduleOp); + ModuleAllocation allocation(moduleOp); ModuleMembarAnalysis membarPass(&allocation, mlir::triton::NVIDIA::canSkipBarSync); membarPass.run(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp index e6af9391e6..4a0a7fed22 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp @@ -231,7 +231,7 @@ class OptimizeAMDLDSUsage LDSLimit = targetInfo.getSharedMemorySize(); } - ModuleAllocation allocAnalysis = ModuleAllocation::get(mod); + ModuleAllocation allocAnalysis(mod); if (allocAnalysis.getSharedMemorySize() <= LDSLimit) return; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 79bfa96cbe..f99cd50b0d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -107,7 +107,7 @@ struct ConvertTritonAMDGPUToLLVM } // Allocate shared memory and set barrier - ModuleAllocation allocation = ModuleAllocation::get(mod); + ModuleAllocation allocation(mod); ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); diff --git a/third_party/intel/include/Analysis/Allocation.h b/third_party/intel/include/Analysis/Allocation.h index afdef179a1..08d5135e3b 100644 --- a/third_party/intel/include/Analysis/Allocation.h +++ b/third_party/intel/include/Analysis/Allocation.h @@ -3,13 +3,8 @@ #include "triton/Analysis/Allocation.h" -namespace mlir { -namespace triton::intel { -class AllocationAnalysis; -} // namespace triton::intel -template <> -void Allocation::run( - FuncAllocMapT &funcAllocMap); -} // namespace mlir +namespace mlir::triton::intel { +unsigned allocationAnalysisScratchSizeFn(Operation *op); +} // namespace mlir::triton::intel #endif diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp index b868711673..70782aaa36 100644 --- a/third_party/intel/lib/Analysis/Allocation.cpp +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -1,624 +1,42 @@ #include "intel/include/Analysis/Allocation.h" -#include -#include -#include +#include "llvm/ADT/TypeSwitch.h" -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Analysis/Liveness.h" -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Support/LLVM.h" -#include "triton/Analysis/Alias.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "llvm/ADT/SmallVector.h" #include "intel/include/Analysis/Utility.h" -using ::mlir::triton::gpu::AMDMfmaEncodingAttr; -using ::mlir::triton::gpu::BlockedEncodingAttr; -using ::mlir::triton::gpu::DotOperandEncodingAttr; -using ::mlir::triton::gpu::getContigPerThread; -using ::mlir::triton::gpu::getOrder; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; -using ::mlir::triton::gpu::getSizePerThread; -using ::mlir::triton::gpu::getUniqueContigPerThread; -using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; -using ::mlir::triton::gpu::SharedEncodingAttr; -using ::mlir::triton::gpu::SliceEncodingAttr; - -namespace mlir { - -//===----------------------------------------------------------------------===// -// Shared Memory Allocation Analysis -//===----------------------------------------------------------------------===// -namespace triton::intel { - -// Bitwidth of pointers +namespace mlir::triton::intel { +namespace { constexpr int kPtrBitWidth = 64; +constexpr unsigned invalidSize = -1; -static std::pair, SmallVector> -getCvtOrder(Attribute srcLayout, Attribute dstLayout) { - auto srcMmaLayout = mlir::dyn_cast(srcLayout); - auto srcDotLayout = mlir::dyn_cast(srcLayout); - auto dstMmaLayout = mlir::dyn_cast(dstLayout); - auto dstDotLayout = mlir::dyn_cast(dstLayout); - - assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() && - !srcMmaLayout.isHopper()) && - "mma -> mma layout conversion is only supported on Ampere"); - - // mma or dot layout does not have an order, so the order depends on the - // layout of the other operand. - const auto &inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout) - : getOrder(srcLayout); - const auto &outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout) - : getOrder(dstLayout); - - return {inOrd, outOrd}; -} - -static SmallVector getRepShapeForCvt(RankedTensorType srcTy, - RankedTensorType dstTy) { - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - if (!cvtNeedsSharedMemory(srcTy, dstTy)) { - return {}; - } - - if (shouldUseDistSmem(srcLayout, dstLayout)) { - // TODO: padding to avoid bank conflicts - return convertType(getShapePerCTA(srcTy)); - } - - assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()"); - - auto srcShapePerCTA = getShapePerCTA(srcTy); - auto dstShapePerCTA = getShapePerCTA(dstTy); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); - - unsigned rank = dstTy.getRank(); - SmallVector repShape(rank); - for (unsigned d = 0; d < rank; ++d) { - repShape[d] = - std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), - std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); - } - return repShape; -} - -// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning -// a scalar value because Triton's block-based programming model ensures that -// all threads in each block see the same return value, even those threads that -// do not participate in the atomic operation -static SmallVector getRepShapeForAtomic(Value result) { - SmallVector smemShape; - if (atomicNeedsSharedMemory(result)) { - smemShape.push_back(1); - } - return smemShape; -} - -ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, - RankedTensorType dstTy) { - if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) { - // Conversions that can be implemented as sub-group shuffles do not need - // scratch memory. - return ScratchConfig({}, {}); - } - +unsigned allocationAnalysisScratchSizeFn(gpu::ConvertLayoutOp convertLayout) { + RankedTensorType srcTy = convertLayout.getSrc().getType(); + RankedTensorType dstTy = convertLayout.getResult().getType(); + if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) + return 0; if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) { - // Conversions that can be implemented as sub-group transposes store the - // whole tensor in shared memory and read it afterwards. - auto srcEncoding = cast(srcTy.getEncoding()); - unsigned threadsPerWarp = product(srcEncoding.getThreadsPerWarp()); - unsigned warpsPerCTA = product(srcEncoding.getWarpsPerCTA()); - unsigned remaining = product(srcTy.getShape()) / - (threadsPerWarp * threadsPerWarp * warpsPerCTA); - SmallVector repShape{threadsPerWarp, threadsPerWarp, remaining, - warpsPerCTA}; - return ScratchConfig(repShape, repShape, - /*inVec=*/1, /*outVec=*/threadsPerWarp); - } - - // Initialize vector sizes and stride - auto repShape = getRepShapeForCvt(srcTy, dstTy); - if (repShape.empty()) - return ScratchConfig({}, {}); - ScratchConfig scratchConfig(repShape, repShape); - auto rank = repShape.size(); - Attribute srcLayout = srcTy.getEncoding(); - Attribute dstLayout = dstTy.getEncoding(); - - assert(cvtNeedsSharedMemory(srcTy, dstTy)); - - // FIXME This is NOT entirely correct - // This should be getElemOrder, but we don't have such a method - // TODO Implement getElemOrder and make sure it's consistent with - // getContigPerThread - auto inOrd = gpu::getThreadOrder(srcLayout); - auto outOrd = gpu::getThreadOrder(dstLayout); - scratchConfig.order = outOrd; - - unsigned srcContigPerThread = - getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; - unsigned dstContigPerThread = - getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; - // TODO: Fix the legacy issue that ourOrd[0] == 0 always means - // that we cannot do vectorization. - unsigned innerDim = rank - 1; - scratchConfig.inVec = outOrd[0] != innerDim ? 1 - : inOrd[0] != innerDim ? 1 - : srcContigPerThread; - scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread; - - if (auto mma = mlir::dyn_cast(srcLayout)) { - if (mma.getVersionMajor() == 1) { - // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the - // codegen. - scratchConfig.inVec = srcContigPerThread; - } else if (mlir::isa(dstLayout)) { - // when storing from mma layout and loading in blocked layout vectorizing - // the load back gives better performance even if there is a - // transposition. - scratchConfig.outVec = dstContigPerThread; - } - } - - // No padding is required if the tensor is 1-D, or if all dimensions except - // the first accessed dimension have a size of 1. - if (rank <= 1 || product(repShape) == repShape[outOrd[0]]) - return scratchConfig; - - auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec); - scratchConfig.paddedRepShape[outOrd[0]] += paddedSize; - return scratchConfig; + Type elemTy = srcTy.getElementType(); + unsigned bytesPerElement = + isa(elemTy) + ? kPtrBitWidth / 8 + : std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; + return product(srcTy.getShape()) * bytesPerElement; + } + return invalidSize; } - -class AllocationAnalysis { -public: - AllocationAnalysis(Operation *operation, - Allocation::FuncAllocMapT *funcAllocMap, - Allocation *allocation) - : operation(operation), funcAllocMap(funcAllocMap), - allocation(allocation) { - run(); - } - -private: - using BufferT = Allocation::BufferT; - - /// Value -> Liveness Range - /// Use MapVector to ensure determinism. - using BufferRangeMapT = llvm::MapVector>; - /// Nodes -> Nodes - using GraphT = DenseMap>; - - void run() { - getValuesAndSizes(); - resolveLiveness(); - computeOffsets(); - } - - /// Initializes explicitly defined shared memory values for a given operation. - void getExplicitValueSize(Operation *op) { - for (Value result : op->getResults()) { - auto alloc = result.getDefiningOp(); - if (alloc && alloc.isSharedMemoryAlloc()) { - // Bytes could be a different value once we support padding or other - // allocation policies. - auto allocType = alloc.getType(); - auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); - auto bytes = product(shapePerCTA) * - allocType.getElementTypeBitWidth() / 8; - - auto alignment = alloc.getAlignmentOrDefault(); - allocation->addBuffer(result, bytes, - alignment); - } - } - } - - template - void maybeAddScratchBuffer(Operation *op, unsigned bytes, - unsigned alignment) { - if (bytes > 0) - allocation->addBuffer(op, bytes, alignment); - } - - template - void maybeAddScratchBuffer(Operation *op, unsigned bytes) { - if (bytes > 0) - allocation->addBuffer(op, bytes); - } - - /// Initializes temporary shared memory for a given operation. - void getScratchValueSize(Operation *op) { - const size_t scratchAlignment = 128; - if (auto reduceOp = dyn_cast(op)) { - ReduceOpHelper helper(reduceOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto scanOp = dyn_cast(op)) { - ScanLoweringHelper helper(scanOp); - unsigned bytes = helper.getScratchSizeInBytes(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto histogram = dyn_cast(op)) { - auto dstTy = histogram.getType(); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( - op->getParentOfType()); - auto bytes = std::max(dstTy.getNumElements(), threadsPerWarp) * - std::max(8, dstTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto cvtLayout = dyn_cast(op)) { - auto srcTy = cvtLayout.getSrc().getType(); - auto dstTy = cvtLayout.getType(); - auto srcEncoding = srcTy.getEncoding(); - auto dstEncoding = dstTy.getEncoding(); - if (mlir::isa(srcEncoding) || - mlir::isa(dstEncoding)) { - // Conversions from/to shared memory do not need scratch memory. - return; - } - // ConvertLayoutOp with both input/output non-shared_layout - // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's - // also possible to realize it with other approaches in restricted - // conditions, such as warp-shuffle - auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy); - auto elems = getNumScratchElements(scratchConfig.paddedRepShape); - auto bytes = - isa(srcTy.getElementType()) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (isa(op)) { - auto value = op->getOperand(0); - // only scalar requires scratch memory - // make it explicit for readability - if (dyn_cast(value.getType())) { - // nothing to do - } else { - auto smemShape = getRepShapeForAtomic(op->getResult(0)); - auto elems = getNumScratchElements(smemShape); - auto elemTy = - cast(value.getType()).getPointeeType(); - auto bytes = - isa(elemTy) - ? elems * kPtrBitWidth / 8 - : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } - } else if (auto callOp = dyn_cast(op)) { - auto callable = callOp.resolveCallable(); - auto funcOp = dyn_cast(callable); - auto *funcAlloc = &(*funcAllocMap)[funcOp]; - auto bytes = funcAlloc->getSharedMemorySize(); - maybeAddScratchBuffer(op, bytes, - scratchAlignment); - } else if (auto createTensormap = - dyn_cast(op)) { - constexpr int32_t kTMASize = 128; - constexpr int32_t kTMAAlign = 128; - maybeAddScratchBuffer(op, kTMASize, - kTMAAlign); - } - } - - void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { - dataflow::Lattice *latticeElement = - analysis.getLatticeElement(value); - if (latticeElement) { - AliasInfo &info = latticeElement->getValue(); - if (!info.getAllocs().empty()) { - for (auto alloc : info.getAllocs()) { - allocation->addAlias(value, alloc); - } - } - } - } - - /// Extract all shared memory values and their sizes - void getValuesAndSizes() { - // Get the alloc values - operation->walk([&](Operation *op) { - getExplicitValueSize(op); - getScratchValueSize(op); - }); - // Get the alias values - std::unique_ptr solver = createDataFlowSolver(); - SharedMemoryAliasAnalysis *aliasAnalysis = - solver->load(); - if (failed(solver->initializeAndRun(operation))) { - // TODO: return error instead of bailing out.. - llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); - } - operation->walk([&](Operation *op) { - for (auto operand : op->getOperands()) { - getValueAlias(operand, *aliasAnalysis); - } - for (auto value : op->getResults()) { - getValueAlias(value, *aliasAnalysis); - } - }); - } - - /// Computes the liveness range of the allocated value. - /// Each buffer is allocated only once. - void resolveExplicitBufferLiveness( - function_ref(Value value)> getLiveness) { - for (auto valueBufferIter : allocation->getValueBuffer()) { - auto value = valueBufferIter.first; - auto *buffer = valueBufferIter.second; - bufferRange[buffer] = getLiveness(value); - } - } - - /// Extends the liveness range by unionizing the liveness range of the aliased - /// values because each allocated buffer could be an alias of others, if block - /// arguments are involved. - void resolveAliasBufferLiveness( - function_ref(Value value)> getLiveness) { - for (const auto &aliasBufferIter : allocation->getAliasBuffer()) { - auto value = aliasBufferIter.first; - auto buffers = aliasBufferIter.second; - auto range = getLiveness(value); - for (auto *buffer : buffers) { - auto minId = range.start(); - auto maxId = range.end(); - if (bufferRange.count(buffer)) { - // Extend the allocated buffer's range - minId = std::min(minId, bufferRange[buffer].start()); - maxId = std::max(maxId, bufferRange[buffer].end()); - } - bufferRange[buffer] = Interval(minId, maxId); - } - } - } - - /// Computes the liveness range of scratched buffers. - /// Some operations may have a temporary buffer that is not explicitly - /// allocated, but is used to store intermediate results. - void resolveScratchBufferLiveness( - const DenseMap &operationId) { - // Analyze liveness of scratch buffers and virtual buffers. - auto processScratchMemory = [&](const auto &container) { - for (auto opScratchIter : container) { - // Any scratch memory's live range is the current operation's live - // range. - auto *op = opScratchIter.first; - auto *buffer = opScratchIter.second; - bufferRange.insert({buffer, Interval(operationId.lookup(op), - operationId.lookup(op) + 1)}); - } - }; - processScratchMemory(allocation->getOpScratch()); - processScratchMemory(allocation->getOpVirtual()); - } - - /// Resolves liveness of all values involved under the root operation. - void resolveLiveness() { - // Assign an ID to each operation using post-order traversal. - // To achieve the correct liveness range, the parent operation's ID - // should be greater than each of its child operation's ID . - // Example: - // ... - // %5 = triton.convert_layout %4 - // %6 = scf.for ... iter_args(%arg0 = %0) -> (i32) { - // %2 = triton.convert_layout %5 - // ... - // scf.yield %arg0 - // } - // For example, %5 is defined in the parent region and used in - // the child region, and is not passed as a block argument. - // %6 should should have an ID greater than its child operations, - // otherwise %5 liveness range ends before the child operation's liveness - // range ends. - DenseMap operationId; - operation->walk( - [&](Operation *op) { operationId[op] = operationId.size(); }); - - // Analyze liveness of explicit buffers - Liveness liveness(operation); - auto getValueLivenessRange = [&](Value value) { - auto liveOperations = liveness.resolveLiveness(value); - auto minId = std::numeric_limits::max(); - auto maxId = std::numeric_limits::min(); - std::for_each(liveOperations.begin(), liveOperations.end(), - [&](Operation *liveOp) { - if (operationId[liveOp] < minId) { - minId = operationId[liveOp]; - } - if ((operationId[liveOp] + 1) > maxId) { - maxId = operationId[liveOp] + 1; - } - }); - return Interval(minId, maxId); - }; - - resolveExplicitBufferLiveness(getValueLivenessRange); - resolveAliasBufferLiveness(getValueLivenessRange); - resolveScratchBufferLiveness(operationId); - } - - /// Computes the shared memory offsets for all related values. - /// Paper: Algorithms for Compile-Time Memory Optimization - /// (https://dl.acm.org/doi/pdf/10.5555/314500.315082) - void computeOffsets() { - SmallVector buffers; - for (auto bufferIter : bufferRange) { - buffers.emplace_back(bufferIter.first); - } - - calculateStarts(buffers); - - // NOTE: The original paper doesn't consider interference between - // the bumped ranges. Buffers that previously do not interfere with - // could interfere after offset bumping if their liveness ranges overlap. - // Therefore, we rerun the interference graph algorithm after bumping so - // that we regroup the buffers and color them again. Since we always - // increase the buffer offset and keep reducing conflicts, we will - // eventually reach a fixed point. - GraphT interference; - buildInterferenceGraph(buffers, interference); - do { - allocate(buffers, interference); - buildInterferenceGraph(buffers, interference); - } while (!interference.empty()); - } - - /// Computes the initial shared memory offsets. - void calculateStarts(const SmallVector &buffers) { - // v = values in shared memory - // t = triplet of (size, start, end) - // shared memory space - // - - // | *******t4 - // | /|\ v2 inserts t4, t5, and t6 - // | | - // | ******t5 ************t6 - // | ^^^^^v2^^^^^^ - // | | *********************t2 - // | \|/ v2 erases t1 - // | ******t1 ^^^^^^^^^v1^^^^^^^^^ ************t3 - // |---------------------------------------------| liveness range - // 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 ... - // If the available triple's range is less than a given buffer range, - // we won't know if there has been an overlap without using graph coloring. - // Start -> Liveness Range - using TripleMapT = std::multimap>; - TripleMapT tripleMap; - tripleMap.insert(std::make_pair(0, Interval())); - SmallVector xBuffers = buffers; - while (!xBuffers.empty()) { - auto tripleIt = tripleMap.begin(); - auto offset = tripleIt->first; - auto range = tripleIt->second; - tripleMap.erase(tripleIt); - auto bufferIt = - std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) { - auto xRange = bufferRange[buffer]; - bool res = xRange.intersects(range); - for (const auto &val : tripleMap) - res = res && - !val.second.intersects(xRange); // only one buffer intersect - return res; - }); - if (bufferIt != xBuffers.end()) { - auto buffer = *bufferIt; - auto xSize = buffer->size; - auto xRange = bufferRange.lookup(buffer); - // TODO(Keren): A buffer's size shouldn't be determined here, have to - // clean it up - size_t alignOffset = buffer->setOffsetAligned(offset); - tripleMap.insert({alignOffset + xSize, - Interval{std::max(range.start(), xRange.start()), - std::min(range.end(), xRange.end())}}); - // We could either insert (range.start, xRange.start) or (range.start, - // xRange.end), both are correct and determine the potential buffer - // offset, and the graph coloring algorithm will solve the interference, - // if any - if (range.start() < xRange.start()) - tripleMap.insert({offset, Interval{range.start(), xRange.end()}}); - if (xRange.end() < range.end()) - tripleMap.insert({offset, Interval{xRange.start(), range.end()}}); - xBuffers.erase(bufferIt); - } - } - } - - /// Builds a graph of all shared memory values. Edges are created between - /// shared memory values that are overlapping. - void buildInterferenceGraph(const SmallVector &buffers, - GraphT &interference) { - // Reset interference graph - interference.clear(); - for (auto x : buffers) { - for (auto y : buffers) { - if (x == y) - continue; - auto xStart = x->offset; - auto yStart = y->offset; - auto xSize = x->size; - auto ySize = y->size; - Interval xSizeRange = {xStart, xStart + xSize}; - Interval ySizeRange = {yStart, yStart + ySize}; - auto xOpRange = bufferRange.lookup(x); - auto yOpRange = bufferRange.lookup(y); - if (xOpRange.intersects(yOpRange) && - xSizeRange.intersects(ySizeRange)) { - interference[x].insert(y); - } - } - } - } - - /// Finalizes shared memory offsets considering interference. - void allocate(const SmallVector &buffers, - const GraphT &interference) { - // Reset shared memory size - allocation->setSharedMemorySize(0); - // First-fit graph coloring - // Neighbors are nodes that interfere with each other. - // We color a node by finding the index of the first available - // non-neighboring node or the first neighboring node without any color. - // Nodes with the same color do not interfere with each other. - DenseMap colors; - for (auto value : buffers) { - colors[value] = (value == buffers[0]) ? 0 : -1; - } - SmallVector available(buffers.size()); - for (auto x : buffers) { - std::fill(available.begin(), available.end(), true); - for (auto y : interference.lookup(x)) { - int color = colors[y]; - if (color >= 0) { - available[color] = false; - } - } - auto it = std::find(available.begin(), available.end(), true); - colors[x] = std::distance(available.begin(), it); - } - // Finalize allocation - // color0: [0, 7), [0, 8), [0, 15) -> [0, 7), [0, 8), [0, 15) - // color1: [7, 9) -> [0 + 1 * 15, 9 + 1 * 15) -> [15, 24) - // color2: [8, 12) -> [8 + 2 * 15, 12 + 2 * 15) -> [38, 42) - // TODO(Keren): We are wasting memory here. - // Nodes with color2 can actually start with 24. - for (auto x : buffers) { - size_t newOffset = 0; - for (auto y : interference.lookup(x)) { - newOffset = std::max(newOffset, y->offset + y->size); - } - if (colors.lookup(x) != 0) - x->setOffsetAligned(newOffset); - allocation->setSharedMemorySize( - std::max(allocation->getSharedMemorySize(), x->offset + x->size)); - } - } - -private: - Operation *operation; - Allocation::FuncAllocMapT *funcAllocMap; - Allocation *allocation; - BufferRangeMapT bufferRange; -}; - -} // namespace triton::intel - -template <> -void Allocation::run( - FuncAllocMapT &funcAllocMap) { - triton::intel::AllocationAnalysis(getOperation(), &funcAllocMap, this); +} // namespace + +unsigned allocationAnalysisScratchSizeFn(Operation *op) { + return TypeSwitch(op) + .Case([](auto op) { + unsigned size = allocationAnalysisScratchSizeFn(op); + return size == invalidSize ? defaultAllocationAnalysisScratchSizeFn(op) + : size; + }) + .Default([](Operation *op) { + return defaultAllocationAnalysisScratchSizeFn(op); + }); } - -} // namespace mlir +} // namespace mlir::triton::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp index 1a9e44e92e..f44489c501 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp @@ -22,8 +22,8 @@ struct AllocateSharedMemory void runOnOperation() override { ModuleOp mod = getOperation(); MLIRContext *ctx = &getContext(); - ModuleAllocation allocation = - ModuleAllocation::get(mod); + ModuleAllocation allocation( + mod, ::mlir::triton::intel::allocationAnalysisScratchSizeFn); mod.walk([&](FunctionOpInterface funcOp) { if (allocation.isRoot(funcOp) && allocation.getSharedMemorySize()) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index a4c2da184e..7feadbd22d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -97,8 +97,8 @@ struct ConvertTritonGPUToLLVM // Allocate shared memory and set barrier if (!pipelineManager.skipSharedMemoryAllocation()) { - ModuleAllocation allocation = - ModuleAllocation::get(mod); + ModuleAllocation allocation( + mod, ::mlir::triton::intel::allocationAnalysisScratchSizeFn); ModuleMembarAnalysis membarPass(&allocation, ::mlir::intel::membarFilter); membarPass.run(); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 9c7cfc044d..6674c9a810 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -96,7 +96,7 @@ struct ConvertTritonGPUToLLVM int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); // Allocate shared memory and set barrier - ModuleAllocation allocation = ModuleAllocation::get(mod); + ModuleAllocation allocation(mod); ModuleMembarAnalysis membarPass(&allocation, NVIDIA::canSkipBarSync); membarPass.run(); From b3ca9884310650690e6d82abd644dee0d932e385 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 19 Nov 2024 17:46:50 +0100 Subject: [PATCH 09/11] Run `test_debug.py` on XPU (#2754) Closes #2753 Pass rate: 93.25% -> 93.16% --------- Signed-off-by: Anatoly Myachev --- scripts/requirements-test.txt | 1 + scripts/skiplist/a770/debug.txt | 20 ++++++++++++++++++++ scripts/skiplist/conda/debug.txt | 20 ++++++++++++++++++++ scripts/skiplist/default/debug.txt | 20 ++++++++++++++++++++ scripts/skiplist/lts/debug.txt | 20 ++++++++++++++++++++ scripts/skiplist/mtl/debug.txt | 20 ++++++++++++++++++++ scripts/skiplist/xe2/debug.txt | 20 ++++++++++++++++++++ scripts/test-triton.sh | 3 +++ 8 files changed, 124 insertions(+) create mode 100644 scripts/skiplist/a770/debug.txt create mode 100644 scripts/skiplist/conda/debug.txt create mode 100644 scripts/skiplist/default/debug.txt create mode 100644 scripts/skiplist/lts/debug.txt create mode 100644 scripts/skiplist/mtl/debug.txt create mode 100644 scripts/skiplist/xe2/debug.txt diff --git a/scripts/requirements-test.txt b/scripts/requirements-test.txt index 4a9fbef6a7..c966bb12df 100644 --- a/scripts/requirements-test.txt +++ b/scripts/requirements-test.txt @@ -12,6 +12,7 @@ tabulate # Used by test-triton.sh pytest-xdist +pytest-forked pytest-rerunfailures pytest-select pytest-timeout diff --git a/scripts/skiplist/a770/debug.txt b/scripts/skiplist/a770/debug.txt new file mode 100644 index 0000000000..f3fde1d586 --- /dev/null +++ b/scripts/skiplist/a770/debug.txt @@ -0,0 +1,20 @@ +https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 +test/unit/test_debug.py::test_device_assert[True-True-True-False] +test/unit/test_debug.py::test_device_assert[True-False-None-False] +test/unit/test_debug.py::test_device_assert[False-True-True-False] +test/unit/test_debug.py::test_device_assert[False-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-None-False] +test/unit/test_debug.py::test_device_assert[True-False-True-False] +test/unit/test_debug.py::test_device_assert[False-False-True-False] +test/unit/test_debug.py::test_device_assert[False-True-None-False] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-2147483648--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-100-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-32768--1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[32767-1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[-1073741824--4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-2-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[-2147483648-1-int32-int32-True-True] diff --git a/scripts/skiplist/conda/debug.txt b/scripts/skiplist/conda/debug.txt new file mode 100644 index 0000000000..f3fde1d586 --- /dev/null +++ b/scripts/skiplist/conda/debug.txt @@ -0,0 +1,20 @@ +https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 +test/unit/test_debug.py::test_device_assert[True-True-True-False] +test/unit/test_debug.py::test_device_assert[True-False-None-False] +test/unit/test_debug.py::test_device_assert[False-True-True-False] +test/unit/test_debug.py::test_device_assert[False-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-None-False] +test/unit/test_debug.py::test_device_assert[True-False-True-False] +test/unit/test_debug.py::test_device_assert[False-False-True-False] +test/unit/test_debug.py::test_device_assert[False-True-None-False] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-2147483648--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-100-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-32768--1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[32767-1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[-1073741824--4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-2-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[-2147483648-1-int32-int32-True-True] diff --git a/scripts/skiplist/default/debug.txt b/scripts/skiplist/default/debug.txt new file mode 100644 index 0000000000..f3fde1d586 --- /dev/null +++ b/scripts/skiplist/default/debug.txt @@ -0,0 +1,20 @@ +https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 +test/unit/test_debug.py::test_device_assert[True-True-True-False] +test/unit/test_debug.py::test_device_assert[True-False-None-False] +test/unit/test_debug.py::test_device_assert[False-True-True-False] +test/unit/test_debug.py::test_device_assert[False-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-None-False] +test/unit/test_debug.py::test_device_assert[True-False-True-False] +test/unit/test_debug.py::test_device_assert[False-False-True-False] +test/unit/test_debug.py::test_device_assert[False-True-None-False] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-2147483648--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-100-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-32768--1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[32767-1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[-1073741824--4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-2-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[-2147483648-1-int32-int32-True-True] diff --git a/scripts/skiplist/lts/debug.txt b/scripts/skiplist/lts/debug.txt new file mode 100644 index 0000000000..f3fde1d586 --- /dev/null +++ b/scripts/skiplist/lts/debug.txt @@ -0,0 +1,20 @@ +https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 +test/unit/test_debug.py::test_device_assert[True-True-True-False] +test/unit/test_debug.py::test_device_assert[True-False-None-False] +test/unit/test_debug.py::test_device_assert[False-True-True-False] +test/unit/test_debug.py::test_device_assert[False-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-None-False] +test/unit/test_debug.py::test_device_assert[True-False-True-False] +test/unit/test_debug.py::test_device_assert[False-False-True-False] +test/unit/test_debug.py::test_device_assert[False-True-None-False] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-2147483648--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-100-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-32768--1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[32767-1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[-1073741824--4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-2-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[-2147483648-1-int32-int32-True-True] diff --git a/scripts/skiplist/mtl/debug.txt b/scripts/skiplist/mtl/debug.txt new file mode 100644 index 0000000000..f3fde1d586 --- /dev/null +++ b/scripts/skiplist/mtl/debug.txt @@ -0,0 +1,20 @@ +https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 +test/unit/test_debug.py::test_device_assert[True-True-True-False] +test/unit/test_debug.py::test_device_assert[True-False-None-False] +test/unit/test_debug.py::test_device_assert[False-True-True-False] +test/unit/test_debug.py::test_device_assert[False-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-None-False] +test/unit/test_debug.py::test_device_assert[True-False-True-False] +test/unit/test_debug.py::test_device_assert[False-False-True-False] +test/unit/test_debug.py::test_device_assert[False-True-None-False] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-2147483648--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-100-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-32768--1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[32767-1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[-1073741824--4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-2-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[-2147483648-1-int32-int32-True-True] diff --git a/scripts/skiplist/xe2/debug.txt b/scripts/skiplist/xe2/debug.txt new file mode 100644 index 0000000000..f3fde1d586 --- /dev/null +++ b/scripts/skiplist/xe2/debug.txt @@ -0,0 +1,20 @@ +https://github.com/intel/intel-xpu-backend-for-triton/issues/2755 +test/unit/test_debug.py::test_device_assert[True-True-True-False] +test/unit/test_debug.py::test_device_assert[True-False-None-False] +test/unit/test_debug.py::test_device_assert[False-True-True-False] +test/unit/test_debug.py::test_device_assert[False-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-False-False] +test/unit/test_debug.py::test_device_assert[True-True-None-False] +test/unit/test_debug.py::test_device_assert[True-False-True-False] +test/unit/test_debug.py::test_device_assert[False-False-True-False] +test/unit/test_debug.py::test_device_assert[False-True-None-False] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-2147483648--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[2147483647-100-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[-32768--1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_add_overflow[32767-1-int16-int16-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[-1073741824--4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-2-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_mul_overflow[1073741824-4-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[2147483647--1-int32-int32-True-True] +test/unit/test_debug.py::test_sanitize_int_sub_overflow[-2147483648-1-int32-int32-True-True] diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 9fb2256498..d35e1983d8 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -191,6 +191,9 @@ run_core_tests() { TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=runtime \ pytest --verbose --device xpu runtime/ + TRITON_TEST_SUITE=debug \ + pytest --verbose -n ${PYTEST_MAX_PROCESSES:-8} test_debug.py --forked --device xpu + # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 TRITON_TEST_SUITE=line_info \ pytest -k "not test_line_info_interpreter" --verbose --device xpu language/test_line_info.py From 513aedeb9e91e3a72f1d28003f67c59462b77778 Mon Sep 17 00:00:00 2001 From: Gregory Shimansky Date: Tue, 19 Nov 2024 14:40:25 -0600 Subject: [PATCH 10/11] Windows native port (#2478) Fixes #2407. Current state is that code builds on windows and is able to pass many unit tests. More fixes will be done when this patch is integrated into main branch. --------- Signed-off-by: Gregory Shimansky --- CMakeLists.txt | 65 +++++++++++++++----- python/setup.py | 58 ++++++++++++++++- python/triton/runtime/CLFinder.py | 55 +++++++++++++++++ python/triton/runtime/build.py | 45 +++++++++++--- third_party/intel/backend/driver.py | 10 ++- third_party/nvidia/backend/driver.py | 11 +++- third_party/nvidia/include/cublas_instance.h | 18 +++++- 7 files changed, 229 insertions(+), 33 deletions(-) create mode 100644 python/triton/runtime/CLFinder.py diff --git a/CMakeLists.txt b/CMakeLists.txt index da6eec1410..aa9bd605c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,23 +8,24 @@ endif() include(ExternalProject) -set(CMAKE_CXX_STANDARD 17) - set(CMAKE_INCLUDE_CURRENT_DIR ON) project(triton CXX) include(CTest) -if(NOT WIN32) - list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -endif() - - +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") # Options +if(WIN32) + set(DEFAULT_BUILD_PROTON OFF) +else() + set(DEFAULT_BUILD_PROTON ON) +endif() + +# Define the option with the determined default value +option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ${DEFAULT_BUILD_PROTON}) option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON) option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) -option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") @@ -49,10 +50,21 @@ endif() # used conditionally in this file and by lit tests # Customized release build type with assertions: TritonRelBuildWithAsserts -set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") -set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") -set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") +if(NOT MSVC) + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g") + set(CMAKE_C_FLAGS_TRITONBUILDWITHO1 "-O1") + set(CMAKE_CXX_FLAGS_TRITONBUILDWITHO1 "-O1") +else() + set(CMAKE_CXX_STANDARD 20) + set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor") + set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1 /bigobj /Zc:preprocessor") + set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") +endif() # Default build type if(NOT CMAKE_BUILD_TYPE) @@ -70,7 +82,15 @@ endif() # Compiler flags include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17") +if(NOT MSVC) + if(NOT WIN32) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -fPIC") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS -Wno-deprecated") + endif() +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530") +endif() # ######### @@ -124,7 +144,11 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-") +endif() include_directories(".") include_directories(${MLIR_INCLUDE_DIRS}) @@ -134,7 +158,8 @@ include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files include_directories(${PROJECT_SOURCE_DIR}/third_party) include_directories(${PROJECT_BINARY_DIR}/third_party) # Tablegen'd files -# link_directories(${LLVM_LIBRARY_DIR}) +link_directories(${LLVM_LIBRARY_DIR}) + add_subdirectory(include) add_subdirectory(lib) @@ -163,6 +188,8 @@ if(TRITON_BUILD_PYTHON_MODULE) # using pip install. include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${PYBIND11_INCLUDE_DIR}) + message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}") + link_directories(${PYTHON_LIB_DIRS}) else() # Otherwise, we might be building from top CMakeLists.txt directly. # Try to find Python and pybind11 packages. @@ -245,7 +272,7 @@ if(TRITON_BUILD_PYTHON_MODULE) LLVMAArch64CodeGen LLVMAArch64AsmParser ) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64") list(APPEND TRITON_LIBRARIES LLVMX86CodeGen LLVMX86AsmParser @@ -280,6 +307,8 @@ if(TRITON_BUILD_PYTHON_MODULE) target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES}) if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) + set_target_properties(triton PROPERTIES SUFFIX ".pyd") + set_target_properties(triton PROPERTIES PREFIX "lib") else() target_link_libraries(triton PRIVATE z) endif() @@ -306,6 +335,10 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() endif() +if(WIN32) + option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON) + option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON) +endif() add_subdirectory(third_party/f2reduce) add_subdirectory(bin) diff --git a/python/setup.py b/python/setup.py index 6bc6e1e489..95e0519b37 100644 --- a/python/setup.py +++ b/python/setup.py @@ -103,6 +103,49 @@ def copy_externals(): ] +def find_vswhere(): + program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)") + vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe" + if vswhere_path.exists(): + return vswhere_path + return None + + +def find_visual_studio(version_ranges): + vswhere = find_vswhere() + if not vswhere: + raise FileNotFoundError("vswhere.exe not found.") + + for version_range in version_ranges: + command = [ + str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", + "-property", "installationPath", "-prerelease" + ] + + try: + output = subprocess.check_output(command, text=True).strip() + if output: + return output + except subprocess.CalledProcessError: + continue + + return None + + +def set_env_vars(vs_path, arch="x64"): + vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat" + if not vcvarsall_path.exists(): + raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}") + + command = ["call", vcvarsall_path, arch, "&&", "set"] + output = subprocess.check_output(command, shell=True, text=True) + + for line in output.splitlines(): + if '=' in line: + var, value = line.split('=', 1) + os.environ[var] = value + + # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] @@ -196,6 +239,8 @@ def get_llvm_package_info(): f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build." ) return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") + elif system == 'Windows': + system_suffix = "windows-x64" else: print( f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build." @@ -281,10 +326,10 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func): base_dir = os.path.dirname(__file__) system = platform.system() try: - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] + arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] except KeyError: arch = platform.machine() - supported = {"Linux": "linux", "Darwin": "linux"} + supported = {"Linux": "linux", "Darwin": "linux", "Windows": "win"} url = url_func(supported[system], arch, version) tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path @@ -401,6 +446,11 @@ def get_proton_cmake_args(self): def build_extension(self, ext): lit_dir = shutil.which('lit') ninja_dir = shutil.which('ninja') + if platform.system() == "Windows": + vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"]) + env = set_env_vars(vs_path) + if not vs_path: + raise EnvironmentError("Visual Studio 2019 or 2022 not found.") # lit is used by the test suite thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() @@ -421,6 +471,10 @@ def build_extension(self, ext): "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]) ] + if platform.system() == "Windows": + installed_base = sysconfig.get_config_var('installed_base') + py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs")) + cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs) if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) cmake_args.extend(thirdparty_cmake_args) diff --git a/python/triton/runtime/CLFinder.py b/python/triton/runtime/CLFinder.py new file mode 100644 index 0000000000..2021e0b04f --- /dev/null +++ b/python/triton/runtime/CLFinder.py @@ -0,0 +1,55 @@ +import os +import subprocess +from pathlib import Path + + +def find_vswhere(): + program_files = os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)") + vswhere_path = Path(program_files) / "Microsoft Visual Studio" / "Installer" / "vswhere.exe" + if vswhere_path.exists(): + return vswhere_path + return None + + +def find_visual_studio(version_ranges): + vswhere = find_vswhere() + if not vswhere: + raise FileNotFoundError("vswhere.exe not found.") + + for version_range in version_ranges: + command = [ + str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", + "-property", "installationPath", "-prerelease" + ] + + try: + output = subprocess.check_output(command, text=True).strip() + if output: + return output + except subprocess.CalledProcessError: + continue + + return None + + +def set_env_vars(vs_path, arch="x64"): + vcvarsall_path = Path(vs_path) / "VC" / "Auxiliary" / "Build" / "vcvarsall.bat" + if not vcvarsall_path.exists(): + raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}") + + command = f'call "{vcvarsall_path}" {arch} && set' + output = subprocess.check_output(command, shell=True, text=True) + + for line in output.splitlines(): + if '=' in line: + var, value = line.split('=', 1) + os.environ[var] = value + + +def initialize_visual_studio_env(version_ranges, arch="x64"): + # Check if the environment variable that vcvarsall.bat sets is present + if os.environ.get('VSCMD_ARG_TGT_ARCH') != arch: + vs_path = find_visual_studio(version_ranges) + if not vs_path: + raise EnvironmentError("Visual Studio not found in specified version ranges.") + set_env_vars(vs_path, arch) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index aae62030e4..b40411aec6 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -6,6 +6,8 @@ import shutil import subprocess import setuptools +import platform +from .CLFinder import initialize_visual_studio_env def is_xpu(): @@ -23,6 +25,29 @@ def quiet(): sys.stdout, sys.stderr = old_stdout, old_stderr +def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries): + if cc in ["cl", "clang-cl"]: + cc_cmd = [cc, src, "/nologo", "/O2", "/LD"] + cc_cmd += [f"/I{dir}" for dir in include_dirs] + cc_cmd += [f"/Fo{os.path.join(os.path.dirname(out), 'main.obj')}"] + cc_cmd += ["/link"] + cc_cmd += [f"/OUT:{out}"] + cc_cmd += [f"/IMPLIB:{os.path.join(os.path.dirname(out), 'main.lib')}"] + cc_cmd += [f"/PDB:{os.path.join(os.path.dirname(out), 'main.pdb')}"] + cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs] + cc_cmd += [f'{lib}.lib' for lib in libraries] + else: + cc_cmd = [cc, src, "-O3", "-shared", "-Wno-psabi"] + if os.name != "nt": + cc_cmd += ["-fPIC"] + cc_cmd += [f'-l{lib}' for lib in libraries] + cc_cmd += [f"-L{dir}" for dir in library_dirs] + cc_cmd += [f"-I{dir}" for dir in include_dirs] + cc_cmd += ["-o", out] + + return cc_cmd + + def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compile_args=[]): suffix = sysconfig.get_config_var('EXT_SUFFIX') so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) @@ -33,6 +58,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi clang = shutil.which("clang") gcc = shutil.which("gcc") cc = gcc if gcc is not None else clang + if platform.system() == "Windows": + cc = "cl" + initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"]) if cc is None: raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") # This function was renamed and made public in Python 3.10 @@ -55,25 +83,24 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi clangpp = shutil.which("clang++") gxx = shutil.which("g++") icpx = shutil.which("icpx") - cxx = icpx or clangpp or gxx + cxx = icpx if os.name == "nt" else icpx or clangpp or gxx if cxx is None: raise RuntimeError("Failed to find C++ compiler. Please specify via CXX environment variable.") + cc = cxx import numpy as np numpy_include_dir = np.get_include() include_dirs = include_dirs + [numpy_include_dir] - cc_cmd = [cxx] if icpx is not None: - cc_cmd += ["-fsycl"] + extra_compile_args += ["-fsycl"] else: - cc_cmd += ["--std=c++17"] + extra_compile_args += ["--std=c++17"] + if os.name == "nt": + library_dirs += [os.path.join(sysconfig.get_paths(scheme=scheme)["stdlib"], "..", "libs")] else: cc_cmd = [cc] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 - cc_cmd += [src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] - cc_cmd += [f'-l{lib}' for lib in libraries] - cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries) cc_cmd += extra_compile_args if os.getenv("VERBOSE"): @@ -90,7 +117,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries, extra_compi language='c', sources=[src], include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], + extra_compile_args=extra_compile_args + ['-O3' if "-O3" in cc_cmd else "/O2"], extra_link_args=extra_link_args, library_dirs=library_dirs, libraries=libraries, diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index ec35e4e46a..f372091e56 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -34,7 +34,7 @@ def find_sycl(include_dir: list[str]) -> tuple[list[str], Optional[str]]: "or provide `ONEAPI_ROOT` environment " "or install `intel-sycl-rt>=2025.0.0` wheel") - if shutil.which("icpx"): + if shutil.which("icpx") and os.name != "nt": # only `icpx` compiler knows where sycl runtime binaries and header files are return include_dir, None @@ -74,7 +74,9 @@ def __init__(self): self._library_dir = None self._include_dir = None self._libsycl_dir = None - self.libraries = ['ze_loader', 'sycl'] + self.libraries = ['ze_loader'] + if os.name != "nt": + self.libraries += ["sycl"] @cached_property def _compute_compilation_options_lazy(self): @@ -85,6 +87,8 @@ def _compute_compilation_options_lazy(self): include_dir, self._libsycl_dir = find_sycl(include_dir) if self._libsycl_dir: library_dir += [self._libsycl_dir] + if os.name == "nt": + library_dir += [os.path.join(ze_root, "lib")] dirname = os.path.dirname(os.path.realpath(__file__)) include_dir += [os.path.join(dirname, "include")] @@ -215,7 +219,7 @@ def format_of(ty): "int8_t": "b", "int16_t": "h", "int32_t": "i", - "int64_t": "l", + "int64_t": "L", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index e0dac56d01..5f946e70bb 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -2,8 +2,10 @@ import os import sysconfig import hashlib +import sysconfig import subprocess import tempfile +import sys from pathlib import Path from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager @@ -22,6 +24,11 @@ def libcuda_dirs(): env_libcuda_path = os.getenv("TRITON_LIBCUDA_PATH") if env_libcuda_path: return [env_libcuda_path] + if os.name == "nt": + installed_base = sysconfig.get_config_var('installed_base') + dirs = [os.path.join(os.environ.get("CUDA_PATH"), "lib", "x64")] + dirs += [os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))] + return dirs libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() # each line looks like the following: @@ -139,7 +146,7 @@ def format_of(ty): "int8_t": "b", "int16_t": "h", "int32_t": "i", - "int64_t": "l", + "int64_t": "L", "uint8_t": "B", "uint16_t": "H", "uint32_t": "I", @@ -235,7 +242,7 @@ def format_of(ty): #endif static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(params)} }}; + void *params[] = {{{', '.join(f'&arg{i}' for i in params) if params else 'NULL'}}}; if (gridX*gridY*gridZ > 0) {{ if (num_ctas == 1) {{ CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); diff --git a/third_party/nvidia/include/cublas_instance.h b/third_party/nvidia/include/cublas_instance.h index d79d4d76bf..ad0fb2e9f3 100644 --- a/third_party/nvidia/include/cublas_instance.h +++ b/third_party/nvidia/include/cublas_instance.h @@ -2,7 +2,13 @@ #define TRITON_CUBLAS_INSTANCE_H #include "cublas_types.h" +#ifdef WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#else #include +#endif #include #include @@ -64,6 +70,9 @@ class CublasLtInstance { cublasLtMatmulPreference_t preference = NULL; void loadCublasDylib() { +#ifdef WIN32 + assert(0 && "Not implemented on Windows"); +#else if (dylibHandle == nullptr) { // First reuse the existing handle dylibHandle = dlopen(name, RTLD_NOLOAD); @@ -108,9 +117,16 @@ class CublasLtInstance { std::string(name) + "`: " + std::string(dlsym_error)); } +#endif } - void unloadCublasDylib() { dlclose(dylibHandle); } + void unloadCublasDylib() { +#ifdef WIN32 + assert(0 && "Not implemented on Windows"); +#else + dlclose(dylibHandle); +#endif + } void successOrExit(cublasStatus_t status) { if (status != CUBLAS_STATUS_SUCCESS) { From a8ca9e558026ff49c3bb74c6471c112b04f63d2d Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Tue, 19 Nov 2024 15:40:57 -0500 Subject: [PATCH 11/11] Fix coalescing pass (#2760) Fix Intel coalescing pass for cases where the result of a SCF loop (containing a coalescable block ptr load) is used by an operation with operands that do not have block ptr type (e.g. `tt.reduce`) --------- Signed-off-by: Tiotto, Ettore --- test/TritonIntelGPU/coalesce.mlir | 47 +++++++++++++++++++ .../lib/TritonIntelGPUTransforms/Coalesce.cpp | 11 +++-- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/test/TritonIntelGPU/coalesce.mlir b/test/TritonIntelGPU/coalesce.mlir index d9b2de454c..b078158d8b 100644 --- a/test/TritonIntelGPU/coalesce.mlir +++ b/test/TritonIntelGPU/coalesce.mlir @@ -336,3 +336,50 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +// COM: Test coalescing on blocked pointers: loop result used by tt.reduce + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}> + // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [1, 1, 16], order = [0, 1, 2]}> + // CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}> + // CHECK: @triton_red_fused_mul_sum_0 + tt.func public @triton_red_fused_mul_sum_0(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %c262144_i64 = arith.constant 262144 : i64 + %c1_i64 = arith.constant 1 : i64 + %c512_i64 = arith.constant 512 : i64 + %c32_i32 = arith.constant 32 : i32 + %c512_i32 = arith.constant 512 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c32_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %3 = tt.expand_dims %2 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %4 = arith.divsi %1, %c512_i32 : i32 + %5 = arith.remsi %1, %c512_i32 : i32 + // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg0, {{.*}} : > + %6 = tt.make_tensor_ptr %arg0, [%c512_i64, %c512_i64, %c512_i64], [%c1_i64, %c512_i64, %c262144_i64], [%4, %5, %c0_i32] {order = array} : > + // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR1]], [[ARG2:%.*]] = {{.*}}) -> (!tt.ptr>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>) + %8:2 = scf.for %arg5 = %c0_i32 to %c512_i32 step %c128_i32 iter_args(%arg6 = %6, %arg8 = %cst_0) -> (!tt.ptr>, tensor<32x128xf32, #blocked>) : i32 { + // CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] evictionPolicy = evict_last {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]> + %17 = tt.load %arg6 evictionPolicy = evict_last {boundaryCheck = array, padding = 1 : i32} : !tt.ptr> + // CHECK: scf.yield [[ARG1]], [[ARG2]] : !tt.ptr>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]> + scf.yield %arg6, %arg8 : !tt.ptr>, tensor<32x128xf32, #blocked> + } + // CHECK: = "tt.reduce"([[RES]]#1) <{axis = 1 : i32}> ({ + // CHECK }) : (tensor<32x128xf32, [[BLOCKED_LAYOUT]]) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = [[BLOCKED_LAYOUT]]}>> + %9 = "tt.reduce"(%8#1) <{axis = 1 : i32}> ({ + ^bb0(%arg5: f32, %arg6: f32): + %14 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %14 : f32 + }) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp index 7f52090f4e..978622ecc0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp @@ -148,11 +148,13 @@ struct CoalescePass if (op->getNumResults() == 0 && op->getNumRegions() == 0) return true; + // Operations that do not consume a block pointer aren't interesting. + if (llvm::none_of(op->getOperandTypes(), tt::isTensorPointerType)) + return true; + // Operations that do not yield a block pointer aren't interesting. if (op->getNumRegions() == 0 && - llvm::none_of(op->getResultTypes(), [](Type resType) { - return tt::isTensorPointerType(resType); - })) + llvm::none_of(op->getResultTypes(), tt::isTensorPointerType)) return true; return false; @@ -367,8 +369,7 @@ struct CoalescePass }); LLVM_DEBUG({ - DBGS() << "\nlayoutMap:" - << "\n"; + DBGS() << "\nlayoutMap:\n"; for (auto [op, encoding] : layoutMap) { DBGS() << "op: " << *op << "\n"; DBGS() << "encoding: " << encoding << "\n\n";