Skip to content

Commit

Permalink
Support __enzyme_reverse when inferring arg activities
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Nov 12, 2024
1 parent dd955e6 commit 20cb61f
Showing 1 changed file with 43 additions and 16 deletions.
59 changes: 43 additions & 16 deletions enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,14 @@ struct PrintActivityAnalysisPass
void inferArgActivitiesFromEnzymeAutodiff(
FunctionOpInterface callee, CallOpInterface autodiff_call,
MutableArrayRef<enzyme::Activity> argActivities,
MutableArrayRef<enzyme::Activity> resultActivities) {
MutableArrayRef<enzyme::Activity> resultActivities,
bool foundEnzymeReverse = false) {
unsigned argIdx = 1;
// __enzyme_reverse additionally takes in two extra parameters relating to
// the tape
if (foundEnzymeReverse)
argIdx += 2;

for (const auto &[paramIdx, paramType] :
llvm::enumerate(callee.getArgumentTypes())) {
Value arg = autodiff_call.getArgOperands()[argIdx];
Expand Down Expand Up @@ -205,6 +211,35 @@ struct PrintActivityAnalysisPass
}
}

FailureOr<Operation *> lookupAutodiffOrReverseCall(ModuleOp moduleOp,
bool *foundEnzymeReverse) {
auto tryToLookup = [&moduleOp](StringRef name) -> Operation * {
Operation *autodiff_decl = moduleOp.lookupSymbol(name);
if (!autodiff_decl) {
for (auto &subOp : *moduleOp.getBody()) {
if (auto func = dyn_cast<FunctionOpInterface>(&subOp)) {
if (func.getName().contains(name)) {
return &subOp;
}
}
}
}
return autodiff_decl;
};

Operation *autodiff_decl = tryToLookup("__enzyme_autodiff");
if (!autodiff_decl) {
autodiff_decl = tryToLookup("__enzyme_reverse");
*foundEnzymeReverse = autodiff_decl != nullptr;
}

if (!autodiff_decl) {
llvm::errs() << "Failed to find __enzyme_autodiff symbol";
return failure();
}
return autodiff_decl;
}

void runOnOperation() override {
enzyme::ActivityPrinterConfig config;
config.dataflow = dataflow;
Expand All @@ -217,22 +252,13 @@ struct PrintActivityAnalysisPass

if (inferFromAutodiff) {
// Infer the activity attributes from the __enzyme_autodiff call
Operation *autodiff_decl = moduleOp.lookupSymbol("__enzyme_autodiff");
if (!autodiff_decl) {
for (auto &subOp : *moduleOp.getBody()) {
if (auto func = dyn_cast<FunctionOpInterface>(&subOp)) {
if (func.getName().contains("__enzyme_autodiff")) {
autodiff_decl = &subOp;
break;
}
}
}
}
if (!autodiff_decl) {
moduleOp.emitError("Failed to find __enzyme_autodiff symbol");
bool foundEnzymeReverse = false;
auto autodiff_decl =
lookupAutodiffOrReverseCall(moduleOp, &foundEnzymeReverse);
if (failed(autodiff_decl)) {
return signalPassFailure();
}
auto uses = SymbolTable::getSymbolUses(autodiff_decl, moduleOp);
auto uses = SymbolTable::getSymbolUses(*autodiff_decl, moduleOp);
assert(uses && "failed to find symbol uses of autodiff decl");

for (SymbolTable::SymbolUse use : *uses) {
Expand All @@ -249,7 +275,8 @@ struct PrintActivityAnalysisPass
// Populate the argument activities based on either the type or the
// supplied annotation. First argument is the callee
inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call,
argActivities, resultActivities);
argActivities, resultActivities,
foundEnzymeReverse);
runActivityAnalysis(config, callee, argActivities, resultActivities);
}
return;
Expand Down

0 comments on commit 20cb61f

Please sign in to comment.