diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 07b4dd219..cf1688e63 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1304,24 +1304,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { - DeclRefExpr* clonedDRE = nullptr; + DeclRefExpr* clonedDRE = cast(Clone(DRE)); // Check if referenced Decl was "replaced" with another identifier inside // the derivative - if (const auto* VD = dyn_cast(DRE->getDecl())) { - clonedDRE = cast(Clone(DRE)); + if (const auto* VD = dyn_cast(clonedDRE->getDecl())) { // If current context is different than the context of the original // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. - if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { - auto* referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = cast(BuildDeclRef(referencedDecl)); - } - } else - clonedDRE = cast(Clone(DRE)); + if (VD->getDeclContext() != m_Sema.CurContext) + clonedDRE = cast(BuildDeclRef(VD)); - if (auto* decl = dyn_cast(clonedDRE->getDecl())) { if (isVectorValued) { if (m_VectorOutput.size() <= outputArrayCursor) return StmtDiff(clonedDRE);