From 202ca0e7fa42c92c68a91b469995f36e885958ec Mon Sep 17 00:00:00 2001 From: Zack Grannan Date: Thu, 7 Dec 2023 06:49:00 -0800 Subject: [PATCH] Raise an error when postconditions of pure functions contain old() expressions (#1474) * Raise an error if old() appears in postcondition of pure functions * rustfmt, commit more files * Add a test * Clippy * Remove unnecessary debug * Fix test * More tests --- .../incorrect/old_in_pure_postcondition.rs | 18 +++++ .../old_in_pure_postcondition_extern.rs | 22 ++++++ prusti-viper/src/encoder/interface.rs | 46 ++++++++++++ .../mir/pure/pure_functions/encoder_poly.rs | 75 +++++++++++-------- .../mir/pure/pure_functions/interface.rs | 35 ++++++--- prusti-viper/src/encoder/mod.rs | 1 + .../src/encoder/stub_function_encoder.rs | 67 ++++++++++------- vir/defs/polymorphic/ast/expr.rs | 14 ++++ 8 files changed, 212 insertions(+), 66 deletions(-) create mode 100644 prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition.rs create mode 100644 prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition_extern.rs create mode 100644 prusti-viper/src/encoder/interface.rs diff --git a/prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition.rs b/prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition.rs new file mode 100644 index 00000000000..a1790f8eb60 --- /dev/null +++ b/prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition.rs @@ -0,0 +1,18 @@ +use prusti_contracts::*; + +struct MyWrapper(u32); + +impl MyWrapper { + #[pure] + #[ensures(old(self.0) == self.0)] + fn unwrap(&self) -> u32 { //~ ERROR old expressions should not appear in the postconditions of pure functions + self.0 + } +} + +fn test(x: &MyWrapper) -> u32 { + // Following error is due to stub encoding of invalid spec for function `unwrap()` + x.unwrap() //~ ERROR precondition of pure function call might not hold +} + +fn main() { } diff --git a/prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition_extern.rs b/prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition_extern.rs new file mode 100644 index 00000000000..01faa27f922 --- /dev/null +++ b/prusti-tests/tests/verify/fail/incorrect/old_in_pure_postcondition_extern.rs @@ -0,0 +1,22 @@ +use prusti_contracts::*; + +#[extern_spec] +impl std::option::Option { + #[pure] // <=== Error triggered by this + #[requires(self.is_some())] + #[ensures(old(self) === Some(result))] + pub fn unwrap(self) -> T; //~ ERROR old expressions should not appear in the postconditions of pure functions + + #[pure] + #[ensures(result == matches!(self, Some(_)))] + pub const fn is_some(&self) -> bool; +} + +#[pure] +#[requires(x.is_some())] +fn test(x: Option) -> i32 { + // Following error is due to stub encoding of invalid external spec for function `unwrap()` + x.unwrap() //~ ERROR precondition of pure function call might not hold +} + +fn main() { } diff --git a/prusti-viper/src/encoder/interface.rs b/prusti-viper/src/encoder/interface.rs new file mode 100644 index 00000000000..e5b8a64e0de --- /dev/null +++ b/prusti-viper/src/encoder/interface.rs @@ -0,0 +1,46 @@ +use crate::encoder::{ + errors::{SpannedEncodingResult, WithSpan}, + snapshot::interface::SnapshotEncoderInterface, + Encoder, +}; + +use prusti_rustc_interface::{ + middle::{mir, ty, ty::Binder}, + span::Span, +}; + +use vir_crate::polymorphic as vir_poly; + +pub(crate) trait PureFunctionFormalArgsEncoderInterface<'p, 'v: 'p, 'tcx: 'v> { + fn encoder(&self) -> &'p Encoder<'v, 'tcx>; + + fn check_type( + &self, + var_span: Span, + ty: Binder<'tcx, ty::Ty<'tcx>>, + ) -> SpannedEncodingResult<()>; + + fn get_span(&self, local: mir::Local) -> Span; + + fn encode_formal_args( + &self, + sig: ty::PolyFnSig<'tcx>, + ) -> SpannedEncodingResult> { + let mut formal_args = vec![]; + for local_idx in 0..sig.skip_binder().inputs().len() { + let local_ty = sig.input(local_idx); + let local = mir::Local::from_usize(local_idx + 1); + let var_name = format!("{local:?}"); + let var_span = self.get_span(local); + + self.check_type(var_span, local_ty)?; + + let var_type = self + .encoder() + .encode_snapshot_type(local_ty.skip_binder()) + .with_span(var_span)?; + formal_args.push(vir_poly::LocalVar::new(var_name, var_type)) + } + Ok(formal_args) + } +} diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_poly.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_poly.rs index 7de3b1fdc2e..99e0a6012fd 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_poly.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/encoder_poly.rs @@ -7,6 +7,7 @@ use crate::encoder::{ errors::{ErrorCtxt, SpannedEncodingError, SpannedEncodingResult, WithSpan}, high::{generics::HighGenericsEncoderInterface, types::HighTypeEncoderInterface}, + interface::PureFunctionFormalArgsEncoderInterface, mir::{ contracts::{ContractsEncoderInterface, ProcedureContract}, pure::{ @@ -50,7 +51,7 @@ pub(super) struct PureFunctionEncoder<'p, 'v: 'p, 'tcx: 'v> { /// Span of the function declaration. span: Span, /// Signature of the function to be encoded. - sig: ty::PolyFnSig<'tcx>, + pub(crate) sig: ty::PolyFnSig<'tcx>, /// Spans of MIR locals, when encoding a local pure function. local_spans: Option>, } @@ -137,6 +138,38 @@ fn encode_mir<'p, 'v: 'p, 'tcx: 'v>( Ok(body_expr) } +impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionFormalArgsEncoderInterface<'p, 'v, 'tcx> + for PureFunctionEncoder<'p, 'v, 'tcx> +{ + fn encoder(&self) -> &'p Encoder<'v, 'tcx> { + self.encoder + } + + fn check_type( + &self, + var_span: Span, + ty: ty::Binder<'tcx, ty::Ty<'tcx>>, + ) -> SpannedEncodingResult<()> { + if !self + .encoder + .env() + .query + .type_is_copy(ty, self.parent_def_id) + { + Err(SpannedEncodingError::incorrect( + "pure function parameters must be Copy", + var_span, + )) + } else { + Ok(()) + } + } + + fn get_span(&self, local: mir::Local) -> Span { + self.get_local_span(local) + } +} + impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { #[tracing::instrument( name = "PureFunctionEncoder::new", @@ -314,7 +347,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { let mut precondition = vec![type_precondition, func_precondition]; let mut postcondition = vec![self.encode_postcondition_expr(&contract)?]; - let formal_args = self.encode_formal_args()?; + let formal_args = self.encode_formal_args(self.sig)?; let return_type = self.encode_function_return_type()?; let res_value_range_pos = self.encoder.error_manager().register_error( @@ -545,6 +578,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { .replace_place(&encoded_return.into(), &pure_fn_return_variable.into()) .set_default_pos(postcondition_pos); + if post.has_old_expression() { + return Err(SpannedEncodingError::incorrect( + "old expressions should not appear in the postconditions of pure functions", + self.span, + )); + } + Ok(post) } @@ -620,40 +660,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> { .with_span(self.span) } - fn encode_formal_args(&self) -> SpannedEncodingResult> { - let mut formal_args = vec![]; - for local_idx in 0..self.sig.skip_binder().inputs().len() { - let local_ty = self.sig.input(local_idx); - let local = prusti_rustc_interface::middle::mir::Local::from_usize(local_idx + 1); - let var_name = format!("{local:?}"); - let var_span = self.get_local_span(local); - - if !self - .encoder - .env() - .query - .type_is_copy(local_ty, self.parent_def_id) - { - return Err(SpannedEncodingError::incorrect( - "pure function parameters must be Copy", - var_span, - )); - } - - let var_type = self - .encoder - .encode_snapshot_type(local_ty.skip_binder()) - .with_span(var_span)?; - formal_args.push(vir::LocalVar::new(var_name, var_type)) - } - Ok(formal_args) - } - pub fn encode_function_call_info(&self) -> SpannedEncodingResult { Ok(FunctionCallInfo { name: self.encode_function_name(), type_arguments: self.encode_type_arguments()?, - formal_args: self.encode_formal_args()?, + formal_args: self.encode_formal_args(self.sig)?, return_type: self.encode_function_return_type()?, }) } diff --git a/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs b/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs index 49e44c76b25..6b8aa0cea89 100644 --- a/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs +++ b/prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs @@ -331,10 +331,11 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> substs, ); + let is_bodyless = self.is_trusted(proc_def_id, Some(substs)) + || !self.env().query.has_body(proc_def_id); + let maybe_identifier: SpannedEncodingResult = (|| { let proc_kind = self.get_proc_kind(proc_def_id, Some(substs)); - let is_bodyless = self.is_trusted(proc_def_id, Some(substs)) - || !self.env().query.has_body(proc_def_id); let mut function = if is_bodyless { pure_function_encoder.encode_bodyless_function()? } else { @@ -393,13 +394,29 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx> Err(error) => { self.register_encoding_error(error); debug!("Error encoding pure function: {:?}", proc_def_id); - let body = self - .env() - .body - .get_pure_fn_body(proc_def_id, substs, parent_def_id); - // TODO(tymap): does stub encoder need substs? - let stub_encoder = StubFunctionEncoder::new(self, proc_def_id, &body, substs); - let function = stub_encoder.encode_function()?; + let function = if !is_bodyless { + let pure_fn_body = + self.env() + .body + .get_pure_fn_body(proc_def_id, substs, parent_def_id); + let encoder = StubFunctionEncoder::new( + self, + proc_def_id, + Some(&pure_fn_body), + substs, + pure_function_encoder.sig, + ); + encoder.encode_function()? + } else { + let encoder = StubFunctionEncoder::new( + self, + proc_def_id, + None, + substs, + pure_function_encoder.sig, + ); + encoder.encode_function()? + }; self.log_vir_program_before_viper(function.to_string()); let identifier = self.insert_function(function); self.pure_function_encoder_state diff --git a/prusti-viper/src/encoder/mod.rs b/prusti-viper/src/encoder/mod.rs index 5faec050a9b..161b24f9288 100644 --- a/prusti-viper/src/encoder/mod.rs +++ b/prusti-viper/src/encoder/mod.rs @@ -12,6 +12,7 @@ mod encoder; mod errors; mod foldunfold; mod initialisation; +mod interface; mod loop_encoder; mod mir_encoder; mod mir_successor; diff --git a/prusti-viper/src/encoder/stub_function_encoder.rs b/prusti-viper/src/encoder/stub_function_encoder.rs index 5c5c3270ac9..68d3f4853bb 100644 --- a/prusti-viper/src/encoder/stub_function_encoder.rs +++ b/prusti-viper/src/encoder/stub_function_encoder.rs @@ -7,23 +7,45 @@ use crate::encoder::{ errors::{SpannedEncodingResult, WithSpan}, high::generics::HighGenericsEncoderInterface, - mir_encoder::{MirEncoder, PlaceEncoder}, + interface::PureFunctionFormalArgsEncoderInterface, snapshot::interface::SnapshotEncoderInterface, Encoder, }; use log::debug; use prusti_rustc_interface::{ hir::def_id::DefId, - middle::{mir, ty::GenericArgsRef}, + middle::{ + mir, ty, + ty::{Binder, GenericArgsRef}, + }, + span::Span, }; use vir_crate::polymorphic as vir; +use super::mir::specifications::SpecificationsInterface; + pub struct StubFunctionEncoder<'p, 'v: 'p, 'tcx: 'v> { encoder: &'p Encoder<'v, 'tcx>, - mir: &'p mir::Body<'tcx>, - mir_encoder: MirEncoder<'p, 'v, 'tcx>, + mir: Option<&'p mir::Body<'tcx>>, proc_def_id: DefId, substs: GenericArgsRef<'tcx>, + sig: ty::PolyFnSig<'tcx>, +} + +impl<'p, 'v, 'tcx> PureFunctionFormalArgsEncoderInterface<'p, 'v, 'tcx> + for StubFunctionEncoder<'p, 'v, 'tcx> +{ + fn check_type(&self, _span: Span, _ty: Binder>) -> SpannedEncodingResult<()> { + Ok(()) + } + + fn encoder(&self) -> &'p Encoder<'v, 'tcx> { + self.encoder + } + + fn get_span(&self, _local: mir::Local) -> Span { + self.encoder.get_spec_span(self.proc_def_id) + } } impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> { @@ -31,40 +53,36 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> { pub fn new( encoder: &'p Encoder<'v, 'tcx>, proc_def_id: DefId, - mir: &'p mir::Body<'tcx>, + mir: Option<&'p mir::Body<'tcx>>, substs: GenericArgsRef<'tcx>, + sig: ty::PolyFnSig<'tcx>, ) -> Self { StubFunctionEncoder { encoder, mir, - mir_encoder: MirEncoder::new(encoder, mir, proc_def_id), proc_def_id, substs, + sig, } } + fn default_span(&self) -> Span { + self.mir + .map(|m| m.span) + .unwrap_or_else(|| self.encoder.get_spec_span(self.proc_def_id)) + } + #[tracing::instrument(level = "debug", skip(self))] pub fn encode_function(&self) -> SpannedEncodingResult { let function_name = self.encode_function_name(); debug!("Encode stub function {}", function_name); - let formal_args: Vec<_> = self - .mir - .args_iter() - .map(|local| { - let var_name = self.mir_encoder.encode_local_var_name(local); - let mir_type = self.mir_encoder.get_local_ty(local); - self.encoder - .encode_snapshot_type(mir_type) - .map(|var_type| vir::LocalVar::new(var_name, var_type)) - }) - .collect::>() - .with_span(self.mir.span)?; + let formal_args = self.encode_formal_args(self.sig)?; let type_arguments = self .encoder .encode_generic_arguments(self.proc_def_id, self.substs) - .with_span(self.mir.span)?; + .with_span(self.default_span())?; let return_type = self.encode_function_return_type()?; @@ -74,8 +92,6 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> { formal_args, return_type, pres: vec![false.into()], - // Note: Silicon is currently unsound when declaring a function that ensures `false` - // See: https://github.com/viperproject/silicon/issues/376 posts: vec![], body: None, }; @@ -94,9 +110,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> { } pub fn encode_function_return_type(&self) -> SpannedEncodingResult { - let ty = self.mir.return_ty(); - let return_local = mir::Place::return_place().as_local().unwrap(); - let span = self.mir_encoder.get_local_span(return_local); - self.encoder.encode_snapshot_type(ty).with_span(span) + let ty = self.sig.output(); + + self.encoder + .encode_snapshot_type(ty.skip_binder()) + .with_span(self.encoder.get_spec_span(self.proc_def_id)) } } diff --git a/vir/defs/polymorphic/ast/expr.rs b/vir/defs/polymorphic/ast/expr.rs index 12e3183dfaf..e85ee38e739 100644 --- a/vir/defs/polymorphic/ast/expr.rs +++ b/vir/defs/polymorphic/ast/expr.rs @@ -1454,6 +1454,20 @@ impl Expr { PlaceReplacer { replacements }.fold(self) } + pub fn has_old_expression(&self) -> bool { + struct OldFinder { + has_old: bool, + } + impl ExprWalker for OldFinder { + fn walk_labelled_old(&mut self, _labelled_old: &LabelledOld) { + self.has_old = true; + } + } + let mut walker = OldFinder { has_old: false }; + walker.walk(self); + walker.has_old + } + /// Replaces expressions like `old[l5](old[l5](_9.val_ref).foo.bar)` /// into `old[l5](_9.val_ref.foo.bar)` #[must_use]