Skip to content

Commit

Permalink
[RemoveLayoutConversions] Sync from upstream (#1037)
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored May 3, 2024
1 parent 82240cd commit 7c92443
Showing 1 changed file with 91 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,31 +203,71 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
SmallVector<Value> queue = {op->getResult(0)};
SetVector<Operation *> forwardSlice;
llvm::SmallDenseSet<Value> seen;
bool isMMAV3 = encoding.cast<NvidiaMmaEncodingAttr>().getVersionMajor() == 3;
while (!queue.empty()) {
Value currentValue = queue.back();
queue.pop_back();
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
// HACK: Stop propagation if the ReduceOp is using mma layout but is
// producing tensor smaller than the layout we would like to propagate.
// This is to avoid stepping into the known bug.
if (isa<mlir::triton::ReduceOp>(op)) {
auto tensorType =
op->getOperand(0).getType().dyn_cast<RankedTensorType>();
if (tensorType &&
tensorType.getEncoding().isa<NvidiaMmaEncodingAttr>()) {
auto mmaInstrShape =
encoding.cast<NvidiaMmaEncodingAttr>().getInstrShape();
if (tensorType.getShape()[tensorType.getRank() - 2] <
mmaInstrShape[0] ||
tensorType.getShape()[tensorType.getRank() - 1] <
mmaInstrShape[1]) {
return false;
}
}
}

if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
Attribute dstEncoding = convertOp.getType().getEncoding();
if (auto mmaLayout = dstEncoding.dyn_cast<NvidiaMmaEncodingAttr>())
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (dstEncoding.isa<DotOperandEncodingAttr>())
return encoding.cast<NvidiaMmaEncodingAttr>().getVersionMajor() > 1;
if (dstEncoding.isa<triton::gpu::AMDMfmaEncodingAttr,
triton::gpu::AMDWmmaEncodingAttr>())
return true;
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
if (auto mmaLayout = encoding.dyn_cast<NvidiaMmaEncodingAttr>()) {
return mmaLayout.getVersionMajor() > 1;
} else {
assert(encoding.isa<triton::gpu::AMDMfmaEncodingAttr>() ||
encoding.isa<triton::gpu::AMDWmmaEncodingAttr>());
return true;
}
}
}
bool isMMAV3 =
encoding.isa<NvidiaMmaEncodingAttr>() &&
encoding.cast<NvidiaMmaEncodingAttr>().getVersionMajor() == 3;
if (isMMAV3 && isa<LocalAllocOp>(op))
return true;
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
continue;
if (auto ifOp = dyn_cast<scf::IfOp>(yield->getParentOp())) {
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def &&
(forwardSlice.count(def) || operand.get() == currentValue) &&
(seen.insert(operand.get()).second == true))
queue.push_back(ifOp.getResult(operand.getOperandNumber()));
}
}
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
if (!forOp)
continue;
for (OpOperand &operand : yield->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if (def && forwardSlice.count(def) &&
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
(seen.insert(operand.get()).second == true))
queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber()));
}
Expand Down Expand Up @@ -257,12 +297,12 @@ bool isLayoutAnchor(Operation *op) {

void LayoutPropagation::initAnchorLayout() {
auto maybeAddAnchor = [&](Value v) {
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
// Workaround, don't popagate MMA layout unless there is a convert
// back to mma further down to avoid generating reduction with MMA
// layout that may have lower performance.
// This can be improved with more aggressive backward propagation.
if (tensorType.getEncoding().isa<NvidiaMmaEncodingAttr>() &&
if (tensorType.getEncoding().isa<MmaEncodingTrait>() &&
v.getDefiningOp() &&
!hasConvertToMMATransisitiveUse(v.getDefiningOp(),
tensorType.getEncoding())) {
Expand Down Expand Up @@ -292,7 +332,7 @@ void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info,
SmallVector<Value> &changed,
Operation *op) {
for (Value value : values) {
if (!value.getType().isa<RankedTensorType>())
if (!isa<RankedTensorType>(value.getType()))
continue;
bool hasChanged = false;
for (auto encoding : info.encodings) {
Expand Down Expand Up @@ -401,7 +441,7 @@ void LayoutPropagation::resolveConflicts() {
op && isa<LoadOp, StoreOp, AtomicRMWOp, AtomicCASOp>(op);
for (Attribute e : info.encodings) {
if ((isLoadOrStore && e.isa<BlockedEncodingAttr>()) ||
(!isLoadOrStore && e.isa<NvidiaMmaEncodingAttr>())) {
(!isLoadOrStore && e.isa<MmaEncodingTrait>())) {
encoding = e;
break;
}
Expand Down Expand Up @@ -431,7 +471,7 @@ void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); }
bool reduceToScalar(Operation *op) {
// For reductions returning a scalar we can change the src encoding without
// affecting the output.
return isa<ReduceOp>(op) && !op->getResultTypes()[0].isa<RankedTensorType>();
return isa<ReduceOp>(op) && !isa<RankedTensorType>(op->getResultTypes()[0]);
}

void LayoutPropagation::rewriteRegion(Region &region) {
Expand All @@ -450,7 +490,7 @@ void LayoutPropagation::rewriteRegion(Region &region) {
LayoutInfo &info = it->second;
assert(info.encodings.size() == 1 &&
"we should have resolved to a single encoding");
auto encoding = result.getType().cast<RankedTensorType>().getEncoding();
auto encoding = cast<RankedTensorType>(result.getType()).getEncoding();
// If the encoding is already what we want skip.
if (encoding == *info.encodings.begin())
continue;
Expand All @@ -476,7 +516,7 @@ void LayoutPropagation::rewriteRegion(Region &region) {
if (it == layouts.end())
continue;
Attribute encoding =
operand.get().getType().cast<RankedTensorType>().getEncoding();
cast<RankedTensorType>(operand.get().getType()).getEncoding();
Value newOperand = getValueAs(operand.get(), encoding);
op.setOperand(operand.getOperandNumber(), newOperand);
}
Expand All @@ -490,12 +530,12 @@ void LayoutPropagation::rewriteRegion(Region &region) {
}

void LayoutPropagation::map(Value old, Value newV) {
rewriteMapping[{old, newV.getType().cast<RankedTensorType>().getEncoding()}] =
rewriteMapping[{old, cast<RankedTensorType>(newV.getType()).getEncoding()}] =
newV;
}

Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
if (auto tensorType = value.getType().dyn_cast<RankedTensorType>()) {
if (auto tensorType = dyn_cast<RankedTensorType>(value.getType())) {
Value rewrittenValue;
auto layoutIt = layouts.find(value);
if (layoutIt == layouts.end()) {
Expand All @@ -510,7 +550,7 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
rewrittenValue = rewriteMapping[{value, encodingPicked}];
}
assert(rewrittenValue);
if (rewrittenValue.getType().cast<RankedTensorType>().getEncoding() ==
if (cast<RankedTensorType>(rewrittenValue.getType()).getEncoding() ==
encoding)
return rewrittenValue;
OpBuilder rewriter(value.getContext());
Expand Down Expand Up @@ -542,7 +582,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter,
}

for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
auto origType = op->getResult(i).getType().dyn_cast<RankedTensorType>();
auto origType = dyn_cast<RankedTensorType>(op->getResult(i).getType());
if (!origType)
continue;
auto newType = RankedTensorType::get(origType.getShape(),
Expand Down Expand Up @@ -608,7 +648,7 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) {
returnTypes.push_back(ret.getType());
continue;
}
auto origType = ret.getType().dyn_cast<RankedTensorType>();
auto origType = dyn_cast<RankedTensorType>(ret.getType());
auto newType =
RankedTensorType::get(origType.getShape(), origType.getElementType(),
it->second.encodings[0]);
Expand Down Expand Up @@ -659,7 +699,7 @@ Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) {
auto it = layouts.find(ifOp->getResult(i));
if (it == layouts.end())
continue;
auto origType = ifOp->getResult(i).getType().cast<RankedTensorType>();
auto origType = cast<RankedTensorType>(ifOp->getResult(i).getType());
Attribute encoding = *(it->second.encodings.begin());
newResultTypes[i] = RankedTensorType::get(
origType.getShape(), origType.getElementType(), encoding);
Expand Down Expand Up @@ -688,7 +728,7 @@ void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) {
if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp))
yieldType =
whileOp.getBeforeArguments()[operand.getOperandNumber()].getType();
auto tensorType = yieldType.dyn_cast<RankedTensorType>();
auto tensorType = dyn_cast<RankedTensorType>(yieldType);
if (!tensorType)
continue;
Value newOperand = getValueAs(operand.get(), tensorType.getEncoding());
Expand All @@ -701,7 +741,7 @@ void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) {
for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) {
OpOperand &operand = conditionOp->getOpOperand(i);
Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType();
auto tensorType = argType.dyn_cast<RankedTensorType>();
auto tensorType = dyn_cast<RankedTensorType>(argType);
if (!tensorType)
continue;
Value newOperand = getValueAs(operand.get(), tensorType.getEncoding());
Expand Down Expand Up @@ -757,7 +797,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
if (it != layouts.end())
srcEncoding = *(it->second.encodings.begin());
Value src = getValueAs(convertOp.getSrc(), srcEncoding);
auto tensorType = op->getResult(0).getType().cast<RankedTensorType>();
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
auto newType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
auto cvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), newType, src);
Expand All @@ -766,7 +806,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
}
if (canFoldIntoConversion(op, encoding)) {
Operation *newOp = rewriter.clone(*op);
auto tensorType = op->getResult(0).getType().cast<RankedTensorType>();
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
auto newType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
auto cvt = rewriter.create<ConvertLayoutOp>(op->getLoc(), newType,
Expand Down Expand Up @@ -834,6 +874,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
ConvertLayoutOp convertOp,
IRMapping &mapping) {
SetVector<Operation *> opsToRewrite;
// Keep track of yield operands that need to be duplicated.
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
Expand All @@ -845,13 +887,22 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
if (v.getDefiningOp()) {
opsToRewrite.insert(v.getDefiningOp());
if (auto ifOp = v.getDefiningOp<scf::IfOp>()) {
unsigned operandIdx = v.cast<OpResult>().getResultNumber();
opsToRewrite.insert(ifOp.thenYield().getOperation());
yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx);
opsToRewrite.insert(ifOp.elseYield().getOperation());
yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx);
}
} else {
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getParentOp());
// We also need to rewrite the yield op.
opsToRewrite.insert(v.cast<BlockArgument>().getOwner()->getTerminator());
BlockArgument blockArg = v.cast<BlockArgument>();
Operation *parentOp = blockArg.getOwner()->getParentOp();
if (auto loopOp = cast<LoopLikeOpInterface>(parentOp)) {
opsToRewrite.insert(loopOp.getOperation());
OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg);
auto yieldOp = blockArg.getOwner()->getTerminator();
yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber());
opsToRewrite.insert(yieldOp);
}
}
}
opsToRewrite = multiRootTopologicalSort(opsToRewrite);
Expand Down Expand Up @@ -895,6 +946,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Value oldArg = loopBody.getArgument(m.first + numIndVars);
addRematValue(newForOp.getResult(m.first), layout[oldArg],
newForOp.getResult(m.second));
addRematValue(oldArg, layout[oldArg],
loopBody.getArgument(m.second + numIndVars));
}
continue;
}
Expand All @@ -905,7 +958,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
auto it = layout.find(res);
assert(it != layout.end());

auto oldType = res.getType().cast<RankedTensorType>();
auto oldType = cast<RankedTensorType>(res.getType());
auto newType = RankedTensorType::get(
oldType.getShape(), oldType.getElementType(), it->second);
newTypes.push_back(newType);
Expand All @@ -931,18 +984,20 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
builder.setInsertionPoint(op);
if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
auto yieldOperands = llvm::to_vector(yieldOp.getOperands());
for (Value operand : yieldOp.getOperands()) {
if (slice.count(operand) == 0)
continue;
yieldOperands.push_back(mapping.lookup(operand));
SmallVector<int> operandsToRewrite = yieldOperandsMap[op];
// Sort so that operands are added in the same order as the new scf
// results/arguments.
std::sort(operandsToRewrite.begin(), operandsToRewrite.end());
for (int operandIdx : operandsToRewrite) {
yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx)));
}
builder.create<scf::YieldOp>(op->getLoc(), yieldOperands);
op->erase();
continue;
}
if (isa<arith::ConstantOp>(op)) {
Operation *newOp = builder.clone(*op);
auto tensorType = op->getResult(0).getType().cast<RankedTensorType>();
auto tensorType = cast<RankedTensorType>(op->getResult(0).getType());
auto newType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(),
layout[op->getResult(0)]);
Expand All @@ -961,17 +1016,16 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Type oldType = old.getType();
Type newType;
if (isTensorPointerType(oldType)) {
auto ptrType = oldType.cast<PointerType>();
auto tensorType = ptrType.getPointeeType().cast<RankedTensorType>();
auto ptrType = cast<PointerType>(oldType);
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
newType = triton::PointerType::get(
RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), it->second),
ptrType.getAddressSpace());
} else {
newType = RankedTensorType::get(
old.getType().cast<RankedTensorType>().getShape(),
old.getType().cast<RankedTensorType>().getElementType(),
it->second);
cast<RankedTensorType>(old.getType()).getShape(),
cast<RankedTensorType>(old.getType()).getElementType(), it->second);
}
newV.setType(newType);
addRematValue(old, it->second, newV);
Expand Down Expand Up @@ -1137,15 +1191,15 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
// Move the convert before the ext op and rewrite the slice.
OpBuilder builder(extOrBroadcatOp);
auto tensorType =
extOrBroadcatOp->getOperand(0).getType().cast<RankedTensorType>();
cast<RankedTensorType>(extOrBroadcatOp->getOperand(0).getType());
auto newType = RankedTensorType::get(
tensorType.getShape(), tensorType.getElementType(), *srcEncoding);
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp);
newExtOrBroadcast->setOperand(0, newConvertOp.getResult());
auto oldExtOrBroadcastType =
extOrBroadcatOp->getResult(0).getType().cast<RankedTensorType>();
cast<RankedTensorType>(extOrBroadcatOp->getResult(0).getType());
Type newExtOrBroadcasrType = RankedTensorType::get(
oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(),
dstEncoding);
Expand Down

0 comments on commit 7c92443

Please sign in to comment.