Skip to content

Commit

Permalink
Fix gradient of fxns with const reference parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Oct 25, 2023
1 parent 1ae1f75 commit 8bda639
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
24 changes: 20 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
StmtDiff argDiff{};
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
if (passByRef && isa<MaterializeTemporaryExpr>(arg)) {
// If the argument is a temporary variable, this means that param type
// is a reference to a const type and we are passing a temporary
// variable to it. In this case, we should not pass the derivative
// argument by reference.
passByRef = false;
}
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
Expand Down Expand Up @@ -1498,7 +1505,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// same as the call expression as it is the type used to declare the
// _gradX array
Expr* dArg;
dArg = StoreAndRef(/*E=*/nullptr, arg->getType(), direction::reverse, "_r",
QualType argType = utils::GetValueType(arg->getType());
dArg = StoreAndRef(/*E=*/nullptr, argType, direction::reverse, "_r",
/*forceDeclCreation=*/true);
ArgResultDecls.push_back(
cast<VarDecl>(cast<DeclRefExpr>(dArg)->getDecl()));
Expand Down Expand Up @@ -1673,6 +1681,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto PVD = FD->getParamDecl(idx);
bool passByRef = utils::IsReferenceOrPointerType(PVD->getType());
if (passByRef && isa<MaterializeTemporaryExpr>(CE->getArg(idx))) {
// If the argument is a temporary variable, this means that param type
// is a reference to a const type and we are passing a temporary
// variable to it. In this case, we should not pass the derivative
// argument by reference.
passByRef = false;
}
if (passByRef) {
// If derivative type is constant array type instead of
// `clad::array_ref` or `clad::array` type, then create an
Expand Down Expand Up @@ -1700,13 +1715,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else {
// Declare: diffArgType _grad;
Expr* initVal = nullptr;
if (!PVD->getType()->isRecordType()) {
QualType gradVarType = utils::GetValueType(PVD->getType());
if (!gradVarType->isRecordType()) {
// If the argument is not a class type, then initialize the grad
// variable with 0.
initVal =
ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0);
ConstantFolder::synthesizeLiteral(gradVarType, m_Context, 0);
}
gradVarDecl = BuildVarDecl(PVD->getType(), gradVarII, initVal);
gradVarDecl = BuildVarDecl(gradVarType, gradVarII, initVal);
// Pass the address of the declared variable
gradVarExpr = BuildDeclRef(gradVarDecl);
gradArgExpr =
Expand Down
44 changes: 44 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,48 @@ double fn8(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double custom_max(const double& a, const double& b) {
return a > b ? a : b;
}

// CHECK: void custom_max_pullback(const double &a, const double &b, double _d_y, clad::array_ref<double> _d_a, clad::array_ref<double> _d_b) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: _cond0 = a > b;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: * _d_a += _d_y;
// CHECK-NEXT: else
// CHECK-NEXT: * _d_b += _d_y;
// CHECK-NEXT: }

double fn9(double x, double y) {
return custom_max(x*y, y);
}

// CHECK: void fn9_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: _t1 = x;
// CHECK-NEXT: _t0 = y;
// CHECK-NEXT: _t2 = _t1 * _t0;
// CHECK-NEXT: _t3 = y;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _grad0 = 0.;
// CHECK-NEXT: custom_max_pullback(_t2, _t3, 1, &_grad0, &* _d_y);
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: double _r1 = _r0 * _t0;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: double _r2 = _t1 * _r0;
// CHECK-NEXT: * _d_y += _r2;
// CHECK-NEXT: double _r3 = * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -544,6 +586,7 @@ int main() {
INIT(fn6);
INIT(fn7);
INIT(fn8);
INIT(fn9);

TEST1_float(fn1, 11); // CHECK-EXEC: {3.00}
TEST2(fn2, 3, 5); // CHECK-EXEC: {1.00, 3.00}
Expand All @@ -554,4 +597,5 @@ int main() {
TEST2(fn6, 3, 5); // CHECK-EXEC: {5.00, 3.00}
TEST2(fn7, 3, 5); // CHECK-EXEC: {10.00, 71.00}
TEST2(fn8, 3, 5); // CHECK-EXEC: {7.62, 4.57}
TEST2(fn9, 3, 5); // CHECK-EXEC: {5.00, 3.00}
}

0 comments on commit 8bda639

Please sign in to comment.