From 3b105423d8341f2964b15dc70344878caf7c8338 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 15 Feb 2024 15:48:34 +0000 Subject: [PATCH] Delay the differentiation process until the end of TU. Before this patch clad attaches itself as a first consumer and applies AD before code generation. However, that is limited since clang sends every top-level declaration to codegen which limits the amount of flexibility clad has. For example, we have to instantiate all pending templates at every HandleTopLevelDecl calls; we cannot really differentiate virtual functions whose classes have sent their key function to CodeGen; and in general we perform actions which are semantically useful for the end of the translation unit. This patch makes clad a single consumer of clang which dispatches to the others. That's done by delaying all calls to the consumers until the end of the TU where clad can replay the exact sequence of calls to the other consumers as if they were directly connected to the frontend. Fixes #248 --- tools/ClangPlugin.cpp | 68 ++++++++++++++++++++++-- tools/ClangPlugin.h | 121 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 181 insertions(+), 8 deletions(-) diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index c4ab7039a..d678d12e1 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -121,11 +121,70 @@ namespace clad { CladPlugin::~CladPlugin() {} + // A facility allowing us to access the private member CurScope of the Sema + // object using standard-conforming C++. + namespace { + template struct Rob { + friend typename Tag::type get(Tag) { return M; } + }; + + template struct TagBase { + using type = Member; +#ifdef MSVC +#pragma warning(push, 0) +#endif // MSVC +#pragma GCC diagnostic push +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#endif // __clang__ +#pragma GCC diagnostic ignored "-Wnon-template-friend" + friend type get(Tag); +#pragma GCC diagnostic pop +#ifdef MSVC +#pragma warning(pop) +#endif // MSVC + }; + // Tag used to access MultiplexConsumer::Consumers. + using namespace clang; + struct MultiplexConsumer_Consumers + : TagBase< + MultiplexConsumer_Consumers, + std::vector> MultiplexConsumer::*> { + }; + template struct Rob; + } // namespace + + void CladPlugin::Initialize(clang::ASTContext& C) { + // We know we have a multiplexer. We commit a sin here by stealing it and + // making the consumer pass-through so that we can delay all operations + // until clad is happy. + + using namespace clang; + + auto& MultiplexC = static_cast(m_CI.getASTConsumer()); + auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers()); + assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); + std::vector> StolenConsumers; + + // The range-based for loop in MultiplexConsumer::Initialize has + // dispatched this call. Generally, it is unsafe to delete elements while + // iterating but we know we are in the end of the loop and ::end() won't + // be invalidated. + for (auto& RC : RobbedCs) + if (RC.get() == this) + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); + else + StolenConsumers.push_back(std::move(RC)); + m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); + } + // We cannot use HandleTranslationUnit because codegen already emits code on // HandleTopLevelDecl calls and makes updateCall with no effect. bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { + AppendDelayed({CallKind::HandleTopLevelDecl, DGR}); if (!CheckBuiltins()) - return true; + return m_Multiplexer->HandleTopLevelDecl(DGR); // true; Sema& S = m_CI.getSema(); @@ -135,13 +194,13 @@ namespace clad { // if HandleTopLevelDecl was called through clad we don't need to process // it for diff requests if (m_HandleTopLevelDeclInternal) - return true; + return m_Multiplexer->HandleTopLevelDecl(DGR); // true; DiffSchedule requests{}; DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); if (requests.empty()) - return true; + return m_Multiplexer->HandleTopLevelDecl(DGR); // true; // FIXME: flags have to be set manually since DiffCollector's constructor // does not have access to m_DO. @@ -162,7 +221,8 @@ namespace clad { for (DiffRequest& request : requests) ProcessDiffRequest(request); - return true; // Happiness + + return m_Multiplexer->HandleTopLevelDecl(DGR); // Happiness } void CladPlugin::ProcessTopLevelDecl(Decl* D) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3394a0d38..3e99a93b9 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -13,11 +13,12 @@ #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" -#include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Version.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Sema/SemaConsumer.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/DenseMap.h" @@ -82,7 +83,7 @@ namespace clad { std::string CustomModelName; }; - class CladPlugin : public clang::ASTConsumer { + class CladPlugin : public clang::SemaConsumer { clang::CompilerInstance& m_CI; DifferentiationOptions m_DO; std::unique_ptr m_DerivativeBuilder; @@ -90,13 +91,125 @@ namespace clad { bool m_PendingInstantiationsInFlight = false; bool m_HandleTopLevelDeclInternal = false; DerivedFnCollector m_DFC; + enum class CallKind { + HandleCXXStaticMemberVarInstantiation, + HandleTopLevelDecl, + HandleInlineFunctionDefinition, + HandleInterestingDecl, + HandleTagDeclDefinition, + HandleTagDeclRequiredDefinition, + HandleCXXImplicitFunctionInstantiation, + HandleTopLevelDeclInObjCContainer, + HandleImplicitImportDecl, + CompleteTentativeDefinition, + CompleteExternalDeclaration, + AssignInheritanceModel, + HandleVTable, + InitializeSema, + ForgetSema + }; + struct DelayedCallInfo { + CallKind m_Kind; + clang::DeclGroupRef m_DGR; + DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR) + : m_Kind(K), m_DGR(DGR) {} + DelayedCallInfo(CallKind K, const clang::Decl* D) + : m_Kind(K), m_DGR(const_cast(D)) {} + }; + std::vector m_DelayedCalls; + std::unique_ptr m_Multiplexer; + public: CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); ~CladPlugin(); - bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; + // ASTConsumer + void Initialize(clang::ASTContext& Context) override; + void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { + AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); + m_Multiplexer->HandleCXXStaticMemberVarInstantiation(D); + } + bool HandleTopLevelDecl(clang::DeclGroupRef D) override; /*{ + AppendDelayed({CallKind::HandleTopLevelDecl, D}); + return true; // happyness, continue parsing + }*/ + void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); + m_Multiplexer->HandleInlineFunctionDefinition(D); + } + void HandleInterestingDecl(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleInterestingDecl, D}); + m_Multiplexer->HandleInterestingDecl(D); + } + void HandleTagDeclDefinition(clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclDefinition, D}); + m_Multiplexer->HandleTagDeclDefinition(D); + } + void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); + m_Multiplexer->HandleTagDeclRequiredDefinition(D); + } + void + HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); + m_Multiplexer->HandleCXXImplicitFunctionInstantiation(D); + } + void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + } + void HandleImplicitImportDecl(clang::ImportDecl* D) override { + AppendDelayed({CallKind::HandleImplicitImportDecl, D}); + m_Multiplexer->HandleImplicitImportDecl(D); + } + void CompleteTentativeDefinition(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteTentativeDefinition, D}); + m_Multiplexer->CompleteTentativeDefinition(D); + } +#if CLANG_VERSION_MAJOR > 9 + void CompleteExternalDeclaration(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteExternalDeclaration, D}); + m_Multiplexer->CompleteExternalDeclaration(D); + } +#endif + void AssignInheritanceModel(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::AssignInheritanceModel, D}); + m_Multiplexer->AssignInheritanceModel(D); + } + void HandleVTable(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::HandleVTable, D}); + m_Multiplexer->HandleVTable(D); + } + + // Not delayed. + void HandleTranslationUnit(clang::ASTContext& C) override { + m_Multiplexer->HandleTranslationUnit(C); + } + // No need to handle the listeners, they will be handled at non-delayed by + // the parent multiplexer. + // + // clang::ASTMutationListener *GetASTMutationListener() override; + // clang::ASTDeserializationListener *GetASTDeserializationListener() + // override; + void PrintStats() override { m_Multiplexer->PrintStats(); } + bool shouldSkipFunctionBody(clang::Decl* D) override { + return m_Multiplexer->shouldSkipFunctionBody(D); + } + + // SemaConsumer + void InitializeSema(clang::Sema& S) override { + AppendDelayed({CallKind::InitializeSema, nullptr}); + m_Multiplexer->InitializeSema(S); + } + void ForgetSema() override { + AppendDelayed({CallKind::ForgetSema, nullptr}); + m_Multiplexer->ForgetSema(); + } + + // bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: + void AppendDelayed(DelayedCallInfo DCI) { m_DelayedCalls.push_back(DCI); } bool CheckBuiltins(); void ProcessTopLevelDecl(clang::Decl* D); }; @@ -179,7 +292,7 @@ namespace clad { } PluginASTAction::ActionType getActionType() override { - return AddBeforeMainAction; + return AddAfterMainAction; } }; } // end namespace plugin