Skip to content

Commit

Permalink
Synthesize the nested name specifiers to include the namespace qualif…
Browse files Browse the repository at this point in the history
…iers.

This patch appends to each DeclRefExpr the namespace qualifiers so that the code
can compile properly and refer to the exact namespace where the entities were
defined.
  • Loading branch information
vgvassilev committed Dec 15, 2024
1 parent c1a87d4 commit b36eae9
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 62 deletions.
22 changes: 7 additions & 15 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
#ifndef CLAD_VISITOR_BASE_H
#define CLAD_VISITOR_BASE_H

namespace clad {
class DerivativeBuilder;
}

#include "Compatibility.h"
#include "DerivativeBuilder.h"

Expand All @@ -22,6 +18,10 @@ namespace clad {
#include <stack>
#include <unordered_map>

namespace clang {
class NestedNameSpecifier;
}

namespace clad {
/// A class that represents the result of Visit of ForwardModeVisitor.
/// Stmt() allows to access the original (cloned) Stmt and Stmt_dx() allows
Expand Down Expand Up @@ -360,18 +360,11 @@ namespace clad {
/// declaration reference expressions. This function builds a declaration
/// reference given a declaration.
/// \param[in] D The declaration to build a DeclRefExpr for.
/// \param[in] SS The scope specifier for the declaration.
/// \param[in] SS The nested name specifier for the declaration.
/// \returns the DeclRefExpr for the given declaration.
clang::DeclRefExpr*
BuildDeclRef(clang::DeclaratorDecl* D,
const clang::CXXScopeSpec* SS = nullptr,
clang::ExprValueKind VK = clang::VK_LValue);
/// Builds a DeclRefExpr to a given Decl, adding proper nested name
/// qualifiers.
/// \param[in] D The declaration to build a DeclRefExpr for.
/// \param[in] NNS The nested name specifier to use.
clang::DeclRefExpr*
BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS,
clang::NestedNameSpecifier* NNS = nullptr,
clang::ExprValueKind VK = clang::VK_LValue);

/// Stores the result of an expression in a temporary variable (of the same
Expand Down Expand Up @@ -543,8 +536,7 @@ namespace clad {
clang::Expr*
BuildCallExprToFunction(clang::FunctionDecl* FD,
llvm::MutableArrayRef<clang::Expr*> argExprs,
bool useRefQualifiedThisObj = false,
const clang::CXXScopeSpec* SS = nullptr);
bool useRefQualifiedThisObj = false);

/// Build a call to templated free function inside the clad namespace.
///
Expand Down
70 changes: 37 additions & 33 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Sema/Lookup.h"
Expand All @@ -25,8 +26,9 @@
#include "clang/Sema/SemaInternal.h"
#include "clang/Sema/Template.h"

#include "llvm/ADT/SmallVector.h"

#include <algorithm>
#include <llvm/ADT/SmallVector.h>
#include <numeric>

#include "clad/Differentiator/Compatibility.h"
Expand Down Expand Up @@ -237,38 +239,40 @@ namespace clad {
}

DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D,
const CXXScopeSpec* SS /*=nullptr*/,
NestedNameSpecifier* NNS /*=nullptr*/,
ExprValueKind VK /*=VK_LValue*/) {
QualType T = D->getType();
T = T.getNonReferenceType();
return cast<DeclRefExpr>(clad_compat::GetResult<Expr*>(
m_Sema.BuildDeclRefExpr(D, T, VK, D->getBeginLoc(), SS)));
}

DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D,
NestedNameSpecifier* NNS,
ExprValueKind VK /*=VK_LValue*/) {
std::vector<NestedNameSpecifier*> NNChain;
CXXScopeSpec CSS;
while (NNS) {
NNChain.push_back(NNS);
NNS = NNS->getPrefix();
}

std::reverse(NNChain.begin(), NNChain.end());

for (size_t i = 0; i < NNChain.size(); ++i) {
NNS = NNChain[i];
// FIXME: this needs to be extended to support more NNS kinds. An
// inspiration can be take from getFullyQualifiedNestedNameSpecifier in
// llvm-project/clang/lib/AST/QualTypeNames.cpp
if (NNS->getKind() == NestedNameSpecifier::Namespace) {
NamespaceDecl* NS = NNS->getAsNamespace();
CSS.Extend(m_Context, NS, noLoc, noLoc);
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
if (NNS) {
CSS.MakeTrivial(m_Context, NNS, fakeLoc);
} else {
// If no CXXScopeSpec is provided we should try to find the common path
// between the current scope (in which presumably we will make the call)
// and where `D` is.
llvm::SmallVector<DeclContext*, 4> DCs;
DeclContext* DeclDC = D->getDeclContext();
// FIXME: We should respect using clauses and shorten the qualified names.
while (!DeclDC->isTranslationUnit()) {
// Stop when we find the common ancestor.
if (DeclDC->Equals(m_Sema.CurContext))
break;

// FIXME: We should extend that for classes and class templates. See
// clang's getFullyQualifiedNestedNameSpecifier.
if (DeclDC->isNamespace() && !DeclDC->isInlineNamespace())
DCs.push_back(DeclDC);

DeclDC = DeclDC->getParent();
}
}

return BuildDeclRef(D, &CSS, VK);
for (unsigned i = DCs.size(); i > 0; --i)
CSS.Extend(m_Context, cast<NamespaceDecl>(DCs[i - 1]), fakeLoc,
fakeLoc);
}
QualType T = D->getType();
T = T.getNonReferenceType();
return cast<DeclRefExpr>(clad_compat::GetResult<Expr*>(
m_Sema.BuildDeclRefExpr(D, T, VK, D->getBeginLoc(), &CSS)));
}

IdentifierInfo*
Expand Down Expand Up @@ -681,13 +685,12 @@ namespace clad {
Expr*
VisitorBase::BuildCallExprToFunction(FunctionDecl* FD,
llvm::MutableArrayRef<Expr*> argExprs,
bool useRefQualifiedThisObj /*=false*/,
const CXXScopeSpec* SS /*=nullptr*/) {
bool useRefQualifiedThisObj /*=false*/) {
Expr* call = nullptr;
if (auto derMethod = dyn_cast<CXXMethodDecl>(FD)) {
call = BuildCallExprToMemFn(derMethod, argExprs, useRefQualifiedThisObj);
} else {
Expr* exprFunc = BuildDeclRef(FD, SS);
Expr* exprFunc = BuildDeclRef(FD);
call = m_Sema
.ActOnCallExpr(
getCurrentScope(),
Expand Down Expand Up @@ -720,7 +723,8 @@ namespace clad {
clang::TemplateArgumentList TL(TemplateArgumentList::OnStack, templateArgs);
FunctionDecl* FD = m_Sema.InstantiateFunctionDeclaration(FTD, &TL, loc);

return BuildCallExprToFunction(FD, argExprs, false, &CSS);
return BuildCallExprToFunction(FD, argExprs,
/*useRefQualifiedThisObj=*/false);
}

TemplateDecl* VisitorBase::GetCladArrayRefDecl() {
Expand Down
10 changes: 5 additions & 5 deletions test/FirstDerivative/FunctionsInNamespaces.C
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ int test_1(int x, int y) {
// CHECK: int test_1_darg1(int x, int y) {
// CHECK-NEXT: int _d_x = 0;
// CHECK-NEXT: int _d_y = 1;
// CHECK-NEXT: clad::ValueAndPushforward<int, int> _t0 = func3_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: clad::ValueAndPushforward<int, int> _t0 = function_namespace2::func3_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

Expand Down Expand Up @@ -83,7 +83,7 @@ double fn1(double i, double j) {
// CHECK: double fn1_darg1(double i, double j) {
// CHECK-NEXT: double _d_i = 0;
// CHECK-NEXT: double _d_j = 1;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_pushforward(i, j, _d_i, _d_j);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = A::B::C::someFn_pushforward(i, j, _d_i, _d_j);
// CHECK-NEXT: return _d_i + _d_j;
// CHECK-NEXT: }

Expand All @@ -98,14 +98,14 @@ int main () {
// CHECK: clad::ValueAndPushforward<int, int> func4_pushforward(int x, int y, int _d_x, int _d_y);

// CHECK: clad::ValueAndPushforward<int, int> func3_pushforward(int x, int y, int _d_x, int _d_y) {
// CHECK-NEXT: clad::ValueAndPushforward<int, int> _t0 = func4_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: clad::ValueAndPushforward<int, int> _t0 = function_namespace10::function_namespace11::func4_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return {_t0.value, _t0.pushforward};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double &_d_i, double _d_j);

// CHECK: clad::ValueAndPushforward<double, double> someFn_pushforward(double &i, double &j, double &_d_i, double &_d_j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_1_pushforward(i, j, _d_i, _d_j);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = A::B::C::someFn_1_pushforward(i, j, _d_i, _d_j);
// CHECK-NEXT: return {(double)3, (double)0};
// CHECK-NEXT: }

Expand All @@ -116,7 +116,7 @@ int main () {
// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double k, double &_d_i, double _d_j, double _d_k);

// CHECK: clad::ValueAndPushforward<double, double> someFn_1_pushforward(double &i, double j, double &_d_i, double _d_j) {
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = someFn_1_pushforward(i, j, j, _d_i, _d_j, _d_j);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = A::B::C::someFn_1_pushforward(i, j, j, _d_i, _d_j, _d_j);
// CHECK-NEXT: return {(double)2, (double)0};
// CHECK-NEXT: }

Expand Down
10 changes: 5 additions & 5 deletions test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,10 @@ std::complex<double> fn10(double i, double j) {
// CHECK-NEXT: c1.imag_pushforward(5 * i, &_d_c1, 0 * i + 5 * _d_i);
// CHECK-NEXT: c2.real_pushforward(5 * i, &_d_c2, 0 * i + 5 * _d_i);
// CHECK-NEXT: c2.imag_pushforward(2 * i, &_d_c2, 0 * i + 2 * _d_i);
// CHECK-NEXT: clad::ValueAndPushforward<complex<double>, complex<double> > _t0 = operator_plus_pushforward(c1, c2, _d_c1, _d_c2);
// CHECK-NEXT: clad::ValueAndPushforward<complex<double>, complex<double> > _t0 = std::operator_plus_pushforward(c1, c2, _d_c1, _d_c2);
// CHECK-NEXT: clad::ValueAndPushforward<complex<double> &, complex<double> &> _t1 = c1.operator_equal_pushforward({{(static_cast<std(::__1)?::complex<double> &&>\(_t0.value\))|(_t0.value)}}, &_d_c1, {{(static_cast<std(::__1)?::complex<double> &&>\(_t0.pushforward\))|(_t0.pushforward)}});
// CHECK-NEXT: clad::ValueAndPushforward<complex<double> &, complex<double> &> _t2 = c1.operator_plus_equal_pushforward(c2, &_d_c1, _d_c2);
// CHECK-NEXT: clad::ValueAndPushforward<complex<double>, complex<double> > _t3 = operator_plus_pushforward(c1, c1, _d_c1, _d_c1);
// CHECK-NEXT: clad::ValueAndPushforward<complex<double>, complex<double> > _t3 = std::operator_plus_pushforward(c1, c1, _d_c1, _d_c1);
// CHECK-NEXT: return _t3.pushforward;
// CHECK-NEXT: }

Expand Down Expand Up @@ -883,13 +883,13 @@ double fn14(double i, double j) {
// CHECK-NEXT: {{.*}}ValueAndPushforward<{{.*}}, {{.*}}> _t1 = clad::custom_derivatives::class_functions::operator_subscript_pushforward(&v, 1, &_d_v, 0);
// CHECK-NEXT: _t1.pushforward = 0 * i + 11 * _d_i;
// CHECK-NEXT: _t1.value = 11 * i;
// CHECK-NEXT: clad::ValueAndPushforward<decltype({{.*}}.begin()), decltype({{.*}}.begin())> _t2 = begin_pushforward(v, _d_v);
// CHECK-NEXT: clad::ValueAndPushforward<decltype({{.*}}.begin()), decltype({{.*}}.begin())> _t2 = std::begin_pushforward(v, _d_v);
// CHECK-NEXT: {{.*}} _d_b = _t2.pushforward;
// CHECK-NEXT: {{.*}} b = _t2.value;
// CHECK-NEXT: clad::ValueAndPushforward<decltype({{.*}}.end()), decltype({{.*}}.end())> _t3 = end_pushforward(v, _d_v);
// CHECK-NEXT: clad::ValueAndPushforward<decltype({{.*}}.end()), decltype({{.*}}.end())> _t3 = std::end_pushforward(v, _d_v);
// CHECK-NEXT: {{.*}} _d_e = _t3.pushforward;
// CHECK-NEXT: {{.*}} e = _t3.value;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t4 = accumulate_pushforward(b, e, 0., _d_b, _d_e, 0.);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t4 = std::accumulate_pushforward(b, e, 0., _d_b, _d_e, 0.);
// CHECK-NEXT: double _d_res = _t4.pushforward;
// CHECK-NEXT: double res = _t4.value;
// CHECK-NEXT: return _d_res;
Expand Down
8 changes: 4 additions & 4 deletions test/NthDerivative/CustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ float test_sin(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = sin_pushforward_pushforward(x, _d_x0, _d_x, _d__d_x);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::sin_pushforward_pushforward(x, _d_x0, _d_x, _d__d_x);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: return _d__t0.pushforward;
Expand All @@ -28,7 +28,7 @@ float test_cos(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = cos_pushforward_pushforward(x, _d_x0, _d_x, _d__d_x);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::cos_pushforward_pushforward(x, _d_x0, _d_x, _d__d_x);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: return _d__t0.pushforward;
Expand All @@ -55,7 +55,7 @@ float test_trig(float x, float y, int a, int b) {
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<double, double>, ValueAndPushforward<double, double> > _t1 = pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<double, double>, ValueAndPushforward<double, double> > _t1 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: ValueAndPushforward<double, double> _d__t1 = _t1.pushforward;
// CHECK-NEXT: ValueAndPushforward<double, double> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = clad::custom_derivatives::std::cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
Expand Down Expand Up @@ -130,7 +130,7 @@ float test_exp(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d_x0 = 1;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = exp_pushforward_pushforward(x * x, _d_x0 * x + x * _d_x0, _d_x * x + x * _d_x, _d__d_x * x + _d_x0 * _d_x + _d_x * _d_x0 + x * _d__d_x);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::exp_pushforward_pushforward(x * x, _d_x0 * x + x * _d_x0, _d_x * x + x * _d_x, _d__d_x * x + _d_x0 * _d_x + _d_x * _d_x0 + x * _d__d_x);
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: return _d__t0.pushforward;
Expand Down

0 comments on commit b36eae9

Please sign in to comment.