diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 73ed020a0..ad9ce8126 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -37,9 +37,8 @@ class BaseForwardModeVisitor DerivativeAndOverload DerivePushforward(); - /// Returns the return type for the pushforward function of the function - /// `m_DiffReq->Function`. - clang::QualType ComputePushforwardFnReturnType(); + /// Computes the return type of the derivative in `m_DiffReq->Function`. + clang::QualType ComputeDerivativeFunctionType(); virtual void ExecuteInsidePushforwardFunctionBlock(); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index eda02e000..e242d0ccd 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -172,8 +172,9 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { llvm::SaveAndRestore SaveScope(getCurrentScope()); m_Sema.CurContext = DC; + QualType derivedFnType = ComputeDerivativeFunctionType(); DeclWithContext result = - m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType()); + m_Builder.cloneFunction(FD, *this, DC, validLoc, name, derivedFnType); FunctionDecl* derivedFD = result.first; m_Derivative = derivedFD; @@ -364,19 +365,49 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { /*OverloadFunctionDecl=*/nullptr}; } -clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() { +QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() { + const FunctionDecl* FD = m_DiffReq.Function; + + if (m_DiffReq.Mode == DiffMode::forward) + return FD->getType(); + assert(m_DiffReq.Mode == GetPushForwardMode()); - QualType originalFnRT = m_DiffReq->getReturnType(); - if (originalFnRT->isVoidType()) - return m_Context.VoidTy; - TemplateDecl* valueAndPushforward = - LookupTemplateDeclInCladNamespace("ValueAndPushforward"); - assert(valueAndPushforward && - "clad::ValueAndPushforward template not found!!"); - QualType RT = InstantiateTemplate( - valueAndPushforward, - {originalFnRT, GetPushForwardDerivativeType(originalFnRT)}); - return RT; + + const auto* FnProtoTy = cast(FD->getType()); + llvm::SmallVector FnTypes(FnProtoTy->getParamTypes().begin(), + FnProtoTy->getParamTypes().end()); + + bool HasThis = false; + if (const auto* MD = dyn_cast(FD)) { + if (MD->isInstance()) { + FnTypes.push_back(MD->getThisType()); + HasThis = true; + } + } + + // Iterate over all but the "this" type and extend the signature to add the + // extra parameters. + for (size_t i = 0, e = FnTypes.size() - HasThis; i < e; ++i) { + QualType PVDTy = FnTypes[i]; + if (BaseForwardModeVisitor::IsDifferentiableType(PVDTy)) + FnTypes.push_back(GetPushForwardDerivativeType(PVDTy)); + } + + QualType oRetTy = FD->getReturnType(); + QualType dRetTy; + if (oRetTy->isVoidType()) { + dRetTy = m_Context.VoidTy; + } else { + TemplateDecl* valueAndPushforward = + LookupTemplateDeclInCladNamespace("ValueAndPushforward"); + assert(valueAndPushforward && + "clad::ValueAndPushforward template not found!!"); + // FIXME: Sink GetPushForwardDerivativeType here. + QualType PushFwdTy = GetPushForwardDerivativeType(oRetTy); + dRetTy = InstantiateTemplate(valueAndPushforward, {oRetTy, PushFwdTy}); + } + FunctionProtoType::ExtProtoInfo EPI = FnProtoTy->getExtProtoInfo(); + return m_Context.getFunctionType(dRetTy, FnTypes, EPI); } void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { @@ -393,41 +424,9 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { "Doesn't support recursive diff. Use DiffPlan."); m_DerivativeInFlight = true; - llvm::SmallVector paramTypes; - llvm::SmallVector derivedParamTypes; - - // If we are differentiating an instance member function then - // create a parameter type for the parameter that will represent the - // derivative of `this` pointer with respect to the independent parameter. - if (const auto* MD = dyn_cast(FD)) { - if (MD->isInstance()) { - QualType thisType = MD->getThisType(); - derivedParamTypes.push_back(thisType); - } - } - - for (auto* PVD : m_DiffReq->parameters()) { - paramTypes.push_back(PVD->getType()); - - if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) - derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType())); - } - - paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(), - derivedParamTypes.end()); - - const auto* originalFnType = - dyn_cast(m_DiffReq->getType()); - QualType returnType = ComputePushforwardFnReturnType(); - QualType derivedFnType = m_Context.getFunctionType( - returnType, paramTypes, originalFnType->getExtProtoInfo()); llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto* DC = const_cast(m_DiffReq->getDeclContext()); - m_Sema.CurContext = DC; auto originalFnEffectiveName = utils::ComputeEffectiveFnName(FD); @@ -436,6 +435,11 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { SourceLocation loc{FD->getLocation()}; DeclarationNameInfo derivedFnName(derivedFnII, loc); + // FIXME: We should not use const_cast to get the decl context here. + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + auto* DC = const_cast(m_DiffReq->getDeclContext()); + m_Sema.CurContext = DC; + QualType derivedFnType = ComputeDerivativeFunctionType(); DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType); m_Derivative = cloneFunctionResult.first;