Skip to content

Commit

Permalink
Rework computation of derivative function prototype.
Browse files Browse the repository at this point in the history
This patch centralizes it in one place.
  • Loading branch information
vgvassilev committed Dec 26, 2024
1 parent ed826c4 commit 39ce604
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 48 deletions.
5 changes: 2 additions & 3 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ class BaseForwardModeVisitor

DerivativeAndOverload DerivePushforward();

/// Returns the return type for the pushforward function of the function
/// `m_DiffReq->Function`.
clang::QualType ComputePushforwardFnReturnType();
/// Computes the return type of the derivative in `m_DiffReq->Function`.
clang::QualType ComputeDerivativeFunctionType();

virtual void ExecuteInsidePushforwardFunctionBlock();

Expand Down
94 changes: 49 additions & 45 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());

m_Sema.CurContext = DC;
QualType derivedFnType = ComputeDerivativeFunctionType();
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, derivedFnType);
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

Expand Down Expand Up @@ -364,19 +365,49 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
/*OverloadFunctionDecl=*/nullptr};
}

clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() {
const FunctionDecl* FD = m_DiffReq.Function;

if (m_DiffReq.Mode == DiffMode::forward)
return FD->getType();

assert(m_DiffReq.Mode == GetPushForwardMode());
QualType originalFnRT = m_DiffReq->getReturnType();
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward =
LookupTemplateDeclInCladNamespace("ValueAndPushforward");
assert(valueAndPushforward &&
"clad::ValueAndPushforward template not found!!");
QualType RT = InstantiateTemplate(
valueAndPushforward,
{originalFnRT, GetPushForwardDerivativeType(originalFnRT)});
return RT;

const auto* FnProtoTy = cast<FunctionProtoType>(FD->getType());
llvm::SmallVector<QualType, 16> FnTypes(FnProtoTy->getParamTypes().begin(),
FnProtoTy->getParamTypes().end());

bool HasThis = false;
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
FnTypes.push_back(MD->getThisType());
HasThis = true;
}
}

// Iterate over all but the "this" type and extend the signature to add the
// extra parameters.
for (size_t i = 0, e = FnTypes.size() - HasThis; i < e; ++i) {
QualType PVDTy = FnTypes[i];
if (BaseForwardModeVisitor::IsDifferentiableType(PVDTy))
FnTypes.push_back(GetPushForwardDerivativeType(PVDTy));
}

QualType oRetTy = FD->getReturnType();
QualType dRetTy;
if (oRetTy->isVoidType()) {
dRetTy = m_Context.VoidTy;
} else {
TemplateDecl* valueAndPushforward =
LookupTemplateDeclInCladNamespace("ValueAndPushforward");
assert(valueAndPushforward &&
"clad::ValueAndPushforward template not found!!");
// FIXME: Sink GetPushForwardDerivativeType here.
QualType PushFwdTy = GetPushForwardDerivativeType(oRetTy);
dRetTy = InstantiateTemplate(valueAndPushforward, {oRetTy, PushFwdTy});
}
FunctionProtoType::ExtProtoInfo EPI = FnProtoTy->getExtProtoInfo();
return m_Context.getFunctionType(dRetTy, FnTypes, EPI);
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Expand All @@ -393,41 +424,9 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
"Doesn't support recursive diff. Use DiffPlan.");
m_DerivativeInFlight = true;

llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

// If we are differentiating an instance member function then
// create a parameter type for the parameter that will represent the
// derivative of `this` pointer with respect to the independent parameter.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
QualType thisType = MD->getThisType();
derivedParamTypes.push_back(thisType);
}
}

for (auto* PVD : m_DiffReq->parameters()) {
paramTypes.push_back(PVD->getType());

if (BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
derivedParamTypes.push_back(GetPushForwardDerivativeType(PVD->getType()));
}

paramTypes.insert(paramTypes.end(), derivedParamTypes.begin(),
derivedParamTypes.end());

const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_DiffReq->getType());
QualType returnType = ComputePushforwardFnReturnType();
QualType derivedFnType = m_Context.getFunctionType(
returnType, paramTypes, originalFnType->getExtProtoInfo());
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;

auto originalFnEffectiveName = utils::ComputeEffectiveFnName(FD);

Expand All @@ -436,6 +435,11 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
SourceLocation loc{FD->getLocation()};
DeclarationNameInfo derivedFnName(derivedFnII, loc);

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
m_Sema.CurContext = DC;
QualType derivedFnType = ComputeDerivativeFunctionType();
DeclWithContext cloneFunctionResult =
m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;
Expand Down

0 comments on commit 39ce604

Please sign in to comment.