From e5be006a4f8c1d8a47ae7c618844eece8ec8612c Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 9 Dec 2024 00:32:48 -0800 Subject: [PATCH 1/9] [AMD] Use warp shuffle for fp8 MFMA to dot operand layout conversion (#5362) This relands https://github.com/triton-lang/triton/pull/5139: Adding a shortcut case for fp8 MFMA to dot operand layout conversion that avoids using shared memory, to speed up FP8 attention kernels. --------- Co-authored-by: ilia-cher <30845429+ilia-cher@users.noreply.github.com> --- include/triton/Analysis/Utility.h | 5 + lib/Analysis/Utility.cpp | 25 ++- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 6 + test/Conversion/amd/mfma-shortcut.mlir | 190 +++++++++++++++++- .../ConvertLayoutOpToLLVM.cpp | 154 ++++++++++++++ 5 files changed, 378 insertions(+), 2 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index e06db19c6d..a3e38e177d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -231,6 +231,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy); bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, RankedTensorType dstTy); +// Check if MFMA layout can be converted to the dot operand +// layout using warp shuffle. +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy); + // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3a8be9ee33..3014245e61 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/MLIRTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -644,6 +645,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } +bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto dotOperandLayout = dyn_cast(dstTy.getEncoding()); + if (!mfmaLayout || !dotOperandLayout) + return false; + + // Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case + return dotOperandLayout.getParent() == mfmaLayout && + dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() && + dotOperandLayout.getKWidth() == 8 && + getContigPerThread(mfmaLayout)[1] == 4 && + ((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) || + (mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) && + triton::type::isFloat8(srcTy.getElementType()) && + triton::type::isFloat8(dstTy.getElementType()) && + mfmaLayout.getWarpsPerCTA()[1] == 1; +} + // We get the smallest submap of srcTy^{-1} * dstTy that is not the identity // under kBlock, kWarp or kLane (in that order). The idea here is that if we // have a transformation that's the identity on kBlock, we don't need to use @@ -708,7 +728,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { // supported yet in Triton's backend. return !cvtReordersRegisters(srcTy, dstTy) && !isBlockedToDotShortcut(srcTy, dstTy) && - !matchMmaV3AndDotOperandLayout(srcTy, dstTy); + !matchMmaV3AndDotOperandLayout(srcTy, dstTy) && + // to be removed when generalized warp shuffle conversions + // are ready: + !matchMFMAAndDotOperandShuffleCase(srcTy, dstTy); } bool atomicNeedsSharedMemory(Value value) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 7e8f6b7836..4b0993c865 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -391,6 +391,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return failure(); } + // The following check can be removed when generalized warp shuffle + // conversions are ready: + if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) { + return failure(); + } + assert(cvtNeedsSharedMemory(srcTy, dstTy)); SmallVector inVals = diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 9a9764d992..bdf6db4e69 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx90a" -split-input-file | FileCheck %s +// RUN: triton-opt %s --decompose-unsupported-amd-conversions --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> @@ -27,3 +27,191 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.return } } + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 + tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) + // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) + + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 + tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 + tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + + // Input (8 values): (vec0, vec1) + // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): + // resVec0 resVec1 + // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1) + // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0) + // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) + // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> + tt.return + } +} + +// ----- + +#mfma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +#dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 + tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return + %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 208483beb8..3b61fb8cc4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -116,6 +116,158 @@ struct LocalLoadOpConversion } }; +struct ConvertLayoutOpMFMAToDotOpConversion + : public ConvertOpToLLVMPattern { +public: + explicit ConvertLayoutOpMFMAToDotOpConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(op.getSrc().getType()); + auto dstType = cast(op.getType()); + + if (!matchMFMAAndDotOperandShuffleCase(srcType, dstType)) + return failure(); + + auto loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (inVals.empty() || inVals.size() % 8 != 0) + return failure(); + + auto mfmaLayout = dyn_cast(srcType.getEncoding()); + assert((mfmaLayout.getMDim() == 16 || mfmaLayout.getMDim() == 32) && + "Expected MFMA size 16 or 32"); + assert(triton::gpu::getWarpSize(mfmaLayout) == 64 && + "Expected warp size 64 for MFMA"); + + auto elemTy = int_ty(8); + auto vecTy = vec_ty(elemTy, 4); + + Value c16 = i32_val(16); + Value c32 = i32_val(32); + Value c48 = i32_val(48); + Value c64 = i32_val(64); + + Value threadId = tid_val(); + Value laneId = urem(threadId, c64); + + Value mask0 = icmp_slt(laneId, c32); + Value mask1 = icmp_slt(urem(laneId, c32), c16); + + Value addrShift16 = urem(add(laneId, c16), c64); + Value addrShift32 = urem(add(laneId, c32), c64); + Value addrShift48 = urem(add(laneId, c48), c64); + + SmallVector outVals; + for (size_t startIdx = 0; startIdx < inVals.size(); startIdx += 8) { + Value vec0 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec0 = + insert_element(vecTy, vec0, inVals[startIdx + vIdx], i32_val(vIdx)); + } + Value vec1 = undef(vecTy); + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + vec1 = insert_element(vecTy, vec1, inVals[startIdx + vIdx + 4], + i32_val(vIdx)); + } + + Value resVec0, resVec1; + if (mfmaLayout.getMDim() == 32) { + /* + Using wave shuffle to convert layouts (32x32x16 case): + 1) Input MMA layout (32x32, fp8, 16 values): + _____________________________________________________________ + |(t0 v0 v1 v2 v3) (t32 v0 v1 v2 v3) ... (t32 v12 v13 v14 v15)| + | ... ... | + |(t31 v0 v1 v2 v3) (t63 v0 v1 v2 v3) ... (t63 v12 v13 v14 v15)| + |_____________________________________________________________| + + 2) Output Dot operand layout (two 32x16 tiles, fp8, 8 values each): + ____________________________________________________________ ___ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) (t32 v0 v1 v2 v3 v4 v5 v6 v7) || + | ... ... ||... + |(t31 v0 v1 v2 v3 v4 v5 v6 v7) (t63 v0 v1 v2 v3 v4 v5 v6 v7) || + |____________________________________________________________||___ + */ + + Value shflVec0 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + + resVec0 = select(mask0, vec0, shflVec1); + resVec1 = select(mask0, shflVec0, vec1); + } else if (mfmaLayout.getMDim() == 16) { + /* + 16x16x32 case: + 1) Input MMA layout (two 16x16, fp8, 4 values each): + _________________________________________________________ ___________ + |(t0 v0 v1 v2 v3) (t16 v0 v1 v2 v3) ... (t48 v0 v1 v2 v3)||(t0 v4 ... + | ... ... || ... + |(t15 v0 v1 v2 v3) (t31 v0 v1 v2 v3) ... (t63 v0 v1 v2 v3)||(t15 v4 ... + |_________________________________________________________||___________ + + 2) Output Dot operand layout (16x32 tile, fp8, 8 values): + ________________________________________________________________ + |(t0 v0 v1 v2 v3 v4 v5 v6 v7) ... (t48 v0 v1 v2 v3 v4 v5 v6 v7) | + | ... ... | + |(t15 v0 v1 v2 v3 v4 v5 v6 v7) ... (t63 v0 v1 v2 v3 v4 v5 v6 v7) | + |________________________________________________________________| + */ + + Value shflVec0_16 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift16), + vecTy); + Value shflVec0_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec0, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_32 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift32), + vecTy); + Value shflVec1_48 = + bitcast(targetInfo.shuffleIdx( + rewriter, loc, bitcast(vec1, int_ty(32)), addrShift48), + vecTy); + + resVec0 = select(mask0, select(mask1, vec0, shflVec0_16), + select(mask1, shflVec1_32, shflVec1_48)); + resVec1 = select(mask0, select(mask1, shflVec0_16, shflVec0_32), + select(mask1, shflVec1_48, vec1)); + } + + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, resVec0, i32_val(vIdx))); + } + for (size_t vIdx = 0; vIdx < 4; ++vIdx) { + outVals.push_back(extract_element(elemTy, resVec1, i32_val(vIdx))); + } + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + } // namespace namespace mlir::triton::AMD { @@ -124,5 +276,7 @@ void populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, + benefit); } } // namespace mlir::triton::AMD From 9743ec0dca5bbd9dbce20adc3ee273af6b095f94 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 9 Dec 2024 00:58:30 -0800 Subject: [PATCH 2/9] [FRONTEND] added support for tuples (#5220) --- python/src/ir.cc | 1 + .../test/unit/language/test_compile_errors.py | 49 ++-- python/test/unit/language/test_core.py | 23 +- python/test/unit/language/test_decorator.py | 2 +- python/test/unit/language/test_tuple.py | 100 +++++++ python/test/unit/runtime/test_bindings.py | 15 +- python/test/unit/runtime/test_cache.py | 4 +- python/test/unit/runtime/test_subproc.py | 8 +- python/test/unit/test_perf_warning.py | 5 +- python/triton/_utils.py | 49 ++++ python/triton/backends/compiler.py | 41 ++- python/triton/compiler/code_generator.py | 258 +++++++++++------- python/triton/compiler/compiler.py | 29 +- python/triton/language/__init__.py | 19 +- python/triton/language/core.py | 141 ++++++++-- python/triton/language/semantic.py | 4 +- python/triton/runtime/jit.py | 44 ++- python/triton/tools/compile.py | 16 +- third_party/amd/backend/compiler.py | 13 +- third_party/amd/backend/driver.py | 59 ++-- third_party/nvidia/backend/driver.py | 43 ++- 21 files changed, 635 insertions(+), 288 deletions(-) create mode 100644 python/test/unit/language/test_tuple.py diff --git a/python/src/ir.cc b/python/src/ir.cc index 23bb86e5eb..53ba39ae10 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) { "Function argument index out of range"); return self.getArgument(idx); }) + .def("get_num_args", &FuncOp::getNumArguments) .def( "add_entry_block", [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 1bafa551e3..2760d26bc5 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -17,7 +17,7 @@ def kernel(): a += 1 # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "is not defined" in str(e.value), "error should mention the undefined variable" @@ -32,7 +32,7 @@ def kernel(): 0 + "a" with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the 0" @@ -47,7 +47,7 @@ def kernel(): tl.static_assert(isinstance(0, tl.tensor)) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert isinstance(e.value, CompileTimeAssertionFailure) @@ -66,7 +66,7 @@ def kernel(): not (0, 0) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert e.value.__cause__ is None @@ -83,7 +83,7 @@ def kernel(): 1.0 << 1 with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the 1.0" @@ -107,7 +107,7 @@ def kernel(): nested_call() with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: inner = e.value.__cause__ @@ -130,7 +130,7 @@ def kernel(): tl.expand_dims(None, -1) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: inner = e.value.__cause__ @@ -157,7 +157,7 @@ def kernel(): a = two_returns() a + tl.arange(0, 4) # only works if we took the first return - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) def test_not_const_annotate_no_err(): @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err(): def kernel(N: int = 1): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) @triton.jit @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 4) - triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) @triton.jit def kernel2(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 8) - triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) @triton.jit @@ -211,7 +211,7 @@ def kernel(N: int): returns_branched_on_non_constexpr(N) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the function call" @@ -227,7 +227,7 @@ def kernel(): tl.arange(2, 7) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert str(e.value.__cause__) == "arange's range must be a power of 2" @@ -238,7 +238,7 @@ def kernel(): tl.full((33, ), 0, dtype=tl.int64) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" @@ -251,7 +251,7 @@ def kernel(): a = CAPTURED # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert "CAPTURED is not defined" in str(e.value) @@ -265,7 +265,7 @@ def kernel(): a = GLOBAL # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert "global variable" in str(e.value) @@ -279,7 +279,7 @@ def kernel(): a = CONSTEXPR_ANNOTATED_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) CONSTEXPR_GLOBAL = tl.constexpr(42) @@ -292,7 +292,7 @@ def kernel(): a = CONSTEXPR_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) TYPE_ALIAS = tl.pointer_type(tl.int32) @@ -305,7 +305,7 @@ def kernel(): a = TYPE_ALIAS # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) def test_global_access_in_fn_default_arg(): @@ -315,7 +315,7 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) def test_defaults_assign_no_err(): @@ -324,7 +324,7 @@ def test_defaults_assign_no_err(): def kernel(a=1, B: tl.constexpr = ""): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) def test_where_warning(fresh_triton_cache): @@ -337,7 +337,7 @@ def kernel(): tl.where(a, b, c) with pytest.warns(UserWarning): - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) @pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) @@ -369,7 +369,8 @@ def dtype_kernel(dtype: tl.constexpr): ctx = pytest.raises(CompilationError, match="") with ctx as e: - triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype})) + triton.compile( + triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) if dtype not in supported_dtypes: try: @@ -388,7 +389,7 @@ def dot_kernel(): tl.dot(a, b, max_num_imprecise_acc=128) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) try: assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") except AssertionError as assertion_err: diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2daa8aaf07..ac22cdee43 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4352,15 +4352,17 @@ def kernel(x): def test_value_specialization(value: int, value_type: str, device) -> None: def repr(specialization): - spec_type = specialization.signature["VALUE"] - return f"kernel_{spec_type}" + ty = specialization.signature["value1"] + cst = '_'.join([k for k, v in specialization.constants.items() if v == 1]) + return f"kernel_{ty}_{cst}" @triton.jit(repr=repr) - def kernel(VALUE, X): + def kernel(value1, is_one, X): pass x = torch.tensor([3.14159], device=device) - h = kernel[(1, )](value, x) + h = kernel[(1, )](value, 1, x) + assert "is_one" in h.name assert value_type in h.name @@ -6130,6 +6132,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) +def test_dtype(device): + + @triton.jit + def kernel(X): + dtype_x: tl.constexpr = X.dtype.element_ty + tl.static_assert(dtype_x == tl.int32) + tl.static_assert(dtype_x == tl.constexpr(tl.int32)) + tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) + + X = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](X) + + def test_side_effectful_scan(device): if device != "cuda": pytest.skip() diff --git a/python/test/unit/language/test_decorator.py b/python/test/unit/language/test_decorator.py index fbbfb71446..42207cc1fa 100644 --- a/python/test/unit/language/test_decorator.py +++ b/python/test/unit/language/test_decorator.py @@ -23,7 +23,7 @@ def kernel(): pass try: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) except Exception as e: pytest.fail(f"triton compile failed with error: {e}") diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py new file mode 100644 index 0000000000..863034579a --- /dev/null +++ b/python/test/unit/language/test_tuple.py @@ -0,0 +1,100 @@ +import pytest +import triton +import triton.language as tl +import torch + + +@triton.jit +def _tuple_increment(values): + for i in tl.static_range(len(values)): + values[i] = values[i] + 1 + return values + + +@triton.jit +def _tuple_index_func(Ptrs, values): + for i in tl.static_range(len(values)): + tl.store(Ptrs[i], values[i]) + + +@triton.jit +def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): + values = _tuple_increment(values) + _tuple_index_func(Ptrs, values) + + +@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) +def test_index(size, device="cuda"): + vals = tuple([i + 1 for i in range(size)]) + rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) + _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) + assert vals == tuple([x.item() - 1 for x in rets]) + + +# ---- + + +@triton.jit +def _tuple_assign(XPtrs, YPtrs, values): + # assign from tuple + X0, X1 = XPtrs + x0, x1 = values + tl.store(X0, x0) + tl.store(X1, x1) + # assign to tuple + Y0, Y1, Y2 = YPtrs + Y = Y0, Y1, Y2 + y = x0, 10, x1 + tl.store(Y[0], y[0]) + tl.store(Y[1], y[1]) + tl.store(Y[2], y[2]) + + +def test_assign(device="cuda"): + vals = (2., 3.) + x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) + y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) + _tuple_assign[(1, )](x, y, vals) + assert x[0] == vals[0] + assert x[1] == vals[1] + assert y[0] == vals[0] + assert y[1] == 10 + assert y[2] == vals[1] + + +# ------- + + +@triton.jit +def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): + tl.static_assert(tuple1[1] is None) + tl.store(Ptr + 5, cst2) + tl.store(Ptr + 6, tuple1[0]) + tl.store(Ptr + 7, tl.load(tuple1[2][0])) + tl.store(Ptr + 8, tuple1[2][1][0]) + tl.store(Ptr + 9, tl.load(tuple1[2][1][2])) + + +# test serialization/deserialization of tuple arguments in +# the frontend. +@triton.jit +def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2): + tl.static_assert(N1 is None) + tl.static_assert(tuple1[1][1] is None) + tl.store(Ptr + 0, tl.load(tuple1[0])) + tl.store(Ptr + 1, tuple1[1][0]) + tl.store(Ptr + 2, tl.load(tuple1[1][2])) + tl.store(Ptr + 3, cst1 + val1) + tl.store(Ptr + 4, tl.load(tuple2[0])) + _tuple_fn0(Ptr, 15, (-1, None, tuple1)) + + +def test_serialize(device="cuda"): + x0 = torch.tensor([8], dtype=torch.int32, device=device) + x1 = torch.tensor([12], dtype=torch.int32, device=device) + y0 = torch.tensor([10], dtype=torch.int32, device=device) + z = torch.empty((10, ), dtype=torch.int32, device=device) + # we want to check that JIT specialization propagates to tuples: + _tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, )) + ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device) + assert torch.equal(z, ref) diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index 206d132301..e621eefc01 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -63,15 +63,12 @@ def walk_fn(op): backend = triton.compiler.compiler.make_backend(target) src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={ - kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args) - if i not in kernel.constexprs - }, - constants={kernel.arg_names[i]: arg - for i, arg in enumerate(args) - if not isinstance(arg, torch.Tensor)}, - attrs=backend.get_attrs_descriptor(args, kernel.params), + signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args)}, + constexprs={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=backend.get_attrs_descriptor(kernel.params, args), ) context = triton._C.libtriton.ir.context() diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a5f381dc9f..23c943aeb1 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -592,10 +592,10 @@ def cache_hook(*args, **kwargs): JITFunction.cache_hook = cache_hook # In warmup we assume that the pointer range is 32 bits kernel_add.warmup(torch.float32, grid=(1, )) - assert pointer_range_32 == [0] + assert pointer_range_32 == [(0, )] # Torch tensor > 2GB kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) assert len(pointer_range_32) == 0 # Torch tensor <= 2GB kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) - assert pointer_range_32 == [0] + assert pointer_range_32 == [(0, )] diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 334d5d635f..ecd7227a30 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -19,8 +19,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constants={'N': 32}, - signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, + constexprs={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'}, attrs=attrs, ) triton.compile(src=src, target=target) @@ -44,7 +44,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constexprs={}) triton.compile(src=src, target=target) @@ -65,7 +65,7 @@ def empty_kernel(): import gc gc.collect() - src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constexprs={}) triton.compile(src=src, target=target) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 6646d94f50..461dcb46b4 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -92,7 +92,7 @@ def matmul_kernel( "stride_cm": "i32", "stride_cn": "i32", }, - constants={}, + constexprs={}, )) captured = capfd.readouterr() @@ -136,8 +136,9 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) "in_ptr2": "*fp16", "in_ptr3": "*fp32", "out_ptr0": "*fp16", + "XBLOCK": "constexpr", }, - constants={"XBLOCK": XBLOCK}, + constexprs={"XBLOCK": XBLOCK}, ), options={"num_warps": 1}, ) diff --git a/python/triton/_utils.py b/python/triton/_utils.py index ca60c8c3cb..0ce1a53a70 100644 --- a/python/triton/_utils.py +++ b/python/triton/_utils.py @@ -20,3 +20,52 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]: idx += size assert idx == len(flat) return ret + + +def find_paths_if(iterable, pred): + from .language import core + is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) + ret = dict() + + def _impl(current, path): + path = (path[0], ) if len(path) == 1 else tuple(path) + if is_iterable(current): + for idx, item in enumerate(current): + _impl(item, path + (idx, )) + elif pred(path, current): + if len(path) == 1: + ret[(path[0], )] = current + else: + ret[tuple(path)] = current + + if is_iterable(iterable): + _impl(iterable, []) + elif pred(list(), iterable): + ret = {tuple(): iterable} + else: + ret = dict() + return ret + + +def parse_list_string(s): + s = s.strip() + if s.startswith('[') and s.endswith(']'): + s = s[1:-1] + result = [] + current = '' + depth = 0 + for c in s: + if c == '[': + depth += 1 + current += c + elif c == ']': + depth -= 1 + current += c + elif c == ',' and depth == 0: + result.append(current.strip()) + current = '' + else: + current += c + if current.strip(): + result.append(current.strip()) + return result diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 6d33dbd6fa..4c5ac74cf2 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -3,11 +3,11 @@ import hashlib import subprocess import sysconfig - from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple, Union from types import ModuleType +from .._utils import find_paths_if # Table that associates strings to AttrsDescriptor (sub)classes. # In this way we can dynamically select the correct class @@ -52,7 +52,8 @@ class AttrsDescriptor: `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant """ - __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + __slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values', + 'constant_properties') def __init__(self, params=None, values=None): """ @@ -67,6 +68,7 @@ def __init__(self, params=None, values=None): # Default initialization self.arg_properties = {} self.property_values = {} + self.equal_to_none = {} self.constant_properties = set() self._add_common_properties(params, values) @@ -86,17 +88,30 @@ def _add_common_properties(self, params, values): assert (len(params) == len(values)) # Divisibility property - self.arg_properties["tt.divisibility"] = [ - param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + divisibility_16 = [] + for param, arg in zip(params, values): + if param.do_not_specialize or \ + param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_divisible_by_16(val)) + divisibility_16 += [(param.num, ) + x for x in paths] + self.arg_properties["tt.divisibility"] = divisibility_16 # Equal to 1 property - self.arg_properties["tt.equal_to"] = [ - param.num - for param, arg in zip(params, values) - if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize - ] + equal_to_1 = [] + for param, arg in zip(params, values): + if param.do_not_specialize: + continue + paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_equal_to_1(val)) + equal_to_1 += [(param.num, ) + x for x in paths] + self.arg_properties["tt.equal_to"] = equal_to_1 + + # Equal to None property + equal_to_none = [] + for param, arg in zip(params, values): + paths = find_paths_if(arg, lambda path, val: val is None) + equal_to_none += [(param.num, ) + x for x in paths] + self.equal_to_none = equal_to_none def _add_backend_properties(self, params=None, values=None): """ This method is for different subclasses to implement their own compile-time properties """ @@ -130,6 +145,8 @@ def get_constants(self) -> Dict: for prop_name in self.constant_properties: for p in self.arg_properties.get(prop_name, []): constants[p] = self.property_values[prop_name] + for v in self.equal_to_none: + constants[v] = None return constants def filter_out_constants(self): @@ -166,7 +183,7 @@ def from_dict(data): """ attrs_descriptor = _descriptor_table[data["cls"]]() for prop_name, param_ids in data["arg_properties"].items(): - attrs_descriptor.arg_properties[prop_name] = param_ids + attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids)) attrs_descriptor._init_slots() return attrs_descriptor diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 1c39d778ec..050e8ad0d7 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,9 +15,13 @@ from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType from triton._utils import list_list_flatten, list_list_unflatten +from functools import reduce +from .._utils import find_paths_if def mangle_ty(ty): + if ty.is_tuple(): + return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T' if ty.is_ptr(): return 'P' + mangle_ty(ty.element_ty) if ty.is_int(): @@ -56,7 +60,7 @@ def _is_triton_tensor(o: Any) -> bool: def _is_constexpr(o: Any) -> bool: - return isinstance(o, constexpr) + return o is None or isinstance(o, (constexpr, language.core.dtype)) def _is_triton_scalar(o: Any) -> bool: @@ -189,11 +193,66 @@ def visit_Call(self, node: ast.Call) -> bool: return self.visit(node.func) +class ASTFunction: + + def get_path(self, x, path): + return reduce(lambda a, idx: a[idx], path, x) + + def set_path(self, x, path, val): + prev = x if len(path) == 1 else self.get_path(x, path[:-1]) + prev[path[-1]] = val + + def __init__(self, ret_types, arg_types, constexprs, constants, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.constexprs = constexprs + self.constants = constants + self.attrs = attrs + + def serialize(self, builder: ir.builder): + # fill up IR values in template + # > build function + is_val = lambda path, _: path not in self.constexprs and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(arg_types, ret_types) + + def deserialize(self, fn): + # create "template" + def make_template(val): + if isinstance(val, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in val]) + return language.constexpr(None) + + vals = make_template(self.arg_types) + is_val = lambda path, _: path not in self.constexprs and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + # > set attributes + for attr_path, attr_specs in self.attrs.items(): + for attr_name, attr_val in attr_specs: + if attr_path in val_paths: + fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + if isinstance(ty, nv_tma_desc_type): + fn.set_arg_attr(i, "tt.nv_tma_desc", 1) + # > add IR values to the template + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + self.set_path(vals, path, language.tensor(fn.args(i), ty)) + # > add constexpr values to the template + constants = self.constants | self.constexprs + for path, val in constants.items(): + self.set_path(vals, path, language.constexpr(val)) + return vals + + class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, - noinline=False, file_name: Optional[str] = None, begin_line=0): + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, + module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, + file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -223,8 +282,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.gscope[k] = v self.lscope = {} - self.attributes = attributes - self.constants = constants self.jit_fn = jit_fn self.function_name = function_name self.is_kernel = is_kernel @@ -342,7 +399,6 @@ def visit_compound_statement(self, stmts): stmts = [stmts] for stmt in stmts: self.visit(stmt) - # Stop parsing as soon as we hit a `return` statement; everything # after this is dead code. if isinstance(stmt, ast.Return): @@ -354,7 +410,7 @@ def visit_Module(self, node): def visit_List(self, node): ctx = self.visit(node.ctx) assert ctx is None - elts = [self.visit(elt) for elt in node.elts] + elts = language.tuple([self.visit(elt) for elt in node.elts]) return elts # By design, only non-kernel functions can return @@ -363,16 +419,15 @@ def visit_Return(self, node): if ret_value is None: self.builder.ret([]) ret_ty = language.void - elif isinstance(ret_value, tuple): - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + elif isinstance(ret_value, language.tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) - ret_ty = tuple(ret_types) + ret_ty = language.tuple_type(ret_types) else: ret = language.semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type - if self.ret_type is None: self.ret_type = ret_ty elif self.ret_type != ret_ty: @@ -397,7 +452,6 @@ def visit_FunctionDef(self, node): init_node = ast.Assign(targets=[st_target], value=default_value) else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) - try: assert not self.visiting_arg_default_value self.visiting_arg_default_value = True @@ -407,34 +461,15 @@ def visit_FunctionDef(self, node): # initialize function visibility = "public" if self.is_kernel else "private" - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, - self.prototype.to_ir(self.builder), visibility, self.noinline) + fn_ty = self.prototype.serialize(self.builder) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() - arg_values = [] - idx = 0 - for i in range(len(arg_names)): - if i in self.constants: - cst = self.constants[i] - if not _is_constexpr(cst): - cst = constexpr(self.constants[i]) - arg_values.append(cst) - continue - else: - if i in self.attributes: - for name, value in self.attributes[i]: - self.fn.set_arg_attr(idx, name, value) - - # Mark this argument as a pass-by-value TMA descriptor (nvidia) - if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): - self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) - - arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) - idx += 1 - - insert_pt = self.builder.get_insertion_block() + arg_values = self.prototype.deserialize(self.fn) + # bind arguments to symbols for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) @@ -445,8 +480,11 @@ def visit_FunctionDef(self, node): self.ret_type = language.void self.builder.ret([]) else: - self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] - self.fn.reset_type(self.prototype.to_ir(self.builder)) + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = self.ret_type.types + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) self.builder.ret([ self.builder.create_poison(ty.to_ir(self.builder)) for ty in self.prototype.ret_types @@ -478,37 +516,41 @@ def visit_AnnAssign(self, node): if target in self.lscope: raise ValueError(f'{target} is already defined.' f' constexpr cannot be reassigned.') - if not _is_constexpr(value): - value = constexpr(value) + value = constexpr(value) self.lscope[target] = value return self.lscope[target] # default: call visit_Assign return self.visit_Assign(node) + def assignTarget(self, target, value): + if isinstance(target, ast.Subscript): + assert target.ctx.__class__.__name__ == "Store" + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + assert target.ctx.__class__.__name__ == "Store" + for i, name in enumerate(target.elts): + self.set_value(self.visit(name), value.values[i]) + return + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + def visit_Assign(self, node): - _names = [] - if isinstance(node, ast.AnnAssign): - _names += [self.visit(node.target)] - else: - for target in node.targets: - _names += [self.visit(target)] - if len(_names) > 1: - raise self._unsupported(node, "simultaneous multiple assignment is not supported.") - names = _names[0] - values = self.visit(node.value) - if not _is_list_like(names): - names = [names] - if not _is_list_like(values): - values = [values] - native_nontensor_types = (language.dtype, ) - for name, value in zip(names, values): - # by default, constexpr are assigned into python variable + # construct values to assign + def _sanitize_value(value): + if isinstance(value, language.tuple): + return language.tuple([_sanitize_value(v) for v in value.values]) + native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) if value is not None and \ - not _is_triton_value(value) and \ - not isinstance(value, native_nontensor_types): + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) - self.set_value(name, value) + return value + + values = _sanitize_value(self.visit(node.value)) + targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets + assert len(targets) == 1 + self.assignTarget(targets[0], values) def visit_AugAssign(self, node): name = node.target.id @@ -531,7 +573,7 @@ def visit_Load(self, node): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] - return tuple(args) + return language.tuple(args) def _apply_binary_method(self, method_name, lhs, rhs): # TODO: raise something meaningful if getattr fails below, esp for reverse method @@ -903,7 +945,7 @@ def visit_While(self, node): assert False, "Not implemented" ast.NodeVisitor.generic_visit(self, stmt) - def visit_Subscript(self, node): + def visit_Subscript_Load(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) @@ -911,6 +953,16 @@ def visit_Subscript(self, node): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] + def visit_Subscript_Store(self, node, value): + assert node.ctx.__class__.__name__ == "Store" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + assert isinstance(lhs, language.tuple) + lhs.__setitem__(slices, value) + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + def visit_ExtSlice(self, node): return [self.visit(dim) for dim in node.dims] @@ -1067,7 +1119,7 @@ def visit_Slice(self, node): lower = self.visit(node.lower) upper = self.visit(node.upper) step = self.visit(node.step) - return slice(lower, upper, step) + return language.slice(lower, upper, step) def visit_Index(self, node): return self.visit(node.value) @@ -1083,24 +1135,26 @@ def visit_Assert(self, node) -> Any: def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] - # generate function def - attributes = {} - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) + for i, arg in enumerate(args): + if isinstance(arg, (language.dtype, float, int, bool)): + args[i] = language.core.constexpr(arg) + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) + args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values() + # mangle + fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) # generate function def if necessary if not self.module.has_function(fn_name): - prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + arg_types = [ + language.core.constexpr if arg is None or isinstance(arg, + (bool, int, language.core.dtype)) else arg.type + for arg in args + ] + prototype = ASTFunction([], arg_types, args_cst, dict(), dict()) + generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn, + function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, module_map=self.builder.module_map) @@ -1115,8 +1169,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): else: callee_ret_type = self.function_ret_types[fn_name] symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: + args_val = [arg.handle for arg in args_val] + call_op = self.builder.call(symbol, args_val) + if callee_ret_type is None: return None elif call_op.get_num_results() == 1: return tensor(call_op.get_result(0), callee_ret_type) @@ -1124,8 +1179,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): # should return a tuple of tl.tensor results = [] for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) + results.append(tensor(call_op.get_result(i), callee_ret_type.types[i])) + return language.tuple(results) def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func)) @@ -1144,7 +1199,11 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: - return fn(*args, **extra_kwargs, **kws) + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret except Exception as e: # Normally when we raise a CompilationError, we raise it as # `from None`, because the original fileline from the exception @@ -1285,38 +1344,29 @@ def kernel_suffix(signature, specialization): suffix = '' for i, _ in enumerate(signature): suffix += str(i) - if i in specialization.equal_to_1: + if (i, ) in specialization.equal_to_1: suffix += 'c' - if i in specialization.divisibility_16: + if (i, ) in specialization.divisibility_16: suffix += 'd' return suffix def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): + constexprs = specialization.constexprs + arg_idx = lambda x: (fn.arg_names.index(x), ) if isinstance(x, str) else x + constants = specialization.attrs.get_constants() + constexprs = {arg_idx(k): v for k, v in constexprs.items()} + arg_types = [str_to_ty(ty) for ty in specialization.signature.values()] + # find index of constants in serialized order attrs = specialization.attrs - # create kernel prototype - cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in specialization.constants.items()} - # visit kernel AST - gscope = fn.__globals__.copy() - function_name = fn.repr(specialization) - tys = list(specialization.signature.values()) - new_constants = attrs.get_constants() - for k in new_constants: - if k in tys and tys[k] == "i1" and new_constants[k] == 1: - new_constants[k] = True - new_attrs = attrs.filter_out_constants() fn_attrs = new_attrs.get_fn_attrs() - all_constants = constants.copy() - all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + fn_attrs = {k: v for k, v in fn_attrs.items() if k not in constants} file_name, begin_line = get_jit_fn_file_line(fn) - - prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, - begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) + prototype = ASTFunction([], arg_types, constexprs, constants, fn_attrs) + generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(specialization), + jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, + codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) ret = generator.module diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index f70c46a9d4..52b8afea14 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -51,12 +51,12 @@ def convert_type_repr(x): class ASTSource: - def __init__(self, fn, signature, constants=None, attrs=None) -> None: + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ self.signature = signature - self.constants = constants + self.constexprs = constexprs self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} @@ -64,20 +64,19 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: for k in self.signature.keys(): if not isinstance(k, str): raise TypeError("Signature keys must be string") - if self.constants is None: - self.constants = {} - else: - for k in self.constants.keys(): - if not isinstance(k, str): - raise TypeError("Constants keys must be string") + if self.constexprs is None: + self.constexprs = {} if self.attrs is None: self.attrs = AttrsDescriptor() + # this is the constexprs plus the specialized constants + spec_constants = {self.fn.arg_names[k[0]]: v for k, v in self.attrs.get_constants().items() if len(k) == 1} + self.constants = self.constexprs | spec_constants def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] # Note - we stringify the keys here to allow sorting to work for cases # where constants have mixed int/str keys. - sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + sorted_constants = sorted((str(k), v) for k, v in self.constexprs.items()) key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() @@ -276,11 +275,11 @@ def compile(src, target=None, options=None): codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() - try: - module = src.make_ir(options, codegen_fns, module_map, context) - except Exception as e: - filter_traceback(e) - raise + # try: + module = src.make_ir(options, codegen_fns, module_map, context) + # except Exception as e: + # filter_traceback(e) + # raise use_ir_loc = os.environ.get("USE_IR_LOC", None) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) @@ -412,7 +411,7 @@ def launch_metadata(self, grid, stream, *args): arg_idx = 0 for i, arg_name in enumerate(self.src.fn.arg_names): if i in self.src.fn.constexprs: - arg_dict[arg_name] = self.src.constants[arg_name] + arg_dict[arg_name] = self.src.constexprs[arg_name] else: arg_dict[arg_name] = args[arg_idx] arg_idx += 1 diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0c8965fc52..5f5d464d63 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,6 +1,7 @@ """isort:skip_file""" # Import order is significant here. +from .._utils import parse_list_string from . import math from . import extra from .standard import ( @@ -69,7 +70,6 @@ float8e5, float8e5b16, full, - function_type, gather, histogram, inline_asm_elementwise, @@ -95,6 +95,7 @@ range, reduce, reshape, + slice, split, static_assert, static_print, @@ -102,6 +103,8 @@ store, tensor, trans, + tuple, + tuple_type, uint16, uint32, uint64, @@ -188,7 +191,6 @@ "floor", "fma", "full", - "function_type", "gather", "histogram", "inline_asm_elementwise", @@ -232,6 +234,7 @@ "reduce", "reshape", "rsqrt", + "slice", "sigmoid", "sin", "softmax", @@ -248,6 +251,7 @@ "tensor", "trans", "triton", + "tuple", "uint16", "uint32", "uint64", @@ -264,6 +268,9 @@ def str_to_ty(name): + if name == "none": + return None + if name[0] == "*": name = name[1:] const = False @@ -273,9 +280,17 @@ def str_to_ty(name): ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) + if name[0] == "[": + names = parse_list_string(name) + tys = [str_to_ty(x) for x in names] + return tuple_type(types=tys) + if name == "nvTmaDesc": return nv_tma_desc_type() + if name == "constexpr": + return constexpr + tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 85d5f6beba..31b19754c6 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -140,6 +140,7 @@ def __init__(self, value): self.value = value.value else: self.value = value + self.type = constexpr def __repr__(self) -> str: return f"constexpr[{self.value}]" @@ -473,6 +474,10 @@ def is_ptr(): def is_const(): return False + @staticmethod + def is_tuple(): + return False + def __eq__(self, other: dtype): if not isinstance(other, dtype): return False @@ -608,11 +613,10 @@ def __init__(self, element_ty: dtype, shape: List): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. - - assert (isinstance(shape, list)) + assert (isinstance(shape, (list, tuple))) # shape can be empty ([]) when an input is a 0D tensor. - self.shape = _unwrap_shape(shape) + self.shape = tuple(_unwrap_shape(shape)) if not self.shape: raise TypeError('0d block_type is forbidden') @@ -647,19 +651,32 @@ def scalar(self): return self.element_ty -class function_type(dtype): +class tuple_type(dtype): - def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: - self.ret_types = ret_types - self.param_types = param_types + def __init__(self, types): + self.types = types + self.name = f"[{','.join(map(str, self.types))}]" def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_types}' + return self.name + + def __iter__(self): + return iter(self.types) def to_ir(self, builder: ir.builder): - ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] - return builder.get_function_ty(ir_param_types, ret_types) + return [ty.to_ir(builder) for ty in self.types] + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def is_tuple(self): + return True + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' # scalar types @@ -761,7 +778,7 @@ def __init__(self, handle, type: dtype): self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type self.dtype = type.scalar - self.shape = [constexpr(s) for s in self.shape] + self.shape = tuple([constexpr(s) for s in self.shape]) def _flatten_ir(self): return [self.handle] @@ -982,13 +999,16 @@ def __not__(self, _builder=None): @builtin def __getitem__(self, slices, _builder=None): - if isinstance(slices, (slice, constexpr)) or slices is None: + import builtins + if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: slices = [slices] + if isinstance(slices, tuple): + slices = slices.values ret = self for dim, sl in enumerate(slices): if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) - elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None: pass else: raise ValueError(f"unsupported tensor index: {sl}") @@ -1147,6 +1167,77 @@ def flip(self, dim=None) -> tensor: ... +class tuple: + + def __init__(self, args: list): + self.values = [i for i in args] + + @property + def type(self): + + def get_type(x): + if isinstance(x, dtype): + return dtype + return x.type + + return tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + import builtins + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + # TODO: remove + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + if isinstance(other, list): + other = tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + import builtins + if isinstance(other, (list, builtins.tuple)): + other = tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + import builtins + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() + + class _experimental_tensor_descriptor_base(_value): """" A tensor descriptor with unknown shape and strides @@ -1562,7 +1653,7 @@ def expand_dims(input, axis, _builder=None): """ input = semantic.to_tensor(input, _builder) axis = _constexpr_to_value(axis) - axes = list(axis) if isinstance(axis, Sequence) else [axis] + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] new_ndim = len(input.shape) + len(axes) axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] @@ -2215,14 +2306,12 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = reduce_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] @@ -2316,14 +2405,12 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = scan_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 60890ac596..2f7dba929b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -759,14 +759,14 @@ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> # Add new axes to lhs for _ in range(len(lhs_shape), len(rhs_shape)): lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), - tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for _ in range(len(rhs_shape), len(lhs_shape)): rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), - tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d04f516e81..4ae7a918a1 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -308,6 +308,8 @@ def mangle_type(arg, is_const=False): return "fp32" elif hasattr(arg, "tma_desc_cpu_ptr"): return "nvTmaDesc" + elif isinstance(arg, tuple): + return "[" + ",".join(map(mangle_type, arg)) + "]" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -335,8 +337,8 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} import json obj = { - 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': - options.__dict__, 'key': key + 'name': name, 'signature': signature, 'constant_keys': list(constants.keys()), 'constant_vals': + list(constants.values()), 'attrs': attrs.to_dict(), 'options': options.__dict__, 'key': key } serialized_obj = json.dumps(obj) return serialized_obj @@ -368,6 +370,7 @@ def create_function_from_signature(sig, kparams, backend): func_args.append(f"{name}=default_{name}") dict_entries.append(f"'{name}': {name}") if kp.is_constexpr: + signature_types.append('"constexpr"') constexpr_vals.append(name) else: non_constexpr_vals.append(name) @@ -601,32 +604,23 @@ def run(self, *args, grid, warmup, **kwargs): # done here rather than when we build the signature as otherwise # the kernel cache key could not distinguish between byte pointers # and None arguments, resulting in a downstream mismatch: - sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigkeys = [param.name for param in self.params] sigvals = sig_and_spec[:len(sigkeys)] - signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - - configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) - constant_params = configs[0].get_constants() - constants = { - p.name: v - for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or (p.num in constant_params) or v is None - } - for i, arg in constants.items(): + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + + attrs = backend.get_attrs_descriptor(self.params, bound_vals) + constexprs = {p.name: v for (v, p) in zip(bound_vals, self.params) if p.is_constexpr} + for i, arg in constexprs.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True): return None # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) + src = self.ASTSource(self, signature, constexprs, attrs) + kernel = self.compile(src, target=target, options=options.__dict__) self.cache[device][key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -639,15 +633,11 @@ def run(self, *args, grid, warmup, **kwargs): # canonicalize grid assert grid is not None if callable(grid): - # Arguments are passed as a dict to `grid`, by contract. - # TODO(jlebar): In the new launch API, pass the compiler flags as a - # second parameter to `grid`. grid = grid(bound_args) grid_size = len(grid) grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, @@ -738,9 +728,11 @@ def preload(self, specialization_data): if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constant_keys = deserialized_obj['constant_keys'] + constant_vals = deserialized_obj['constant_vals'] constants = { key: tl.dtype(value) if tl.dtype.is_dtype(value) else value - for key, value in deserialized_obj['constants'].items() + for key, value in zip(constant_keys, constant_vals) } signature = dict(deserialized_obj['signature'].items()) src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 6adf7794cc..50483b2362 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -91,15 +91,13 @@ def constexpr(s): pass return None - hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = { - kernel.arg_names[i]: s.split(":")[0] - for i, s in enumerate(signature) - if kernel.arg_names[i] not in constants - } + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = 'constexpr' const_sig = 'x'.join([str(v) for v in constants.values()]) doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] @@ -109,8 +107,8 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) for p, v in attrs.get_constants().items(): - constants.update({kernel.arg_names[p]: v}) - src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + constants.update({kernel.arg_names[p[0]]: v}) + src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) if ccinfo.metadata.global_scratch_size > 0: @@ -126,7 +124,7 @@ def constexpr(s): arg_types.append(signature[arg_name]) arg_names_not_1.append(arg_name) arg_types_not_1.append(signature[arg_name]) - elif i in attrs.equal_to_1: + elif (i, ) in attrs.equal_to_1: arg_names.append(arg_name) arg_types.append(signature[arg_name]) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 81b07f2e7d..a8d806a8b1 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,5 +1,6 @@ from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd +from triton._utils import find_paths_if from dataclasses import dataclass from typing import Any, Dict, Tuple from types import ModuleType @@ -100,10 +101,14 @@ def _add_backend_properties(self, params=None, values=None): if params is None or values is None: return - self.arg_properties["tt.pointer_range"] = [ - param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + pointer_range = [] + for param, arg in zip(params, values): + if param.do_not_specialize or \ + param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, lambda path, val: HIPAttrsDescriptor.is_within2gb(val)) + pointer_range += [(param.num, ) + x for x in paths] + self.arg_properties["tt.pointer_range"] = pointer_range @staticmethod def is_within2gb(arg): diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 99e5509eca..965341b96e 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -8,6 +8,7 @@ from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -164,7 +165,7 @@ def __init__(self): # -------------------- Launcher ---------------------------- def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "hipDeviceptr_t" return { "i1": "int32_t", @@ -186,32 +187,27 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids, warp_size): - start_desc = len(signature) - #signature = generate_cu_signature(constants, signature, ids) - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" - return { - 'i1': 'int32_t', - 'i8': 'int8_t', - 'i16': 'int16_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u1': 'uint32_t', - 'u8': 'uint8_t', - 'u16': 'uint16_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" + return ty_to_cpp(ty) def format_of(ty): + if ty == "hipDeviceptr_t": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -227,14 +223,22 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) libhip_path = _get_path_to_hip_runtime_dylib() # generate glue code - params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params = list(range(len(signature))) + params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] params.append("&global_scratch") src = f""" #define __HIP_PLATFORM_AMD__ @@ -416,8 +420,8 @@ def format_of(ty): // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); @@ -468,9 +472,8 @@ class HIPLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids, metadata.warp_size) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 196f189caa..468a2e9dea 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -10,6 +10,7 @@ from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -95,7 +96,7 @@ def __init__(self): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "CUdeviceptr" return { "i1": "int32_t", @@ -118,19 +119,29 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" if ty == "nvTmaDesc": return "PyObject*" - + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" return ty_to_cpp(ty) def format_of(ty): + if ty == "CUdeviceptr": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -146,22 +157,29 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) internal_args_list = [] for i, ty in signature.items(): - if ty[0] == "*": + if ty[0] == "*" or ty == "none": internal_args_list.append(f"ptr_info{i}.dev_ptr") elif ty == "nvTmaDesc": # Note: we have to dereference the pointer internal_args_list.append(f"*tma_ptr{i}") else: internal_args_list.append(f"_arg{i}") + params = range(len(signature)) # generate glue code - params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] params.append("&global_scratch") src = f""" #include \"cuda.h\" @@ -395,7 +413,7 @@ def format_of(ty): }} // raise exception asap - {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); @@ -446,9 +464,8 @@ class CudaLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch From 105cb56487cd8a433b8fbfe9cc63c1f1c04a4b2a Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 9 Dec 2024 10:23:42 +0100 Subject: [PATCH 3/9] Use `get_current_target` function to select the device to run tutorials on (#5286) This pull request contains changes for all tutorials except `09-persistent-matmul.py`, as there is a lot of cuda-specific function. --------- Signed-off-by: Anatoly Myachev --- python/triton/backends/driver.py | 9 +++-- python/tutorials/01-vector-add.py | 12 ++++--- python/tutorials/02-fused-softmax.py | 13 ++++---- python/tutorials/03-matrix-multiplication.py | 14 ++++---- python/tutorials/04-low-memory-dropout.py | 8 +++-- python/tutorials/05-layer-norm.py | 6 ++-- python/tutorials/06-fused-attention.py | 12 ++++--- python/tutorials/07-extern-functions.py | 6 ++-- python/tutorials/08-grouped-gemm.py | 35 ++++++++++---------- third_party/amd/backend/driver.py | 5 +++ third_party/nvidia/backend/driver.py | 4 +++ 11 files changed, 76 insertions(+), 48 deletions(-) diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 202ae15686..6606b21ca8 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -1,4 +1,4 @@ -from abc import ABCMeta, abstractmethod, abstractclassmethod +from abc import ABCMeta, abstractmethod from typing import Callable, List, Protocol, Sequence @@ -10,7 +10,8 @@ def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) - class DriverBase(metaclass=ABCMeta): - @abstractclassmethod + @classmethod + @abstractmethod def is_active(self): pass @@ -18,6 +19,10 @@ def is_active(self): def get_current_target(self): pass + @abstractmethod + def get_active_torch_device(self): + pass + @abstractmethod def get_benchmarker(self) -> Benchmarker: """ diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index e0220a45ce..e527e5fc7a 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,6 +23,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -60,7 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def add(x: torch.Tensor, y: torch.Tensor): # We need to preallocate the output. output = torch.empty_like(x) - assert x.is_cuda and y.is_cuda and output.is_cuda + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. @@ -81,8 +83,8 @@ def add(x: torch.Tensor, y: torch.Tensor): torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='cuda') -y = torch.rand(size, device='cuda') +x = torch.rand(size, device=DEVICE) +y = torch.rand(size, device=DEVICE) output_torch = x + y output_triton = add(x, y) print(output_torch) @@ -116,8 +118,8 @@ def add(x: torch.Tensor, y: torch.Tensor): args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): - x = torch.rand(size, device='cuda', dtype=torch.float32) - y = torch.rand(size, device='cuda', dtype=torch.float32) + x = torch.rand(size, device=DEVICE, dtype=torch.float32) + y = torch.rand(size, device=DEVICE, dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] if provider == 'torch': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index c980425597..06d5cd0e3f 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -27,6 +27,8 @@ import triton.language as tl from triton.runtime import driver +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -110,8 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n # %% # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. -device = torch.cuda.current_device() -properties = driver.active.utils.get_device_properties(device) +properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] @@ -189,7 +190,7 @@ def softmax(x): # This will allow us to verify that our padding mechanism works. torch.manual_seed(0) -x = torch.randn(1823, 781, device='cuda') +x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) @@ -221,9 +222,9 @@ def softmax(x): args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): - x = torch.randn(M, N, device='cuda', dtype=torch.float32) - stream = torch.cuda.Stream() - torch.cuda.set_stream(stream) + x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) + stream = getattr(torch, DEVICE.type).Stream() + getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 8153509055..7b838b17a2 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -154,6 +154,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" @@ -355,8 +357,8 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((512, 512), device='cuda', dtype=torch.float16) -b = torch.randn((512, 512), device='cuda', dtype=torch.float16) +a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output_with_fp16_inputs={triton_output}") @@ -373,8 +375,8 @@ def matmul(a, b, activation=""): TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") if TORCH_HAS_FP8 and is_cuda(): torch.manual_seed(0) - a = torch.randn((512, 512), device="cuda", dtype=torch.float16) - b = torch.randn((512, 512), device="cuda", dtype=torch.float16) + a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) a = a.to(torch.float8_e5m2) # pre-transpose b for efficiency. b = b.T @@ -423,8 +425,8 @@ def matmul(a, b, activation=""): @triton.testing.perf_report(configs) def benchmark(M, N, K, provider, fp8_inputs): - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) if TORCH_HAS_FP8 and fp8_inputs: a = a.to(torch.float8_e5m2) b = b.T diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index a8bfc46a16..3dd84da47e 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -38,6 +38,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def _dropout( @@ -71,10 +73,10 @@ def dropout(x, x_keep, p): # Input tensor -x = torch.randn(size=(10, )).cuda() +x = torch.randn(size=(10, ), device=DEVICE) # Dropout mask p = 0.5 -x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda() +x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32) # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ @@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed): return output -x = torch.randn(size=(10, )).cuda() +x = torch.randn(size=(10, ), device=DEVICE) # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index a234153a04..5be07a9ea7 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -42,6 +42,8 @@ except ModuleNotFoundError: HAS_APEX = False +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def _layer_norm_fwd_fused( @@ -290,7 +292,7 @@ def backward(ctx, dy): layer_norm = LayerNorm.apply -def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): +def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) @@ -328,7 +330,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): plot_name='layer-norm-backward', args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, )) -def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): +def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE): # create data x_shape = (M, N) w_shape = (x_shape[-1], ) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 09efc06de4..1ddb2eef17 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -19,6 +19,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" @@ -526,13 +528,13 @@ def backward(ctx, do): @pytest.mark.parametrize("causal", [True]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): torch.manual_seed(20) - q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) - v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -599,7 +601,7 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device=DEVICE): assert mode in ["fwd", "bwd"] dtype = torch.float16 if "triton" in provider: diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index bf5f0acf96..800563701f 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -25,6 +25,8 @@ from pathlib import Path +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.jit def asin_kernel( @@ -49,8 +51,8 @@ def asin_kernel( torch.manual_seed(0) size = 98432 -x = torch.rand(size, device='cuda') -output_triton = torch.zeros(size, device='cuda') +x = torch.rand(size, device=DEVICE) +output_triton = torch.zeros(size, device=DEVICE) output_torch = torch.asin(x) assert x.is_cuda and output_triton.is_cuda n_elements = output_torch.numel() diff --git a/python/tutorials/08-grouped-gemm.py b/python/tutorials/08-grouped-gemm.py index 43be11382f..57fd0722c4 100644 --- a/python/tutorials/08-grouped-gemm.py +++ b/python/tutorials/08-grouped-gemm.py @@ -31,6 +31,8 @@ import triton import triton.language as tl +DEVICE = triton.runtime.driver.active.get_active_torch_device() + @triton.autotune( configs=[ @@ -141,7 +143,6 @@ def grouped_matmul_kernel( def group_gemm_fn(group_A, group_B): - device = torch.device('cuda') assert len(group_A) == len(group_B) group_size = len(group_A) @@ -157,7 +158,7 @@ def group_gemm_fn(group_A, group_B): assert A.shape[1] == B.shape[0] M, K = A.shape K, N = B.shape - C = torch.empty((M, N), device=device, dtype=A.dtype) + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) @@ -166,11 +167,11 @@ def group_gemm_fn(group_A, group_B): g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors - d_a_ptrs = torch.tensor(A_addrs, device=device) - d_b_ptrs = torch.tensor(B_addrs, device=device) - d_c_ptrs = torch.tensor(C_addrs, device=device) - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) # we use a fixed number of CTA, and it's auto-tunable grid = lambda META: (META['NUM_SM'], ) grouped_matmul_kernel[grid]( @@ -197,8 +198,8 @@ def group_gemm_fn(group_A, group_B): M = group_m[i] N = group_n[i] K = group_k[i] - A = torch.rand((M, K), device="cuda", dtype=torch.float16) - B = torch.rand((K, N), device="cuda", dtype=torch.float16) + A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) group_A.append(A) group_B.append(B) @@ -255,9 +256,9 @@ def benchmark(N, provider): g_lds = [] group_C = [] for i in range(group_size): - A = torch.rand((N, N), device="cuda", dtype=torch.float16) - B = torch.rand((N, N), device="cuda", dtype=torch.float16) - C = torch.empty((N, N), device="cuda", dtype=torch.float16) + A = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + B = torch.rand((N, N), device=DEVICE, dtype=torch.float16) + C = torch.empty((N, N), device=DEVICE, dtype=torch.float16) group_A.append(A) group_B.append(B) group_C.append(C) @@ -267,11 +268,11 @@ def benchmark(N, provider): g_sizes += [N, N, N] g_lds += [N, N, N] - d_a_ptrs = torch.tensor(A_addrs, device="cuda") - d_b_ptrs = torch.tensor(B_addrs, device="cuda") - d_c_ptrs = torch.tensor(C_addrs, device="cuda") - d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda") - d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda") + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 965341b96e..dc424caddb 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -505,6 +505,11 @@ def get_current_target(self): warp_size = device_properties['warpSize'] return GPUTarget("hip", arch.split(':')[0], warp_size) + def get_active_torch_device(self): + import torch + # when using hip devices, the device string in pytorch is "cuda" + return torch.device("cuda", self.get_current_device()) + def get_benchmarker(self): from triton.testing import do_bench return do_bench diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 468a2e9dea..ee440bd4f6 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -496,6 +496,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("cuda", capability, warp_size) + def get_active_torch_device(self): + import torch + return torch.device("cuda", self.get_current_device()) + def get_device_interface(self): import torch return torch.cuda From 07e1cc632cef7df02e54269b11b1fa0e9ed019d5 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 9 Dec 2024 15:01:07 +0100 Subject: [PATCH 4/9] [NFC] Remove duplicate libraries in `bin/CMakeLists.txt` (#5370) All deleted libraries are either in `${triton_libs}` or in `${conversion_libs}`. Signed-off-by: Anatoly Myachev --- bin/CMakeLists.txt | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index fa84e9fd69..ec5cc0c544 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,12 +7,6 @@ add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED) # TODO: what's this? llvm_update_compile_flags(triton-opt) target_link_libraries(triton-opt PRIVATE - TritonLLVMIR - TritonAnalysis - TritonTransforms - TritonGPUTransforms - TritonNvidiaGPUTransforms - MLIRGPUToROCDLTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} @@ -31,11 +25,6 @@ mlir_check_all_link_libraries(triton-reduce) llvm_update_compile_flags(triton-reduce) target_link_libraries(triton-reduce PRIVATE - TritonLLVMIR - TritonAnalysis - TritonTransforms - TritonGPUTransforms - TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} @@ -53,10 +42,6 @@ add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED) llvm_update_compile_flags(triton-lsp) target_link_libraries(triton-lsp PRIVATE - TritonAnalysis - TritonTransforms - TritonGPUTransforms - TritonNvidiaGPUTransforms ${dialect_libs} ${conversion_libs} ${triton_libs} @@ -93,8 +78,6 @@ export_executable_symbols_for_plugins(triton-llvm-opt) add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED) target_link_libraries(triton-tensor-layout PRIVATE - TritonGPUIR - TritonNvidiaGPUIR ${triton_libs} ${conversion_libs} ${dialect_libs} From 2626f2f1281c3ec093d8dac4b1475801429774be Mon Sep 17 00:00:00 2001 From: pawelszczerbuk <153013546+pawelszczerbuk@users.noreply.github.com> Date: Mon, 9 Dec 2024 08:24:54 -0800 Subject: [PATCH 5/9] [PIPELINING] Fix stage for the local_load in the TMA pipelining (#5365) `local_load` should be in the same stage that the `subview` that it is using. --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f0fe8d43f4..548702fdc0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -272,7 +272,7 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, builder.setInsertionPointAfter(viewLoad); auto sharedLoad = builder.createWithStage( - loc, stage, clusterId, loadOp.getType(), + loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(), viewLoad /*,wait->getResult(0)*/); auto result = sharedLoad->getResults(); loadOp->replaceAllUsesWith(result); From e3d3851ed51644245ff44067d0239db4613aec36 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 9 Dec 2024 14:29:23 -0500 Subject: [PATCH 6/9] [BACKEND] Use linear layout for loading mmav2 dot operand tensors from shared memory (#5154) --- .../Conversion/TritonGPUToLLVM/Utility.h | 12 +- include/triton/Tools/LinearLayout.h | 7 + .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 17 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 170 +++++++++++++++--- .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 2 +- lib/Tools/LinearLayout.cpp | 12 ++ test/Conversion/tritongpu_to_llvm.mlir | 2 + .../LoadStoreOpToLLVM.cpp | 4 +- 8 files changed, 176 insertions(+), 50 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index a1c37efb52..4b2179611a 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -1154,15 +1154,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, // Returns true on success. [[nodiscard]] bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, - Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback); inline DenseMap getSwizzledSharedPtrs( Location loc, const TargetInfoBase &target, unsigned inVec, RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, - Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, + Type resElemTy, const SharedMemoryObject &smemObj, RewriterBase &rewriter, ArrayRef offsetVals, ArrayRef srcStrides) { // This utility computes the pointers for accessing the provided swizzled // shared memory layout `resSharedLayout`. More specifically, it computes, @@ -1324,14 +1324,14 @@ inline DenseMap getSwizzledSharedPtrs( SmallVector loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy, Type elemLlvmTy, - SharedMemoryObject smemObj, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target); void storeDistributedToShared( triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, - ArrayRef srcVals, Value smemBase, ArrayRef dstStrides, - Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + ArrayRef srcVals, const SharedMemoryObject &smemObj, Location loc, + RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount = nullptr); inline Value getStructFromSharedMemoryObject(Location loc, diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 9ddec88812..aa831bcc35 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -414,6 +414,10 @@ class LinearLayout { bool isSurjective() const { return surjective; } + bool isInvertible() const { + return surjective && getTotalInDimSize() == getTotalOutDimSize(); + } + const BasesT &getBases() const { return bases; } // Get the pos'th basis vector for the inDim -> outDim mapping. @@ -673,6 +677,9 @@ class LinearLayout { // don't place any guarantees on the choices made by this function. [[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const; + // Get the layout that is the inverse of this layout. + [[nodiscard]] LinearLayout invert() const; + // For each in-dim, returns a bitmask of the "free variables" in the layout // function. // diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 1e6e1c1fd7..81a4d6f6bf 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -25,11 +25,9 @@ void lowerDistributedToShared( auto outOrd = mlir::cast(dstTy.getEncoding()).getOrder(); auto elemTy = typeConverter->convertType(srcTy.getElementType()); - auto smemBase = smemObj.getBase(); - auto dstStrides = smemObj.getStrides(); auto inVals = unpackLLElements(loc, adaptorSrc, rewriter); - storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides, - loc, rewriter, targetInfo, llvmOpCount); + storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter, + targetInfo, llvmOpCount); } struct GlobalScratchAllocOpConversion @@ -157,14 +155,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { // If we remove this one, ldmatrix will IMA. It can probably be relaxed // though canUseLdmatrix &= - srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; - // To be removed in https://github.com/triton-lang/triton/pull/5154 - bool legacyLoweringIsBuggy = - (kWidth >= 8 || (kWidth == 4 && bitwidth == 32) || - dstTy.getRank() == 3) && - mma.isAmpere(); - return (mma.isHopper() && !canUseLdmatrix) || - (mma.isAmpere() && legacyLoweringIsBuggy); + srcTy.getShape()[0] >= 8 && + srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2; + return !canUseLdmatrix; } if (isa(dot.getParent())) return true; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index a310cdba5f..02b3b121f4 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -169,10 +169,139 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, return ret; } +namespace { + +Value getSmemVecAddr(RankedTensorType registerTy, + triton::gpu::MemDescType sharedTy, Type elemLlvmTy, + Location loc, RewriterBase &rewriter, + const LinearLayout ®ToSharedLayout, Value regId, + Value laneId, Value warpId, + const SharedMemoryObject &smemObj) { + MLIRContext *ctx = rewriter.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + auto shape = sharedTy.getShape(); + auto rank = shape.size(); + auto allocShape = sharedTy.getAllocShape(); + auto sharedEnc = + dyn_cast(sharedTy.getEncoding()); + + auto smemBase = smemObj.getBase(); + auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); + auto smemOffsets = smemObj.getOffsets(); + auto smemStrides = smemObj.getStrides(); + Value smemOffset; + // When loading or storing to shared memory, we consider two cases for + // performance reasons: + // + // 1. Non-swizzled shared memory. + // 2. Swizzled shared memory. + // + // Consider lowering `ttg.local_load %a`. In the first case, we can + // directly construct a linear layout using `%a`'s shape and shared memory + // encoding, irrespective of `%a`'s rank or whether it represents a slice of a + // larger tensor. + // + // The method does not apply for swizzled shared memory in some scenarios. + // Key properties of swizzling in Triton are: + // + // - Swizzling applies only to tensors with rank ≥ 2. + // - It is restricted to the last two dimensions of the tensor. + // - These last two dimensions are always treated as the most "minor." + // + // An important edge case arises when `%a` results from `%a = ttg.subview %b`, + // where `%b` is swizzled (and so is `%a`). In this case, constructing a + // layout and determining shared memory offsets using `%a`'s shape is + // incorrect. This is because swizzling depends on the original shape of `%b`, + // which differs from `%a`'s shape. As a result, some locations may fall + // outside `%a`'s contiguous view of memory. Specifically, an element `[i + // (row_idx), j (col_idx)]` in `%a` might map to `[i, j']` after swizzling, + // where `j'` lies outside `%a`'s shape but still within `%b`'s shape. + // + // We propose case 2 (see comments below), which provides a more general + // solution for all swizzled shared memory scenarios, including the edge case + // mentioned above. + if (/*no swizzling*/ sharedEnc.getMaxPhase() == 1 || + /*swizzling but same shape*/ shape == allocShape || + /*swizzling and rank-reduced and rank >= 2*/ + (shape == allocShape.take_back(rank) && rank >= 2)) { // Case 1 + // Get the address to load/store. The multi-dim address is (offsetX1, ..., + // offsetXN, block), where the offsets appear in minor-to-major order, and + // we drop_end to drop block, which we know from above will be 0. + smemOffsets = llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})))); + // Reorder strides according to `order`. This way they match the + // multi-dimensional offsets in regToSharedLayout. + smemOffset = dot(rewriter, loc, smemOffsets, + applyPermutation(smemStrides, sharedOrder)); + } else { // Case 2 -> rank-reduced swizzling + assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2"); + // We define both tensor offsets and shared memory offsets: + // + // - Tensor offsets: Relative offsets within a given tensor. + // - Shared memory offsets: Absolute offsets within the shared memory. + // + // In Triton, the shared memory layout provides an invertible, one-to-one + // mapping between tensor offsets and shared memory offsets. The `base` + // field of any shared memory object represents both the shared memory + // offset and the tensor offset relative to the original tensor at + // allocation, prior to any subview operations. + // + // To determine the shared memory offsets for a specific register when + // dealing with swizzled and sliced tensors, the process involves: + // + // 1. Retrieving the original tensor's `invertAllocSharedLayout`, which + // maps the allocated tensor's offsets back to shared memory offsets. + // 2. Reconstructing the register's offsets in the allocated tensor by + // summing: + // - The shared memory offsets of the current view's base, and + // - The relative tensor offsets of the register. + // + // This approach ensures that "absolute" tensor offsets can be + // mapped to the correct shared memory addresses using + // `invertAllocSharedLayout`. + std::optional regLayout = + triton::gpu::toLinearLayout(shape, registerTy.getEncoding()); + auto allocSharedLayout = triton::gpu::toLinearLayout( + allocShape.take_back(rank), sharedTy.getEncoding(), + elemLlvmTy.getIntOrFloatBitWidth()); + assert(allocSharedLayout.has_value() && + "Failed to convert layout to linear layout"); + auto invertAllocSharedLayout = allocSharedLayout->invert(); + auto multiDimTensorOffsets = + llvm::to_vector(applyLinearLayout(loc, rewriter, *regLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})); + for (auto i = 0; i < rank; i++) { + multiDimTensorOffsets[i].second = + add(multiDimTensorOffsets[i].second, smemOffsets[i]); + } + smemOffset = applyLinearLayout(loc, rewriter, invertAllocSharedLayout, + multiDimTensorOffsets)[0] + .second; + Value baseToAllocBaseDist = dot(rewriter, loc, smemOffsets, smemStrides); + smemOffset = sub(smemOffset, baseToAllocBaseDist); + } + auto ptrTy = smemBase.getType(); + auto vecAddr = gep(ptrTy, elemLlvmTy, smemBase, smemOffset); + vecAddr.setInbounds(true); + return vecAddr; +} + +} // namespace + bool emitTransferBetweenRegistersAndShared( RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, - Type elemLlvmTy, std::optional maxVecElems, Value shmemBase, - ArrayRef shmemStrides, Location loc, RewriterBase &rewriter, + Type elemLlvmTy, std::optional maxVecElems, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback) { MLIRContext *ctx = rewriter.getContext(); @@ -230,28 +359,12 @@ bool emitTransferBetweenRegistersAndShared( int numElems = regToSharedLayout->getInDimSize(kRegister); auto vecTy = vec_ty(elemLlvmTy, vecElems); - auto ptrTy = shmemBase.getType(); Value zero = i32_val(0); SmallVector ret; for (int i = 0; i < numElems / vecElems; i++) { - // Get the address to load/store. The multi-dim address is (offsetX1, ..., - // offsetXN, block), where the offsets appear in minor-to-major order, and - // we drop_end to drop block, which we know from above will be 0. - auto multiDimShmemOffset = - llvm::to_vector(llvm::drop_end(llvm::make_second_range( - applyLinearLayout(loc, rewriter, *regToSharedLayout, - {{kRegister, i32_val(i * vecElems)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, zero}})))); - - // Reorder strides according to `order`. This way they match the - // multi-dimensional offsets in regToSharedLayout. - auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding()); - Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset, - applyPermutation(shmemStrides, sharedOrder)); - auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset); - vecAddr.setInbounds(true); + auto vecAddr = getSmemVecAddr( + registerTy, sharedTy, elemLlvmTy, loc, rewriter, *regToSharedLayout, + i32_val(i * vecElems), laneId, warpId, smemObj); perVectorCallback(vecTy, vecAddr); } @@ -261,14 +374,13 @@ bool emitTransferBetweenRegistersAndShared( SmallVector loadSharedToDistributed(RankedTensorType dstTy, triton::gpu::MemDescType srcTy, Type elemLlvmTy, - SharedMemoryObject smemObj, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target) { SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( - dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(), - smemObj.getStrides(), loc, rewriter, target, - [&](VectorType vecTy, Value vecAddr) { + dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { auto vecVal = load(vecTy, vecAddr); vecVal.setAlignment(vecTy.getNumElements() * elemLlvmTy.getIntOrFloatBitWidth() / 8); @@ -285,14 +397,14 @@ SmallVector loadSharedToDistributed(RankedTensorType dstTy, void storeDistributedToShared(triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, - ArrayRef srcVals, Value smemBase, - ArrayRef dstStrides, Location loc, + ArrayRef srcVals, + const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::pair *const llvmOpCount) { bool success = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase, - dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) { + srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, + rewriter, target, [&](VectorType vecTy, Value vecAddr) { ArrayRef vals = srcVals.take_front(vecTy.getNumElements()); srcVals = srcVals.drop_front(vecTy.getNumElements()); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index ea05490c7a..1dcba27e15 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -394,7 +394,7 @@ struct MemDescSubviewOpConversion int rankReduced = srcTy.getRank() - destRank; for (int i = rankReduced; i < opOffsetVals.size(); i++) { strides.push_back(smemObj.strides[i]); - offsetVals.push_back(opOffsetVals[i]); + offsetVals.push_back(add(opOffsetVals[i], smemObj.offsets[i])); } // Compute the offset based on the original strides of the shared memory // object diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 0ab563908a..984625e895 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -313,6 +313,7 @@ LinearLayout::checkInvariants(bool requireSurjective) { "can be reached by some `in` coordinate, but was not:" + toString(); } + return std::nullopt; } @@ -955,6 +956,17 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const { return ret; } +LinearLayout LinearLayout::invert() const { + // A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A) + assert(isInvertible() && + "A linear layout must be surjective and square to be invertible"); + LinearLayout identity = LinearLayout::empty(); + for (auto outDim : getOutDimNames()) { + identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim); + } + return identity.invertAndCompose(*this); +} + llvm::MapVector LinearLayout::getFreeVariableMasks() const { std::unique_ptr mat = getMatrix(*this); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index a97ac476cb..fe178e5758 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -469,6 +469,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue + // CHECK-NEXT: llvm.add + // CHECK-NEXT: llvm.add // CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32 // CHECK-NEXT: llvm.mul // CHECK-NEXT: llvm.add diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index d2cef405eb..f21120dbca 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -935,8 +935,8 @@ struct AsyncCopyGlobalToLocalOpConversion VectorType vecTy; SmallVector shmemAddrs; bool ok = emitTransferBetweenRegistersAndShared( - srcTy, dstTy, resElemTy, maxVec, smemObj.base, smemObj.strides, loc, - rewriter, targetInfo, [&](VectorType vecTy_, Value shmemAddr) { + srcTy, dstTy, resElemTy, maxVec, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { vecTy = vecTy_; shmemAddrs.push_back(shmemAddr); }); From 5700c1468773d224075597f53710a79a796d5fd2 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Mon, 9 Dec 2024 19:43:26 -0500 Subject: [PATCH 7/9] [FRONTEND] Fix and improve minimum dot size checks (#5383) 1. Fix the problem that [m, k, n] but not [m, n, k] is returned on the nvidia backend 2. Check both int8 and float8 3. Add a new compiler error test 4. Fix dtype check in AMD backend --- .../test/unit/language/test_compile_errors.py | 38 ++++++++++++++++++- python/triton/language/semantic.py | 1 + third_party/amd/backend/compiler.py | 3 +- third_party/nvidia/backend/compiler.py | 12 +++++- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 2760d26bc5..efb6dab02f 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -7,7 +7,7 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback -from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300 +from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_hip_mi200 def test_err_undefined_variable(): @@ -379,6 +379,42 @@ def dtype_kernel(dtype: tl.constexpr): raise assertion_err from e.value +@pytest.mark.parametrize("dtype", [tl.float8e5, tl.int8, tl.float16]) +def test_min_dot_size(dtype): + error_msg = "Input shapes should have " + if is_cuda(): + if dtype.primitive_bitwidth == 8: + error_msg += "M >= 16, N >= 16 and K >= 32" + else: + error_msg = "M >= 16, N >= 16 and K >= 16" + elif is_hip_mi300(): + if dtype.is_int8(): + error_msg += "M >= 16, N >= 16 and K >= 16" + else: + error_msg += "M >= 16, N >= 16 and K >= 8" + elif is_hip_mi200(): + error_msg += "M >= 16, N >= 16 and K >= 8" + elif is_hip(): + error_msg = "M >= 16, N >= 16 and K >= 16" + else: + pytest.skip("Test only supported on CUDA and HIP") + + @triton.jit + def dot_kernel(dtype: tl.constexpr): + SIZE: tl.constexpr = 8 + a = tl.full((SIZE, SIZE), 0.0, dtype) + b = tl.full((SIZE, SIZE), 0.0, dtype) + tl.dot(a, b) + + with pytest.raises(CompilationError) as e: + triton.compile( + triton.compiler.ASTSource(fn=dot_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + try: + assert (error_msg in str(e.value.__cause__)) + except AssertionError as assertion_err: + raise assertion_err from e.value + + def test_max_num_imprecise_acc_limit(): @triton.jit diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 2f7dba929b..f5fea3d3b2 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1473,6 +1473,7 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + # We upcast because there's no fp8e4b15 type in MLIR lhs = cast(lhs, tl.float16, builder) rhs = cast(rhs, tl.float16, builder) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index a8d806a8b1..91955c921e 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -18,7 +18,8 @@ def min_dot_size(target: GPUTarget): # CDNA 3.0 supports k==8 in all mfma variants except for int8 # (where the smallest `k` supported is 16) if "gfx94" in arch_str: - return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.is_int8() or rhsType.is_int8()) else (16, 16, 8) + return lambda lhsType, rhsType: (16, 16, 16) if (lhsType.scalar.is_int8() or rhsType.scalar.is_int8()) else ( + 16, 16, 8) # CDNA 2.0 always supports `k==8` if "gfx9" in arch_str: return lambda lhsType, rhsType: (16, 16, 8) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index d94be93872..137fef4bd0 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -17,7 +17,17 @@ def min_dot_size(target: GPUTarget): - return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16) + + def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k] + lhs_bitwidth = lhs_type.scalar.primitive_bitwidth + rhs_bitwidth = rhs_type.scalar.primitive_bitwidth + assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same" + if lhs_bitwidth == 8: + return (16, 16, 32) + else: + return (16, 16, 16) + + return check_dot_compatibility @functools.lru_cache() From 4d2e9e5de96a5d6ea163f2de04ae5c5b6be45825 Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Tue, 10 Dec 2024 00:43:46 +0000 Subject: [PATCH 8/9] [FRONTEND] Fix bitcast with constexpr dtype (#5382) Fixes #5364 --- python/test/unit/language/test_core.py | 20 +++++++++++++------- python/triton/_internal_testing.py | 6 +++++- python/triton/language/core.py | 13 ++++--------- python/triton/language/semantic.py | 4 ---- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ac22cdee43..5e60d9fd14 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -21,6 +21,7 @@ from triton._internal_testing import ( integral_dtypes, int_dtypes, + str_to_triton_dtype, uint_dtypes, float_dtypes, float_dtypes_with_bfloat16, @@ -1641,7 +1642,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr): ('float32', 'bfloat16', False, 1024), ('bfloat16', 'float32', False, 1024), ('float32', 'int32', True, 1024), - ('float32', 'int1', False, 1024), + ('float32', 'bool', False, 1024), ('int8', 'bfloat16', False, 1024), ] + [(f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) @@ -1687,19 +1688,21 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): # triton kernel @triton.jit - def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): + def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr): x_ptr = X + tl.arange(0, SIZE) z_ptr = Z + tl.arange(0, SIZE) x = tl.load(x_ptr) # Depending on the value of ARG_HASH (a "random" number determined by # the test parameters), spell the cast one of three different ways. - if ARG_HASH % 3 == 0: + if ARG_HASH % 4 == 0: z = x.to(Z.dtype.element_ty, bitcast=BITCAST) - elif ARG_HASH % 3 == 1: + elif ARG_HASH % 4 == 1: z = x.cast(Z.dtype.element_ty, bitcast=BITCAST) - else: + elif ARG_HASH % 4 == 2: z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST) + else: + z = tl.cast(x, TO_TYPE, bitcast=BITCAST) tl.store(z_ptr, z) @@ -1707,7 +1710,7 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constex # This way we don't have to increase the number of tests. arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas)) - dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' + dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_' # triton result if dtype_z.startswith('bfloat'): z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) @@ -1715,7 +1718,10 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constex z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) - kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas) + + dtype_z_tri = str_to_triton_dtype(dtype_z) + kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, + num_ctas=num_ctas) # torch result if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( 'float8') or dtype_x.startswith('float8'): diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 377eed877a..87836a886c 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -9,7 +9,7 @@ from numpy.random import RandomState from typing import Optional, Union -from triton.runtime.jit import TensorWrapper, reinterpret +from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] @@ -119,6 +119,10 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc return torch.tensor(x, device=device) +def str_to_triton_dtype(x: str) -> tl.dtype: + return tl.str_to_ty(type_canonicalisation_dict[x]) + + def torch_dtype_name(dtype) -> str: if isinstance(dtype, triton.language.dtype): return dtype.name diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 31b19754c6..dcce42908c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1024,13 +1024,7 @@ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: """ Alias for :py:func:`tensor.cast`. """ - # Triton doesn't like core functions calling other core functions, so we - # just copy-paste the implementation of cast here. It's not too bad. - dtype = _unwrap_if_constexpr(dtype) - bitcast = _unwrap_if_constexpr(bitcast) - if bitcast: - return semantic.bitcast(self, dtype, _builder) - return semantic.cast(self, dtype, _builder, fp_downcast_rounding) + return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder) # Type stubs for functions added by the _tensor_member_fn decorator. # (Unfortunately these can't be created automatically.) @@ -1685,8 +1679,9 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas :type bitcast: bool, optional """ input = semantic.to_tensor(input, _builder) - if isinstance(bitcast, constexpr): - bitcast = bitcast.value + dtype = _constexpr_to_value(dtype) + fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding) + bitcast = _constexpr_to_value(bitcast) if bitcast: return semantic.bitcast(input, dtype, _builder) return semantic.cast(input, dtype, _builder, fp_downcast_rounding) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index f5fea3d3b2..8a83a29018 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -828,10 +828,6 @@ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tenso def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None) -> tl.tensor: src_ty = input.type - if isinstance(dst_ty, tl.constexpr): - dst_ty = dst_ty.value - if isinstance(fp_downcast_rounding, tl.constexpr): - fp_downcast_rounding = fp_downcast_rounding.value if src_ty.is_block(): dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) if src_ty == dst_ty: From 492ea92c05ac2fdde8abf7ded241442f029217ea Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 10 Dec 2024 04:02:08 +0000 Subject: [PATCH 9/9] Revert "[FRONTEND] added support for tuples (#5220)" This reverts commit 9743ec0dca5bbd9dbce20adc3ee273af6b095f94. --- python/src/ir.cc | 1 - .../test/unit/language/test_compile_errors.py | 49 ++-- python/test/unit/language/test_core.py | 23 +- python/test/unit/language/test_decorator.py | 2 +- python/test/unit/language/test_tuple.py | 100 ------- python/test/unit/runtime/test_bindings.py | 15 +- python/test/unit/runtime/test_cache.py | 4 +- python/test/unit/runtime/test_subproc.py | 8 +- python/test/unit/test_perf_warning.py | 5 +- python/triton/_utils.py | 49 ---- python/triton/backends/compiler.py | 41 +-- python/triton/compiler/code_generator.py | 258 +++++++----------- python/triton/compiler/compiler.py | 29 +- python/triton/language/__init__.py | 19 +- python/triton/language/core.py | 141 ++-------- python/triton/language/semantic.py | 4 +- python/triton/runtime/jit.py | 44 +-- python/triton/tools/compile.py | 16 +- third_party/amd/backend/compiler.py | 13 +- third_party/amd/backend/driver.py | 59 ++-- third_party/nvidia/backend/driver.py | 43 +-- 21 files changed, 288 insertions(+), 635 deletions(-) delete mode 100644 python/test/unit/language/test_tuple.py diff --git a/python/src/ir.cc b/python/src/ir.cc index 53ba39ae10..23bb86e5eb 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -605,7 +605,6 @@ void init_triton_ir(py::module &&m) { "Function argument index out of range"); return self.getArgument(idx); }) - .def("get_num_args", &FuncOp::getNumArguments) .def( "add_entry_block", [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index a117597213..c775714859 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -17,7 +17,7 @@ def kernel(): a += 1 # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: assert "is not defined" in str(e.value), "error should mention the undefined variable" @@ -32,7 +32,7 @@ def kernel(): 0 + "a" with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: assert "at 2:4:" in str(e.value), "error should point to the 0" @@ -47,7 +47,7 @@ def kernel(): tl.static_assert(isinstance(0, tl.tensor)) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: assert isinstance(e.value, CompileTimeAssertionFailure) @@ -66,7 +66,7 @@ def kernel(): not (0, 0) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: assert e.value.__cause__ is None @@ -83,7 +83,7 @@ def kernel(): 1.0 << 1 with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: assert "at 2:4:" in str(e.value), "error should point to the 1.0" @@ -107,7 +107,7 @@ def kernel(): nested_call() with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: inner = e.value.__cause__ @@ -130,7 +130,7 @@ def kernel(): tl.expand_dims(None, -1) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) try: inner = e.value.__cause__ @@ -157,7 +157,7 @@ def kernel(): a = two_returns() a + tl.arange(0, 4) # only works if we took the first return - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) def test_not_const_annotate_no_err(): @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err(): def kernel(N: int = 1): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) @triton.jit @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 4) - triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) @triton.jit def kernel2(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 8) - triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) @triton.jit @@ -211,7 +211,7 @@ def kernel(N: int): returns_branched_on_non_constexpr(N) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) try: assert "at 2:4:" in str(e.value), "error should point to the function call" @@ -227,7 +227,7 @@ def kernel(): tl.arange(2, 7) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) assert str(e.value.__cause__) == "arange's range must be a power of 2" @@ -238,7 +238,7 @@ def kernel(): tl.full((33, ), 0, dtype=tl.int64) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" @@ -251,7 +251,7 @@ def kernel(): a = CAPTURED # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) assert "CAPTURED is not defined" in str(e.value) @@ -265,7 +265,7 @@ def kernel(): a = GLOBAL # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) assert "global variable" in str(e.value) @@ -279,7 +279,7 @@ def kernel(): a = CONSTEXPR_ANNOTATED_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) CONSTEXPR_GLOBAL = tl.constexpr(42) @@ -292,7 +292,7 @@ def kernel(): a = CONSTEXPR_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) TYPE_ALIAS = tl.pointer_type(tl.int32) @@ -305,7 +305,7 @@ def kernel(): a = TYPE_ALIAS # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) def test_global_access_in_fn_default_arg(): @@ -315,7 +315,7 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) def test_defaults_assign_no_err(): @@ -324,7 +324,7 @@ def test_defaults_assign_no_err(): def kernel(a=1, B: tl.constexpr = ""): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""})) def test_where_warning(fresh_triton_cache): @@ -337,7 +337,7 @@ def kernel(): tl.where(a, b, c) with pytest.warns(UserWarning): - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) @pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) @@ -371,8 +371,7 @@ def dtype_kernel(dtype: tl.constexpr): ctx = pytest.raises(CompilationError, match="") with ctx as e: - triton.compile( - triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) + triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype})) if dtype not in supported_dtypes: try: @@ -391,7 +390,7 @@ def dot_kernel(): tl.dot(a, b, max_num_imprecise_acc=128) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={})) try: assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") except AssertionError as assertion_err: diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2d570e66b3..6d41b6eece 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4407,17 +4407,15 @@ def kernel(x): def test_value_specialization(value: int, value_type: str, device) -> None: def repr(specialization): - ty = specialization.signature["value1"] - cst = '_'.join([k for k, v in specialization.constants.items() if v == 1]) - return f"kernel_{ty}_{cst}" + spec_type = specialization.signature["VALUE"] + return f"kernel_{spec_type}" @triton.jit(repr=repr) - def kernel(value1, is_one, X): + def kernel(VALUE, X): pass x = torch.tensor([3.14159], device=device) - h = kernel[(1, )](value, 1, x) - assert "is_one" in h.name + h = kernel[(1, )](value, x) assert value_type in h.name @@ -6188,19 +6186,6 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) -def test_dtype(device): - - @triton.jit - def kernel(X): - dtype_x: tl.constexpr = X.dtype.element_ty - tl.static_assert(dtype_x == tl.int32) - tl.static_assert(dtype_x == tl.constexpr(tl.int32)) - tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) - - X = torch.zeros(1, dtype=torch.int32, device=device) - kernel[(1, )](X) - - def test_side_effectful_scan(device): if device != "cuda": pytest.xfail() diff --git a/python/test/unit/language/test_decorator.py b/python/test/unit/language/test_decorator.py index 42207cc1fa..fbbfb71446 100644 --- a/python/test/unit/language/test_decorator.py +++ b/python/test/unit/language/test_decorator.py @@ -23,7 +23,7 @@ def kernel(): pass try: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) except Exception as e: pytest.fail(f"triton compile failed with error: {e}") diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py deleted file mode 100644 index 863034579a..0000000000 --- a/python/test/unit/language/test_tuple.py +++ /dev/null @@ -1,100 +0,0 @@ -import pytest -import triton -import triton.language as tl -import torch - - -@triton.jit -def _tuple_increment(values): - for i in tl.static_range(len(values)): - values[i] = values[i] + 1 - return values - - -@triton.jit -def _tuple_index_func(Ptrs, values): - for i in tl.static_range(len(values)): - tl.store(Ptrs[i], values[i]) - - -@triton.jit -def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): - values = _tuple_increment(values) - _tuple_index_func(Ptrs, values) - - -@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) -def test_index(size, device="cuda"): - vals = tuple([i + 1 for i in range(size)]) - rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) - _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) - assert vals == tuple([x.item() - 1 for x in rets]) - - -# ---- - - -@triton.jit -def _tuple_assign(XPtrs, YPtrs, values): - # assign from tuple - X0, X1 = XPtrs - x0, x1 = values - tl.store(X0, x0) - tl.store(X1, x1) - # assign to tuple - Y0, Y1, Y2 = YPtrs - Y = Y0, Y1, Y2 - y = x0, 10, x1 - tl.store(Y[0], y[0]) - tl.store(Y[1], y[1]) - tl.store(Y[2], y[2]) - - -def test_assign(device="cuda"): - vals = (2., 3.) - x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) - y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) - _tuple_assign[(1, )](x, y, vals) - assert x[0] == vals[0] - assert x[1] == vals[1] - assert y[0] == vals[0] - assert y[1] == 10 - assert y[2] == vals[1] - - -# ------- - - -@triton.jit -def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): - tl.static_assert(tuple1[1] is None) - tl.store(Ptr + 5, cst2) - tl.store(Ptr + 6, tuple1[0]) - tl.store(Ptr + 7, tl.load(tuple1[2][0])) - tl.store(Ptr + 8, tuple1[2][1][0]) - tl.store(Ptr + 9, tl.load(tuple1[2][1][2])) - - -# test serialization/deserialization of tuple arguments in -# the frontend. -@triton.jit -def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2): - tl.static_assert(N1 is None) - tl.static_assert(tuple1[1][1] is None) - tl.store(Ptr + 0, tl.load(tuple1[0])) - tl.store(Ptr + 1, tuple1[1][0]) - tl.store(Ptr + 2, tl.load(tuple1[1][2])) - tl.store(Ptr + 3, cst1 + val1) - tl.store(Ptr + 4, tl.load(tuple2[0])) - _tuple_fn0(Ptr, 15, (-1, None, tuple1)) - - -def test_serialize(device="cuda"): - x0 = torch.tensor([8], dtype=torch.int32, device=device) - x1 = torch.tensor([12], dtype=torch.int32, device=device) - y0 = torch.tensor([10], dtype=torch.int32, device=device) - z = torch.empty((10, ), dtype=torch.int32, device=device) - # we want to check that JIT specialization propagates to tuples: - _tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, )) - ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device) - assert torch.equal(z, ref) diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index e621eefc01..206d132301 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -63,12 +63,15 @@ def walk_fn(op): backend = triton.compiler.compiler.make_backend(target) src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args)}, - constexprs={kernel.arg_names[i]: arg - for i, arg in enumerate(args) - if not isinstance(arg, torch.Tensor)}, - attrs=backend.get_attrs_descriptor(kernel.params, args), + signature={ + kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args) + if i not in kernel.constexprs + }, + constants={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=backend.get_attrs_descriptor(args, kernel.params), ) context = triton._C.libtriton.ir.context() diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 23c943aeb1..a5f381dc9f 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -592,10 +592,10 @@ def cache_hook(*args, **kwargs): JITFunction.cache_hook = cache_hook # In warmup we assume that the pointer range is 32 bits kernel_add.warmup(torch.float32, grid=(1, )) - assert pointer_range_32 == [(0, )] + assert pointer_range_32 == [0] # Torch tensor > 2GB kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) assert len(pointer_range_32) == 0 # Torch tensor <= 2GB kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) - assert pointer_range_32 == [(0, )] + assert pointer_range_32 == [0] diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index ecd7227a30..334d5d635f 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -19,8 +19,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constexprs={'N': 32}, - signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'}, + constants={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, attrs=attrs, ) triton.compile(src=src, target=target) @@ -44,7 +44,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constexprs={}) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) triton.compile(src=src, target=target) @@ -65,7 +65,7 @@ def empty_kernel(): import gc gc.collect() - src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constexprs={}) + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) triton.compile(src=src, target=target) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 461dcb46b4..6646d94f50 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -92,7 +92,7 @@ def matmul_kernel( "stride_cm": "i32", "stride_cn": "i32", }, - constexprs={}, + constants={}, )) captured = capfd.readouterr() @@ -136,9 +136,8 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) "in_ptr2": "*fp16", "in_ptr3": "*fp32", "out_ptr0": "*fp16", - "XBLOCK": "constexpr", }, - constexprs={"XBLOCK": XBLOCK}, + constants={"XBLOCK": XBLOCK}, ), options={"num_warps": 1}, ) diff --git a/python/triton/_utils.py b/python/triton/_utils.py index 0ce1a53a70..ca60c8c3cb 100644 --- a/python/triton/_utils.py +++ b/python/triton/_utils.py @@ -20,52 +20,3 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]: idx += size assert idx == len(flat) return ret - - -def find_paths_if(iterable, pred): - from .language import core - is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) - ret = dict() - - def _impl(current, path): - path = (path[0], ) if len(path) == 1 else tuple(path) - if is_iterable(current): - for idx, item in enumerate(current): - _impl(item, path + (idx, )) - elif pred(path, current): - if len(path) == 1: - ret[(path[0], )] = current - else: - ret[tuple(path)] = current - - if is_iterable(iterable): - _impl(iterable, []) - elif pred(list(), iterable): - ret = {tuple(): iterable} - else: - ret = dict() - return ret - - -def parse_list_string(s): - s = s.strip() - if s.startswith('[') and s.endswith(']'): - s = s[1:-1] - result = [] - current = '' - depth = 0 - for c in s: - if c == '[': - depth += 1 - current += c - elif c == ']': - depth -= 1 - current += c - elif c == ',' and depth == 0: - result.append(current.strip()) - current = '' - else: - current += c - if current.strip(): - result.append(current.strip()) - return result diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 98e6f95c8c..a24f768192 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -3,11 +3,11 @@ import hashlib import subprocess import sysconfig + from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple, Union from types import ModuleType -from .._utils import find_paths_if # Table that associates strings to AttrsDescriptor (sub)classes. # In this way we can dynamically select the correct class @@ -52,8 +52,7 @@ class AttrsDescriptor: `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant """ - __slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values', - 'constant_properties') + __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') def __init__(self, params=None, values=None): """ @@ -68,7 +67,6 @@ def __init__(self, params=None, values=None): # Default initialization self.arg_properties = {} self.property_values = {} - self.equal_to_none = {} self.constant_properties = set() self._add_common_properties(params, values) @@ -88,30 +86,17 @@ def _add_common_properties(self, params, values): assert (len(params) == len(values)) # Divisibility property - divisibility_16 = [] - for param, arg in zip(params, values): - if param.do_not_specialize or \ - param.do_not_specialize_on_alignment: - continue - paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_divisible_by_16(val)) - divisibility_16 += [(param.num, ) + x for x in paths] - self.arg_properties["tt.divisibility"] = divisibility_16 + self.arg_properties["tt.divisibility"] = [ + param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] # Equal to 1 property - equal_to_1 = [] - for param, arg in zip(params, values): - if param.do_not_specialize: - continue - paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_equal_to_1(val)) - equal_to_1 += [(param.num, ) + x for x in paths] - self.arg_properties["tt.equal_to"] = equal_to_1 - - # Equal to None property - equal_to_none = [] - for param, arg in zip(params, values): - paths = find_paths_if(arg, lambda path, val: val is None) - equal_to_none += [(param.num, ) + x for x in paths] - self.equal_to_none = equal_to_none + self.arg_properties["tt.equal_to"] = [ + param.num + for param, arg in zip(params, values) + if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize + ] def _add_backend_properties(self, params=None, values=None): """ This method is for different subclasses to implement their own compile-time properties """ @@ -145,8 +130,6 @@ def get_constants(self) -> Dict: for prop_name in self.constant_properties: for p in self.arg_properties.get(prop_name, []): constants[p] = self.property_values[prop_name] - for v in self.equal_to_none: - constants[v] = None return constants def filter_out_constants(self): @@ -183,7 +166,7 @@ def from_dict(data): """ attrs_descriptor = _descriptor_table[data["cls"]]() for prop_name, param_ids in data["arg_properties"].items(): - attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids)) + attrs_descriptor.arg_properties[prop_name] = param_ids attrs_descriptor._init_slots() return attrs_descriptor diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 050e8ad0d7..1c39d778ec 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,13 +15,9 @@ from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType from triton._utils import list_list_flatten, list_list_unflatten -from functools import reduce -from .._utils import find_paths_if def mangle_ty(ty): - if ty.is_tuple(): - return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T' if ty.is_ptr(): return 'P' + mangle_ty(ty.element_ty) if ty.is_int(): @@ -60,7 +56,7 @@ def _is_triton_tensor(o: Any) -> bool: def _is_constexpr(o: Any) -> bool: - return o is None or isinstance(o, (constexpr, language.core.dtype)) + return isinstance(o, constexpr) def _is_triton_scalar(o: Any) -> bool: @@ -193,66 +189,11 @@ def visit_Call(self, node: ast.Call) -> bool: return self.visit(node.func) -class ASTFunction: - - def get_path(self, x, path): - return reduce(lambda a, idx: a[idx], path, x) - - def set_path(self, x, path, val): - prev = x if len(path) == 1 else self.get_path(x, path[:-1]) - prev[path[-1]] = val - - def __init__(self, ret_types, arg_types, constexprs, constants, attrs): - self.ret_types = ret_types - self.arg_types = arg_types - self.constexprs = constexprs - self.constants = constants - self.attrs = attrs - - def serialize(self, builder: ir.builder): - # fill up IR values in template - # > build function - is_val = lambda path, _: path not in self.constexprs and _ is not None - val_paths = list(find_paths_if(self.arg_types, is_val).keys()) - arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths] - ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] - return builder.get_function_ty(arg_types, ret_types) - - def deserialize(self, fn): - # create "template" - def make_template(val): - if isinstance(val, (list, tuple, language.tuple_type)): - return language.tuple([make_template(x) for x in val]) - return language.constexpr(None) - - vals = make_template(self.arg_types) - is_val = lambda path, _: path not in self.constexprs and _ is not None - val_paths = list(find_paths_if(self.arg_types, is_val).keys()) - # > set attributes - for attr_path, attr_specs in self.attrs.items(): - for attr_name, attr_val in attr_specs: - if attr_path in val_paths: - fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) - for i, path in enumerate(val_paths): - ty = self.get_path(self.arg_types, path) - if isinstance(ty, nv_tma_desc_type): - fn.set_arg_attr(i, "tt.nv_tma_desc", 1) - # > add IR values to the template - for i, path in enumerate(val_paths): - ty = self.get_path(self.arg_types, path) - self.set_path(vals, path, language.tensor(fn.args(i), ty)) - # > add constexpr values to the template - constants = self.constants | self.constexprs - for path, val in constants.items(): - self.set_path(vals, path, language.constexpr(val)) - return vals - - class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, - module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, - file_name: Optional[str] = None, begin_line=0): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -282,6 +223,8 @@ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunctio self.gscope[k] = v self.lscope = {} + self.attributes = attributes + self.constants = constants self.jit_fn = jit_fn self.function_name = function_name self.is_kernel = is_kernel @@ -399,6 +342,7 @@ def visit_compound_statement(self, stmts): stmts = [stmts] for stmt in stmts: self.visit(stmt) + # Stop parsing as soon as we hit a `return` statement; everything # after this is dead code. if isinstance(stmt, ast.Return): @@ -410,7 +354,7 @@ def visit_Module(self, node): def visit_List(self, node): ctx = self.visit(node.ctx) assert ctx is None - elts = language.tuple([self.visit(elt) for elt in node.elts]) + elts = [self.visit(elt) for elt in node.elts] return elts # By design, only non-kernel functions can return @@ -419,15 +363,16 @@ def visit_Return(self, node): if ret_value is None: self.builder.ret([]) ret_ty = language.void - elif isinstance(ret_value, language.tuple): - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values] + elif isinstance(ret_value, tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) - ret_ty = language.tuple_type(ret_types) + ret_ty = tuple(ret_types) else: ret = language.semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type + if self.ret_type is None: self.ret_type = ret_ty elif self.ret_type != ret_ty: @@ -452,6 +397,7 @@ def visit_FunctionDef(self, node): init_node = ast.Assign(targets=[st_target], value=default_value) else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + try: assert not self.visiting_arg_default_value self.visiting_arg_default_value = True @@ -461,15 +407,34 @@ def visit_FunctionDef(self, node): # initialize function visibility = "public" if self.is_kernel else "private" - fn_ty = self.prototype.serialize(self.builder) - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() - arg_values = self.prototype.deserialize(self.fn) - # bind arguments to symbols + arg_values = [] + idx = 0 + for i in range(len(arg_names)): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + + # Mark this argument as a pass-by-value TMA descriptor (nvidia) + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + + arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) - insert_pt = self.builder.get_insertion_block() self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) @@ -480,11 +445,8 @@ def visit_FunctionDef(self, node): self.ret_type = language.void self.builder.ret([]) else: - if isinstance(self.ret_type, language.tuple_type): - self.prototype.ret_types = self.ret_type.types - else: - self.prototype.ret_types = [self.ret_type] - self.fn.reset_type(self.prototype.serialize(self.builder)) + self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) self.builder.ret([ self.builder.create_poison(ty.to_ir(self.builder)) for ty in self.prototype.ret_types @@ -516,41 +478,37 @@ def visit_AnnAssign(self, node): if target in self.lscope: raise ValueError(f'{target} is already defined.' f' constexpr cannot be reassigned.') - value = constexpr(value) + if not _is_constexpr(value): + value = constexpr(value) self.lscope[target] = value return self.lscope[target] # default: call visit_Assign return self.visit_Assign(node) - def assignTarget(self, target, value): - if isinstance(target, ast.Subscript): - assert target.ctx.__class__.__name__ == "Store" - return self.visit_Subscript_Store(target, value) - if isinstance(target, ast.Tuple): - assert target.ctx.__class__.__name__ == "Store" - for i, name in enumerate(target.elts): - self.set_value(self.visit(name), value.values[i]) - return - assert isinstance(target, ast.Name) - self.set_value(self.visit(target), value) - def visit_Assign(self, node): - # construct values to assign - def _sanitize_value(value): - if isinstance(value, language.tuple): - return language.tuple([_sanitize_value(v) for v in value.values]) - native_nontensor_types = (language.dtype, language.tuple) + _names = [] + if isinstance(node, ast.AnnAssign): + _names += [self.visit(node.target)] + else: + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable value = _unwrap_if_constexpr(value) if value is not None and \ - not _is_triton_value(value) and \ - not isinstance(value, native_nontensor_types): + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) - return value - - values = _sanitize_value(self.visit(node.value)) - targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets - assert len(targets) == 1 - self.assignTarget(targets[0], values) + self.set_value(name, value) def visit_AugAssign(self, node): name = node.target.id @@ -573,7 +531,7 @@ def visit_Load(self, node): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] - return language.tuple(args) + return tuple(args) def _apply_binary_method(self, method_name, lhs, rhs): # TODO: raise something meaningful if getattr fails below, esp for reverse method @@ -945,7 +903,7 @@ def visit_While(self, node): assert False, "Not implemented" ast.NodeVisitor.generic_visit(self, stmt) - def visit_Subscript_Load(self, node): + def visit_Subscript(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) @@ -953,16 +911,6 @@ def visit_Subscript_Load(self, node): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] - def visit_Subscript_Store(self, node, value): - assert node.ctx.__class__.__name__ == "Store" - lhs = self.visit(node.value) - slices = self.visit(node.slice) - assert isinstance(lhs, language.tuple) - lhs.__setitem__(slices, value) - - def visit_Subscript(self, node): - return self.visit_Subscript_Load(node) - def visit_ExtSlice(self, node): return [self.visit(dim) for dim in node.dims] @@ -1119,7 +1067,7 @@ def visit_Slice(self, node): lower = self.visit(node.lower) upper = self.visit(node.upper) step = self.visit(node.step) - return language.slice(lower, upper, step) + return slice(lower, upper, step) def visit_Index(self, node): return self.visit(node.value) @@ -1135,26 +1083,24 @@ def visit_Assert(self, node) -> Any: def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - for i, arg in enumerate(args): - if isinstance(arg, (language.dtype, float, int, bool)): - args[i] = language.core.constexpr(arg) - args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) - args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values() - # mangle - fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) + args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = {} + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) # generate function def if necessary if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - arg_types = [ - language.core.constexpr if arg is None or isinstance(arg, - (bool, int, language.core.dtype)) else arg.type - for arg in args - ] - prototype = ASTFunction([], arg_types, args_cst, dict(), dict()) - generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn, - function_name=fn_name, function_types=self.function_ret_types, + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, module_map=self.builder.module_map) @@ -1169,9 +1115,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): else: callee_ret_type = self.function_ret_types[fn_name] symbol = self.module.get_function(fn_name) - args_val = [arg.handle for arg in args_val] - call_op = self.builder.call(symbol, args_val) - if callee_ret_type is None: + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: return None elif call_op.get_num_results() == 1: return tensor(call_op.get_result(0), callee_ret_type) @@ -1179,8 +1124,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): # should return a tuple of tl.tensor results = [] for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type.types[i])) - return language.tuple(results) + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func)) @@ -1199,11 +1144,7 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: - ret = fn(*args, **extra_kwargs, **kws) - # builtin functions return plain tuples for readability - if isinstance(ret, tuple): - ret = language.tuple(ret) - return ret + return fn(*args, **extra_kwargs, **kws) except Exception as e: # Normally when we raise a CompilationError, we raise it as # `from None`, because the original fileline from the exception @@ -1344,29 +1285,38 @@ def kernel_suffix(signature, specialization): suffix = '' for i, _ in enumerate(signature): suffix += str(i) - if (i, ) in specialization.equal_to_1: + if i in specialization.equal_to_1: suffix += 'c' - if (i, ) in specialization.divisibility_16: + if i in specialization.divisibility_16: suffix += 'd' return suffix def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): - constexprs = specialization.constexprs - arg_idx = lambda x: (fn.arg_names.index(x), ) if isinstance(x, str) else x - constants = specialization.attrs.get_constants() - constexprs = {arg_idx(k): v for k, v in constexprs.items()} - arg_types = [str_to_ty(ty) for ty in specialization.signature.values()] - # find index of constants in serialized order attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True + new_attrs = attrs.filter_out_constants() fn_attrs = new_attrs.get_fn_attrs() - fn_attrs = {k: v for k, v in fn_attrs.items() if k not in constants} + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] file_name, begin_line = get_jit_fn_file_line(fn) - prototype = ASTFunction([], arg_types, constexprs, constants, fn_attrs) - generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(specialization), - jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, - codegen_fns=codegen_fns, module_map=module_map) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) ret = generator.module diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 5de3b0f344..9c7e9f28a4 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -51,12 +51,12 @@ def convert_type_repr(x): class ASTSource: - def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: + def __init__(self, fn, signature, constants=None, attrs=None) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ self.signature = signature - self.constexprs = constexprs + self.constants = constants self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} @@ -64,19 +64,20 @@ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: for k in self.signature.keys(): if not isinstance(k, str): raise TypeError("Signature keys must be string") - if self.constexprs is None: - self.constexprs = {} + if self.constants is None: + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") if self.attrs is None: self.attrs = AttrsDescriptor() - # this is the constexprs plus the specialized constants - spec_constants = {self.fn.arg_names[k[0]]: v for k, v in self.attrs.get_constants().items() if len(k) == 1} - self.constants = self.constexprs | spec_constants def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] # Note - we stringify the keys here to allow sorting to work for cases # where constants have mixed int/str keys. - sorted_constants = sorted((str(k), v) for k, v in self.constexprs.items()) + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() @@ -275,11 +276,11 @@ def compile(src, target=None, options=None): codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() - # try: - module = src.make_ir(options, codegen_fns, module_map, context) - # except Exception as e: - # filter_traceback(e) - # raise + try: + module = src.make_ir(options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise use_ir_loc = os.environ.get("USE_IR_LOC", None) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) @@ -411,7 +412,7 @@ def launch_metadata(self, grid, stream, *args): arg_idx = 0 for i, arg_name in enumerate(self.src.fn.arg_names): if i in self.src.fn.constexprs: - arg_dict[arg_name] = self.src.constexprs[arg_name] + arg_dict[arg_name] = self.src.constants[arg_name] else: arg_dict[arg_name] = args[arg_idx] arg_idx += 1 diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 5f5d464d63..0c8965fc52 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,7 +1,6 @@ """isort:skip_file""" # Import order is significant here. -from .._utils import parse_list_string from . import math from . import extra from .standard import ( @@ -70,6 +69,7 @@ float8e5, float8e5b16, full, + function_type, gather, histogram, inline_asm_elementwise, @@ -95,7 +95,6 @@ range, reduce, reshape, - slice, split, static_assert, static_print, @@ -103,8 +102,6 @@ store, tensor, trans, - tuple, - tuple_type, uint16, uint32, uint64, @@ -191,6 +188,7 @@ "floor", "fma", "full", + "function_type", "gather", "histogram", "inline_asm_elementwise", @@ -234,7 +232,6 @@ "reduce", "reshape", "rsqrt", - "slice", "sigmoid", "sin", "softmax", @@ -251,7 +248,6 @@ "tensor", "trans", "triton", - "tuple", "uint16", "uint32", "uint64", @@ -268,9 +264,6 @@ def str_to_ty(name): - if name == "none": - return None - if name[0] == "*": name = name[1:] const = False @@ -280,17 +273,9 @@ def str_to_ty(name): ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) - if name[0] == "[": - names = parse_list_string(name) - tys = [str_to_ty(x) for x in names] - return tuple_type(types=tys) - if name == "nvTmaDesc": return nv_tma_desc_type() - if name == "constexpr": - return constexpr - tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 31b19754c6..85d5f6beba 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -140,7 +140,6 @@ def __init__(self, value): self.value = value.value else: self.value = value - self.type = constexpr def __repr__(self) -> str: return f"constexpr[{self.value}]" @@ -474,10 +473,6 @@ def is_ptr(): def is_const(): return False - @staticmethod - def is_tuple(): - return False - def __eq__(self, other: dtype): if not isinstance(other, dtype): return False @@ -613,10 +608,11 @@ def __init__(self, element_ty: dtype, shape: List): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. - assert (isinstance(shape, (list, tuple))) + + assert (isinstance(shape, list)) # shape can be empty ([]) when an input is a 0D tensor. - self.shape = tuple(_unwrap_shape(shape)) + self.shape = _unwrap_shape(shape) if not self.shape: raise TypeError('0d block_type is forbidden') @@ -651,32 +647,19 @@ def scalar(self): return self.element_ty -class tuple_type(dtype): +class function_type(dtype): - def __init__(self, types): - self.types = types - self.name = f"[{','.join(map(str, self.types))}]" + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types def __str__(self): - return self.name - - def __iter__(self): - return iter(self.types) + return f'fn ({self.param_types}) -> {self.ret_types}' def to_ir(self, builder: ir.builder): - return [ty.to_ir(builder) for ty in self.types] - - def __getitem__(self, index: int) -> dtype: - return self.types[index] - - def is_tuple(self): - return True - - -class slice_type(dtype): - - def __init__(self): - self.name = 'slice_type' + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) # scalar types @@ -778,7 +761,7 @@ def __init__(self, handle, type: dtype): self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type self.dtype = type.scalar - self.shape = tuple([constexpr(s) for s in self.shape]) + self.shape = [constexpr(s) for s in self.shape] def _flatten_ir(self): return [self.handle] @@ -999,16 +982,13 @@ def __not__(self, _builder=None): @builtin def __getitem__(self, slices, _builder=None): - import builtins - if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: + if isinstance(slices, (slice, constexpr)) or slices is None: slices = [slices] - if isinstance(slices, tuple): - slices = slices.values ret = self for dim, sl in enumerate(slices): if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) - elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None: + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: pass else: raise ValueError(f"unsupported tensor index: {sl}") @@ -1167,77 +1147,6 @@ def flip(self, dim=None) -> tensor: ... -class tuple: - - def __init__(self, args: list): - self.values = [i for i in args] - - @property - def type(self): - - def get_type(x): - if isinstance(x, dtype): - return dtype - return x.type - - return tuple_type([get_type(x) for x in self.values]) - - def __getitem__(self, idx: constexpr): - if isinstance(idx, int): - idx = constexpr(idx) - if isinstance(idx, constexpr): - return self.values[idx] - else: - import builtins - assert isinstance(idx, (slice, builtins.slice)) - return tuple(self.values[idx.start:idx.stop:idx.step]) - - # TODO: remove - def __setitem__(self, idx: constexpr, value): - if isinstance(idx, int): - idx = constexpr(idx) - assert isinstance(idx, constexpr) - self.values[idx] = value - - def __add__(self, other): - if isinstance(other, list): - other = tuple(other) - return tuple(self.values + other.values) - # return tuple(a + b for a, b in zip(self.values, other.values)) - - def __mul__(self, other): - assert isinstance(other, constexpr) - return tuple(self.values * other.value) - - def __eq__(self, other): - import builtins - if isinstance(other, (list, builtins.tuple)): - other = tuple(other) - return constexpr(self.values == other.values) - - def __hash__(self): - import builtins - return hash(builtins.tuple(self.values)) - - def __str__(self): - return str([str(x) for x in self.values]) - - def __iter__(self): - return iter(self.values) - - def __len__(self): - return len(self.values) - - -class slice: - - def __init__(self, start, stop, step): - self.start = start - self.stop = stop - self.step = step - self.type = slice_type() - - class _experimental_tensor_descriptor_base(_value): """" A tensor descriptor with unknown shape and strides @@ -1653,7 +1562,7 @@ def expand_dims(input, axis, _builder=None): """ input = semantic.to_tensor(input, _builder) axis = _constexpr_to_value(axis) - axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] + axes = list(axis) if isinstance(axis, Sequence) else [axis] new_ndim = len(input.shape) + len(axes) axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] @@ -2306,12 +2215,14 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): - param_types = [t.type.scalar for t in input] * 2 + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + region = reduce_op.get_region(0) with _insertion_guard(_builder): - to_ir = lambda T: T.to_ir(_builder) - block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] @@ -2405,12 +2316,14 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): - param_types = [t.type.scalar for t in input] * 2 + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + region = scan_op.get_region(0) with _insertion_guard(_builder): - to_ir = lambda T: T.to_ir(_builder) - block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 72451737cd..57c44e0fa6 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -759,14 +759,14 @@ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> # Add new axes to lhs for _ in range(len(lhs_shape), len(rhs_shape)): lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), - tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for _ in range(len(rhs_shape), len(lhs_shape)): rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), - tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4ae7a918a1..d04f516e81 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -308,8 +308,6 @@ def mangle_type(arg, is_const=False): return "fp32" elif hasattr(arg, "tma_desc_cpu_ptr"): return "nvTmaDesc" - elif isinstance(arg, tuple): - return "[" + ",".join(map(mangle_type, arg)) + "]" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -337,8 +335,8 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} import json obj = { - 'name': name, 'signature': signature, 'constant_keys': list(constants.keys()), 'constant_vals': - list(constants.values()), 'attrs': attrs.to_dict(), 'options': options.__dict__, 'key': key + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key } serialized_obj = json.dumps(obj) return serialized_obj @@ -370,7 +368,6 @@ def create_function_from_signature(sig, kparams, backend): func_args.append(f"{name}=default_{name}") dict_entries.append(f"'{name}': {name}") if kp.is_constexpr: - signature_types.append('"constexpr"') constexpr_vals.append(name) else: non_constexpr_vals.append(name) @@ -604,23 +601,32 @@ def run(self, *args, grid, warmup, **kwargs): # done here rather than when we build the signature as otherwise # the kernel cache key could not distinguish between byte pointers # and None arguments, resulting in a downstream mismatch: - sigkeys = [param.name for param in self.params] + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] sigvals = sig_and_spec[:len(sigkeys)] - signature = {k: v for (k, v) in zip(sigkeys, sigvals)} - - attrs = backend.get_attrs_descriptor(self.params, bound_vals) - constexprs = {p.name: v for (v, p) in zip(bound_vals, self.params) if p.is_constexpr} - for i, arg in constexprs.items(): + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or (p.num in constant_params) or v is None + } + for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True): + if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): return None # compile the kernel - src = self.ASTSource(self, signature, constexprs, attrs) - kernel = self.compile(src, target=target, options=options.__dict__) + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) self.cache[device][key] = kernel - self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False) + self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -633,11 +639,15 @@ def run(self, *args, grid, warmup, **kwargs): # canonicalize grid assert grid is not None if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. grid = grid(bound_args) grid_size = len(grid) grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 + # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, @@ -728,11 +738,9 @@ def preload(self, specialization_data): if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") - constant_keys = deserialized_obj['constant_keys'] - constant_vals = deserialized_obj['constant_vals'] constants = { key: tl.dtype(value) if tl.dtype.is_dtype(value) else value - for key, value in zip(constant_keys, constant_vals) + for key, value in deserialized_obj['constants'].items() } signature = dict(deserialized_obj['signature'].items()) src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 50483b2362..6adf7794cc 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -91,13 +91,15 @@ def constexpr(s): pass return None - hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} - for key in constants: - signature[key] = 'constexpr' + signature = { + kernel.arg_names[i]: s.split(":")[0] + for i, s in enumerate(signature) + if kernel.arg_names[i] not in constants + } const_sig = 'x'.join([str(v) for v in constants.values()]) doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] @@ -107,8 +109,8 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) for p, v in attrs.get_constants().items(): - constants.update({kernel.arg_names[p[0]]: v}) - src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) + constants.update({kernel.arg_names[p]: v}) + src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) if ccinfo.metadata.global_scratch_size > 0: @@ -124,7 +126,7 @@ def constexpr(s): arg_types.append(signature[arg_name]) arg_names_not_1.append(arg_name) arg_types_not_1.append(signature[arg_name]) - elif (i, ) in attrs.equal_to_1: + elif i in attrs.equal_to_1: arg_names.append(arg_name) arg_types.append(signature[arg_name]) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index a8d806a8b1..81b07f2e7d 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,6 +1,5 @@ from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd -from triton._utils import find_paths_if from dataclasses import dataclass from typing import Any, Dict, Tuple from types import ModuleType @@ -101,14 +100,10 @@ def _add_backend_properties(self, params=None, values=None): if params is None or values is None: return - pointer_range = [] - for param, arg in zip(params, values): - if param.do_not_specialize or \ - param.do_not_specialize_on_alignment: - continue - paths = find_paths_if(arg, lambda path, val: HIPAttrsDescriptor.is_within2gb(val)) - pointer_range += [(param.num, ) + x for x in paths] - self.arg_properties["tt.pointer_range"] = pointer_range + self.arg_properties["tt.pointer_range"] = [ + param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] @staticmethod def is_within2gb(arg): diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 965341b96e..99e5509eca 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -8,7 +8,6 @@ from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver -from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -165,7 +164,7 @@ def __init__(self): # -------------------- Launcher ---------------------------- def ty_to_cpp(ty): - if ty[0] == '*' or ty == "none": + if ty[0] == '*': return "hipDeviceptr_t" return { "i1": "int32_t", @@ -187,27 +186,32 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids, warp_size): + start_desc = len(signature) + #signature = generate_cu_signature(constants, signature, ids) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*' or ty == "none": + if ty[0] == '*': return "PyObject*" - if ty[0] == '[': - if ty == "[]": - return "[]" - tys = parse_list_string(ty) - val = ','.join(map(_extracted_type, tys)) - return f"[{val}]" - return ty_to_cpp(ty) + return { + 'i1': 'int32_t', + 'i8': 'int8_t', + 'i16': 'int16_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u1': 'uint32_t', + 'u8': 'uint8_t', + 'u16': 'uint16_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] def format_of(ty): - if ty == "hipDeviceptr_t": - return "O" - if ty[0] == "[": - if ty == "[]": - return "()" - tys = parse_list_string(ty) - val = ''.join(map(format_of, tys)) - return f"({val})" return { "PyObject*": "O", "float": "f", @@ -223,22 +227,14 @@ def format_of(ty): "uint64_t": "K", }[ty] - signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format - signature = ','.join(signature.values()).replace('[', '').replace(']', '') - signature = list(filter(bool, signature.split(','))) - signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) libhip_path = _get_path_to_hip_runtime_dylib() # generate glue code - params = list(range(len(signature))) - params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] + params = [f"&arg{i}" for i in signature.keys() if i not in constants] params.append("&global_scratch") src = f""" #define __HIP_PLATFORM_AMD__ @@ -420,8 +416,8 @@ def format_of(ty): // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); @@ -472,8 +468,9 @@ class HIPLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - constants = {idx: value for idx, value in constants.items()} - signature = {idx: value for idx, value in src.signature.items()} + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} src = make_launcher(constants, signature, ids, metadata.warp_size) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index ca4c8bcf02..edeab969ab 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -10,7 +10,6 @@ from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver -from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -96,7 +95,7 @@ def __init__(self): def ty_to_cpp(ty): - if ty[0] == '*' or ty == "none": + if ty[0] == '*': return "CUdeviceptr" return { "i1": "int32_t", @@ -119,29 +118,19 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*' or ty == "none": + if ty[0] == '*': return "PyObject*" if ty == "nvTmaDesc": return "PyObject*" - if ty[0] == '[': - if ty == "[]": - return "[]" - tys = parse_list_string(ty) - val = ','.join(map(_extracted_type, tys)) - return f"[{val}]" + return ty_to_cpp(ty) def format_of(ty): - if ty == "CUdeviceptr": - return "O" - if ty[0] == "[": - if ty == "[]": - return "()" - tys = parse_list_string(ty) - val = ''.join(map(format_of, tys)) - return f"({val})" return { "PyObject*": "O", "float": "f", @@ -157,29 +146,22 @@ def format_of(ty): "uint64_t": "K", }[ty] - signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOOO" + args_format - signature = ','.join(signature.values()).replace('[', '').replace(']', '') - signature = list(filter(bool, signature.split(','))) - signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + internal_args_list = [] for i, ty in signature.items(): - if ty[0] == "*" or ty == "none": + if ty[0] == "*": internal_args_list.append(f"ptr_info{i}.dev_ptr") elif ty == "nvTmaDesc": # Note: we have to dereference the pointer internal_args_list.append(f"*tma_ptr{i}") else: internal_args_list.append(f"_arg{i}") - params = range(len(signature)) # generate glue code - params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] + params = [f"&arg{i}" for i in signature.keys() if i not in constants] params.append("&global_scratch") src = f""" #include \"cuda.h\" @@ -438,7 +420,7 @@ def format_of(ty): }} // raise exception asap - {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); @@ -489,8 +471,9 @@ class CudaLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - constants = {idx: value for idx, value in constants.items()} - signature = {idx: value for idx, value in src.signature.items()} + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch