Skip to content

Commit

Permalink
Infer slightly more types in mapping patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
Alasdair committed May 13, 2024
1 parent bad1b54 commit 3a0e9d3
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 37 deletions.
115 changes: 78 additions & 37 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1933,11 +1933,21 @@ let tc_assume nc (E_aux (aux, annot)) = E_aux (E_internal_assume (nc, E_aux (aux
type ('a, 'b) pattern_functions = {
infer : Env.t -> 'a -> 'b * Env.t * uannot exp list;
bind : Env.t -> 'a -> typ -> 'b * Env.t * uannot exp list;
strip : 'b -> 'a;
typ_of : 'b -> typ;
get_loc : 'a -> l;
get_loc_typed : 'b -> l;
}

type ('a, 'b) vector_concat_elem = VC_elem_ok of 'a | VC_elem_error of 'b * exn | VC_elem_unknown of 'a

let unwrap_vector_concat_elem ~at:l = function
| VC_elem_ok x -> x
| VC_elem_unknown x -> x
| VC_elem_error _ -> Reporting.unreachable l __POS__ "Tried to unwrap VC_elem_error"

let vector_concat_elem_is_ok = function VC_elem_ok _ -> true | _ -> false

module PC_config = struct
type t = tannot
let typ_of_t = typ_of_tannot
Expand Down Expand Up @@ -2748,40 +2758,50 @@ and bind_vector_concat_generic :
let typ_opt =
Option.bind typ_opt (fun typ ->
match destruct_any_vector_typ l env typ with
| Destruct_vector (len, elem_typ) -> Option.map (fun l -> (l, Some elem_typ)) (solve_unique env len)
| Destruct_bitvector len -> Option.map (fun l -> (l, None)) (solve_unique env len)
| Destruct_vector (len, elem_typ) -> Option.map (fun len -> (len, Some elem_typ)) (solve_unique env len)
| Destruct_bitvector len -> Option.map (fun len -> (len, None)) (solve_unique env len)
)
in

(* Try to infer any subpatterns, skipping those we cannot infer *)
let fold_pats (pats, env, guards) pat =
let wrap_some (x, y, z) = (Ok x, y, z) in
let wrap_ok (x, y, z) = (VC_elem_ok x, y, z) in
let inferred_pat, env, guards' =
if Option.is_none typ_opt then wrap_some (funcs.infer env pat)
else (try wrap_some (funcs.infer env pat) with Type_error _ as exn -> (Error (pat, exn), env, []))
if Option.is_none typ_opt then wrap_ok (funcs.infer env pat)
else (try wrap_ok (funcs.infer env pat) with Type_error _ as exn -> (VC_elem_error (pat, exn), env, []))
in
(inferred_pat :: pats, env, guards' @ guards)
in
let inferred_pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in
let inferred_pats = List.rev inferred_pats in

(* If we are checking a mapping we can have unknown types, in this
case we can't continue, so just give the entire vector concat an
unknown type. *)
let have_unknown =
allow_unknown
&& List.exists (function Ok pat -> is_unknown_type (funcs.typ_of pat) | Error _ -> false) inferred_pats
case we can't continue if there is more than a single unknown, so
just give the entire vector concat an unknown type. *)
let inferred_pats =
if allow_unknown then
List.map
(function
| VC_elem_ok pat -> if is_unknown_type (funcs.typ_of pat) then VC_elem_unknown pat else VC_elem_ok pat
| err -> err
)
inferred_pats
else inferred_pats
in
if have_unknown && List.for_all Result.is_ok inferred_pats then
(annotate (List.map Result.get_ok inferred_pats) unknown_typ, env, guards)
let num_unknowns = List.length (List.filter (function VC_elem_unknown _ -> true | _ -> false) inferred_pats) in
if num_unknowns > 1 || (num_unknowns > 0 && Option.is_none typ_opt) then (
match Util.option_first (function VC_elem_error (_, exn) -> Some exn | _ -> None) inferred_pats with
| Some exn -> raise exn
| None -> (annotate (List.map (unwrap_vector_concat_elem ~at:l) inferred_pats) unknown_typ, env, guards)
)
else (
(* Will be none if the subpatterns are bitvectors *)
let elem_typ =
match typ_opt with
| Some (_, elem_typ) -> elem_typ
| None -> (
match List.find_opt Result.is_ok inferred_pats with
| Some (Ok pat) -> begin
match List.find_opt vector_concat_elem_is_ok inferred_pats with
| Some (VC_elem_ok pat) -> begin
match destruct_any_vector_typ l env (funcs.typ_of pat) with
| Destruct_vector (_, t) -> Some t
| Destruct_bitvector _ -> None
Expand All @@ -2790,32 +2810,43 @@ and bind_vector_concat_generic :
)
in

(* We can handle a single None in inferred_pats from something like
(* We can handle a single uninferred element (treating unknown
elements as uninferred) in inferred_pats from something like
0b00 @ _ @ 0b00, because we know the wildcard will be bits('n - 4)
where 'n is the total length of the pattern. *)
let before_uninferred, rest = Util.take_drop Result.is_ok inferred_pats in
let before_uninferred = List.map Result.get_ok before_uninferred in
let before_uninferred, rest = Util.take_drop vector_concat_elem_is_ok inferred_pats in
let before_uninferred = List.map (unwrap_vector_concat_elem ~at:l) before_uninferred in
let uninferred, after_uninferred =
(* When we encounter an unknown or uninferred pattern, check the rest for a second 'bad'
pattern that is also unknown or uninferred *)
let check_rest ~first_bad rest =
let msg =
"Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation \
pattern"
in
match List.find_opt (fun elem -> not (vector_concat_elem_is_ok elem)) rest with
| Some (VC_elem_error (second_bad, _)) ->
typ_raise (funcs.get_loc second_bad)
(err_because (Err_other msg, first_bad, Err_other "A previous subpattern is here"))
| Some (VC_elem_unknown second_bad) ->
typ_raise
(funcs.get_loc (funcs.strip second_bad))
(err_because (Err_other msg, first_bad, Err_other "A previous subpattern is here"))
| _ -> ()
in
match rest with
| Error (first_uninferred, exn) :: rest ->
begin
match List.find_opt Result.is_error rest with
| Some (Error (second_uninferred, _)) ->
let msg =
"Cannot infer width here, as there are multiple subpatterns with unclear width in vector \
concatenation pattern"
in
typ_raise (funcs.get_loc second_uninferred)
(err_because
(Err_other msg, funcs.get_loc first_uninferred, Err_other "A previous subpattern is here")
)
| _ -> ()
end;
begin
match typ_opt with
| Some (total_len, _) -> (Some (total_len, first_uninferred), List.map Result.get_ok rest)
| None -> raise exn
end
| VC_elem_error (first_uninferred, exn) :: rest -> begin
check_rest ~first_bad:(funcs.get_loc first_uninferred) rest;
match typ_opt with
| Some (total_len, _) -> (Some (total_len, first_uninferred), List.map (unwrap_vector_concat_elem ~at:l) rest)
| None -> raise exn
end
| VC_elem_unknown first_unknown :: rest ->
let first_unknown = funcs.strip first_unknown in
check_rest ~first_bad:(funcs.get_loc first_unknown) rest;
(* If we have unknown elems, we check above that the typ_opt is Some _ *)
let total_len = fst (Option.get typ_opt) in
(Some (total_len, first_unknown), List.map (unwrap_vector_concat_elem ~at:l) rest)
| _ -> (None, [])
in

Expand Down Expand Up @@ -2876,7 +2907,16 @@ and bind_vector_concat_generic :

and bind_vector_concat_pat l env uannot pat pats typ_opt =
let annot_vcp pats typ = P_aux (P_vector_concat pats, (l, mk_tannot ~uannot env typ)) in
let funcs = { infer = infer_pat; bind = bind_pat; typ_of = typ_of_pat; get_loc = pat_loc; get_loc_typed = pat_loc } in
let funcs =
{
infer = infer_pat;
bind = bind_pat;
strip = strip_pat;
typ_of = typ_of_pat;
get_loc = pat_loc;
get_loc_typed = pat_loc;
}
in
bind_vector_concat_generic funcs annot_vcp l false env pat pats typ_opt

and bind_vector_concat_mpat l allow_unknown other_env env uannot mpat mpats typ_opt =
Expand All @@ -2885,6 +2925,7 @@ and bind_vector_concat_mpat l allow_unknown other_env env uannot mpat mpats typ_
{
infer = infer_mpat allow_unknown other_env;
bind = bind_mpat allow_unknown other_env;
strip = strip_mpat;
typ_of = typ_of_mpat;
get_loc = mpat_loc;
get_loc_typed = mpat_loc;
Expand Down
4 changes: 4 additions & 0 deletions test/c/encdec_subrange.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
x = 0xFFFC03FF
x = 0xF0000000
x = 0xFAB00000
ok
31 changes: 31 additions & 0 deletions test/c/encdec_subrange.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
default Order dec

$include <prelude.sail>

scattered union instr

val encdec : instr <-> bits(32)

scattered mapping encdec

union clause instr = X : bits(21)

mapping clause encdec = X(imm20 @ 0b0)
<-> 0xF @ imm20[9..0] @ 0x00 @ imm20[19..10]

val main : unit -> unit

function main() = {
let x = encdec(X(0xFFFFF @ 0b0));
print_bits("x = ", x);
let x = encdec(X(0x00000 @ 0b0));
print_bits("x = ", x);
let x = encdec(X(0b00 @ 0x00AB @ 0b00 @ 0b0));
print_bits("x = ", x);

match encdec(0xFFFC03FF) {
X(y) => assert(y == 0xFFFFF @ 0b0),
};

print_endline("ok");
}
10 changes: 10 additions & 0 deletions test/typecheck/fail/encdec_unknown_and_error.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Type error:
fail/encdec_unknown_and_error.sail:13.30-31:
13 |mapping clause encdec = X(y @ x @ Z() : bits(1))
 | ^
 | Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation pattern
 |
 | Caused by fail/encdec_unknown_and_error.sail:13.26-27:
 | 13 |mapping clause encdec = X(y @ x @ Z() : bits(1))
 |  | ^
 |  | A previous subpattern is here
14 changes: 14 additions & 0 deletions test/typecheck/fail/encdec_unknown_and_error.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
default Order dec

$include <prelude.sail>

scattered union instr

val encdec : instr <-> bits(32)

scattered mapping encdec

union clause instr = X : bits(21)

mapping clause encdec = X(y @ x @ Z() : bits(1))
<-> 0xF @ x @ 0x00 @ y

0 comments on commit 3a0e9d3

Please sign in to comment.