Skip to content

Commit

Permalink
Optimize pullback calls
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Mar 12, 2024
1 parent 55835f0 commit 836aa4f
Show file tree
Hide file tree
Showing 29 changed files with 386 additions and 699 deletions.
15 changes: 7 additions & 8 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ class ErrorEstimationHandler : public ExternalRMVSource {
/// \param[in] CallArgs The orignal call arguments of the function call.
/// \param[in] ArgResultDecls The differentiated call arguments.
/// \param[in] numArgs The number of call args.
void EmitNestedFunctionParamError(
clang::FunctionDecl* fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, size_t numArgs);
void
EmitNestedFunctionParamError(clang::FunctionDecl* fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
llvm::SmallVectorImpl<clang::Expr*>& ArgResult,
size_t numArgs);

/// Checks if a variable should be considered in error estimation.
///
Expand Down Expand Up @@ -181,16 +182,14 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& fnDecl,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) override;
void ActBeforeFinalizingAssignOp(clang::Expr*&, clang::Expr*&, clang::Expr*&,
clang::BinaryOperator::Opcode&) override;
void ActBeforeFinalizingDifferentiateSingleStmt(const direction& d) override;
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls,
bool hasAssignee) override;
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) override;
void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
llvm::SmallVectorImpl<clang::Decl*>& declsDiff) override;
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ExternalRMVSource {
virtual void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, bool asGrad) {}
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) {}

/// This is called just before finalising processing of post and pre
/// increment and decrement operations.
Expand Down Expand Up @@ -157,7 +157,7 @@ class ExternalRMVSource {

virtual void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls, bool hasAssignee) {}
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) {}

virtual void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
Expand Down
6 changes: 2 additions & 4 deletions include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls,
bool asGrad) override;
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) override;
void ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) override;
void ActAfterCloningLHSOfAssignOp(clang::Expr*&, clang::Expr*&,
clang::BinaryOperatorKind& opCode) override;
Expand All @@ -60,8 +59,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActBeforeFinalizingDifferentiateSingleExpr(const direction& d) override;
void ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls,
bool hasAssignee) override;
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) override;
void ActBeforeFinalizingVisitDeclStmt(
llvm::SmallVectorImpl<clang::Decl*>& decls,
llvm::SmallVectorImpl<clang::Decl*>& declsDiff) override;
Expand Down
10 changes: 7 additions & 3 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,20 @@ namespace clad {
///
/// \param[in] targetFuncCall The function to get the derivative for.
/// \param[in] retType The return type of the target call expression.
/// \param[in] dfdx The dfdx corresponding to this call expression.
/// \param[in] numArgs The total number of 'args'.
/// \param[in] NumericalDiffMultiArg The built statements to add to block
/// later.
/// \param[in] PreCallStmts The built statements to add to block
/// before the call to the derived function.
/// \param[in] PostCallStmts The built statements to add to block
/// after the call to the derived function.
/// \param[in] args All the arguments to the target function.
/// \param[in] outputArgs The output gradient arguments.
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
llvm::SmallVectorImpl<clang::Stmt*>& NumericalDiffMultiArg,
clang::Expr* dfdx, llvm::SmallVectorImpl<clang::Stmt*>& PreCallStmts,
llvm::SmallVectorImpl<clang::Stmt*>& PostCallStmts,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
/// Emits diagnostic messages on differentiation (or lack thereof) for
Expand Down
10 changes: 5 additions & 5 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void ErrorEstimationHandler::SaveReturnExpr(Expr* retExpr) {

void ErrorEstimationHandler::EmitNestedFunctionParamError(
FunctionDecl* fnDecl, llvm::SmallVectorImpl<Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<VarDecl*>& ArgResultDecls, size_t numArgs) {
llvm::SmallVectorImpl<Expr*>& ArgResult, size_t numArgs) {
assert(fnDecl && "Must have a value");
for (size_t i = 0; i < numArgs; i++) {
if (!fnDecl->getParamDecl(0)->getType()->isLValueReferenceType())
Expand All @@ -109,7 +109,7 @@ void ErrorEstimationHandler::EmitNestedFunctionParamError(
// if (utils::IsReferenceOrPointerType(fnDecl->getParamDecl(i)->getType()))
// continue;
Expr* errorExpr = m_EstModel->AssignError(
{derivedCallArgs[i], m_RMV->BuildDeclRef(ArgResultDecls[i])},
{derivedCallArgs[i], m_RMV->Clone(ArgResult[i])},
fnDecl->getNameInfo().getAsString() + "_param_" + std::to_string(i));
Expr* errorStmt = m_RMV->BuildOp(BO_AddAssign, m_FinalError, errorExpr);
m_ReverseErrorStmts.push_back(errorStmt);
Expand Down Expand Up @@ -372,7 +372,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingPostIncDecOp(StmtDiff& diff) {
void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<VarDecl*>& ArgResultDecls, bool asGrad) {
llvm::SmallVectorImpl<Expr*>& ArgResult, bool asGrad) {
if (OverloadedDerivedFn && asGrad) {
// Derivative was found.
FunctionDecl* fnDecl =
Expand All @@ -382,7 +382,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingVisitCallExpr(
// in the input prameters (if of reference type) to call and save to
// emit them later.

EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResultDecls,
EmitNestedFunctionParamError(fnDecl, derivedCallArgs, ArgResult,
CE->getNumArgs());
}
}
Expand Down Expand Up @@ -416,7 +416,7 @@ void ErrorEstimationHandler::ActBeforeFinalizingDifferentiateSingleExpr(

void ErrorEstimationHandler::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<DeclStmt*>& ArgDecls, bool hasAssignee) {
llvm::SmallVectorImpl<Stmt*>& ArgDecls, bool hasAssignee) {
auto errorRef =
m_RMV->BuildVarDecl(m_RMV->m_Context.DoubleTy, "_t",
m_RMV->getZeroInit(m_RMV->m_Context.DoubleTy));
Expand Down
8 changes: 4 additions & 4 deletions lib/Differentiator/MultiplexExternalRMVSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingVisitReturnStmt(
void MultiplexExternalRMVSource::ActBeforeFinalizingVisitCallExpr(
const clang::CallExpr*& CE, clang::Expr*& OverloadedDerivedFn,
llvm::SmallVectorImpl<clang::Expr*>& derivedCallArgs,
llvm::SmallVectorImpl<clang::VarDecl*>& ArgResultDecls, bool asGrad) {
llvm::SmallVectorImpl<clang::Expr*>& ArgResult, bool asGrad) {
for (auto source : m_Sources) {
source->ActBeforeFinalizingVisitCallExpr(CE, OverloadedDerivedFn, derivedCallArgs,
ArgResultDecls, asGrad);
source->ActBeforeFinalizingVisitCallExpr(
CE, OverloadedDerivedFn, derivedCallArgs, ArgResult, asGrad);
}
}

Expand Down Expand Up @@ -199,7 +199,7 @@ void MultiplexExternalRMVSource::ActBeforeFinalizingDifferentiateSingleExpr(

void MultiplexExternalRMVSource::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<clang::DeclStmt*>& ArgDecls, bool hasAssignee) {
llvm::SmallVectorImpl<clang::Stmt*>& ArgDecls, bool hasAssignee) {
for (auto source : m_Sources)
source->ActBeforeDifferentiatingCallExpr(pullbackArgs, ArgDecls,
hasAssignee);
Expand Down
Loading

0 comments on commit 836aa4f

Please sign in to comment.