Skip to content

Commit

Permalink
Move the computation of parameters in SetupDerivativeParameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 27, 2024
1 parent d05e3c9 commit d8a8798
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 71 deletions.
13 changes: 10 additions & 3 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

#include <array>
#include <stack>
#include <unordered_map>
Expand Down Expand Up @@ -37,9 +39,6 @@ class BaseForwardModeVisitor

DerivativeAndOverload DerivePushforward();

/// Computes the return type of the derivative in `m_DiffReq->Function`.
clang::QualType ComputeDerivativeFunctionType();

virtual void ExecuteInsidePushforwardFunctionBlock();

static bool IsDifferentiableType(clang::QualType T);
Expand Down Expand Up @@ -148,6 +147,14 @@ class BaseForwardModeVisitor
const clang::CXXConstructExpr* CE,
llvm::SmallVectorImpl<clang::Expr*>& clonedArgs,
llvm::SmallVectorImpl<clang::Expr*>& derivedArgs);

private:
/// Computes the return type of the derivative in `m_DiffReq->Function`.
clang::QualType ComputeDerivativeFunctionType();

/// Prepares the derivative function parameters.
void
SetupDerivativeParameters(llvm::SmallVectorImpl<clang::ParmVarDecl*>& params);
};
} // end namespace clad

Expand Down
132 changes: 64 additions & 68 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,32 +154,16 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

llvm::SmallVector<ParmVarDecl*, 4> params;
const ParmVarDecl* PVD = nullptr;

// Function declaration scope
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) {
PVD = FD->getParamDecl(i);
auto* newPVD = CloneParmVarDecl(PVD, PVD->getIdentifier(),
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);

// Make m_IndependentVar to point to the argument of the newly created
// derivedFD.
if (PVD == m_IndependentVar)
m_IndependentVar = newPVD;

params.push_back(newPVD);
}
llvm::SmallVector<ParmVarDecl*, 16> params;
SetupDerivativeParameters(params);
derivedFD->setParams(params);

llvm::ArrayRef<ParmVarDecl*> paramsRef =
clad_compat::makeArrayRef(params.data(), params.size());
derivedFD->setParams(paramsRef);
derivedFD->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down Expand Up @@ -386,6 +370,64 @@ QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() {
return m_Context.getFunctionType(dRetTy, FnTypes, EPI);
}

void BaseForwardModeVisitor::SetupDerivativeParameters(
llvm::SmallVectorImpl<ParmVarDecl*>& params) {
const FunctionDecl* FD = m_DiffReq.Function;
for (ParmVarDecl* PVD : FD->parameters()) {
IdentifierInfo* PVDII = PVD->getIdentifier();
// Implicitly created special member functions have no parameter names.
if (!PVD->getDeclName())
PVDII = CreateUniqueIdentifier("param");

auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);

// Point m_IndependentVar to the argument of the newly created param.
if (PVD == m_IndependentVar)
m_IndependentVar = newPVD;

if (!PVD->getDeclName()) // We can't use lookup-based replacements
m_DeclReplacements[PVD] = newPVD;

params.push_back(newPVD);
}

if (m_DiffReq.Mode == DiffMode::forward)
return;

bool HasThis = false;
// If we are differentiating an instance member function then create a
// parameter for representing derivative of `this` pointer with respect to the
// independent parameter.
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
IdentifierInfo* dThisII = CreateUniqueIdentifier("_d_this");
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, dThisII,
MD->getThisType());
m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false);
params.push_back(dPVD);
// FIXME: Replace m_ThisExprDerivative in favor of lookups of _d_this.
m_ThisExprDerivative = BuildDeclRef(dPVD);
HasThis = true;
}
}

for (size_t i = 0, e = params.size() - HasThis; i < e; ++i) {
const ParmVarDecl* PVD = params[i];

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;

IdentifierInfo* II = CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
auto* dPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, II, GetPushForwardDerivativeType(PVD->getType()),
PVD->getStorageClass());
params.push_back(dPVD);
m_Variables[PVD] = BuildDeclRef(dPVD);
}
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
Expand Down Expand Up @@ -417,61 +459,15 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
llvm::SmallVector<ParmVarDecl*, 16> derivedParams;
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

// If we are differentiating an instance member function then
// create a parameter for representing derivative of
// `this` pointer with respect to the independent parameter.
if (const auto* MFD = dyn_cast<CXXMethodDecl>(FD)) {
if (MFD->isInstance()) {
auto thisType = MFD->getThisType();
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier("_d_this");
auto* derivedPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext,
derivedPVDII, thisType);
m_Sema.PushOnScopeChains(derivedPVD, getCurrentScope(),
/*AddToContext=*/false);
derivedParams.push_back(derivedPVD);
m_ThisExprDerivative = BuildDeclRef(derivedPVD);
}
}

std::size_t numParamsOriginalFn = m_DiffReq->getNumParams();
for (std::size_t i = 0; i < numParamsOriginalFn; ++i) {
const auto* PVD = m_DiffReq->getParamDecl(i);
// Some of the special member functions created implicitly by compilers
// have missing parameter identifier.
bool identifierMissing = false;
IdentifierInfo* PVDII = PVD->getIdentifier();
if (!PVDII || PVDII->getLength() == 0) {
PVDII = CreateUniqueIdentifier("param");
identifierMissing = true;
}
auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);
params.push_back(newPVD);

if (identifierMissing)
m_DeclReplacements[PVD] = newPVD;

if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType()))
continue;
auto derivedPVDName = "_d_" + std::string(PVDII->getName());
IdentifierInfo* derivedPVDII = CreateUniqueIdentifier(derivedPVDName);
auto* derivedPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, derivedPVDII,
GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass());
derivedParams.push_back(derivedPVD);
m_Variables[newPVD] = BuildDeclRef(derivedPVD);
}

params.insert(params.end(), derivedParams.begin(), derivedParams.end());
llvm::SmallVector<ParmVarDecl*, 16> params;
SetupDerivativeParameters(params);
m_Derivative->setParams(params);

m_Derivative->setBody(nullptr);

if (!m_DiffReq.DeclarationOnly) {
Expand Down

0 comments on commit d8a8798

Please sign in to comment.