Skip to content

Commit

Permalink
Only one region per Node
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasAlaif committed Sep 28, 2023
1 parent 0e5f84f commit 7610848
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 96 deletions.
216 changes: 164 additions & 52 deletions mir-state-analysis/src/coupling_graph/impl/cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use prusti_rustc_interface::{
index::bit_set::BitSet,
dataflow::fmt::DebugWithContext, index::IndexVec, middle::mir::Local,
borrowck::{borrow_set::BorrowData, consumers::BorrowIndex},
middle::{mir::{ConstraintCategory, RETURN_PLACE, Location}, ty::{RegionVid, TyKind}},
middle::{mir::{BasicBlock, ConstraintCategory, RETURN_PLACE, Location, Operand}, ty::{RegionVid, TyKind}},
};

use crate::{
Expand All @@ -32,11 +32,93 @@ pub struct Regions<'a, 'tcx> {

pub version: usize,
pub(crate) graph: Graph<'a, 'tcx>,
pub path_condition: Vec<PathCondition>,
// If I'm here,
pub path_condition: Vec<PathConditionMap<'tcx>>,
}

pub type NodeId = usize;

#[derive(Debug, Clone)]
pub struct PathConditionMap<'tcx> {
pub branch_arg: Operand<'tcx>,
pub branch_tgts: FxHashMap<BasicBlock, Option<u128>>,
}

pub struct PathConditionCollection(FxHashMap<Vec<PathCondition>, Vec<PathCondition>>);
impl PathConditionCollection {
pub fn new() -> Self {
Self(FxHashMap::default())
}
pub fn push(&mut self, pc: PathCondition) {
// let new = self.0.drain().flat_map(|(mut k, mut v)| {
// match v.pop() {
// None => Vec::new(),
// Some(last) => {
// v.into_iter().map(|mut v| {
// let mut k = k.clone();
// k.push(v);
// (k, pc.clone())
// }).collect()
// k.push(last);

// }
// }
// if v.len() == 0 {
// Vec::new()
// } else {
// v.pop().un
// }
// if v.len() == 1 {
// k.push(v.pop().unwrap());
// std::iter::once((k, pc.clone()))
// }
// k.push(v);
// (k, pc.clone())
// }).collect();
// self.0 = new;
}
// pub fn set_value_all(&mut self, value: Option<u128>) {
// self.0.
// }
}

pub type BitSetValues = u128;
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PathCondition {
pub bb: BasicBlock,
pub values: BitSetValues,
pub all_values: u128,
pub other_values: Vec<bool>,
}

impl PathCondition {
pub fn new(bb: BasicBlock, values_len: usize) -> Self {
let (all_values, other_values) = if values_len < 127 {
((1 << (values_len+1)) - 1, Vec::new())
} else {
(u128::MAX, vec![false; values_len - 127])
};
Self {
bb,
values: 0,
all_values,
other_values,
}
}
pub fn set_value(&mut self, value: Option<u128>) -> bool {
match value {
None => self.values |= 1,
Some(v) if v < 127 => self.values |= 1 << (v+1),
Some(v) => self.other_values[TryInto::<usize>::try_into(v).unwrap() - 127] = true,
};
self.all_values_set()
}
fn all_values_set(&self) -> bool {
assert!(self.values <= self.all_values);
self.values == self.all_values && self.other_values.iter().all(|b| *b)
}
}

#[derive(Clone)]
pub struct Graph<'a, 'tcx> {
pub id: Option<String>,
Expand All @@ -61,9 +143,6 @@ impl Debug for Graph<'_, '_> {
}
}

#[derive(Debug, Clone)]
pub struct PathCondition;

impl PartialEq for Graph<'_, '_> {
fn eq(&self, other: &Self) -> bool {
self.nodes == other.nodes
Expand Down Expand Up @@ -153,11 +232,11 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
}
// r1 outlives r2, or `r1: r2` (i.e. r1 gets blocked)
#[tracing::instrument(name = "Graph::outlives", level = "trace", skip(self), ret)]
pub fn outlives(&mut self, r1: RegionVid, r2: RegionVid, reason: ConstraintCategory<'tcx>) -> bool {
pub fn outlives(&mut self, r1: RegionVid, r2: RegionVid, reason: EdgeInfo<'tcx>) -> bool {
self.outlives_many(r1, r2, &FxHashSet::from_iter([reason].into_iter()))
}
#[tracing::instrument(name = "Graph::outlives_many", level = "trace", skip(self), ret)]
pub fn outlives_many(&mut self, r1: RegionVid, r2: RegionVid, reasons: &FxHashSet<ConstraintCategory<'tcx>>) -> bool {
pub fn outlives_many(&mut self, r1: RegionVid, r2: RegionVid, reasons: &FxHashSet<EdgeInfo<'tcx>>) -> bool {
// eprintln!("Outlives: {:?} -> {:?} ({:?})", r1, r2, reasons);
let Some(n2) = self.region_to_node(r2) else {
// `r2` is static
Expand Down Expand Up @@ -201,17 +280,30 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
// self.sanity_check();
}
#[tracing::instrument(name = "Graph::remove", level = "trace", skip(self))]
pub fn remove(&mut self, r: RegionVid, maybe_already_removed: bool) {
// Set `remove_dangling_children` when removing regions which are not tracked by the regular borrowck,
// to remove in e.g. `let y: &'a i32 = &'b *x;` the region `'b` when removing `'a` (if `x: &'c i32`).
// NOTE: Maybe shouldn't be set, since it seems that the regular borrowck does not kill off `'b` this eagerly (if `x: &'c mut i32`).
pub fn remove(&mut self, r: RegionVid, maybe_already_removed: bool, remove_dangling_children: bool) {
if self.static_regions.remove(&r) {
return;
}
for n in self.nodes.iter_mut() {
if let Some(n) = n {
if n.regions.contains(&r) {
n.regions.remove(&r);
if n.regions.is_empty() {
let id = n.id;
self.remove_node_rejoin(id);
// if n.regions.contains(&r) {
// n.regions.remove(&r);
// if n.regions.is_empty() {
// let id = n.id;
// self.remove_node_rejoin(id);
// }
// return;
// }
if n.region == r {
let id = n.id;
let n = self.remove_node_rejoin(id);
if remove_dangling_children {
for (block, _) in n.blocks {
self.remove_if_dangling_and_sh_borrow(block);
}
}
return;
}
Expand All @@ -220,20 +312,29 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
assert!(maybe_already_removed, "Region {:?} not found in graph", r);
// self.sanity_check();
}
// Used when merging two graphs (and we know from one graph that two regions are equal)
#[tracing::instrument(name = "Graph::equate_regions", level = "trace", skip(self), ret)]
pub fn equate_regions(&mut self, ra: RegionVid, rb: RegionVid) -> bool {
let mut changed = self.outlives(ra, rb, ConstraintCategory::Internal);
changed = changed || self.outlives(rb, ra, ConstraintCategory::Internal);
// self.sanity_check();
changed
fn remove_if_dangling_and_sh_borrow(&mut self, n: NodeId) {
let node = self.get_node(n);
if node.blocked_by.is_empty() && self.cgx.sbs.location_map.values().any(|data| data.region == node.region) {
self.remove_node(n);
}
}

// // Used when merging two graphs (and we know from one graph that two regions are equal)
// #[tracing::instrument(name = "Graph::equate_regions", level = "trace", skip(self), ret)]
// pub fn equate_regions(&mut self, ra: RegionVid, rb: RegionVid) -> bool {
// let mut changed = self.outlives(ra, rb, ConstraintCategory::Internal);
// changed = changed || self.outlives(rb, ra, ConstraintCategory::Internal);
// // self.sanity_check();
// changed
// }
#[tracing::instrument(name = "Graph::equate_regions", level = "trace", skip(self), ret)]
pub fn edge_to_regions(&self, from: NodeId, to: NodeId) -> (RegionVid, RegionVid) {
let n1 = self.get_node(from);
let n2 = self.get_node(to);
let r1 = *n1.regions.iter().next().unwrap();
let r2 = *n2.regions.iter().next().unwrap();
// let r1 = *n1.regions.iter().next().unwrap();
// let r2 = *n2.regions.iter().next().unwrap();
let r1 = n1.region;
let r2 = n2.region;
(r1, r2)
}
#[tracing::instrument(name = "Graph::equate_regions", level = "trace", skip(self), ret)]
Expand Down Expand Up @@ -264,7 +365,8 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
// If there is a cycle we could have already removed and made static
if let Some(mut node) = self.remove_node(n) {
// println!("Making static node: {node:?}");
self.static_regions.extend(node.regions.drain());
// self.static_regions.extend(node.regions.drain());
self.static_regions.insert(node.region);
for (block_by, _) in node.blocked_by.drain() {
// assert!(node.blocked_by.is_empty());
self.set_borrows_from_static_node(block_by);
Expand Down Expand Up @@ -300,7 +402,8 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
let mut last_none = self.nodes.len();
for (i, n) in self.nodes.iter().enumerate() {
if let Some(n) = n {
if n.regions.contains(&r) {
// if n.regions.contains(&r) {
if n.region == r {
return Some(i);
}
} else {
Expand All @@ -316,21 +419,21 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
}
fn merge(&mut self, n1: NodeId, n2: NodeId) {
panic!("{:?} and {:?}", self.get_node(n1), self.get_node(n2)); // Should never need to merge!
assert_ne!(n1, n2);
let to_merge = self.remove_node(n1).unwrap();
for (block, edge) in to_merge.blocks {
if block != n2 {
self.blocks(n2, block, &edge);
}
}
for (block_by, edge) in to_merge.blocked_by {
if block_by != n2 {
self.blocks(block_by, n2, &edge);
}
}
let n2 = self.get_node_mut(n2);
// n2.contained_by.extend(to_merge.contained_by);
n2.regions.extend(to_merge.regions);
// assert_ne!(n1, n2);
// let to_merge = self.remove_node(n1).unwrap();
// for (block, edge) in to_merge.blocks {
// if block != n2 {
// self.blocks(n2, block, &edge);
// }
// }
// for (block_by, edge) in to_merge.blocked_by {
// if block_by != n2 {
// self.blocks(block_by, n2, &edge);
// }
// }
// let n2 = self.get_node_mut(n2);
// // n2.contained_by.extend(to_merge.contained_by);
// n2.regions.extend(to_merge.regions);
}
fn kill_node(&mut self, n: NodeId) {
let removed = self.remove_node(n).unwrap();
Expand Down Expand Up @@ -376,7 +479,7 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
fn get_node_mut(&mut self, n: NodeId) -> &mut Node<'tcx> {
self.nodes[n].as_mut().unwrap()
}
fn blocks(&mut self, n1: NodeId, n2: NodeId, reason: &FxHashSet<ConstraintCategory<'tcx>>) -> bool {
fn blocks(&mut self, n1: NodeId, n2: NodeId, reason: &FxHashSet<EdgeInfo<'tcx>>) -> bool {
assert_ne!(n1, n2);
let mut changed = false;
let reasons = self.get_node_mut(n1).blocks.entry(n2).or_default();
Expand All @@ -397,31 +500,39 @@ impl<'a, 'tcx> Graph<'a, 'tcx> {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Node<'tcx> {
pub id: NodeId,
pub regions: FxHashSet<RegionVid>,
pub blocks: FxHashMap<NodeId, FxHashSet<ConstraintCategory<'tcx>>>,
pub blocked_by: FxHashMap<NodeId, FxHashSet<ConstraintCategory<'tcx>>>,
// pub regions: FxHashSet<RegionVid>,
pub region: RegionVid,
pub blocks: FxHashMap<NodeId, FxHashSet<EdgeInfo<'tcx>>>,
pub blocked_by: FxHashMap<NodeId, FxHashSet<EdgeInfo<'tcx>>>,
pub borrows_from_static: bool,
// pub contained_by: Vec<Local>,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct EdgeInfo<'tcx> {
pub creation: BasicBlock,
pub reason: ConstraintCategory<'tcx>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Edge<'tcx> {
pub from: NodeId,
pub to: NodeId,
pub reasons: FxHashSet<ConstraintCategory<'tcx>>,
pub reasons: FxHashSet<EdgeInfo<'tcx>>,
}

impl<'tcx> Edge<'tcx> {
pub(crate) fn new(from: NodeId, to: NodeId, reasons: FxHashSet<ConstraintCategory<'tcx>>) -> Self {
pub(crate) fn new(from: NodeId, to: NodeId, reasons: FxHashSet<EdgeInfo<'tcx>>) -> Self {
Self { from, to, reasons }
}
}

impl<'tcx> Node<'tcx> {
pub fn new(id: NodeId, r: RegionVid) -> Self {
pub fn new(id: NodeId, region: RegionVid) -> Self {
Self {
id,
regions: [r].into_iter().collect(),
// regions: [r].into_iter().collect(),
region,
blocks: FxHashMap::default(),
blocked_by: FxHashMap::default(),
borrows_from_static: false,
Expand Down Expand Up @@ -501,13 +612,13 @@ impl<'tcx> Regions<'_, 'tcx> {
self.output_to_dot("log/coupling/error.dot");
panic!("{node:?}");
} else {
let r = *node.regions.iter().next().unwrap();
if universal_region.contains(&r) {
// let r = *node.regions.iter().next().unwrap();
if universal_region.contains(&node.region) {
continue;
}

let proof = outlives.iter().find(|c| {
universal_region.contains(&c.sub) && c.sup == r
universal_region.contains(&c.sub) && c.sup == node.region
// let r = c.sub.as_u32(); // The thing that lives shorter
// r == 0 || r == 1 // `0` means that it's static, `1` means that it's the function region
});
Expand Down Expand Up @@ -545,13 +656,14 @@ impl<'tcx> Regions<'_, 'tcx> {
pub fn sanity_check(&self) {
let mut all = self.graph.static_regions.clone();
for n in self.graph.all_nodes() {
for r in &n.regions {
// for r in &n.regions {
let r = &n.region;
let contained = all.insert(*r);
if !contained {
self.output_to_dot("log/coupling/error.dot");
panic!();
}
}
// }
}

for n1 in self.graph.all_nodes().map(|n| n.id) {
Expand Down
Loading

0 comments on commit 7610848

Please sign in to comment.