Skip to content

Commit

Permalink
Make WpoBuilder faster
Browse files Browse the repository at this point in the history
Summary: This is a behavior-preserving change.

Reviewed By: arnaudvenet

Differential Revision: D66673243

fbshipit-source-id: 44926e36291eae4bb96d6755f060bbb8db00e7bc
  • Loading branch information
Nikolai Tillmann authored and facebook-github-bot committed Dec 16, 2024
1 parent ab3d413 commit fc961d9
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 41 deletions.
6 changes: 4 additions & 2 deletions include/sparta/MonotonicFixpointIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ class ParallelMonotonicFixpointIterator
}

private:
WeakPartialOrdering<NodeId, NodeHash> m_wpo;
WeakPartialOrdering<NodeId, NodeHash, /*Support_is_from_outside=*/false>
m_wpo;
size_t m_num_thread;
std::unordered_set<NodeId> m_all_nodes;
static constexpr size_t ChunkSize = 512;
Expand Down Expand Up @@ -673,7 +674,8 @@ class MonotonicFixpointIterator
}

private:
WeakPartialOrdering<NodeId, NodeHash> m_wpo;
WeakPartialOrdering<NodeId, NodeHash, /*Support_is_from_outside=*/false>
m_wpo;
};

/*
Expand Down
108 changes: 69 additions & 39 deletions include/sparta/WeakPartialOrdering.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ namespace sparta {
template <typename NodeId>
class WpoNode;

template <typename NodeId, typename NodeHash>
template <typename NodeId, typename NodeHash, bool Support_is_from_outside>
class WeakPartialOrdering;

namespace wpo_impl {

template <typename NodeId, typename NodeHash>
template <typename NodeId, typename NodeHash, bool Support_is_from_outside>
class WpoBuilder;

} // end namespace wpo_impl
Expand Down Expand Up @@ -131,7 +131,7 @@ class WpoNode final {
WpoNode& operator=(const WpoNode&) = delete;
WpoNode& operator=(WpoNode&&) = delete;

template <typename T1, typename T2>
template <typename T1, typename T2, bool B>
friend class wpo_impl::WpoBuilder;
}; // end class WpoNode

Expand All @@ -145,7 +145,9 @@ class WpoNode final {
* https://dl.acm.org/ft_gateway.cfm?id=3371082
*
*/
template <typename NodeId, typename NodeHash = std::hash<NodeId>>
template <typename NodeId,
typename NodeHash = std::hash<NodeId>,
bool Support_is_from_outside = true>
class WeakPartialOrdering final {
private:
using WpoNodeT = WpoNode<NodeId>;
Expand Down Expand Up @@ -186,7 +188,7 @@ class WeakPartialOrdering final {
m_post_dfn[root] = 1;
return;
}
wpo_impl::WpoBuilder<NodeId, NodeHash> builder(
wpo_impl::WpoBuilder<NodeId, NodeHash, Support_is_from_outside> builder(
successors, m_nodes, m_toplevel, m_post_dfn, lift);
builder.build(root);
}
Expand Down Expand Up @@ -238,6 +240,7 @@ class WeakPartialOrdering final {
// := !is_from_outside(head, pred) /\ is_predecessor(head, pred)
// This is used in interleaved widening and narrowing.
bool is_from_outside(NodeId head, NodeId pred) {
SPARTA_RUNTIME_CHECK(Support_is_from_outside, undefined_operation());
return get_post_dfn(head) < get_post_dfn(pred);
}

Expand All @@ -249,11 +252,40 @@ class WeakPartialOrdering final {

namespace wpo_impl {

template <typename NodeId, typename NodeHash>
template <class T>
struct VectorMap {
private:
std::vector<T> m_vec;

void recalibrate(size_t i) {
if (i >= m_vec.size()) {
m_vec.resize(i * 2 + 1);
}
}

public:
T& operator[](size_t i) {
recalibrate(i);
return m_vec[i];
}

void set(size_t i, T value) {
recalibrate(i);
m_vec[i] = std::move(value);
}

T get(size_t i) const { return i >= m_vec.size() ? T() : m_vec[i]; }

const T* get_opt(size_t i) const {
return i >= m_vec.size() ? nullptr : &m_vec[i];
}
};

template <typename NodeId, typename NodeHash, bool Support_is_from_outside>
class WpoBuilder final {
private:
using WpoNodeT = WpoNode<NodeId>;
using WpoT = WeakPartialOrdering<NodeId, NodeHash>;
using WpoT = WeakPartialOrdering<NodeId, NodeHash, Support_is_from_outside>;
using Type = typename WpoNodeT::Type;
using WpoIdx = uint32_t;

Expand Down Expand Up @@ -308,55 +340,56 @@ class WpoBuilder final {
p_pmap_t p_pmap(parent_map);
boost::disjoint_sets<r_pmap_t, p_pmap_t> dsets(r_pmap, p_pmap);

std::unordered_map<uint32_t, uint32_t> ancestor;
std::stack<std::tuple<NodeId, bool, uint32_t>> stack;
std::unordered_map<uint32_t, bool> black;
VectorMap<uint32_t> ancestor;
struct StackEntry {
NodeId vertex_ref;
uint32_t pred;
uint32_t finished_vertex;
};
std::stack<StackEntry> stack;
VectorMap<bool> black;

stack.push(std::make_tuple(root, false, 0));
stack.push(StackEntry{root, 0, 0});
while (!stack.empty()) {
// Iterative DFS.
auto& stack_top = stack.top();
auto vertex_ref = std::get<0>(stack_top);
auto finished = std::get<1>(stack_top);
auto pred = std::get<2>(stack_top);
auto [vertex_ref, pred, finished_vertex] = stack.top();
stack.pop();

if (finished) {
// DFS is done with this vertex.
// Set the post DFN.
m_post_dfn[vertex_ref] = m_next_post_dfn++;
if (finished_vertex != 0) {
if (Support_is_from_outside) {
// DFS is done with this vertex.
// Set the post DFN.
m_post_dfn[vertex_ref] = m_next_post_dfn++;
}

auto vertex = get_dfn(vertex_ref);
// Mark visited.
black[vertex] = true;
black.set(finished_vertex, true);

dsets.union_set(vertex, pred);
dsets.union_set(finished_vertex, pred);
ancestor[dsets.find_set(pred)] = pred;
} else {
if (get_dfn(vertex_ref) !=
0 /* means that the vertex is already discovered. */) {
auto& vertex = m_dfn[vertex_ref];
if (vertex != 0 /* means that the vertex is already discovered. */) {
// A forward edge.
// Forward edges can be ignored, as they are redundant.
continue;
}
// New vertex is discovered.
auto vertex = m_next_dfn++;
vertex = m_next_dfn++;
push_ref(vertex_ref);
set_dfn(vertex_ref, vertex);
dsets.make_set(vertex);
ancestor[vertex] = vertex;

// This will be popped after all its successors are finished.
stack.push(std::make_tuple(vertex_ref, true, pred));
stack.push((StackEntry){vertex_ref, pred, vertex});

auto successors = m_successors(vertex_ref);
// Successors vector is reversed to match the order with WTO.
for (auto rit = successors.rbegin(); rit != successors.rend(); ++rit) {
auto succ = get_dfn(*rit);
if (succ == 0 /* 0 means that vertex is undiscovered. */) {
// Newly discovered vertex. Search continues.
stack.push(std::make_tuple(*rit, false, vertex));
} else if (black[succ]) {
stack.push((StackEntry){*rit, vertex, 0});
} else if (black.get(succ)) {
// A cross edge.
auto lca = ancestor[dsets.find_set(succ)];
m_cross_fwds[lca].emplace_back(vertex, succ);
Expand Down Expand Up @@ -406,9 +439,9 @@ class WpoBuilder final {
// In reverse DFS order, build WPOs for SCCs bottom-up.
for (uint32_t h = get_next_dfn() - 1; h > 0; h--) {
// Restore cross/fwd edges which has h as the LCA.
auto it = m_cross_fwds.find(h);
if (it != m_cross_fwds.end()) {
for (auto& edge : it->second) {
auto opt = m_cross_fwds.get_opt(h);
if (opt != nullptr) {
for (auto& edge : *opt) {
// edge: u -> v
auto& u = edge.first;
auto& v = edge.second;
Expand Down Expand Up @@ -564,8 +597,6 @@ class WpoBuilder final {
return 0;
}

void set_dfn(NodeId n, uint32_t num) { m_dfn[n] = num; }

const NodeId& get_ref(uint32_t num) const { return m_ref.at(num - 1); }

void push_ref(NodeId n) { m_ref.push_back(n); }
Expand Down Expand Up @@ -616,12 +647,11 @@ class WpoBuilder final {
// A map from DFN to NodeId.
std::vector<NodeId> m_ref;
// A map from DFN to DFNs of its backedge predecessors.
std::unordered_map<uint32_t, std::vector<uint32_t>> m_back_preds;
VectorMap<std::vector<uint32_t>> m_back_preds;
// A map from DFN to DFNs of its non-backedge predecessors.
std::unordered_map<uint32_t, std::vector<uint32_t>> m_non_back_preds;
VectorMap<std::vector<uint32_t>> m_non_back_preds;
// A map from DFN to cross/forward edges (DFN is the lowest common ancestor).
std::unordered_map<uint32_t, std::vector<std::pair<uint32_t, uint32_t>>>
m_cross_fwds;
VectorMap<std::vector<std::pair<uint32_t, uint32_t>>> m_cross_fwds;
// Increase m_num_outer_preds[x][pair.first] for component C_x that satisfies
// pair.first \in C_x \subseteq C_{pair.second}.
std::vector<std::pair<WpoIdx, WpoIdx>> m_for_outer_preds;
Expand Down

0 comments on commit fc961d9

Please sign in to comment.