diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 3b668ff87..9d683bedc 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -59,6 +59,7 @@ namespace clad { /// to maintain the correct statement order when the current block has /// delayed emission i.e. assignment LHS. Stmts m_PopIdxValues; + Stmts m_RPopIdxValues;//for reference variables as input in function argumentlist std::vector m_LoopBlock; unsigned outputArrayCursor = 0; unsigned numParams = 0; @@ -392,6 +393,33 @@ namespace clad { /// https://github.com/vgvassilev/clad/issues/385 clang::QualType GetParameterDerivativeType(clang::QualType yType, clang::QualType xType); + + + bool Ref=false; + bool hasReferenceType(const clang::QualType& type) { + return type->isReferenceType(); + } + bool printArgTypes(const clang::CallExpr* CE) { + const clang::FunctionDecl* FD = CE->getDirectCallee(); + if (!FD) { + return false; + } + int numParams = FD->getNumParams(); + if (CE->getNumArgs() != numParams) { + return false; + } + bool hasRefType = false; + for (int i = 0; i < numParams; i++) { + const clang::ParmVarDecl* param = FD->getParamDecl(i); + const clang::Expr* arg = CE->getArg(i); + bool isReferenceType = hasReferenceType(param->getType()); + if (isReferenceType) { + hasRefType = true; + } + } + return hasRefType; + } + /// Allows to easily create and manage a counter for counting the number of /// executed iterations of a loop. @@ -584,4 +612,4 @@ namespace clad { }; } // end namespace clad -#endif // CLAD_REVERSE_MODE_VISITOR_H +#endif // CLAD_REVERSE_MODE_VISITOR_H \ No newline at end of file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ebdab0751..06d47f970 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1260,6 +1260,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /*DirectInit=*/true); if (dfdx()) addToCurrentBlock(BuildDeclStmt(popVal), direction::reverse); + else if(Ref) + m_RPopIdxValues.push_back(BuildDeclStmt(popVal)); else m_PopIdxValues.push_back(BuildDeclStmt(popVal)); IdxStored = StmtDiff(IdxStored.getExpr(), BuildDeclRef(popVal)); @@ -1372,6 +1374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitCallExpr(const CallExpr* CE) { const FunctionDecl* FD = CE->getDirectCallee(); + Ref=printArgTypes(CE);//finds if reference passed as argument if (!FD) { diag(DiagnosticsEngine::Warning, CE->getEndLoc(), @@ -1547,10 +1550,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, block.insert(block.begin() + insertionPoint, BuildDeclStmt(argDiffLocalVD)); Expr* argDiffLocalE = BuildDeclRef(argDiffLocalVD); - + int Numref=0; + while(!m_RPopIdxValues.empty()) + { + Numref++; + block.insert(block.begin() + insertionPoint,m_RPopIdxValues.pop_back_val()); + } // We added local variable to store result of `clad::pop(...)`. Thus // we need to correspondingly adjust the insertion point. - insertionPoint += 1; + insertionPoint = insertionPoint+1+Numref; // We cannot use the already existing `argDiff.getExpr()` here because // it will cause inconsistent pushes and pops to the clad tape. // FIXME: Modify `GlobalStoreAndRef` such that its functioning is