From 39341b1127a6694867f6399d478e1df39389d84c Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Fri, 16 Feb 2024 08:33:43 +0000 Subject: [PATCH] More --- include/clad/Differentiator/Sins.h | 29 ++++++++++++++ lib/Differentiator/VisitorBase.cpp | 37 +++--------------- test/Misc/ClangConsumers.cpp | 8 ++++ tools/ClangPlugin.cpp | 47 +++-------------------- tools/ClangPlugin.h | 61 ++++++++++++++++++++++++++++-- 5 files changed, 105 insertions(+), 77 deletions(-) create mode 100644 include/clad/Differentiator/Sins.h create mode 100644 test/Misc/ClangConsumers.cpp diff --git a/include/clad/Differentiator/Sins.h b/include/clad/Differentiator/Sins.h new file mode 100644 index 000000000..28983d626 --- /dev/null +++ b/include/clad/Differentiator/Sins.h @@ -0,0 +1,29 @@ +#ifndef CLAD_DIFFERENTIATOR_SINS_H +#define CLAD_DIFFERENTIATOR_SINS_H + +#include + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { \ + friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \ + }; \ + template struct Only_##MEMBER; \ + template <> struct Only_##MEMBER { \ + friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); \ + }; \ + template struct CONCATE(MEMBER, \ + __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER< \ + std::remove_reference::type>*)nullptr) + +// NOLINTEND(cppcoreguidelines-macro-usage) + +#endif // CLAD_DIFFERENTIATOR_SINS_H diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32ab9f161 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope*); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/test/Misc/ClangConsumers.cpp b/test/Misc/ClangConsumers.cpp new file mode 100644 index 000000000..cfd977b99 --- /dev/null +++ b/test/Misc/ClangConsumers.cpp @@ -0,0 +1,8 @@ +// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out -Xclang -print-stats 2>&1 | FileCheck %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +// CHECK: HandleTopLevelDecl +int main() { + +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index d678d12e1..4b87eff15 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -121,39 +122,8 @@ namespace clad { CladPlugin::~CladPlugin() {} - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - // Tag used to access MultiplexConsumer::Consumers. - using namespace clang; - struct MultiplexConsumer_Consumers - : TagBase< - MultiplexConsumer_Consumers, - std::vector> MultiplexConsumer::*> { - }; - template struct Rob; - } // namespace + ALLOW_ACCESS(MultiplexConsumer, Consumers, + std::vector>); void CladPlugin::Initialize(clang::ASTContext& C) { // We know we have a multiplexer. We commit a sin here by stealing it and @@ -163,7 +133,7 @@ namespace clad { using namespace clang; auto& MultiplexC = static_cast(m_CI.getASTConsumer()); - auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); std::vector> StolenConsumers; @@ -191,11 +161,6 @@ namespace clad { if (!m_DerivativeBuilder) m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return m_Multiplexer->HandleTopLevelDecl(DGR); // true; - DiffSchedule requests{}; DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); @@ -226,9 +191,7 @@ namespace clad { } void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + m_Multiplexer->HandleTopLevelDecl(DeclGroupRef(D)); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3e99a93b9..8e2035a4e 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -89,7 +89,6 @@ namespace clad { std::unique_ptr m_DerivativeBuilder; bool m_HasRuntime = false; bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; DerivedFnCollector m_DFC; enum class CallKind { HandleCXXStaticMemberVarInstantiation, @@ -190,7 +189,64 @@ namespace clad { // clang::ASTMutationListener *GetASTMutationListener() override; // clang::ASTDeserializationListener *GetASTDeserializationListener() // override; - void PrintStats() override { m_Multiplexer->PrintStats(); } + void PrintStats() override { + llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n"; + for (const DelayedCallInfo& DCI : m_DelayedCalls) { + llvm::errs() << " "; + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; + break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; + break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; + break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; + break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; + break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; + break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; + break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; + break; + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; + break; + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; + break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; + break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; + break; + case CallKind::ForgetSema: + llvm::errs() << "ForgetSema"; + break; + }; + for (const clang::Decl* D : DCI.m_DGR) + llvm::errs() << " " << D; + llvm::errs() << "\n"; + } + + m_Multiplexer->PrintStats(); + } bool shouldSkipFunctionBody(clang::Decl* D) override { return m_Multiplexer->shouldSkipFunctionBody(D); } @@ -205,7 +261,6 @@ namespace clad { m_Multiplexer->ForgetSema(); } - // bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: