Skip to content

Commit

Permalink
add support for operator overload in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and PhrygianGates committed Sep 1, 2023
1 parent bba8cb1 commit 985097e
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 25 deletions.
8 changes: 8 additions & 0 deletions benchmark/BenchmarkedFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ inline double product(double p[], int n) {
}
return prod;
}

///\returns the weighted sum of the elements in \p
inline double weightedSum(double p[], double w[], int n) {
double sum = 0;
for (int i = 0; i < n; i++)
sum += p[i] * w[i];
return sum;
}
1 change: 1 addition & 0 deletions benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CB_ADD_GBENCHMARK(Simple Simple.cpp)
CB_ADD_GBENCHMARK(AlgorithmicComplexity AlgorithmicComplexity.cpp)
CB_ADD_GBENCHMARK(EnzymeCladComparison EnzymeCladComparison.cpp)
CB_ADD_GBENCHMARK(MemoryComplexity MemoryComplexity.cpp)
CB_ADD_GBENCHMARK(VectorModeComparison VectorModeComparison.cpp)

set (CLAD_BENCHMARK_DEPS clad)
get_property(_benchmark_names DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY TESTS)
Expand Down
122 changes: 122 additions & 0 deletions benchmark/VectorModeComparison.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "benchmark/benchmark.h"

#include "clad/Differentiator/Differentiator.h"

#include "BenchmarkedFunctions.h"

// Benchmark forward mode for weighted sum.
static void BM_ForwardModeWeightedSum(benchmark::State& state) {
auto dp0 = clad::differentiate(weightedSum, "p[0]");
auto dp1 = clad::differentiate(weightedSum, "p[1]");
auto dp2 = clad::differentiate(weightedSum, "p[2]");
auto dp3 = clad::differentiate(weightedSum, "p[3]");
auto dp4 = clad::differentiate(weightedSum, "p[4]");

auto dw0 = clad::differentiate(weightedSum, "w[0]");
auto dw1 = clad::differentiate(weightedSum, "w[1]");
auto dw2 = clad::differentiate(weightedSum, "w[2]");
auto dw3 = clad::differentiate(weightedSum, "w[3]");
auto dw4 = clad::differentiate(weightedSum, "w[4]");

constexpr int n = 5;
double inputs[n];
double weights[n];
for (int i = 0; i < n; ++i) {
inputs[i] = i + 1;
weights[i] = 1.0 / (double)(i + 1);
}

double sum = 0;
for (auto _ : state) {
benchmark::DoNotOptimize(
sum +=
dp0.execute(inputs, weights, n) + dp1.execute(inputs, weights, n) +
dp2.execute(inputs, weights, n) + dp3.execute(inputs, weights, n) +
dp4.execute(inputs, weights, n) + dw0.execute(inputs, weights, n) +
dw1.execute(inputs, weights, n) + dw2.execute(inputs, weights, n) +
dw3.execute(inputs, weights, n) + dw4.execute(inputs, weights, n));
}
}
BENCHMARK(BM_ForwardModeWeightedSum);

// Benchmark reverse mode for weighted sum.
static void BM_ReverseModeWeightedSum(benchmark::State& state) {
auto grad = clad::gradient(weightedSum, "p, w");
constexpr int n = 5;

double inputs[n];
double weights[n];
for (int i = 0; i < n; ++i) {
inputs[i] = i + 1;
weights[i] = 1.0 / (double)(i + 1);
}

double dinp[n];
double dweights[n];
clad::array_ref<double> dinp_ref(dinp, n);
clad::array_ref<double> dweights_ref(dweights, n);

double sum = 0;
for (auto _ : state) {
grad.execute(inputs, weights, n, dinp_ref, dweights_ref);
for (int i = 0; i < n; ++i)
sum += dinp[i] + dweights[i];
}
}
BENCHMARK(BM_ReverseModeWeightedSum);

// Benchmark enzyme's reverse mode for weighted sum.
static void BM_EnzymeReverseModeWeightedSum(benchmark::State& state) {
auto grad = clad::gradient<clad::opts::use_enzyme>(weightedSum, "p, w");
constexpr int n = 5;

double inputs[n];
double weights[n];
for (int i = 0; i < n; ++i) {
inputs[i] = i + 1;
weights[i] = 1.0 / (double)(i + 1);
}

double dinp[n];
double dweights[n];
clad::array_ref<double> dinp_ref(dinp, n);
clad::array_ref<double> dweights_ref(dweights, n);

double sum = 0;
for (auto _ : state) {
grad.execute(inputs, weights, n, dinp_ref, dweights_ref);
for (int i = 0; i < n; ++i)
sum += dinp[i] + dweights[i];
}
}
BENCHMARK(BM_EnzymeReverseModeWeightedSum);

// Benchmark vector forward mode for weighted sum.
static void BM_VectorForwardModeWeightedSum(benchmark::State& state) {
auto vm_grad =
clad::differentiate<clad::opts::vector_mode>(weightedSum, "p, w");
constexpr int n = 5;

double inputs[n];
double weights[n];
for (int i = 0; i < n; ++i) {
inputs[i] = i + 1;
weights[i] = 1.0 / (double)(i + 1);
}

double dinp[n];
double dweights[n];
clad::array_ref<double> dinp_ref(dinp, n);
clad::array_ref<double> dweights_ref(dweights, n);

double sum = 0;
for (auto _ : state) {
vm_grad.execute(inputs, weights, n, dinp_ref, dweights_ref);
for (int i = 0; i < n; ++i)
sum += dinp[i] + dweights[i];
}
}
BENCHMARK(BM_VectorForwardModeWeightedSum);

// Define our main.
BENCHMARK_MAIN();
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp) override;
};
} // namespace clad

Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ namespace clad {
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
virtual StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
/// Decl is not Stmt, so it cannot be visited directly.
StmtDiff VisitWhileStmt(const clang::WhileStmt* WS);
Expand Down
32 changes: 28 additions & 4 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
DiffParams args{};
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

auto fnName = m_Function->getNameAsString() + "_forw";
auto fnName = clad::utils::ComputeEffectiveFnName(m_Function) + "_forw";
auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -86,8 +86,6 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
QualType
ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand All @@ -106,7 +104,7 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {

QualType effectiveReturnType =
m_Function->getReturnType().getNonReferenceType();

if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
Expand Down Expand Up @@ -240,4 +238,30 @@ ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get();
return {newRS};
}

StmtDiff
ReverseModeForwPassVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
auto opCode = UnOp->getOpcode();
StmtDiff diff{};
// If it is a post-increment/decrement operator, its result is a reference
// and we should return it.
Expr* ResultRef = nullptr;
if (opCode == UnaryOperatorKind::UO_Deref) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance()) {
diff = Visit(UnOp->getSubExpr());
Expr* cloneE = BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());
Expr* derivedE = diff.getExpr_dx();
return {cloneE, derivedE};
}
}
} else if (opCode == UO_Plus)
diff = Visit(UnOp->getSubExpr(), dfdx());
else if (opCode == UO_Minus) {
auto d = BuildOp(UO_Minus, dfdx());
diff = Visit(UnOp->getSubExpr(), d);
}
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef);
}
} // namespace clad
54 changes: 37 additions & 17 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
// contribute to gradient.
if (!NArgs && !isa<CXXMemberCallExpr>(CE))
if ((NArgs == 0U) && !isa<CXXMemberCallExpr>(CE) &&
!isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

// Stores the call arguments for the function to be derived
Expand All @@ -1391,7 +1392,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE)) {
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
Expand All @@ -1415,9 +1416,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
auto PVD = FD->getParamDecl(i);
const auto* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
// We do not need to create result arg for arguments passed by reference
Expand Down Expand Up @@ -1597,8 +1603,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

/// Add base derivative expression in the derived call output args list if
/// `CE` is a call to an instance member function.
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE)) {
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseDiff = Visit(MCE->getImplicitObjectArgument());
else if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseDiff = Visit(OCE->getArg(0));
if (baseDiff.getExpr()) {
StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
if (isInsideLoop) {
addToCurrentBlock(baseDiffStore.getExpr());
Expand Down Expand Up @@ -1689,15 +1698,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackCallArgs = DerivedCallArgs;

if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(),
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() -
static_cast<int>(isCXXOperatorCall),
pullback);

// Try to find it in builtin derivatives
std::string customPullback = FD->getNameAsString() + "_pullback";
if (baseDiff.getExpr())
pullbackCallArgs.insert(
pullbackCallArgs.begin(),
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
std::string customPullback =
clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
}

// should be true if we are using numerical differentiation to differentiate
Expand Down Expand Up @@ -1728,7 +1745,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;
pullbackRequest.BaseFunctionName = FD->getNameAsString();
pullbackRequest.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
Expand Down Expand Up @@ -1775,7 +1793,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
usingNumericalDiff = true;
}
} else if (pullbackFD) {
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD);
Expand Down Expand Up @@ -1861,7 +1879,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString();
calleeFnForwPassReq.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;
FunctionDecl* calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);
Expand All @@ -1878,20 +1897,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// We cannot reuse the derivatives previously computed because
// they might contain 'clad::pop(..)` expression.
if (isa<CXXMemberCallExpr>(CE)) {
if (isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)) {
Expr* derivedBase = baseDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
// (isCladArrayType(derivedBase->getType()))
// CallArgs.push_back(derivedBase);
// else
// Currently derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc));
}

for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const ParmVarDecl* PVD = FD->getParamDecl(i);
const ParmVarDecl* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff = Visit(arg);
if ((argDiff.getExpr_dx() != nullptr) &&
PVD->getType()->isReferenceType()) {
Expand All @@ -1906,7 +1929,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get());
}
if (isa<CXXMemberCallExpr>(CE)) {
if (isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)) {
Expr* baseE = baseDiff.getExpr();
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, calleeFnForwPassFD);
Expand Down Expand Up @@ -3154,9 +3177,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_Mode == DiffMode::reverse)
assert(yType->isRealType() &&
"yType should be a non-reference builtin-numerical scalar type!!");
else if (m_Mode == DiffMode::experimental_pullback)
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand Down
Loading

0 comments on commit 985097e

Please sign in to comment.