Skip to content

Commit

Permalink
[XeGPUToVC] Fix 'offset' computation for 'base address + offset' calc… (
Browse files Browse the repository at this point in the history
#992)

[XeGPUToVC] Fix 'offset' computation for 'base address + offset' calculation

Current implementation of computing 'offset' fails for sub-byte types.
This patch generalizes the implementation of 'offset' computation so that it works
even for sub-byte types.

Co-authored-by: Mahesha S <[email protected]>
  • Loading branch information
silee2 and hsmahesha authored Dec 23, 2024
1 parent ff51594 commit de3e322
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 30 deletions.
12 changes: 12 additions & 0 deletions include/imex/Utils/VCUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef VC_UTILS_H
#define VC_UTILS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -64,6 +65,11 @@ using namespace mlir;
#define dense_vector_val(attr, vecTy) \
rewriter.create<arith::ConstantOp>(loc, DenseElementsAttr::get(vecTy, attr))

#define divi(a, b) rewriter.createOrFold<arith::DivSIOp>(loc, a, b)
#define muli(a, b) rewriter.createOrFold<arith::MulIOp>(loc, a, b)
#define addi(a, b) rewriter.createOrFold<arith::AddIOp>(loc, a, b)
#define subi(a, b) rewriter.createOrFold<arith::SubIOp>(loc, a, b)

/// This function adds necessary Func Declaration for Imported VC-intrinsics
/// functions and sets linkage attributes to those declaration
/// to support SPIRV compilation
Expand All @@ -78,4 +84,10 @@ func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc,
StringRef funcName, TypeRange resultType,
ValueRange operands, bool emitCInterface);

Value getOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
Type addrTy, Value offset, unsigned eTyBitWidth);

Value getVecOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
unsigned vecSize, Type addrTy, Value offset,
unsigned eTyBitWidth);
#endif // XEGPU_VC_UTILS_H
3 changes: 1 addition & 2 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,8 @@ auto getElemBitWidth = [](TensorDescType tdescTy) -> unsigned {
};

auto isLowPrecision = [](TensorDescType tdescTy) -> bool {
// Note: Handling for sub 8bit types is unclear so report as false
auto width = getElemBitWidth(tdescTy);
return width < 32 && width >= 8;
return width < 32 && width >= 4;
};

auto getScaled1DTdesc =
Expand Down
65 changes: 37 additions & 28 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ static Value castValueTo(Value val, Type toType, Location loc,
return val;
}

#define muli(a, b) rewriter.createOrFold<arith::MulIOp>(loc, a, b)
#define addi(a, b) rewriter.createOrFold<arith::AddIOp>(loc, a, b)
#define subi(a, b) rewriter.createOrFold<arith::SubIOp>(loc, a, b)

// Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a
// 2d memory region with respect to the two inner-most dimensions. Other
// outer dimensions affect the base address of the 2d plane. For 2d, we
Expand Down Expand Up @@ -135,8 +131,10 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,
// size
auto effectiveRank = strides.size();
int64_t ranksToAdjust = effectiveRank;
auto bytesPerElem =
op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth() / 8;
auto eTyBitWidth =
op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
auto bytesPerElem = eTyBitWidth / 8;
Value eTyBitWidthVal = index_val(eTyBitWidth);
Value bytesPerElemVal = index_val(bytesPerElem);

// We only need combine ranks that are larger than tileRank (e.g., if we the
Expand All @@ -147,15 +145,21 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,

auto computeBase = [&](Value base) {
for (auto i = 0; i < ranksToAdjust; i++) {
auto factor = muli(strides[i], bytesPerElemVal);
Value factor;
Value offsetVal;
if (eTyBitWidth < 8)
factor = muli(strides[i], eTyBitWidthVal);
else
factor = muli(strides[i], bytesPerElemVal);
if (offsets[i].is<Value>()) {
offsetVal = offsets[i].get<Value>();
} else {
offsetVal = index_val(
llvm::cast<IntegerAttr>(offsets[i].get<Attribute>()).getInt());
}
auto linearOffset = muli(offsetVal, factor);
if (eTyBitWidth < 8)
linearOffset = divi(linearOffset, index_val(8));
base = addi(base, linearOffset);
}

Expand All @@ -176,7 +180,7 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
auto tdescTy = op.getType();
auto scope = tdescTy.getMemorySpace();
auto rank = tdescTy.getRank();
auto elemBytes = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
auto eTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();

// SLM has to use 32-bit address, while ugm needs to use 64-bit address.
auto addrTy =
Expand Down Expand Up @@ -209,12 +213,13 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
auto payloadTy = VectorType::get(simd_lanes, addrTy);

// adjust base address to get absolute offset in unit of bytes.
// the computation is simply: base + linearOffset * elemBytes
// the computation is simply: base + linearOffset (in bytes)
Value offset =
getValueOrConstantOp(op.getMixedOffsets().back(), loc, rewriter);
offset = castValueTo(offset, addrTy, loc, rewriter);
Value factor = integer_val(elemBytes, addrTy);
auto payload = addi(base, muli(offset, factor));
Value numOffsetBytes =
getOffsetInUnitOfBytes(rewriter, loc, addrTy, offset, eTyBitWidth);
auto payload = addi(base, numOffsetBytes);

// convert the payload into vector type
payload = rewriter.create<vector::BroadcastOp>(loc, payloadTy, payload);
Expand Down Expand Up @@ -251,12 +256,15 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
if (v) {
auto value = ofr.get<Value>();
value = rewriter.create<arith::IndexCastUIOp>(loc, i32Ty, value);
if (mul > 1)
value = rewriter.create<arith::MulIOp>(loc, value, i32_val(mul));
if (mul > 8)
value =
rewriter.create<arith::MulIOp>(loc, value, i32_val(mul / 8));
else if (mul >= 4)
value = getOffsetInUnitOfBytes(rewriter, loc, i32Ty, value, mul);
return (!minus) ? value : subi(value, i32_val(minus));
} else {
int value = cast<IntegerAttr>(ofr.get<Attribute>()).getInt();
return i32_val(value * mul - minus);
return i32_val(((value * mul) / 8) - minus);
}
};

Expand All @@ -267,8 +275,9 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
// is in rows.
auto matrixShape = op.getMixedSizes();
auto size = matrixShape.size();
auto surfaceW = encodeShapeAndOffset(matrixShape[size - 1], elemBytes, 1);
auto surfaceH = encodeShapeAndOffset(matrixShape[size - 2], 1, 1);
auto surfaceW =
encodeShapeAndOffset(matrixShape[size - 1], eTyBitWidth, 1);
auto surfaceH = encodeShapeAndOffset(matrixShape[size - 2], 8, 1);

// encode the pitch, which is in bytes minus 1
auto matrixStrides = op.getMixedStrides();
Expand All @@ -279,16 +288,16 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
assert(isOneOrUnknow(matrixStrides[size - 1]) &&
"Fast Changing Dimension can only have stride of 1.");
auto surfaceP =
encodeShapeAndOffset(matrixStrides[size - 2], elemBytes, 1);
encodeShapeAndOffset(matrixStrides[size - 2], eTyBitWidth, 1);

payload = rewriter.create<vector::InsertOp>(loc, surfaceW, payload, 2);
payload = rewriter.create<vector::InsertOp>(loc, surfaceH, payload, 3);
payload = rewriter.create<vector::InsertOp>(loc, surfaceP, payload, 4);

// encode the offset, they are in elements
auto offsets = op.getMixedOffsets();
auto offsetX = encodeShapeAndOffset(offsets[size - 1], 1, 0);
auto offsetY = encodeShapeAndOffset(offsets[size - 2], 1, 0);
auto offsetX = encodeShapeAndOffset(offsets[size - 1], 8, 0);
auto offsetY = encodeShapeAndOffset(offsets[size - 2], 8, 0);
payload = rewriter.create<vector::InsertOp>(loc, offsetX, payload, 5);
payload = rewriter.create<vector::InsertOp>(loc, offsetY, payload, 6);

Expand Down Expand Up @@ -337,11 +346,11 @@ class UpdateNDOffsetPattern : public OpConversionPattern<UpdateNdOffsetOp> {
}

// update offset from unit of elements to unit of bytes
auto elemBytes = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
auto factor = integer_val(elemBytes, addrTy);
auto offset = getValueOrConstantOp(offsets.back(), loc, rewriter);
offset = castValueTo(offset, addrTy, loc, rewriter);
offset = muli(offset, factor);
auto eTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
offset =
getOffsetInUnitOfBytes(rewriter, loc, addrTy, offset, eTyBitWidth);

// convert offset to vector type and update the payload
const int simd_lanes = 1;
Expand Down Expand Up @@ -409,10 +418,10 @@ class CreateDescPattern : public OpConversionPattern<CreateDescOp> {
auto payloadTy = vecTy(simd_lanes, addrTy);

// offset is represented in number of elements, need to scale it to bytes
auto elemBytes = elemTy.getIntOrFloatBitWidth() / 8;
auto factor = dense_vector_int_val(elemBytes, addrTy, simd_lanes);
auto eTyBitWidth = elemTy.getIntOrFloatBitWidth();
Value offsets = castValueTo(adaptor.getOffsets(), payloadTy, loc, rewriter);
offsets = muli(factor, offsets);
offsets = getVecOffsetInUnitOfBytes(rewriter, loc, simd_lanes, addrTy,
offsets, eTyBitWidth);

// create a payload with the base address broadcasted to all simd lanes
Value payload = rewriter.create<vector::BroadcastOp>(loc, payloadTy, base);
Expand Down Expand Up @@ -444,10 +453,10 @@ class UpdateOffsetOpPattern : public OpConversionPattern<UpdateOffsetOp> {
auto simd_lanes = tdescTy.getShape()[0];
auto payloadTy = VectorType::get(simd_lanes, addrTy);

auto elemBytes = elemTy.getIntOrFloatBitWidth() / 8;
Value factor = dense_vector_int_val(elemBytes, addrTy, simd_lanes);
auto eTyBitWidth = elemTy.getIntOrFloatBitWidth();
Value offsets = castValueTo(adaptor.getOffsets(), payloadTy, loc, rewriter);
offsets = muli(factor, offsets);
offsets = getVecOffsetInUnitOfBytes(rewriter, loc, simd_lanes, addrTy,
offsets, eTyBitWidth);

auto payload = addi(adaptor.getTensorDesc(), offsets);
rewriter.replaceOp(op, payload);
Expand Down
27 changes: 27 additions & 0 deletions lib/Utils/VCUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,30 @@ func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc,
true /*isVectorComputeFunctionINTEL=true*/, emitCInterface);
return rewriter.create<func::CallOp>(loc, fn, resultType, operands);
}

Value getOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
Type addrTy, Value offset, unsigned eTyBitWidth) {
if (eTyBitWidth >= 8) {
unsigned eTyBytes = eTyBitWidth / 8;
Value factor = integer_val(eTyBytes, addrTy);
return muli(offset, factor);
} else {
Value eight = integer_val(8, addrTy);
Value bw = integer_val(eTyBitWidth, addrTy);
return divi(muli(offset, bw), eight);
}
}

Value getVecOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
unsigned vecSize, Type addrTy, Value offset,
unsigned eTyBitWidth) {
if (eTyBitWidth >= 8) {
unsigned eTyBytes = eTyBitWidth / 8;
Value factor = dense_vector_int_val(eTyBytes, addrTy, vecSize);
return muli(offset, factor);
} else {
Value eight = dense_vector_int_val(8, addrTy, vecSize);
Value bw = dense_vector_int_val(eTyBitWidth, addrTy, vecSize);
return divi(muli(offset, bw), eight);
}
}

0 comments on commit de3e322

Please sign in to comment.