Skip to content

Commit

Permalink
Fix passing pointers as call arguments
Browse files Browse the repository at this point in the history
fixes #735, #636
  • Loading branch information
vaithak authored and vgvassilev committed Feb 15, 2024
1 parent 9f3556d commit d305002
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 14 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ namespace clad {
bool ContainsFunctionCalls(const clang::Stmt* E);

void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);

bool IsLiteral(const clang::Expr* E);
} // namespace utils
} // namespace clad

Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ namespace clad {
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE);

/// A helper method to differentiate a single Stmt in the reverse mode.
/// Internally, calls Visit(S, expr). Its result is wrapped into a
Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,5 +634,12 @@ namespace clad {
else
cast<DefaultStmt>(SC)->setSubStmt(subStmt);
}

bool IsLiteral(const clang ::Expr* E) {
return isa<IntegerLiteral>(E) || isa<FloatingLiteral>(E) ||
isa<CharacterLiteral>(E) || isa<StringLiteral>(E) ||
isa<ObjCBoolLiteralExpr>(E) || isa<CXXBoolLiteralExpr>(E) ||
isa<GNUNullExpr>(E);
}
} // namespace utils
} // namespace clad
38 changes: 29 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
SL->getType(), utils::GetValidSLoc(m_Sema)));
}

StmtDiff ReverseModeVisitor::VisitCXXNullPtrLiteralExpr(
const CXXNullPtrLiteralExpr* NPE) {
return StmtDiff(Clone(NPE), Clone(NPE));
}

StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
// Initially, df/df = 1.
const Expr* value = RS->getRetValue();
Expand Down Expand Up @@ -1360,7 +1365,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) {
return StmtDiff(Clone(IL));
auto* Constant0 =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
return StmtDiff(Clone(IL), Constant0);
}

StmtDiff ReverseModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) {
Expand Down Expand Up @@ -1461,22 +1468,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// argument by reference.
passByRef = false;
}
QualType argDiffType;
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
if (passByRef) {
argDiff = Visit(arg);
Expr* dArg = nullptr;
argDiffType = argDiff.getExpr()->getType();
QualType argResultValueType =
utils::GetValueType(argDiff.getExpr()->getType())
.getNonReferenceType();
utils::GetValueType(argDiffType).getNonReferenceType();
// Create ArgResult variable for each reference argument because it is
// required by error estimator. For automatic differentiation, we do not need
// to create ArgResult variable for arguments passed by reference.
// ```
// _r0 = _d_a;
// ```
Expr* dArg = nullptr;
if (utils::isArrayOrPointerType(argDiff.getExpr()->getType())) {
if (argDiff.getExpr_dx() && utils::IsLiteral(argDiff.getExpr_dx())) {
dArg = StoreAndRef(argDiff.getExpr_dx(), arg->getType(),
direction::reverse, "_r",
/*forceDeclCreation=*/true);
} else if (argDiffType->isArrayType()) {
Expr* init = argDiff.getExpr_dx();
if (isa<ConstantArrayType>(argDiff.getExpr_dx()->getType()))
init = utils::BuildCladArrayInitByConstArray(m_Sema,
Expand All @@ -1486,6 +1498,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
direction::reverse, "_r",
/*forceDeclCreation=*/true,
VarDecl::InitializationStyle::CallInit);
} else if (argDiffType->isPointerType()) {
dArg = StoreAndRef(argDiff.getExpr_dx(), argDiffType,
direction::reverse, "_r",
/*forceDeclCreation=*/true);
} else {
dArg = StoreAndRef(argDiff.getExpr_dx(), argResultValueType,
direction::reverse, "_r",
Expand All @@ -1511,6 +1527,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
cast<VarDecl>(cast<DeclRefExpr>(dArg)->getDecl()));
// Visit using uninitialized reference.
argDiff = Visit(arg, dArg);
argDiffType = argDiff.getExpr()->getType();
}

// FIXME: We may use same argDiff.getExpr_dx at two places. This can
Expand All @@ -1536,7 +1553,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: We cannot use GlobalStoreAndRef to store a whole array so now
// arrays are not stored.
StmtDiff argDiffStore;
if (passByRef && !argDiff.getExpr()->getType()->isArrayType())
if (passByRef && !argDiffType->isArrayType() &&
!argDiff.getExpr()->isEvaluatable(m_Context))
argDiffStore =
GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true);
else
Expand Down Expand Up @@ -1567,7 +1585,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ```
// FIXME: We cannot use GlobalStoreAndRef to store a whole array so now
// arrays are not stored.
if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) {
if (passByRef && !argDiffType->isArrayType()) {
if (isInsideLoop) {
// Add tape push expression. We need to explicitly add it here because
// we cannot add it as call expression argument -- we need to pass the
Expand Down Expand Up @@ -1606,7 +1624,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// inside loop and outside loop cases separately.
Expr* newArgE = Visit(arg).getExpr();
argDiffStore = {newArgE, argDiffLocalE};
} else {
} else if (isa<DeclRefExpr>(argDiff.getExpr())) {
// Restore args
auto& block = getCurrentBlock(direction::reverse);
auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(),
Expand Down Expand Up @@ -1734,7 +1752,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
argDerivative = BuildDeclRef(derivativeArrayRefVD);
}
if ((argDerivative != nullptr) &&
isCladArrayType(argDerivative->getType()))
(isCladArrayType(argDerivative->getType()) ||
argDerivative->getType()->isPointerType() ||
!argDerivative->isLValue()))
gradArgExpr = argDerivative;
else
gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative);
Expand Down
4 changes: 2 additions & 2 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ double f(double *arr) {
//CHECK-NEXT: arr = _t0;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(_t0, 3, 1, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: double *_r0 = _d_arr;
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -473,7 +473,7 @@ double func8(double i, double *arr, int n) {
//CHECK-NEXT: helper2_pullback(i, _t2, n, _r_d1, &_grad0, _d_arr, &_grad2);
//CHECK-NEXT: double _r0 = _grad0;
//CHECK-NEXT: * _d_i += _r0;
//CHECK-NEXT: clad::array<double> _r1(_d_arr);
//CHECK-NEXT: double *_r1 = _d_arr;
//CHECK-NEXT: int _r2 = _grad2;
//CHECK-NEXT: * _d_n += _r2;
//CHECK-NEXT: }
Expand Down
35 changes: 32 additions & 3 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ double fn4(double* arr, int n) {
// CHECK-NEXT: arr = _t1;
// CHECK-NEXT: int _grad1 = 0;
// CHECK-NEXT: sum_pullback(_t1, n, _r_d0, _d_arr, &_grad1);
// CHECK-NEXT: clad::array<double> _r0(_d_arr);
// CHECK-NEXT: double *_r0 = _d_arr;
// CHECK-NEXT: int _r1 = _grad1;
// CHECK-NEXT: * _d_n += _r1;
// CHECK-NEXT: }
Expand Down Expand Up @@ -348,7 +348,7 @@ double fn5(double* arr, int n) {
// CHECK-NEXT: {
// CHECK-NEXT: arr = _t0;
// CHECK-NEXT: modify2_pullback(_t0, _d_temp, _d_arr);
// CHECK-NEXT: clad::array<double> _r0(_d_arr);
// CHECK-NEXT: double *_r0 = _d_arr;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down Expand Up @@ -497,7 +497,7 @@ double fn8(double x, double y) {
// CHECK-NEXT: double _r0 = _grad0;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: char _r1 = _grad1;
// CHECK-NEXT: clad::array<char> _r2({"", 3UL});
// CHECK-NEXT: const char *_r2 = "";
// CHECK-NEXT: * _d_y += _t3 * 1 * _t0 * _t1;
// CHECK-NEXT: }
// CHECK-NEXT: }
Expand Down Expand Up @@ -645,6 +645,33 @@ double fn11(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double do_nothing(double* u, double* v, double* w) {
return u[0];
}

// CHECK: void do_nothing_pullback(double *u, double *v, double *w, double _d_y, clad::array_ref<double> _d_u, clad::array_ref<double> _d_v, clad::array_ref<double> _d_w) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_u[0] += _d_y;
// CHECK-NEXT: }

double fn12(double x, double y) {
return do_nothing(&x, nullptr, 0);
}

// CHECK: void fn12_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double *_t0;
// CHECK-NEXT: _t0 = &x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: do_nothing_pullback(_t0, nullptr, 0, 1, &* _d_x, nullptr, 0);
// CHECK-NEXT: double *_r0 = &* _d_x;
// CHECK-NEXT: {{(std::)?}}nullptr_t _r1 = nullptr;
// CHECK-NEXT: double *_r2 = 0;
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -701,6 +728,7 @@ int main() {
INIT(fn9);
INIT(fn10);
INIT(fn11);
INIT(fn12);

TEST1_float(fn1, 11); // CHECK-EXEC: {3.00}
TEST2(fn2, 3, 5); // CHECK-EXEC: {1.00, 3.00}
Expand All @@ -714,4 +742,5 @@ int main() {
TEST2(fn9, 3, 5); // CHECK-EXEC: {5.00, 3.00}
TEST2(fn10, 8, 5); // CHECK-EXEC: {0.00, 7.00}
TEST2(fn11, 3, 5); // CHECK-EXEC: {1.00, 1.00}
TEST2(fn12, 3, 5); // CHECK-EXEC: {1.00, 0.00}
}

0 comments on commit d305002

Please sign in to comment.