Skip to content

Commit

Permalink
Add count_trailing_zero primop (#792)
Browse files Browse the repository at this point in the history
* Add count_trailing_zero primop

To go along with count_leading_zeros

* Trivial TC test fixups

Apparently I've caused the typechecker to ask for another existential
variable along the way, and so some error messages have changed only in
the syntactic labels of these variables.
  • Loading branch information
nwf authored Nov 22, 2024
1 parent ac83f8c commit 9787829
Show file tree
Hide file tree
Showing 28 changed files with 164 additions and 28 deletions.
10 changes: 10 additions & 0 deletions lib/sail.c
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,16 @@ void count_leading_zeros(sail_int *rop, const lbits op)
}
}

void count_trailing_zeros(sail_int *rop, const lbits op)
{
if (mpz_cmp_ui(*op.bits, 0) == 0) {
mpz_set_ui(*rop, op.len);
} else {
mp_bitcnt_t ix = mpz_scan1(*op.bits, 0);
mpz_set_ui(*rop, ix);
}
}

bool eq_bits(const lbits op1, const lbits op2)
{
assert(op1.len == op2.len);
Expand Down
1 change: 1 addition & 0 deletions lib/sail.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ fbits fast_sign_extend2(const sbits op, const uint64_t m);

void length_lbits(sail_int *rop, const lbits op);
void count_leading_zeros(sail_int *rop, const lbits op);
void count_trailing_zeros(sail_int *rop, const lbits op);

bool eq_bits(const lbits op1, const lbits op2);
bool EQUAL(lbits)(const lbits op1, const lbits op2);
Expand Down
1 change: 1 addition & 0 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ val vector_init = pure "vector_init" : forall 'n ('a : Type), 'n >= 0. (implicit
overload length = {bitvector_length, vector_length}

val count_leading_zeros = pure "count_leading_zeros" : forall 'N , 'N >= 1. bits('N) -> {'n, 0 <= 'n <= 'N . atom('n)}
val count_trailing_zeros = pure "count_trailing_zeros" : forall 'N , 'N >= 1. bits('N) -> {'n, 0 <= 'n <= 'N . atom('n)}

$[sv_module { stdout = true }]
val print_bits = pure "print_bits" : forall 'n. (string, bits('n)) -> unit
Expand Down
1 change: 1 addition & 0 deletions src/gen_lib/sail2_operators_bitlists.lem
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,4 @@ let eq_vec = eq_bv
let neq_vec = neq_bv

let inline count_leading_zeros v = count_leading_zero_bits v
let inline count_trailing_zeros v = count_trailing_zero_bits v
3 changes: 3 additions & 0 deletions src/gen_lib/sail2_operators_mwords.lem
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,6 @@ let inline neq_vec = neq_mword

val count_leading_zeros : forall 'a. Size 'a => mword 'a -> integer
let count_leading_zeros v = count_leading_zeros_bv v

val count_trailing_zeros : forall 'a. Size 'a => mword 'a -> integer
let count_trailing_zeros v = count_trailing_zeros_bv v
6 changes: 6 additions & 0 deletions src/gen_lib/sail2_values.lem
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,12 @@ let rec count_leading_zero_bits v =
val count_leading_zeros_bv : forall 'a. Bitvector 'a => 'a -> integer
let count_leading_zeros_bv v = count_leading_zero_bits (bits_of v)

val count_trailing_zero_bits : list bitU -> integer
let rec count_trailing_zero_bits v = count_leading_zeros_bv (List.reverse v)

val count_trailing_zeros_bv : forall 'a. Bitvector 'a => 'a -> integer
let count_trailing_zeros_bv v = count_trailing_zero_bits (bits_of v)

val decimal_string_of_bv : forall 'a. Bitvector 'a => 'a -> string
let decimal_string_of_bv bv =
let place_values =
Expand Down
2 changes: 2 additions & 0 deletions src/lib/sail_lib.ml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ let count_leading_zeros xs =
let rec aux bs acc = match bs with B0 :: bs' -> aux bs' (acc + 1) | _ -> acc in
Big_int.of_int (aux xs 0)

let count_trailing_zeros xs = count_leading_zeros (List.rev xs)

let subrange (list, n, m) =
let n = Big_int.to_int n in
let m = Big_int.to_int m in
Expand Down
73 changes: 73 additions & 0 deletions src/lib/smt_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ module type PRIMOP_GEN = sig
val hex_str : Parse_ast.l -> ctyp -> string
val hex_str_upper : Parse_ast.l -> ctyp -> string
val count_leading_zeros : Parse_ast.l -> int -> string
val count_trailing_zeros : Parse_ast.l -> int -> string
val fvector_store : Parse_ast.l -> int -> ctyp -> string
val is_empty : Parse_ast.l -> ctyp -> string
val hd : Parse_ast.l -> ctyp -> string
Expand Down Expand Up @@ -1321,6 +1322,77 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct
)
| _ -> builtin_type_error "count_leading_zeros" [v] (Some ret_ctyp)

let builtin_count_trailing_zeros v ret_ctyp =
let rec tzcnt ret_sz sz smt =
if sz == 1 then
Ite
( Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]),
bvint ret_sz (Big_int.of_int 1),
bvint ret_sz Big_int.zero
)
else (
assert (sz land (sz - 1) = 0);
let hsz = sz / 2 in
Ite
( Fn ("=", [Extract (hsz - 1, 0, smt); bvzero hsz]),
Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); tzcnt ret_sz hsz (Extract (sz - 1, hsz, smt))]),
tzcnt ret_sz hsz (Extract (hsz - 1, 0, smt))
)
)
in
let smallest_greater_power_of_two n =
let m = ref 1 in
while !m < n do
m := !m lsl 1
done;
assert (!m land (!m - 1) = 0);
!m
in
let ret_sz = int_size ret_ctyp in
let* smt = smt_cval v in
match cval_ctyp v with
| CT_fbits sz when sz land (sz - 1) = 0 -> return (tzcnt ret_sz sz smt)
| CT_fbits sz ->
let padded_sz = smallest_greater_power_of_two sz in
let padding = bvzero (padded_sz - sz) in
assert (padded_sz > sz);
return
(Fn
( "bvsub",
[tzcnt ret_sz padded_sz (Fn ("concat", [padding; smt])); bvint ret_sz (Big_int.of_int (padded_sz - sz))]
)
)
| CT_lbits ->
if ret_sz > lbits_index then
return
(Fn
( "bvsub",
[
tzcnt ret_sz lbits_size (Fn ("contents", [smt]));
Fn
( "bvsub",
[
bvint ret_sz (Big_int.of_int lbits_size);
Fn ("concat", [bvzero (ret_sz - lbits_index); Fn ("len", [smt])]);
]
);
]
)
)
else (
let trailing_zeros =
Fn
( "bvsub",
[
tzcnt lbits_index lbits_size (Fn ("contents", [smt]));
Fn ("bvsub", [bvint lbits_index (Big_int.of_int lbits_size); Fn ("len", [smt])]);
]
)
in
return (Extract (ret_sz - 1, 0, trailing_zeros))
)
| _ -> builtin_type_error "count_trailing_zeros" [v] (Some ret_ctyp)

let rec builtin_eq_anything x y =
match (cval_ctyp x, cval_ctyp y) with
| CT_struct (xid, xfields), CT_struct (yid, yfields) ->
Expand Down Expand Up @@ -1508,6 +1580,7 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct
| "length" -> unary_primop builtin_length
| "replicate_bits" -> binary_primop builtin_replicate_bits
| "count_leading_zeros" -> unary_primop builtin_count_leading_zeros
| "count_trailing_zeros" -> unary_primop builtin_count_trailing_zeros
| "eq_real" -> binary_primop (binary_smt "=")
| "neg_real" -> unary_primop (unary_smt "-")
| "add_real" -> binary_primop (binary_smt "+")
Expand Down
1 change: 1 addition & 0 deletions src/lib/smt_gen.mli
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ module type PRIMOP_GEN = sig
val hex_str : Parse_ast.l -> ctyp -> string
val hex_str_upper : Parse_ast.l -> ctyp -> string
val count_leading_zeros : Parse_ast.l -> int -> string
val count_trailing_zeros : Parse_ast.l -> int -> string
val fvector_store : Parse_ast.l -> int -> ctyp -> string
val is_empty : Parse_ast.l -> ctyp -> string
val hd : Parse_ast.l -> ctyp -> string
Expand Down
5 changes: 5 additions & 0 deletions src/lib/value.ml
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ let value_count_leading_zeros = function
| [v1] -> V_int (Sail_lib.count_leading_zeros (coerce_bv v1))
| _ -> failwith "value count_leading_zeros"

let value_count_trailing_zeros = function
| [v1] -> V_int (Sail_lib.count_trailing_zeros (coerce_bv v1))
| _ -> failwith "value count_trailing_zeros"

let is_member = function V_member _ -> true | _ -> false

let is_ctor = function V_ctor _ -> true | _ -> false
Expand Down Expand Up @@ -807,6 +811,7 @@ let primops =
("internal_pick", value_internal_pick);
("replicate_bits", value_replicate_bits);
("count_leading_zeros", value_count_leading_zeros);
("count_trailing_zeros", value_count_trailing_zeros);
("Elf_loader.elf_entry", fun _ -> V_int !Elf_loader.opt_elf_entry);
("Elf_loader.elf_tohost", fun _ -> V_int !Elf_loader.opt_elf_tohost);
("string_append", value_string_append);
Expand Down
2 changes: 2 additions & 0 deletions src/sail_smt_backend/jib_smt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ module Make (Config : CONFIG) = struct

let count_leading_zeros l = function _ -> Reporting.unreachable l __POS__ "count_leading_zeros"

let count_trailing_zeros l = function _ -> Reporting.unreachable l __POS__ "count_trailing_zeros"

let fvector_store l _ _ = "store"

let is_empty l = function _ -> Reporting.unreachable l __POS__ "is_empty"
Expand Down
2 changes: 2 additions & 0 deletions src/sail_sv_backend/jib_sv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ module Make (Config : CONFIG) = struct

let count_leading_zeros l _ = Reporting.unreachable l __POS__ "count_leading_zeros"

let count_trailing_zeros l _ = Reporting.unreachable l __POS__ "count_trailing_zeros"

let fvector_store _l len ctyp = Primops.fvector_store len ctyp

let is_empty l = function
Expand Down
13 changes: 13 additions & 0 deletions test/builtins/ctz.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
default Order dec
$include <prelude.sail>

function main () : unit -> unit = {
assert(count_trailing_zeros(0x0) == 4);
assert(count_trailing_zeros(0x1) == 0);
assert(count_trailing_zeros(0x4) == 2);
assert(count_trailing_zeros(0xf) == 0);

foreach (i from 0 to 32 by 1 in inc) {
assert(count_trailing_zeros(sail_shiftleft(0x00000001, i)) == i);
}
}
16 changes: 16 additions & 0 deletions test/smt/tzcnt.unsat.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
default Order dec

$include <prelude.sail>

val lzcnt = pure "count_trailing_zeros" : forall 'w. bits('w) -> range(0, 'w)

$property
function prop() -> bool = {
let p1 = lzcnt(0x0) == 4;
let p2 = lzcnt(0x00) == 8;
let p3 = lzcnt(0x20) == 5;
let p4 = lzcnt(0b10000000) == 7;
let p5 = lzcnt(0b1) == 0;
let p6 = lzcnt(0xF) == 0;
p1 & p2 & p3 & p4 & p5 & p6
}
2 changes: 1 addition & 1 deletion test/typecheck/fail/tuple_lexp1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
fail/tuple_lexp1.sail:10.11-20:
10 | (x, y) = (2, 3, 4)
 | ^-------^
 | Type mismatch between (int('ex182#), int('ex183#)) and (int(2), int(3), int(4))
 | Type mismatch between (int('ex183#), int('ex184#)) and (int(2), int(3), int(4))
 |
 | Caused by fail/tuple_lexp1.sail:10.2-8:
 | 10 | (x, y) = (2, 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/fail/tuple_lexp2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
fail/tuple_lexp2.sail:10.11-12:
10 | (x, y) = 2
 | ^
 | Type mismatch between (int('ex182#), int('ex183#)) and int(2)
 | Type mismatch between (int('ex183#), int('ex184#)) and int(2)
 |
 | Caused by fail/tuple_lexp2.sail:10.2-8:
 | 10 | (x, y) = 2
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/pass/Replicate/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Explicit effect annotations are deprecated. They are no longer used and can be r
 | ^------------------------^
 | Could not resolve quantifiers for replicate_bits
 | * 'M >= 0
 | * 'ex178# >= 0
 | * 'ex179# >= 0
6 changes: 3 additions & 3 deletions test/typecheck/pass/bool_constraint/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ All external bindings should be marked as either pure or impure
12 | if b then n else 4
 | ^
 | int(4) is not a subtype of {('m : Int), (('b & 'm == 'n) | (not('b) & 'm == 3)). int('m)}
 | as (('b & 'ex173 == 'n) | (not('b) & 'ex173 == 3)) could not be proven
 | as (('b & 'ex174 == 'n) | (not('b) & 'ex174 == 3)) could not be proven
 |
 | type variable 'ex173:
 | type variable 'ex174:
 | pass/bool_constraint/v1.sail:9.25-73:
 | 9 | (bool('b), int('n)) -> {'m, 'b & 'm == 'n | not('b) & 'm == 3. int('m)}
 |  | ^----------------------------------------------^ derived from here
 | pass/bool_constraint/v1.sail:12.19-20:
 | 12 | if b then n else 4
 |  | ^ bound here
 |  | has constraint: 4 == 'ex173
 |  | has constraint: 4 == 'ex174
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Old set syntax, {|1, 2, 3|} can now be written as {1, 2, 3}.
26 | Some(Ctor1(a, x, c))
 | ^------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor1
 | * 'ex249# in {32, 64}
 | * 'ex250# in {32, 64}
8 changes: 4 additions & 4 deletions test/typecheck/pass/existential_ast3/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
pass/existential_ast3/v1.sail:17.48-65:
17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 | ^---------------^
 | (int(33), int('ex214)) is not a subtype of (int('ex209), int('ex210))
 | (int(33), int('ex215)) is not a subtype of (int('ex210), int('ex211))
 | as false could not be proven
 |
 | type variable 'ex209:
 | type variable 'ex210:
 | pass/existential_ast3/v1.sail:16.23-25:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex210:
 | type variable 'ex211:
 | pass/existential_ast3/v1.sail:16.26-28:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex214:
 | type variable 'ex215:
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
8 changes: 4 additions & 4 deletions test/typecheck/pass/existential_ast3/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
pass/existential_ast3/v2.sail:17.48-65:
17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 | ^---------------^
 | (int(31), int('ex214)) is not a subtype of (int('ex209), int('ex210))
 | (int(31), int('ex215)) is not a subtype of (int('ex210), int('ex211))
 | as false could not be proven
 |
 | type variable 'ex209:
 | type variable 'ex210:
 | pass/existential_ast3/v2.sail:16.23-25:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex210:
 | type variable 'ex211:
 | pass/existential_ast3/v2.sail:16.26-28:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex214:
 | type variable 'ex215:
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast3/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
25 | Some(Ctor(64, unsigned(0b0 @ b @ a)))
 | ^-----------------------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor
 | * (64 in {32, 64} & (0 <= 'ex246# & 'ex246# < 64))
 | * (64 in {32, 64} & (0 <= 'ex247# & 'ex247# < 64))
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v4.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
36 | if is_64 then 64 else 32;
 | ^^
 | int(64) is not a subtype of {('d : Int), (('is_64 & 'd == 63) | (not('is_64) & 'd == 32)). int('d)}
 | as (('is_64 & 'ex258 == 63) | (not('is_64) & 'ex258 == 32)) could not be proven
 | as (('is_64 & 'ex259 == 63) | (not('is_64) & 'ex259 == 32)) could not be proven
 |
 | type variable 'ex258:
 | type variable 'ex259:
 | pass/existential_ast3/v4.sail:35.18-79:
 | 35 | let 'datasize : {'d, ('is_64 & 'd == 63) | (not('is_64) & 'd == 32). int('d)} =
 |  | ^-----------------------------------------------------------^ derived from here
 | pass/existential_ast3/v4.sail:36.18-20:
 | 36 | if is_64 then 64 else 32;
 |  | ^^ bound here
 |  | has constraint: 64 == 'ex258
 |  | has constraint: 64 == 'ex259
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v5.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 | ^-------------^
 | range(0, 63) is not a subtype of range(0, ('datasize - 2))
 | as (0 <= 'ex266 & 'ex266 <= ('datasize - 2)) could not be proven
 | as (0 <= 'ex267 & 'ex267 <= ('datasize - 2)) could not be proven
 |
 | type variable 'ex266:
 | type variable 'ex267:
 | pass/existential_ast3/v5.sail:37.10-33:
 | 37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 |  | ^---------------------^ derived from here
 | pass/existential_ast3/v5.sail:37.50-65:
 | 37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 |  | ^-------------^ bound here
 |  | has constraint: (0 <= 'ex266 & 'ex266 <= 63)
 |  | has constraint: (0 <= 'ex267 & 'ex267 <= 63)
Loading

0 comments on commit 9787829

Please sign in to comment.