diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index cae596fafe..070e375ac8 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,5 +1,5 @@ #pragma once -#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -27,7 +27,6 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/InitAllPasses.h" -#include "triton/Tools/Sys/GetEnv.hpp" namespace mlir { namespace test { diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 4f54b761da..d1494fd7ee 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -4,7 +4,7 @@ #include "TargetInfoBase.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "triton/Analysis/AxisInfo.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" + using namespace mlir; using namespace mlir::triton; @@ -33,6 +33,7 @@ void populateElementwiseOpToLLVMPatterns( PatternBenefit benefit); void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit); @@ -42,6 +43,7 @@ void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit); void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index a7dbaee9bf..4e0e41295e 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -8,6 +8,8 @@ class TargetInfoBase { public: virtual bool supportMaximumMinimum() const = 0; + virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; + virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const = 0; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 3b84850977..0b953f7ac0 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -5,7 +5,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "llvm/Support/ErrorHandling.h" @@ -59,8 +59,6 @@ using namespace mlir::triton; rewriter.create(loc, __VA_ARGS__) #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(...) rewriter.create(loc, __VA_ARGS__) -#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__) -#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__) #define fcmp_ogt(lhs, rhs) \ rewriter.create(loc, rewriter.getI1Type(), \ LLVM::FCmpPredicate::ogt, lhs, rhs) @@ -222,29 +220,6 @@ Value createIndexConstant(OpBuilder &builder, Location loc, Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, int64_t value); -/// Usage of macro load_dsmem -/// (1) load_dsmem(addr, ctaId) -/// (2) load_dsmem(addr, ctaId, vec) -Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Type elemTy); -SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, - Value addr, Value ctaId, unsigned vec, - Type elemTy); - -/// Usage of macro store_dsmem -/// (1) store_dsmem(addr, ctaId, value, pred) -/// (2) store_dsmem(addr, ctaId, value) -/// (3) store_dsmem(addr, ctaId, values, pred) -/// (4) store_dsmem(addr, ctaId, values) -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value, Value pred); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values, Value pred); -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values); - /// Helper function to get strides from a given shape and its order SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, @@ -354,6 +329,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, // smallest CTA tile that is common between input and output layouts. SmallVector getMultiDimOffset(Attribute layout, Location loc, ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, ArrayRef shapePerCTATile); @@ -416,11 +392,6 @@ inline Value getThreadId(RewriterBase &rewriter, Location loc) { return tid; } -inline Value getClusterCTAId(RewriterBase &rewriter, Location loc) { - return rewriter.create(loc, - rewriter.getI32Type()); -} - // ----------------------------------------------------------------------- // Shared memory utilities // ----------------------------------------------------------------------- @@ -1023,6 +994,7 @@ emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, inline SmallVector emitCTAOffsetForLayout(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, ArrayRef shape) { unsigned rank = shape.size(); @@ -1033,7 +1005,7 @@ inline SmallVector emitCTAOffsetForLayout(Location loc, triton::gpu::getShapePerCTA(CTASplitNum, shape); // Delinearize clusterCTAId - Value clusterCTAId = getClusterCTAId(rewriter, loc); + Value clusterCTAId = target.getClusterCTAId(rewriter, loc); SmallVector multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); @@ -1051,11 +1023,10 @@ inline SmallVector emitCTAOffsetForLayout(Location loc, return CTAOffset; } -inline SmallVector emitBaseIndexForLayoutImpl(Location loc, - RewriterBase &rewriter, - Attribute layout, - RankedTensorType type, - bool withCTAOffset) { +inline SmallVector +emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, + RankedTensorType type, bool withCTAOffset) { auto shape = type.getShape(); SmallVector baseIndex; @@ -1080,8 +1051,8 @@ inline SmallVector emitBaseIndexForLayoutImpl(Location loc, auto parentShape = sliceLayout.paddedShape(type.getShape()); RankedTensorType parentTy = RankedTensorType::get(parentShape, type.getElementType(), parentLayout); - result = emitBaseIndexForLayoutImpl(loc, rewriter, parentLayout, parentTy, - withCTAOffset); + result = emitBaseIndexForLayoutImpl(loc, rewriter, target, parentLayout, + parentTy, withCTAOffset); result.erase(result.begin() + sliceLayout.getDim()); // CTAOffset has been added in emitBaseIndexForLayout of parentLayout return result; @@ -1089,7 +1060,8 @@ inline SmallVector emitBaseIndexForLayoutImpl(Location loc, llvm_unreachable("unsupported emitBaseIndexForLayout"); } if (withCTAOffset) { - auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape); + auto CTAOffset = + emitCTAOffsetForLayout(loc, rewriter, target, layout, shape); assert(CTAOffset.size() == result.size() && "Rank mismatch"); for (unsigned k = 0; k < result.size(); ++k) { // Individual elements of `result` may be null. In the caller @@ -1104,10 +1076,11 @@ inline SmallVector emitBaseIndexForLayoutImpl(Location loc, } inline SmallVector -emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout, +emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, + const TargetInfoBase &target, Attribute layout, RankedTensorType type, bool withCTAOffset) { - SmallVector idx = - emitBaseIndexForLayoutImpl(loc, rewriter, layout, type, withCTAOffset); + SmallVector idx = emitBaseIndexForLayoutImpl( + loc, rewriter, target, layout, type, withCTAOffset); // Check that any null values were sliced out. for (Value v : idx) { @@ -1151,11 +1124,11 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) { // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. inline SmallVector> -emitIndices(Location loc, RewriterBase &rewriter, Attribute layout, - RankedTensorType type, bool withCTAOffset) { +emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, + Attribute layout, RankedTensorType type, bool withCTAOffset) { // step 1, delinearize threadId to get the base index - auto multiDimBase = - emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset); + auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout, + type, withCTAOffset); // step 2, get offset of each element auto offset = emitOffsetForLayout(layout, type); // step 3, add offset to base, and reorder the sequence @@ -1175,9 +1148,9 @@ emitIndices(Location loc, RewriterBase &rewriter, Attribute layout, /* ---------------- */ /* ---------------- */ inline DenseMap getSwizzledSharedPtrs( - Location loc, unsigned inVec, RankedTensorType srcTy, - triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy, - SharedMemoryObject smemObj, RewriterBase &rewriter, + Location loc, const TargetInfoBase &target, unsigned inVec, + RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout, + Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter, SmallVectorImpl &offsetVals, SmallVectorImpl &srcStrides) { // This utility computes the pointers for accessing the provided swizzled // shared memory layout `resSharedLayout`. More specifically, it computes, @@ -1224,7 +1197,8 @@ inline DenseMap getSwizzledSharedPtrs( outVec * maxPhase <= srcShape[outOrder[0]] && "Swizzling would generate out of bounds memory accesses"); // Tensor indices held by the current thread, as LLVM values - auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false); + auto srcIndices = + emitIndices(loc, rewriter, target, srcEncoding, srcTy, false); // Swizzling with leading offsets (e.g. Hopper GMMA) unsigned swizzlingByteWidth = 0; if (resSharedLayout.getHasLeadingOffset()) { @@ -1336,10 +1310,9 @@ inline DenseMap getSwizzledSharedPtrs( return ret; } -inline SmallVector -loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj, - Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { +inline SmallVector loadSharedToDistributed( + Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc, + ConversionPatternRewriter &rewriter, const TargetInfoBase &target) { auto dstTy = cast(dst.getType()); auto dstShape = dstTy.getShape(); assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed"); @@ -1373,7 +1346,7 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj, SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, outVec, dstTy, srcSharedLayout, elemTy, + getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy, smemObj, rewriter, offsetVals, smemObj.strides); assert(outElems % minVec == 0 && "Unexpected number of elements"); unsigned numVecs = outElems / minVec; @@ -1395,7 +1368,8 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj, inline void storeDistributedToShared(Value src, ArrayRef inVals, ArrayRef dstStrides, Value dst, Value smemBase, Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &target) { auto srcTy = cast(src.getType()); auto srcShape = srcTy.getShape(); auto rank = srcShape.size(); @@ -1432,8 +1406,8 @@ inline void storeDistributedToShared(Value src, ArrayRef inVals, SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals); DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, elemTy, smemObj, - rewriter, offsetVals, srcStrides); + getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy, + smemObj, rewriter, offsetVals, srcStrides); LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = " << minVec << " " << wordTy); for (unsigned i = 0; i < numElems; ++i) { diff --git a/include/triton/Dialect/CMakeLists.txt b/include/triton/Dialect/CMakeLists.txt index d49732a513..49ab196b4c 100644 --- a/include/triton/Dialect/CMakeLists.txt +++ b/include/triton/Dialect/CMakeLists.txt @@ -3,4 +3,3 @@ add_subdirectory(TritonGEN) add_subdirectory(TritonGPU) add_subdirectory(TritonIntelGPU) add_subdirectory(TritonNvidiaGPU) -add_subdirectory(NVGPU) diff --git a/include/triton/Dialect/NVGPU/CMakeLists.txt b/include/triton/Dialect/NVGPU/CMakeLists.txt deleted file mode 100644 index 218c20c881..0000000000 --- a/include/triton/Dialect/NVGPU/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -#add_subdirectory(Transforms) diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index a4a90e3c46..a3f55f1e73 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 499a182dcb..bff2a2d6b7 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,3 +1,5 @@ +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Analysis/Allocation.h" @@ -25,8 +27,11 @@ namespace { struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { public: - using ConvertOpToLLVMPattern< - triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern; + LocalLoadOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } LogicalResult matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor, @@ -93,25 +98,28 @@ struct LocalLoadOpConversion auto srcStrides = getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter); - SmallVector outVals = loadSharedToDistributed( - op.getResult(), op.getSrc(), smemObj, elemTy, loc, rewriter); + SmallVector outVals = + loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy, + loc, rewriter, targetInfo); Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); return success(); } + +private: + const TargetInfoBase &targetInfo; }; struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { public: - explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, - benefit), - targetInfo(targetInfo) {} + ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } LogicalResult matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -179,7 +187,7 @@ struct ConvertLayoutOpConversion // of performance issue observed. for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { SmallVector multiDimOffset = - getMultiDimOffset(layout, loc, rewriter, elemId, type, + getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, multiDimCTAInRepId, shapePerCTATile); SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, @@ -315,5 +323,5 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp index 82a08c51a5..acf940b3e8 100644 --- a/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp @@ -185,8 +185,8 @@ struct HistogramOpConversion LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); auto dstType = op.getType(); Attribute dstEncoding = dstType.getEncoding(); - auto indices = - emitIndices(op.getLoc(), rewriter, dstEncoding, dstType, true); + auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding, + dstType, true); SmallVector innerDimIndices; for (int i = 0; i < indices.size(); ++i) innerDimIndices.push_back(indices[i][0]); diff --git a/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp index 46297b6ae8..43120c7913 100644 --- a/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MakeRangeOpToLLVM.cpp @@ -1,4 +1,6 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace { @@ -7,8 +9,11 @@ using namespace mlir; using namespace mlir::triton; struct MakeRangeOpConversion : public ConvertOpToLLVMPattern { - MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} + MakeRangeOpConversion(LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -19,7 +24,7 @@ struct MakeRangeOpConversion auto elemTy = ty.getElementType(); assert(elemTy.isInteger(32)); Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart()); - auto idxs = emitIndices(loc, rewriter, layout, ty, true); + auto idxs = emitIndices(loc, rewriter, targetInfo, layout, ty, true); unsigned elems = idxs.size(); SmallVector retVals(elems); // TODO: slice layout has more elements than expected. @@ -34,12 +39,15 @@ struct MakeRangeOpConversion rewriter.replaceOp(op, result); return success(); } + +private: + const TargetInfoBase &targetInfo; }; } // namespace void mlir::triton::populateMakeRangeOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 126a506205..7f2adf0558 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -1,5 +1,10 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace { @@ -12,7 +17,8 @@ using namespace mlir::triton::gpu; // A/B operands of dots. void lowerDistributedToShared(LocalAllocOp op, LocalAllocOpAdaptor adaptor, const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); @@ -32,13 +38,16 @@ void lowerDistributedToShared(LocalAllocOp op, LocalAllocOpAdaptor adaptor, LLVM::getStridesFromShapeAndOrder(dstShapePerCTA, outOrd, loc, rewriter); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); storeDistributedToShared(op.getSrc(), inVals, dstStrides, op.getResult(), - smemBase, elemTy, loc, rewriter); + smemBase, elemTy, loc, rewriter, targetInfo); } struct LocalAllocOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - triton::gpu::LocalAllocOp>::ConvertOpToLLVMPattern; + LocalAllocOpConversion(const LLVMTypeConverter &converter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} LogicalResult matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, @@ -65,7 +74,8 @@ struct LocalAllocOpConversion // If there is an initial tensor, store it into the shared memory. if (op.getSrc()) { - lowerDistributedToShared(op, adaptor, typeConverter, rewriter); + lowerDistributedToShared(op, adaptor, typeConverter, rewriter, + targetInfo); } auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); @@ -76,6 +86,9 @@ struct LocalAllocOpConversion rewriter.replaceOp(op, retVal); return success(); } + +private: + const TargetInfoBase &targetInfo; }; struct LocalDeallocOpConversion @@ -94,8 +107,8 @@ struct LocalDeallocOpConversion } // namespace void mlir::triton::populateMemoryOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp index 1a31b40648..a2d047dc3a 100644 --- a/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp @@ -58,8 +58,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { SmallVector> indices; if (auto rankedTy = op.getOperand(i).getType().dyn_cast()) { - indices = - emitIndices(loc, rewriter, rankedTy.getEncoding(), rankedTy, true); + indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(), + rankedTy, true); for (int64_t dim : rankedTy.getShape()) { if (dim > 0) { dimWidths.push_back(static_cast(std::ceil(std::log10(dim)))); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 331cf70998..a0ecc41abb 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -147,8 +147,8 @@ struct ReduceOpConversion emitOffsetForLayout(helper.getSrcLayout(), operandType); unsigned srcElems = getTotalElemsPerThread(operandType); auto *combineOp = &op.getCombineOp(); - auto srcIndices = emitIndices(op.getLoc(), rewriter, helper.getSrcLayout(), - operandType, true); + auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo, + helper.getSrcLayout(), operandType, true); // reduce within threads for (unsigned i = 0; i < srcElems; ++i) { SmallVector key = offset[i]; @@ -377,8 +377,8 @@ struct ReduceOpConversion // nd-tensor where n >= 1 auto resultLayout = resultTy.getEncoding().cast(); unsigned resultElems = getTotalElemsPerThread(resultTy); - auto resultIndices = - emitIndices(loc, rewriter, resultLayout, resultTy, true); + auto resultIndices = emitIndices(loc, rewriter, targetInfo, + resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); assert(resultIndices.size() == resultElems); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h index c873538474..3130001cc5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -11,7 +11,6 @@ #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include #include diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 8b6b5f5270..98b2e06592 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -1,7 +1,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" namespace SharedToDotOperandMMAv1 { using CoordTy = SmallVector; @@ -168,92 +168,6 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } -// A wrapper of LoadDSmemOp when vec = 1 -// (1) Get bitwidth from elemTy -// (2) Create LoadDSmemOp -// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy -Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Type elemTy) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value ret = - rewriter.create(loc, addr, ctaId, bitwidth); - return bitcast(ret, elemTy); -} - -// A wrapper of LoadDSmemOp when vec > 1 -// (1) Get bitwidth from elemTy -// (2) Create LoadDSmemOp and extract results from retStruct -// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy -SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, - Value addr, Value ctaId, unsigned vec, - Type elemTy) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); - Value retStruct = rewriter.create( - loc, addr, ctaId, bitwidth, vec); - SmallVector retVals; - for (unsigned i = 0; i < vec; ++i) { - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = extract_val(dataTy, retStruct, i); - retVals.push_back(bitcast(data, elemTy)); - } - return retVals; -} - -// A wrapper of StoreDSmemOp when vec = 1 -// (1) Get bitwidth from elemTy -// (2) Bitcast value from elemTy to dataTy (u16/u32/u64) -// (3) Create StoreDSmemOp -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value, Value pred) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = value.getType().getIntOrFloatBitWidth(); - auto dataTy = rewriter.getIntegerType(bitwidth); - Value data = bitcast(value, dataTy); - rewriter.create(loc, addr, ctaId, data, pred); -} - -// A wrapper of StoreDSmemOp when vec = 1 and pred = 1 -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, Value value) { - Value pred = int_val(/*width=*/1, 1); - createStoreDSmem(loc, rewriter, addr, ctaId, value, pred); -} - -// A wrapper of StoreDSmemOp when vec > 1 -// (1) Get bitwidth from elemTy -// (2) Bitcast values from elemTy to dataTy (u16/u32/u64) -// (3) Create StoreDSmemOp -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values, Value pred) { - assert(isa(addr.getType()) && "addr must be a pointer type"); - auto ptrTy = cast(addr.getType()); - assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); - unsigned bitwidth = 0; - if (!values.empty()) { - bitwidth = values.back().getType().getIntOrFloatBitWidth(); - } - auto dataTy = rewriter.getIntegerType(bitwidth); - SmallVector data; - for (unsigned i = 0; i < values.size(); ++i) - data.push_back(bitcast(values[i], dataTy)); - rewriter.create(loc, addr, ctaId, data, pred); -} - -// A wrapper of StoreDSmemOp when vec > 1 and pred = 1 -void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, - Value ctaId, ArrayRef values) { - Value pred = int_val(/*width=*/1, 1); - createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); -} - SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, ConversionPatternRewriter &rewriter) { @@ -395,14 +309,15 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, SmallVector getMultiDimOffset(Attribute layout, Location loc, ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, ArrayRef shapePerCTATile) { auto shape = type.getShape(); unsigned rank = shape.size(); if (auto blockedLayout = layout.dyn_cast()) { - auto multiDimOffsetFirstElem = - emitBaseIndexForLayout(loc, rewriter, blockedLayout, type, false); + auto multiDimOffsetFirstElem = emitBaseIndexForLayout( + loc, rewriter, targetInfo, blockedLayout, type, false); SmallVector multiDimOffset(rank); SmallVector multiDimElemId = getMultiDimIndex( elemId, getSizePerThread(layout), getOrder(layout)); @@ -429,10 +344,10 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, auto it = std::find(parentOffset.begin(), parentOffset.end(), off); idxs.push_back(std::distance(parentOffset.begin(), it)); } - auto multiDimOffsetParent = - getMultiDimOffset(parentEncoding, loc, rewriter, idxs[elemId], parentTy, - sliceLayout.paddedShape(multiDimCTAInRepId), - sliceLayout.paddedShape(shapePerCTATile)); + auto multiDimOffsetParent = getMultiDimOffset( + parentEncoding, loc, rewriter, targetInfo, idxs[elemId], parentTy, + sliceLayout.paddedShape(multiDimCTAInRepId), + sliceLayout.paddedShape(shapePerCTATile)); SmallVector multiDimOffset(rank); for (unsigned d = 0; d < rank + 1; ++d) { if (d == dim) @@ -533,7 +448,7 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, } if (layout.isa()) { auto multiDimBase = - emitBaseIndexForLayout(loc, rewriter, layout, type, false); + emitBaseIndexForLayout(loc, rewriter, targetInfo, layout, type, false); SmallVector> offsets; assert(rank == 2); SmallVector multiDimOffset(rank); diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index d49732a513..49ab196b4c 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -3,4 +3,3 @@ add_subdirectory(TritonGEN) add_subdirectory(TritonGPU) add_subdirectory(TritonIntelGPU) add_subdirectory(TritonNvidiaGPU) -add_subdirectory(NVGPU) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index e0016a6957..eacf448062 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,13 +1,8 @@ -#include "mlir/IR/Matchers.h" -#include "mlir/IR/TypeUtilities.h" - #include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" #include "Utility.h" - #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include - using namespace mlir; using ::mlir::LLVM::delinearize; @@ -20,11 +15,11 @@ namespace { // Return the mask for the unique data accessed by given tensor type. // Used to mask out the redundant data accessed by threads. Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, - Location loc) { + Location loc, const AMD::TargetInfo &targetInfo) { auto tensorTy = valueTy.dyn_cast(); Value mask = int_val(1, 1); auto tid = tid_val(); - auto clusterCTAId = getClusterCTAId(rewriter, loc); + auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); if (tensorTy) { auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); @@ -92,8 +87,9 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { - explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) - : axisAnalysisPass(axisAnalysisPass) {} + explicit LoadStoreConversionBase(const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} unsigned getContiguity(Value ptr) const { auto tensorTy = ptr.getType().dyn_cast(); @@ -118,6 +114,7 @@ struct LoadStoreConversionBase { protected: ModuleAxisInfoAnalysis &axisAnalysisPass; + const AMD::TargetInfo &targetInfo; }; struct LoadOpConversion : public ConvertOpToLLVMPattern, @@ -125,10 +122,11 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LoadOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, @@ -229,10 +227,11 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; StoreOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, @@ -269,7 +268,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } - Value mask = redundantDataMask(valueTy, rewriter, loc); + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; @@ -346,10 +345,11 @@ struct AtomicCASOpConversion using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; AtomicCASOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, @@ -387,7 +387,7 @@ struct AtomicCASOpConversion vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } - Value mask = redundantDataMask(valueTy, rewriter, loc); + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -481,10 +481,11 @@ struct AtomicRMWOpConversion using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; AtomicRMWOpConversion(LLVMTypeConverter &converter, + const AMD::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} /// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp. static std::optional matchAtomicOp(RMWOp atomicOp) { @@ -635,13 +636,13 @@ struct AtomicRMWOpConversion namespace mlir::triton::AMD { void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); + patterns.add(typeConverter, targetInfo, axisInfoAnalysis, + benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 71ea1dded4..e904f6794f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -21,6 +21,7 @@ void populateElementwiseOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, const TargetInfo &targetInfo, PatternBenefit benefit); void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, RewritePatternSet &patterns, int numWarps, ModuleAxisInfoAnalysis &axisInfoAnalysis, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 707761e796..3ad153151e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -1,6 +1,7 @@ #include "TargetInfo.h" #include "Utility.h" #include "amd/include/TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -53,6 +54,13 @@ Value printfPromoteValue(ConversionPatternRewriter &rewriter, Value value) { bool TargetInfo::supportMaximumMinimum() const { return false; } +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + // On AMD hardware we don't have CTA clusters like NVIDIA. So this will always + // be zero. Whoever calling into this should make sure the whole program does + // not try to utilize CTA clusters. + return rewriter.create(loc, 0, 32); +} + Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const { auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 08b8828484..dab187ccf4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -14,6 +14,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool supportMaximumMinimum() const override; + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const override; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8a51319c6b..ed0ef89b06 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -2,33 +2,25 @@ #include "PatternTritonGPUOpToLLVM.h" #include "TargetInfo.h" -#include "Utility.h" -#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Tools/Sys/GetPlatform.hpp" namespace mlir { namespace triton { @@ -72,7 +64,6 @@ class TritonLLVMConversionTarget : public ConversionTarget { addLegalDialect(); addLegalDialect(); addLegalDialect(); - addLegalDialect(); addIllegalDialect(); addIllegalDialect(); addIllegalDialect(); @@ -89,8 +80,8 @@ struct ConvertTritonAMDGPUToLLVM } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() override { @@ -191,16 +182,16 @@ struct ConvertTritonAMDGPUToLLVM AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, benefit); populatePatterns6(AMD::populateElementwiseOpToLLVMPatterns); - AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, benefit); + AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, + numWarps, axisInfoAnalysis, benefit); populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns); populatePatterns7(mlir::triton::populateScanOpToLLVMPatterns); populatePatterns5(mlir::triton::populateViewOpToLLVMPatterns); populatePatterns7(mlir::triton::populateHistogramOpToLLVMPatterns); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, patterns, - benefit); - mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, patterns, - benefit); + mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, @@ -225,15 +216,6 @@ struct ConvertTritonAMDGPUToLLVM if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } - - // Fold CTAId when there is only 1 CTA. - if (numCTAs == 1) { - mod.walk([](triton::nvgpu::ClusterCTAIdOp id) { - OpBuilder b(id); - Value zero = LLVM::createConstantI32(id->getLoc(), b, 0); - id.replaceAllUsesWith(zero); - }); - } } private: diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 515542f248..1c0aaa6dd1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" namespace { enum class ShflKind : uint32_t { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index d75ae9008c..298a909a03 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -19,6 +19,12 @@ Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, assert("TODO: implement ballot on XPU"); return Value(); } + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + // Clusters of thread blocks aren't supported. + return i32_val(0); +} + Value TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred) const { LLVM::intel::createPredicatedBlock(rewriter, loc, pred, [&] { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index c1ce4750f1..009414c29a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -18,6 +18,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool supportMaximumMinimum() const override; + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const override; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index 79a362d5f3..c14c5c36b1 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -14,7 +14,6 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -120,15 +119,6 @@ struct ConvertTritonGPUToLLVM if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); - - // Fold CTAId when there is only 1 CTA. - if (numCTAs == 1) { - mod.walk([](triton::nvgpu::ClusterCTAIdOp id) { - OpBuilder b(id); - Value zero = LLVM::createConstantI32(id->getLoc(), b, 0); - id.replaceAllUsesWith(zero); - }); - } } }; diff --git a/third_party/nvidia/include/CMakeLists.txt b/third_party/nvidia/include/CMakeLists.txt index b6c94e796d..2ef7aab106 100644 --- a/third_party/nvidia/include/CMakeLists.txt +++ b/third_party/nvidia/include/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Dialect) add_subdirectory(TritonNVIDIAGPUToLLVM) add_subdirectory(NVGPUToLLVM) diff --git a/third_party/nvidia/include/Dialect/CMakeLists.txt b/third_party/nvidia/include/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..edeac06603 --- /dev/null +++ b/third_party/nvidia/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(NVGPU) diff --git a/lib/Dialect/NVGPU/CMakeLists.txt b/third_party/nvidia/include/Dialect/NVGPU/CMakeLists.txt similarity index 100% rename from lib/Dialect/NVGPU/CMakeLists.txt rename to third_party/nvidia/include/Dialect/NVGPU/CMakeLists.txt diff --git a/include/triton/Dialect/NVGPU/IR/CMakeLists.txt b/third_party/nvidia/include/Dialect/NVGPU/IR/CMakeLists.txt similarity index 100% rename from include/triton/Dialect/NVGPU/IR/CMakeLists.txt rename to third_party/nvidia/include/Dialect/NVGPU/IR/CMakeLists.txt diff --git a/include/triton/Dialect/NVGPU/IR/Dialect.h b/third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h similarity index 87% rename from include/triton/Dialect/NVGPU/IR/Dialect.h rename to third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h index a27b556fed..6e238af4f2 100644 --- a/include/triton/Dialect/NVGPU/IR/Dialect.h +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h @@ -29,14 +29,14 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h.inc" -#include "triton/Dialect/NVGPU/IR/OpsEnums.h.inc" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc" +#include "nvidia/include/Dialect/NVGPU/IR/OpsEnums.h.inc" #define GET_ATTRDEF_CLASSES -#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" +#include "nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" #define GET_OP_CLASSES -#include "triton/Dialect/NVGPU/IR/Ops.h.inc" +#include "nvidia/include/Dialect/NVGPU/IR/Ops.h.inc" namespace mlir { namespace triton { diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.td similarity index 96% rename from include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td rename to third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.td index 20229f1e02..c904824ef0 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUAttrDefs.td @@ -22,8 +22,8 @@ #ifndef NVGPU_ATTRDEFS #define NVGPU_ATTRDEFS -include "triton/Dialect/NVGPU/IR/NVGPUDialect.td" include "mlir/IR/AttrTypeBase.td" +include "NVGPUDialect.td" class NVGPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUDialect.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUDialect.td similarity index 100% rename from include/triton/Dialect/NVGPU/IR/NVGPUDialect.td rename to third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUDialect.td diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td similarity index 98% rename from include/triton/Dialect/NVGPU/IR/NVGPUOps.td rename to third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 3c7c502087..ca9d18873e 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -22,12 +22,12 @@ #ifndef NVGPU_OPS #define NVGPU_OPS -include "triton/Dialect/NVGPU/IR/NVGPUDialect.td" -include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td" include "mlir/IR/OpBase.td" include "mlir/IR/EnumAttr.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "NVGPUDialect.td" +include "NVGPUAttrDefs.td" def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>; def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>; diff --git a/third_party/nvidia/lib/CMakeLists.txt b/third_party/nvidia/lib/CMakeLists.txt index b6c94e796d..2ef7aab106 100644 --- a/third_party/nvidia/lib/CMakeLists.txt +++ b/third_party/nvidia/lib/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Dialect) add_subdirectory(TritonNVIDIAGPUToLLVM) add_subdirectory(NVGPUToLLVM) diff --git a/third_party/nvidia/lib/Dialect/CMakeLists.txt b/third_party/nvidia/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..edeac06603 --- /dev/null +++ b/third_party/nvidia/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(NVGPU) diff --git a/third_party/nvidia/lib/Dialect/NVGPU/CMakeLists.txt b/third_party/nvidia/lib/Dialect/NVGPU/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/nvidia/lib/Dialect/NVGPU/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/third_party/nvidia/lib/Dialect/NVGPU/IR/CMakeLists.txt similarity index 100% rename from lib/Dialect/NVGPU/IR/CMakeLists.txt rename to third_party/nvidia/lib/Dialect/NVGPU/IR/CMakeLists.txt diff --git a/lib/Dialect/NVGPU/IR/Dialect.cpp b/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp similarity index 91% rename from lib/Dialect/NVGPU/IR/Dialect.cpp rename to third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp index 42497df494..ed87a588f2 100644 --- a/lib/Dialect/NVGPU/IR/Dialect.cpp +++ b/third_party/nvidia/lib/Dialect/NVGPU/IR/Dialect.cpp @@ -25,8 +25,8 @@ #include "mlir/IR/OpImplementation.h" // clang-format off -#include "triton/Dialect/NVGPU/IR/Dialect.h" -#include "triton/Dialect/NVGPU/IR/Dialect.cpp.inc" +#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.cpp.inc" // clang-format on using namespace mlir; @@ -88,15 +88,15 @@ static LogicalResult verify(mlir::triton::nvgpu::WGMMAOp op) { void mlir::triton::nvgpu::NVGPUDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST -#include "triton/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" +#include "Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" >(); addOperations< #define GET_OP_LIST -#include "triton/Dialect/NVGPU/IR/Ops.cpp.inc" +#include "Dialect/NVGPU/IR/Ops.cpp.inc" >(); } #define GET_OP_CLASSES -#include "triton/Dialect/NVGPU/IR/Ops.cpp.inc" -#include "triton/Dialect/NVGPU/IR/OpsEnums.cpp.inc" +#include "Dialect/NVGPU/IR/Ops.cpp.inc" +#include "Dialect/NVGPU/IR/OpsEnums.cpp.inc" diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index c61a0d16de..e192165204 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -1,12 +1,12 @@ #include "NVGPUToLLVM/NVGPUToLLVMPass.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp index e918507e75..a5ab6fddb2 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp @@ -21,9 +21,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +#include "Dialect/NVGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" using namespace mlir; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 3432a09b71..3147703c2a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,8 +1,13 @@ #include "PatternTritonGPUOpToLLVM.h" +#include "TargetInfo.h" #include "Utility.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/PatternMatch.h" #include "triton/Analysis/Allocation.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using mlir::isLayoutMmaV1; @@ -203,8 +208,11 @@ struct ConvertLayoutOpOptimizedConversion struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { public: - using ConvertOpToLLVMPattern< - triton::gpu::ConvertLayoutOp>::ConvertOpToLLVMPattern; + ConvertLayoutOpConversion(const LLVMTypeConverter &typeConverter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } LogicalResult matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -276,7 +284,7 @@ struct ConvertLayoutOpConversion // of performance issue observed. for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) { SmallVector multiDimOffset = - getMultiDimOffset(layout, loc, rewriter, elemId, type, + getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, multiDimCTAInRepId, shapePerCTATile); SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, @@ -370,7 +378,7 @@ struct ConvertLayoutOpConversion // TODO[Superjomn]: Move the coordinate computation out of loop, it is // duplicate in Volta. SmallVector multiDimOffset = - getMultiDimOffset(layout, loc, rewriter, elemId, type, + getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type, multiDimCTAInRepId, shapePerCTATile); coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]); } @@ -442,8 +450,8 @@ struct ConvertLayoutOpConversion // Store to local shared memory { auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto inIndices = - emitIndices(loc, rewriter, srcLayout, srcTy, /*withCTAOffset*/ false); + auto inIndices = emitIndices(loc, rewriter, targetInfo, srcLayout, srcTy, + /*withCTAOffset*/ false); assert(inIndices.size() == inVals.size() && "Unexpected number of indices emitted"); @@ -466,8 +474,8 @@ struct ConvertLayoutOpConversion srcShapePerCTACache.push_back(i32_val(srcShapePerCTA[i])); SmallVector outVals; - auto outIndices = - emitIndices(loc, rewriter, dstLayout, dstTy, /*withCTAOffset*/ true); + auto outIndices = emitIndices(loc, rewriter, targetInfo, dstLayout, dstTy, + /*withCTAOffset*/ true); for (unsigned i = 0; i < outIndices.size(); ++i) { auto coord = outIndices[i]; @@ -740,6 +748,9 @@ struct ConvertLayoutOpConversion } return failure(); } + +private: + const NVIDIA::TargetInfo &targetInfo; }; } // namespace @@ -754,7 +765,7 @@ void mlir::triton::NVIDIA::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { // For now give ConvertLayoutOpConversion higher benefit, I can split before // merging - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); // Same default benefit patterns.add(typeConverter, benefit); mlir::triton::populateConvertLayoutOpToLLVMPatterns(typeConverter, targetInfo, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 4ad41fa4de..cfb490dba2 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "TargetInfo.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" @@ -7,8 +8,6 @@ #include "Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include - using namespace mlir; using namespace mlir::triton; @@ -25,11 +24,11 @@ namespace { // Return the mask for the unique data accessed by given tensor type. // Used to mask out the redundant data accessed by threads. Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, - Location loc) { + Location loc, const NVIDIA::TargetInfo &targetInfo) { auto tensorTy = dyn_cast(valueTy); Value mask = int_val(1, 1); auto tid = tid_val(); - auto clusterCTAId = getClusterCTAId(rewriter, loc); + auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc); if (tensorTy) { auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); @@ -98,8 +97,9 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { - explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) - : axisAnalysisPass(axisAnalysisPass) {} + explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass) + : targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {} unsigned getContiguity(Value ptr) const { auto tensorTy = dyn_cast(ptr.getType()); @@ -125,18 +125,18 @@ struct LoadStoreConversionBase { } protected: + const NVIDIA::TargetInfo &targetInfo; ModuleAxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LoadOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, @@ -343,13 +343,12 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, struct StoreOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - StoreOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, @@ -386,7 +385,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } - Value mask = redundantDataMask(valueTy, rewriter, loc); + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; @@ -480,13 +479,12 @@ void createBarrier(ConversionPatternRewriter &rewriter, Location loc, struct AtomicCASOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - AtomicCASOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, @@ -521,7 +519,7 @@ struct AtomicCASOpConversion vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); } - Value mask = redundantDataMask(valueTy, rewriter, loc); + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -595,13 +593,12 @@ struct AtomicCASOpConversion struct AtomicRMWOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - AtomicRMWOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertOpToLLVMPattern(converter, benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, @@ -645,7 +642,7 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } - Value mask = redundantDataMask(valueTy, rewriter, loc); + Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -761,15 +758,12 @@ struct AtomicRMWOpConversion struct AsyncCopyGlobalToLocalOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { - using ConvertOpToLLVMPattern< - triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern; - AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, - benefit), - LoadStoreConversionBase(axisAnalysisPass) {} + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} LogicalResult matchAndRewrite(triton::gpu::AsyncCopyGlobalToLocalOp op, OpAdaptor adaptor, @@ -832,9 +826,9 @@ struct AsyncCopyGlobalToLocalOpConversion unsigned perPhase = resSharedLayout.getPerPhase(); unsigned maxPhase = resSharedLayout.getMaxPhase(); SmallVector offsetVals = {smemObj.strides.size(), i32_val(0)}; - DenseMap sharedPtrs = - getSwizzledSharedPtrs(loc, inVec, srcTy, resSharedLayout, resElemTy, - smemObj, rewriter, offsetVals, smemObj.strides); + DenseMap sharedPtrs = getSwizzledSharedPtrs( + loc, targetInfo, inVec, srcTy, resSharedLayout, resElemTy, smemObj, + rewriter, offsetVals, smemObj.strides); // A sharedLayout encoding has a "vec" parameter. // On the column dimension, if inVec > outVec, it means we have to divide @@ -882,7 +876,7 @@ struct AsyncCopyGlobalToLocalOpConversion // When 'other != 0' is supported, we will need to fold the op.getMask() // and redundantDataMask() into the same predicate, the way it is done // for LoadOp. - Value maskVal = redundantDataMask(srcTy, rewriter, loc); + Value maskVal = redundantDataMask(srcTy, rewriter, loc, targetInfo); // TODO: Masking does not work for CTA multicast with cp.async. This is // a quick and dirty workaround to avoid the issue. @@ -1021,14 +1015,12 @@ struct AsyncCommitGroupOpConversion } // namespace void mlir::triton::NVIDIA::populateLoadStoreOpToLLVMPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) { - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, - axisInfoAnalysis, benefit); + LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, + RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, + PatternBenefit benefit) { + patterns.add( + typeConverter, targetInfo, axisInfoAnalysis, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h index c835671197..1013d5bc21 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -37,6 +37,7 @@ void populateElementwiseOpToLLVMPatterns( const TargetInfo &targetInfo, PatternBenefit benefit); void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfo &targetInfo, RewritePatternSet &patterns, ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 4c88e37dbe..d649a103f6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -1,7 +1,7 @@ #include "TargetInfo.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -218,6 +218,12 @@ static std::optional matchReduxKind(triton::ReduceOp op, bool TargetInfo::supportMaximumMinimum() const { return computeCapability >= 80; } + +Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { + return rewriter.create(loc, + rewriter.getI32Type()); +} + Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const { Value threadMask = int_val(type.getIntOrFloatBitWidth(), -1); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index d8ae340604..73ed25e784 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -1,13 +1,18 @@ #ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H #define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H + #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + namespace mlir::triton::NVIDIA { + class TargetInfo : public mlir::triton::TargetInfoBase { public: TargetInfo(int computeCapability) : computeCapability(computeCapability) {} bool supportMaximumMinimum() const override; + Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; + Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, Value cmp) const override; @@ -50,5 +55,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase { private: int computeCapability; }; + } // namespace mlir::triton::NVIDIA + #endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFONVIDIA_H diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 15bafe0b07..70103f23e6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -1,3 +1,4 @@ +#include "Dialect/NVGPU/IR/Dialect.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -16,7 +17,6 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -138,8 +138,8 @@ struct ConvertTritonGPUToLLVM populateClampFOpToLLVMPattern(typeConverter, patterns, axisInfoAnalysis, computeCapability, patternBenefitClampOptimizedPattern); - populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, - benefit); + populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, + axisInfoAnalysis, benefit); mlir::triton::populateReduceOpToLLVMPatterns(typeConverter, patterns, targetInfo, benefit); mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns, @@ -169,10 +169,10 @@ struct ConvertTritonGPUToLLVM benefit); mlir::triton::populateAssertOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); - mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, patterns, - benefit); - mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, patterns, - benefit); + mlir::triton::populateMemoryOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); + mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, + patterns, benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index b744630685..0752b4577e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -1,7 +1,7 @@ #include "Utility.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" namespace mlir { namespace LLVM { @@ -104,6 +104,92 @@ Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, return builder.launch(rewriter, loc, rewriter.getIntegerType(32), false); } +// A wrapper of LoadDSmemOp when vec = 1 +// (1) Get bitwidth from elemTy +// (2) Create LoadDSmemOp +// (3) Bitcast result from dataTy (u16/u32/u64) back to elemTy +Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Type elemTy) { + assert(isa(addr.getType()) && "addr must be a pointer type"); + auto ptrTy = cast(addr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + Value ret = + rewriter.create(loc, addr, ctaId, bitwidth); + return bitcast(ret, elemTy); +} + +// A wrapper of LoadDSmemOp when vec > 1 +// (1) Get bitwidth from elemTy +// (2) Create LoadDSmemOp and extract results from retStruct +// (3) Bitcast results from dataTy (u16/u32/u64) back to elemTy +SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, + Value addr, Value ctaId, unsigned vec, + Type elemTy) { + assert(isa(addr.getType()) && "addr must be a pointer type"); + auto ptrTy = cast(addr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + Value retStruct = rewriter.create( + loc, addr, ctaId, bitwidth, vec); + SmallVector retVals; + for (unsigned i = 0; i < vec; ++i) { + auto dataTy = rewriter.getIntegerType(bitwidth); + Value data = extract_val(dataTy, retStruct, i); + retVals.push_back(bitcast(data, elemTy)); + } + return retVals; +} + +// A wrapper of StoreDSmemOp when vec = 1 +// (1) Get bitwidth from elemTy +// (2) Bitcast value from elemTy to dataTy (u16/u32/u64) +// (3) Create StoreDSmemOp +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value, Value pred) { + assert(isa(addr.getType()) && "addr must be a pointer type"); + auto ptrTy = cast(addr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + unsigned bitwidth = value.getType().getIntOrFloatBitWidth(); + auto dataTy = rewriter.getIntegerType(bitwidth); + Value data = bitcast(value, dataTy); + rewriter.create(loc, addr, ctaId, data, pred); +} + +// A wrapper of StoreDSmemOp when vec = 1 and pred = 1 +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value) { + Value pred = int_val(/*width=*/1, 1); + createStoreDSmem(loc, rewriter, addr, ctaId, value, pred); +} + +// A wrapper of StoreDSmemOp when vec > 1 +// (1) Get bitwidth from elemTy +// (2) Bitcast values from elemTy to dataTy (u16/u32/u64) +// (3) Create StoreDSmemOp +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values, Value pred) { + assert(isa(addr.getType()) && "addr must be a pointer type"); + auto ptrTy = cast(addr.getType()); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for load_dsmem"); + unsigned bitwidth = 0; + if (!values.empty()) { + bitwidth = values.back().getType().getIntOrFloatBitWidth(); + } + auto dataTy = rewriter.getIntegerType(bitwidth); + SmallVector data; + for (unsigned i = 0; i < values.size(); ++i) + data.push_back(bitcast(values[i], dataTy)); + rewriter.create(loc, addr, ctaId, data, pred); +} + +// A wrapper of StoreDSmemOp when vec > 1 and pred = 1 +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values) { + Value pred = int_val(/*width=*/1, 1); + createStoreDSmem(loc, rewriter, addr, ctaId, values, pred); +} + } // namespace NVIDIA } // namespace LLVM } // namespace mlir diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index ff14c2eb2b..163a49cb7b 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -7,9 +7,9 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #define DEBUG_TYPE "ttgpu_to_llvm" @@ -29,6 +29,11 @@ using namespace mlir::triton; ptxBuilder.launch(rewriter, op->getLoc(), voidTy); \ } while (0) +#define load_dsmem(...) \ + ::mlir::LLVM::NVIDIA::createLoadDSmem(loc, rewriter, __VA_ARGS__) +#define store_dsmem(...) \ + ::mlir::LLVM::NVIDIA::createStoreDSmem(loc, rewriter, __VA_ARGS__) + namespace mlir { namespace LLVM { @@ -48,6 +53,30 @@ Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, ModuleOp moduleOp, int axis); + +/// Usage of macro load_dsmem +/// (1) load_dsmem(addr, ctaId) +/// (2) load_dsmem(addr, ctaId, vec) +Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Type elemTy); +SmallVector createLoadDSmem(Location loc, PatternRewriter &rewriter, + Value addr, Value ctaId, unsigned vec, + Type elemTy); + +/// Usage of macro store_dsmem +/// (1) store_dsmem(addr, ctaId, value, pred) +/// (2) store_dsmem(addr, ctaId, value) +/// (3) store_dsmem(addr, ctaId, values, pred) +/// (4) store_dsmem(addr, ctaId, values) +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value, Value pred); +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, Value value); +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values, Value pred); +void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr, + Value ctaId, ArrayRef values); + } // namespace NVIDIA } // namespace LLVM diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index b10581a382..96f428cd4f 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,14 +1,12 @@ -#include "NVGPUToLLVM/Passes.h" +#include "Dialect/NVGPU/IR/Dialect.h" +#include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "passes.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "llvm/IR/Constants.h" -#include "llvm/Support/TargetSelect.h" #include #include #include diff --git a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp index 2c0fbf13ea..f2c6b2f65d 100644 --- a/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/DumpLayout.cpp @@ -23,8 +23,10 @@ #include "DumpLayout.h" #ifdef AMD_TARGET +#include "amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" #include "amd/lib/TritonAMDGPUToLLVM/Utility.h" #else +#include "intel/lib/TritonIntelGPUToLLVM/TargetInfo.h" #include "intel/lib/TritonIntelGPUToLLVM/Utility.h" #endif namespace mlir { @@ -57,7 +59,13 @@ class IndexEmitter { public: IndexEmitter(MLIRContext *context_) : context(context_), option(context), rewriter(context), - loc(UnknownLoc::get(context)) { + loc(UnknownLoc::get(context)), +#ifdef AMD_TARGET + targetInfo("gfx942") +#else + targetInfo(90) +#endif + { mlir::OpBuilder builder(context); std::vector inTypes{}; std::vector outTypes{}; @@ -73,7 +81,8 @@ class IndexEmitter { emitIndices(Attribute layout, llvm::ArrayRef shape, bool withCTAOffset) { auto type = RankedTensorType::get(shape, rewriter.getF16Type(), layout); - return mlir::emitIndices(loc, rewriter, layout, type, withCTAOffset); + return mlir::emitIndices(loc, rewriter, targetInfo, layout, type, + withCTAOffset); } llvm::DenseMap @@ -83,9 +92,9 @@ class IndexEmitter { auto srcTy = RankedTensorType::get(shape, elemTy, srcLayout); SharedMemoryObject smemObj(getMockSmemBaseImpl(rewriter, loc), elemTy, shape, sharedLayout.getOrder(), loc, rewriter); - return getSwizzledSharedPtrs(loc, /*inVec=*/1, srcTy, sharedLayout, elemTy, - smemObj, rewriter, smemObj.offsets, - smemObj.strides); + return getSwizzledSharedPtrs(loc, targetInfo, /*inVec=*/1, srcTy, + sharedLayout, elemTy, smemObj, rewriter, + smemObj.offsets, smemObj.strides); } private: @@ -94,6 +103,11 @@ class IndexEmitter { LowerToLLVMOptions option; IRRewriter rewriter; Location loc; +#ifdef AMD_TARGET + AMD::TargetInfo targetInfo; +#else + NVIDIA::TargetInfo targetInfo; +#endif }; //===----------------------------------------------------------------------===//