From 507ac9f98aa305876d0fdef48c3a9e1c6dac153f Mon Sep 17 00:00:00 2001 From: Alasdair Date: Tue, 21 May 2024 00:13:17 +0100 Subject: [PATCH] Update typing rule for cons Ensure that (x :: y :: [||]) is always checked the same as [|x, y|]. Previously this was not the case as the list literal would consider all elements simultaneously in some cases, whereas the cons would only look at the head and tail. --- src/lib/type_check.ml | 125 ++++++++++++++++------- test/lem/run_tests.py | 2 + test/typecheck/pass/ex_list_infer.sail | 17 +++ test/typecheck/pass/list_infer.sail | 15 +++ test/typecheck/pass/list_infer/v1.expect | 5 + test/typecheck/pass/list_infer/v1.sail | 9 ++ 6 files changed, 134 insertions(+), 39 deletions(-) create mode 100644 test/typecheck/pass/ex_list_infer.sail create mode 100644 test/typecheck/pass/list_infer.sail create mode 100644 test/typecheck/pass/list_infer/v1.expect create mode 100644 test/typecheck/pass/list_infer/v1.sail diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index 02c09c7b0..065c0cb50 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -112,7 +112,7 @@ let rec orig_nexp (Nexp_aux (nexp, l)) = | Nexp_if (i, t, e) -> rewrap (Nexp_if (i, orig_nexp t, orig_nexp e)) | _ -> rewrap nexp -let is_list (Typ_aux (typ_aux, _)) = +let destruct_list (Typ_aux (typ_aux, _)) = match typ_aux with Typ_app (f, [A_aux (A_typ typ, _)]) when string_of_id f = "list" -> Some typ | _ -> None let is_unknown_type = function Typ_aux (Typ_internal_unknown, _) -> true | _ -> false @@ -2007,6 +2007,18 @@ let backwards_attr l uannot = add_attribute l "backwards" None (remove_attribute let tc_assume nc (E_aux (aux, annot)) = E_aux (E_internal_assume (nc, E_aux (aux, annot)), annot) +let rec unroll_cons = function + | E_aux (E_cons (h, t), annot) -> + let elems, annots, last_tail = unroll_cons t in + (h :: elems, annot :: annots, last_tail) + | exp -> ([], [], exp) + +let rec reroll_cons ~at:l elems annots last_tail = + match (elems, annots) with + | elem :: elems, annot :: annots -> E_aux (E_cons (elem, reroll_cons ~at:l elems annots last_tail), annot) + | [], [] -> last_tail + | _, _ -> Reporting.unreachable l __POS__ "Could not recreate cons list due to element and annotation length mismatch" + type ('a, 'b) pattern_functions = { infer : Env.t -> 'a -> 'b * Env.t * uannot exp list; bind : Env.t -> 'a -> typ -> 'b * Env.t * uannot exp list; @@ -2071,21 +2083,6 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au | E_try (exp, cases), _ -> let checked_exp = crule check_exp env exp typ in annot_exp (E_try (checked_exp, List.map (fun case -> check_case env exc_typ case typ) cases)) typ - | E_cons (x, xs), _ -> begin - match is_list (Env.expand_synonyms env typ) with - | Some elem_typ -> - let checked_xs = crule check_exp env xs typ in - let checked_x = crule check_exp env x elem_typ in - annot_exp (E_cons (checked_x, checked_xs)) typ - | None -> typ_error l ("Cons " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) - end - | E_list xs, _ -> begin - match is_list (Env.expand_synonyms env typ) with - | Some elem_typ -> - let checked_xs = List.map (fun x -> crule check_exp env x elem_typ) xs in - annot_exp (E_list checked_xs) typ - | None -> typ_error l ("List " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) - end | E_struct_update (exp, fexps), _ -> let checked_exp = crule check_exp env exp typ in let rectyp_id = @@ -2350,12 +2347,47 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au else if prove __POS__ env (nc_eq (nint literal_len) (nexp_simp len)) then (tyvars, nc, elem_typ) else typ_error l "Vector literal with incorrect length" in - match check_or_infer_sequence ~at:l env vec tyvars nc elem_typ with + match check_or_infer_sequence ~at:l env vec tyvars nc (Some elem_typ) with | Some (vec, elem_typ) -> annot_exp (E_vector vec) (if is_generic then vector_typ (nint literal_len) elem_typ else bitvector_typ (nint literal_len)) | None -> typ_error l ("This vector literal does not satisfy the constraint in " ^ string_of_typ (mk_typ orig_typ)) end + | E_cons (x, xs), orig_typ -> begin + let xs, annots, last_tail = unroll_cons xs in + let tyvars, nc, typ = + match destruct_exist_plain typ with Some (tyvars, nc, typ) -> (tyvars, nc, typ) | None -> ([], nc_true, typ) + in + let tyvars = List.fold_left (fun set kopt -> KidSet.add (kopt_kid kopt) set) KidSet.empty tyvars in + match destruct_list (Env.expand_synonyms env typ) with + | Some elem_typ -> begin + match check_or_infer_sequence ~at:l env (x :: xs) tyvars nc (Some elem_typ) with + | Some (xs, elem_typ) -> + let checked_last_tail = crule check_exp env last_tail (list_typ elem_typ) in + let annots = + List.map + (fun (l, uannot) -> (l, mk_expected_tannot ~uannot env (list_typ elem_typ) (Some (mk_typ orig_typ)))) + ((l, uannot) :: annots) + in + reroll_cons ~at:l xs annots checked_last_tail + | _ -> typ_error l ("This list does not satisfy the constraint in " ^ string_of_typ (mk_typ orig_typ)) + end + | None -> typ_error l ("Cons " ^ string_of_exp exp ^ " must have list type") + end + | E_list xs, orig_typ -> begin + let tyvars, nc, typ = + match destruct_exist_plain typ with Some (tyvars, nc, typ) -> (tyvars, nc, typ) | None -> ([], nc_true, typ) + in + let tyvars = List.fold_left (fun set kopt -> KidSet.add (kopt_kid kopt) set) KidSet.empty tyvars in + match destruct_list (Env.expand_synonyms env typ) with + | Some elem_typ -> begin + match check_or_infer_sequence ~at:l env xs tyvars nc (Some elem_typ) with + | Some (xs, elem_typ) -> annot_exp (E_list xs) (list_typ elem_typ) + | None -> + typ_error l ("This list literal does not satisfy the constraint in " ^ string_of_typ (mk_typ orig_typ)) + end + | None -> typ_error l ("List " ^ string_of_exp exp ^ " must have list type, got " ^ string_of_typ typ) + end | E_lit (L_aux (L_undef, _) as lit), _ -> if can_be_undefined ~at:l env typ then if is_typ_inhabited env (Env.expand_synonyms env typ) then annot_exp (E_lit lit) typ @@ -2374,31 +2406,41 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au the same type, where that type can have additional type variables and constraints that must be instantiated (usually these variables/constraints come from an existential). *) -and check_or_infer_sequence ~at:l env xs tyvars nc typ = - let tyvars, nc, typ, xs = +and check_or_infer_sequence ~at:l env xs tyvars nc typ_opt = + let tyvars, nc, typ_opt, xs = List.fold_left - (fun (tyvars, nc, typ, xs) x -> - let goals = KidSet.inter tyvars (tyvars_of_typ typ) in - if not (KidSet.is_empty goals) then ( - match irule infer_exp env x with - | exception Type_error _ -> (tyvars, nc, typ, Error x :: xs) - | x -> - let unifiers = unify l env goals typ (typ_of x) in - let typ = subst_unifiers unifiers typ in - let nc = KBindings.fold (fun v arg nc -> constraint_subst v arg nc) unifiers nc in - let tyvars = KBindings.fold (fun v _ tyvars -> KidSet.remove v tyvars) unifiers tyvars in - (tyvars, nc, typ, Ok x :: xs) - ) - else ( - let x = crule check_exp env x typ in - (tyvars, nc, typ, Ok x :: xs) - ) + (fun (tyvars, nc, typ_opt, xs) x -> + match typ_opt with + | Some typ -> + let goals = KidSet.inter tyvars (tyvars_of_typ typ) in + if not (KidSet.is_empty goals) then ( + match irule infer_exp env x with + | exception Type_error _ -> (tyvars, nc, Some typ, Error x :: xs) + | x -> + let unifiers = unify l env goals typ (typ_of x) in + let typ = subst_unifiers unifiers typ in + let nc = KBindings.fold (fun v arg nc -> constraint_subst v arg nc) unifiers nc in + let tyvars = KBindings.fold (fun v _ tyvars -> KidSet.remove v tyvars) unifiers tyvars in + (tyvars, nc, Some typ, Ok x :: xs) + ) + else ( + let x = crule check_exp env x typ in + (tyvars, nc, Some typ, Ok x :: xs) + ) + | None -> ( + match irule infer_exp env x with + | exception Type_error _ -> (tyvars, nc, None, Error x :: xs) + | x -> (tyvars, nc, Some (typ_of x), Ok x :: xs) + ) ) - (tyvars, nc, typ, []) xs + (tyvars, nc, typ_opt, []) xs in - if KidSet.is_empty tyvars && prove __POS__ env nc then - Some (List.rev_map (function Ok x -> x | Error x -> crule check_exp env x typ) xs, typ) - else None + match typ_opt with + | Some typ -> + if KidSet.is_empty tyvars && prove __POS__ env nc then + Some (List.rev_map (function Ok x -> x | Error x -> crule check_exp env x typ) xs, typ) + else None + | None -> None and check_block l env exps ret_typ = let final env exp = match ret_typ with Some typ -> crule check_exp env exp typ | None -> irule infer_exp env exp in @@ -3601,6 +3643,11 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) = let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in annot_exp (E_vector (inferred_item :: checked_items)) vec_typ end + | E_list xs -> begin + match check_or_infer_sequence ~at:l env xs KidSet.empty nc_true None with + | Some (xs, elem_typ) -> annot_exp (E_list xs) (list_typ elem_typ) + | None -> typ_error l "Could not infer type of list literal" + end | E_assert (test, msg) -> let msg = assert_msg msg in let checked_test = crule check_exp env test bool_typ in diff --git a/test/lem/run_tests.py b/test/lem/run_tests.py index c048957c5..57d27be36 100755 --- a/test/lem/run_tests.py +++ b/test/lem/run_tests.py @@ -58,6 +58,8 @@ 'abstract_bool2', 'constraint_syn', 'ex_vector_infer', + 'ex_list_infer', + 'ex_cons_infer', } print('Sail is {}'.format(sail)) diff --git a/test/typecheck/pass/ex_list_infer.sail b/test/typecheck/pass/ex_list_infer.sail new file mode 100644 index 000000000..493a7385d --- /dev/null +++ b/test/typecheck/pass/ex_list_infer.sail @@ -0,0 +1,17 @@ +default Order dec + +$include + +register R : bool + +register X : bits(32) + +val test : unit -> {'n, 'n > 1. list(bits('n))} + +function test() = { + if R then { + [| 0b00, 0b11 |] + } else { + [| match X { _ => 0b000 }, 0b001, 0b100 |] + } +} diff --git a/test/typecheck/pass/list_infer.sail b/test/typecheck/pass/list_infer.sail new file mode 100644 index 000000000..b8d67a16f --- /dev/null +++ b/test/typecheck/pass/list_infer.sail @@ -0,0 +1,15 @@ +default Order dec + +$include + +val test : unit -> unit + +function test() = { + let _ = [| 0b00, 0b11 |]; +} + +val test2 : unit -> unit + +function test2() = { + let _ = [| match 0b00 { x => x }, 0b11 |]; +} diff --git a/test/typecheck/pass/list_infer/v1.expect b/test/typecheck/pass/list_infer/v1.expect new file mode 100644 index 000000000..d868eec86 --- /dev/null +++ b/test/typecheck/pass/list_infer/v1.expect @@ -0,0 +1,5 @@ +Type error: +pass/list_infer/v1.sail:8.25-30: +8 | let _ = [| 0b00, 0b11, 0b111 |]; +  | ^---^ +  | Failed to prove constraint: 3 == 2 diff --git a/test/typecheck/pass/list_infer/v1.sail b/test/typecheck/pass/list_infer/v1.sail new file mode 100644 index 000000000..9cbdb62bb --- /dev/null +++ b/test/typecheck/pass/list_infer/v1.sail @@ -0,0 +1,9 @@ +default Order dec + +$include + +val test : unit -> unit + +function test() = { + let _ = [| 0b00, 0b11, 0b111 |]; +}