From d18b84051b689af330c31ac40fb786520ac1cd7d Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 15 Feb 2024 15:48:34 +0000 Subject: [PATCH] Delay the differentiation process until the end of TU. Before this patch clad attaches itself as a first consumer and applies AD before code generation. However, that is limited since clang sends every top-level declaration to codegen which limits the amount of flexibility clad has. For example, we have to instantiate all pending templates at every HandleTopLevelDecl calls; we cannot really differentiate virtual functions whose classes have sent their key function to CodeGen; and in general we perform actions which are semantically useful for the end of the translation unit. This patch makes clad a single consumer of clang which dispatches to the others. That's done by delaying all calls to the consumers until the end of the TU where clad can replay the exact sequence of calls to the other consumers as if they were directly connected to the frontend. Fixes #248 --- include/clad/Differentiator/Differentiator.h | 30 +-- include/clad/Differentiator/Sins.h | 29 +++ lib/Differentiator/VisitorBase.cpp | 37 +-- test/FirstDerivative/CodeGenSimple.C | 8 + test/Misc/ClangConsumers.cpp | 81 +++++++ test/lit.cfg | 3 + test/lit.site.cfg.in | 1 + tools/ClangPlugin.cpp | 228 +++++++++++++++---- tools/ClangPlugin.h | 220 +++++++++++++++--- 9 files changed, 520 insertions(+), 117 deletions(-) create mode 100644 include/clad/Differentiator/Sins.h create mode 100644 test/Misc/ClangConsumers.cpp diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index b3e708b18..42fae95e2 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -367,9 +367,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { gradient(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code); } /// Specialization for differentiating functors. @@ -384,8 +384,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { gradient(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code, f); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code, f); } /// Generates function which computes hessian matrix of the given function wrt @@ -404,9 +404,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { hessian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction>( - derivedFn /* will be replaced by hessian*/, code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction>( + derivedFn /* will be replaced by hessian*/, code); } /// Specialization for differentiating functors. @@ -421,8 +421,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { hessian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction>( - derivedFn /* will be replaced by hessian*/, code, f); + return CladFunction>( + derivedFn /* will be replaced by hessian*/, code, f); } /// Generates function which computes jacobian matrix of the given function @@ -441,9 +441,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { jacobian(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction>( - derivedFn /* will be replaced by Jacobian*/, code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction>( + derivedFn /* will be replaced by Jacobian*/, code); } /// Specialization for differentiating functors. @@ -458,8 +458,8 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { jacobian(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction>( - derivedFn /* will be replaced by Jacobian*/, code, f); + return CladFunction>( + derivedFn /* will be replaced by Jacobian*/, code, f); } template + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { \ + friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \ + }; \ + template struct Only_##MEMBER; \ + template <> struct Only_##MEMBER { \ + friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); \ + }; \ + template struct CONCATE(MEMBER, \ + __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER< \ + std::remove_reference::type>*)nullptr) + +// NOLINTEND(cppcoreguidelines-macro-usage) + +#endif // CLAD_DIFFERENTIATOR_SINS_H diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32ab9f161 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope*); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/test/FirstDerivative/CodeGenSimple.C b/test/FirstDerivative/CodeGenSimple.C index 02a815c92..4ff77e806 100644 --- a/test/FirstDerivative/CodeGenSimple.C +++ b/test/FirstDerivative/CodeGenSimple.C @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...); int f_1_darg0(int x); +double sq_defined_later(double); + int main() { int x = 4; clad::differentiate(f_1, 0); + auto df = clad::differentiate(sq_defined_later, "x"); printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2 + printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6 return 0; } + +double sq_defined_later(double x) { + return x * x; +} diff --git a/test/Misc/ClangConsumers.cpp b/test/Misc/ClangConsumers.cpp new file mode 100644 index 000000000..210c3060d --- /dev/null +++ b/test/Misc/ClangConsumers.cpp @@ -0,0 +1,81 @@ +// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \ +// RUN: -fms-compatibility -DMS_COMPAT -std=c++14 -fmodules \ +// RUN: -Xclang -print-stats 2>&1 | FileCheck %s +// CHECK-NOT: {{.*error|warning|note:.*}} +// +// RUN: clang -xc -Xclang -add-plugin -Xclang clad -Xclang -load \ +// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \ +// RUN: -Xclang -debug-info-kind=limited -Xclang -triple -Xclang bpf-linux-gnu \ +// RUN: -S -emit-llvm -Xclang -target-cpu -Xclang generic \ +// RUN: -Xclang -print-stats 2>&1 | \ +// RUN: FileCheck -check-prefix=CHECK_C %s +// CHECK_C-NOT: {{.*error|warning|note:.*}} +// XFAIL: clang-7, clang-8, clang-9, target={{i586.*}}, target=arm64-apple-{{.*}} +// +// RUN: clang -xobjective-c -Xclang -add-plugin -Xclang clad -Xclang -load \ +// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \ +// RUN: -Xclang -print-stats 2>&1 | \ +// RUN: FileCheck -check-prefix=CHECK_OBJC %s +// CHECK_OBJC-NOT: {{.*error|warning|note:.*}} + +#ifdef __cplusplus + +#pragma clang module build N + module N {} + #pragma clang module contents + #pragma clang module begin N + struct f { void operator()() const {} }; + template auto vtemplate = f{}; + #pragma clang module end +#pragma clang module endbuild + +#pragma clang module import N + +#ifdef MS_COMPAT +class __single_inheritance IncSingle; +#endif // MS_COMPAT + +struct V { virtual int f(); }; +int V::f() { return 1; } +template T f() { return T(); } +int i = f(); + +// Check if shouldSkipFunctionBody is called. +// RUN: %cladclang -I%S/../../include -fsyntax-only -fmodules \ +// RUN: -Xclang -code-completion-at=%s:%(line-1):1 %s -o - | \ +// RUN: FileCheck -check-prefix=CHECK-CODECOMP %s +// CHECK-CODECOMP: COMPLETION + +// CHECK: HandleImplicitImportDecl +// CHECK: AssignInheritanceModel +// CHECK: HandleTopLevelDecl +// CHECK: HandleCXXImplicitFunctionInstantiation +// CHECK: HandleInterestingDecl +// CHECK: HandleVTable +// CHECK: HandleCXXStaticMemberVarInstantiation + +#endif // __cplusplus + +#ifdef __STDC_VERSION__ // C mode +int i; + +extern char ch; +int test(void) { return ch; } +char ch = 1; + +// CHECK_C: CompleteTentativeDefinition +// CHECK_C: CompleteExternalDeclaration +#endif // __STDC_VERSION__ + +#ifdef __OBJC__ +@interface I +void f(); +@end +// CHECK_OBJC: HandleTopLevelDeclInObjCContainer +#endif // __OBJC__ + +int main() { +#ifdef __cplusplus + vtemplate(); +#endif // __cplusplus +} diff --git a/test/lit.cfg b/test/lit.cfg index 014f22024..8c3271983 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -294,6 +294,9 @@ if.*\[ ?(llvm[^ ]*) ([^ ]*) ?\].*{ if platform.system() not in ['Windows'] or lit_config.getBashPath() != '': config.available_features.add('shell') + +config.available_features.add("clang-{0}".format(config.clang_version_major)) + # Loadable module # FIXME: This should be supplied by Makefile or autoconf. #if sys.platform in ['win32', 'cygwin']: diff --git a/test/lit.site.cfg.in b/test/lit.site.cfg.in index 6c36bb63a..868d2db09 100644 --- a/test/lit.site.cfg.in +++ b/test/lit.site.cfg.in @@ -3,6 +3,7 @@ import sys ## Autogenerated by LLVM/clad configuration. # Do not edit! llvm_version_major = @LLVM_VERSION_MAJOR@ +config.clang_version_major = @CLANG_VERSION_MAJOR@ config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 6ce14e591..8be825c3f 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -91,57 +92,48 @@ namespace clad { CladPlugin::~CladPlugin() {} - // We cannot use HandleTranslationUnit because codegen already emits code on - // HandleTopLevelDecl calls and makes updateCall with no effect. - bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { + ALLOW_ACCESS(MultiplexConsumer, Consumers, + std::vector>); + + void CladPlugin::Initialize(clang::ASTContext& C) { + // We know we have a multiplexer. We commit a sin here by stealing it and + // making the consumer pass-through so that we can delay all operations + // until clad is happy. + + auto& MultiplexC = cast(m_CI.getASTConsumer()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); + assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); + std::vector> StolenConsumers; + + // The range-based for loop in MultiplexConsumer::Initialize has + // dispatched this call. Generally, it is unsafe to delete elements while + // iterating but we know we are in the end of the loop and ::end() won't + // be invalidated. + std::move(RobbedCs.begin(), RobbedCs.end() - 1, + std::back_inserter(StolenConsumers)); + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); + m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); + } + + void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) { if (!CheckBuiltins()) - return true; + return; Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return true; + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); RequestOptions opts{}; SetRequestOptions(opts); - DiffSchedule requests{}; - DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema(), - opts); - - if (requests.empty()) - return true; - - // FIXME: Remove the PerformPendingInstantiations altogether. We should - // somehow make the relevant functions referenced. - // Instantiate all pending for instantiations templates, because we will - // need the full bodies to produce derivatives. - // FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`? - if (!m_PendingInstantiationsInFlight) { - m_PendingInstantiationsInFlight = true; - S.PerformPendingInstantiations(); - m_PendingInstantiationsInFlight = false; - } - - for (DiffRequest& request : requests) - ProcessDiffRequest(request); - return true; // Happiness - } - - void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S, opts); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { Sema& S = m_CI.getSema(); // Required due to custom derivatives function templates that might be // used in the function that we need to derive. + // FIXME: Remove the call to PerformPendingInstantiations(). S.PerformPendingInstantiations(); if (request.Function->getDefinition()) request.Function = request.Function->getDefinition(); @@ -264,6 +256,8 @@ namespace clad { // Call CodeGen only if the produced Decl is a top-most // decl or is contained in a namespace decl. + // FIXME: We could get rid of this by prepending the produced + // derivatives in CladPlugin::HandleTranslationUnitDecl DeclContext* derivativeDC = DerivativeDecl->getDeclContext(); bool isTUorND = derivativeDC->isTranslationUnit() || derivativeDC->isNamespace(); @@ -293,6 +287,68 @@ namespace clad { return nullptr; } + void CladPlugin::SendToMultiplexer() { + for (auto DelayedCall : m_DelayedCalls) { + DeclGroupRef& D = DelayedCall.m_DGR; + switch (DelayedCall.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + m_Multiplexer->HandleCXXStaticMemberVarInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDecl: + m_Multiplexer->HandleTopLevelDecl(D); + break; + case CallKind::HandleInlineFunctionDefinition: + m_Multiplexer->HandleInlineFunctionDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleInterestingDecl: + m_Multiplexer->HandleInterestingDecl(D); + break; + case CallKind::HandleTagDeclDefinition: + m_Multiplexer->HandleTagDeclDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTagDeclRequiredDefinition: + m_Multiplexer->HandleTagDeclRequiredDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + m_Multiplexer->HandleCXXImplicitFunctionInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + break; + case CallKind::HandleImplicitImportDecl: + m_Multiplexer->HandleImplicitImportDecl( + cast(D.getSingleDecl())); + break; + case CallKind::CompleteTentativeDefinition: + m_Multiplexer->CompleteTentativeDefinition( + cast(D.getSingleDecl())); + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + m_Multiplexer->CompleteExternalDeclaration( + cast(D.getSingleDecl())); + break; +#endif + case CallKind::AssignInheritanceModel: + m_Multiplexer->AssignInheritanceModel( + cast(D.getSingleDecl())); + break; + case CallKind::HandleVTable: + m_Multiplexer->HandleVTable(cast(D.getSingleDecl())); + break; + case CallKind::InitializeSema: + m_Multiplexer->InitializeSema(m_CI.getSema()); + break; + }; + } + m_HasMultiplexerProcessedDelayedCalls = true; + } + bool CladPlugin::CheckBuiltins() { // If we have included "clad/Differentiator/Differentiator.h" return. if (m_HasRuntime) @@ -316,18 +372,104 @@ namespace clad { return m_HasRuntime; } - void CladPlugin::SetRequestOptions(RequestOptions& opts) const { - SetTBRAnalysisOptions(m_DO, opts); - } - - void CladPlugin::SetTBRAnalysisOptions(const DifferentiationOptions& DO, - RequestOptions& opts) { + static void SetTBRAnalysisOptions(const DifferentiationOptions& DO, + RequestOptions& opts) { // If user has explicitly specified the mode for TBR analysis, use it. if (DO.EnableTBRAnalysis || DO.DisableTBRAnalysis) opts.EnableTBRAnalysis = DO.EnableTBRAnalysis && !DO.DisableTBRAnalysis; else opts.EnableTBRAnalysis = false; // Default mode. } + + void CladPlugin::SetRequestOptions(RequestOptions& opts) const { + SetTBRAnalysisOptions(m_DO, opts); + } + + void CladPlugin::HandleTranslationUnit(ASTContext& C) { + Sema& S = m_CI.getSema(); + // Restore the TUScope that became a 0 in Sema::ActOnEndOfTranslationUnit. + S.TUScope = m_StoredTUScope; + constexpr bool Enabled = true; + Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); + Sema::LocalEagerInstantiationScope LocalInstantiations(S); + + for (DiffRequest& request : m_DiffSchedule) { + // FIXME: flags have to be set manually since DiffCollector's + // constructor does not have access to m_DO. + request.EnableTBRAnalysis = m_DO.EnableTBRAnalysis; + ProcessDiffRequest(request); + } + // Put the TUScope in a consistent state after clad is done. + S.TUScope = nullptr; + // Force emission of the produced pending template instantiations. + LocalInstantiations.perform(); + GlobalInstantiations.perform(); + + SendToMultiplexer(); + m_Multiplexer->HandleTranslationUnit(C); + } + + void CladPlugin::PrintStats() { + llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n"; + for (const DelayedCallInfo& DCI : m_DelayedCalls) { + llvm::errs() << " "; + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; + break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; + break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; + break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; + break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; + break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; + break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; + break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; + break; +#endif + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; + break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; + break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; + break; + }; + for (const clang::Decl* D : DCI.m_DGR) { + llvm::errs() << " " << D; + if (const auto* ND = dyn_cast(D)) + llvm::errs() << " " << ND->getNameAsString(); + } + llvm::errs() << "\n"; + } + + m_Multiplexer->PrintStats(); + } + } // end namespace plugin clad::CladTimerGroup::CladTimerGroup() diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 47ae51046..7cabc788d 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -13,11 +13,12 @@ #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" -#include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Version.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Sema/SemaConsumer.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -74,47 +75,212 @@ namespace clad { namespace plugin { struct DifferentiationOptions { - DifferentiationOptions() - : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), - DumpDerivedAST(false), GenerateSourceFile(false), - ValidateClangVersion(true), EnableTBRAnalysis(false), - DisableTBRAnalysis(false), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false) {} - - bool DumpSourceFn : 1; - bool DumpSourceFnAST : 1; - bool DumpDerivedFn : 1; - bool DumpDerivedAST : 1; - bool GenerateSourceFile : 1; - bool ValidateClangVersion : 1; - bool EnableTBRAnalysis : 1; - bool DisableTBRAnalysis : 1; - bool CustomEstimationModel : 1; - bool PrintNumDiffErrorInfo : 1; - std::string CustomModelName; + DifferentiationOptions() + : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), + DumpDerivedAST(false), GenerateSourceFile(false), + ValidateClangVersion(true), EnableTBRAnalysis(false), + DisableTBRAnalysis(false), CustomEstimationModel(false), + PrintNumDiffErrorInfo(false) {} + + bool DumpSourceFn : 1; + bool DumpSourceFnAST : 1; + bool DumpDerivedFn : 1; + bool DumpDerivedAST : 1; + bool GenerateSourceFile : 1; + bool ValidateClangVersion : 1; + bool EnableTBRAnalysis : 1; + bool DisableTBRAnalysis : 1; + bool CustomEstimationModel : 1; + bool PrintNumDiffErrorInfo : 1; + std::string CustomModelName; }; - class CladPlugin : public clang::ASTConsumer { + class CladExternalSource : public clang::ExternalSemaSource { + // ExternalSemaSource + void ReadUndefinedButUsed( + llvm::MapVector& Undefined) + override { + // namespace { double f_darg0(double x); } will issue a warning that + // f_darg0 has internal linkage but is not defined. This is because we + // have not yet started to differentiate it. The warning is triggered by + // Sema::ActOnEndOfTranslationUnit before Clad is given control. + // To avoid the warning we should remove the entry from here. + using namespace clang; + Undefined.remove_if([](std::pair P) { + NamedDecl* ND = P.first; + + if (!ND->getDeclName().isIdentifier()) + return false; + + // FIXME: We should replace this comparison with the canonical decl + // from the differentiation plan... + llvm::StringRef Name = ND->getName(); + return Name.contains("_darg") || Name.contains("_grad") || + Name.contains("_hessian") || Name.contains("_jacobian"); + }); + } + }; + class CladPlugin : public clang::SemaConsumer { clang::CompilerInstance& m_CI; DifferentiationOptions m_DO; std::unique_ptr m_DerivativeBuilder; bool m_HasRuntime = false; - bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; CladTimerGroup m_CTG; DerivedFnCollector m_DFC; + DiffSchedule m_DiffSchedule; + enum class CallKind { + HandleCXXStaticMemberVarInstantiation, + HandleTopLevelDecl, + HandleInlineFunctionDefinition, + HandleInterestingDecl, + HandleTagDeclDefinition, + HandleTagDeclRequiredDefinition, + HandleCXXImplicitFunctionInstantiation, + HandleTopLevelDeclInObjCContainer, + HandleImplicitImportDecl, + CompleteTentativeDefinition, +#if CLANG_VERSION_MAJOR > 9 + CompleteExternalDeclaration, +#endif + AssignInheritanceModel, + HandleVTable, + InitializeSema, + }; + struct DelayedCallInfo { + CallKind m_Kind; + clang::DeclGroupRef m_DGR; + DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR) + : m_Kind(K), m_DGR(DGR) {} + DelayedCallInfo(CallKind K, const clang::Decl* D) + : m_Kind(K), m_DGR(const_cast(D)) {} + bool operator==(const DelayedCallInfo& other) const { + if (m_Kind != other.m_Kind) + return false; + + if (std::distance(m_DGR.begin(), m_DGR.end()) != + std::distance(other.m_DGR.begin(), other.m_DGR.end())) + return false; + + clang::Decl* const* first1 = m_DGR.begin(); + clang::Decl* const* first2 = other.m_DGR.begin(); + clang::Decl* const* last1 = m_DGR.end(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + for (; first1 != last1; ++first1, ++first2) + if (!(*first1 == *first2)) + return false; + return true; + } + }; + /// The calls to the main action which clad delayed and will dispatch at + /// then end of the translation unit. + std::vector m_DelayedCalls; + /// The default clang consumers which are called after clad is done. + std::unique_ptr m_Multiplexer; + + /// Have we processed all delayed calls. + bool m_HasMultiplexerProcessedDelayedCalls = false; + + /// The Sema::TUScope to restore in CladPlugin::HandleTranslationUnit. + clang::Scope* m_StoredTUScope = nullptr; + public: CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); ~CladPlugin(); - bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; + // ASTConsumer + void Initialize(clang::ASTContext& Context) override; + void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { + AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); + } + bool HandleTopLevelDecl(clang::DeclGroupRef D) override { + HandleTopLevelDeclForClad(D); + AppendDelayed({CallKind::HandleTopLevelDecl, D}); + return true; // happyness, continue parsing + } + void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); + } + void HandleInterestingDecl(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleInterestingDecl, D}); + } + void HandleTagDeclDefinition(clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclDefinition, D}); + } + void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); + } + void + HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); + } + void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); + } + void HandleImplicitImportDecl(clang::ImportDecl* D) override { + AppendDelayed({CallKind::HandleImplicitImportDecl, D}); + } + void CompleteTentativeDefinition(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteTentativeDefinition, D}); + } +#if CLANG_VERSION_MAJOR > 9 + void CompleteExternalDeclaration(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteExternalDeclaration, D}); + } +#endif + void AssignInheritanceModel(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::AssignInheritanceModel, D}); + } + void HandleVTable(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::HandleVTable, D}); + } + + // Not delayed. + void HandleTranslationUnit(clang::ASTContext& C) override; + + // No need to handle the listeners, they will be handled non-delayed by + // the parent multiplexer. + // + // clang::ASTMutationListener *GetASTMutationListener() override; + // clang::ASTDeserializationListener *GetASTDeserializationListener() + // override; + void PrintStats() override; + + bool shouldSkipFunctionBody(clang::Decl* D) override { + return m_Multiplexer->shouldSkipFunctionBody(D); + } + + // SemaConsumer + void InitializeSema(clang::Sema& S) override { + // We are also a ExternalSemaSource. + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + S.addExternalSource(new CladExternalSource()); // Owned by Sema. + m_StoredTUScope = S.TUScope; + AppendDelayed({CallKind::InitializeSema, nullptr}); + } + void ForgetSema() override { + // ForgetSema is called in the destructor of Sema which is much later + // than where we can process anything. We can't delay this call. + m_Multiplexer->ForgetSema(); + } + + // FIXME: We should hide ProcessDiffRequest when we implement proper + // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: + void AppendDelayed(DelayedCallInfo DCI) { + assert(!m_HasMultiplexerProcessedDelayedCalls); + m_DelayedCalls.push_back(DCI); + } + void SendToMultiplexer(); bool CheckBuiltins(); void SetRequestOptions(RequestOptions& opts) const; - static void SetTBRAnalysisOptions(const DifferentiationOptions& DO, - RequestOptions& opts); - void ProcessTopLevelDecl(clang::Decl* D); + + void ProcessTopLevelDecl(clang::Decl* D) { + DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; + assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); + AppendDelayed(DCI); + } + void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR); }; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, @@ -210,7 +376,7 @@ namespace clad { } PluginASTAction::ActionType getActionType() override { - return AddBeforeMainAction; + return AddAfterMainAction; } }; } // end namespace plugin