diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 91c7bf99e4..f55aa01e6e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -203,31 +203,71 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { SmallVector queue = {op->getResult(0)}; SetVector forwardSlice; llvm::SmallDenseSet seen; - bool isMMAV3 = encoding.cast().getVersionMajor() == 3; while (!queue.empty()) { Value currentValue = queue.back(); queue.pop_back(); getForwardSlice(currentValue, &forwardSlice); for (Operation *op : forwardSlice) { + // HACK: Stop propagation if the ReduceOp is using mma layout but is + // producing tensor smaller than the layout we would like to propagate. + // This is to avoid stepping into the known bug. + if (isa(op)) { + auto tensorType = + op->getOperand(0).getType().dyn_cast(); + if (tensorType && + tensorType.getEncoding().isa()) { + auto mmaInstrShape = + encoding.cast().getInstrShape(); + if (tensorType.getShape()[tensorType.getRank() - 2] < + mmaInstrShape[0] || + tensorType.getShape()[tensorType.getRank() - 1] < + mmaInstrShape[1]) { + return false; + } + } + } + if (auto convertOp = dyn_cast(op)) { Attribute dstEncoding = convertOp.getType().getEncoding(); if (auto mmaLayout = dstEncoding.dyn_cast()) return (mmaLayout.getVersionMajor() > 1) ? true : mmaLayout == encoding; - if (dstEncoding.isa()) - return encoding.cast().getVersionMajor() > 1; + if (dstEncoding.isa()) + return true; + if (dstEncoding.isa()) { + if (auto mmaLayout = encoding.dyn_cast()) { + return mmaLayout.getVersionMajor() > 1; + } else { + assert(encoding.isa() || + encoding.isa()); + return true; + } + } } + bool isMMAV3 = + encoding.isa() && + encoding.cast().getVersionMajor() == 3; if (isMMAV3 && isa(op)) return true; auto yield = dyn_cast(op); if (!yield) continue; + if (auto ifOp = dyn_cast(yield->getParentOp())) { + for (OpOperand &operand : yield->getOpOperands()) { + Operation *def = operand.get().getDefiningOp(); + if (def && + (forwardSlice.count(def) || operand.get() == currentValue) && + (seen.insert(operand.get()).second == true)) + queue.push_back(ifOp.getResult(operand.getOperandNumber())); + } + } auto forOp = dyn_cast(yield.getOperation()->getParentOp()); if (!forOp) continue; for (OpOperand &operand : yield->getOpOperands()) { Operation *def = operand.get().getDefiningOp(); - if (def && forwardSlice.count(def) && + if (def && (forwardSlice.count(def) || operand.get() == currentValue) && (seen.insert(operand.get()).second == true)) queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber())); } @@ -257,12 +297,12 @@ bool isLayoutAnchor(Operation *op) { void LayoutPropagation::initAnchorLayout() { auto maybeAddAnchor = [&](Value v) { - if (auto tensorType = v.getType().dyn_cast()) { + if (auto tensorType = dyn_cast(v.getType())) { // Workaround, don't popagate MMA layout unless there is a convert // back to mma further down to avoid generating reduction with MMA // layout that may have lower performance. // This can be improved with more aggressive backward propagation. - if (tensorType.getEncoding().isa() && + if (tensorType.getEncoding().isa() && v.getDefiningOp() && !hasConvertToMMATransisitiveUse(v.getDefiningOp(), tensorType.getEncoding())) { @@ -292,7 +332,7 @@ void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info, SmallVector &changed, Operation *op) { for (Value value : values) { - if (!value.getType().isa()) + if (!isa(value.getType())) continue; bool hasChanged = false; for (auto encoding : info.encodings) { @@ -401,7 +441,7 @@ void LayoutPropagation::resolveConflicts() { op && isa(op); for (Attribute e : info.encodings) { if ((isLoadOrStore && e.isa()) || - (!isLoadOrStore && e.isa())) { + (!isLoadOrStore && e.isa())) { encoding = e; break; } @@ -431,7 +471,7 @@ void LayoutPropagation::rewrite() { rewriteRegion(funcOp->getRegion(0)); } bool reduceToScalar(Operation *op) { // For reductions returning a scalar we can change the src encoding without // affecting the output. - return isa(op) && !op->getResultTypes()[0].isa(); + return isa(op) && !isa(op->getResultTypes()[0]); } void LayoutPropagation::rewriteRegion(Region ®ion) { @@ -450,7 +490,7 @@ void LayoutPropagation::rewriteRegion(Region ®ion) { LayoutInfo &info = it->second; assert(info.encodings.size() == 1 && "we should have resolved to a single encoding"); - auto encoding = result.getType().cast().getEncoding(); + auto encoding = cast(result.getType()).getEncoding(); // If the encoding is already what we want skip. if (encoding == *info.encodings.begin()) continue; @@ -476,7 +516,7 @@ void LayoutPropagation::rewriteRegion(Region ®ion) { if (it == layouts.end()) continue; Attribute encoding = - operand.get().getType().cast().getEncoding(); + cast(operand.get().getType()).getEncoding(); Value newOperand = getValueAs(operand.get(), encoding); op.setOperand(operand.getOperandNumber(), newOperand); } @@ -490,12 +530,12 @@ void LayoutPropagation::rewriteRegion(Region ®ion) { } void LayoutPropagation::map(Value old, Value newV) { - rewriteMapping[{old, newV.getType().cast().getEncoding()}] = + rewriteMapping[{old, cast(newV.getType()).getEncoding()}] = newV; } Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { - if (auto tensorType = value.getType().dyn_cast()) { + if (auto tensorType = dyn_cast(value.getType())) { Value rewrittenValue; auto layoutIt = layouts.find(value); if (layoutIt == layouts.end()) { @@ -510,7 +550,7 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { rewrittenValue = rewriteMapping[{value, encodingPicked}]; } assert(rewrittenValue); - if (rewrittenValue.getType().cast().getEncoding() == + if (cast(rewrittenValue.getType()).getEncoding() == encoding) return rewrittenValue; OpBuilder rewriter(value.getContext()); @@ -542,7 +582,7 @@ Operation *LayoutPropagation::cloneElementwise(OpBuilder &rewriter, } for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { - auto origType = op->getResult(i).getType().dyn_cast(); + auto origType = dyn_cast(op->getResult(i).getType()); if (!origType) continue; auto newType = RankedTensorType::get(origType.getShape(), @@ -608,7 +648,7 @@ Operation *LayoutPropagation::rewriteWhileOp(scf::WhileOp whileOp) { returnTypes.push_back(ret.getType()); continue; } - auto origType = ret.getType().dyn_cast(); + auto origType = dyn_cast(ret.getType()); auto newType = RankedTensorType::get(origType.getShape(), origType.getElementType(), it->second.encodings[0]); @@ -659,7 +699,7 @@ Operation *LayoutPropagation::rewriteIfOp(scf::IfOp ifOp) { auto it = layouts.find(ifOp->getResult(i)); if (it == layouts.end()) continue; - auto origType = ifOp->getResult(i).getType().cast(); + auto origType = cast(ifOp->getResult(i).getType()); Attribute encoding = *(it->second.encodings.begin()); newResultTypes[i] = RankedTensorType::get( origType.getShape(), origType.getElementType(), encoding); @@ -688,7 +728,7 @@ void LayoutPropagation::rewriteYieldOp(scf::YieldOp yieldOp) { if (auto whileOp = dyn_cast(parentOp)) yieldType = whileOp.getBeforeArguments()[operand.getOperandNumber()].getType(); - auto tensorType = yieldType.dyn_cast(); + auto tensorType = dyn_cast(yieldType); if (!tensorType) continue; Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); @@ -701,7 +741,7 @@ void LayoutPropagation::rewriteConditionOp(scf::ConditionOp conditionOp) { for (unsigned i = 1; i < conditionOp->getNumOperands(); ++i) { OpOperand &operand = conditionOp->getOpOperand(i); Type argType = whileOp->getResult(operand.getOperandNumber() - 1).getType(); - auto tensorType = argType.dyn_cast(); + auto tensorType = dyn_cast(argType); if (!tensorType) continue; Value newOperand = getValueAs(operand.get(), tensorType.getEncoding()); @@ -757,7 +797,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { if (it != layouts.end()) srcEncoding = *(it->second.encodings.begin()); Value src = getValueAs(convertOp.getSrc(), srcEncoding); - auto tensorType = op->getResult(0).getType().cast(); + auto tensorType = cast(op->getResult(0).getType()); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create(op->getLoc(), newType, src); @@ -766,7 +806,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { } if (canFoldIntoConversion(op, encoding)) { Operation *newOp = rewriter.clone(*op); - auto tensorType = op->getResult(0).getType().cast(); + auto tensorType = cast(op->getResult(0).getType()); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create(op->getLoc(), newType, @@ -834,6 +874,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, ConvertLayoutOp convertOp, IRMapping &mapping) { SetVector opsToRewrite; + // Keep track of yield operands that need to be duplicated. + DenseMap> yieldOperandsMap; for (Value v : slice) { auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); @@ -845,13 +887,22 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); if (auto ifOp = v.getDefiningOp()) { + unsigned operandIdx = v.cast().getResultNumber(); opsToRewrite.insert(ifOp.thenYield().getOperation()); + yieldOperandsMap[ifOp.thenYield()].push_back(operandIdx); opsToRewrite.insert(ifOp.elseYield().getOperation()); + yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx); } } else { - opsToRewrite.insert(v.cast().getOwner()->getParentOp()); - // We also need to rewrite the yield op. - opsToRewrite.insert(v.cast().getOwner()->getTerminator()); + BlockArgument blockArg = v.cast(); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (auto loopOp = cast(parentOp)) { + opsToRewrite.insert(loopOp.getOperation()); + OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + auto yieldOp = blockArg.getOwner()->getTerminator(); + yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + opsToRewrite.insert(yieldOp); + } } } opsToRewrite = multiRootTopologicalSort(opsToRewrite); @@ -895,6 +946,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, Value oldArg = loopBody.getArgument(m.first + numIndVars); addRematValue(newForOp.getResult(m.first), layout[oldArg], newForOp.getResult(m.second)); + addRematValue(oldArg, layout[oldArg], + loopBody.getArgument(m.second + numIndVars)); } continue; } @@ -905,7 +958,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto it = layout.find(res); assert(it != layout.end()); - auto oldType = res.getType().cast(); + auto oldType = cast(res.getType()); auto newType = RankedTensorType::get( oldType.getShape(), oldType.getElementType(), it->second); newTypes.push_back(newType); @@ -931,10 +984,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, builder.setInsertionPoint(op); if (auto yieldOp = dyn_cast(op)) { auto yieldOperands = llvm::to_vector(yieldOp.getOperands()); - for (Value operand : yieldOp.getOperands()) { - if (slice.count(operand) == 0) - continue; - yieldOperands.push_back(mapping.lookup(operand)); + SmallVector operandsToRewrite = yieldOperandsMap[op]; + // Sort so that operands are added in the same order as the new scf + // results/arguments. + std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); + for (int operandIdx : operandsToRewrite) { + yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); } builder.create(op->getLoc(), yieldOperands); op->erase(); @@ -942,7 +997,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } if (isa(op)) { Operation *newOp = builder.clone(*op); - auto tensorType = op->getResult(0).getType().cast(); + auto tensorType = cast(op->getResult(0).getType()); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), layout[op->getResult(0)]); @@ -961,17 +1016,16 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, Type oldType = old.getType(); Type newType; if (isTensorPointerType(oldType)) { - auto ptrType = oldType.cast(); - auto tensorType = ptrType.getPointeeType().cast(); + auto ptrType = cast(oldType); + auto tensorType = cast(ptrType.getPointeeType()); newType = triton::PointerType::get( RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), it->second), ptrType.getAddressSpace()); } else { newType = RankedTensorType::get( - old.getType().cast().getShape(), - old.getType().cast().getElementType(), - it->second); + cast(old.getType()).getShape(), + cast(old.getType()).getElementType(), it->second); } newV.setType(newType); addRematValue(old, it->second, newV); @@ -1137,7 +1191,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( // Move the convert before the ext op and rewrite the slice. OpBuilder builder(extOrBroadcatOp); auto tensorType = - extOrBroadcatOp->getOperand(0).getType().cast(); + cast(extOrBroadcatOp->getOperand(0).getType()); auto newType = RankedTensorType::get( tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( @@ -1145,7 +1199,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); auto oldExtOrBroadcastType = - extOrBroadcatOp->getResult(0).getType().cast(); + cast(extOrBroadcatOp->getResult(0).getType()); Type newExtOrBroadcasrType = RankedTensorType::get( oldExtOrBroadcastType.getShape(), oldExtOrBroadcastType.getElementType(), dstEncoding);