diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 93d87734c7..df6029db0d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy); bool atomicNeedsSharedMemory(Value result); -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT); +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); // Return true if the src and dst layout match. bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index c728cfbb32..47e3fca79b 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -679,6 +679,13 @@ class LinearLayout { // (i.e. every input bit affects the output). llvm::MapVector getFreeVariableMasks() const; + // Increase an input dimension without affecting the output dimension. The + // added free variables are mapped to 0, ensuring that the new input + // dimensions correspond directly to the existing output space. The function + // errors out if `newInDimSize` is less than the current size or the new size + // is not a power of 2. + LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const; + std::string toString() const; friend bool operator==(LinearLayout lhs, LinearLayout rhs); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ddce7f544a..78e1ca787f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) { (elemTy.isInteger(8) && version >= 2); } -bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { +bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) { auto blockedLayout = dyn_cast(srcTy.getEncoding()); auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); if (blockedLayout == nullptr || dotOperandLayout == nullptr) @@ -646,8 +646,46 @@ std::optional minimalCvtLayout(RankedTensorType srcTy, toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); if (!(srcLayout.has_value() && dstLayout.has_value())) return std::nullopt; + StringAttr kRegister = StringAttr::get(ctx, "register"); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kBlock = StringAttr::get(ctx, "block"); + auto numSrcRegs = srcLayout->getInDimSize(kRegister); + auto numDstRegs = dstLayout->getInDimSize(kRegister); + // The `invertAndCompose` function will generate a layout that is injective + // by assigning new output dimensions to free variables. For instance, + // consider a scenario where `srcLayout` has a free variable in the lane + // dimension, while `dstLayout` has two free variables in the lane + // dimension and also a larger number of registers. + // The injective form of `srcLayout` will add only a single additional row + // to the transformation matrix, whereas the injective form of `dstLayout` + // will add two additional rows. This discrepancy causes misleading results + // because the matrices end up with a different number of rows. + // + // Take `dstLayout ⋅ srcLayout^-1` as an example: + // + // - `injective(dstLayout)`: [n, m] → [n + 2, m] + // - `injective(srcLayout)`: [n, m] → [n + 1, m] + // - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1] + // - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n + + // 1] → [n + 2, n + 1] + // + // Here, the `(n + 1)`-th row added by `dstLayout` represents the free + // variable in registers, and the `(n + 2)`-th row represents the free + // variable in lanes. However, the `(n + 1)`-th row added by `srcLayout` + // represents the free variable in lanes. As a result, the `(n + 1)`-th row + // in two layouts do not correspond to the same free variable. + // + // To address this issue, we pad the free variables in `srcLayout` and + // `dstLayout` to ensure they have the same number of registers. This + // guarantees that the resulting matrices have the same number of rows, + // ensuring consistency in the composition process. + auto numRegs = std::max(numSrcRegs, numDstRegs); + auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs); + auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs); // comp describes the layout function to create dst from src. - LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); + LinearLayout comp = + dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs); // We try to quotient by the largest subspace first auto dims = SmallVector{"block", "warp", "lane", "register"}; for (auto dim : dims) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 65ee8cc002..b889d4812c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -328,20 +328,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } else { // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. - auto dstCvt = requiresI32Conversion(dstTy); - auto srcCvt = requiresI32Conversion(srcTy); - if (dstCvt || srcCvt) { - auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(), - getTypeConverter()); - inVals = - packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter()); - auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals, - rewriter, op.getType()); - rewriter.replaceOp(op, res); - } else { - rewriter.replaceOp(op, adaptor.getSrc()); - } + rewriter.replaceOp(op, adaptor.getSrc()); return success(); } } @@ -358,9 +345,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); SmallVector outVals(numRegs); - for (int i = 0; i < numRegs; i++) { + for (int i = 0; i < outVals.size(); i++) { // Remove free masks from the register index // For example, if idx = 0b00111, and masks = 0b00100, then we get // 0b00011. It means that register 7 (0b111) has the same value as diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index bf017f8c64..4319d1f086 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const { return true; } +LinearLayout LinearLayout::resize(StringAttr inDim, + int32_t newInDimSize) const { + BasesT bases = getBases(); + assert(bases.contains(inDim) && "inDim not in layout"); + assert(llvm::isPowerOf2_32(newInDimSize) && + "newInDimSize must be a power of 2"); + assert(newInDimSize >= getInDimSize(inDim) && + "newInDimSize must be >= old size"); + auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim); + for (int i = 0; i < numFreeVariables; i++) { + bases[inDim].push_back(std::vector(getNumOutDims(), 0)); + } + return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames())); +} + std::string LinearLayout::toString() const { // Start with a newline because we print out a bulleted list; it doesn't // make sense for the first line of this list to be on the same line as diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 66bba6e060..146c6ffec7 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -946,6 +946,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // ----- +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: convert_layout_mmav2_dot_reg + tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_0 + tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_1 + tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_2 + tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0> + tt.return + } +} + +// ----- + +#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: convert_layout_mmav3_mmav3_3 + tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) { + // CHECK-NOT: st.shared + // CHECK-NOT: llvm.load + %0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0> + tt.return + } +} + +// ----- + #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index f006447002..897172fd6d 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -747,6 +747,39 @@ TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); } +TEST_F(LinearLayoutTest, Resize) { + auto init = LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")}); + EXPECT_EQ(init.resize(S("in0"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}, {0, 0}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in0"), 4), LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ(init.resize(S("in1"), 8), + LinearLayout( + { + {S("in0"), {{0, 1}, {0, 2}}}, + {S("in1"), {{1, 0}, {2, 0}, {0, 0}}}, + {S("in2"), {}}, + }, + {S("dim0"), S("dim1")})); +} + } // anonymous namespace } // namespace mlir::triton