Skip to content

Commit

Permalink
Harmonize Derive and DerivePushforward
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 26, 2024
1 parent 3e11004 commit 5b47ca0
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,6 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

auto originalFnEffectiveName =
utils::ComputeEffectiveFnName(m_DiffReq.Function);

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, m_DiffReq->getLocation());
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

Expand Down Expand Up @@ -446,9 +440,15 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;

SourceLocation loc{m_DiffReq->getLocation()};
DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_DiffReq.Function, *this, DC, loc, derivedFnName, derivedFnType);
auto originalFnEffectiveName = utils::ComputeEffectiveFnName(FD);

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
SourceLocation loc{FD->getLocation()};
DeclarationNameInfo derivedFnName(derivedFnII, loc);

DeclWithContext cloneFunctionResult =
m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
Expand Down Expand Up @@ -518,7 +518,6 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {

Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);

endScope(); // Function body scope

// Size >= current derivative order means that there exists a declaration
Expand Down

0 comments on commit 5b47ca0

Please sign in to comment.