Skip to content

Commit

Permalink
No need to handle recursive calls separately
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 3, 2024
1 parent 48b76c4 commit bca7dee
Showing 1 changed file with 74 additions and 92 deletions.
166 changes: 74 additions & 92 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1953,99 +1953,81 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Derivative was not found, check if it is a recursive call
if (!OverloadedDerivedFn) {
if (FD == m_DiffReq.Function &&
m_DiffReq.Mode == DiffMode::experimental_pullback) {
// Recursive call.
Expr* selfRef =
m_Sema
.BuildDeclarationNameExpr(
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
pullbackCallArgs, PreCallStmts, dfdx());

OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef, Loc,
pullbackCallArgs, Loc, CUDAExecConfig)
.get();
} else {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingCallExpr(
pullbackCallArgs, PreCallStmts, dfdx());

// Overloaded derivative was not found, request the CladPlugin to
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs;

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (MD && isLambdaCallOperator(MD)) {
if (const auto* paramDecl = FD->getParamDecl(i))
pullbackRequest.DVI.push_back(paramDecl);
} else if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));

FunctionDecl* pullbackFD = nullptr;
if (m_ExternalSource)
// FIXME: Error estimation currently uses singleton objects -
// m_ErrorEstHandler and m_EstModel, which is cleared after each
// error_estimate request. This requires the pullback to be derived
// at the same time to access the singleton objects.
pullbackFD =
plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
else
pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest);

// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
// it already correctly support reference arguments then confirm the
// support and add tests for the same.
if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMethodDecl>(FD)) {
// Try numerically deriving it.
if (NArgs == 1) {
OverloadedDerivedFn = GetSingleArgCentralDiffCall(
Clone(CE->getCallee()), DerivedCallArgs[0],
/*targetPos=*/0,
/*numArgs=*/1, DerivedCallArgs, CUDAExecConfig);
asGrad = !OverloadedDerivedFn;
} else {
auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema);
OverloadedDerivedFn = GetMultiArgCentralDiffCall(
Clone(CE->getCallee()), CEType.getCanonicalType(),
CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts,
DerivedCallArgs, CallArgDx, CUDAExecConfig);
}
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!OverloadedDerivedFn) {
Stmts& block = getCurrentBlock(direction::reverse);
block.insert(block.begin(), PreCallStmts.begin(),
PreCallStmts.end());
return StmtDiff(Clone(CE));
}
} else if (pullbackFD) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, Loc);
} else {
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
}
// Overloaded derivative was not found, request the CladPlugin to
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;

// Mark the indexes of the global args. Necessary if the argument of the
// call has a different name than the function's signature parameter.
pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs;

pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis;
pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (MD && isLambdaCallOperator(MD)) {
if (const auto* paramDecl = FD->getParamDecl(i))
pullbackRequest.DVI.push_back(paramDecl);
} else if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));

FunctionDecl* pullbackFD = nullptr;
if (m_ExternalSource)
// FIXME: Error estimation currently uses singleton objects -
// m_ErrorEstHandler and m_EstModel, which is cleared after each
// error_estimate request. This requires the pullback to be derived
// at the same time to access the singleton objects.
pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
else
pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest);

// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
// it already correctly support reference arguments then confirm the
// support and add tests for the same.
if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMethodDecl>(FD)) {
// Try numerically deriving it.
if (NArgs == 1) {
OverloadedDerivedFn = GetSingleArgCentralDiffCall(
Clone(CE->getCallee()), DerivedCallArgs[0],
/*targetPos=*/0,
/*numArgs=*/1, DerivedCallArgs, CUDAExecConfig);
asGrad = !OverloadedDerivedFn;
} else {
auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema);
OverloadedDerivedFn = GetMultiArgCentralDiffCall(
Clone(CE->getCallee()), CEType.getCanonicalType(),
CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts,
DerivedCallArgs, CallArgDx, CUDAExecConfig);
}
CallExprDiffDiagnostics(FD, CE->getBeginLoc());
if (!OverloadedDerivedFn) {
Stmts& block = getCurrentBlock(direction::reverse);
block.insert(block.begin(), PreCallStmts.begin(), PreCallStmts.end());
return StmtDiff(Clone(CE));
}
} else if (pullbackFD) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, Loc);
} else {
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
}
}
}
Expand Down

0 comments on commit bca7dee

Please sign in to comment.