Skip to content

Commit

Permalink
Move the seed initialization into a separate function.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 27, 2024
1 parent bbe5390 commit c893c31
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 124 deletions.
7 changes: 7 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ class BaseForwardModeVisitor
/// Prepares the derivative function parameters.
void
SetupDerivativeParameters(llvm::SmallVectorImpl<clang::ParmVarDecl*>& params);

/// Generate a seed initializing each independent argument with 1 and 0
/// otherwise:
/// double f_darg0(double x, double y) {
/// double _d_x = 1;
/// double _d_y = 0;
void GenerateSeeds(const clang::FunctionDecl* dFD);
};
} // end namespace clad

Expand Down
257 changes: 133 additions & 124 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/OperationKinds.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/Type.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
#include "clang/Sema/Scope.h"
Expand All @@ -29,6 +33,7 @@

#include <algorithm>
#include <string>
#include <vector>

#include "clad/Differentiator/Compatibility.h"

Expand Down Expand Up @@ -171,131 +176,9 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();
// For each function parameter variable, store its derivative value.
for (auto* param : params) {
// We cannot create derivatives of reference type since seed value is
// always a constant (r-value). We assume that all the arguments have no
// relation among them, thus it is safe (correct) to use the corresponding
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
continue;
Expr* dParam = nullptr;
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
dValue);
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
// double f_darg0(double x, double y) {
// double _d_x = 1;
// double _d_y = 0;
// ...
auto* dParamDecl =
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (dParamType->isRecordType() && param == m_IndependentVar) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
Expr* memRef =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref);
assert(memRef->getType()->isRealType() &&
"Forward mode can only differentiate w.r.t builtin scalar "
"numerical types.");
addToCurrentBlock(BuildOp(
BinaryOperatorKind::BO_Assign, memRef,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1)));
}
// Memorize the derivative of param, i.e. whenever the param is visited
// in the future, it's derivative dParam is found (unless reassigned with
// something new).
m_Variables[param] = dParam;
}

if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
// We cannot create derivative of lambda yet because lambdas default
// constructor is deleted.
if (MD->isInstance() && !MD->getParent()->isLambda()) {
QualType thisObjectType =
clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD);
QualType thisType = MD->getThisType();
// Here we are effectively doing:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// We are not creating `this` expression derivative using `new` because
// then we would be responsible for freeing the memory as well and its
// more convenient to let compiler handle the object lifecycle.
VarDecl* derivativeVD = BuildVarDecl(thisObjectType, "_d_this_obj");
DeclRefExpr* derivativeE = BuildDeclRef(derivativeVD);
VarDecl* thisExprDerivativeVD =
BuildVarDecl(thisType, "_d_this",
BuildOp(UnaryOperatorKind::UO_AddrOf, derivativeE));
addToCurrentBlock(BuildDeclStmt(derivativeVD));
addToCurrentBlock(BuildDeclStmt(thisExprDerivativeVD));
m_ThisExprDerivative = BuildDeclRef(thisExprDerivativeVD);
}
}

// Create derived variable for each member variable if we are
// differentiating a call operator.
if (m_DiffReq.Functor) {
for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) {
Expr* dInitializer = nullptr;
QualType fieldType = fieldDecl->getType();

if (const auto* arrType =
dyn_cast<ConstantArrayType>(fieldType.getTypePtr())) {
if (!arrType->getElementType()->isRealType())
continue;

auto arrSize = arrType->getSize().getZExtValue();
std::vector<Expr*> dArrVal;

// Create an initializer list to initialize derived variable created
// for array member variable.
// For example, if we are differentiating wrt arr[3], then
// ```
// double arr[7];
// ```
// will get differentiated to,
//
// ```
// double _d_arr[7] = {0, 0, 0, 1, 0, 0, 0};
// ```
for (size_t i = 0; i < arrSize; ++i) {
int dValue =
(fieldDecl == m_IndependentVar && i == m_IndependentVarIndex);
auto* dValueLiteral = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
}
dInitializer =
m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get();
} else if (const auto* ptrType =
dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, dValue);
}
VarDecl* derivedFieldDecl =
BuildVarDecl(fieldType.getNonReferenceType(),
"_d_" + fieldDecl->getNameAsString(), dInitializer);
addToCurrentBlock(BuildDeclStmt(derivedFieldDecl));
m_Variables.emplace(fieldDecl, BuildDeclRef(derivedFieldDecl));
}
}
if (m_DiffReq.Mode == DiffMode::forward)
GenerateSeeds(derivedFD);

Stmt* BodyDiff = Visit(FD->getBody()).getStmt();
if (auto* CS = dyn_cast<CompoundStmt>(BodyDiff))
Expand Down Expand Up @@ -428,6 +311,132 @@ void BaseForwardModeVisitor::SetupDerivativeParameters(
}
}

void BaseForwardModeVisitor::GenerateSeeds(const clang::FunctionDecl* dFD) {
// For each function parameter variable, store its derivative value.
for (const ParmVarDecl* param : dFD->parameters()) {
// We cannot create derivatives of reference type since seed value is
// always a constant (r-value). We assume that all the arguments have no
// relation among them, thus it is safe (correct) to use the corresponding
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
continue;
Expr* dParam = nullptr;
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, dValue);
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
// double f_darg0(double x, double y) {
// double _d_x = 1;
// double _d_y = 0;
// ...
auto* dParamDecl =
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (dParamType->isRecordType() && param == m_IndependentVar) {
DiffInputVarInfo diffVarInfo = m_DiffReq.DVI.back();
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
Expr* memRef =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), dParam, ref);
assert(memRef->getType()->isRealType() &&
"Forward mode can only differentiate w.r.t builtin scalar "
"numerical types.");
addToCurrentBlock(BuildOp(BinaryOperatorKind::BO_Assign, memRef,
ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/1)));
}
// Memorize the derivative of param, i.e. whenever the param is visited
// in the future, it's derivative dParam is found (unless reassigned with
// something new).
m_Variables[param] = dParam;
}
if (const auto* MD = dyn_cast<CXXMethodDecl>(dFD)) {
// We cannot create derivative of lambda yet because lambdas default
// constructor is deleted.
if (MD->isInstance() && !MD->getParent()->isLambda()) {
QualType thisObjectType =
clad_compat::CXXMethodDecl_GetThisObjectType(m_Sema, MD);
QualType thisType = MD->getThisType();
// Here we are effectively doing:
// ```
// Class _d_this_obj;
// Class* _d_this = &_d_this_obj;
// ```
// We are not creating `this` expression derivative using `new` because
// then we would be responsible for freeing the memory as well and its
// more convenient to let compiler handle the object lifecycle.
VarDecl* derivativeVD = BuildVarDecl(thisObjectType, "_d_this_obj");
DeclRefExpr* derivativeE = BuildDeclRef(derivativeVD);
VarDecl* thisExprDerivativeVD =
BuildVarDecl(thisType, "_d_this",
BuildOp(UnaryOperatorKind::UO_AddrOf, derivativeE));
addToCurrentBlock(BuildDeclStmt(derivativeVD));
addToCurrentBlock(BuildDeclStmt(thisExprDerivativeVD));
m_ThisExprDerivative = BuildDeclRef(thisExprDerivativeVD);
}
}
SourceLocation validLoc{m_DiffReq->getLocation()};
// Create derived variable for each member variable if we are
// differentiating a call operator.
if (m_DiffReq.Functor) {
for (FieldDecl* fieldDecl : m_DiffReq.Functor->fields()) {
Expr* dInitializer = nullptr;
QualType fieldType = fieldDecl->getType();

if (const auto* arrType = dyn_cast<ConstantArrayType>(fieldType)) {
if (!arrType->getElementType()->isRealType())
continue;

auto arrSize = arrType->getSize().getZExtValue();
std::vector<Expr*> dArrVal;

// Create an initializer list to initialize derived variable created
// for array member variable.
// For example, if we are differentiating wrt arr[3], then
// ```
// double arr[7];
// ```
// will get differentiated to,
//
// ```
// double _d_arr[7] = {0, 0, 0, 1, 0, 0, 0};
// ```
for (size_t i = 0; i < arrSize; ++i) {
int dValue =
(fieldDecl == m_IndependentVar && i == m_IndependentVarIndex);
auto* dValueLiteral = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
}
dInitializer = m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get();
} else if (const auto* ptrType =
dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, dValue);
}
VarDecl* derivedFieldDecl =
BuildVarDecl(fieldType.getNonReferenceType(),
"_d_" + fieldDecl->getNameAsString(), dInitializer);
addToCurrentBlock(BuildDeclStmt(derivedFieldDecl));
m_Variables.emplace(fieldDecl, BuildDeclRef(derivedFieldDecl));
}
}
}

void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() {
Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt();
auto* CS = cast<CompoundStmt>(bodyDiff);
Expand Down

0 comments on commit c893c31

Please sign in to comment.