Skip to content

Commit

Permalink
Merge commit '4f94c88498767d32420adb7f1b8d45c956a388ac'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jun 16, 2024
2 parents eddd58e + 4f94c88 commit e0a3716
Show file tree
Hide file tree
Showing 28 changed files with 538 additions and 670 deletions.
58 changes: 38 additions & 20 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,52 @@ class TargetInfoBase {

virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0;

virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc,
Type type, Value cmp) const = 0;
virtual Value ballot(RewriterBase &rewriter, Location loc, Type type,
Value cmp) const = 0;

virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) const = 0;
virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc,
const TypeConverter *converter, Value ptr,
Type elemTy, Value pred) const = 0;
// Store/load a value from shared memory, either in the same CTA or, if
// `ctaId` is non-nullopt, in another CTA in the same group.
//
// A target that does not support cross-CTA transfers will assert if ctaId is
// non-nullopt.
//
// Assumes the address is aligned to the width of `val`.
virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Value val,
Value pred) const = 0;
virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Type elemTy,
Value pred) const = 0;

virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc,
Value val, int i) const = 0;
virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc,
Value val, int i) const = 0;
virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc,
Value val, int i) const = 0;
virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc,
Value val, Value i) const = 0;
void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred) const {
storeDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, val, pred);
}
Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred) const {
return loadDShared(rewriter, loc, ptr, /*ctaId=*/std::nullopt, elemTy,
pred);
}

virtual Value programId(ConversionPatternRewriter &rewriter, Location loc,
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
Value i) const = 0;

virtual Value programId(RewriterBase &rewriter, Location loc,
ModuleOp moduleOp, int axis) const = 0;

virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc,
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce,
unsigned interleave) const = 0;

virtual bool processReplicaUsingStMatrix(
ConversionPatternRewriter &rewriter, Location loc, Value smemBase,
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
Expand All @@ -48,11 +66,11 @@ class TargetInfoBase {
// format from the device. |formatStrStart| is the pointer to the start of
// the format string global variable; |args| are the arguments to fill
// placeholders in the format string.
virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart,
virtual void printf(RewriterBase &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const = 0;
// Emits LLVM code with |rewriter| to perform assertion failure with the given
// |message| from the given |func| in |file|.
virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc,
virtual void assertFail(RewriterBase &rewriter, Location loc,
StringRef message, StringRef file, StringRef func,
int line) const = 0;

Expand Down
77 changes: 34 additions & 43 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ using namespace mlir::triton;
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)

// Constants
#define i1_val(val) LLVM::createConstantI1(loc, rewriter, val)
#define true_val() i1_val(true)
#define false_val() i1_val(false)
#define f16_val(...) LLVM::createConstantF16(loc, rewriter, __VA_ARGS__)
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
Expand Down Expand Up @@ -202,9 +205,9 @@ T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
namespace gpu {
Type getFunctionType(Type resultType, ValueRange operands);

LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter,
Operation *op, StringRef funcName,
Type funcType, StringRef libname = "",
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
StringRef funcName, Type funcType,
StringRef libname = "",
StringRef libpath = "");
} // namespace gpu

Expand All @@ -213,31 +216,21 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter,
namespace LLVM {
using namespace mlir::triton;

Value createConstantI1(Location loc, OpBuilder &rewriter, bool v);
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);

/// Create a 64-bit integer constant.
Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v);

/// Create a 16-bit float constant.
Value createConstantF16(Location loc, OpBuilder &rewriter, float v);

/// Create a 32-bit float constant.
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);

/// Create a 64-bit float constant.
Value createConstantF64(Location loc, OpBuilder &rewriter, double v);

/// Create NaN constant of specified type.
Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type);

/// Create an index type constant.
Value createIndexConstant(OpBuilder &builder, Location loc,
const TypeConverter *converter, int64_t value);

/// Create an integer constant of \param width bits.
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value);

// Is v an integer or floating-point scalar constant equal to 0?
bool isConstantZero(Value v);

/// Helper function to get strides from a given shape and its order
SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
ArrayRef<unsigned> order,
Expand Down Expand Up @@ -305,17 +298,18 @@ struct SharedMemoryObject {
}

Value getBaseBeforeSlice(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
RewriterBase &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(order);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();
return gep(type, baseElemType, base, offset);
}
};

SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy,
ConversionPatternRewriter &rewriter);
SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Value llvmStruct,
Type elemTy,
RewriterBase &rewriter);

// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
Expand All @@ -329,15 +323,14 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape);

Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);

Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape);
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);

Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content);
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content);

// Given an elemId which represents the index of an element from the list of
// elements that are in the thread's registers (i.e. total of
Expand All @@ -346,7 +339,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
// when converting distributed to distributed layout. Also, a replica is the
// smallest CTA tile that is common between input and output layouts.
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
RewriterBase &rewriter,
const TargetInfoBase &targetInfo,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
Expand All @@ -355,15 +348,15 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
// Given a multiDimOffset, this function wraps around each dimension to be
// within shape.
SmallVector<Value> getWrappedMultiDimOffset(
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDimOffset, ArrayRef<unsigned> shape,
SmallVector<unsigned> shapePerCTATile, SmallVector<int64_t> shapePerCTA);
RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDimOffset,
ArrayRef<unsigned> shape, SmallVector<unsigned> shapePerCTATile,
SmallVector<int64_t> shapePerCTA);

inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

inline Value getStackPointer(PatternRewriter &rewriter,
inline Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
auto mod = funcOp->getParentOfType<ModuleOp>();
LLVM::GlobalOp globalBase = nullptr;
Expand All @@ -378,8 +371,7 @@ inline Value getStackPointer(PatternRewriter &rewriter,
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

inline Value getSharedMemoryBase(Location loc,
ConversionPatternRewriter &rewriter,
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
FunctionOpInterface func =
Expand Down Expand Up @@ -1566,9 +1558,9 @@ inline void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
}
}

inline Value
getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
RewriterBase &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
Expand All @@ -1582,9 +1574,8 @@ getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj,
return llvmStruct;
}

inline SmallVector<Value>
unpackLLElements(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmStruct.getType()) ||
Expand All @@ -1602,8 +1593,8 @@ unpackLLElements(Location loc, Value llvmStruct,

inline Value packLLElements(Location loc,
const LLVMTypeConverter *typeConverter,
ValueRange resultVals,
ConversionPatternRewriter &rewriter, Type type) {
ValueRange resultVals, RewriterBase &rewriter,
Type type) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (!structType) {
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/Triton/IR/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
template <typename VecT> auto product(const VecT &vec) {
return product(llvm::ArrayRef(vec));
}
template <typename Int> Int getNumElements(ArrayRef<Int> shape) {
if (shape.empty()) {
return 0;
}
return product(shape);
}

// TODO(jlebar): Rename to ceilOfRatio.
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
Expand Down
9 changes: 3 additions & 6 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ class AllocationAnalysis {
unsigned inVec = 0;
unsigned outVec = 0;
auto smemShape = getScratchConfigForCvtLayout(cvtLayout, inVec, outVec);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elems = getNumElements<unsigned>(smemShape);
auto bytes =
isa<triton::PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
Expand All @@ -286,8 +285,7 @@ class AllocationAnalysis {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicRMW(atomicRMWOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elems = getNumElements<unsigned>(smemShape);
auto elemTy =
cast<triton::PointerType>(value.getType()).getPointeeType();
auto bytes =
Expand All @@ -305,8 +303,7 @@ class AllocationAnalysis {
// nothing to do
} else {
auto smemShape = getScratchConfigForAtomicCAS(atomicCASOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto elems = getNumElements<unsigned>(smemShape);
auto elemTy =
cast<triton::PointerType>(value.getType()).getPointeeType();
auto bytes = isa<triton::PointerType>(elemTy)
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ struct ReduceOpConversion
auto elemTy = getElementType(op, i);
Value readPtr = gep(ptr_ty(rewriter.getContext(), 3), elemTy,
smemBases[i], readOffset);
acc[i] = targetInfo.loadShared(rewriter, loc, getTypeConverter(),
readPtr, elemTy, threadIsNeeded);
acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy,
threadIsNeeded);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */);
// only the first thread in each sizeInterWarps is writing
Expand Down
Loading

0 comments on commit e0a3716

Please sign in to comment.