diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f288c7fc2cb77b..b58a95c3baf70a 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -624,10 +624,9 @@ class ModifyOperationRewrite : public OperationRewrite { class ReplaceOperationRewrite : public OperationRewrite { public: ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, const TypeConverter *converter, - bool changedResults) + Operation *op, const TypeConverter *converter) : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), - converter(converter), changedResults(changedResults) {} + converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::ReplaceOperation; @@ -641,15 +640,10 @@ class ReplaceOperationRewrite : public OperationRewrite { const TypeConverter *getConverter() const { return converter; } - bool hasChangedResults() const { return changedResults; } - private: /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. const TypeConverter *converter; - - /// A boolean flag that indicates whether result types have changed or not. - bool changedResults; }; class CreateOperationRewrite : public OperationRewrite { @@ -941,6 +935,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector replacedOps; + /// A set of all unresolved materializations. + DenseSet unresolvedMaterializations; + /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1066,6 +1063,7 @@ void UnresolvedMaterializationRewrite::rollback() { for (Value input : op->getOperands()) rewriterImpl.mapping.erase(input); } + rewriterImpl.unresolvedMaterializations.erase(op); op->erase(); } @@ -1347,6 +1345,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create(loc, outputType, inputs); + unresolvedMaterializations.insert(convertOp); appendRewrite(convertOp, converter, kind); return convertOp.getResult(0); } @@ -1379,22 +1378,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); - // Track if any of the results changed, e.g. erased and replaced with null. - bool resultChanged = false; - // Create mappings for each of the new result values. for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { if (!newValue) { - resultChanged = true; - continue; + // This result was dropped and no replacement value was provided. + if (unresolvedMaterializations.contains(op)) { + // Do not create another materializations if we are erasing a + // materialization. + continue; + } + + // Materialize a replacement value "out of thin air". + newValue = buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(result), + result.getLoc(), /*inputs=*/ValueRange(), + /*outputType=*/result.getType(), currentTypeConverter); } + // Remap, and check for any result type changes. mapping.map(result, newValue); - resultChanged |= (newValue.getType() != result.getType()); } - appendRewrite(op, currentTypeConverter, - resultChanged); + appendRewrite(op, currentTypeConverter); // Mark this operation and all nested ops as replaced. op->walk([&](Operation *op) { replacedOps.insert(op); }); @@ -2359,11 +2364,6 @@ struct OperationConverter { ConversionPatternRewriterImpl &rewriterImpl, DenseMap> &inverseMapping); - /// Legalize an operation result that was marked as "erased". - LogicalResult - legalizeErasedResult(Operation *op, OpResult result, - ConversionPatternRewriterImpl &rewriterImpl); - /// Dialect conversion configuration. ConversionConfig config; @@ -2455,77 +2455,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, return failure(); } -/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results -/// are not used (transitively) by any op that is not in the given list of -/// cast ops. -/// -/// In particular, this function erases cyclic casts that may be inserted -/// during the dialect conversion process. E.g.: -/// %0 = unrealized_conversion_cast(%1) -/// %1 = unrealized_conversion_cast(%0) -// Note: This step will become unnecessary when -// https://github.com/llvm/llvm-project/pull/106760 has been merged. -static void eraseDeadUnrealizedCasts( - ArrayRef castOps, - SmallVectorImpl *remainingCastOps) { - // Ops that have already been visited or are currently being visited. - DenseSet visited; - // Set of all cast ops for faster lookups. - DenseSet castOpSet; - // Set of all cast ops that have been determined to be alive. - DenseSet live; - - for (UnrealizedConversionCastOp op : castOps) - castOpSet.insert(op); - - // Visit a cast operation. Return "true" if the operation is live. - std::function visit = [&](Operation *op) -> bool { - // No need to traverse any IR if the op was already marked as live. - if (live.contains(op)) - return true; - - // Do not visit ops multiple times. If we find a circle, no live user was - // found on the current path. - if (!visited.insert(op).second) - return false; - - // Visit all users. - for (Operation *user : op->getUsers()) { - // If the user is not an unrealized_conversion_cast op, then the given op - // is live. - if (!castOpSet.contains(user)) { - live.insert(op); - return true; - } - // Otherwise, it is live if a live op can be reached from one of its - // users (which must all be unrealized_conversion_cast ops). - if (visit(user)) { - live.insert(op); - return true; - } - } - - return false; - }; - - // Visit all cast ops. - for (UnrealizedConversionCastOp op : castOps) { - visit(op); - visited.clear(); - } - - // Erase all cast ops that are dead. - for (UnrealizedConversionCastOp op : castOps) { - if (live.contains(op)) { - if (remainingCastOps) - remainingCastOps->push_back(op); - continue; - } - op->dropAllUses(); - op->erase(); - } -} - LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (ops.empty()) return success(); @@ -2584,14 +2513,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Reconcile all UnrealizedConversionCastOps that were inserted by the // dialect conversion frameworks. (Not the one that were inserted by // patterns.) - SmallVector remainingCastOps1, remainingCastOps2; - eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1); - reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2); + SmallVector remainingCastOps; + reconcileUnrealizedCasts(allCastOps, &remainingCastOps); // Try to legalize all unresolved materializations. if (config.buildMaterializations) { IRRewriter rewriter(rewriterImpl.context, config.listener); - for (UnrealizedConversionCastOp castOp : remainingCastOps2) { + for (UnrealizedConversionCastOp castOp : remainingCastOps) { auto it = rewriteMap.find(castOp.getOperation()); assert(it != rewriteMap.end() && "inconsistent state"); if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) @@ -2646,30 +2574,22 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes( for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { auto *opReplacement = dyn_cast(rewriterImpl.rewrites[i].get()); - if (!opReplacement || !opReplacement->hasChangedResults()) + if (!opReplacement) continue; Operation *op = opReplacement->getOperation(); for (OpResult result : op->getResults()) { - Value newValue = rewriterImpl.mapping.lookupOrNull(result); - - // If the operation result was replaced with null, all of the uses of this - // value should be replaced. - if (!newValue) { - if (failed(legalizeErasedResult(op, result, rewriterImpl))) - return failure(); + // If the type of this op result changed and the result is still live, + // we need to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(result, result.getType())) continue; - } - - // Otherwise, check to see if the type of the result changed. - if (result.getType() == newValue.getType()) - continue; - Operation *liveUser = findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); if (!liveUser) continue; // Legalize this result. + Value newValue = rewriterImpl.mapping.lookupOrNull(result); + assert(newValue && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), op->getLoc(), /*inputs=*/newValue, /*outputType=*/result.getType(), @@ -2727,25 +2647,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( return success(); } -LogicalResult OperationConverter::legalizeErasedResult( - Operation *op, OpResult result, - ConversionPatternRewriterImpl &rewriterImpl) { - // If the operation result was replaced with null, all of the uses of this - // value should be replaced. - auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - if (liveUserIt != result.user_end()) { - InFlightDiagnostic diag = op->emitError("failed to legalize operation '") - << op->getName() << "' marked as erased"; - diag.attachNote(liveUserIt->getLoc()) - << "found live user of result #" << result.getResultNumber() << ": " - << *liveUserIt; - return failure(); - } - return success(); -} - //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir index 49275e8008e749..6e8f0162e505d0 100644 --- a/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir +++ b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir @@ -3,8 +3,8 @@ // Test that an error is emitted when an operation is marked as "erased", but // has users that live across the conversion. func.func @remove_all_ops(%arg0: i32) -> i32 { - // expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}} + // expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}} %0 = "test.illegal_op_a"() : () -> i32 - // expected-note@below {{found live user of result #0: func.return %0 : i32}} + // expected-note@below {{see existing live user here}} return %0 : i32 }