From 3aeb266ce9f9d1c8488f4f914a71fbf849352e96 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Fri, 14 Jun 2024 12:36:12 -0400 Subject: [PATCH 1/5] [ANALYSIS] Fix allocation of empty shapes (#4143) --- include/triton/Dialect/Triton/IR/Utility.h | 6 ++++++ lib/Analysis/Allocation.cpp | 9 +++------ test/Analysis/test-allocation.mlir | 8 ++++++++ .../DecomposeUnsupportedConversions.cpp | 4 ++-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 0ef5971473..267c5617d4 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -25,6 +25,12 @@ template Int product(llvm::ArrayRef arr) { template auto product(const VecT &vec) { return product(llvm::ArrayRef(vec)); } +template Int getNumElements(ArrayRef shape) { + if (shape.empty()) { + return 0; + } + return product(shape); +} // TODO(jlebar): Rename to ceilOfRatio. template Int ceil(Int m, Int n) { return (m + n - 1) / n; } diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 8c1ce494a8..95e53b12d3 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -270,8 +270,7 @@ class AllocationAnalysis { unsigned inVec = 0; unsigned outVec = 0; auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); + auto elems = getNumElements(smemShape); auto bytes = isa(srcTy.getElementType()) ? elems * kPtrBitWidth / 8 @@ -286,8 +285,7 @@ class AllocationAnalysis { // nothing to do } else { auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); + auto elems = getNumElements(smemShape); auto elemTy = cast(value.getType()).getPointeeType(); auto bytes = @@ -305,8 +303,7 @@ class AllocationAnalysis { // nothing to do } else { auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp); - unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1, - std::multiplies{}); + auto elems = getNumElements(smemShape); auto elemTy = cast(value.getType()).getPointeeType(); auto bytes = isa(elemTy) diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 4ea1fd5e4b..70b7cb907f 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -12,6 +12,14 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { +// CHECK-LABEL: empty +tt.func @empty(%A : !tt.ptr) { + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> + tt.return + // CHECK: size = 0 +} + // CHECK-LABEL: matmul_loop tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp index 68ebe9499d..471dfa8b19 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -5,6 +5,7 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/Patterns.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -61,8 +62,7 @@ static int getCvtOpLDSUsage(triton::gpu::ConvertLayoutOp &cvtOp) { unsigned inVec = 0; unsigned outVec = 0; auto smemShape = triton::getScratchConfigForCvtLayout(cvtOp, inVec, outVec); - unsigned elems = - std::accumulate(smemShape.begin(), smemShape.end(), 1, std::multiplies{}); + unsigned elems = getNumElements(smemShape); auto srcType = cvtOp.getSrc().getType(); auto bytes = isa(srcType.getElementType()) From 6d936965f5b5ed2e14e89fc80e60e737c3170f5e Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 14 Jun 2024 11:43:11 -0700 Subject: [PATCH 2/5] s/ConversionPatternRewriter/RewriterBase/ in utils (NFC) (#4140) You can't create a ConversionPatternRewriter in C++ unit tests, so this means that anything which is touched from a C++ unit test cannot transitively touch anything which uses a ConversionPatternRewriter. This is a simple search+replace, there's no functional difference between these types for our purposes. --- .../TritonGPUToLLVM/TargetInfoBase.h | 36 +++++------ .../Conversion/TritonGPUToLLVM/Utility.h | 56 +++++++++--------- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 37 ++++++------ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 49 ++++++++------- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 49 +++++++-------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 29 ++++----- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 26 ++++---- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 59 +++++++++---------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 47 +++++++-------- .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 25 ++++---- .../lib/TritonNVIDIAGPUToLLVM/Utility.h | 22 +++---- 11 files changed, 205 insertions(+), 230 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index f977d30c02..61f8fb87b4 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -10,34 +10,34 @@ class TargetInfoBase { virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; - virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const = 0; + virtual Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const = 0; - virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const = 0; - virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + virtual void storeShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val, Value pred) const = 0; + virtual Value loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const = 0; - virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const = 0; + virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const = 0; - virtual Value programId(ConversionPatternRewriter &rewriter, Location loc, + virtual Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const = 0; - virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + virtual bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const = 0; virtual bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + RewriterBase &rewriter, Location loc, Value smemBase, SmallVector &vals, RankedTensorType srcTy, Type elemTy, ArrayRef paddedRepShape, ArrayRef origRepShape, ArrayRef outOrd, unsigned accumNumReplicates, @@ -48,11 +48,11 @@ class TargetInfoBase { // format from the device. |formatStrStart| is the pointer to the start of // the format string global variable; |args| are the arguments to fill // placeholders in the format string. - virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + virtual void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const = 0; // Emits LLVM code with |rewriter| to perform assertion failure with the given // |message| from the given |func| in |file|. - virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc, + virtual void assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const = 0; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 6ad63e3796..54f79ab0a2 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -202,9 +202,9 @@ T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, namespace gpu { Type getFunctionType(Type resultType, ValueRange operands); -LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, - Operation *op, StringRef funcName, - Type funcType, StringRef libname = "", +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname = "", StringRef libpath = ""); } // namespace gpu @@ -305,7 +305,7 @@ struct SharedMemoryObject { } Value getBaseBeforeSlice(int order, Location loc, - ConversionPatternRewriter &rewriter) const { + RewriterBase &rewriter) const { Value cSwizzleOffset = getCSwizzleOffset(order); Value offset = sub(i32_val(0), cSwizzleOffset); Type type = base.getType(); @@ -313,9 +313,10 @@ struct SharedMemoryObject { } }; -SharedMemoryObject -getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, - ConversionPatternRewriter &rewriter); +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter); // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. @@ -329,15 +330,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, Value linear, ArrayRef shape); -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape, - ArrayRef order); +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape); +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape); -Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, - StringRef key, StringRef content); +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content); // Given an elemId which represents the index of an element from the list of // elements that are in the thread's registers (i.e. total of @@ -346,7 +346,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, // when converting distributed to distributed layout. Also, a replica is the // smallest CTA tile that is common between input and output layouts. SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, @@ -355,15 +355,15 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, // Given a multiDimOffset, this function wraps around each dimension to be // within shape. SmallVector getWrappedMultiDimOffset( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDimOffset, ArrayRef shape, - SmallVector shapePerCTATile, SmallVector shapePerCTA); + RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, + ArrayRef shape, SmallVector shapePerCTATile, + SmallVector shapePerCTA); inline bool isKernel(FunctionOpInterface funcOp) { return funcOp.getVisibility() == SymbolTable::Visibility::Public; } -inline Value getStackPointer(PatternRewriter &rewriter, +inline Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { auto mod = funcOp->getParentOfType(); LLVM::GlobalOp globalBase = nullptr; @@ -378,8 +378,7 @@ inline Value getStackPointer(PatternRewriter &rewriter, return funcOp.getArgument(funcOp.getNumArguments() - 1); } -inline Value getSharedMemoryBase(Location loc, - ConversionPatternRewriter &rewriter, +inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Operation *op) { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); FunctionOpInterface func = @@ -1566,9 +1565,9 @@ inline void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, } } -inline Value -getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, - ConversionPatternRewriter &rewriter) { +inline Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter) { auto elems = smemObj.getElems(); auto types = smemObj.getTypes(); auto structTy = @@ -1582,9 +1581,8 @@ getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, return llvmStruct; } -inline SmallVector -unpackLLElements(Location loc, Value llvmStruct, - ConversionPatternRewriter &rewriter) { +inline SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); if (llvmStruct.getType().isIntOrIndexOrFloat() || isa(llvmStruct.getType()) || @@ -1602,8 +1600,8 @@ unpackLLElements(Location loc, Value llvmStruct, inline Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, - ValueRange resultVals, - ConversionPatternRewriter &rewriter, Type type) { + ValueRange resultVals, RewriterBase &rewriter, + Type type) { auto structType = dyn_cast(typeConverter->convertType(type)); if (!structType) { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 7811a2ef56..eaaa690c0b 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -12,7 +12,7 @@ using CoordTy = SmallVector; using ValueTable = std::map, std::pair>; static SmallVector -getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, +getMNCoords(Value thread, Location loc, RewriterBase &rewriter, ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, bool isBVec4) { @@ -120,9 +120,8 @@ Type getFunctionType(Type resultType, ValueRange operands) { return LLVM::LLVMFunctionType::get(resultType, operandTypes); } -LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, - Operation *op, StringRef funcName, - Type funcType, +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, StringRef libname /*= ""*/, StringRef libpath /*= ""*/) { using LLVM::LLVMFuncOp; @@ -496,9 +495,10 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } -SharedMemoryObject -getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, - ConversionPatternRewriter &rewriter) { +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter) { ArrayRef types = cast(llvmStruct.getType()).getBody(); SmallVector elems(types.size()); @@ -580,15 +580,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, return multiDim; } -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape, - ArrayRef order) { +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { return linearize(rewriter, loc, applyPermutation(multiDim, order), applyPermutation(shape, order)); } -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape) { +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape) { auto rank = multiDim.size(); Value linear = i32_val(0); if (rank > 0) { @@ -602,8 +601,8 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, return linear; } -Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, - StringRef key, StringRef content) { +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto ctx = moduleOp.getContext(); unsigned stringNumber = 0; @@ -619,7 +618,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, LLVM::GlobalOp global; { - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); global = rewriter.create( UnknownLoc::get(ctx), globalType, @@ -637,7 +636,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, @@ -791,9 +790,9 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, } SmallVector getWrappedMultiDimOffset( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDimOffset, ArrayRef shape, - SmallVector shapePerCTATile, SmallVector shapePerCTA) { + RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, + ArrayRef shape, SmallVector shapePerCTATile, + SmallVector shapePerCTA) { unsigned rank = shape.size(); SmallVector multiDimOffsetWrapped(rank); for (unsigned d = 0; d < rank; ++d) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 55e3609da1..527b89d305 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -10,12 +10,11 @@ namespace mlir::triton::AMD { namespace { template LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, - ConversionPatternRewriter &rewriter, - StringRef name, + RewriterBase &rewriter, StringRef name, LLVM::LLVMFunctionType type) { LLVM::LLVMFuncOp ret; if (!(ret = moduleOp.template lookupSymbol(name))) { - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); ret = rewriter.create(loc, name, type, LLVM::Linkage::External); @@ -24,7 +23,7 @@ LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, } // Extend all values to 64-bit per printf call requirements. -Value printfPromoteValue(ConversionPatternRewriter &rewriter, Value value) { +Value printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto loc = UnknownLoc::get(context); auto type = value.getType(); @@ -68,8 +67,8 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { return rewriter.create(loc, 0, 32); } -Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const { +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); SmallVector operands = {cmp}; Value asmResult = @@ -78,12 +77,12 @@ Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, return asmResult; } -void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const { +void TargetInfo::storeShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val, Value pred) const { mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); } -Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const { Value falseVal = rewriter.create( @@ -91,32 +90,32 @@ Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, return mlir::LLVM::AMD::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); } -Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleXor(loc, rewriter, val, i); } -Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleUp(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const { return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } -bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { @@ -124,7 +123,7 @@ bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, } bool TargetInfo::processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + RewriterBase &rewriter, Location loc, Value smemBase, SmallVector &vals, RankedTensorType srcTy, Type elemTy, ArrayRef paddedRepShape, ArrayRef origRepShape, ArrayRef outOrd, unsigned accumNumReplicates, @@ -133,8 +132,7 @@ bool TargetInfo::processReplicaUsingStMatrix( } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, - ValueRange args, - ConversionPatternRewriter &rewriter, + ValueRange args, RewriterBase &rewriter, bool useStdErr) const { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto *ctx = rewriter.getContext(); @@ -205,14 +203,13 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -void TargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int formatStrByteCount, - ValueRange args) const { +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const { return printfImpl(formatStrStart, formatStrByteCount, args, rewriter, /*useStdError=*/false); } -void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { // Compose and print an assert message. diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 4e86beb3ca..2e7c604cf1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -18,51 +18,52 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; - Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) const override; - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const override; + Value loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const override; - Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const override; - Value programId(ConversionPatternRewriter &rewriter, Location loc, - ModuleOp moduleOp, int axis) const override; + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; - bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, - SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; - bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const override; + bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, + Value smemBase, SmallVector &vals, + RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, + unsigned accumNumReplicates, + int swizzleByteWidth) const override; std::string getMulhiFuncName(Type resultElementTy) const override; - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; - void assertFail(ConversionPatternRewriter &rewriter, Location loc, - StringRef message, StringRef file, StringRef func, - int line) const override; + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; bool enableLinearLayout() const override { return false; } private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, - ConversionPatternRewriter &rewriter, bool useStdErr) const; + RewriterBase &rewriter, bool useStdErr) const; std::string arch; }; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 111045d134..f4325bbcf9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -38,9 +38,8 @@ std::string mangleFunc(std::string name, Type type) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, int strideInt, ShflKind mode, - Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -126,30 +125,26 @@ static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, return Value(); } -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); } -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis) { +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); @@ -160,8 +155,8 @@ Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, return rewriter.create(loc, i32_ty, blockId); } -Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Type elemTy, Value pred, Value falseVal) { +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal) { Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); auto parent = ptr.getParentRegion()->getParentOfType(); auto funcName = mangleFunc(mlir::LLVM::AMD::Predicated_Load, funcType); @@ -174,8 +169,8 @@ Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, return loadVal; } -void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) { +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) { auto ctx = ptr.getContext(); Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred})); auto parent = ptr.getParentRegion()->getParentOfType(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index c60d53f4b4..b8aa25475f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -13,26 +13,22 @@ namespace mlir::LLVM::AMD { const char Predicated_Load[] = "__predicated_load"; const char Predicated_Store[] = "__predicated_store"; -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i); - -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); // Loads from shared or global memory with predication. // `otherElems` is used to mask out the elements that are not loaded -Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Type elemTy, Value pred, Value falseVal); +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal); // Stores to shared or global memory with predication. -void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred); +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred); } // namespace mlir::LLVM::AMD #endif diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 2afa2383c8..902951cc94 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -14,8 +14,7 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; namespace { Value computeStMatrixAddr(Value laneId, int matStride, Location loc, - ConversionPatternRewriter &rewriter, - int swizzleByteWidth) { + RewriterBase &rewriter, int swizzleByteWidth) { Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix // linear index of the matrix in the 2x2 matrices // Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in @@ -34,7 +33,7 @@ Value computeStMatrixAddr(Value laneId, int matStride, Location loc, void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, Value smemBase, Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { + RewriterBase &rewriter) { SmallVector inputs; auto prTy = ptr_ty(rewriter.getContext(), 3); // Pack the input into 2xf16 @@ -53,8 +52,8 @@ void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, void storeDistributedToSharedWithStMatrix( RankedTensorType tensorTy, Type elemTy, SmallVector &inVals, Value smemBase, ArrayRef paddedRepShape, - ArrayRef origRepShape, Location loc, - ConversionPatternRewriter &rewriter, int swizzlingByteWidth) { + ArrayRef origRepShape, Location loc, RewriterBase &rewriter, + int swizzlingByteWidth) { auto shapePerCTA = getShapePerCTA(tensorTy); auto mmaLayout = mlir::cast(tensorTy.getEncoding()); auto order = triton::gpu::getOrder(mmaLayout); @@ -140,7 +139,7 @@ bool isStMatrixCompatible(RankedTensorType tensorTy, int swizzlingByteWidth) { } // declare vprintf(i8*, i8*) as external function -LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { +LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName("vprintf"); Operation *funcOp = moduleOp.lookupSymbol(funcName); @@ -152,7 +151,7 @@ LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { SmallVector argsType{ptr_ty(context), ptr_ty(context)}; auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); return rewriter.create(UnknownLoc::get(context), funcName, @@ -161,8 +160,7 @@ LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { // extend integer to int32, extend float to float64 // this comes from vprintf alignment requirements. -std::pair printfPromoteValue(ConversionPatternRewriter &rewriter, - Value value) { +std::pair printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto type = value.getType(); Value newOp = value; @@ -186,7 +184,7 @@ std::pair printfPromoteValue(ConversionPatternRewriter &rewriter, return {newType, newOp}; } -LLVM::LLVMFuncOp getAssertfailDeclaration(ConversionPatternRewriter &rewriter) { +LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName("__assertfail"); { @@ -200,7 +198,7 @@ LLVM::LLVMFuncOp getAssertfailDeclaration(ConversionPatternRewriter &rewriter) { SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), rewriter.getIntegerType(sizeof(size_t) * 8)}; auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); auto funcOp = rewriter.create(UnknownLoc::get(ctx), funcName, funcType); @@ -260,13 +258,13 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { rewriter.getI32Type()); } -Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const { +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { Value threadMask = int_val(type.getIntOrFloatBitWidth(), -1); return rewriter.create(loc, type, threadMask, cmp); } -void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const { +void TargetInfo::storeShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val, Value pred) const { MLIRContext *ctx = rewriter.getContext(); unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); @@ -279,7 +277,7 @@ void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, builder.launch(rewriter, loc, void_ty(ctx)); } -Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const { MLIRContext *ctx = rewriter.getContext(); @@ -297,31 +295,31 @@ Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, return builder.launch(rewriter, loc, elemTy); } -Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); } -Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleUp(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const { return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis); } -bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { @@ -362,7 +360,7 @@ bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, return false; } bool TargetInfo::processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + RewriterBase &rewriter, Location loc, Value smemBase, SmallVector &vals, RankedTensorType srcTy, Type elemTy, ArrayRef paddedRepShape, ArrayRef origRepShape, ArrayRef outOrd, unsigned accumNumReplicates, @@ -383,9 +381,8 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -void TargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int /*formatStrByteCount*/, - ValueRange args) const { +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args) const { auto *ctx = rewriter.getContext(); Type ptr = ptr_ty(ctx); auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -426,7 +423,7 @@ void TargetInfo::printf(ConversionPatternRewriter &rewriter, call(funcOp, operands); } -void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { auto funcOp = getAssertfailDeclaration(rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 9b59993e6a..8572feeb6f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -13,45 +13,46 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; - Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) const override; - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const override; + Value loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const override; - Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const override; - Value programId(ConversionPatternRewriter &rewriter, Location loc, - ModuleOp moduleOp, int axis) const override; + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; - bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, - SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; - bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const override; + bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, + Value smemBase, SmallVector &vals, + RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, + unsigned accumNumReplicates, + int swizzleByteWidth) const override; std::string getMulhiFuncName(Type resultElementTy) const override; - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; - void assertFail(ConversionPatternRewriter &rewriter, Location loc, - StringRef message, StringRef file, StringRef func, - int line) const override; + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; private: int computeCapability; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 37c5b6ec7c..685b836208 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -8,9 +8,8 @@ namespace LLVM { namespace NVIDIA { using namespace mlir::triton; -static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, NVVM::ShflKind mode, - Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { @@ -42,31 +41,27 @@ static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, return result; } -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { return shuffleCommon(loc, rewriter, val, i, NVVM::ShflKind::idx, i32_val(0x1f)); } -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis) { +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); @@ -92,8 +87,8 @@ Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { return val; } -Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, - Value b, Value mask) { +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask) { PTXBuilder builder; auto &prmt = builder.create("prmt")->o("b32"); auto *destOpr = builder.newOperand("=r"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index bb4e9dd336..3d3eeb1aff 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -40,19 +40,15 @@ namespace LLVM { namespace NVIDIA { Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i); -Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, - Value b, Value mask); - -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); /// Usage of macro load_dsmem /// (1) load_dsmem(addr, ctaId) From 6ee351ecce9f295ed29443724961697b2c1ad442 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 14 Jun 2024 15:27:56 -0700 Subject: [PATCH 3/5] Refactor shared load/store utilities. (#4141) Refactor shared load/store utilities. (This commit message is written about loads, but everything also applies to stores.) Previous to this PR we had two ways of loading from shared memory within the same CTA. 1. LLVM::LoadOp. This supports vector loads, but not predication. 2. TargetInfo::loadShared. This supported predication, but not vector loads. Loads from shared memory in different CTAs were accessible only through an nvidia-specific header. These did not support predication, and although they supported vector loads, it worked slightly differently than LLVM::LoadOp (namely, you have to know you're loading a vector and unwrap the type before passing to the function). This PR reworks all this. Now 1. TargetInfo::loadShared and TargetInfo::loadDShared have the same API. 2. They both support predication and vectors, and the vectors work like LLVM::LoadOp. 3. They share code; they both just emit PTX. 4. Because we're emitting PTX directly from loadDShared, we can delete the NVIDIA::LoadDSmem op. In general I think a logical operation should have either A. A function createFoo() that emits one or more MLIR operations, or B. An MLIR op FooOp that lowers to one or more MLIR operations. But for distributed shmem loads, we had both (A) and (B). This was a redundant layer of indirection. This is used in a future LLs patch. --- .../TritonGPUToLLVM/TargetInfoBase.h | 28 ++- .../Conversion/TritonGPUToLLVM/Utility.h | 21 +-- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 4 +- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 18 ++ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 19 ++- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 11 +- .../include/Dialect/NVGPU/IR/NVGPUOps.td | 24 --- .../nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp | 53 ------ .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 116 ------------- .../ConvertLayoutOpToLLVM.cpp | 7 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 161 ++++++++++++++++-- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 11 +- .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 86 ---------- .../lib/TritonNVIDIAGPUToLLVM/Utility.h | 28 --- 14 files changed, 226 insertions(+), 361 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 61f8fb87b4..f4d904eae0 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -13,11 +13,29 @@ class TargetInfoBase { virtual Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const = 0; - virtual void storeShared(RewriterBase &rewriter, Location loc, Value ptr, - Value val, Value pred) const = 0; - virtual Value loadShared(RewriterBase &rewriter, Location loc, - const TypeConverter *converter, Value ptr, - Type elemTy, Value pred) const = 0; + // Store/load a value from shared memory, either in the same CTA or, if + // `ctaId` is non-nullopt, in another CTA in the same group. + // + // A target that does not support cross-CTA transfers will assert if ctaId is + // non-nullopt. + // + // Assumes the address is aligned to the width of `val`. + virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const = 0; + virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const = 0; + + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const { + storeDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, val, pred); + } + Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred) const { + return loadDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, elemTy, + pred); + } virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const = 0; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 54f79ab0a2..15e19cc7fc 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -124,6 +124,9 @@ using namespace mlir::triton; #define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) // Constants +#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val) +#define true_val() i1_val(true) +#define false_val() i1_val(false) #define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) @@ -213,31 +216,21 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, namespace LLVM { using namespace mlir::triton; +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v); Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v); - -/// Create a 64-bit integer constant. Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v); - -/// Create a 16-bit float constant. Value createConstantF16(Location loc, OpBuilder &rewriter, float v); - -/// Create a 32-bit float constant. Value createConstantF32(Location loc, OpBuilder &rewriter, float v); - -/// Create a 64-bit float constant. Value createConstantF64(Location loc, OpBuilder &rewriter, double v); - -/// Create NaN constant of specified type. Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type); - -/// Create an index type constant. Value createIndexConstant(OpBuilder &builder, Location loc, const TypeConverter *converter, int64_t value); - -/// Create an integer constant of \param width bits. Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, int64_t value); +// Is v an integer or floating-point scalar constant equal to 0? +bool isConstantZero(Value v); + /// Helper function to get strides from a given shape and its order SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 66f7b90e0d..47ff90c7ed 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -341,8 +341,8 @@ struct ReduceOpConversion auto elemTy = getElementType(op, i); Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy, smemBases[i], readOffset); - acc[i] = targetInfo.loadShared(rewriter, loc, getTypeConverter(), - readPtr, elemTy, threadIsNeeded); + acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, + threadIsNeeded); } warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); // only the first thread in each sizeInterWarps is writing diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index eaaa690c0b..a8f525a466 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -441,6 +441,12 @@ using namespace mlir::triton; using mlir::triton::gpu::getOrder; using mlir::triton::gpu::getSizePerThread; +Value createConstantI1(Location loc, OpBuilder &rewriter, bool v) { + auto i1ty = rewriter.getIntegerType(1); + return rewriter.create(loc, i1ty, + IntegerAttr::get(i1ty, v)); +} + Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) { auto i32ty = rewriter.getIntegerType(32); return rewriter.create(loc, i32ty, @@ -495,6 +501,18 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } +bool isConstantZero(Value v) { + if (auto constantOp = v.getDefiningOp()) { + if (auto attr = dyn_cast(constantOp.getValue())) { + return attr.getValue().isZero(); + } + if (auto attr = dyn_cast(constantOp.getValue())) { + return attr.getValue().isZero(); + } + } + return false; +} + SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 527b89d305..c2552bd3ff 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -77,14 +77,23 @@ Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, return asmResult; } -void TargetInfo::storeShared(RewriterBase &rewriter, Location loc, Value ptr, - Value val, Value pred) const { +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { + if (ctaId.has_value()) { + llvm::report_fatal_error( + "AMDGPU does not support cross-CTA shared memory transfers"); + } mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); } -Value TargetInfo::loadShared(RewriterBase &rewriter, Location loc, - const TypeConverter *converter, Value ptr, - Type elemTy, Value pred) const { +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const { + if (ctaId.has_value()) { + llvm::report_fatal_error( + "AMDGPU does not support cross-CTA shared memory transfers"); + } Value falseVal = rewriter.create( loc, elemTy, rewriter.getZeroAttr(elemTy)); return mlir::LLVM::AMD::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 2e7c604cf1..01c7c9e901 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -21,11 +21,12 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, - Value pred) const override; - Value loadShared(RewriterBase &rewriter, Location loc, - const TypeConverter *converter, Value ptr, Type elemTy, - Value pred) const override; + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const override; Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index ca9d18873e..7affd88406 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -87,30 +87,6 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } -def NVGPU_LoadDSmemOp : NVGPU_Op<"load_dsmem", [MemoryEffects<[MemRead]>]> { - let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, I32Attr:$bitwidth, I32Attr:$vec); - let builders = [ - OpBuilder<(ins "Type":$resultTy, "Value":$addr, "Value":$ctaId)>, - OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth, "unsigned":$vec)>, - OpBuilder<(ins "Value":$addr, "Value":$ctaId, "unsigned":$bitwidth)> - ]; - let results = (outs LLVM_LoadableType:$result); - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; -} - -def NVGPU_StoreDSmemOp : NVGPU_Op<"store_dsmem", [MemoryEffects<[MemWrite]>]> { - let arguments = (ins LLVM_AnyPointer:$addr, I32:$ctaId, - Variadic:$values, I1:$pred); - let builders = [ - OpBuilder<(ins "Value":$addr, "Value":$ctaId, "Value":$value, "Value":$pred)>, - ]; - let assemblyFormat = "operands attr-dict `:` type(operands)"; - let extraClassDeclaration = [{ - unsigned getBitwidth(); - unsigned getVec(); - }]; -} - def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { let arguments = (ins BoolAttr:$bCluster); let assemblyFormat = "attr-dict"; diff --git a/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp b/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp index ed87a588f2..f623f50c63 100644 --- a/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp +++ b/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp @@ -32,59 +32,6 @@ using namespace mlir; using namespace mlir::triton::nvgpu; -void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, - Type resultTy, Value addr, Value ctaId) { - unsigned vec, bitwidth; - if (auto structTy = dyn_cast(resultTy)) { - auto types = structTy.getBody(); - assert(types.size() > 0 && "Invalid result type of LoadDSmemOp"); - vec = types.size(); - for (unsigned i = 0; i < vec; ++i) - assert(types[0] == types[i]); - bitwidth = types[0].getIntOrFloatBitWidth(); - } else { - vec = 1; - bitwidth = resultTy.getIntOrFloatBitWidth(); - } - build(builder, state, resultTy, addr, ctaId, bitwidth, vec); -} - -void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, - Value ctaId, unsigned bitwidth, unsigned vec) { - Type resultTy = builder.getIntegerType(bitwidth); - if (vec > 1) { - SmallVector types(vec, resultTy); - resultTy = LLVM::LLVMStructType::getLiteral(builder.getContext(), types); - } - build(builder, state, resultTy, addr, ctaId, bitwidth, vec); -} - -void LoadDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, - Value ctaId, unsigned bitwidth) { - build(builder, state, addr, ctaId, bitwidth, /*vec*/ 1); -} - -void StoreDSmemOp::build(OpBuilder &builder, OperationState &state, Value addr, - Value ctaId, Value value, Value pred) { - SmallVector values = {value}; - build(builder, state, addr, ctaId, values, pred); -} - -unsigned StoreDSmemOp::getBitwidth() { - auto addrTy = getAddr().getType(); - assert(isa(addrTy) && "addr must be a pointer type"); - if (getValues().empty()) - return 0; - auto elemTy = getValues().back().getType(); - return elemTy.getIntOrFloatBitWidth(); -} - -unsigned StoreDSmemOp::getVec() { return getValues().size(); } - -static LogicalResult verify(mlir::triton::nvgpu::WGMMAOp op) { - return success(); -} - void mlir::triton::nvgpu::NVGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index e192165204..3377510620 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -353,122 +353,6 @@ class StoreMatrixOpPattern } }; -class StoreDSmemOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - OperandsAndConstraints getOperandsAndConstraints(ttn::StoreDSmemOp op) const { - OperandsAndConstraints operandsAndTypes; - auto addr = op.getAddr(); - auto ctaId = op.getCtaId(); - auto values = op.getValues(); - auto pred = op.getPred(); - auto bitwidth = op.getBitwidth(); - operandsAndTypes.push_back({addr, "r"}); - operandsAndTypes.push_back({ctaId, "r"}); - operandsAndTypes.push_back({pred, "b"}); - std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); - for (unsigned i = 0; i < values.size(); i++) { - operandsAndTypes.push_back({values[i], c}); - } - return operandsAndTypes; - } - - std::string getPtxAsm(ttn::StoreDSmemOp op) const { - auto bitwidth = op.getBitwidth(); - auto vec = op.getVec(); - auto values = op.getValues(); - assert( - (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && - "invalid bitwidth"); - assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && - "invalid vec size"); - std::string ptxAsm; - if (vec == 1) { - ptxAsm = "{ \n" - ".reg .u32 remoteAddr; \n" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" - ".reg .pred p; \n" - "mov.pred p, $2; \n" - "@p st.shared::cluster.u#bitwidth [remoteAddr], $3; \n" - "}\n"; - } - if (vec == 2) { - ptxAsm = "{ \n" - ".reg .u32 remoteAddr; \n" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" - ".reg .pred p; \n" - "mov.pred p, $2; \n" - "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4}; \n" - "}\n"; - } - if (vec == 4) { - ptxAsm = "{ \n" - ".reg .u32 remoteAddr; \n" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" - ".reg .pred p; \n" - "mov.pred p, $2; \n" - "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4, $5, " - "$6}; \n" - "}\n"; - } - return ptxAsm; - } -}; - -class LoadDSmemOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::vector getOutputConstraints(ttn::LoadDSmemOp op) const { - auto bitwidth = op.getBitwidth(); - std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); - auto vec = op.getVec(); - return std::vector(vec, c); - } - OperandsAndConstraints getOperandsAndConstraints(ttn::LoadDSmemOp op) const { - OperandsAndConstraints operandsAndTypes; - auto addr = op.getAddr(); - auto ctaId = op.getCtaId(); - - operandsAndTypes.push_back({addr, "r"}); - operandsAndTypes.push_back({ctaId, "r"}); - return operandsAndTypes; - } - - std::string getPtxAsm(ttn::LoadDSmemOp op) const { - auto addr = op.getAddr(); - auto ctaId = op.getCtaId(); - auto bitwidth = op.getBitwidth(); - auto vec = op.getVec(); - - assert( - (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && - "invalid bitwidth"); - assert((vec == 1 || vec == 2 || vec == 4) && "invalid vec size"); - - std::string o1 = vec > 1 ? ".v.u" : ".u"; - std::string vecStr = vec == 1 ? "$0" - : vec == 2 ? "{$0, $1}" - : "{$0, $1, $2, $3}"; - unsigned argNum = vec == 1 ? 1 : vec == 2 ? 2 : 4; - auto ptxAsm = "{\n" - ".reg .u32 remoteAddr;\n" - "mapa.shared::cluster.u32 remoteAddr, $" + - std::to_string(argNum) + " , $" + std::to_string(argNum + 1) + - " ; \n" - "ld.shared::cluster" + - o1 + std::to_string(bitwidth) + " " + vecStr + - ", [remoteAddr];\n" - "}\n"; - return ptxAsm; - } -}; - class WGMMAWaitGroupOpPattern : public NVGPUOpPatternBase { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 6b09f3ebfe..2a60feb586 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -363,6 +363,7 @@ struct ConvertLayoutOpConversion lowerDistToDistWithDistSmem(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + MLIRContext *ctx = rewriter.getContext(); auto loc = op.getLoc(); auto typeConverter = getTypeConverter(); auto srcTy = op.getSrc().getType(); @@ -427,7 +428,8 @@ struct ConvertLayoutOpConversion Value localOffset = linearize(rewriter, loc, localCoord, smemShape); Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, localOffset); - outVals.push_back(load_dsmem(ptr, remoteCTAId, llvmElemTy)); + outVals.push_back(targetInfo.loadDShared( + rewriter, loc, ptr, remoteCTAId, llvmElemTy, /*pred=*/true_val())); } Value result = @@ -772,6 +774,9 @@ void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { // For now give ConvertLayoutOpConversion higher benefit, I can split before // merging + // + // TODO(jlebar): lowerDistributedToDistributed does not get hit in any + // testcases. Is this dead code? Does the benefit need to be increased? patterns.add(typeConverter, targetInfo, benefit); // Same default benefit patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 902951cc94..2e781f3398 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -3,8 +3,10 @@ #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "llvm/Support/MathExtras.h" using namespace mlir; @@ -263,36 +265,161 @@ Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, Value threadMask = int_val(type.getIntOrFloatBitWidth(), -1); return rewriter.create(loc, type, threadMask, cmp); } -void TargetInfo::storeShared(RewriterBase &rewriter, Location loc, Value ptr, - Value val, Value pred) const { + +static Value mapa(RewriterBase &rewriter, Location loc, Value ptr, Value ctaid, + Value pred) { + PTXBuilder builder; + (*builder.create<>("mapa.shared::cluster.u32"))( + builder.newOperand("=r"), // + builder.newAddrOperand(ptr, "r"), builder.newAddrOperand(ctaid, "r")) + .predicate(pred, "b"); + return builder.launch(rewriter, loc, i32_ty, /*hasSideEffects=*/false); +} + +static std::string getConstraintForBitwidth(unsigned bitwidth) { + switch (bitwidth) { + case 8: + case 16: + return "h"; + case 32: + return "r"; + case 64: + return "l"; + default: + llvm_unreachable("unsupported bitwidth"); + } +} + +void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const { MLIRContext *ctx = rewriter.getContext(); - unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); - const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); + auto ptrTy = cast(ptr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + // Simpliy the special case of a single-element vector. + if (auto vecTy = dyn_cast(val.getType())) { + if (vecTy.getNumElements() == 1) { + val = extract_element(val, i32_val(0)); + } + } + + auto vecTy = dyn_cast(val.getType()); + unsigned vec; + unsigned bitwidth; + if (vecTy) { + vec = vecTy.getNumElements(); + bitwidth = vecTy.getElementType().getIntOrFloatBitWidth(); + assert(bitwidth >= 8 && "can't load/store vectors with sub-byte elems"); + } else { + vec = 1; + bitwidth = std::max(8u, val.getType().getIntOrFloatBitWidth()); + } + assert(llvm::isPowerOf2_32(vec)); + + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, split it into multiple ops. + if (vec > 4) { + // TODO(jlebar): Implement this once we can write a testcase. + assert(false && "not yet implemented"); + } + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } PTXBuilder builder; + auto st = builder.create<>("st") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .b(bitwidth) + .v(vec, /*predicate=*/vec > 1); + + PTXBuilder::Operand *valOpr; auto *ptrOpr = builder.newAddrOperand(ptr, "r"); - auto *valOpr = builder.newOperand(val, c); - auto &st = builder.create<>("st")->shared().b(bits); + + std::string elemConstraint = getConstraintForBitwidth(bitwidth); + if (vecTy) { + SmallVector vecVals; + for (int i = 0; i < vec; i++) { + vecVals.push_back(extract_element(val, i32_val(i))); + } + valOpr = builder.newListOperand(vec, elemConstraint); + } else { + valOpr = builder.newOperand(val, elemConstraint); + } st(ptrOpr, valOpr).predicate(pred, "b"); builder.launch(rewriter, loc, void_ty(ctx)); } -Value TargetInfo::loadShared(RewriterBase &rewriter, Location loc, - const TypeConverter *converter, Value ptr, - Type elemTy, Value pred) const { +Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type loadTy, + Value pred) const { MLIRContext *ctx = rewriter.getContext(); auto ptrTy = cast(ptr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); - unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + + auto vecTy = dyn_cast(loadTy); + unsigned vec; + unsigned bitwidth; + if (vecTy) { + vec = vecTy.getNumElements(); + bitwidth = vecTy.getElementType().getIntOrFloatBitWidth(); + assert(bitwidth >= 8 && "can't load/store vectors with sub-byte elems"); + } else { + vec = 1; + bitwidth = std::max(8u, loadTy.getIntOrFloatBitWidth()); + } + assert(llvm::isPowerOf2_32(vec)); - const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); + // load/store ops only support v2 and v4. If the vector width is larger than + // 4, split it into multiple ops. + if (vec > 4) { + // TODO(jlebar): Implement this once we can write a testcase. + assert(false && "not yet implemented"); + } + + // Get pointer to remote shared memory if needed. + if (ctaId.has_value()) { + ptr = mapa(rewriter, loc, ptr, *ctaId, pred); + } PTXBuilder builder; - auto *dOpr = builder.newOperand(c); - auto *ptrOpr = builder.newAddrOperand(ptr, "r"); - auto &ld = builder.create<>("ld")->shared().b(bitwidth); - ld(dOpr, ptrOpr).predicate(pred, "b"); - return builder.launch(rewriter, loc, elemTy); + auto ld = builder.create<>("ld") + ->o("shared::cta", ctaId.has_value()) + .o("shared", !ctaId.has_value()) + .b(bitwidth) + .v(vec, /*predicate=*/vec > 1); + + std::string elemConstraint = "=" + getConstraintForBitwidth(bitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy; + if (vec == 1) { + resultTy = int_ty(bitwidth); + } else { + resultTy = struct_ty(SmallVector(vec, int_ty(bitwidth))); + } + Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + + if (vecTy) { + // Unpack the struct returned by the inline asm into a vector. + SmallVector vals; + for (int i = 0; i < vec; i++) { + auto elem = extract_val(int_ty(bitwidth), load, i); + vals.push_back(bitcast(elem, vecTy.getElementType())); + } + Value ret = undef(loadTy); + for (int i = 0; i < vec; i++) { + ret = insert_element(ret, i32_val(i), vals[i]); + } + return ret; + } else { + return bitcast(load, loadTy); + } } Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 8572feeb6f..7a7cd72c71 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -16,11 +16,12 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, - Value pred) const override; - Value loadShared(RewriterBase &rewriter, Location loc, - const TypeConverter *converter, Value ptr, Type elemTy, - Value pred) const override; + void storeDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Value val, + Value pred) const override; + Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr, + std::optional ctaId, Type elemTy, + Value pred) const override; Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 685b836208..86421e19b1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -99,92 +99,6 @@ Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, return builder.launch(rewriter, loc, rewriter.getIntegerType(32), false); } -// A wrapper of LoadDSmemOp when vec = 1 -// (1) Get bitwidth from elemTy -// (2) Create LoadDSmemOp -// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy -Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Type elemTy) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value ret = - rewriter.create(loc, addr, ctaId, bitwidth); - return bitcast(ret, elemTy); -} - -// A wrapper of LoadDSmemOp when vec > 1 -// (1) Get bitwidth from elemTy -// (2) Create LoadDSmemOp and extract results from retStruct -// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy -SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, - Value addr, Value ctaId, unsigned vec, - Type elemTy) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value retStruct = rewriter.create( - loc, addr, ctaId, bitwidth, vec); - SmallVector retVals; - for (unsigned i = 0; i < vec; ++i) { - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = extract_val(dataTy, retStruct, i); - retVals.push_back(bitcast(data, elemTy)); - } - return retVals; -} - -// A wrapper of StoreDSmemOp when vec = 1 -// (1) Get bitwidth from elemTy -// (2) Bitcast value from elemTy to dataTy (u16/u32/u64) -// (3) Create StoreDSmemOp -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value, Value pred) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = value.getType().getIntOrFloatBitWidth(); - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = bitcast(value, dataTy); - rewriter.create(loc, addr, ctaId, data, pred); -} - -// A wrapper of StoreDSmemOp when vec = 1 and pred = 1 -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value) { - Value pred = int_val(/*width=*/1, 1); - createStoreDSmem(loc, rewriter, addr, ctaId, value, pred); -} - -// A wrapper of StoreDSmemOp when vec > 1 -// (1) Get bitwidth from elemTy -// (2) Bitcast values from elemTy to dataTy (u16/u32/u64) -// (3) Create StoreDSmemOp -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values, Value pred) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = 0; - if (!values.empty()) { - bitwidth = values.back().getType().getIntOrFloatBitWidth(); - } - auto dataTy = rewriter.getIntegerType(bitwidth); - SmallVector data; - for (unsigned i = 0; i < values.size(); ++i) - data.push_back(bitcast(values[i], dataTy)); - rewriter.create(loc, addr, ctaId, data, pred); -} - -// A wrapper of StoreDSmemOp when vec > 1 and pred = 1 -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values) { - Value pred = int_val(/*width=*/1, 1); - createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); -} - /// Create a predicate with just single active thread. Value createElectPredicate(Location loc, PatternRewriter &rewriter) { PTXBuilder ptxBuilder; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index 3d3eeb1aff..97c4a48d7e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -29,11 +29,6 @@ using namespace mlir::triton; ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \ } while (0) -#define load_dsmem(...) \ - ::mlir::LLVM::NVIDIA::createLoadDSmem(loc, rewriter, __VA_ARGS__) -#define store_dsmem(...) \ - ::mlir::LLVM::NVIDIA::createStoreDSmem(loc, rewriter, __VA_ARGS__) - namespace mlir { namespace LLVM { @@ -50,29 +45,6 @@ Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, int axis); -/// Usage of macro load_dsmem -/// (1) load_dsmem(addr, ctaId) -/// (2) load_dsmem(addr, ctaId, vec) -Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Type elemTy); -SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, - Value addr, Value ctaId, unsigned vec, - Type elemTy); - -/// Usage of macro store_dsmem -/// (1) store_dsmem(addr, ctaId, value, pred) -/// (2) store_dsmem(addr, ctaId, value) -/// (3) store_dsmem(addr, ctaId, values, pred) -/// (4) store_dsmem(addr, ctaId, values) -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value, Value pred); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values, Value pred); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values); - /// Create a predicate with just single active thread. Value createElectPredicate(Location loc, PatternRewriter &rewriter); From 83a9b340fea73f4cf25a07b782671dc0067e9881 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Fri, 14 Jun 2024 20:50:02 -0700 Subject: [PATCH 4/5] thread-locality: Do not segfault for reduce on forOp arguments (#4144) If definingOp for the reduce operands is nullptr, bail out instead of seg fault. --------- Co-authored-by: Manman Ren --- .../Transforms/OptimizeThreadLocality.cpp | 3 +- test/TritonGPU/optimize-locality.mlir | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp index 30211da081..3775b4f7d8 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -121,8 +121,7 @@ class TritonGPUOptimizeThreadLocalityPass if (!(isa(srcEncoding) && rank > 1)) return; for (auto operand : reduce->getOperands()) { - auto def = operand.getDefiningOp(); - if (!isa(def)) + if (!operand.getDefiningOp()) return; } auto elemsPerThread = diff --git a/test/TritonGPU/optimize-locality.mlir b/test/TritonGPU/optimize-locality.mlir index 9504fe20ec..5073f997d4 100644 --- a/test/TritonGPU/optimize-locality.mlir +++ b/test/TritonGPU/optimize-locality.mlir @@ -595,3 +595,31 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return %1 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> } } + +// ----- +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +#slice = #triton_gpu.slice<{dim = 1, parent = #blocked}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { + tt.func public @reduce_for_arg(%arg: tensor<64x128xf32, #blocked>, %arg1: !tt.ptr) { + %c0_i32 = arith.constant 0 : i32 + %c128_i32 = arith.constant 128 : i32 + %c4096_i32 = arith.constant 4096 : i32 + %cst_1 = arith.constant dense<1.000000e+00> : tensor<64x128xf32, #blocked> + %64:1 = scf.for %arg22 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg29 = %arg) -> (tensor<64x128xf32, #blocked>) : i32 { + %129 = "tt.reduce"(%arg29) <{axis = 1 : i32}> ({ + ^bb0(%arg31: f32, %arg32: f32): + %160 = arith.maxnumf %arg31, %arg32 : f32 + tt.reduce.return %160 : f32 + }) : (tensor<64x128xf32, #blocked>) -> tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %75 = triton_gpu.convert_layout %129 : tensor<64xf32, #slice> -> tensor<64xf32, #blocked1> + %79 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked1> + %80 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr, #blocked1> + %81 = tt.addptr %80, %79 : tensor<64x!tt.ptr, #blocked1>, tensor<64xi32, #blocked1> + tt.store %81, %75 : tensor<64x!tt.ptr, #blocked1> + %141 = arith.addf %arg29, %cst_1 : tensor<64x128xf32, #blocked> + scf.yield %141 : tensor<64x128xf32, #blocked> + } + tt.return + } +} From 4f94c88498767d32420adb7f1b8d45c956a388ac Mon Sep 17 00:00:00 2001 From: HyoungWook Nam Date: Sat, 15 Jun 2024 08:57:08 -0500 Subject: [PATCH 5/5] [TEST] Fixing precision bug in fp8 test_dot (#4131) There is a precision bug in test_core.py::test_dot for chain-dot cases. Inside the kernel, it does fp8xfp8->fp32 dot-product (`z=xy`) first. Then, for chain-dot, it casts the output `z` back to fp8 and do the fp8xfp8->fp32 dot-product again `z=zw`. However, the reference numpy computation *does not* cast the intermediate output `z_ref` to fp8. Therefore, the second dot-product becomes fp32xfp8->fp32, whose result is different from the kernel output. In some fp8 setup (float8e4nv in our case), it sometimes causes a test failure due to this precision issue. I have fixed the reference computation to reduce the precision of the intermediate output. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --- python/test/unit/language/test_core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 20a41c8dd4..a3488c178e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3206,6 +3206,14 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid z_ref = num / denom if epilogue == 'chain-dot': if 'float8' in in_dtype: + # Reduce z_ref's precision to fp8 to match the kernel behavior + if in_dtype == 'float8e4nv': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e4m3fn) + elif in_dtype == 'float8e5': + z_fp8 = torch.tensor(z_ref, dtype=torch.float8_e5m2) + else: + assert "Unsupported float8 dtype" + z_ref = to_numpy(z_fp8.to(torch.float32)) w = to_numpy(convert_fp8_to_fp32(w, device, in_dtype)) z_ref = np.matmul(z_ref, w) # compare