Skip to content

Commit

Permalink
Bank conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
victor-eds committed Nov 18, 2024
1 parent 6c5b353 commit 3586668
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 90 deletions.
16 changes: 10 additions & 6 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,16 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
let hasVerifier = 1;
}

def TritonGEN_SIMDBlockMemoryAccessElementType : AnyTypeOf<[I8, I16, I32, I64]>;

def TritonGEN_SIMDBlockMemoryAccessType
: AnyTypeOf<[TritonGEN_SIMDBlockMemoryAccessElementType,
FixedVectorOfLengthAndType<[2, 4, 8], [TritonGEN_SIMDBlockMemoryAccessElementType]>,
// Vectors of length 16 only allowed for i8 for now.
FixedVectorOfLengthAndType<[16], [I8]>]>;

def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
Results<(outs TritonGEN_SIMDBlockMemoryAccessType:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
)> {
Expand All @@ -331,14 +339,12 @@ def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
let assemblyFormat = [{
operands ` ` attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
}

def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
TritonGEN_SIMDBlockMemoryAccessType:$val
)> {

let summary = "simd block write";
Expand All @@ -353,7 +359,5 @@ def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
let assemblyFormat = [{
operands ` ` attr-dict `:` `(` type(operands) `)`
}];

let hasVerifier = 1;
}
#endif // TRITONGEN_OPS
4 changes: 2 additions & 2 deletions third_party/intel/lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
unsigned warpsPerCTA = product(srcEncoding.getWarpsPerCTA());
unsigned remaining = product(srcTy.getShape()) /
(threadsPerWarp * threadsPerWarp * warpsPerCTA);
SmallVector<unsigned> repShape{threadsPerWarp, threadsPerWarp, remaining,
warpsPerCTA};
SmallVector<unsigned> repShape{threadsPerWarp + 1, threadsPerWarp,
remaining, warpsPerCTA};
return ScratchConfig(repShape, repShape,
/*inVec=*/1, /*outVec=*/threadsPerWarp);
}
Expand Down
28 changes: 0 additions & 28 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,6 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
return success();
}

static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) {
unsigned numElems = vecTy.getNumElements();
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
(elemTy.getWidth() != 8 || numElems != 16))
return op->emitOpError("unsupported vector type");

return success();
}

//===----------------------------------------------------------------------===//
// gen.sub_group_reduce
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {

return success();
}

//===----------------------------------------------------------------------===//
// gen.simdblockread
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
return verifySIMDBlockTy(*this, getRes().getType());
}

//===----------------------------------------------------------------------===//
// gen.simdblockwrite
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
return verifySIMDBlockTy(*this, getVal().getType());
}
40 changes: 25 additions & 15 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"

#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Attributes.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/ModRef.h"
Expand Down Expand Up @@ -937,25 +938,34 @@ struct TritonMatrix2DBlockPrefetchLowering
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SIMDBlockReadOp,
TritonGEN::SIMDBlockWriteOp>::value>>
static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
static std::string getSIMDBlockManglingName(OpType op, Type type) {
constexpr bool isWrite =
std::is_same<OpType, TritonGEN::SIMDBlockWriteOp>::value;
const LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
const unsigned numElems = vecTy.getNumElements();
// Note: OCL builtin name here differs from regular mangling.
std::string funcName = "intel_sub_group_block_";
if constexpr (isWrite)
funcName += "write";
else
funcName += "read";
funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) +
(numElems == 1 ? "" : std::to_string(numElems));
funcName =
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true);
TypeSwitch<Type>(type)
.Case([&](VectorType vecType) {
const unsigned numElems = vecType.getNumElements();
funcName += "_u" + intel::getTypeMangling(vecType.getElementType()) +
std::to_string(numElems);
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(vecType.getElementType(),
/*isUnsigned=*/true);
})
.Case([&](IntegerType vecType) {
funcName += "_u" + intel::getTypeMangling(type);
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(type, /*isUnsigned=*/true);
});
if constexpr (isWrite)
funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true);
funcName += intel::getTypeMangling(type, /*isUnsigned=*/true);
return funcName;
}

Expand All @@ -968,17 +978,17 @@ struct TritonSIMDBlockReadLowering
matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getRes().getType();
Type type = op.getRes().getType();

std::string funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSIMDBlockManglingName(op, type);
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, vecTy, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});

rewriter.replaceOp(op, call.getResult());
return success();
Expand All @@ -995,9 +1005,9 @@ struct TritonSIMDBlockWriteLowering
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getVal().getType();
Type type = op.getVal().getType();

std::string funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSIMDBlockManglingName(op, type);

auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
Expand All @@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering
auto funcAttrs = noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, void_ty(ctx), {ptrTy, vecTy},
rewriter, funcName, void_ty(ctx), {ptrTy, type},
{op.getPtr(), op.getVal()}, {}, funcAttrs);

rewriter.replaceOp(op, call);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,14 +767,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
rewriter.replaceOp(op, result);
}

VectorType
getTypeForSubGroupTranspose(ArrayRef<Value> inVals,
ConversionPatternRewriter &rewriter) const {
auto elementTy = cast<IntegerType>(inVals.front().getType());
return elementTy.getWidth() <= 16 ? vec_ty(elementTy, 16)
: vec_ty(elementTy, 8);
}

Value wrapInVector(Location loc, VectorType type, ArrayRef<Value> values,
ConversionPatternRewriter &rewriter) const {
assert(type.getShape()[0] == values.size() && "Size mismatch");
Expand All @@ -800,18 +792,18 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
performSubGroupTranspose(Location loc, ArrayRef<Value> inVals,
ConversionPatternRewriter &rewriter,
bool isContiguous) const {
VectorType opType = getTypeForSubGroupTranspose(inVals, rewriter);
Type elementType = inVals.front().getType();
auto mod = rewriter.getInsertionPoint()->getParentOfType<ModuleOp>();
unsigned vecWidth = opType.getShape()[0];

Value smemBase = LLVM::intel::getSharedMemoryBase(
loc, rewriter, targetInfo, &*rewriter.getInsertionPoint());
Type ptrType = smemBase.getType();

int numElements = inVals.size();
int numRows = inVals.size();
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
int offset = threadsPerWarp;
int rowLength = threadsPerWarp + 1;
Type offsetType = getTypeConverter()->getIndexType();
Value subGroupOffset =
int_val(offsetType.getIntOrFloatBitWidth(), rowLength * numRows);
Value subGroupId = getValueOrCreateCastToIndexLike(
rewriter, loc, offsetType,
rewriter.create<mlir::gpu::SubgroupIdOp>(
Expand All @@ -820,42 +812,40 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
rewriter, loc, offsetType,
rewriter.create<mlir::gpu::LaneIdOp>(loc,
/*upper_bound=*/IntegerAttr{}));
int wiStrideNum = isContiguous ? numElements : threadsPerWarp;
Value wiStride =
rewriter.create<LLVM::ConstantOp>(loc, offsetType, wiStrideNum);
Value sgStride = rewriter.create<LLVM::ConstantOp>(
loc, offsetType, threadsPerWarp * numElements);
Value subGroupOffset = mul(sgStride, subGroupId);
Type elementType = opType.getElementType();
Value subGroupBasePtr = gep(ptrType, elementType, smemBase,
ValueRange{subGroupOffset}, /*inbounds=*/true);
Value base = subGroupBasePtr;
// Store in matrix, transposed
for (ArrayRef<Value> vals = inVals; !vals.empty();
vals = vals.drop_front(vecWidth)) {
ArrayRef<Value> curr = vals.take_front(vecWidth);
Value vec = wrapInVector(loc, opType, curr, rewriter);
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, base, vec);
base = gep(base.getType(), opType, base, ArrayRef<LLVM::GEPArg>{offset},
for (Value val : inVals) {
rewriter.create<TritonGEN::SIMDBlockWriteOp>(loc, base, val);
base = gep(base.getType(), elementType, base,
ArrayRef<LLVM::GEPArg>{rowLength},
/*inbounds=*/true);
}

// Load from matrix, non-trasposed.
// As per SIMD block semantics, we have stored the elements in a matrix of
// `Nxsub_group_size` size, so we need to load back in blocks of
// `sub_group_size` (`N/sub_group_size` loads).
Value workItemOffset = mul(wiStride, subGroupLocalId);
int32_t numContiguous = isContiguous ? inVals.size() / threadsPerWarp : 1;
int32_t workItemStride =
isContiguous ? rowLength : rowLength * threadsPerWarp;
Value workItemOffset =
mul(subGroupLocalId, int_val(offsetType.getIntOrFloatBitWidth(),
numContiguous * rowLength));
Value workItemBasePtr =
gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset},
/*inbounds=*/true);
SmallVector<Value> transposedVecs;
Type loadTy = vec_ty(opType.getElementType(), wiStrideNum);
for (std::size_t i = 0, n = inVals.size(); i < n; i += wiStrideNum) {
transposedVecs.push_back(load(loadTy, workItemBasePtr));
workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr,
ArrayRef<LLVM::GEPArg>{offset}, /*inbounds=*/true);
int32_t rowsPerThread = numRows / threadsPerWarp;
SmallVector<Value> outputVals;
for (int i = 0; i < rowsPerThread; ++i) {
for (int j = 0; j < threadsPerWarp; ++j) {
outputVals.push_back(load(elementType, workItemBasePtr));
workItemBasePtr =
gep(workItemBasePtr.getType(), elementType, workItemBasePtr,
ArrayRef<LLVM::GEPArg>{1}, /*inbounds=*/true);
}
workItemBasePtr =
gep(workItemBasePtr.getType(), elementType, workItemBasePtr,
ArrayRef<LLVM::GEPArg>{workItemStride - threadsPerWarp},
/*inbounds=*/true);
}
return unwrapFromVectors(loc, transposedVecs, rewriter);
return outputVals;
}

void performUnbroadcast(ConvertLayoutOp op, const LinearLayout &srcLayout,
Expand Down

0 comments on commit 3586668

Please sign in to comment.