Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 4d2e9e5 #2978

Merged
merged 11 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@ add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
# TODO: what's this?
llvm_update_compile_flags(triton-opt)
target_link_libraries(triton-opt PRIVATE
TritonLLVMIR
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonIntelLLVMIR
MLIRGPUToROCDLTransforms
${dialect_libs}
${conversion_libs}
${triton_libs}
Expand All @@ -32,11 +26,6 @@ mlir_check_all_link_libraries(triton-reduce)

llvm_update_compile_flags(triton-reduce)
target_link_libraries(triton-reduce PRIVATE
TritonLLVMIR
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
${triton_libs}
Expand All @@ -54,10 +43,6 @@ add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED)

llvm_update_compile_flags(triton-lsp)
target_link_libraries(triton-lsp PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
${dialect_libs}
${conversion_libs}
${triton_libs}
Expand Down Expand Up @@ -96,8 +81,6 @@ export_executable_symbols_for_plugins(triton-llvm-opt)

add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
target_link_libraries(triton-tensor-layout PRIVATE
TritonGPUIR
TritonNvidiaGPUIR
${triton_libs}
${conversion_libs}
${dialect_libs}
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

// Check if MFMA layout can be converted to the dot operand
// layout using warp shuffle.
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy);

// TODO: Move utility functions that belong to ConvertLayoutOp to class
// ConvertLayoutOpHelper in the future
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
Expand Down
12 changes: 6 additions & 6 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1154,15 +1154,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
// Returns true on success.
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);

inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
Location loc, const TargetInfoBase &target, unsigned inVec,
RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout,
Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter,
Type resElemTy, const SharedMemoryObject &smemObj, RewriterBase &rewriter,
ArrayRef<Value> offsetVals, ArrayRef<Value> srcStrides) {
// This utility computes the pointers for accessing the provided swizzled
// shared memory layout `resSharedLayout`. More specifically, it computes,
Expand Down Expand Up @@ -1324,14 +1324,14 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
triton::gpu::MemDescType srcTy,
Type elemLlvmTy,
SharedMemoryObject smemObj,
const SharedMemoryObject &smemObj,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
ArrayRef<Value> srcVals, const SharedMemoryObject &smemObj, Location loc,
RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
Expand Down
7 changes: 7 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ class LinearLayout {

bool isSurjective() const { return surjective; }

bool isInvertible() const {
return surjective && getTotalInDimSize() == getTotalOutDimSize();
}

const BasesT &getBases() const { return bases; }

// Get the pos'th basis vector for the inDim -> outDim mapping.
Expand Down Expand Up @@ -673,6 +677,9 @@ class LinearLayout {
// don't place any guarantees on the choices made by this function.
[[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const;

// Get the layout that is the inverse of this layout.
[[nodiscard]] LinearLayout invert() const;

// For each in-dim, returns a bitmask of the "free variables" in the layout
// function.
//
Expand Down
25 changes: 24 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -650,6 +651,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
return ans;
}

bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!mfmaLayout || !dotOperandLayout)
return false;

// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
return dotOperandLayout.getParent() == mfmaLayout &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == 8 &&
getContigPerThread(mfmaLayout)[1] == 4 &&
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
triton::type::isFloat8(srcTy.getElementType()) &&
triton::type::isFloat8(dstTy.getElementType()) &&
mfmaLayout.getWarpsPerCTA()[1] == 1;
}

// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
// have a transformation that's the identity on kBlock, we don't need to use
Expand Down Expand Up @@ -749,7 +769,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
return !cvtReordersRegisters(srcTy, dstTy) &&
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
// to be removed when generalized warp shuffle conversions
// are ready:
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
}

bool atomicNeedsSharedMemory(Value value) {
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
return failure();
}

// The following check can be removed when generalized warp shuffle
// conversions are ready:
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
return failure();
}

assert(cvtNeedsSharedMemory(srcTy, dstTy));

SmallVector<Value> inVals =
Expand Down
17 changes: 5 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ void lowerDistributedToShared(
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
auto elemTy = typeConverter->convertType(srcTy.getElementType());

auto smemBase = smemObj.getBase();
auto dstStrides = smemObj.getStrides();
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
loc, rewriter, targetInfo, llvmOpCount);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
targetInfo, llvmOpCount);
}

struct GlobalScratchAllocOpConversion
Expand Down Expand Up @@ -157,14 +155,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though
canUseLdmatrix &=
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
// To be removed in https://github.com/triton-lang/triton/pull/5154
bool legacyLoweringIsBuggy =
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32) ||
dstTy.getRank() == 3) &&
mma.isAmpere();
return (mma.isHopper() && !canUseLdmatrix) ||
(mma.isAmpere() && legacyLoweringIsBuggy);
srcTy.getShape()[0] >= 8 &&
srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2;
return !canUseLdmatrix;
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
Expand Down
Loading
Loading