Skip to content

Commit

Permalink
Use Clang CFG to analyse control-flow.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 5, 2023
1 parent 5fa310f commit 884d21c
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 561 deletions.
124 changes: 63 additions & 61 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define CLAD_DIFFERENTIATOR_TBRANALYZER_H

#include "clang/AST/StmtVisitor.h"
#include "clang/Analysis/CFG.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/Compatibility.h"

Expand Down Expand Up @@ -107,7 +109,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
std::unordered_map<const llvm::APInt, VarData*, APIntHash, APIntComp>;

struct VarData {
enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE };
enum VarDataType { UNDEFINED =1, FUND_TYPE =2, OBJ_TYPE =3, ARR_TYPE =4, REF_TYPE =5};
union VarDataValue {
bool fundData;
/// objData, arrData are stored as pointers for VarDataValue to take
Expand Down Expand Up @@ -139,6 +141,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
delete pair.second;
}
}

/// Recursively sets all the leaves' bools to isReq.
void setIsRequired(bool isReq = true);
/// Returns true if there is at least one required to store node among
Expand Down Expand Up @@ -167,6 +170,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
void restoreRefs(std::unordered_map<VarData*, VarData*>& refVars);
};

clang::CFGBlock* getCFGBlockByID(unsigned ID);

/// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its
/// corresponding VarData. If the given element of an array does not have a
/// VarData* yet it will be added automatically. If addNonConstIdx==false this
Expand All @@ -191,7 +196,40 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// particular moment.
/// Note: 'this' pointer does not have a declaration so nullptr is used as
/// its key instead.
using VarsData = std::unordered_map<const clang::VarDecl*, VarData*>;
struct VarsData {
std::unordered_map<const clang::VarDecl*, VarData*> data = std::unordered_map<const clang::VarDecl*, VarData*>();
VarsData* prev = nullptr;


VarsData() {}
VarsData(VarsData& other) : data(other.data), prev(other.prev){}

using iterator = std::unordered_map<const clang::VarDecl*, VarData*>::iterator;
iterator begin() {
return data.begin();
}
iterator end() {
return data.end();
}
VarData*& operator[] (const clang::VarDecl* VD) {
return data[VD];
}
iterator find(const clang::VarDecl* VD) {
return data.find(VD);
}
void emplace (const clang::VarDecl* VD, VarData* varsData) {
data.emplace(VD, varsData);
}
void emplace (std::pair<const clang::VarDecl*, VarData*> pair) {
data.emplace(pair);
}

std::unique_ptr<VarsData> collectDataFromPredecessors(VarsData* limit=nullptr);
VarsData* findLowestCommonAncestor(VarsData* other);
void merge(VarsData* mergeData);
};


/// Used to find DeclRefExpr's that will be used in the backwards pass.
/// In order to be marked as required, a variables has to appear in a place
/// where it would have a differential influence and will appear non-linearly
Expand All @@ -201,32 +239,25 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Tells if the variable at a given location is required to store. Basically,
/// is the result of analysis.
std::map<clang::SourceLocation, bool> TBRLocs;
/// Stores VarsData for every branch in control flow (e.g. if-else statements,
/// loops).
std::vector<std::vector<VarsData>> reqStack;

/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
std::vector<int> modeStack;
/// Stores local variables to delete them after exiting the corresponding
/// scope.
/// Note: This is not used every time a new scope is entered. This is only
/// used when merging an if-else statement to get rid of local variables in
/// the then-branch.
std::vector<std::vector<const VarDecl*>> localVarsStack;
std::vector<short> modeStack;

ASTContext* m_Context;

/// The index of the innermost branch corresponding to a loop (used to handle
/// break/continue statements).
size_t innermostLoopLayer = 0;
/// Tells if the current branch should be deleted instead of merged with
/// others. This happens when the branch has a break/continue statement or a
/// return expression in it.
bool deleteCurBranch = false;
/// Loop bodies have to be passed twice. This tells us what pass is currently
/// happening.
bool firstLoopPass = false;
/// Set to true when a non-const index is found while analysing an
std::unique_ptr<clang::CFG> m_CFG;

std::vector<VarsData*> blockData;

std::vector<short> blockPassCounter;

unsigned curBlockID;

std::set<unsigned> CFGQueue;


/// Set to true when a non-const index is found while analysing an
/// array subscript expression.
bool nonConstIndexFound = false;

Expand All @@ -242,32 +273,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// ArraySubscriptExpr* or MemberExpr*.
void setIsRequired(const clang::Expr* E, bool isReq = true);

//// Control Flow
/// Returns the current branch.
VarsData& getCurBranch() { return reqStack.back().back(); }
/// Adds a new layer.
void addLayer() { reqStack.emplace_back(); }
/// Creates a new empty branch.
void addBranch() { reqStack.back().emplace_back(); }
/// Deletes the last branch.
void deleteBranch() {
for (auto& pair : getCurBranch())
delete pair.second;
reqStack.back().pop_back();
}
/// Merges the last layer into the one last branch on the previous layer
/// right and deletes the last layer.
void mergeLayer();
/// Merges the last layer but, unlike the previous method, basically replaces
/// the last branch on the previous layer with the result of merging. After
/// that, removes the last layer.
void mergeLayerOnTop();
/// Merges the branch with index targetBranch into a sourceBranchNum.
/// No branches are deleted.
void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch);
/// Removes local variables from the current branch (uses localVarsStack).
/// This is necessary when merging if-else branches.
void removeLocalVars();
VarsData& getCurBranch() { return *blockData[curBlockID]; }

//// Modes Setters
/// Sets the mode manually
Expand All @@ -287,16 +293,16 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Constructor
TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) {
modeStack.push_back(0);
addLayer();
addBranch();
}

/// Destructor
~TBRAnalyzer() {
for (auto& layer : reqStack)
for (auto& branch : layer)
for (auto& pair : branch)
delete pair.second;
for (auto varsData : blockData) {
for (auto pair : *varsData) {
delete pair.second;
}
delete varsData;
}
}

/// Delete copy/move operators and constructors.
Expand All @@ -311,33 +317,29 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// Visitors
void Analyze(const clang::FunctionDecl* FD);

void VisitCFGBlock(clang::CFGBlock* block);

void Visit(const clang::Stmt* stmt) {
clang::ConstStmtVisitor<TBRAnalyzer, void>::Visit(stmt);
}

void VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
void VisitBinaryOperator(const clang::BinaryOperator* BinOp);
void VisitBreakStmt(const clang::BreakStmt* BS);
void VisitCallExpr(const clang::CallExpr* CE);
void VisitCompoundStmt(const clang::CompoundStmt* CS);
void VisitConditionalOperator(const clang::ConditionalOperator* CO);
void VisitContinueStmt(const clang::ContinueStmt* CS);
void VisitCXXConstructExpr(const clang::CXXConstructExpr* CE);
void VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
void VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE);
void VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
void VisitDeclStmt(const clang::DeclStmt* DS);
void VisitDoStmt(const clang::DoStmt* DS);
void VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
void VisitForStmt(const clang::ForStmt* FS);
void VisitIfStmt(const clang::IfStmt* If);
void VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE);
void VisitInitListExpr(const clang::InitListExpr* ILE);
void VisitMemberExpr(const clang::MemberExpr* ME);
void VisitParenExpr(const clang::ParenExpr* PE);
void VisitReturnStmt(const clang::ReturnStmt* RS);
void VisitUnaryOperator(const clang::UnaryOperator* UnOp);
void VisitWhileStmt(const clang::WhileStmt* WS);

/// FIXME: Make sure these are not necessary
/// Unused Visitors:
Expand Down
Loading

0 comments on commit 884d21c

Please sign in to comment.