Skip to content

Commit

Permalink
Add a better inference rule for vector literals
Browse files Browse the repository at this point in the history
Mostly this improves typechecking in situations like the following,
where we are returning a bitvector with some length that satisfied
a property. In the example below, we are returning a vector of
length `'n` containing bitvectors of length `'m`, where `'n == 'm`
and `'n > 1`. Sail can now type-check this without any annotations.

```
register R : bool

register X : bits(32)

val test : unit -> {'n 'm, 'n > 1 & 'n == 'm. vector('n, bits('m))}

function test() = {
  if R then {
    [0b00, 0b11]
  } else {
    [match X { _ => 0b000 }, 0b001, 0b100]
  }
}
```

Previously one would need a lot of annotations to convince Sail that
this was ok, one on each vector literal.

Note that we don't rely on inferring the first element (any element
can be inferred), as can be seen in the second literal with the match.
  • Loading branch information
Alasdair committed May 15, 2024
1 parent 36eb33a commit 05f5970
Show file tree
Hide file tree
Showing 18 changed files with 143 additions and 33 deletions.
63 changes: 56 additions & 7 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2328,15 +2328,34 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
in
let checked_body = crule check_exp env body typ in
annot_exp (E_internal_plet (tpat, bind_exp, checked_body)) typ
| E_vector vec, _ ->
let len, vtyp =
| E_vector vec, orig_typ -> begin
let literal_len = List.length vec in
let tyvars, nc, typ =
match destruct_exist_plain typ with Some (tyvars, nc, typ) -> (tyvars, nc, typ) | None -> ([], nc_true, typ)
in
let len, elem_typ, is_generic =
match destruct_any_vector_typ l env typ with
| Destruct_vector (len, vtyp) -> (len, vtyp)
| Destruct_bitvector len -> (len, bit_typ)
| Destruct_vector (len, elem_typ) -> (len, elem_typ, true)
| Destruct_bitvector len -> (len, bit_typ, false)
in
let tyvars = List.fold_left (fun set kopt -> KidSet.add (kopt_kid kopt) set) KidSet.empty tyvars in
let tyvars, nc, elem_typ =
if not (KidSet.is_empty (KidSet.inter tyvars (tyvars_of_nexp len))) then (
let unifiers = unify_nexp l env tyvars len (nint literal_len) in
let elem_typ = subst_unifiers unifiers elem_typ in
let nc = KBindings.fold (fun v arg nc -> constraint_subst v arg nc) unifiers nc in
let tyvars = KBindings.fold (fun v _ tyvars -> KidSet.remove v tyvars) unifiers tyvars in
(tyvars, nc, elem_typ)
)
else if prove __POS__ env (nc_eq (nint literal_len) (nexp_simp len)) then (tyvars, nc, elem_typ)
else typ_error l "Vector literal with incorrect length"
in
let checked_items = List.map (fun i -> crule check_exp env i vtyp) vec in
if prove __POS__ env (nc_eq (nint (List.length vec)) (nexp_simp len)) then annot_exp (E_vector checked_items) typ
else typ_error l "Vector literal with incorrect length" (* FIXME: improve error message *)
match check_or_infer_sequence ~at:l env vec tyvars nc elem_typ with
| Some (vec, elem_typ) ->
annot_exp (E_vector vec)
(if is_generic then vector_typ (nint literal_len) elem_typ else bitvector_typ (nint literal_len))
| None -> typ_error l ("This vector literal does not satisfy the constraint in " ^ string_of_typ (mk_typ orig_typ))
end
| E_lit (L_aux (L_undef, _) as lit), _ ->
if can_be_undefined ~at:l env typ then
if is_typ_inhabited env (Env.expand_synonyms env typ) then annot_exp (E_lit lit) typ
Expand All @@ -2351,6 +2370,36 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
let inferred_exp = irule infer_exp env exp in
expect_subtype env inferred_exp typ

(* This function will check that a sequence of expressions all have
the same type, where that type can have additional type variables
and constraints that must be instantiated (usually these
variables/constraints come from an existential). *)
and check_or_infer_sequence ~at:l env xs tyvars nc typ =
let tyvars, nc, typ, xs =
List.fold_left
(fun (tyvars, nc, typ, xs) x ->
let goals = KidSet.inter tyvars (tyvars_of_typ typ) in
if not (KidSet.is_empty goals) then (
match irule infer_exp env x with
| exception Type_error _ -> (tyvars, nc, typ, Error x :: xs)
| x ->
let unifiers = unify l env goals typ (typ_of x) in
let typ = subst_unifiers unifiers typ in
let nc = KBindings.fold (fun v arg nc -> constraint_subst v arg nc) unifiers nc in
let tyvars = KBindings.fold (fun v _ tyvars -> KidSet.remove v tyvars) unifiers tyvars in
(tyvars, nc, typ, Ok x :: xs)
)
else (
let x = crule check_exp env x typ in
(tyvars, nc, typ, Ok x :: xs)
)
)
(tyvars, nc, typ, []) xs
in
if KidSet.is_empty tyvars && prove __POS__ env nc then
Some (List.rev_map (function Ok x -> x | Error x -> crule check_exp env x typ) xs, typ)
else None

and check_block l env exps ret_typ =
let final env exp = match ret_typ with Some typ -> crule check_exp env exp typ | None -> irule infer_exp env exp in
let annot_exp exp typ exp_typ = E_aux (exp, (l, mk_expected_tannot env typ exp_typ)) in
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/pass/Replicate/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ Explicit effect annotations are deprecated. They are no longer used and can be r
pass/Replicate/v2.sail:13.4-30:
13 | replicate_bits(x, 'N / 'M)
 | ^------------------------^
 | Failed to prove constraint: ('M * 'ex249#) == 'N
 | Failed to prove constraint: ('M * 'ex248#) == 'N
6 changes: 3 additions & 3 deletions test/typecheck/pass/bool_constraint/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ All external bindings should be marked as either pure or impure
12 | if b then n else 4
 | ^
 | int(4) is not a subtype of {('m : Int), (('b & 'm == 'n) | (not('b) & 'm == 3)). int('m)}
 | as (('b & 'ex240 == 'n) | (not('b) & 'ex240 == 3)) could not be proven
 | as (('b & 'ex239 == 'n) | (not('b) & 'ex239 == 3)) could not be proven
 |
 | type variable 'ex240:
 | type variable 'ex239:
 | pass/bool_constraint/v1.sail:9.25-73:
 | 9 | (bool('b), int('n)) -> {'m, 'b & 'm == 'n | not('b) & 'm == 3. int('m)}
 |  | ^----------------------------------------------^ derived from here
 | pass/bool_constraint/v1.sail:12.19-20:
 | 12 | if b then n else 4
 |  | ^ bound here
 |  | has constraint: 4 == 'ex240
 |  | has constraint: 4 == 'ex239
17 changes: 17 additions & 0 deletions test/typecheck/pass/ex_vector_infer.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
default Order dec

$include <prelude.sail>

register R : bool

register X : bits(32)

val test : unit -> {'n 'm, 'n > 1 & 'n == 'm. vector('n, bits('m))}

function test() = {
if R then {
[0b00, 0b11]
} else {
[match X { _ => 0b000 }, 0b001, 0b100]
}
}
5 changes: 5 additions & 0 deletions test/typecheck/pass/ex_vector_infer/v1.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Type error:
pass/ex_vector_infer/v1.sail:13.4-9:
13 | [0b1]
 | ^---^
 | This vector literal does not satisfy the constraint in {('n : Int) ('m : Int), ('n > 1 & 'n == 'm). vector('n, bitvector('m))}
17 changes: 17 additions & 0 deletions test/typecheck/pass/ex_vector_infer/v1.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
default Order dec

$include <prelude.sail>

register R : bool

register X : bits(32)

val test : unit -> {'n 'm, 'n > 1 & 'n == 'm. vector('n, bits('m))}

function test() = {
if R then {
[0b1]
} else {
[match X { _ => 0b000 }, 0b001, 0b100]
}
}
5 changes: 5 additions & 0 deletions test/typecheck/pass/ex_vector_infer/v2.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Type error:
pass/ex_vector_infer/v2.sail:13.11-16:
13 | [0b00, 0b111]
 | ^---^
 | Failed to prove constraint: 3 == 2
17 changes: 17 additions & 0 deletions test/typecheck/pass/ex_vector_infer/v2.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
default Order dec

$include <prelude.sail>

register R : bool

register X : bits(32)

val test : unit -> {'n 'm, 'n > 1 & 'n == 'm. vector('n, bits('m))}

function test() = {
if R then {
[0b00, 0b111]
} else {
[match X { _ => 0b000 }, 0b001, 0b100]
}
}
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Old set syntax, {|1, 2, 3|} can now be written as {1, 2, 3}.
26 | Some(Ctor1(a, x, c))
 | ^------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor1
 | * 'ex335# in {32, 64}
 | * 'ex334# in {32, 64}
8 changes: 4 additions & 4 deletions test/typecheck/pass/existential_ast3/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
pass/existential_ast3/v1.sail:17.48-65:
17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 | ^---------------^
 | (int(33), int('ex289)) is not a subtype of (int('ex284), int('ex285))
 | (int(33), int('ex288)) is not a subtype of (int('ex283), int('ex284))
 | as false could not be proven
 |
 | type variable 'ex284:
 | type variable 'ex283:
 | pass/existential_ast3/v1.sail:16.23-25:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex285:
 | type variable 'ex284:
 | pass/existential_ast3/v1.sail:16.26-28:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex289:
 | type variable 'ex288:
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
8 changes: 4 additions & 4 deletions test/typecheck/pass/existential_ast3/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
pass/existential_ast3/v2.sail:17.48-65:
17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 | ^---------------^
 | (int(31), int('ex289)) is not a subtype of (int('ex284), int('ex285))
 | (int(31), int('ex288)) is not a subtype of (int('ex283), int('ex284))
 | as false could not be proven
 |
 | type variable 'ex284:
 | type variable 'ex283:
 | pass/existential_ast3/v2.sail:16.23-25:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex285:
 | type variable 'ex284:
 | pass/existential_ast3/v2.sail:16.26-28:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex289:
 | type variable 'ex288:
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast3/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
25 | Some(Ctor(64, unsigned(0b0 @ b @ a)))
 | ^-----------------------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor
 | * (64 in {32, 64} & (0 <= 'ex326# & 'ex326# < 64))
 | * (64 in {32, 64} & (0 <= 'ex325# & 'ex325# < 64))
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v4.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
36 | if is_64 then 64 else 32;
 | ^^
 | int(64) is not a subtype of {('d : Int), (('is_64 & 'd == 63) | (not('is_64) & 'd == 32)). int('d)}
 | as (('is_64 & 'ex342 == 63) | (not('is_64) & 'ex342 == 32)) could not be proven
 | as (('is_64 & 'ex341 == 63) | (not('is_64) & 'ex341 == 32)) could not be proven
 |
 | type variable 'ex342:
 | type variable 'ex341:
 | pass/existential_ast3/v4.sail:35.18-79:
 | 35 | let 'datasize : {'d, ('is_64 & 'd == 63) | (not('is_64) & 'd == 32). int('d)} =
 |  | ^-----------------------------------------------------------^ derived from here
 | pass/existential_ast3/v4.sail:36.18-20:
 | 36 | if is_64 then 64 else 32;
 |  | ^^ bound here
 |  | has constraint: 64 == 'ex342
 |  | has constraint: 64 == 'ex341
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v5.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 | ^-------------^
 | range(0, 63) is not a subtype of range(0, ('datasize - 2))
 | as (0 <= 'ex350 & 'ex350 <= ('datasize - 2)) could not be proven
 | as (0 <= 'ex349 & 'ex349 <= ('datasize - 2)) could not be proven
 |
 | type variable 'ex350:
 | type variable 'ex349:
 | pass/existential_ast3/v5.sail:37.10-33:
 | 37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 |  | ^---------------------^ derived from here
 | pass/existential_ast3/v5.sail:37.50-65:
 | 37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 |  | ^-------------^ bound here
 |  | has constraint: (0 <= 'ex350 & 'ex350 <= 63)
 |  | has constraint: (0 <= 'ex349 & 'ex349 <= 63)
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v6.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
37 | let n : range(0, 'datasize - 1) = if is_64 then unsigned(b @ a) else unsigned(b @ a);
 | ^-------------^
 | range(0, 63) is not a subtype of range(0, ('datasize - 1))
 | as (0 <= 'ex356 & 'ex356 <= ('datasize - 1)) could not be proven
 | as (0 <= 'ex355 & 'ex355 <= ('datasize - 1)) could not be proven
 |
 | type variable 'ex356:
 | type variable 'ex355:
 | pass/existential_ast3/v6.sail:37.10-33:
 | 37 | let n : range(0, 'datasize - 1) = if is_64 then unsigned(b @ a) else unsigned(b @ a);
 |  | ^---------------------^ derived from here
 | pass/existential_ast3/v6.sail:37.71-86:
 | 37 | let n : range(0, 'datasize - 1) = if is_64 then unsigned(b @ a) else unsigned(b @ a);
 |  | ^-------------^ bound here
 |  | has constraint: (0 <= 'ex356 & 'ex356 <= 63)
 |  | has constraint: (0 <= 'ex355 & 'ex355 <= 63)
2 changes: 1 addition & 1 deletion test/typecheck/pass/if_infer/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
10 | let _ = 0b100[if R then 0 else f()];
 | ^-------------------------^
 | Could not resolve quantifiers for bitvector_access
 | * (0 <= 'ex241# & 'ex241# < 3)
 | * (0 <= 'ex240# & 'ex240# < 3)
 |
 | Caused by pass/if_infer/v1.sail:10.10-37:
 | 10 | let _ = 0b100[if R then 0 else f()];
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/pass/if_infer/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
10 | let _ = 0b1001[if R then 0 else f()];
 | ^--------------------------^
 | Could not resolve quantifiers for bitvector_access
 | * (0 <= 'ex241# & 'ex241# < 4)
 | * (0 <= 'ex240# & 'ex240# < 4)
 |
 | Caused by pass/if_infer/v2.sail:10.10-38:
 | 10 | let _ = 0b1001[if R then 0 else f()];
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/pass/reg_32_64/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ Explicit effect annotations are deprecated. They are no longer used and can be r
 | * sub_int
 | pass/reg_32_64/v3.sail:29.15-17:
 | 29 | reg_deref(R)['d - 1 .. 0]
 |  | ^^ checking function argument has type int('ex174#)
 |  | ^^ checking function argument has type int('ex173#)
 |  | Cannot re-write sizeof('d)

0 comments on commit 05f5970

Please sign in to comment.