Skip to content

Commit

Permalink
Add spec map generation to all expression call sites
Browse files Browse the repository at this point in the history
  • Loading branch information
lenianiva committed Nov 25, 2024
1 parent 342d5d3 commit 839483e
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 102 deletions.
39 changes: 28 additions & 11 deletions source/rust_verify/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,30 @@ impl<'a> OpGenerator<'a> {
return Ok(vec![]);
};

let (commands, snap_map) =
vir::sst_to_air_func::func_sst_to_air(self.ctx, &function, func_check_sst)?;
let specs = self.specializations.get(fun).unwrap_or("Specialization does not exist");

let results: Vec<_> = specs
.iter()
.map(|spec| {
let spec_map = spec.create_spec_map(&function.x.typ_params);
let (commands, snap_map) = vir::sst_to_air_func::func_sst_to_air(
self.ctx,
&function,
func_check_sst,
&spec_map,
)?;
Ok(Op::query(
QueryOp::Body(style),
commands,
snap_map,
&function,
Some(func_check_sst.clone()),
))
})
.collect::<Result<_, VirErr>>()?;
self.ctx.fun = None;

Ok(vec![Op::query(
QueryOp::Body(style),
commands,
snap_map,
&function,
Some(func_check_sst.clone()),
)])
Ok(results)
}

fn handle_proof_body_expand(
Expand All @@ -331,8 +343,13 @@ impl<'a> OpGenerator<'a> {
) -> Result<Op, VirErr> {
self.ctx.fun = mk_fun_ctx(&function, false /*recommend*/);

let (commands, snap_map) =
vir::sst_to_air_func::func_sst_to_air(self.ctx, &function, &expanded_function_sst)?;
let fun = &function.x.name;
let (commands, snap_map) = vir::sst_to_air_func::func_sst_to_air(
self.ctx,
&function,
&expanded_function_sst,
&Default::default(),
)?;
let commands = focus_commands_with_context_on_assert_id(commands, assert_id);

self.ctx.fun = None;
Expand Down
42 changes: 22 additions & 20 deletions source/vir/src/mono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::ast::Idents;
use crate::ast::IntRange;
use crate::ast::Primitive;
use crate::ast::{Dt, Fun, TypDecoration, TypDecorationArg};
use crate::def::{path_to_string, Spanned};
use crate::def::POLY;
use crate::def::{path_to_string, Spanned};
use crate::poly;
use crate::sst::{CallFun, Exp, ExpX, KrateSstX, Stm};
use crate::sst::{Par, ParX};
Expand Down Expand Up @@ -116,7 +116,7 @@ impl SpecTypX {
}
}
}
fn typs_as_spec(typs: &Typs, spec_map: &SpecMap<'_>) -> Vec<SpecTyp> {
fn typs_as_spec(typs: &Typs, spec_map: &SpecMap) -> Vec<SpecTyp> {
let mut spec_typs: Vec<SpecTyp> = Vec::new();
for typ in typs.iter() {
let spec_typ = typ_as_spec(typ, spec_map);
Expand All @@ -125,7 +125,7 @@ fn typs_as_spec(typs: &Typs, spec_map: &SpecMap<'_>) -> Vec<SpecTyp> {
spec_typs
}

pub(crate) fn typ_as_spec(typ: &Typ, spec_map: &SpecMap<'_>) -> SpecTyp {
pub(crate) fn typ_as_spec(typ: &Typ, spec_map: &SpecMap) -> SpecTyp {
match &**typ {
TypX::Bool => Arc::new(SpecTypX::Bool),
TypX::Int(range) => Arc::new(SpecTypX::Int(*range)),
Expand Down Expand Up @@ -158,12 +158,10 @@ pub(crate) fn typ_as_spec(typ: &Typ, spec_map: &SpecMap<'_>) -> SpecTyp {
};
(*spec_typ).clone()
}
TypX::Boxed(..) | TypX::SpecFn(..) | TypX::FnDef(..) => {
Arc::new(SpecTypX::Poly)
}
TypX::Boxed(..) | TypX::SpecFn(..) | TypX::FnDef(..) => Arc::new(SpecTypX::Poly),
TypX::ConstInt(_) => Arc::new(SpecTypX::Poly),
TypX::Projection { .. } => Arc::new(SpecTypX::Poly),
TypX::Poly => Arc::new(SpecTypX::Poly)
TypX::Poly => Arc::new(SpecTypX::Poly),
}
}

Expand All @@ -181,7 +179,7 @@ impl Specialization {
pub fn empty() -> Self {
Self { typs: Arc::new(vec![]) }
}
pub fn from_exp<'a>(exp: &'a ExpX, spec_map: &SpecMap<'_>) -> Option<(&'a Fun, Self)> {
pub fn from_exp<'a>(exp: &'a ExpX, spec_map: &SpecMap) -> Option<(&'a Fun, Self)> {
let ExpX::Call(CallFun::Fun(fun, _) | CallFun::Recursive(fun), typs, _) = exp else {
return None;
};
Expand Down Expand Up @@ -263,10 +261,16 @@ impl Specialization {
format!(" specialized to {:?}", &self.typs)
}

pub fn is_empty(&self) -> bool { self.typs.is_empty() }
pub fn is_empty(&self) -> bool {
self.typs.is_empty()
}

pub fn create_spec_map(&self, typ_params: &Idents) -> SpecMap {
assert!(self.is_empty() || self.typs.len() == typ_params.len());
std::iter::zip(typ_params.iter().cloned(), self.typs.iter().cloned()).collect()
}
}
impl Default for Specialization
{
impl Default for Specialization {
fn default() -> Self {
return Self::empty();
}
Expand All @@ -283,7 +287,7 @@ struct SpecializationVisitor<'a> {
spec_map: &'a SpecMap,
}
impl<'a> SpecializationVisitor<'a> {
fn new(spec_map: &SpecMap) -> Self {
fn new(spec_map: &'a SpecMap) -> Self {
Self { invocations: vec![], spec_map }
}
}
Expand All @@ -303,10 +307,7 @@ pub(crate) fn collect_specializations_from_function(
spec: &Specialization,
function: &FunctionSst,
) -> Vec<(Fun, Specialization)> {

// Build map from function type parametres to spec types.
assert!(spec.is_empty() || spec.typs.len() == function.x.typ_params.len());
let spec_map: SpecMap = std::iter::zip(function.x.typ_params.iter().cloned(), spec.typs.iter().cloned()).collect();
let spec_map = spec.create_spec_map(&function.x.typ_params);

let mut visitor = SpecializationVisitor::new(&spec_map);
visitor.visit_function(function).unwrap();
Expand All @@ -319,7 +320,8 @@ Collect all polymorphic function invocations in a module
pub fn mono_krate_for_module(krate: &KrateSst) -> HashMap<Fun, HashSet<Specialization>> {
let KrateSstX { functions, .. } = &**krate;

let mut to_visit: VecDeque<(Specialization, &FunctionSst)> = functions.iter().map(|f| (Default::default(), f)).collect();
let mut to_visit: VecDeque<(Specialization, &FunctionSst)> =
functions.iter().map(|f| (Default::default(), f)).collect();
let mut invocations: HashMap<Fun, HashSet<Specialization>> = HashMap::new();

while let Some((caller_spec, caller_sst)) = to_visit.pop_front() {
Expand All @@ -330,14 +332,14 @@ pub fn mono_krate_for_module(krate: &KrateSst) -> HashMap<Fun, HashSet<Specializ
continue;
}
}
invocations.entry(callee).or_insert_with(HashSet::new).insert(callee_spec.clone());

// Push this call site back into queue
let callee_sst = functions
.iter()
.find(|f| f.x.name == callee)
.unwrap_or_else(|| panic!("Function name not found: {callee}"));
to_visit.push_back((callee_spec, callee_sst))
to_visit.push_back((callee_spec.clone(), callee_sst));

invocations.entry(callee).or_insert_with(HashSet::new).insert(callee_spec);
}
}
invocations
Expand Down
Loading

0 comments on commit 839483e

Please sign in to comment.