diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index 473a526142f..2236c79ac19 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -77,8 +77,14 @@ struct PrintActivityAnalysisPass void inferArgActivitiesFromEnzymeAutodiff( FunctionOpInterface callee, CallOpInterface autodiff_call, MutableArrayRef argActivities, - MutableArrayRef resultActivities) { + MutableArrayRef resultActivities, + bool foundEnzymeReverse = false) { unsigned argIdx = 1; + // __enzyme_reverse additionally takes in two extra parameters relating to + // the tape + if (foundEnzymeReverse) + argIdx += 2; + for (const auto &[paramIdx, paramType] : llvm::enumerate(callee.getArgumentTypes())) { Value arg = autodiff_call.getArgOperands()[argIdx]; @@ -205,6 +211,35 @@ struct PrintActivityAnalysisPass } } + FailureOr lookupAutodiffOrReverseCall(ModuleOp moduleOp, + bool *foundEnzymeReverse) { + auto tryToLookup = [&moduleOp](StringRef name) -> Operation * { + Operation *autodiff_decl = moduleOp.lookupSymbol(name); + if (!autodiff_decl) { + for (auto &subOp : *moduleOp.getBody()) { + if (auto func = dyn_cast(&subOp)) { + if (func.getName().contains(name)) { + return &subOp; + } + } + } + } + return autodiff_decl; + }; + + Operation *autodiff_decl = tryToLookup("__enzyme_autodiff"); + if (!autodiff_decl) { + autodiff_decl = tryToLookup("__enzyme_reverse"); + *foundEnzymeReverse = autodiff_decl != nullptr; + } + + if (!autodiff_decl) { + llvm::errs() << "Failed to find __enzyme_autodiff symbol"; + return failure(); + } + return autodiff_decl; + } + void runOnOperation() override { enzyme::ActivityPrinterConfig config; config.dataflow = dataflow; @@ -217,22 +252,13 @@ struct PrintActivityAnalysisPass if (inferFromAutodiff) { // Infer the activity attributes from the __enzyme_autodiff call - Operation *autodiff_decl = moduleOp.lookupSymbol("__enzyme_autodiff"); - if (!autodiff_decl) { - for (auto &subOp : *moduleOp.getBody()) { - if (auto func = dyn_cast(&subOp)) { - if (func.getName().contains("__enzyme_autodiff")) { - autodiff_decl = &subOp; - break; - } - } - } - } - if (!autodiff_decl) { - moduleOp.emitError("Failed to find __enzyme_autodiff symbol"); + bool foundEnzymeReverse = false; + auto autodiff_decl = + lookupAutodiffOrReverseCall(moduleOp, &foundEnzymeReverse); + if (failed(autodiff_decl)) { return signalPassFailure(); } - auto uses = SymbolTable::getSymbolUses(autodiff_decl, moduleOp); + auto uses = SymbolTable::getSymbolUses(*autodiff_decl, moduleOp); assert(uses && "failed to find symbol uses of autodiff decl"); for (SymbolTable::SymbolUse use : *uses) { @@ -249,7 +275,8 @@ struct PrintActivityAnalysisPass // Populate the argument activities based on either the type or the // supplied annotation. First argument is the callee inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call, - argActivities, resultActivities); + argActivities, resultActivities, + foundEnzymeReverse); runActivityAnalysis(config, callee, argActivities, resultActivities); } return;