Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make built-in type variable declarations global #738

Merged
merged 1 commit into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ template <typename T> class array {
std::size_t m_size = 0;

public:
/// Delete default constructor
array() = delete;
/// Default constructor
array() = default;
/// Constructor to create an array of the specified size
CUDA_HOST_DEVICE array(std::size_t size)
: m_arr(new T[size]{static_cast<T>(0)}), m_size(size) {}
Expand Down Expand Up @@ -81,6 +81,12 @@ template <typename T> class array {
}

CUDA_HOST_DEVICE array<T>& operator=(const array<T>& arr) {
if (m_size < arr.m_size) {
delete[] m_arr;
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
m_arr = new T[arr.m_size];
PetroZarytskyi marked this conversation as resolved.
Show resolved Hide resolved
m_size = arr.m_size;
}
(*this) = arr.m_arr;
return *this;
}
Expand Down
26 changes: 25 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,23 @@ namespace clad {
clang::Expr* R, clang::SourceLocation OpLoc = noLoc);

clang::Expr* BuildParens(clang::Expr* E);

/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
/// \param[in] Identifier The identifier information for the variable
/// declaration.
/// \param[in] Init The initalization expression to assign to the variable
/// declaration.
/// \param[in] DirectInit A check for if the initialization expression is a
/// C style initalization.
/// \param[in] TSI The type source information of the variable declaration.
/// \returns The newly built variable declaration.
clang::VarDecl*
BuildVarDecl(clang::QualType Type, clang::IdentifierInfo* Identifier,
clang::Scope* scope, clang::Expr* Init = nullptr,
bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand Down Expand Up @@ -311,6 +327,14 @@ namespace clad {
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
/// body in the derivative function global scope.
clang::VarDecl*
BuildGlobalVarDecl(clang::QualType Type, llvm::StringRef prefix = "_t",
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
212 changes: 111 additions & 101 deletions lib/Differentiator/ReverseModeVisitor.cpp

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,12 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
// Replace the declaration if it is present in `m_DeclReplacements`.
if (VarDecl* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements))
if (it != std::end(m_DeclReplacements)) {
DRE->setDecl(it->second);
QualType NonRefQT = it->second->getType().getNonReferenceType();
if (NonRefQT != DRE->getType())
DRE->setType(NonRefQT);
}
}

DeclarationNameInfo DNI = DRE->getNameInfo();
Expand Down
35 changes: 29 additions & 6 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,17 @@ namespace clad {
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
delete oldScope;
}

VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {

return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit,
TSI, IS);
}
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Scope* Scope, Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
auto VD =
Expand All @@ -144,7 +149,7 @@ namespace clad {
}
m_Sema.FinalizeDeclaration(VD);
// Add the identifier to the scope and IdResolver
m_Sema.PushOnScopeChains(VD, getCurrentScope(), /*AddToContext*/ false);
m_Sema.PushOnScopeChains(VD, Scope, /*AddToContext*/ false);
return VD;
}

Expand All @@ -162,6 +167,14 @@ namespace clad {
TSI, IS);
}

VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type,
llvm::StringRef prefix, Expr* Init,
bool DirectInit, TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix),
m_DerivativeFnScope, Init, DirectInit, TSI, IS);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
bool isInline) {
// Check if the namespace is being redeclared.
Expand Down Expand Up @@ -409,9 +422,19 @@ namespace clad {
Expr* VisitorBase::BuildArraySubscript(
Expr* Base, const llvm::SmallVectorImpl<clang::Expr*>& Indices) {
PetroZarytskyi marked this conversation as resolved.
Show resolved Hide resolved
Expr* result = Base;
for (Expr* I : Indices)
result =
m_Sema.CreateBuiltinArraySubscriptExpr(result, noLoc, I, noLoc).get();
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
if (utils::isArrayOrPointerType(Base->getType())) {
for (Expr* I : Indices)
result =
m_Sema.CreateBuiltinArraySubscriptExpr(result, fakeLoc, I, fakeLoc)
.get();
} else {
Expr* idx = Indices.back();
result = m_Sema
.ActOnArraySubscriptExpr(getCurrentScope(), Base, fakeLoc,
idx, fakeLoc)
.get();
}
return result;
}

Expand Down
77 changes: 51 additions & 26 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ double addArr(const double *arr, int n) {
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, ret);
//CHECK-NEXT: ret += arr[i];
Expand Down Expand Up @@ -71,11 +72,12 @@ float func(float* a, float* b) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, a[i]);
//CHECK-NEXT: a[i] *= b[i];
Expand Down Expand Up @@ -123,10 +125,11 @@ float func2(float* a) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += helper(a[i]);
Expand Down Expand Up @@ -156,11 +159,12 @@ float func3(float* a, float* b) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: clad::push(_t2, a[i]);
Expand Down Expand Up @@ -194,11 +198,12 @@ double func4(double x) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
Expand Down Expand Up @@ -242,23 +247,25 @@ double func5(int k) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t2;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int _d_i0 = 0;
//CHECK-NEXT: int i0 = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: int n = k;
//CHECK-NEXT: clad::array<double> _d_arr(n);
//CHECK-NEXT: double arr[n];
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: for (i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[i]);
//CHECK-NEXT: arr[i] = k;
//CHECK-NEXT: }
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i0 = 0; i0 < 3; i0++) {
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t3, sum);
//CHECK-NEXT: sum += addArr(arr, n);
Expand All @@ -267,7 +274,7 @@ double func5(int k) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t2; _t2--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: i0--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d1 = _d_sum;
Expand Down Expand Up @@ -304,14 +311,19 @@ double func6(double seed) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::array<double> arr(3UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<clad::array<double> > _t3 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: for (i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double arr[3] = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: clad::push(_t1, arr) , arr = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: clad::push(_t3, arr);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -320,12 +332,14 @@ double func6(double seed) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: clad::array<double> _r1 = clad::pop(_t3);
//CHECK-NEXT: arr = _r1;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: addArr_pullback(_r1, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0 = _d_arr;
//CHECK-NEXT: int _r2 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
Expand All @@ -334,6 +348,7 @@ double func6(double seed) {
//CHECK-NEXT: * _d_seed += _d_arr[2];
//CHECK-NEXT: _d_i += _d_arr[2];
//CHECK-NEXT: _d_arr = {};
//CHECK-NEXT: arr = clad::pop(_t1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -367,14 +382,19 @@ double func7(double *params) {
//CHECK-NEXT: double _d_out = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: std::size_t _d_i = 0;
//CHECK-NEXT: std::size_t i = 0;
//CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
//CHECK-NEXT: clad::array<double> _d_paramsPrime(1UL);
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::array<double> paramsPrime(1UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<clad::array<double> > _t3 = {};
//CHECK-NEXT: double out = 0.;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (std::size_t i = 0; i < 1; ++i) {
//CHECK-NEXT: for (i = 0; i < 1; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double paramsPrime[1] = {params[0]};
//CHECK-NEXT: clad::push(_t1, out);
//CHECK-NEXT: clad::push(_t1, paramsPrime) , paramsPrime = {params[0]};
//CHECK-NEXT: clad::push(_t2, out);
//CHECK-NEXT: clad::push(_t3, paramsPrime);
//CHECK-NEXT: out = out + inv_square(paramsPrime);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -383,16 +403,19 @@ double func7(double *params) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: out = clad::pop(_t1);
//CHECK-NEXT: out = clad::pop(_t2);
//CHECK-NEXT: double _r_d0 = _d_out;
//CHECK-NEXT: _d_out -= _r_d0;
//CHECK-NEXT: _d_out += _r_d0;
//CHECK-NEXT: inv_square_pullback(paramsPrime, _r_d0, _d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r0(_d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r1 = clad::pop(_t3);
//CHECK-NEXT: paramsPrime = _r1;
//CHECK-NEXT: inv_square_pullback(_r1, _r_d0, _d_paramsPrime);
//CHECK-NEXT: clad::array<double> _r0 = _d_paramsPrime;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _d_params[0] += _d_paramsPrime[0];
//CHECK-NEXT: _d_paramsPrime = {};
//CHECK-NEXT: paramsPrime = clad::pop(_t1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -491,10 +514,11 @@ double func9(double i, double j) {
//CHECK-NEXT: clad::array<double> _d_arr(5UL);
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_idx = 0;
//CHECK-NEXT: int idx = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[5] = {};
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int idx = 0; idx < 5; ++idx) {
//CHECK-NEXT: for (idx = 0; idx < 5; ++idx) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[idx]);
//CHECK-NEXT: modify(arr[idx], i);
Expand Down Expand Up @@ -557,11 +581,12 @@ double func10(double *arr, int n) {
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; ++i) {
//CHECK-NEXT: for (i = 0; i < n; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, res);
//CHECK-NEXT: clad::push(_t2, arr[i]);
Expand Down
3 changes: 2 additions & 1 deletion test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double _t3;
Expand All @@ -42,7 +43,7 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: double _t6;
//CHECK-NEXT: double t = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < dim; i++) {
//CHECK-NEXT: for (i = 0; i < dim; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
//CHECK-NEXT: t += (x[i] - p[i]) * (x[i] - p[i]);
Expand Down
Loading
Loading