Skip to content

Commit

Permalink
Make all declarations function global.
Browse files Browse the repository at this point in the history
When we produce a gradient function we generally have a forward and reverse
sweep. In the forward sweep we accumulate the state and in the reverse sweep
we use that state to invert the program execution. The forward sweep generally
follows the sematics of the primal function and when neccessary stores the state
which would be needed but lost for the reverse sweep.

However, to minimize the stores onto the tape we need to reuse some of the
variables between the forward and the reverse sweeps which requires some
variables to be promoted to the enclosing lexical scope of both sweeps.

Fixes #659, fixes #681.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Feb 8, 2024
1 parent b58816f commit e0de8e7
Show file tree
Hide file tree
Showing 19 changed files with 401 additions and 221 deletions.
10 changes: 8 additions & 2 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ template <typename T> class array {
std::size_t m_size = 0;

public:
/// Delete default constructor
array() = delete;
/// Default constructor
array() = default;
/// Constructor to create an array of the specified size
CUDA_HOST_DEVICE array(std::size_t size)
: m_arr(new T[size]{static_cast<T>(0)}), m_size(size) {}
Expand Down Expand Up @@ -81,6 +81,12 @@ template <typename T> class array {
}

CUDA_HOST_DEVICE array<T>& operator=(const array<T>& arr) {
if (m_size < arr.m_size) {
delete[] m_arr;
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
m_arr = new T[arr.m_size];
m_size = arr.m_size;
}
(*this) = arr.m_arr;
return *this;
}
Expand Down
26 changes: 25 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,23 @@ namespace clad {
clang::Expr* R, clang::SourceLocation OpLoc = noLoc);

clang::Expr* BuildParens(clang::Expr* E);

/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
/// \param[in] Identifier The identifier information for the variable
/// declaration.
/// \param[in] Init The initalization expression to assign to the variable
/// declaration.
/// \param[in] DirectInit A check for if the initialization expression is a
/// C style initalization.
/// \param[in] TSI The type source information of the variable declaration.
/// \returns The newly built variable declaration.
clang::VarDecl*
BuildVarDecl(clang::QualType Type, clang::IdentifierInfo* Identifier,
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down Expand Up @@ -311,6 +327,14 @@ namespace clad {
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
/// body in the derivative function global scope.
clang::VarDecl*
BuildGlobalVarDecl(clang::QualType Type, llvm::StringRef prefix = "_t",
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
212 changes: 111 additions & 101 deletions lib/Differentiator/ReverseModeVisitor.cpp

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,12 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
// Replace the declaration if it is present in `m_DeclReplacements`.
if (VarDecl* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements))
if (it != std::end(m_DeclReplacements)) {
DRE->setDecl(it->second);
QualType NonRefQT = it->second->getType().getNonReferenceType();
if (NonRefQT != DRE->getType())
DRE->setType(NonRefQT);
}
}

DeclarationNameInfo DNI = DRE->getNameInfo();
Expand Down
35 changes: 29 additions & 6 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,17 @@ namespace clad {
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
delete oldScope;
}

VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {

return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit,
TSI, IS);
}
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto VD =
Expand All @@ -144,7 +149,7 @@ namespace clad {
}
m_Sema.FinalizeDeclaration(VD);
// Add the identifier to the scope and IdResolver
m_Sema.PushOnScopeChains(VD, getCurrentScope(), /*AddToContext*/ false);
m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false);
return VD;
}

Expand All @@ -162,6 +167,14 @@ namespace clad {
TSI, IS);
}

VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
llvm::StringRef prefix, Expr* Init,
bool DirectInit, TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix),
m_DerivativeFnScope, Init, DirectInit, TSI, IS);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
bool isInline) {
// Check if the namespace is being redeclared.
Expand Down Expand Up @@ -409,9 +422,19 @@ namespace clad {
Expr* VisitorBase::BuildArraySubscript(
Expr* Base, const llvm::SmallVectorImpl<clang::Expr*>& Indices) {
Expr* result = Base;
for (Expr* I : Indices)
result =
m_Sema.CreateBuiltinArraySubscriptExpr(result, noLoc, I, noLoc).get();
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
if (utils::isArrayOrPointerType(Base->getType())) {
for (Expr* I : Indices)
result =
m_Sema.CreateBuiltinArraySubscriptExpr(result, fakeLoc, I, fakeLoc)
.get();
} else {
Expr* idx = Indices.back();
result = m_Sema
.ActOnArraySubscriptExpr(getCurrentScope(), Base, fakeLoc,
idx, fakeLoc)
.get();
}
return result;
}

Expand Down
77 changes: 51 additions & 26 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ double addArr(const double *arr, int n) {
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, ret);
//CHECK-NEXT: ret += arr[i];
Expand Down Expand Up @@ -71,11 +72,12 @@ float func(float* a, float* b) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, a[i]);
//CHECK-NEXT: a[i] *= b[i];
Expand Down Expand Up @@ -123,10 +125,11 @@ float func2(float* a) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += helper(a[i]);
Expand Down Expand Up @@ -156,11 +159,12 @@ float func3(float* a, float* b) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: clad::push(_t2, a[i]);
Expand Down Expand Up @@ -194,11 +198,12 @@ double func4(double x) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
Expand Down Expand Up @@ -242,23 +247,25 @@ double func5(int k) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t2;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int _d_i0 = 0;
//CHECK-NEXT: int i0 = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: int n = k;
//CHECK-NEXT: clad::array<double> _d_arr(n);
//CHECK-NEXT: double arr[n];
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[i]);
//CHECK-NEXT: arr[i] = k;
//CHECK-NEXT: }
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i0 = 0; i0 < 3; i0++) {
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t3, sum);
//CHECK-NEXT: sum += addArr(arr, n);
Expand All @@ -267,7 +274,7 @@ double func5(int k) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t2; _t2--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: i0--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d1 = _d_sum;
Expand Down Expand Up @@ -304,14 +311,19 @@ double func6(double seed) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::array<double> arr(3UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<clad::array<double> > _t3 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double arr[3] = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: clad::push(_t1, arr) , arr = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: clad::push(_t3, arr);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -320,12 +332,14 @@ double func6(double seed) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: clad::array<double> _r1 = clad::pop(_t3);
//CHECK-NEXT: arr = _r1;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: addArr_pullback(_r1, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0 = _d_arr;
//CHECK-NEXT: int _r2 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
Expand All @@ -334,6 +348,7 @@ double func6(double seed) {
//CHECK-NEXT: * _d_seed += _d_arr[2];
//CHECK-NEXT: _d_i += _d_arr[2];
//CHECK-NEXT: _d_arr = {};
//CHECK-NEXT: arr = clad::pop(_t1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -367,14 +382,19 @@ double func7(double *params) {
//CHECK-NEXT: double _d_out = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: std::size_t _d_i = 0;
//CHECK-NEXT: std::size_t i = 0;
//CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
//CHECK-NEXT: clad::array<double> _d_paramsPrime(1UL);
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::array<double> paramsPrime(1UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<clad::array<double> > _t3 = {};
//CHECK-NEXT: double out = 0.;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (std::size_t i = 0; i < 1; ++i) {
//CHECK-NEXT: for (i = 0; i < 1; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double paramsPrime[1] = {params[0]};
//CHECK-NEXT: clad::push(_t1, out);
//CHECK-NEXT: clad::push(_t1, paramsPrime) , paramsPrime = {params[0]};
//CHECK-NEXT: clad::push(_t2, out);
//CHECK-NEXT: clad::push(_t3, paramsPrime);
//CHECK-NEXT: out = out + inv_square(paramsPrime);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -383,16 +403,19 @@ double func7(double *params) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: out = clad::pop(_t1);
//CHECK-NEXT: out = clad::pop(_t2);
//CHECK-NEXT: double _r_d0 = _d_out;
//CHECK-NEXT: _d_out -= _r_d0;
//CHECK-NEXT: _d_out += _r_d0;
//CHECK-NEXT: inv_square_pullback(paramsPrime, _r_d0, _d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r0(_d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r1 = clad::pop(_t3);
//CHECK-NEXT: paramsPrime = _r1;
//CHECK-NEXT: inv_square_pullback(_r1, _r_d0, _d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r0 = _d_paramsPrime;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _d_params[0] += _d_paramsPrime[0];
//CHECK-NEXT: _d_paramsPrime = {};
//CHECK-NEXT: paramsPrime = clad::pop(_t1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -491,10 +514,11 @@ double func9(double i, double j) {
//CHECK-NEXT: clad::array<double> _d_arr(5UL);
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_idx = 0;
//CHECK-NEXT: int idx = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[5] = {};
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int idx = 0; idx < 5; ++idx) {
//CHECK-NEXT: for (idx = 0; idx < 5; ++idx) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[idx]);
//CHECK-NEXT: modify(arr[idx], i);
Expand Down Expand Up @@ -557,11 +581,12 @@ double func10(double *arr, int n) {
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; ++i) {
//CHECK-NEXT: for (i = 0; i < n; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, res);
//CHECK-NEXT: clad::push(_t2, arr[i]);
Expand Down
3 changes: 2 additions & 1 deletion test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double _t3;
Expand All @@ -42,7 +43,7 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: double _t6;
//CHECK-NEXT: double t = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < dim; i++) {
//CHECK-NEXT: for (i = 0; i < dim; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
//CHECK-NEXT: t += (x[i] - p[i]) * (x[i] - p[i]);
Expand Down
Loading

0 comments on commit e0de8e7

Please sign in to comment.