From 839483e28577b629e35acdaf88fa9d5d3623e20e Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Mon, 25 Nov 2024 14:58:04 -0800 Subject: [PATCH] Add spec map generation to all expression call sites --- source/rust_verify/src/commands.rs | 39 ++++++--- source/vir/src/mono.rs | 42 +++++----- source/vir/src/sst_to_air.rs | 86 +++++++++++--------- source/vir/src/sst_to_air_func.rs | 125 +++++++++++++++++++++-------- 4 files changed, 190 insertions(+), 102 deletions(-) diff --git a/source/rust_verify/src/commands.rs b/source/rust_verify/src/commands.rs index 47befb86a..037ab9ac2 100644 --- a/source/rust_verify/src/commands.rs +++ b/source/rust_verify/src/commands.rs @@ -309,18 +309,30 @@ impl<'a> OpGenerator<'a> { return Ok(vec![]); }; - let (commands, snap_map) = - vir::sst_to_air_func::func_sst_to_air(self.ctx, &function, func_check_sst)?; + let specs = self.specializations.get(fun).unwrap_or("Specialization does not exist"); + let results: Vec<_> = specs + .iter() + .map(|spec| { + let spec_map = spec.create_spec_map(&function.x.typ_params); + let (commands, snap_map) = vir::sst_to_air_func::func_sst_to_air( + self.ctx, + &function, + func_check_sst, + &spec_map, + )?; + Ok(Op::query( + QueryOp::Body(style), + commands, + snap_map, + &function, + Some(func_check_sst.clone()), + )) + }) + .collect::>()?; self.ctx.fun = None; - Ok(vec![Op::query( - QueryOp::Body(style), - commands, - snap_map, - &function, - Some(func_check_sst.clone()), - )]) + Ok(results) } fn handle_proof_body_expand( @@ -331,8 +343,13 @@ impl<'a> OpGenerator<'a> { ) -> Result { self.ctx.fun = mk_fun_ctx(&function, false /*recommend*/); - let (commands, snap_map) = - vir::sst_to_air_func::func_sst_to_air(self.ctx, &function, &expanded_function_sst)?; + let fun = &function.x.name; + let (commands, snap_map) = vir::sst_to_air_func::func_sst_to_air( + self.ctx, + &function, + &expanded_function_sst, + &Default::default(), + )?; let commands = focus_commands_with_context_on_assert_id(commands, assert_id); self.ctx.fun = None; diff --git a/source/vir/src/mono.rs b/source/vir/src/mono.rs index a2d518b8f..48534f550 100644 --- a/source/vir/src/mono.rs +++ b/source/vir/src/mono.rs @@ -23,8 +23,8 @@ use crate::ast::Idents; use crate::ast::IntRange; use crate::ast::Primitive; use crate::ast::{Dt, Fun, TypDecoration, TypDecorationArg}; -use crate::def::{path_to_string, Spanned}; use crate::def::POLY; +use crate::def::{path_to_string, Spanned}; use crate::poly; use crate::sst::{CallFun, Exp, ExpX, KrateSstX, Stm}; use crate::sst::{Par, ParX}; @@ -116,7 +116,7 @@ impl SpecTypX { } } } -fn typs_as_spec(typs: &Typs, spec_map: &SpecMap<'_>) -> Vec { +fn typs_as_spec(typs: &Typs, spec_map: &SpecMap) -> Vec { let mut spec_typs: Vec = Vec::new(); for typ in typs.iter() { let spec_typ = typ_as_spec(typ, spec_map); @@ -125,7 +125,7 @@ fn typs_as_spec(typs: &Typs, spec_map: &SpecMap<'_>) -> Vec { spec_typs } -pub(crate) fn typ_as_spec(typ: &Typ, spec_map: &SpecMap<'_>) -> SpecTyp { +pub(crate) fn typ_as_spec(typ: &Typ, spec_map: &SpecMap) -> SpecTyp { match &**typ { TypX::Bool => Arc::new(SpecTypX::Bool), TypX::Int(range) => Arc::new(SpecTypX::Int(*range)), @@ -158,12 +158,10 @@ pub(crate) fn typ_as_spec(typ: &Typ, spec_map: &SpecMap<'_>) -> SpecTyp { }; (*spec_typ).clone() } - TypX::Boxed(..) | TypX::SpecFn(..) | TypX::FnDef(..) => { -Arc::new(SpecTypX::Poly) - } + TypX::Boxed(..) | TypX::SpecFn(..) | TypX::FnDef(..) => Arc::new(SpecTypX::Poly), TypX::ConstInt(_) => Arc::new(SpecTypX::Poly), TypX::Projection { .. } => Arc::new(SpecTypX::Poly), - TypX::Poly => Arc::new(SpecTypX::Poly) + TypX::Poly => Arc::new(SpecTypX::Poly), } } @@ -181,7 +179,7 @@ impl Specialization { pub fn empty() -> Self { Self { typs: Arc::new(vec![]) } } - pub fn from_exp<'a>(exp: &'a ExpX, spec_map: &SpecMap<'_>) -> Option<(&'a Fun, Self)> { + pub fn from_exp<'a>(exp: &'a ExpX, spec_map: &SpecMap) -> Option<(&'a Fun, Self)> { let ExpX::Call(CallFun::Fun(fun, _) | CallFun::Recursive(fun), typs, _) = exp else { return None; }; @@ -263,10 +261,16 @@ impl Specialization { format!(" specialized to {:?}", &self.typs) } - pub fn is_empty(&self) -> bool { self.typs.is_empty() } + pub fn is_empty(&self) -> bool { + self.typs.is_empty() + } + + pub fn create_spec_map(&self, typ_params: &Idents) -> SpecMap { + assert!(self.is_empty() || self.typs.len() == typ_params.len()); + std::iter::zip(typ_params.iter().cloned(), self.typs.iter().cloned()).collect() + } } -impl Default for Specialization -{ +impl Default for Specialization { fn default() -> Self { return Self::empty(); } @@ -283,7 +287,7 @@ struct SpecializationVisitor<'a> { spec_map: &'a SpecMap, } impl<'a> SpecializationVisitor<'a> { - fn new(spec_map: &SpecMap) -> Self { + fn new(spec_map: &'a SpecMap) -> Self { Self { invocations: vec![], spec_map } } } @@ -303,10 +307,7 @@ pub(crate) fn collect_specializations_from_function( spec: &Specialization, function: &FunctionSst, ) -> Vec<(Fun, Specialization)> { - - // Build map from function type parametres to spec types. - assert!(spec.is_empty() || spec.typs.len() == function.x.typ_params.len()); - let spec_map: SpecMap = std::iter::zip(function.x.typ_params.iter().cloned(), spec.typs.iter().cloned()).collect(); + let spec_map = spec.create_spec_map(&function.x.typ_params); let mut visitor = SpecializationVisitor::new(&spec_map); visitor.visit_function(function).unwrap(); @@ -319,7 +320,8 @@ Collect all polymorphic function invocations in a module pub fn mono_krate_for_module(krate: &KrateSst) -> HashMap> { let KrateSstX { functions, .. } = &**krate; - let mut to_visit: VecDeque<(Specialization, &FunctionSst)> = functions.iter().map(|f| (Default::default(), f)).collect(); + let mut to_visit: VecDeque<(Specialization, &FunctionSst)> = + functions.iter().map(|f| (Default::default(), f)).collect(); let mut invocations: HashMap> = HashMap::new(); while let Some((caller_spec, caller_sst)) = to_visit.pop_front() { @@ -330,14 +332,14 @@ pub fn mono_krate_for_module(krate: &KrateSst) -> HashMap Vec { mk_id(str_apply(crate::def::TYPE_ID_CONST_INT, &vec![big_int_to_expr(c)])) } TypX::Air(_) => panic!("internal error: typ_to_ids of Air"), - TypX::Poly => mk_id(str_var(crate::def::TYPE_ID_POLY)), + TypX::Poly => mk_id(str_var(crate::def::TYPE_ID_POLY)), } } @@ -664,20 +664,25 @@ pub(crate) enum ExprMode { } #[derive(Debug, Clone)] -pub(crate) struct ExprCtxt { +pub(crate) struct ExprCtxt<'a> { pub mode: ExprMode, pub is_singular: bool, + pub spec_map: &'a mono::SpecMap, } -impl ExprCtxt { - pub(crate) fn new() -> Self { - ExprCtxt { mode: ExprMode::Body, is_singular: false } +impl<'a> ExprCtxt<'a> { + pub(crate) fn new(spec_map: &'a mono::SpecMap) -> Self { + ExprCtxt { mode: ExprMode::Body, is_singular: false, spec_map } } - pub(crate) fn new_mode(mode: ExprMode) -> Self { - ExprCtxt { mode, is_singular: false } + pub(crate) fn new_mode(mode: ExprMode, spec_map: &'a mono::SpecMap) -> Self { + ExprCtxt { mode, is_singular: false, spec_map } } - pub(crate) fn new_mode_singular(mode: ExprMode, is_singular: bool) -> Self { - ExprCtxt { mode, is_singular } + pub(crate) fn new_mode_singular( + mode: ExprMode, + is_singular: bool, + spec_map: &'a mono::SpecMap, + ) -> Self { + ExprCtxt { mode, is_singular, spec_map } } } @@ -771,9 +776,9 @@ pub(crate) fn exp_to_expr(ctx: &Ctx, exp: &Exp, expr_ctxt: &ExprCtxt) -> Result< ExpX::Call(f @ (CallFun::Fun(..) | CallFun::Recursive(_)), typs, args) => { let specialization = match ctx.global.poly_strategy { mono::PolyStrategy::Mono => { - let (_, spec) = mono::Specialization::from_exp(&exp.x) - .expect("Could not create specialization rom call site"); - spec + let (_, spec) = mono::Specialization::from_exp(&exp.x, expr_ctxt.spec_map) + .expect("Could not create specialization rom call site"); + spec } mono::PolyStrategy::Poly => mono::Specialization::empty(), }; @@ -1543,12 +1548,17 @@ fn assume_other_fields_unchanged_inner( } } -// fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, VirErr> { +// fn // let expr_ctxt = ExprCtxt { mode: ExprMode::Body, is_bit_vector: false }; // let result = match &stm.x { -fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, VirErr> { - let expr_ctxt = &ExprCtxt::new(); +fn stm_to_stmts( + ctx: &Ctx, + state: &mut State, + stm: &Stm, + spec_map: &mono::SpecMap, +) -> Result, VirErr> { + let expr_ctxt = &ExprCtxt::new(spec_map); let result = match &stm.x { StmX::Call { fun, resolved_method, mode, typ_args: typs, args, split, dest, assert_id } => { assert!(split.is_none()); @@ -1804,7 +1814,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi let mut stmts = if let Some(dest_id) = state.post_condition_info.dest.clone() { let ret_exp = ret_exp.as_ref().expect("if dest is provided, expr must be provided"); - stm_to_stmts(ctx, state, &assume_var(&stm.span, &dest_id, ret_exp))? + stm_to_stmts(ctx, state, &assume_var(&stm.span, &dest_id, ret_exp), spec_map)? } else { // If there is no `dest_id`, then the returned expression // gets ignored. This should happen for functions that @@ -1816,7 +1826,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi if ctx.checking_spec_preconditions() { for stm in state.post_condition_info.ens_spec_precondition_stms.clone().iter() { - let mut new_stmts = stm_to_stmts(ctx, state, stm)?; + let mut new_stmts = stm_to_stmts(ctx, state, stm, spec_map)?; stmts.append(&mut new_stmts); } } else { @@ -1865,7 +1875,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi } state.push_scope(); - let proof_stmts: Vec = stm_to_stmts(ctx, state, body)?; + let proof_stmts: Vec = stm_to_stmts(ctx, state, body, spec_map)?; state.pop_scope(); let mut air_body: Vec = Vec::new(); air_body.append(&mut proof_stmts.clone()); @@ -1936,7 +1946,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi } StmX::Assign { lhs: Dest { dest, is_init: true }, rhs } => { let x = loc_is_var(dest).expect("is_init assign dest must be a variable"); - stm_to_stmts(ctx, state, &assume_var(&stm.span, x, rhs))? + stm_to_stmts(ctx, state, &assume_var(&stm.span, x, rhs), spec_map)? } StmX::Assign { lhs: Dest { dest, is_init: false }, rhs } => { let mut stmts: Vec = Vec::new(); @@ -1964,6 +1974,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi ctx, state, &Spanned::new(stm.span.clone(), StmX::Assume(eq)), + spec_map, )?); stmts.extend(assume_other_fields_unchanged( ctx, @@ -1980,7 +1991,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi stmts } StmX::DeadEnd(s) => { - vec![Arc::new(StmtX::DeadEnd(one_stmt(stm_to_stmts(ctx, state, s)?)))] + vec![Arc::new(StmtX::DeadEnd(one_stmt(stm_to_stmts(ctx, state, s, spec_map)?)))] } StmX::BreakOrContinue { label, is_break } => { let loop_info = if label.is_some() { @@ -2073,7 +2084,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi let mut unwind = UnwindAir::MayUnwind; std::mem::swap(&mut state.unwind, &mut unwind); - let mut body_stmts = stm_to_stmts(ctx, state, body)?; + let mut body_stmts = stm_to_stmts(ctx, state, body, spec_map)?; std::mem::swap(&mut state.mask, &mut mask); std::mem::swap(&mut state.unwind, &mut unwind); @@ -2086,10 +2097,10 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi let neg_cond = Arc::new(ExprX::Unary(air::ast::UnaryOp::Not, pos_cond.clone())); let pos_assume = Arc::new(StmtX::Assume(pos_cond)); let neg_assume = Arc::new(StmtX::Assume(neg_cond)); - let mut lhss = stm_to_stmts(ctx, state, lhs)?; + let mut lhss = stm_to_stmts(ctx, state, lhs, spec_map)?; let mut rhss = match rhs { None => vec![], - Some(rhs) => stm_to_stmts(ctx, state, rhs)?, + Some(rhs) => stm_to_stmts(ctx, state, rhs, spec_map)?, }; lhss.insert(0, pos_assume); rhss.insert(0, neg_assume); @@ -2276,10 +2287,10 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi air_body.push(Arc::new(StmtX::Assume(inv.clone()))); } for dec in decrease_init.iter() { - air_body.append(&mut stm_to_stmts(ctx, state, dec)?); + air_body.append(&mut stm_to_stmts(ctx, state, dec, spec_map)?); } - let cond_stmts = cond_stm.map(|s| stm_to_stmts(ctx, state, s)).transpose()?; + let cond_stmts = cond_stm.map(|s| stm_to_stmts(ctx, state, s, spec_map)).transpose()?; if let Some(cond_stmts) = &cond_stmts { assert!(loop_isolation); air_body.append(&mut cond_stmts.clone()); @@ -2301,7 +2312,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi decrease: decrease.clone(), }; state.loop_infos.push(loop_info); - air_body.append(&mut stm_to_stmts(ctx, state, body)?); + air_body.append(&mut stm_to_stmts(ctx, state, body, spec_map)?); state.loop_infos.pop(); if !ctx.checking_spec_preconditions() { @@ -2421,7 +2432,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi // Build the names_expr. Note: In the SST, this should have been assigned // to an expression whose value is constant for the entire block. - let namespace_expr = exp_to_expr(ctx, namespace_exp, &ExprCtxt::new())?; + let namespace_expr = exp_to_expr(ctx, namespace_exp, &ExprCtxt::new(spec_map))?; // Assert that the namespace of the inv we are opening is in the mask set if !ctx.checking_spec_preconditions() { @@ -2438,7 +2449,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi UnwindAir::NoUnwind(ReasonForNoUnwind::OpenInvariant(stm.span.clone())); swap(&mut state.mask, &mut inner_mask); swap(&mut state.unwind, &mut inner_unwind); - stmts.append(&mut stm_to_stmts(ctx, state, body_stm)?); + stmts.append(&mut stm_to_stmts(ctx, state, body_stm, spec_map)?); swap(&mut state.mask, &mut inner_mask); swap(&mut state.unwind, &mut inner_unwind); @@ -2490,7 +2501,7 @@ fn stm_to_stmts(ctx: &Ctx, state: &mut State, stm: &Stm) -> Result, Vi } let mut stmts: Vec = Vec::new(); for s in stms.iter() { - stmts.extend(stm_to_stmts(ctx, state, s)?); + stmts.extend(stm_to_stmts(ctx, state, s, spec_map)?); } if ctx.debug { state.pop_scope(); @@ -2610,6 +2621,7 @@ pub(crate) fn body_stm_to_air( is_integer_ring: bool, is_bit_vector_mode: bool, is_nonlinear: bool, + spec_map: &mono::SpecMap, ) -> Result<(Vec, Vec<(Span, SnapPos)>), VirErr> { let FuncCheckSst { reqs, post_condition, mask_set, body: stm, local_decls, statics, unwind } = func_check_sst; @@ -2677,14 +2689,14 @@ pub(crate) fn body_stm_to_air( let mut ens_exprs: Vec<(Span, Expr)> = Vec::new(); for ens in post_condition.ens_exps.iter() { - let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body); + let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body, spec_map); let e = exp_to_expr(ctx, &ens, expr_ctxt)?; ens_exprs.push((ens.span.clone(), e)); } let f_mask_singletons = |v: &Vec>| -> Result>, VirErr> { - let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body); + let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body, spec_map); let mut v2: Vec> = Vec::new(); for m in v.iter() { let expr = exp_to_expr(ctx, &m.expr, expr_ctxt)?; @@ -2702,7 +2714,7 @@ pub(crate) fn body_stm_to_air( UnwindSst::MayUnwind => UnwindAir::MayUnwind, UnwindSst::NoUnwind => UnwindAir::NoUnwind(ReasonForNoUnwind::Function), UnwindSst::NoUnwindWhen(exp) => { - let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body); + let expr_ctxt = &ExprCtxt::new_mode(ExprMode::Body, spec_map); let e = exp_to_expr(ctx, &exp, expr_ctxt)?; UnwindAir::NoUnwindWhen(e) } @@ -2751,7 +2763,7 @@ pub(crate) fn body_stm_to_air( stm, ); - let mut stmts = stm_to_stmts(ctx, &mut state, &stm)?; + let mut stmts = stm_to_stmts(ctx, &mut state, &stm, spec_map)?; if has_mut_params { stmts.insert(0, Arc::new(StmtX::Snapshot(snapshot_ident(SNAPSHOT_PRE)))); @@ -2779,7 +2791,7 @@ pub(crate) fn body_stm_to_air( } for req in reqs.iter() { - let expr_ctxt = &ExprCtxt::new_mode(ExprMode::BodyPre); + let expr_ctxt = &ExprCtxt::new_mode(ExprMode::BodyPre, spec_map); let e = exp_to_expr(ctx, &req, expr_ctxt)?; local.push(mk_unnamed_axiom(e)); } @@ -2803,7 +2815,8 @@ pub(crate) fn body_stm_to_air( "Unspported expression in integer_ring".to_string(), "at the require clause".to_string(), ); - let air_expr = exp_to_expr(ctx, req, &ExprCtxt::new_mode(ExprMode::BodyPre))?; + let air_expr = + exp_to_expr(ctx, req, &ExprCtxt::new_mode(ExprMode::BodyPre, spec_map))?; let assert_stm = Arc::new(StmtX::Assert(None, error, None, air_expr)); singular_req_stmts.push(assert_stm); } @@ -2815,7 +2828,8 @@ pub(crate) fn body_stm_to_air( "Unspported expression in integer_ring".to_string(), "at the ensure clause".to_string(), ); - let air_expr = exp_to_expr(ctx, ens, &ExprCtxt::new_mode(ExprMode::BodyPre))?; + let air_expr = + exp_to_expr(ctx, ens, &ExprCtxt::new_mode(ExprMode::BodyPre, spec_map))?; let assert_stm = Arc::new(StmtX::Assert(None, error, None, air_expr)); singular_ens_stmts.push(assert_stm); } diff --git a/source/vir/src/sst_to_air_func.rs b/source/vir/src/sst_to_air_func.rs index ad6ce2ed2..219958aea 100644 --- a/source/vir/src/sst_to_air_func.rs +++ b/source/vir/src/sst_to_air_func.rs @@ -203,7 +203,7 @@ fn func_body_to_air( let mut new_pars: Pars = Arc::new(Vec::new()); let new_pars_mut = Arc::make_mut(&mut new_pars); - for i in pars.iter(){ + for i in pars.iter() { new_pars_mut.push(specialization.transform_par(&function.x.typ_params, i)); } @@ -235,19 +235,23 @@ fn func_body_to_air( if function.x.has.has_decrease { for param in pars.iter() { let arg = ident_var(¶m.x.name.lower()); - if let Some(pre) = typ_invariant(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), &arg) { + if let Some(pre) = typ_invariant( + ctx, + &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), + &arg, + ) { def_reqs.push(pre.clone()); } } } + let spec_map = specialization.create_spec_map(&function.x.typ_params); if let Some(exp) = decrease_when { - let expr = exp_to_expr(ctx, &exp, &ExprCtxt::new_mode(ExprMode::Spec))?; + let expr = exp_to_expr(ctx, &exp, &ExprCtxt::new_mode(ExprMode::Spec, &spec_map))?; // conditions on value arguments: def_reqs.push(expr); } - if let Some(termination_check) = termination_check { let (termination_commands, _snap_map) = crate::sst_to_air::body_stm_to_air( ctx, @@ -260,6 +264,7 @@ fn func_body_to_air( false, false, false, + &spec_map, )?; check_commands.extend(termination_commands.iter().cloned()); } @@ -290,9 +295,6 @@ fn func_body_to_air( let typ_args = vec_map(&function.x.typ_params, |x| Arc::new(TypX::TypParam(x.clone()))); (function.x.name.clone(), function.x.name.clone(), Arc::new(typ_args)) }; - //CONSTRUCT MAP HERE - assert!(specialization.is_empty() || specialization.typs.len() == function.x.typ_params.len()); - let spec_map: HashMap<&Ident, &SpecTyp> = std::iter::zip(function.x.typ_params.iter(), specialization.typs.iter()).collect(); // non-recursive: // (axiom (=> (fuel_bool fuel%f) (forall (...) (= (f ...) body)))) // recursive: @@ -300,7 +302,7 @@ fn func_body_to_air( // (axiom (forall (... fuel) (= (rec%f ... fuel) (rec%f ... zero) ))) // (axiom (forall (... fuel) (= (rec%f ... (succ fuel)) body[rec%f ... fuel] ))) // (axiom (=> (fuel_bool fuel%f) (forall (...) (= (f ...) (rec%f ... (succ fuel_nat%f)))))) - let body_expr = exp_to_expr(&ctx, &new_body_exp, &ExprCtxt::new(), &spec_map)?; + let body_expr = exp_to_expr(&ctx, &new_body_exp, &ExprCtxt::new(&spec_map))?; let def_body = if !function.x.has.is_recursive { body_expr } else { @@ -335,8 +337,10 @@ fn func_body_to_air( let eq_body = mk_eq(&rec_f_succ, &body_expr); let name_zero = format!("{}_fuel_to_zero", &fun_to_air_ident(&name)); let name_body = format!("{}_fuel_to_body", &fun_to_air_ident(&name)); - let bind_zero = func_bind(ctx, name_zero, &function.x.typ_params, &new_pars, &rec_f_fuel, true); - let bind_body = func_bind(ctx, name_body, &function.x.typ_params, &new_pars, &rec_f_succ, true); + let bind_zero = + func_bind(ctx, name_zero, &function.x.typ_params, &new_pars, &rec_f_fuel, true); + let bind_body = + func_bind(ctx, name_body, &function.x.typ_params, &new_pars, &rec_f_succ, true); let implies_body = mk_implies(&mk_and(&def_reqs), &eq_body); let forall_zero = mk_bind_expr(&bind_zero, &eq_zero); let forall_body = mk_bind_expr(&bind_body, &implies_body); @@ -379,6 +383,7 @@ fn req_ens_to_air( typ: air::ast::Typ, inherit_from: Option<(Ident, Typs)>, filter: Option, + spec_map: &mono::SpecMap, ) -> Result { if specs.len() + typing_invs.len() > 0 { let mut all_typs = (**typs).clone(); @@ -406,9 +411,9 @@ fn req_ens_to_air( } for exp in specs.iter() { let expr_ctxt = if is_singular { - ExprCtxt::new_mode_singular(ExprMode::Spec, true) + ExprCtxt::new_mode_singular(ExprMode::Spec, true, spec_map) } else { - ExprCtxt::new_mode(ExprMode::Spec) + ExprCtxt::new_mode(ExprMode::Spec, spec_map) }; let expr = exp_to_expr(ctx, exp, &expr_ctxt)?; let loc_expr = match msg { @@ -447,15 +452,22 @@ pub fn func_name_to_air( if function.x.has.is_recursive { let rec_f = suffix_global_id(&fun_to_air_ident(&prefix_recursive_fun(&function.x.name))); - let mut rec_typs = - vec_map(&*function.x.pars, |param| typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ))); + let mut rec_typs = vec_map(&*function.x.pars, |param| { + typ_to_air( + ctx, + &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), + ) + }); for _ in function.x.typ_params.iter() { for x in crate::def::types().iter().rev() { rec_typs.insert(0, str_typ(x)); } } rec_typs.push(str_typ(FUEL_TYPE)); - let typ = typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ)); + let typ = typ_to_air( + ctx, + &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ), + ); let ident = specialization.transform_ident(rec_f); let rec_decl = Arc::new(DeclX::Fun(ident, Arc::new(rec_typs), typ)); commands.push(Arc::new(CommandX::Global(rec_decl))); @@ -471,18 +483,21 @@ pub fn func_name_to_air( return Ok(Arc::new(commands)); } - let mut all_typs = vec_map(&function.x.pars, |param| typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ))); + let mut all_typs = vec_map(&function.x.pars, |param| { + typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ)) + }); for _ in function.x.typ_params.iter() { for x in crate::def::types().iter().rev() { all_typs.insert(0, str_typ(x)); } } - let all_typs = Arc::new(all_typs); - - let typ = typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ)); + let typ = typ_to_air( + ctx, + &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ), + ); let mut names = vec![function.x.name.clone()]; if let FunctionKind::TraitMethodDecl { .. } = &function.x.kind { names.push(crate::def::trait_default_name(&function.x.name)); @@ -503,20 +518,28 @@ pub fn func_name_to_air( // represent as 0-argument function) commands.push(Arc::new(CommandX::Global(Arc::new(DeclX::Const( specialization.transform_ident(static_name(&function.x.name)), - typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ)), + typ_to_air( + ctx, + &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ), + ), ))))); } Ok(Arc::new(commands)) } -pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: &mono::Specialization) -> Result { +pub fn func_decl_to_air( + ctx: &mut Ctx, + function: &FunctionSst, + specialization: &mono::Specialization, +) -> Result { let func_decl_sst = &function.x.decl; let (is_trait_method_impl, inherit_fn_ens) = match &function.x.kind { FunctionKind::TraitMethodImpl { method, trait_typ_args, .. } => { if ctx.funcs_with_ensure_predicate[method] { // NOTE: Maybe we should use a different specialization - let ens = prefix_ensures(&specialization.transform_ident(fun_to_air_ident(&method))); + let ens = + prefix_ensures(&specialization.transform_ident(fun_to_air_ident(&method))); let mut typ_args = (**trait_typ_args).clone(); let num_trait_and_method_typ_params = ctx.func_map[method].x.typ_params.len(); let num_method_typ_params = num_trait_and_method_typ_params - trait_typ_args.len(); @@ -537,11 +560,20 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & _ => (false, None), }; - let req_typs: Arc> = - Arc::new(function.x.pars.iter().map(|param| typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ))).collect()); + let req_typs: Arc> = Arc::new( + function + .x + .pars + .iter() + .map(|param| { + typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ)) + }) + .collect(), + ); let mut decl_commands: Vec = Vec::new(); let func_name = specialization.transform_ident(fun_to_air_ident(&function.x.name)); + let spec_map = specialization.create_spec_map(&function.x.typ_params); // Requires if function.x.has.has_requires && !function.x.attrs.broadcast_forall_only { assert!(!is_trait_method_impl); @@ -567,6 +599,7 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & bool_typ(), None, Some(func_name.clone()), + &spec_map, )?; } @@ -587,6 +620,7 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & int_typ(), None, None, + &spec_map, ); } } @@ -607,6 +641,7 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & bool_typ(), None, None, + &spec_map, ); } @@ -616,7 +651,10 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & .pars .iter() .flat_map(|param| { - let air_typ = typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ)); + let air_typ = typ_to_air( + ctx, + &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), + ); if !param.x.is_mut { vec![air_typ] } else { @@ -627,8 +665,10 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & let mut ens_typing_invs: Vec = Vec::new(); if matches!(function.x.mode, Mode::Exec | Mode::Proof) { if function.x.has.has_return_name { - let ParX { name, typ, .. } = &specialization.transform_par(&function.x.typ_params, &function.x.ret).x; - ens_typs.push(typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, &typ))); + let ParX { name, typ, .. } = + &specialization.transform_par(&function.x.typ_params, &function.x.ret).x; + ens_typs + .push(typ_to_air(ctx, &specialization.transform_typ(&function.x.typ_params, &typ))); if let Some(expr) = typ_invariant(ctx, &typ, &ident_var(&name.lower())) { ens_typing_invs.push(expr); } @@ -637,8 +677,11 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & for param in func_decl_sst.post_pars.iter().filter(|p| matches!(p.x.purpose, ParPurpose::MutPost)) { - if let Some(expr) = typ_invariant(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), &ident_var(¶m.x.name.lower())) - { + if let Some(expr) = typ_invariant( + ctx, + &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), + &ident_var(¶m.x.name.lower()), + ) { ens_typing_invs.push(expr); } } @@ -669,12 +712,13 @@ pub fn func_decl_to_air(ctx: &mut Ctx, function: &FunctionSst, specialization: & bool_typ(), inherit_fn_ens, None, + &spec_map, )? }; ctx.funcs_with_ensure_predicate.insert(function.x.name.clone(), has_ens_pred); for exp in func_decl_sst.fndef_axioms.iter() { - let expr = exp_to_expr(ctx, exp, &ExprCtxt::new_mode(ExprMode::Spec))?; + let expr = exp_to_expr(ctx, exp, &ExprCtxt::new_mode(ExprMode::Spec, &spec_map))?; let axiom = mk_unnamed_axiom(expr); decl_commands.push(Arc::new(CommandX::Global(axiom))); } @@ -706,9 +750,10 @@ pub fn func_axioms_to_air( let mut new_pars: Pars = Arc::new(Vec::new()); let new_pars_mut = Arc::make_mut(&mut new_pars); - for i in function.x.pars.iter(){ + for i in function.x.pars.iter() { new_pars_mut.push(specialization.transform_par(&function.x.typ_params, i)); } + let spec_map = specialization.create_spec_map(&function.x.typ_params); match function.x.mode { Mode::Spec => { // Body @@ -778,12 +823,20 @@ pub fn func_axioms_to_air( for param in function.x.pars.iter() { let arg = ident_var(¶m.x.name.lower()); f_args.push(arg.clone()); - if let Some(pre) = typ_invariant(ctx, &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), &arg) { + if let Some(pre) = typ_invariant( + ctx, + &specialization.transform_typ(&function.x.typ_params, ¶m.x.typ), + &arg, + ) { f_pre.push(pre.clone()); } } let f_app = ident_apply(&name, &Arc::new(f_args)); - if let Some(post) = typ_invariant(ctx, &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ), &f_app) { + if let Some(post) = typ_invariant( + ctx, + &specialization.transform_typ(&function.x.typ_params, &function.x.ret.x.typ), + &f_app, + ) { // (axiom (forall (...) (=> pre post))) let name = format!("{}_pre_post", name); let e_forall = mk_bind_expr( @@ -820,9 +873,9 @@ pub fn func_axioms_to_air( let forall: Arc> = SpannedTyped::new(&span, &Arc::new(TypX::Bool), forallx); let expr_ctxt = if is_singular { - ExprCtxt::new_mode_singular(ExprMode::Spec, true) + ExprCtxt::new_mode_singular(ExprMode::Spec, true, &spec_map) } else { - ExprCtxt::new_mode(ExprMode::Spec) + ExprCtxt::new_mode(ExprMode::Spec, &spec_map) }; let expr = exp_to_expr(ctx, &forall, &expr_ctxt)?; @@ -854,6 +907,7 @@ pub fn func_sst_to_air( ctx: &Ctx, function: &FunctionSst, func_check_sst: &FuncCheckSst, + spec_map: &mono::SpecMap, ) -> Result<(Arc>, Vec<(Span, SnapPos)>), VirErr> { let (commands, snap_map) = crate::sst_to_air::body_stm_to_air( ctx, @@ -866,6 +920,7 @@ pub fn func_sst_to_air( function.x.attrs.integer_ring, function.x.attrs.bit_vector, function.x.attrs.nonlinear, + &spec_map, )?; Ok((Arc::new(commands), snap_map))