Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes incorrect derivative produced when array is passed to call expression inside a loop #560

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace clad {
/// to maintain the correct statement order when the current block has
/// delayed emission i.e. assignment LHS.
Stmts m_PopIdxValues;
Stmts m_RPopIdxValues;//for reference variables as input in function argumentlist
std::vector<Stmts> m_LoopBlock;
unsigned outputArrayCursor = 0;
unsigned numParams = 0;
Expand Down Expand Up @@ -392,6 +393,33 @@ namespace clad {
/// https://github.com/vgvassilev/clad/issues/385
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);


bool Ref=false;
bool hasReferenceType(const clang::QualType& type) {
return type->isReferenceType();
}
bool printArgTypes(const clang::CallExpr* CE) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this function called printArgTypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a prototype function which did print the argument types. I had done this on last day of GSoC application so it is a complete mess and forgot to change the name. This part can be shifted to CladUtils. This function can be simplified a lot. All it is trying to find if there is any ReferenceType. I am very sorry for this.

const clang::FunctionDecl* FD = CE->getDirectCallee();
if (!FD) {
return false;
}
int numParams = FD->getNumParams();
if (CE->getNumArgs() != numParams) {
return false;
}
bool hasRefType = false;
for (int i = 0; i < numParams; i++) {
const clang::ParmVarDecl* param = FD->getParamDecl(i);
const clang::Expr* arg = CE->getArg(i);
bool isReferenceType = hasReferenceType(param->getType());
if (isReferenceType) {
hasRefType = true;
}
}
return hasRefType;
}


/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
Expand Down Expand Up @@ -584,4 +612,4 @@ namespace clad {
};
} // end namespace clad

#endif // CLAD_REVERSE_MODE_VISITOR_H
#endif // CLAD_REVERSE_MODE_VISITOR_H
12 changes: 10 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/*DirectInit=*/true);
if (dfdx())
addToCurrentBlock(BuildDeclStmt(popVal), direction::reverse);
else if(Ref)
m_RPopIdxValues.push_back(BuildDeclStmt(popVal));
else
m_PopIdxValues.push_back(BuildDeclStmt(popVal));
IdxStored = StmtDiff(IdxStored.getExpr(), BuildDeclRef(popVal));
Expand Down Expand Up @@ -1372,6 +1374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff ReverseModeVisitor::VisitCallExpr(const CallExpr* CE) {
const FunctionDecl* FD = CE->getDirectCallee();
Ref=printArgTypes(CE);//finds if reference passed as argument
if (!FD) {
diag(DiagnosticsEngine::Warning,
CE->getEndLoc(),
Expand Down Expand Up @@ -1547,10 +1550,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
block.insert(block.begin() + insertionPoint,
BuildDeclStmt(argDiffLocalVD));
Expr* argDiffLocalE = BuildDeclRef(argDiffLocalVD);

int Numref=0;
while(!m_RPopIdxValues.empty())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does m_RPopIdxValues represent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parth-07 It keeps the DeclStmt of the variables which were being generated because of having a Reference Variable in the argument list but were not being added to the derived function. These variables were being generated in VisitArraySubscriptExpr()

{
Numref++;
block.insert(block.begin() + insertionPoint,m_RPopIdxValues.pop_back_val());
}
// We added local variable to store result of `clad::pop(...)`. Thus
// we need to correspondingly adjust the insertion point.
insertionPoint += 1;
insertionPoint = insertionPoint+1+Numref;
// We cannot use the already existing `argDiff.getExpr()` here because
// it will cause inconsistent pushes and pops to the clad tape.
// FIXME: Modify `GlobalStoreAndRef` such that its functioning is
Expand Down