From 3c351e3279019b838c6cc3738e67d353bae2e34d Mon Sep 17 00:00:00 2001 From: Federico Poli Date: Tue, 9 Jan 2024 15:21:25 +0100 Subject: [PATCH] Strengthen the postcondition of bisect --- .../verify_overflow/pass/overflow/bisect.rs | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/prusti-tests/tests/verify_overflow/pass/overflow/bisect.rs b/prusti-tests/tests/verify_overflow/pass/overflow/bisect.rs index df6795e22dd..9ad25f16cf0 100644 --- a/prusti-tests/tests/verify_overflow/pass/overflow/bisect.rs +++ b/prusti-tests/tests/verify_overflow/pass/overflow/bisect.rs @@ -1,22 +1,44 @@ +// compile-flags: -Puse_more_complete_exhale=false use prusti_contracts::*; -/// A monotonically increasing discrete function, with domain [0, domain_size) -trait Function { +/// A monotonically strictly increasing discrete function, with domain [0, domain_size) +pub trait Function { #[pure] fn domain_size(&self) -> usize; #[pure] #[requires(x < self.domain_size())] fn eval(&self, x: usize) -> i32; + + predicate!{ + fn invariant(&self) -> bool { + forall(|x1: usize, x2: usize| + x1 < x2 && x2 < self.domain_size() ==> self.eval(x1) < self.eval(x2) + ) + } + } } -/// Find the `x` s.t. `f(x) == target` -#[ensures(if let Some(x) = result { f.eval(x) == target } else { true })] -fn bisect(f: &T, target: i32) -> Option { +/// Find the unique `x` s.t. `f(x) == target` +#[requires(f.invariant())] +#[ensures(match result { + Some(found_x) => { + f.eval(found_x) == target && + forall(|x: usize| x < f.domain_size() && f.eval(x) == target ==> x == found_x) + } + None => { + forall(|x: usize| x < f.domain_size() ==> f.eval(x) != target) + } +})] +pub fn bisect(f: &T, target: i32) -> Option { let mut low = 0; let mut high = f.domain_size(); while low < high { - body_invariant!(low < high && high <= f.domain_size()); + body_invariant!(f.invariant()); + body_invariant!(high <= f.domain_size()); + body_invariant!(forall(|x: usize| + (x < low || high <= x) && x < f.domain_size() ==> f.eval(x) != target + )); let mid = low + ((high - low) / 2); let mid_val = f.eval(mid); if mid_val < target {