Skip to content

Commit

Permalink
add support for operator overload in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 authored and PhrygianGates committed Aug 30, 2023
1 parent bba8cb1 commit 75508b2
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 27 deletions.
6 changes: 2 additions & 4 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
DiffParams args{};
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

auto fnName = m_Function->getNameAsString() + "_forw";
auto fnName = clad::utils::ComputeEffectiveFnName(m_Function) + "_forw";
auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -86,8 +86,6 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
QualType
ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand All @@ -106,7 +104,7 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {

QualType effectiveReturnType =
m_Function->getReturnType().getNonReferenceType();

if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
Expand Down
60 changes: 39 additions & 21 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
// contribute to gradient.
if (!NArgs && !isa<CXXMemberCallExpr>(CE))
if ((NArgs == 0U) && !isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE))
return StmtDiff(Clone(CE));

// Stores the call arguments for the function to be derived
Expand All @@ -1391,7 +1391,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derived function. In the case of member functions, `implicit`
// this object is always passed by reference.
if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMemberCallExpr>(CE)) {
!isa<CXXMemberCallExpr>(CE) && !isa<CXXOperatorCallExpr>(CE)) {
for (const Expr* Arg : CE->arguments()) {
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
Expand All @@ -1415,9 +1415,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall), e = CE->getNumArgs(); i != e; ++i) {
const Expr* arg = CE->getArg(i);
auto PVD = FD->getParamDecl(i);
const auto *PVD = FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
// We do not need to create result arg for arguments passed by reference
Expand Down Expand Up @@ -1599,6 +1601,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// `CE` is a call to an instance member function.
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE)) {
baseDiff = Visit(MCE->getImplicitObjectArgument());
}
else if (const auto *OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
baseDiff = Visit(OCE->getArg(0));
}
if (baseDiff.getExpr()) {
StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
if (isInsideLoop) {
addToCurrentBlock(baseDiffStore.getExpr());
Expand Down Expand Up @@ -1689,15 +1696,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackCallArgs = DerivedCallArgs;

if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(),
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - static_cast<int>(isCXXOperatorCall),
pullback);

// Try to find it in builtin derivatives
std::string customPullback = FD->getNameAsString() + "_pullback";
if (baseDiff.getExpr()) {
pullbackCallArgs.insert(pullbackCallArgs.begin(), BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr()));
}
std::string customPullback = clad::utils::ComputeEffectiveFnName(FD) + "_pullback";
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
if (baseDiff.getExpr()) {
pullbackCallArgs.erase(pullbackCallArgs.begin());
}
}

// should be true if we are using numerical differentiation to differentiate
Expand Down Expand Up @@ -1728,7 +1741,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// derive the called function.
DiffRequest pullbackRequest{};
pullbackRequest.Function = FD;
pullbackRequest.BaseFunctionName = FD->getNameAsString();
pullbackRequest.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD);
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
Expand Down Expand Up @@ -1775,7 +1788,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
usingNumericalDiff = true;
}
} else if (pullbackFD) {
if (isa<CXXMemberCallExpr>(CE)) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD);
Expand Down Expand Up @@ -1861,7 +1874,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString();
calleeFnForwPassReq.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;
FunctionDecl* calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);
Expand All @@ -1878,20 +1891,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// We cannot reuse the derivatives previously computed because
// they might contain 'clad::pop(..)` expression.
if (isa<CXXMemberCallExpr>(CE)) {
if (isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)) {
Expr* derivedBase = baseDiff.getExpr_dx();
// FIXME: We may need this if-block once we support pointers, and
// passing pointers-by-reference if
// (isCladArrayType(derivedBase->getType()))
// CallArgs.push_back(derivedBase);
// else
// derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc));
}

for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall), e = CE->getNumArgs(); i != e; ++i) {
const Expr* arg = CE->getArg(i);
const ParmVarDecl* PVD = FD->getParamDecl(i);
const ParmVarDecl* PVD = FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff = Visit(arg);
if ((argDiff.getExpr_dx() != nullptr) &&
PVD->getType()->isReferenceType()) {
Expand All @@ -1906,7 +1915,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get());
}
if (isa<CXXMemberCallExpr>(CE)) {
if (isa<CXXMemberCallExpr>(CE) || isa<CXXOperatorCallExpr>(CE)) {
Expr* baseE = baseDiff.getExpr();
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, calleeFnForwPassFD);
Expand Down Expand Up @@ -1993,6 +2002,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
}

if (opCode == UnaryOperatorKind::UO_Deref && m_Mode == DiffMode::reverse_mode_forward_pass) {
if (const auto *MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (MD->isInstance()) {
diff = Visit(UnOp->getSubExpr());
Expr* cloneE =
BuildOp(UnaryOperatorKind::UO_Deref, diff.getExpr());
Expr* derivedE = diff.getExpr_dx();
return {cloneE, derivedE};
}
}
}
// We should not output any warning on visiting boolean conditions
// FIXME: We should support boolean differentiation or ignore it
// completely
Expand Down Expand Up @@ -3154,9 +3175,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_Mode == DiffMode::reverse)
assert(yType->isRealType() &&
"yType should be a non-reference builtin-numerical scalar type!!");
else if (m_Mode == DiffMode::experimental_pullback)
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
// derivative variables should always be of non-const type.
xValueType.removeLocalConst();
Expand Down
51 changes: 49 additions & 2 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,10 @@ public:
// CHECK-NEXT: }

double& ref_mem_fn(double i) {return x;}
SimpleFunctions& operator+=(double value) {
x += value;
return *this;
}

void mem_fn_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j);
void const_mem_fn_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j);
Expand Down Expand Up @@ -783,6 +787,44 @@ double fn2(SimpleFunctions& sf, double i) {
// CHECK-NEXT: _t1.ref_mem_fn_pullback(_t0, 1, &(* _d_sf), &_grad0);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_i += _r0;

double fn3(SimpleFunctions& v, double value) {
v += value;
return v.x;
}

// CHECK: void operator_plus_equal_pullback(double value, SimpleFunctions _d_y, clad::array_ref<SimpleFunctions> _d_this, clad::array_ref<double> _d_value) {
// CHECK-NEXT: this->x += value;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: double _r_d0 = (* _d_this).x;
// CHECK-NEXT: (* _d_this).x += _r_d0;
// CHECK-NEXT: * _d_value += _r_d0;
// CHECK-NEXT: (* _d_this).x -= _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> operator_plus_equal_forw(double value, clad::array_ref<SimpleFunctions> _d_this, clad::array_ref<SimpleFunctions> _d_value) {
// CHECK-NEXT: this->x += value;
// CHECK-NEXT: return {*this, (* _d_this)};
// CHECK-NEXT: }

// CHECK: void fn3_grad(SimpleFunctions &v, double value, clad::array_ref<SimpleFunctions> _d_v, clad::array_ref<double> _d_value) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: SimpleFunctions _t1;
// CHECK-NEXT: _t0 = value;
// CHECK-NEXT: _t1 = v;
// CHECK-NEXT: clad::ValueAndAdjoint<SimpleFunctions &, SimpleFunctions &> _t2 = _t1.operator_plus_equal_forw(_t0, &(* _d_v), nullptr);
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: (* _d_v).x += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: _t1.operator_plus_equal_pullback(_t0, {}, &(* _d_v), &_grad0);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_value += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down Expand Up @@ -821,12 +863,17 @@ int main() {
printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00
}

SimpleFunctions sf(2, 3);
SimpleFunctions sf1(2, 3), sf2(3, 4);
SimpleFunctions d_sf;

auto d_fn2 = clad::gradient(fn2);
d_fn2.execute(sf, 2, &d_sf, &result[0]);
d_fn2.execute(sf1, 2, &d_sf, &result[0]);
printf("%.2f", result[0]); //CHECK-EXEC: 40.00

auto d_fn3 = clad::gradient(fn3);
d_fn3.execute(sf2, 3, &d_sf, &result[0]);
printf("%.2f", result[0]); //CHECK-EXEC: 42.00

auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i");

// CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref<volatile SimpleFunctions> _d_this, clad::array_ref<double> _d_i) const volatile & {
Expand Down

0 comments on commit 75508b2

Please sign in to comment.