Skip to content

Commit

Permalink
Delay the differentiation process until the end of TU.
Browse files Browse the repository at this point in the history
Before this patch clad attaches itself as a first consumer and applies AD before
code generation. However, that is limited since clang sends every top-level
declaration to codegen which limits the amount of flexibility clad has. For
example, we have to instantiate all pending templates at every
HandleTopLevelDecl calls; we cannot really differentiate virtual functions
whose classes have sent their key function to CodeGen; and in general we perform
actions which are semantically useful for the end of the translation unit.

This patch makes clad a single consumer of clang which dispatches to the others.
That's done by delaying all calls to the consumers until the end of the TU where
clad can replay the exact sequence of calls to the other consumers as if they
were directly connected to the frontend.

Fixes #248
  • Loading branch information
vgvassilev committed Feb 18, 2024
1 parent 1231901 commit 6f0fb3a
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 8 deletions.
68 changes: 64 additions & 4 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,70 @@ namespace clad {

CladPlugin::~CladPlugin() {}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};
// Tag used to access MultiplexConsumer::Consumers.
using namespace clang;
struct MultiplexConsumer_Consumers
: TagBase<
MultiplexConsumer_Consumers,
std::vector<std::unique_ptr<ASTConsumer>> MultiplexConsumer::*> {
};
template struct Rob<MultiplexConsumer_Consumers,
&MultiplexConsumer::Consumers>;
} // namespace

void CladPlugin::Initialize(clang::ASTContext& C) {
// We know we have a multiplexer. We commit a sin here by stealing it and
// making the consumer pass-through so that we can delay all operations
// until clad is happy.

using namespace clang;

auto& MultiplexC = static_cast<MultiplexConsumer&>(m_CI.getASTConsumer());
auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers());
assert(RobbedCs.back().get() == this && "Clad is not the last consumer");
std::vector<std::unique_ptr<ASTConsumer>> StolenConsumers;

// The range-based for loop in MultiplexConsumer::Initialize has
// dispatched this call. Generally, it is unsafe to delete elements while
// iterating but we know we are in the end of the loop and ::end() won't
// be invalidated.
for (auto& RC : RobbedCs)
if (RC.get() == this)
RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1);
else
StolenConsumers.push_back(std::move(RC));
m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers)));
}

// We cannot use HandleTranslationUnit because codegen already emits code on
// HandleTopLevelDecl calls and makes updateCall with no effect.
bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) {
AppendDelayed({CallKind::HandleTopLevelDecl, DGR});
if (!CheckBuiltins())
return true;
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

Sema& S = m_CI.getSema();

Expand All @@ -135,13 +194,13 @@ namespace clad {
// if HandleTopLevelDecl was called through clad we don't need to process
// it for diff requests
if (m_HandleTopLevelDeclInternal)
return true;
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());

if (requests.empty())
return true;
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

// FIXME: flags have to be set manually since DiffCollector's constructor
// does not have access to m_DO.
Expand All @@ -162,7 +221,8 @@ namespace clad {

for (DiffRequest& request : requests)
ProcessDiffRequest(request);
return true; // Happiness

return m_Multiplexer->HandleTopLevelDecl(DGR); // Happiness
}

void CladPlugin::ProcessTopLevelDecl(Decl* D) {
Expand Down
121 changes: 117 additions & 4 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/Version.h"

#include "clang/AST/ASTConsumer.h"
#include "clang/AST/Decl.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/Version.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/Frontend/MultiplexConsumer.h"
#include "clang/Sema/SemaConsumer.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/DenseMap.h"
Expand Down Expand Up @@ -82,21 +83,133 @@ namespace clad {
std::string CustomModelName;
};

class CladPlugin : public clang::ASTConsumer {
class CladPlugin : public clang::SemaConsumer {
clang::CompilerInstance& m_CI;
DifferentiationOptions m_DO;
std::unique_ptr<DerivativeBuilder> m_DerivativeBuilder;
bool m_HasRuntime = false;
bool m_PendingInstantiationsInFlight = false;
bool m_HandleTopLevelDeclInternal = false;
DerivedFnCollector m_DFC;
enum class CallKind {
HandleCXXStaticMemberVarInstantiation,
HandleTopLevelDecl,
HandleInlineFunctionDefinition,
HandleInterestingDecl,
HandleTagDeclDefinition,
HandleTagDeclRequiredDefinition,
HandleCXXImplicitFunctionInstantiation,
HandleTopLevelDeclInObjCContainer,
HandleImplicitImportDecl,
CompleteTentativeDefinition,
CompleteExternalDeclaration,
AssignInheritanceModel,
HandleVTable,
InitializeSema,
ForgetSema
};
struct DelayedCallInfo {
CallKind m_Kind;
clang::DeclGroupRef m_DGR;
DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR)
: m_Kind(K), m_DGR(DGR) {}
DelayedCallInfo(CallKind K, const clang::Decl* D)
: m_Kind(K), m_DGR(const_cast<clang::Decl*>(D)) {}
};
std::vector<DelayedCallInfo> m_DelayedCalls;
std::unique_ptr<clang::MultiplexConsumer> m_Multiplexer;

public:
CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO);
~CladPlugin();
bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override;
// ASTConsumer
void Initialize(clang::ASTContext& Context) override;
void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override {
AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D});
m_Multiplexer->HandleCXXStaticMemberVarInstantiation(D);
}
bool HandleTopLevelDecl(clang::DeclGroupRef D) override; /*{
AppendDelayed({CallKind::HandleTopLevelDecl, D});
return true; // happyness, continue parsing
}*/
void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override {
AppendDelayed({CallKind::HandleInlineFunctionDefinition, D});
m_Multiplexer->HandleInlineFunctionDefinition(D);
}
void HandleInterestingDecl(clang::DeclGroupRef D) override {
AppendDelayed({CallKind::HandleInterestingDecl, D});
m_Multiplexer->HandleInterestingDecl(D);
}
void HandleTagDeclDefinition(clang::TagDecl* D) override {
AppendDelayed({CallKind::HandleTagDeclDefinition, D});
m_Multiplexer->HandleTagDeclDefinition(D);
}
void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override {
AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D});
m_Multiplexer->HandleTagDeclRequiredDefinition(D);
}
void
HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override {
AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D});
m_Multiplexer->HandleCXXImplicitFunctionInstantiation(D);
}
void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override {
AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D});
m_Multiplexer->HandleTopLevelDeclInObjCContainer(D);
}
void HandleImplicitImportDecl(clang::ImportDecl* D) override {
AppendDelayed({CallKind::HandleImplicitImportDecl, D});
m_Multiplexer->HandleImplicitImportDecl(D);
}
void CompleteTentativeDefinition(clang::VarDecl* D) override {
AppendDelayed({CallKind::CompleteTentativeDefinition, D});
m_Multiplexer->CompleteTentativeDefinition(D);
}
#if CLANG_VERSION_MAJOR > 9
void CompleteExternalDeclaration(clang::VarDecl* D) override {
AppendDelayed({CallKind::CompleteExternalDeclaration, D});
m_Multiplexer->CompleteExternalDeclaration(D);
}
#endif
void AssignInheritanceModel(clang::CXXRecordDecl* D) override {
AppendDelayed({CallKind::AssignInheritanceModel, D});
m_Multiplexer->AssignInheritanceModel(D);
}
void HandleVTable(clang::CXXRecordDecl* D) override {
AppendDelayed({CallKind::HandleVTable, D});
m_Multiplexer->HandleVTable(D);
}

// Not delayed.
void HandleTranslationUnit(clang::ASTContext& C) override {
m_Multiplexer->HandleTranslationUnit(C);
}
// No need to handle the listeners, they will be handled at non-delayed by
// the parent multiplexer.
//
// clang::ASTMutationListener *GetASTMutationListener() override;
// clang::ASTDeserializationListener *GetASTDeserializationListener()
// override;
void PrintStats() override { m_Multiplexer->PrintStats(); }
bool shouldSkipFunctionBody(clang::Decl* D) override {
return m_Multiplexer->shouldSkipFunctionBody(D);
}

// SemaConsumer
void InitializeSema(clang::Sema& S) override {
AppendDelayed({CallKind::InitializeSema, nullptr});
m_Multiplexer->InitializeSema(S);
}
void ForgetSema() override {
AppendDelayed({CallKind::ForgetSema, nullptr});
m_Multiplexer->ForgetSema();
}

// bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override;
clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request);

private:
void AppendDelayed(DelayedCallInfo DCI) { m_DelayedCalls.push_back(DCI); }
bool CheckBuiltins();
void ProcessTopLevelDecl(clang::Decl* D);
};
Expand Down Expand Up @@ -179,7 +292,7 @@ namespace clad {
}

PluginASTAction::ActionType getActionType() override {
return AddBeforeMainAction;
return AddAfterMainAction;
}
};
} // end namespace plugin
Expand Down

0 comments on commit 6f0fb3a

Please sign in to comment.