Skip to content

Commit

Permalink
Initialize object adjoints using a copy of the original and clad::zer…
Browse files Browse the repository at this point in the history
…o_init
  • Loading branch information
PetroZarytskyi committed Dec 20, 2024
1 parent 3e50707 commit 7004010
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 68 deletions.
62 changes: 49 additions & 13 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
#include "Tape.h"

#include <assert.h>
#include <stddef.h>
#include <cstddef>
#include <cstring>
#include <iterator>
#include <type_traits>

namespace clad {

Expand Down Expand Up @@ -82,23 +84,57 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
}

/// The purpose of this function is to initialize adjoints
/// (or all of its differentiable fields) with 0.
// FIXME: Add support for objects.
/// Initialize a non-array variable.
template <typename T> CUDA_HOST_DEVICE void zero_init(T& x) { new (&x) T(); }
/// (or all of its iteratable elements) with 0.
namespace zero_init_detail {
template <class T> struct iterator_traits : std::iterator_traits<T> {};
template <> struct iterator_traits<void*> {};
template <> struct iterator_traits<const void*> {};

template <class T, class It>
std::integral_constant<
bool, !std::is_same<typename std::remove_cv<T>::type,
typename iterator_traits<It>::value_type>::value>
is_range_check(It first, It last);

template <class T>
decltype(is_range_check<T>(std::begin(std::declval<const T&>()),
std::end(std::declval<const T&>())))
is_range(int);
template <class T> std::false_type is_range(...);
} // namespace zero_init_detail

template <class T>
struct is_range : decltype(zero_init_detail::is_range<T>(0)) {};

template <class T> void zero_init(T& t);

template <class T,
typename std::enable_if<!is_range<T>::value, int>::type = 0>
void zero_impl(volatile T& t) {
// Fill an array with zeros.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
unsigned char tmp[sizeof(T)] = {};
// Transfer the zeros with the magic function memcpy which can implicitly
// create objects in the destination region of storage immediately prior to
// copying the sequence of characters to the destination [27.5.1(3)].
// (C++ has deprecated the volatile qualifiers. However, we drop them here
// to make sure things still work with codebases which still have them)
std::memcpy(const_cast<T*>(&t), tmp, sizeof(T));
}

/// Initialize a non-const sized array when the size is known and is equal to
/// N.
template <typename T> CUDA_HOST_DEVICE void zero_init(T* x, std::size_t N) {
for (std::size_t i = 0; i < N; ++i)
zero_init(x[i]);
template <class T, typename std::enable_if<is_range<T>::value, int>::type = 0>
void zero_impl(T& t) {
for (auto& x : t)
zero_init(x);
}

template <class T> void zero_init(T& t) { zero_impl(t); }

/// Initialize a const sized array.
// NOLINTBEGIN(cppcoreguidelines-avoid-c-arrays)
template <typename T, std::size_t N>
CUDA_HOST_DEVICE void zero_init(T (&arr)[N]) {
zero_init((T*)arr, N);
template <typename T> CUDA_HOST_DEVICE void zero_init(T* x, std::size_t N) {
for (std::size_t i = 0; i < N; ++i)
zero_init(x[i]);
}
// NOLINTEND(cppcoreguidelines-avoid-c-arrays)

Expand Down
12 changes: 0 additions & 12 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,18 +451,6 @@ void at_pullback(::std::vector<T>* vec,
(*d_vec)[idx] += d_y;
}

template <typename T, typename S, typename U>
::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
S count, U val,
typename ::std::vector<T>::allocator_type alloc,
S d_count, U d_val,
typename ::std::vector<T>::allocator_type d_alloc) {
::std::vector<T> v(count, val);
::std::vector<T> d_v(count, 0);
return {v, d_v};
}

template <typename T, typename S, typename U>
void constructor_pullback(::std::vector<T>* v, S count, U val,
typename ::std::vector<T>::allocator_type alloc,
Expand Down
99 changes: 92 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Sema/Lookup.h"
Expand All @@ -32,6 +33,7 @@
#include <clang/AST/DeclCXX.h>
#include <clang/AST/ExprCXX.h>
#include <clang/AST/OperationKinds.h>
#include <clang/Basic/SourceLocation.h>
#include <clang/Sema/Ownership.h>

#include "llvm/ADT/SmallString.h"
Expand Down Expand Up @@ -2762,6 +2764,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

bool isConstructInit =
VD->getInit() && isa<CXXConstructExpr>(VD->getInit()->IgnoreImplicit());
const CXXRecordDecl* RD = VD->getType()->getAsCXXRecordDecl();
bool isNonAggrClass = RD && !RD->isAggregate();
bool emptyInitListInit = isNonAggrClass;

// VDDerivedInit now serves two purposes -- as the initial derivative value
// or the size of the derivative array -- depending on the primal type.
Expand Down Expand Up @@ -2814,8 +2819,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_TrackConstructorPullbackInfo = false;
constructorPullbackInfo = getConstructorPullbackCallInfo();
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
if (initDiff.getForwSweepExpr_dx()) {
VDDerivedInit = initDiff.getForwSweepExpr_dx();
emptyInitListInit = false;
}
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
Expand Down Expand Up @@ -2864,7 +2871,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr, VD->getInitStyle());

if (!m_DiffReq.shouldHaveAdjoint((VD)))
if (!m_DiffReq.shouldHaveAdjoint(VD))
VDDerived = nullptr;

// If `VD` is a reference to a local variable, then it is already
Expand Down Expand Up @@ -2899,11 +2906,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (VDDerived && isInsideLoop) {
Stmt* assignToZero = nullptr;
Expr* declRef = BuildDeclRef(VDDerived);
if (!isa<ArrayType>(VDDerivedType))
if (isa<ArrayType>(VDDerivedType) || isNonAggrClass)
assignToZero = GetCladZeroInit(declRef);
else
assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, declRef,
getZeroInit(VDDerivedType));
else
assignToZero = GetCladZeroInit(declRef);
if (!keepLocal)
addToCurrentBlock(assignToZero, direction::reverse);
}
Expand Down Expand Up @@ -2946,6 +2953,43 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());

// We initialize adjoints with original variables as part of
// the strategy to maintain the structure of the original variable.
// After that, we'll zero-initialize the adjoint. e.g.
// ```
// std::vector<...> v{x, y, z};
// std::vector<...> _d_v{v}; // The length of the vector is preserved
// clad::zero_init(_d_v);
// ```
// Also, if the original is initialized with a zero-constructor, it can be
// used for the adjoint as well.
if (isConstructInit && emptyInitListInit &&
cast<CXXConstructExpr>(VD->getInit()->IgnoreImplicit())->getNumArgs() !=
0) {
Expr* copyExpr = BuildDeclRef(VDClone);
QualType origTy = VDClone->getType();
// if VDClone is volatile, we have to use const_cast to be able to use
// most copy constructors.
if (origTy.isVolatileQualified()) {
Qualifiers quals(origTy.getQualifiers());
quals.removeVolatile();
QualType castTy = m_Sema.BuildQualifiedType(origTy.getUnqualifiedType(),
noLoc, quals);
castTy = m_Context.getLValueReferenceType(castTy);
SourceRange range = utils::GetValidSRange(m_Sema);
copyExpr =
m_Sema
.BuildCXXNamedCast(noLoc, tok::kw_const_cast,
m_Context.getTrivialTypeSourceInfo(
castTy, utils::GetValidSLoc(m_Sema)),
copyExpr, range, range)
.get();
}
m_Sema.AddInitializerToDecl(VDDerived, copyExpr, /*DirectInit=*/true);
VDDerived->setInitStyle(VarDecl::InitializationStyle::CallInit);
}

if (isPointerType && derivedVDE) {
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
Expand Down Expand Up @@ -3055,6 +3099,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Stmt*, 16> inits;
llvm::SmallVector<Decl*, 4> decls;
llvm::SmallVector<Decl*, 4> declsDiff;
llvm::SmallVector<Decl*, 4> classDeclsDiff;
llvm::SmallVector<Stmt*, 4> memsetCalls;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;
Expand Down Expand Up @@ -3143,9 +3188,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

decls.push_back(VDDiff.getDecl());
if (VDDiff.getDecl_dx()) {
const CXXRecordDecl* RD = VD->getType()->getAsCXXRecordDecl();
bool isNonAggrClass = RD && !RD->isAggregate();
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
else {
else if (isNonAggrClass) {
classDeclsDiff.push_back(VDDiff.getDecl_dx());
} else {
VarDecl* VDDerived = VDDiff.getDecl_dx();
declsDiff.push_back(VDDerived);
if (Stmt* memsetCall = CheckAndBuildCallToMemset(
Expand Down Expand Up @@ -3220,7 +3269,43 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToBlock(BuildDeclStmt(decl), m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = utils::unwrapIfSingleStmt(initAssignments);
return StmtDiff(initAssignments);
DSClone = initAssignments;
}

if (!classDeclsDiff.empty()) {
Stmts& block =
promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward);
addToBlock(DSClone, block);
DSClone = nullptr;
addToBlock(BuildDeclStmt(classDeclsDiff), block);
for (Decl* decl : classDeclsDiff) {
auto* vDecl = cast<VarDecl>(decl);
Expr* init = vDecl->getInit();
if (promoteToFnScope && init) {
auto* declRef = BuildDeclRef(vDecl);
auto* assignment = BuildOp(BO_Assign, declRef, init);
addToCurrentBlock(assignment, direction::forward);
m_Sema.AddInitializerToDecl(vDecl, /*init=*/nullptr,
/*DirectInit=*/true);
}
// Adjoints are initialized with copy-constructors only as a part of
// the strategy to maintain the structure of the original variable.
// In such cases, we need to zero-initialize the adjoint. e.g.
// ```
// std::vector<...> v{x, y, z};
// std::vector<...> _d_v{v};
// clad::zero_init(_d_v); // this line is generated below
// ```
const auto* CE = dyn_cast<CXXConstructExpr>(init->IgnoreImplicit());
bool copyInit =
CE && (CE->getNumArgs() == 0 ||
isa<DeclRefExpr>(CE->getArg(0)->IgnoreImplicit()));
if (copyInit) {
std::array<Expr*, 1> arg{BuildDeclRef(vDecl)};
Stmt* initCall = GetCladZeroInit(arg);
addToCurrentBlock(initCall, direction::forward);
}
}
}

return StmtDiff(DSClone);
Expand Down
6 changes: 4 additions & 2 deletions test/Gradient/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,9 @@ int main() {
// CHECK-EXEC: 54.00 42.00

// CHECK: void CallFunctor_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: Experiment _d_E({});
// CHECK-NEXT: Experiment E(3, 5);
// CHECK-NEXT: Experiment _d_E(E);
// CHECK-NEXT: clad::zero_init(_d_E);
// CHECK-NEXT: Experiment _t0 = E;
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
Expand Down Expand Up @@ -265,8 +266,9 @@ int main() {
// CHECK: void FunctorAsArg_pullback(Experiment fn, double i, double j, double _d_y, Experiment *_d_fn, double *_d_i, double *_d_j);

// CHECK: void FunctorAsArgWrapper_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: Experiment _d_E({});
// CHECK-NEXT: Experiment E(3, 5);
// CHECK-NEXT: Experiment _d_E(E);
// CHECK-NEXT: clad::zero_init(_d_E);
// CHECK-NEXT: {
// CHECK-NEXT: Experiment _r2 = {};
// CHECK-NEXT: double _r3 = 0.;
Expand Down
15 changes: 8 additions & 7 deletions test/Gradient/MemberFunctions.C
Original file line number Diff line number Diff line change
Expand Up @@ -528,17 +528,17 @@ double fn6(double u, double v) {
// CHECK-NEXT: double &_d_w = *_d_u;
// CHECK-NEXT: double &w = u;
// CHECK-NEXT: clad::ValueAndAdjoint<SafeTestClass, SafeTestClass> _t0 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<SafeTestClass>());
// CHECK-NEXT: SafeTestClass _d_s1(_t0.adjoint);
// CHECK-NEXT: SafeTestClass s1(_t0.value);
// CHECK-NEXT: SafeTestClass _d_s1(_t0.adjoint);
// CHECK-NEXT: clad::ValueAndAdjoint<SafeTestClass, SafeTestClass> _t1 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<SafeTestClass>(), u, &v, *_d_u, &*_d_v);
// CHECK-NEXT: SafeTestClass _d_s2(_t1.adjoint);
// CHECK-NEXT: SafeTestClass s2(_t1.value);
// CHECK-NEXT: SafeTestClass _d_s2(_t1.adjoint);
// CHECK-NEXT: clad::ValueAndAdjoint<SafeTestClass, SafeTestClass> _t2 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<SafeTestClass>(), w, _d_w);
// CHECK-NEXT: SafeTestClass s3(_t2.value);
// CHECK-NEXT: SafeTestClass _d_s3(_t2.adjoint);
// CHECK-NEXT: SafeTestClass s3(_t2.value);
// CHECK-NEXT: *_d_v += 1;
// CHECK-NEXT: {{.*}}constructor_pullback(&s2, u, &v, &_d_s2, &*_d_u, &*_d_v);
// CHECK-NEXT: }
// CHECK-NEXT: *_d_v += 1;
// CHECK-NEXT: {{.*}}constructor_pullback(&s2, u, &v, &_d_s2, &*_d_u, &*_d_v);
// CHECK-NEXT: }


int main() {
Expand Down Expand Up @@ -632,8 +632,9 @@ int main() {
// CHECK: void fn3_grad_2_3(double x, double y, double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _d_x = 0.;
// CHECK-NEXT: double _d_y = 0.;
// CHECK-NEXT: SimpleFunctions _d_sf({});
// CHECK-NEXT: SimpleFunctions sf(x, y);
// CHECK-NEXT: SimpleFunctions _d_sf(sf);
// CHECK-NEXT: clad::zero_init(_d_sf);
// CHECK-NEXT: SimpleFunctions _t0 = sf;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
Expand Down
18 changes: 15 additions & 3 deletions test/Gradient/NonDifferentiable.C
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ public:
}
};

namespace clad {
template <> void zero_init(SimpleFunctions1& f) {
f.x = 0;
f.y = 0;
f.x_pointer = &f.x;
f.y_pointer = &f.y;
}
}

double fn_s1_mem_fn(double i, double j) {
SimpleFunctions1 obj(2, 3);
return obj.mem_fn_1(i, j) + i * j;
Expand Down Expand Up @@ -125,8 +134,9 @@ int main() {
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj(obj);
// CHECK-NEXT: clad::zero_init(_d_obj);
// CHECK-NEXT: SimpleFunctions1 _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
Expand All @@ -144,8 +154,9 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj(obj);
// CHECK-NEXT: clad::zero_init(_d_obj);
// CHECK-NEXT: {
// CHECK-NEXT: _d_obj.x += 1 * obj.y;
// CHECK-NEXT: *_d_i += 1 * j;
Expand All @@ -158,8 +169,9 @@ int main() {
// CHECK-NEXT: }

// CHECK: void fn_s1_field_pointer_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj(obj);
// CHECK-NEXT: clad::zero_init(_d_obj);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_obj.x_pointer += 1 * *obj.y_pointer;
// CHECK-NEXT: *_d_i += 1 * j;
Expand Down
Loading

0 comments on commit 7004010

Please sign in to comment.