diff --git a/lib/sail.c b/lib/sail.c index c939598c3..cdd9a0336 100644 --- a/lib/sail.c +++ b/lib/sail.c @@ -812,6 +812,16 @@ void count_leading_zeros(sail_int *rop, const lbits op) } } +void count_trailing_zeros(sail_int *rop, const lbits op) +{ + if (mpz_cmp_ui(*op.bits, 0) == 0) { + mpz_set_ui(*rop, op.len); + } else { + mp_bitcnt_t ix = mpz_scan1(*op.bits, 0); + mpz_set_ui(*rop, ix); + } +} + bool eq_bits(const lbits op1, const lbits op2) { assert(op1.len == op2.len); diff --git a/lib/sail.h b/lib/sail.h index 259a4e56b..e035b9218 100644 --- a/lib/sail.h +++ b/lib/sail.h @@ -336,6 +336,7 @@ fbits fast_sign_extend2(const sbits op, const uint64_t m); void length_lbits(sail_int *rop, const lbits op); void count_leading_zeros(sail_int *rop, const lbits op); +void count_trailing_zeros(sail_int *rop, const lbits op); bool eq_bits(const lbits op1, const lbits op2); bool EQUAL(lbits)(const lbits op1, const lbits op2); diff --git a/lib/vector.sail b/lib/vector.sail index 88fc77c7a..1ec6f52a1 100644 --- a/lib/vector.sail +++ b/lib/vector.sail @@ -90,6 +90,7 @@ val vector_init = pure "vector_init" : forall 'n ('a : Type), 'n >= 0. (implicit overload length = {bitvector_length, vector_length} val count_leading_zeros = pure "count_leading_zeros" : forall 'N , 'N >= 1. bits('N) -> {'n, 0 <= 'n <= 'N . atom('n)} +val count_trailing_zeros = pure "count_trailing_zeros" : forall 'N , 'N >= 1. bits('N) -> {'n, 0 <= 'n <= 'N . atom('n)} $[sv_module { stdout = true }] val print_bits = pure "print_bits" : forall 'n. (string, bits('n)) -> unit diff --git a/src/gen_lib/sail2_operators_bitlists.lem b/src/gen_lib/sail2_operators_bitlists.lem index dc18b4e3e..70e2cf25a 100644 --- a/src/gen_lib/sail2_operators_bitlists.lem +++ b/src/gen_lib/sail2_operators_bitlists.lem @@ -312,3 +312,4 @@ let eq_vec = eq_bv let neq_vec = neq_bv let inline count_leading_zeros v = count_leading_zero_bits v +let inline count_trailing_zeros v = count_trailing_zero_bits v diff --git a/src/gen_lib/sail2_operators_mwords.lem b/src/gen_lib/sail2_operators_mwords.lem index bec1403c6..0885cf7bb 100644 --- a/src/gen_lib/sail2_operators_mwords.lem +++ b/src/gen_lib/sail2_operators_mwords.lem @@ -336,3 +336,6 @@ let inline neq_vec = neq_mword val count_leading_zeros : forall 'a. Size 'a => mword 'a -> integer let count_leading_zeros v = count_leading_zeros_bv v + +val count_trailing_zeros : forall 'a. Size 'a => mword 'a -> integer +let count_trailing_zeros v = count_trailing_zeros_bv v diff --git a/src/gen_lib/sail2_values.lem b/src/gen_lib/sail2_values.lem index a63b0628b..5eee9d265 100644 --- a/src/gen_lib/sail2_values.lem +++ b/src/gen_lib/sail2_values.lem @@ -754,6 +754,12 @@ let rec count_leading_zero_bits v = val count_leading_zeros_bv : forall 'a. Bitvector 'a => 'a -> integer let count_leading_zeros_bv v = count_leading_zero_bits (bits_of v) +val count_trailing_zero_bits : list bitU -> integer +let rec count_trailing_zero_bits v = count_leading_zeros_bv (List.reverse v) + +val count_trailing_zeros_bv : forall 'a. Bitvector 'a => 'a -> integer +let count_trailing_zeros_bv v = count_trailing_zero_bits (bits_of v) + val decimal_string_of_bv : forall 'a. Bitvector 'a => 'a -> string let decimal_string_of_bv bv = let place_values = diff --git a/src/lib/sail_lib.ml b/src/lib/sail_lib.ml index 962a54db2..9e9f8c95b 100644 --- a/src/lib/sail_lib.ml +++ b/src/lib/sail_lib.ml @@ -174,6 +174,8 @@ let count_leading_zeros xs = let rec aux bs acc = match bs with B0 :: bs' -> aux bs' (acc + 1) | _ -> acc in Big_int.of_int (aux xs 0) +let count_trailing_zeros xs = count_leading_zeros (List.rev xs) + let subrange (list, n, m) = let n = Big_int.to_int n in let m = Big_int.to_int m in diff --git a/src/lib/smt_gen.ml b/src/lib/smt_gen.ml index a9554ba83..74a3598e5 100644 --- a/src/lib/smt_gen.ml +++ b/src/lib/smt_gen.ml @@ -235,6 +235,7 @@ module type PRIMOP_GEN = sig val hex_str : Parse_ast.l -> ctyp -> string val hex_str_upper : Parse_ast.l -> ctyp -> string val count_leading_zeros : Parse_ast.l -> int -> string + val count_trailing_zeros : Parse_ast.l -> int -> string val fvector_store : Parse_ast.l -> int -> ctyp -> string val is_empty : Parse_ast.l -> ctyp -> string val hd : Parse_ast.l -> ctyp -> string @@ -1321,6 +1322,77 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct ) | _ -> builtin_type_error "count_leading_zeros" [v] (Some ret_ctyp) + let builtin_count_trailing_zeros v ret_ctyp = + let rec tzcnt ret_sz sz smt = + if sz == 1 then + Ite + ( Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]), + bvint ret_sz (Big_int.of_int 1), + bvint ret_sz Big_int.zero + ) + else ( + assert (sz land (sz - 1) = 0); + let hsz = sz / 2 in + Ite + ( Fn ("=", [Extract (hsz - 1, 0, smt); bvzero hsz]), + Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); tzcnt ret_sz hsz (Extract (sz - 1, hsz, smt))]), + tzcnt ret_sz hsz (Extract (hsz - 1, 0, smt)) + ) + ) + in + let smallest_greater_power_of_two n = + let m = ref 1 in + while !m < n do + m := !m lsl 1 + done; + assert (!m land (!m - 1) = 0); + !m + in + let ret_sz = int_size ret_ctyp in + let* smt = smt_cval v in + match cval_ctyp v with + | CT_fbits sz when sz land (sz - 1) = 0 -> return (tzcnt ret_sz sz smt) + | CT_fbits sz -> + let padded_sz = smallest_greater_power_of_two sz in + let padding = bvzero (padded_sz - sz) in + assert (padded_sz > sz); + return + (Fn + ( "bvsub", + [tzcnt ret_sz padded_sz (Fn ("concat", [padding; smt])); bvint ret_sz (Big_int.of_int (padded_sz - sz))] + ) + ) + | CT_lbits -> + if ret_sz > lbits_index then + return + (Fn + ( "bvsub", + [ + tzcnt ret_sz lbits_size (Fn ("contents", [smt])); + Fn + ( "bvsub", + [ + bvint ret_sz (Big_int.of_int lbits_size); + Fn ("concat", [bvzero (ret_sz - lbits_index); Fn ("len", [smt])]); + ] + ); + ] + ) + ) + else ( + let trailing_zeros = + Fn + ( "bvsub", + [ + tzcnt lbits_index lbits_size (Fn ("contents", [smt])); + Fn ("bvsub", [bvint lbits_index (Big_int.of_int lbits_size); Fn ("len", [smt])]); + ] + ) + in + return (Extract (ret_sz - 1, 0, trailing_zeros)) + ) + | _ -> builtin_type_error "count_trailing_zeros" [v] (Some ret_ctyp) + let rec builtin_eq_anything x y = match (cval_ctyp x, cval_ctyp y) with | CT_struct (xid, xfields), CT_struct (yid, yfields) -> @@ -1508,6 +1580,7 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct | "length" -> unary_primop builtin_length | "replicate_bits" -> binary_primop builtin_replicate_bits | "count_leading_zeros" -> unary_primop builtin_count_leading_zeros + | "count_trailing_zeros" -> unary_primop builtin_count_trailing_zeros | "eq_real" -> binary_primop (binary_smt "=") | "neg_real" -> unary_primop (unary_smt "-") | "add_real" -> binary_primop (binary_smt "+") diff --git a/src/lib/smt_gen.mli b/src/lib/smt_gen.mli index 49daccb5d..af40109f2 100644 --- a/src/lib/smt_gen.mli +++ b/src/lib/smt_gen.mli @@ -148,6 +148,7 @@ module type PRIMOP_GEN = sig val hex_str : Parse_ast.l -> ctyp -> string val hex_str_upper : Parse_ast.l -> ctyp -> string val count_leading_zeros : Parse_ast.l -> int -> string + val count_trailing_zeros : Parse_ast.l -> int -> string val fvector_store : Parse_ast.l -> int -> ctyp -> string val is_empty : Parse_ast.l -> ctyp -> string val hd : Parse_ast.l -> ctyp -> string diff --git a/src/lib/value.ml b/src/lib/value.ml index e0b33c734..c95e9f7c1 100644 --- a/src/lib/value.ml +++ b/src/lib/value.ml @@ -383,6 +383,10 @@ let value_count_leading_zeros = function | [v1] -> V_int (Sail_lib.count_leading_zeros (coerce_bv v1)) | _ -> failwith "value count_leading_zeros" +let value_count_trailing_zeros = function + | [v1] -> V_int (Sail_lib.count_trailing_zeros (coerce_bv v1)) + | _ -> failwith "value count_trailing_zeros" + let is_member = function V_member _ -> true | _ -> false let is_ctor = function V_ctor _ -> true | _ -> false @@ -807,6 +811,7 @@ let primops = ("internal_pick", value_internal_pick); ("replicate_bits", value_replicate_bits); ("count_leading_zeros", value_count_leading_zeros); + ("count_trailing_zeros", value_count_trailing_zeros); ("Elf_loader.elf_entry", fun _ -> V_int !Elf_loader.opt_elf_entry); ("Elf_loader.elf_tohost", fun _ -> V_int !Elf_loader.opt_elf_tohost); ("string_append", value_string_append); diff --git a/src/sail_smt_backend/jib_smt.ml b/src/sail_smt_backend/jib_smt.ml index e820ba222..11fde48af 100644 --- a/src/sail_smt_backend/jib_smt.ml +++ b/src/sail_smt_backend/jib_smt.ml @@ -247,6 +247,8 @@ module Make (Config : CONFIG) = struct let count_leading_zeros l = function _ -> Reporting.unreachable l __POS__ "count_leading_zeros" + let count_trailing_zeros l = function _ -> Reporting.unreachable l __POS__ "count_trailing_zeros" + let fvector_store l _ _ = "store" let is_empty l = function _ -> Reporting.unreachable l __POS__ "is_empty" diff --git a/src/sail_sv_backend/jib_sv.ml b/src/sail_sv_backend/jib_sv.ml index ee0eeed9a..3d10a7add 100644 --- a/src/sail_sv_backend/jib_sv.ml +++ b/src/sail_sv_backend/jib_sv.ml @@ -634,6 +634,8 @@ module Make (Config : CONFIG) = struct let count_leading_zeros l _ = Reporting.unreachable l __POS__ "count_leading_zeros" + let count_trailing_zeros l _ = Reporting.unreachable l __POS__ "count_trailing_zeros" + let fvector_store _l len ctyp = Primops.fvector_store len ctyp let is_empty l = function diff --git a/test/builtins/ctz.sail b/test/builtins/ctz.sail new file mode 100644 index 000000000..77e926ae1 --- /dev/null +++ b/test/builtins/ctz.sail @@ -0,0 +1,13 @@ +default Order dec +$include + +function main () : unit -> unit = { + assert(count_trailing_zeros(0x0) == 4); + assert(count_trailing_zeros(0x1) == 0); + assert(count_trailing_zeros(0x4) == 2); + assert(count_trailing_zeros(0xf) == 0); + + foreach (i from 0 to 32 by 1 in inc) { + assert(count_trailing_zeros(sail_shiftleft(0x00000001, i)) == i); + } +} diff --git a/test/smt/tzcnt.unsat.sail b/test/smt/tzcnt.unsat.sail new file mode 100644 index 000000000..44f8477a4 --- /dev/null +++ b/test/smt/tzcnt.unsat.sail @@ -0,0 +1,16 @@ +default Order dec + +$include + +val lzcnt = pure "count_trailing_zeros" : forall 'w. bits('w) -> range(0, 'w) + +$property +function prop() -> bool = { + let p1 = lzcnt(0x0) == 4; + let p2 = lzcnt(0x00) == 8; + let p3 = lzcnt(0x20) == 5; + let p4 = lzcnt(0b10000000) == 7; + let p5 = lzcnt(0b1) == 0; + let p6 = lzcnt(0xF) == 0; + p1 & p2 & p3 & p4 & p5 & p6 +}