Skip to content

Commit

Permalink
Make integral type variables non-differentiable in the forward mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Apr 30, 2024
1 parent 6c6b123 commit 20b6dfc
Show file tree
Hide file tree
Showing 31 changed files with 582 additions and 747 deletions.
13 changes: 6 additions & 7 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ namespace custom_derivatives {
#ifdef __CUDACC__
template <typename T>
ValueAndPushforward<cudaError_t, cudaError_t>
cudaMalloc_pushforward(T** devPtr, size_t sz, T** d_devPtr, size_t d_sz)
cudaMalloc_pushforward(T** devPtr, size_t sz, T** d_devPtr)
__attribute__((host)) {
return {cudaMalloc(devPtr, sz), cudaMalloc(d_devPtr, sz)};
}

ValueAndPushforward<cudaError_t, cudaError_t>
cudaMemcpy_pushforward(void* destPtr, void* srcPtr, size_t count,
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr,
size_t d_count) __attribute__((host)) {
cudaMemcpyKind kind, void* d_destPtr, void* d_srcPtr)
__attribute__((host)) {
return {cudaMemcpy(destPtr, srcPtr, count, kind),
cudaMemcpy(d_destPtr, d_srcPtr, count, kind)};
}
Expand Down Expand Up @@ -199,18 +199,17 @@ CUDA_HOST_DEVICE void clamp_pullback(const T& v, const T& lo, const T& hi,

// NOLINTBEGIN(cppcoreguidelines-no-malloc)
// NOLINTBEGIN(cppcoreguidelines-owning-memory)
inline ValueAndPushforward<void*, void*> malloc_pushforward(size_t sz,
size_t d_sz) {
inline ValueAndPushforward<void*, void*> malloc_pushforward(size_t sz) {
return {malloc(sz), malloc(sz)};
}

inline ValueAndPushforward<void*, void*>
calloc_pushforward(size_t n, size_t sz, size_t d_n, size_t d_sz) {
calloc_pushforward(size_t n, size_t sz) {
return {calloc(n, sz), calloc(n, sz)};
}

inline ValueAndPushforward<void*, void*>
realloc_pushforward(void* ptr, size_t sz, void* d_ptr, size_t d_sz) {
realloc_pushforward(void* ptr, size_t sz, void* d_ptr) {
return {realloc(ptr, sz), realloc(d_ptr, sz)};
}

Expand Down
25 changes: 16 additions & 9 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
}
}
}

// If clad failed to derive it, try finding its derivative using
// numerical diff.
if (!callDiff) {
Expand Down Expand Up @@ -1293,9 +1292,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
// If the LHS has a non-differentiable type, Ldiff.getExpr_dx() will be 0.
// Don't create a warning then.
if (IsDifferentiableType(BinOp->getLHS()->getType()))
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
opDiff = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
} else if (opCode == BO_Assign || opCode == BO_AddAssign ||
opCode == BO_SubAssign) {
Expand Down Expand Up @@ -1372,10 +1374,13 @@ BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
BuildVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
// FIXME: Create unique identifier for derivative.
VarDecl* VDDerived = BuildVarDecl(
VarDecl* VDDerived = nullptr;
if (IsDifferentiableType(VD->getType())) {
VDDerived = BuildVarDecl(
VD->getType(), "_d_" + VD->getNameAsString(), initDiff.getExpr_dx(),
VD->isDirectInit(), nullptr, VD->getInitStyle());
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
}
return DeclDiff<VarDecl>(VDClone, VDDerived);
}

Expand Down Expand Up @@ -1437,7 +1442,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();
decls.push_back(VDDiff.getDecl());
declsDiff.push_back(VDDiff.getDecl_dx());
if (VDDiff.getDecl_dx())
declsDiff.push_back(VDDiff.getDecl_dx());
} else if (auto* SAD = dyn_cast<StaticAssertDecl>(D)) {
DeclDiff<StaticAssertDecl> SADDiff = DifferentiateStaticAssertDecl(SAD);
if (SADDiff.getDecl())
Expand Down Expand Up @@ -1576,7 +1582,7 @@ StmtDiff BaseForwardModeVisitor::VisitWhileStmt(const WhileStmt* WS) {
// ...
// ...
// }
if (condVarClone) {
if (condVarRes.getDecl_dx()) {
bodyResult = utils::PrependAndCreateCompoundStmt(
m_Sema.getASTContext(), cast<CompoundStmt>(bodyResult),
BuildDeclStmt(condVarRes.getDecl_dx()));
Expand Down Expand Up @@ -1655,7 +1661,8 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
if (condVarDecl) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
if (condVarDeclDiff.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()));
}

StmtDiff initVarRes = (SS->getInit() ? Visit(SS->getInit()) : StmtDiff());
Expand Down
58 changes: 36 additions & 22 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,30 +892,40 @@ namespace clad {
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

llvm::SmallVector<Expr*, 16> ExtendedCallArgs(CallArgs.begin(),
CallArgs.end());
llvm::SmallVector<Expr*, 16> ExtendedCallArgs;
llvm::SmallVector<Stmt*, 16> DeclStmts;
// FIXME: for now, integer types are considered differentiable in the
// forward mode.
if (m_Mode != DiffMode::forward &&
m_Mode != DiffMode::vector_forward_mode &&
m_Mode != DiffMode::experimental_pushforward)
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy));
Expr* arg = BuildDeclRef(argDecl);
if (!utils::isArrayOrPointerType(argTy))
arg = BuildOp(UO_AddrOf, arg);
ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i + 1, arg);
DeclStmts.push_back(BuildDeclStmt(argDecl));
auto MARargs = llvm::MutableArrayRef<Expr*>(CallArgs);
if (noOverloadExists(UnresolvedLookup, MARargs)) {
bool isMethodCall = isa<CXXMethodDecl>(originalFD);
ExtendedCallArgs = llvm::SmallVector<Expr*, 16>(CallArgs.begin(), CallArgs.end());
if (m_Mode != DiffMode::forward &&
m_Mode != DiffMode::vector_forward_mode &&
m_Mode != DiffMode::experimental_pushforward)
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
VarDecl* argDecl = BuildVarDecl(argTy, "_r", getZeroInit(argTy));
Expr* arg = BuildDeclRef(argDecl);
if (!utils::isArrayOrPointerType(argTy))
arg = BuildOp(UO_AddrOf, arg);
ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i + 1 + 2 * isMethodCall, arg);
DeclStmts.push_back(BuildDeclStmt(argDecl));
}
}
}
auto MARargs = llvm::MutableArrayRef<Expr*>(ExtendedCallArgs);

if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;
else
for (size_t i = 0, e = originalFD->getNumParams(); i < e; ++i) {
QualType paramTy = originalFD->getParamDecl(i)->getType();
if (!IsDifferentiableType(paramTy)) {
QualType argTy = utils::getNonConstType(paramTy, m_Context, m_Sema);
Expr* zero = getZeroInit(argTy);
ExtendedCallArgs.insert(ExtendedCallArgs.begin() + e + i + 2 * isMethodCall, zero);
}
}
MARargs = llvm::MutableArrayRef<Expr*>(ExtendedCallArgs);
if (noOverloadExists(UnresolvedLookup, MARargs))
return nullptr;
}

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, noLoc, MARargs, noLoc)
Expand Down Expand Up @@ -958,6 +968,10 @@ namespace clad {
return true;
}
}
return false;
} else if (const auto* DRE = dyn_cast<DeclRefExpr>(UnresolvedLookup)) {
const auto* FD = cast<FunctionDecl>(DRE->getDecl());
return FD->getNumParams() != ARargs.size();
}
return false;
}
Expand Down
30 changes: 9 additions & 21 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@ double addArr(const double *arr, int n) {
}

//CHECK: double addArr_darg0_1(const double *arr, int n) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _d_ret += (i == 1.);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _d_ret += (i == 1.);
//CHECK-NEXT: ret += arr[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_ret;
//CHECK-NEXT: }
Expand All @@ -59,25 +55,17 @@ double numMultIndex(double* arr, size_t n, double x) {
}

// CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) {
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
// CHECK-NEXT: bool flag = false;
// CHECK-NEXT: size_t _d_idx = 0;
// CHECK-NEXT: size_t idx = 0;
// CHECK-NEXT: {
// CHECK-NEXT: size_t _d_i = 0;
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: if (arr[i] == x) {
// CHECK-NEXT: _d_flag = 0;
// CHECK-NEXT: flag = true;
// CHECK-NEXT: _d_idx = _d_i;
// CHECK-NEXT: idx = i;
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: for (size_t i = 0; i < n; ++i) {
// CHECK-NEXT: if (arr[i] == x) {
// CHECK-NEXT: flag = true;
// CHECK-NEXT: idx = i;
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return flag ? _d_idx * x + idx * _d_x : 0;
// CHECK-NEXT: return flag ? 0 * x + idx * _d_x : 0;
// CHECK-NEXT: }

int main() {
Expand Down
27 changes: 9 additions & 18 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@ double sum(double x, double y, double z) {
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_s = 0;
//CHECK-NEXT: double s = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_vars[i];
//CHECK-NEXT: s = s + vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_vars[i];
//CHECK-NEXT: s = s + vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_s;
//CHECK-NEXT: }
Expand All @@ -55,21 +52,15 @@ double sum_squares(double x, double y, double z) {
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_squares[3];
//CHECK-NEXT: double squares[3];
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_squares[i] = _d_vars[i] * vars[i] + vars[i] * _d_vars[i];
//CHECK-NEXT: squares[i] = vars[i] * vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_squares[i] = _d_vars[i] * vars[i] + vars[i] * _d_vars[i];
//CHECK-NEXT: squares[i] = vars[i] * vars[i];
//CHECK-NEXT: }
//CHECK-NEXT: double _d_s = 0;
//CHECK-NEXT: double s = 0;
//CHECK-NEXT: {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_squares[i];
//CHECK-NEXT: s = s + squares[i];
//CHECK-NEXT: }
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _d_s = _d_s + _d_squares[i];
//CHECK-NEXT: s = s + squares[i];
//CHECK-NEXT: }
//CHECK-NEXT: return _d_s;
//CHECK-NEXT: }
Expand Down
Loading

0 comments on commit 20b6dfc

Please sign in to comment.