diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6db6ee915..ecc17c951 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -398,9 +398,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - QualType pullbackFnType = m_Context.getFunctionType( - m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); - llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); @@ -409,6 +406,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); SourceLocation validLoc{m_DiffReq->getLocation()}; + QualType pullbackFnType = m_Context.getFunctionType( + m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); DeclWithContext fnBuildRes = m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext, validLoc, DNI, pullbackFnType);