Skip to content

Commit

Permalink
Fix pointer arithmetic in fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Jul 16, 2024
1 parent ef668c7 commit 0792ff2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
22 changes: 15 additions & 7 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,13 +1460,13 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
// including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left
// nullptr is used for treating expressions like ((A && B) && C) correctly.
return StmtDiff(opDiff, nullptr);
}
if (!opDiff) {
} else {
// FIXME: add support for other binary operators
unsupportedOpWarn(BinOp->getEndLoc());
opDiff = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
}
opDiff = folder.fold(opDiff);
if (opDiff)
opDiff = folder.fold(opDiff);
// Recover the original operation from the Ldiff and Rdiff instead of
// cloning the tree.
Expr* op;
Expand Down Expand Up @@ -1499,11 +1499,19 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
// This may not necessarily be true in the future.
VarDecl* VDClone =
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
VD->isDirectInit(), /*TSI=*/nullptr, VD->getInitStyle());
// FIXME: Create unique identifier for derivative.
VarDecl* VDDerived = BuildVarDecl(
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
Expr* initDx = initDiff.getExpr_dx();
if (VD->getType()->isPointerType() && !initDx) {
// initialize with nullptr.
// NOLINTBEGIN(cppcoreguidelines-owned-memory)
initDx =
new (m_Context) CXXNullPtrLiteralExpr(VD->getType(), VD->getBeginLoc());
// NOLINTEND(cppcoreguidelines-owned-memory)
}
VarDecl* VDDerived =
BuildVarDecl(VD->getType(), "_d_" + VD->getNameAsString(), initDx,
VD->isDirectInit(), /*TSI=*/nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return DeclDiff<VarDecl>(VDClone, VDDerived);
}
Expand Down
2 changes: 1 addition & 1 deletion test/ForwardMode/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ public:
// CHECK-NEXT: double _d_j = 0;
// CHECK-NEXT: SimpleFunctions _d_this_obj;
// CHECK-NEXT: SimpleFunctions *_d_this = &_d_this_obj;
// CHECK-NEXT: double *_d_p;
// CHECK-NEXT: double *_d_p = nullptr;
// CHECK-NEXT: double *p;
// CHECK-NEXT: _d_p = _d_this->arr[1];
// CHECK-NEXT: p = this->arr[1];
Expand Down
14 changes: 14 additions & 0 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ double fn9(double* params, const double *constants) {
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }

double fn10(double *params, const double *constants) {
double c0 = *(constants + 0);
return params[0] * c0;
}

// CHECK: double fn10_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0;
// CHECK-NEXT: double c0 = *(constants + 0);
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }

int main() {
INIT_DIFFERENTIATE(fn1, "i");
Expand Down Expand Up @@ -230,6 +240,10 @@ int main() {
auto fn9_dx = clad::differentiate(fn9, "params[0]");
d_param = fn9_dx.execute(params, constants);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {5.00}

auto fn10_dx = clad::differentiate(fn10, "params[0]");
d_param = fn10_dx.execute(params, constants);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {5.00}
}

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
Expand Down

0 comments on commit 0792ff2

Please sign in to comment.