Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BV, CP): Add propagators for bvshl and bvlshr #1085

Merged
merged 1 commit into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/lib/reasoners/bitlist.ml
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,61 @@ let mul a b =
in
concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@
concat mid_bits low_bits

let shl a b =
(* If the minimum value for [b] is larger than the bitwidth, the result is
zero.

Otherwise, any low zero bit in [a] is also a zero bit in the result, and
the minimum value for [b] also accounts for that many minimum zeros (e.g.
?000 shifted by at least 2 has at least 5 low zeroes).

NB: [increase_lower_bound b Z.zero] is the smallest positive integer that
matches the bitlist pattern, and so is a lower bound. Ideally we would
like to use the lower bound from the interval domain for [b] instead. *)
match Z.to_int (increase_lower_bound b Z.zero) with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit weird to use increase_lower_bound in order to extract the lower bound of b. I think we should write a specialization:

let extract_lower_bound b = increase_lower_bound Z.zero 

Alternatively, we can rename increase_lower_bound because its name is not very descriptive but I have no good suggestions for it right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the comment above notes, we would want to use the bounds from the interval domain instead, so this should disappear at some point. I will clarify that increase_lower_bound Z.zero is the smallest positive valid lower bound.

| n when n < width a ->
let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in
if low_zeros + n >= width a then
exact (width a) Z.zero (Ex.union (explanation a) (explanation b))
else if low_zeros + n > 0 then
concat (unknown (width a - low_zeros - n) Ex.empty) @@
exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b))
Halbaroth marked this conversation as resolved.
Show resolved Hide resolved
else
unknown (width a) Ex.empty
| _ | exception Z.Overflow ->
exact (width a) Z.zero (explanation b)
Comment on lines +323 to +333
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's bind width a to a shorter name like sz.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code is removed in #1144 I'd rather not — it would only add rebase conflicts.


let lshr a b =
(* If the minimum value for [b] is larger than the bitwidth, the result is
zero.

Otherwise, any high zero bit in [a] is also a zero bit in the result, and
the minimum value for [b] also accounts for that many minimum zeros (e.g.
000??? shifted by at least 2 is 00000?).

NB: [increase_lower_bound b Z.zero] is the smallest positive integer that
matches the bitlist pattern, and so is a lower bound. Ideally we would
like to use the lower bound from the interval domain for [b] instead. *)
match Z.to_int (increase_lower_bound b Z.zero) with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remark here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

| n when n < width a ->
let sz = width a in
if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *)
let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in
let nb_msb_zeros = sz - low_msb_zero in
assert (nb_msb_zeros > 0);
let nb_zeros = nb_msb_zeros + n in
if nb_zeros >= sz then
exact sz Z.zero (Ex.union (explanation a) (explanation b))
else
concat
(exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b)))
(unknown (sz - nb_zeros) Ex.empty)
else if n > 0 then
concat
(exact n Z.zero (explanation b))
(unknown (sz - n) Ex.empty)
else
unknown sz Ex.empty
| _ | exception Z.Overflow ->
exact (width a) Z.zero (explanation b)
6 changes: 6 additions & 0 deletions src/lib/reasoners/bitlist.mli
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ val logxor : t -> t -> t
val mul : t -> t -> t
(** Integer multiplication. *)

val shl : t -> t -> t
(** Logical left shift. *)

val lshr : t -> t -> t
(** Logical right shift. *)

val concat : t -> t -> t
(** Bit-vector concatenation. *)

Expand Down
4 changes: 3 additions & 1 deletion src/lib/reasoners/bitv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ module Shostak(X : ALIEN) = struct
| Op (
Concat | Extract _ | BV2Nat
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem)
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr)
-> true
| _ -> false

Expand Down Expand Up @@ -409,6 +410,7 @@ module Shostak(X : ALIEN) = struct
| { f = Op (
BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr
); _ } ->
X.term_embed t, []
| _ -> X.make t
Expand Down
78 changes: 77 additions & 1 deletion src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ module Constraint : sig

This uses the convention that [x % 0] is [x]. *)

val bvshl : X.r -> X.r -> X.r -> t
(** [bvshl r x y] is the constraint [r = x << y] *)

val bvlshr : X.r -> X.r -> X.r -> t
(** [bvshl r x y] is the constraint [r = x >> y] *)

val bvule : X.r -> X.r -> t

val bvugt : X.r -> X.r -> t
Expand All @@ -273,6 +279,8 @@ end = struct
| Band | Bor | Bxor
(* Arithmetic operations *)
| Badd | Bmul | Budiv | Burem
(* Shift operations *)
| Bshl | Blshr

let pp_binop ppf = function
| Band -> Fmt.pf ppf "bvand"
Expand All @@ -282,6 +290,8 @@ end = struct
| Bmul -> Fmt.pf ppf "bvmul"
| Budiv -> Fmt.pf ppf "bvudiv"
| Burem -> Fmt.pf ppf "bvurem"
| Bshl -> Fmt.pf ppf "bvshl"
| Blshr -> Fmt.pf ppf "bvlshr"

let equal_binop op1 op2 =
match op1, op2 with
Expand All @@ -304,12 +314,18 @@ end = struct
| Budiv, _ | _, Budiv -> false

| Burem, Burem -> true
| Burem, _ | _, Burem -> false

| Bshl, Bshl -> true
| Bshl, _ | _, Bshl -> false

| Blshr, Blshr -> true

let hash_binop : binop -> int = Hashtbl.hash

let is_commutative = function
| Band | Bor | Bxor | Badd | Bmul -> true
| Budiv | Burem -> false
| Budiv | Burem | Bshl | Blshr -> false

let propagate_binop ~ex dx op dy dz =
let open Bitlist_domains.Ephemeral in
Expand Down Expand Up @@ -343,6 +359,12 @@ end = struct
(* TODO: full adder propagation *)
()

| Bshl -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.shl !!dy !!dz)

| Blshr -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.lshr !!dy !!dz)

| Bmul -> (* Only forward propagation for now *)
update ~ex dx (Bitlist.mul !!dy !!dz)

Expand All @@ -361,6 +383,12 @@ end = struct
update ~ex dy @@ norm @@ Intervals.Int.sub !!dr !!dx;
update ~ex dx @@ norm @@ Intervals.Int.sub !!dr !!dy

| Bshl -> (* Only forward propagation for now *)
update ~ex dr @@ Intervals.Int.bvshl ~size:sz !!dx !!dy

| Blshr -> (* Only forward propagation for now *)
update ~ex dr @@ Intervals.Int.lshr !!dx !!dy

| Bmul -> (* Only forward propagation for now *)
update ~ex dr @@ norm @@ Intervals.Int.mul !!dx !!dy

Expand Down Expand Up @@ -574,6 +602,8 @@ end = struct
let bvmul = cbinop Bmul
let bvudiv = cbinop Budiv
let bvurem = cbinop Burem
let bvshl = cbinop Bshl
let bvlshr = cbinop Blshr

let crel r = hcons @@ Crel r

Expand Down Expand Up @@ -729,6 +759,27 @@ end = struct
) else
false

(* Add the constraint: r = x >> c *)
let add_lshr_const acts r x c =
let sz = bitwidth r in
match Z.to_int c with
| 0 -> add_eq acts r x
| n when n < sz ->
assert (n > 0);
let r_bitv = Shostak.Bitv.embed r in
let low_bits =
Shostak.Bitv.is_mine @@
Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x)
in
add_eq acts
(Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv)
low_bits;
add_eq_const acts
(Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv)
Z.zero
| _ | exception Z.Overflow ->
add_eq_const acts r Z.zero

(* Ground evaluation rules for binary operators. *)
let eval_binop op ty x y =
match op with
Expand All @@ -747,6 +798,18 @@ end = struct
cast ty x
else
cast ty @@ Z.rem x y
| Bshl -> (
match ty, Z.to_int y with
| Tbitv sz, y when y < sz ->
cast ty @@ Z.shift_left x y
| _ | exception Z.Overflow -> cast ty Z.zero
)
| Blshr -> (
match ty, Z.to_int y with
| Tbitv sz, y when y < sz ->
cast ty @@ Z.shift_right x y
| _ | exception Z.Overflow -> cast ty Z.zero
)

(* Constant simplification rules for binary operators.

Expand Down Expand Up @@ -793,6 +856,17 @@ end = struct

| Budiv | Burem -> false

(* shifts becomes a simple extraction when we know the right-hand side *)
| Bshl when X.is_constant y ->
add_shl_const acts r x (value y);
true
| Bshl -> false

| Blshr when X.is_constant y ->
add_lshr_const acts r x (value y);
true
| Blshr -> false

(* Algebraic rewrite rules for binary operators.

Rules based on constant simplifications are in [rw_binop_const]. *)
Expand Down Expand Up @@ -864,6 +938,8 @@ let extract_binop =
| BVmul -> Some bvmul
| BVudiv -> Some bvudiv
| BVurem -> Some bvurem
| BVshl -> Some bvshl
| BVlshr -> Some bvlshr
| _ -> None

let extract_constraints bcs uf r t =
Expand Down
67 changes: 67 additions & 0 deletions src/lib/reasoners/intervals.ml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,48 @@ module ZEuclideanType = struct
| Neg_infinite -> Pos_infinite
| Pos_infinite -> Neg_infinite
| Finite n -> Finite (Z.lognot n)

(* Values larger than [max_int] are treated as +oo *)
let shift_left ?(max_int = max_int) x y =
let[@inline always] must_be_nonnegative () =
Fmt.invalid_arg "shl: must shift by nonnegative amount"
in
match y with
| Neg_infinite -> must_be_nonnegative ()
| Finite y when Z.sign y < 0 -> must_be_nonnegative ()
| Pos_infinite -> Pos_infinite
| Finite y ->
match Z.to_int y with
| exception Z.Overflow -> Pos_infinite
| y ->
if y <= max_int then
match x with
| Neg_infinite -> Neg_infinite
| Pos_infinite -> Pos_infinite
| Finite x -> Finite (Z.shift_left x y)
else Pos_infinite

let shift_right x y =
match y with
| Neg_infinite ->
invalid_arg "shift_right: must shift by nonnegative amount"
| Finite y when Z.sign y < 0 ->
invalid_arg "shift_right: must shift by nonnegative amount"
| Pos_infinite -> (
match x with
| Pos_infinite -> invalid_arg "shift_right: undefined limit"
| _ -> zero
)
| Finite y ->
match x with
| Neg_infinite -> Neg_infinite
| Pos_infinite -> Pos_infinite
| Finite x ->
match Z.to_int y with
| exception Z.Overflow ->
(* y > max_int -> x >> y = 0 since numbits x <= max_int *)
zero
| y -> Finite (Z.shift_right x y)
end

(* AlgebraicType interface for reals
Expand Down Expand Up @@ -669,6 +711,31 @@ module Int = struct
interval_set { lb = ZEuclideanType.zero ; ub }
) u1
) u2

let bvshl ~size u1 u2 =
assert (size > 0);
(* Values higher than [max_int] ultimately map to [0] *)
let max_int = size - 1 in
let zero_i = { lb = ZEuclideanType.zero ; ub = ZEuclideanType.zero } in
extract ~ofs:0 ~len:size @@
of_set_nonempty @@
map_to_set (fun i2 ->
assert (ZEuclideanType.sign i2.lb >= 0);
if ZEuclideanType.(compare i2.lb (finite @@ Z.of_int max_int)) > 0 then
(* if i2.lb > max_int, the result is always zero
must not call ZEuclideanType.shift_left or we will likely OOM *)
interval_set zero_i
else
(* equivalent to multiplication by a positive constant *)
approx_map_inc_to_set
(fun lb -> ZEuclideanType.shift_left lb i2.lb)
(fun ub -> ZEuclideanType.shift_left ~max_int ub i2.ub)
u1
) u2

let lshr u1 u2 =
of_set_nonempty @@
map2_mon_to_set ZEuclideanType.shift_right Inc u1 Dec u2
end

module Legacy = struct
Expand Down
13 changes: 13 additions & 0 deletions src/lib/reasoners/intervals.mli
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ module Int : sig
theory, i.e. where [bvurem n 0] is [n].

[s] and [t] must be within the [0, 2^sz - 1] range. *)

val bvshl : size:int -> t -> t -> t
(** [shl sz s t] computes an overapproximation of the left shift [s lsl t],
truncating the result to [sz] bits.

[s] and [t] must only contain non-negative integers. *)

val lshr : t -> t -> t
(** [lshr s t] computes an approximation of the logical right shift [s lsr t].

Note that the result of logical right shift is independent of bit width.

[s] and [t] must only contain non-negative integers. *)
end

module Legacy : sig
Expand Down
4 changes: 2 additions & 2 deletions src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3185,8 +3185,8 @@ module BV = struct
(bvneg u)

(* Shift operations *)
let bvshl s t = int2bv (size2 s t) Ints.(bv2nat s * (~$2 ** bv2nat t))
let bvlshr s t = int2bv (size2 s t) Ints.(bv2nat s / (~$2 ** bv2nat t))
let bvshl s t = mk_term (Op BVshl) [s; t] (type_info s)
let bvlshr s t = mk_term (Op BVlshr) [s; t] (type_info s)
let bvashr s t =
let m = size2 s t in
ite (is (extract (m - 1) (m - 1) s) 0)
Expand Down
6 changes: 6 additions & 0 deletions src/lib/structures/symbols.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type operator =
| Extract of int * int (* lower bound * upper bound *)
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr
| Int2BV of int | BV2Nat
(* FP *)
| Float
Expand Down Expand Up @@ -193,6 +194,7 @@ let compare_operators op1 op2 =
| Integer_log2 | Pow | Integer_round
| BVnot | BVand | BVor | BVxor
| BVadd | BVsub | BVmul | BVudiv | BVurem
| BVshl | BVlshr
| Int2BV _ | BV2Nat
| Not_theory_constant | Is_theory_constant | Linear_dependency
| Constr _ | Destruct _ | Tite) -> assert false
Expand Down Expand Up @@ -356,6 +358,8 @@ module AEPrinter = struct
| BVmul -> Fmt.pf ppf "bvmul"
| BVudiv -> Fmt.pf ppf "bvudiv"
| BVurem -> Fmt.pf ppf "bvurem"
| BVshl -> Fmt.pf ppf "bvshl"
| BVlshr -> Fmt.pf ppf "bvlshr"

(* ArraysEx theory *)
| Get -> Fmt.pf ppf "get"
Expand Down Expand Up @@ -461,6 +465,8 @@ module SmtPrinter = struct
| BVmul -> Fmt.pf ppf "bvmul"
| BVudiv -> Fmt.pf ppf "bvudiv"
| BVurem -> Fmt.pf ppf "bvurem"
| BVshl -> Fmt.pf ppf "bvshl"
| BVlshr -> Fmt.pf ppf "bvlshr"

(* ArraysEx theory *)
| Get -> Fmt.pf ppf "select"
Expand Down
Loading
Loading