From ae07bf2c559157f3525e5a5fdbc7b7c444b1e12f Mon Sep 17 00:00:00 2001 From: Robin Hundt <24554122+robinhundt@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:16:01 +0100 Subject: [PATCH] Refactored base_circuit iter --- crates/seec/src/circuit/base_circuit.rs | 106 +++++++++++------------- 1 file changed, 47 insertions(+), 59 deletions(-) diff --git a/crates/seec/src/circuit/base_circuit.rs b/crates/seec/src/circuit/base_circuit.rs index 482f2de..0a91713 100644 --- a/crates/seec/src/circuit/base_circuit.rs +++ b/crates/seec/src/circuit/base_circuit.rs @@ -14,7 +14,6 @@ use bytemuck::{Pod, Zeroable}; use petgraph::dot::{Config, Dot}; use petgraph::graph::NodeIndex; use petgraph::visit::IntoNodeIdentifiers; -use petgraph::visit::{VisitMap, Visitable}; use petgraph::{Directed, Direction, Graph}; use serde::{Deserialize, Serialize}; use tracing::{debug, info, instrument, trace}; @@ -508,10 +507,9 @@ impl Gate for BaseGate { pub struct BaseLayerIter<'a, G, Idx: GateIdx, W> { circuit: &'a BaseCircuit, inputs_needed_cnt: Vec, + prev_interactive: VecDeque>, to_visit: VecDeque>, next_layer: VecDeque>, - visited: as Visitable>::Map, - added_to_next: as Visitable>::Map, // only used for SIMD circuits // TODO remove entries from hashmap when count recheas 0 inputs_left_to_provide: HashMap, u32>, @@ -549,15 +547,12 @@ impl<'a, Idx: GateIdx, G: Gate, W: Wire> BaseLayerIter<'a, G, Idx, W> { .collect(); let to_visit = VecDeque::new(); let next_layer = circuit.constant_gates.iter().map(|&g| g.into()).collect(); - let visited = circuit.graph.visit_map(); - let added_to_next = circuit.graph.visit_map(); Self { circuit, inputs_needed_cnt, + prev_interactive: VecDeque::new(), to_visit, next_layer, - visited, - added_to_next, inputs_left_to_provide: Default::default(), last_layer_size: (0, 0), gates_produced: 0, @@ -581,10 +576,7 @@ impl<'a, Idx: GateIdx, G: Gate, W: Wire> BaseLayerIter<'a, G, Idx, W> { /// Adds idx to the next layer if it has not been visited pub fn add_to_next_layer(&mut self, idx: NodeIndex) { - if !self.added_to_next.is_visited(&idx) { - self.next_layer.push_back(idx); - self.added_to_next.visit(idx); - } + self.next_layer.push_back(idx); } pub fn is_exhausted(&self) -> bool { @@ -699,71 +691,67 @@ impl<'a, G: Gate, Idx: GateIdx, W: Wire> Iterator for BaseLayerIter<'a, G, Idx, #[tracing::instrument(level = "trace", skip(self), ret)] fn next(&mut self) -> Option { - // TODO this current implementation is confusing -> Refactor let graph = self.circuit.as_graph(); let mut layer = CircuitLayer::with_capacity(self.last_layer_size); std::mem::swap(&mut self.to_visit, &mut self.next_layer); - while let Some(node_idx) = self.to_visit.pop_front() { - // This case handles the interactive gates at the front of to_visit that - // are here because they were `add_to_next_layer` but whose neighbours have not - // had their counts decreased - if self.visited.is_visited(&node_idx) { - let mut neigh_cnt = 0; - for neigh in graph.neighbors(node_idx) { - neigh_cnt += 1; - { - let count = self.inputs_needed_cnt[neigh.index()]; - trace!("Node: {node_idx:?} -> Neigh {neigh:?}: count {count}") - } - self.inputs_needed_cnt[neigh.index()] -= 1; - let inputs_needed = self.inputs_needed_cnt[neigh.index()]; - if inputs_needed == 0 { - self.add_to_visit(neigh); - } + // Unfortunately this is hard to factor into a function on self due to borrowing issues :/ + let update_queue = |node, + inputs_needed: &mut [u32], + inputs_left_to_provide: &mut HashMap<_, _>, + queue: &mut VecDeque<_>| { + let mut neigh_cnt = 0; + for neigh in graph.neighbors(node) { + neigh_cnt += 1; + let inputs_needed = &mut inputs_needed[neigh.index()]; + *inputs_needed -= 1; + if *inputs_needed == 0 { + queue.push_back(neigh); } - if self.circuit.is_simd() { - self.inputs_left_to_provide - .entry(node_idx) - .or_insert(neigh_cnt); - } - continue; } - self.visited.visit(node_idx); if self.circuit.is_simd() { - for neigh in graph.neighbors_directed(node_idx, Direction::Incoming) { - let cnt = self - .inputs_left_to_provide - .get_mut(&neigh) - .expect("inputs_left_to_provide must be initialize"); - *cnt -= 1; - if *cnt == 0 { - layer.freeable_gates.push(neigh.into()); - } - } + inputs_left_to_provide.entry(node).or_insert(neigh_cnt); } + neigh_cnt + }; + + while let Some(node_idx) = self.prev_interactive.pop_front() { + update_queue( + node_idx, + &mut self.inputs_needed_cnt, + &mut self.inputs_left_to_provide, + &mut self.to_visit, + ); + } + + while let Some(node_idx) = self.to_visit.pop_front() { let gate = graph[node_idx].clone(); if gate.is_interactive() { - self.add_to_next_layer(node_idx); layer.push_interactive((gate.clone(), node_idx.into())); + self.prev_interactive.push_back(node_idx); } else { layer.push_non_interactive((gate.clone(), node_idx.into())); - let mut neigh_cnt = 0; - for neigh in graph.neighbors(node_idx) { - neigh_cnt += 1; - self.inputs_needed_cnt[neigh.index()] -= 1; - let inputs_needed = self.inputs_needed_cnt[neigh.index()]; - if inputs_needed == 0 { - self.add_to_visit(neigh) + update_queue( + node_idx, + &mut self.inputs_needed_cnt, + &mut self.inputs_left_to_provide, + &mut self.to_visit, + ); + } + + if self.circuit.is_simd() { + for neigh in graph.neighbors_directed(node_idx, Direction::Incoming) { + let cnt = self + .inputs_left_to_provide + .get_mut(&neigh) + .expect("inputs_left_to_provide is initialized because of topo order"); + *cnt -= 1; + if *cnt == 0 { + layer.freeable_gates.push(neigh.into()); } } - if self.circuit.is_simd() { - self.inputs_left_to_provide - .entry(node_idx) - .or_insert(neigh_cnt); - } } } if layer.is_empty() {