Skip to content

Commit

Permalink
Harmonize Derive and DerivePullback
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 26, 2024
1 parent 5b47ca0 commit f201b71
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes);

QualType pullbackFnType = m_Context.getFunctionType(
m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo());

llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
Expand All @@ -409,6 +406,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Sema.CurContext = const_cast<DeclContext*>(m_DiffReq->getDeclContext());

SourceLocation validLoc{m_DiffReq->getLocation()};
QualType pullbackFnType = m_Context.getFunctionType(
m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo());
DeclWithContext fnBuildRes =
m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext,
validLoc, DNI, pullbackFnType);
Expand Down

0 comments on commit f201b71

Please sign in to comment.