diff --git a/include/clad/Differentiator/CladConfig.h b/include/clad/Differentiator/CladConfig.h index 39d47efd8..a31b04db9 100644 --- a/include/clad/Differentiator/CladConfig.h +++ b/include/clad/Differentiator/CladConfig.h @@ -31,6 +31,8 @@ enum opts : unsigned { disable_tbr = 1 << (ORDER_BITS + 3), enable_va = 1 << (ORDER_BITS + 5), disable_va = 1 << (ORDER_BITS + 6), + enable_ua = 1 << (ORDER_BITS + 7), + disable_ua = 1 << (ORDER_BITS + 8), // Specifying whether we only want the diagonal of the hessian. diagonal_only = 1 << (ORDER_BITS + 4), diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 2116d5ea0..1159c4fd2 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -38,9 +38,16 @@ struct DiffRequest { bool HasAnalysisRun = false; } m_ActivityRunInfo; + mutable struct UsefulRunInfo { + std::set UsefulDecls; + std::set UsefulFuncs; + bool HasAnalysisRun = false; + } m_UsefulRunInfo; + public: /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; + bool ReqAdj = true; /// Name of the base function to be differentiated. Can be different from /// function->getNameAsString() when higher-order derivatives are computed. std::string BaseFunctionName = {}; @@ -65,6 +72,7 @@ struct DiffRequest { /// A flag to enable TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; bool EnableVariedAnalysis = false; + bool EnableUsefulAnalysis = false; /// Puts the derived function and its code in the diff call void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD, clang::Sema& SemaRef); @@ -123,6 +131,7 @@ struct DiffRequest { CallContext == other.CallContext && Args == other.Args && Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && EnableVariedAnalysis == other.EnableVariedAnalysis && + EnableUsefulAnalysis == other.EnableUsefulAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && DeclarationOnly == other.DeclarationOnly; } @@ -141,6 +150,7 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + bool shouldHaveAdjointForw(const clang::VarDecl* VD) const; }; using DiffInterval = std::vector; @@ -150,6 +160,7 @@ struct DiffRequest { /// TBR analysis during reverse-mode differentiation. bool EnableTBRAnalysis = false; bool EnableVariedAnalysis = false; + bool EnableUsefulAnalysis = false; }; class DiffCollector: public clang::RecursiveASTVisitor { diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 8015b8fdb..a7b28ff7b 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1063,7 +1063,14 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // If DRE is of type pointer, then the derivative is a null pointer. if (clonedDRE->getType()->isPointerType()) return StmtDiff(clonedDRE, nullptr); + + if (auto* i = cast(DRE->getDecl())) { + if (!m_DiffReq.shouldHaveAdjointForw(i)) + return StmtDiff(clonedDRE, nullptr); + } + QualType literalTy = utils::GetValueType(clonedDRE->getType()); + return StmtDiff(clonedDRE, ConstantFolder::synthesizeLiteral( literalTy, m_Context, /*val=*/0)); } @@ -1208,6 +1215,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { } llvm::SmallVector pushforwardFnArgs; + // pushforwardFnArgs.insert(pushforwardFnArgs.end(), CallArgs.begin(), CallArgs.end()); pushforwardFnArgs.insert(pushforwardFnArgs.end(), diffArgs.begin(), @@ -1284,6 +1292,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { pushforwardFnRequest.BaseFunctionName = utils::ComputeEffectiveFnName(FD); // Silence diag outputs in nested derivation process. pushforwardFnRequest.VerboseDiags = false; + pushforwardFnRequest.EnableUsefulAnalysis = m_DiffReq.EnableUsefulAnalysis; // Check if request already derived in DerivedFunctions. FunctionDecl* pushforwardFD = @@ -1446,7 +1455,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { derivedR = BuildParens(derivedR); opDiff = BuildOp(opCode, derivedL, derivedR); } else if (BinOp->isAssignmentOp()) { - if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) { + if (Ldiff.getExpr_dx() && + Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) { diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(), "derivative of an assignment attempts to assign to unassignable " "expr, assignment ignored"); @@ -1575,11 +1585,16 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD, VarDecl* VDDerived = BuildVarDecl(VD->getType(), "_d_" + VD->getNameAsString(), initDx, VD->isDirectInit(), /*TSI=*/nullptr, VD->getInitStyle()); - m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); + + if (!m_DiffReq.shouldHaveAdjointForw(VD)) + VDDerived = nullptr; + else + m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); return DeclDiff(VDClone, VDDerived); } StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { + // llvm::errs() << "\nVisitDeclStmt"; llvm::SmallVector decls; llvm::SmallVector declsDiff; // If the type is marked as non_differentiable, skip generating its derivative @@ -1642,7 +1657,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { if (VDDiff.getDecl()->getDeclName() != VD->getDeclName()) m_DeclReplacements[VD] = VDDiff.getDecl(); decls.push_back(VDDiff.getDecl()); - declsDiff.push_back(VDDiff.getDecl_dx()); + if (m_DiffReq.shouldHaveAdjointForw(VD)) + declsDiff.push_back(VDDiff.getDecl_dx()); } else if (auto* SAD = dyn_cast(D)) { DeclDiff SADDiff = DifferentiateStaticAssertDecl(SAD); if (SADDiff.getDecl()) @@ -1661,6 +1677,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { DSClone = BuildDeclStmt(decls); if (!declsDiff.empty()) DSDiff = BuildDeclStmt(declsDiff); + // llvm::errs() << "\n====="; + // DSDiff->dump(); return StmtDiff(DSClone, DSDiff); } diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 7f928b2ac..08beaa076 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -38,6 +38,7 @@ llvm_add_library(cladDifferentiator ReverseModeVisitor.cpp TBRAnalyzer.cpp StmtClone.cpp + UsefulAnalyzer.cpp VectorForwardModeVisitor.cpp VectorPushForwardModeVisitor.cpp Version.cpp diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index a526a6d58..bf4905782 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -2,6 +2,7 @@ #include "ActivityAnalyzer.h" #include "TBRAnalyzer.h" +#include "UsefulAnalyzer.h" #include "clang/AST/ASTContext.h" #include "clang/AST/RecursiveASTVisitor.h" @@ -636,6 +637,26 @@ namespace clad { return found != m_ActivityRunInfo.ToBeRecorded.end(); } + bool DiffRequest::shouldHaveAdjointForw(const VarDecl* VD) const { + if (!EnableUsefulAnalysis) + return true; + + if (!m_UsefulRunInfo.HasAnalysisRun) { + + UsefulAnalyzer analyzer(Function->getASTContext(), + m_UsefulRunInfo.UsefulDecls, + m_UsefulRunInfo.UsefulFuncs); + analyzer.Analyze(Function); + m_UsefulRunInfo.HasAnalysisRun = true; + // llvm::errs() << "ToBeRecorded: "; + // for (auto* i : m_UsefulRunInfo.UsefulDecls) + // llvm::errs() << i->getNameAsString() << " "; + // llvm::errs() << "\n"; + } + auto found = m_UsefulRunInfo.UsefulDecls.find(VD); + return found != m_UsefulRunInfo.UsefulDecls.end(); + } + bool DiffCollector::VisitCallExpr(CallExpr* E) { // Check if we should look into this. // FIXME: Generated code does not usually have valid source locations. @@ -669,6 +690,8 @@ namespace clad { bool disable_tbr_in_req = false; bool enable_va_in_req = false; bool disable_va_in_req = false; + bool enable_ua_in_req = false; + bool disable_ua_in_req = false; if (!A->getAnnotation().equals("E") && FD->getTemplateSpecializationArgs()) { const auto template_arg = FD->getTemplateSpecializationArgs()->get(0); @@ -689,6 +712,10 @@ namespace clad { clad::HasOption(bitmasked_opts_value, clad::opts::enable_va); disable_va_in_req = clad::HasOption(bitmasked_opts_value, clad::opts::disable_va); + enable_ua_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::enable_ua); + disable_ua_in_req = + clad::HasOption(bitmasked_opts_value, clad::opts::disable_ua); if (enable_tbr_in_req && disable_tbr_in_req) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, "Both enable and disable TBR options are specified."); @@ -711,6 +738,10 @@ namespace clad { } else { request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis; } + if (enable_ua_in_req || disable_ua_in_req) + request.EnableUsefulAnalysis = enable_ua_in_req && !disable_ua_in_req; + else + request.EnableUsefulAnalysis = m_Options.EnableUsefulAnalysis; if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) { if (!A->getAnnotation().equals("H")) { utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc, diff --git a/lib/Differentiator/PushForwardModeVisitor.cpp b/lib/Differentiator/PushForwardModeVisitor.cpp index 0bb0c858d..8786806d9 100644 --- a/lib/Differentiator/PushForwardModeVisitor.cpp +++ b/lib/Differentiator/PushForwardModeVisitor.cpp @@ -24,7 +24,6 @@ StmtDiff PushForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { // If there is no return value, we must not attempt to differentiate if (!RS->getRetValue()) return nullptr; - StmtDiff retValDiff = Visit(RS->getRetValue()); Expr* retVal = retValDiff.getExpr(); Expr* retVal_dx = retValDiff.getExpr_dx(); diff --git a/lib/Differentiator/UsefulAnalyzer.cpp b/lib/Differentiator/UsefulAnalyzer.cpp new file mode 100644 index 000000000..aa59b3baa --- /dev/null +++ b/lib/Differentiator/UsefulAnalyzer.cpp @@ -0,0 +1,156 @@ +#include "UsefulAnalyzer.h" + +using namespace clang; + +namespace clad { + +void UsefulAnalyzer::Analyze(const FunctionDecl* FD) { + // Build the CFG (control-flow graph) of FD. + clang::CFG::BuildOptions Options; + m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), &m_Context, Options); + + m_BlockData.resize(m_CFG->size()); + // Set current block ID to the ID of entry the block. + CFGBlock* exit = &m_CFG->getExit(); + m_CurBlockID = exit->getBlockID(); + m_BlockData[m_CurBlockID] = createNewVarsData({}); + for (const VarDecl* i : m_UsefulDecls) + m_BlockData[m_CurBlockID]->insert(i); + // Add the entry block to the queue. + m_CFGQueue.insert(m_CurBlockID); + + // Visit CFG blocks in the queue until it's empty. + while (!m_CFGQueue.empty()) { + auto IDIter = m_CFGQueue.begin(); + m_CurBlockID = *IDIter; + m_CFGQueue.erase(IDIter); + CFGBlock& nextBlock = *getCFGBlockByID(m_CurBlockID); + AnalyzeCFGBlock(nextBlock); + } +} + +CFGBlock* UsefulAnalyzer::getCFGBlockByID(unsigned ID) { + return *(m_CFG->begin() + ID); +} + +bool UsefulAnalyzer::isUseful(const VarDecl* VD) const { + const VarsData& curBranch = getCurBlockVarsData(); + return curBranch.find(VD) != curBranch.end(); +} + +void UsefulAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { + VarsData& curBranch = getCurBlockVarsData(); + curBranch.insert(VD); +} + +static void mergeVarsData(std::set* targetData, + std::set* mergeData) { + for (const clang::VarDecl* i : *mergeData) + targetData->insert(i); + *mergeData = *targetData; +} + +void UsefulAnalyzer::AnalyzeCFGBlock(const CFGBlock& block) { + + for (auto ib = block.end(); ib != block.begin() - 1; ib--) { + if (ib->getKind() == clang::CFGElement::Statement) { + + const clang::Stmt* S = ib->castAs().getStmt(); + // The const_cast is inevitable, since there is no + // ConstRecusiveASTVisitor. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + TraverseStmt(const_cast(S)); + } + } + + for (const clang::CFGBlock::AdjacentBlock pred : block.preds()) { + if (!pred) + continue; + auto& predData = m_BlockData[pred->getBlockID()]; + if (!predData) + predData = createNewVarsData(*m_BlockData[block.getBlockID()]); + + bool shouldPushPred = true; + if (pred->getBlockID() < block.getBlockID()) { + if (m_LoopMem == *m_BlockData[block.getBlockID()]) + shouldPushPred = false; + + for (const VarDecl* i : *m_BlockData[block.getBlockID()]) + m_LoopMem.insert(i); + } + + if (shouldPushPred) + m_CFGQueue.insert(pred->getBlockID()); + + mergeVarsData(predData.get(), m_BlockData[block.getBlockID()].get()); + } + + for (const VarDecl* i : *m_BlockData[block.getBlockID()]) + m_UsefulDecls.insert(i); +} + +bool UsefulAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { + Expr* L = BinOp->getLHS(); + Expr* R = BinOp->getRHS(); + const auto opCode = BinOp->getOpcode(); + if (BinOp->isAssignmentOp()) { + m_Useful = false; + TraverseStmt(L); + m_Marking = m_Useful; + TraverseStmt(R); + m_Marking = false; + } else if (opCode == BO_Add || opCode == BO_Sub || opCode == BO_Mul || + opCode == BO_Div) { + for (auto* subexpr : BinOp->children()) + if (!isa(subexpr)) + TraverseStmt(subexpr); + } + return true; +} + +bool UsefulAnalyzer::VisitDeclStmt(DeclStmt* DS) { + for (Decl* D : DS->decls()) { + if (auto* VD = dyn_cast(D)) { + if (isUseful(VD)) { + m_Useful = true; + m_Marking = true; + } + if (Expr* init = cast(D)->getInit()) + TraverseStmt(init); + m_Marking = false; + } + } + return true; +} + +bool UsefulAnalyzer::VisitReturnStmt(ReturnStmt* RS) { + m_Useful = true; + m_Marking = true; + auto* rv = RS->getRetValue(); + TraverseStmt(rv); + return true; +} + +bool UsefulAnalyzer::VisitCallExpr(CallExpr* CE) { + if (m_Useful) + return true; + FunctionDecl* FD = CE->getDirectCallee(); + m_UsefulFuncs.insert(FD); + return true; +} + +bool UsefulAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { + auto* VD = dyn_cast(DRE->getDecl()); + if (!VD) + return true; + + if (isUseful(VD)) + m_Useful = true; + + if (m_Useful && m_Marking) + copyVarToCurBlock(VD); + + return true; +} + +} // namespace clad diff --git a/lib/Differentiator/UsefulAnalyzer.h b/lib/Differentiator/UsefulAnalyzer.h new file mode 100644 index 000000000..8bd2dce66 --- /dev/null +++ b/lib/Differentiator/UsefulAnalyzer.h @@ -0,0 +1,72 @@ +#ifndef CLAD_DIFFERENTIATOR_USEFULANALYZER_H +#define CLAD_DIFFERENTIATOR_USEFULANALYZER_H +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Analysis/CFG.h" + +#include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/Compatibility.h" + +#include +#include + +namespace clad { + +class UsefulAnalyzer : public clang::RecursiveASTVisitor { + + bool m_Useful = false; + bool m_Marking = false; + + std::set& m_UsefulDecls; + std::set& m_UsefulFuncs; + // std::set& m_VariedDecls; + using VarsData = std::set; + /// A helper method to allocate VarsData + /// \param toAssign - Parameter to initialize new VarsData with. + /// \return Unique pointer to a new object of type Varsdata. + static std::unique_ptr createNewVarsData(VarsData toAssign) { + return std::unique_ptr(new VarsData(std::move(toAssign))); + } + VarsData m_LoopMem; + + clang::CFGBlock* getCFGBlockByID(unsigned ID); + + clang::ASTContext& m_Context; + std::unique_ptr m_CFG; + std::vector> m_BlockData; + unsigned m_CurBlockID{}; + std::set m_CFGQueue; + bool isUseful(const clang::VarDecl* VD) const; + void copyVarToCurBlock(const clang::VarDecl* VD); + VarsData& getCurBlockVarsData() { return *m_BlockData[m_CurBlockID]; } + [[nodiscard]] const VarsData& getCurBlockVarsData() const { + return const_cast(this)->getCurBlockVarsData(); + } + void AnalyzeCFGBlock(const clang::CFGBlock& block); + +public: + /// Constructor + UsefulAnalyzer(clang::ASTContext& Context, + std::set& Decls, + std::set& Funcs) + : m_UsefulDecls(Decls), m_UsefulFuncs(Funcs), m_Context(Context) {} + + /// Destructor + ~UsefulAnalyzer() = default; + + /// Delete copy/move operators and constructors. + UsefulAnalyzer(const UsefulAnalyzer&) = delete; + UsefulAnalyzer& operator=(const UsefulAnalyzer&) = delete; + UsefulAnalyzer(const UsefulAnalyzer&&) = delete; + UsefulAnalyzer& operator=(const UsefulAnalyzer&&) = delete; + + /// Runs Varied analysis. + /// \param FD Function to run the analysis on. + void Analyze(const clang::FunctionDecl* FD); + bool VisitReturnStmt(clang::ReturnStmt* RS); + bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); + bool VisitBinaryOperator(clang::BinaryOperator* BinOp); + bool VisitDeclStmt(clang::DeclStmt* DS); + bool VisitCallExpr(clang::CallExpr* CE); +}; +} // namespace clad +#endif // CLAD_DIFFERENTIATOR_USEFULANALYZER_H \ No newline at end of file diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp deleted file mode 100644 index e1b5cb35c..000000000 --- a/test/Analyses/ActivityReverse.cpp +++ /dev/null @@ -1,273 +0,0 @@ -// RUN: %cladclang %s -I%S/../../include -oActivity.out 2>&1 | %filecheck %s -// RUN: ./Activity.out | %filecheck_exec %s -// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-va %s -I%S/../../include -oActivity.out -// RUN: ./Activity.out | %filecheck_exec %s -//CHECK-NOT: {{.*error|warning|note:.*}} - -#include "clad/Differentiator/Differentiator.h" - -double f1(double x){ - double a = x*x; - double b = 1; - b = b*b; - return a; -} - -//CHECK: void f1_grad(double x, double *_d_x) { -//CHECK-NEXT: double _d_a = 0.; -//CHECK-NEXT: double a = x * x; -//CHECK-NEXT: double b = 1; -//CHECK-NEXT: double _t0 = b; -//CHECK-NEXT: b = b * b; -//CHECK-NEXT: _d_a += 1; -//CHECK-NEXT: b = _t0; -//CHECK-NEXT: { -//CHECK-NEXT: *_d_x += _d_a * x; -//CHECK-NEXT: *_d_x += x * _d_a; -//CHECK-NEXT: } -//CHECK-NEXT: } - -double f2(double x){ - double a = x*x; - double b = 1; - double g; - if(a) - b=x; - else if(b) - double d = b; - else - g = a; - return a; -} - -//CHECK: void f2_grad(double x, double *_d_x) { -//CHECK-NEXT: bool _cond0; -//CHECK-NEXT: double _t0; -//CHECK-NEXT: bool _cond1; -//CHECK-NEXT: double d = 0.; -//CHECK-NEXT: double _t1; -//CHECK-NEXT: double _d_a = 0.; -//CHECK-NEXT: double a = x * x; -//CHECK-NEXT: double _d_b = 0.; -//CHECK-NEXT: double b = 1; -//CHECK-NEXT: double _d_g = 0.; -//CHECK-NEXT: double g; -//CHECK-NEXT: { -//CHECK-NEXT: _cond0 = a; -//CHECK-NEXT: if (_cond0) { -//CHECK-NEXT: _t0 = b; -//CHECK-NEXT: b = x; -//CHECK-NEXT: } else { -//CHECK-NEXT: _cond1 = b; -//CHECK-NEXT: if (_cond1) -//CHECK-NEXT: d = b; -//CHECK-NEXT: else { -//CHECK-NEXT: _t1 = g; -//CHECK-NEXT: g = a; -//CHECK-NEXT: } -//CHECK-NEXT: } -//CHECK-NEXT: } -//CHECK-NEXT: _d_a += 1; -//CHECK-NEXT: if (_cond0) { -//CHECK-NEXT: b = _t0; -//CHECK-NEXT: double _r_d0 = _d_b; -//CHECK-NEXT: _d_b = 0.; -//CHECK-NEXT: *_d_x += _r_d0; -//CHECK-NEXT: } else if (!_cond1) { -//CHECK-NEXT: g = _t1; -//CHECK-NEXT: double _r_d1 = _d_g; -//CHECK-NEXT: _d_g = 0.; -//CHECK-NEXT: _d_a += _r_d1; -//CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: *_d_x += _d_a * x; -//CHECK-NEXT: *_d_x += x * _d_a; -//CHECK-NEXT: } -//CHECK-NEXT: } - -double f3(double x){ - double x1, x2, x3, x4, x5 = 0; - while(!x3){ - x5 = x4; - x4 = x3; - x3 = x2; - x2 = x1; - x1 = x; - } - return x5; -} - -//CHECK: void f3_grad(double x, double *_d_x) { -//CHECK-NEXT: clad::tape _t1 = {}; -//CHECK-NEXT: clad::tape _t2 = {}; -//CHECK-NEXT: clad::tape _t3 = {}; -//CHECK-NEXT: clad::tape _t4 = {}; -//CHECK-NEXT: clad::tape _t5 = {}; -//CHECK-NEXT: double _d_x1 = 0., _d_x2 = 0., _d_x3 = 0., _d_x4 = 0., _d_x5 = 0.; -//CHECK-NEXT: double x1, x2, x3, x4, x5 = 0; -//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; -//CHECK-NEXT: while (!x3) -//CHECK-NEXT: { -//CHECK-NEXT: _t0++; -//CHECK-NEXT: clad::push(_t1, x5); -//CHECK-NEXT: x5 = x4; -//CHECK-NEXT: clad::push(_t2, x4); -//CHECK-NEXT: x4 = x3; -//CHECK-NEXT: clad::push(_t3, x3); -//CHECK-NEXT: x3 = x2; -//CHECK-NEXT: clad::push(_t4, x2); -//CHECK-NEXT: x2 = x1; -//CHECK-NEXT: clad::push(_t5, x1); -//CHECK-NEXT: x1 = x; -//CHECK-NEXT: } -//CHECK-NEXT: _d_x5 += 1; -//CHECK-NEXT: while (_t0) -//CHECK-NEXT: { -//CHECK-NEXT: { -//CHECK-NEXT: { -//CHECK-NEXT: x1 = clad::pop(_t5); -//CHECK-NEXT: double _r_d4 = _d_x1; -//CHECK-NEXT: _d_x1 = 0.; -//CHECK-NEXT: *_d_x += _r_d4; -//CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: x2 = clad::pop(_t4); -//CHECK-NEXT: double _r_d3 = _d_x2; -//CHECK-NEXT: _d_x2 = 0.; -//CHECK-NEXT: _d_x1 += _r_d3; -//CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: x3 = clad::pop(_t3); -//CHECK-NEXT: double _r_d2 = _d_x3; -//CHECK-NEXT: _d_x3 = 0.; -//CHECK-NEXT: _d_x2 += _r_d2; -//CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: x4 = clad::pop(_t2); -//CHECK-NEXT: double _r_d1 = _d_x4; -//CHECK-NEXT: _d_x4 = 0.; -//CHECK-NEXT: _d_x3 += _r_d1; -//CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: x5 = clad::pop(_t1); -//CHECK-NEXT: double _r_d0 = _d_x5; -//CHECK-NEXT: _d_x5 = 0.; -//CHECK-NEXT: _d_x4 += _r_d0; -//CHECK-NEXT: } -//CHECK-NEXT: } -//CHECK-NEXT: _t0--; -//CHECK-NEXT: } -//CHECK-NEXT: } - -double f4_1(double v, double u){ - double k = 2*u; - double n = 2*v; - return n*k; -} -double f4(double x){ - double c = f4_1(x, 1); - return c; -} -// CHECK-NEXT: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u); - -// CHECK: void f4_grad(double x, double *_d_x) { -// CHECK-NEXT: double _d_c = 0.; -// CHECK-NEXT: double c = f4_1(x, 1); -// CHECK-NEXT: _d_c += 1; -// CHECK-NEXT: { -// CHECK-NEXT: double _r0 = 0.; -// CHECK-NEXT: double _r1 = 0.; -// CHECK-NEXT: f4_1_pullback(x, 1, _d_c, &_r0, &_r1); -// CHECK-NEXT: *_d_x += _r0; -// CHECK-NEXT: } -// CHECK-NEXT: } - -double f5(double x){ - double g = x ? 1 : 2; - return g; -} -// CHECK: void f5_grad(double x, double *_d_x) { -// CHECK-NEXT: double _cond0 = x; -// CHECK-NEXT: double _d_g = 0.; -// CHECK-NEXT: double g = _cond0 ? 1 : 2; -// CHECK-NEXT: _d_g += 1; -// CHECK-NEXT: } - -double f6(double x){ - double a = 0; - if(0){ - a = x; - } - return a; -} - -// CHECK: void f6_grad(double x, double *_d_x) { -// CHECK-NEXT: double _t0; -// CHECK-NEXT: double a = 0; -// CHECK-NEXT: if (0) { -// CHECK-NEXT: _t0 = a; -// CHECK-NEXT: a = x; -// CHECK-NEXT: } -// CHECK-NEXT: if (0) { -// CHECK-NEXT: a = _t0; -// CHECK-NEXT: } -// CHECK-NEXT: } - -double f7(double x){ - double &a = x; - double* b = &a; - double arr[3] = {1,2,3}; - double c = arr[0]*(*b)+arr[1]*a+arr[2]*x; - return a; -} - -// CHECK: void f7_grad(double x, double *_d_x) { -// CHECK-NEXT: double &_d_a = *_d_x; -// CHECK-NEXT: double &a = x; -// CHECK-NEXT: double *_d_b = &_d_a; -// CHECK-NEXT: double *b = &a; -// CHECK-NEXT: double _d_arr[3] = {0}; -// CHECK-NEXT: double arr[3] = {1, 2, 3}; -// CHECK-NEXT: double _d_c = 0.; -// CHECK-NEXT: double c = arr[0] * *b + arr[1] * a + arr[2] * x; -// CHECK-NEXT: _d_a += 1; -// CHECK-NEXT: { -// CHECK-NEXT: _d_arr[0] += _d_c * *b; -// CHECK-NEXT: *_d_b += arr[0] * _d_c; -// CHECK-NEXT: _d_arr[1] += _d_c * a; -// CHECK-NEXT: _d_a += arr[1] * _d_c; -// CHECK-NEXT: _d_arr[2] += _d_c * x; -// CHECK-NEXT: *_d_x += arr[2] * _d_c; -// CHECK-NEXT: } -// CHECK-NEXT: } - -#define TEST(F, x) { \ - result[0] = 0; \ - auto F##grad = clad::gradient(F);\ - F##grad.execute(x, result);\ - printf("{%.2f}\n", result[0]); \ -} - -int main(){ - double result[3] = {}; - TEST(f1, 3);// CHECK-EXEC: {6.00} - TEST(f2, 3);// CHECK-EXEC: {6.00} - TEST(f3, 3);// CHECK-EXEC: {0.00} - TEST(f4, 3);// CHECK-EXEC: {4.00} - TEST(f5, 3);// CHECK-EXEC: {0.00} - TEST(f6, 3);// CHECK-EXEC: {0.00} - TEST(f7, 3);// CHECK-EXEC: {1.00} -} - -// CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { -// CHECK-NEXT: double _d_k = 0.; -// CHECK-NEXT: double k = 2 * u; -// CHECK-NEXT: double _d_n = 0.; -// CHECK-NEXT: double n = 2 * v; -// CHECK-NEXT: { -// CHECK-NEXT: _d_n += _d_y * k; -// CHECK-NEXT: _d_k += n * _d_y; -// CHECK-NEXT: } -// CHECK-NEXT: *_d_v += 2 * _d_n; -// CHECK-NEXT: *_d_u += 2 * _d_k; -// CHECK-NEXT: } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index d228a2dc3..dd42b0fbf 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -407,16 +407,26 @@ namespace clad { static void SetActivityAnalysisOptions(const DifferentiationOptions& DO, RequestOptions& opts) { // If user has explicitly specified the mode for AA, use it. - if (DO.EnableVariedAnalysis || DO.DisableActivityAnalysis) + if (DO.EnableVariedAnalysis || DO.DisableVariedAnalysis) opts.EnableVariedAnalysis = - DO.EnableVariedAnalysis && !DO.DisableActivityAnalysis; + DO.EnableVariedAnalysis && !DO.DisableVariedAnalysis; else opts.EnableVariedAnalysis = false; // Default mode. } + static void SetUsefulAnalysisOptions(const DifferentiationOptions& DO, + RequestOptions& opts) { + // If user has explicitly specified the mode for TBR analysis, use it. + if (DO.EnableUsefulAnalysis || DO.DisableUsefulAnalysis) + opts.EnableUsefulAnalysis = + DO.EnableUsefulAnalysis && !DO.DisableUsefulAnalysis; + else + opts.EnableUsefulAnalysis = false; // Default mode. + } void CladPlugin::SetRequestOptions(RequestOptions& opts) const { SetTBRAnalysisOptions(m_DO, opts); SetActivityAnalysisOptions(m_DO, opts); + SetUsefulAnalysisOptions(m_DO, opts); } void CladPlugin::FinalizeTranslationUnit() { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 89b62ce8f..e2e195569 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -56,7 +56,8 @@ class CladTimerGroup { DumpDerivedAST(false), GenerateSourceFile(false), ValidateClangVersion(true), EnableTBRAnalysis(false), DisableTBRAnalysis(false), EnableVariedAnalysis(false), - DisableActivityAnalysis(false), CustomEstimationModel(false), + DisableVariedAnalysis(false), EnableUsefulAnalysis(false), + DisableUsefulAnalysis(false), CustomEstimationModel(false), PrintNumDiffErrorInfo(false) {} bool DumpSourceFn : 1; @@ -68,7 +69,9 @@ class CladTimerGroup { bool EnableTBRAnalysis : 1; bool DisableTBRAnalysis : 1; bool EnableVariedAnalysis : 1; - bool DisableActivityAnalysis : 1; + bool DisableVariedAnalysis : 1; + bool EnableUsefulAnalysis : 1; + bool DisableUsefulAnalysis : 1; bool CustomEstimationModel : 1; bool PrintNumDiffErrorInfo : 1; std::string CustomModelName; @@ -320,7 +323,11 @@ class CladTimerGroup { } else if (args[i] == "-enable-va") { m_DO.EnableVariedAnalysis = true; } else if (args[i] == "-disable-va") { - m_DO.DisableActivityAnalysis = true; + m_DO.DisableVariedAnalysis = true; + } else if (args[i] == "-enable-ua") { + m_DO.EnableUsefulAnalysis = true; + } else if (args[i] == "-disable-ua") { + m_DO.DisableUsefulAnalysis = true; } else if (args[i] == "-fcustom-estimation-model") { m_DO.CustomEstimationModel = true; if (++i == e) { @@ -374,11 +381,16 @@ class CladTimerGroup { "be used together.\n"; return false; } - if (m_DO.EnableVariedAnalysis && m_DO.DisableActivityAnalysis) { + if (m_DO.EnableVariedAnalysis && m_DO.DisableVariedAnalysis) { llvm::errs() << "clad: Error: -enable-va and -disable-va cannot " "be used together.\n"; return false; } + if (m_DO.EnableUsefulAnalysis && m_DO.DisableUsefulAnalysis) { + llvm::errs() << "clad: Error: -enable-ua and -disable-ua cannot " + "be used together.\n"; + return false; + } return true; }