Skip to content

Commit

Permalink
Move computing of the effective derivative name in DiffRequest. NFC.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 27, 2024
1 parent 39ce604 commit 14a79e6
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 33 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ struct DiffRequest {

bool shouldBeRecorded(clang::Expr* E) const;
bool shouldHaveAdjoint(const clang::VarDecl* VD) const;
std::string ComputeDerivativeName() const;
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "Compatibility.h"
#include "DerivativeBuilder.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
Expand Down Expand Up @@ -199,7 +200,7 @@ namespace clad {
}
/// For a qualtype QT returns if it's type is Array or Pointer Type
static bool isArrayOrPointerType(const clang::QualType QT) {
return QT->isArrayType() || QT->isPointerType();
return utils::isArrayOrPointerType(QT);
}

clang::CompoundStmt* MakeCompoundStmt(const Stmts& Stmts);
Expand Down
36 changes: 4 additions & 32 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
// performed. Mathematically, independent variables are all the function
// parameters, thus, does not convey the intendend meaning.
m_IndependentVar = DVI.back().param;
std::string derivativeSuffix("");
// If param is not real (i.e. floating point or integral), a pointer to a
// real type, or an array of a real type we cannot differentiate it.
// FIXME: we should support custom numeric types in the future.
Expand All @@ -113,7 +112,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
return {};
}
m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start;
derivativeSuffix = "_" + std::to_string(m_IndependentVarIndex);
} else {
QualType T = m_IndependentVar->getType();
bool isField = false;
Expand All @@ -132,32 +130,9 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
}
}

// If we are differentiating a call operator, that has no parameters,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
unsigned argIndex = ~0;
const CXXRecordDecl* functor = m_DiffReq.Functor;
if (m_DiffReq->param_empty() && functor)
argIndex = std::distance(functor->field_begin(),
std::find(functor->field_begin(),
functor->field_end(), m_IndependentVar));
else
argIndex = std::distance(
FD->param_begin(),
std::find(FD->param_begin(), FD->param_end(), m_IndependentVar));

std::string argInfo = std::to_string(argIndex);
for (auto field : diffVarInfo.fields)
argInfo += "_" + field;

std::string s;
if (m_DiffReq.CurrentDerivativeOrder > 1)
s = std::to_string(m_DiffReq.CurrentDerivativeOrder);

// Check if the function is already declared as a custom derivative.
std::string gradientName = m_DiffReq.BaseFunctionName + "_d" + s + "arg" +
argInfo + derivativeSuffix;
std::string gradientName = m_DiffReq.ComputeDerivativeName();

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
Expand Down Expand Up @@ -428,12 +403,9 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() {
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());

auto originalFnEffectiveName = utils::ComputeEffectiveFnName(FD);

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
IdentifierInfo* II = &m_Context.Idents.get(m_DiffReq.ComputeDerivativeName());
SourceLocation loc{FD->getLocation()};
DeclarationNameInfo derivedFnName(derivedFnII, loc);
DeclarationNameInfo derivedFnName(II, loc);

// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
Expand Down
38 changes: 38 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,44 @@ namespace clad {
return found != m_ActivityRunInfo.ToBeRecorded.end();
}

std::string DiffRequest::ComputeDerivativeName() const {
if (Mode != DiffMode::forward && Mode != DiffMode::reverse)
return BaseFunctionName + "_" + DiffModeToString(Mode);

if (DVI.empty())
return "<no independent variable specified>";

Check warning on line 677 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L677

Added line #L677 was not covered by tests

DiffInputVarInfo VarInfo = DVI.back();
const ValueDecl* IndependentVar = VarInfo.param;
unsigned argIndex = ~0;
// If we are differentiating a call operator, that has no parameters,
// then the specified independent argument is a member variable of the
// class defining the call operator.
// Thus, we need to find index of the member variable instead.
if (Function->param_empty() && Functor)
argIndex = std::distance(Functor->field_begin(),
std::find(Functor->field_begin(),
Functor->field_end(), IndependentVar));
else
argIndex =
std::distance(Function->param_begin(),
std::find(Function->param_begin(),
Function->param_end(), IndependentVar));

std::string argInfo = std::to_string(argIndex);
for (auto field : VarInfo.fields)
argInfo += "_" + field;

std::string s;
if (CurrentDerivativeOrder > 1)
s = std::to_string(CurrentDerivativeOrder);

if (utils::isArrayOrPointerType(IndependentVar->getType()))
argInfo += "_" + std::to_string(VarInfo.paramIndexInterval.Start);

return BaseFunctionName + "_d" + s + "arg" + argInfo;
}

///\returns true on error.
static bool ProcessInvocationArgs(Sema& S, SourceLocation endLoc,
const RequestOptions& ReqOpts,
Expand Down

0 comments on commit 14a79e6

Please sign in to comment.