diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index 9b9a364ca..94c649778 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -1933,11 +1933,21 @@ let tc_assume nc (E_aux (aux, annot)) = E_aux (E_internal_assume (nc, E_aux (aux type ('a, 'b) pattern_functions = { infer : Env.t -> 'a -> 'b * Env.t * uannot exp list; bind : Env.t -> 'a -> typ -> 'b * Env.t * uannot exp list; + strip : 'b -> 'a; typ_of : 'b -> typ; get_loc : 'a -> l; get_loc_typed : 'b -> l; } +type ('a, 'b) vector_concat_elem = VC_elem_ok of 'a | VC_elem_error of 'b * exn | VC_elem_unknown of 'a + +let unwrap_vector_concat_elem ~at:l = function + | VC_elem_ok x -> x + | VC_elem_unknown x -> x + | VC_elem_error _ -> Reporting.unreachable l __POS__ "Tried to unwrap VC_elem_error" + +let vector_concat_elem_is_ok = function VC_elem_ok _ -> true | _ -> false + module PC_config = struct type t = tannot let typ_of_t = typ_of_tannot @@ -2748,17 +2758,17 @@ and bind_vector_concat_generic : let typ_opt = Option.bind typ_opt (fun typ -> match destruct_any_vector_typ l env typ with - | Destruct_vector (len, elem_typ) -> Option.map (fun l -> (l, Some elem_typ)) (solve_unique env len) - | Destruct_bitvector len -> Option.map (fun l -> (l, None)) (solve_unique env len) + | Destruct_vector (len, elem_typ) -> Option.map (fun len -> (len, Some elem_typ)) (solve_unique env len) + | Destruct_bitvector len -> Option.map (fun len -> (len, None)) (solve_unique env len) ) in (* Try to infer any subpatterns, skipping those we cannot infer *) let fold_pats (pats, env, guards) pat = - let wrap_some (x, y, z) = (Ok x, y, z) in + let wrap_ok (x, y, z) = (VC_elem_ok x, y, z) in let inferred_pat, env, guards' = - if Option.is_none typ_opt then wrap_some (funcs.infer env pat) - else (try wrap_some (funcs.infer env pat) with Type_error _ as exn -> (Error (pat, exn), env, [])) + if Option.is_none typ_opt then wrap_ok (funcs.infer env pat) + else (try wrap_ok (funcs.infer env pat) with Type_error _ as exn -> (VC_elem_error (pat, exn), env, [])) in (inferred_pat :: pats, env, guards' @ guards) in @@ -2766,22 +2776,32 @@ and bind_vector_concat_generic : let inferred_pats = List.rev inferred_pats in (* If we are checking a mapping we can have unknown types, in this - case we can't continue, so just give the entire vector concat an - unknown type. *) - let have_unknown = - allow_unknown - && List.exists (function Ok pat -> is_unknown_type (funcs.typ_of pat) | Error _ -> false) inferred_pats + case we can't continue if there is more than a single unknown, so + just give the entire vector concat an unknown type. *) + let inferred_pats = + if allow_unknown then + List.map + (function + | VC_elem_ok pat -> if is_unknown_type (funcs.typ_of pat) then VC_elem_unknown pat else VC_elem_ok pat + | err -> err + ) + inferred_pats + else inferred_pats in - if have_unknown && List.for_all Result.is_ok inferred_pats then - (annotate (List.map Result.get_ok inferred_pats) unknown_typ, env, guards) + let num_unknowns = List.length (List.filter (function VC_elem_unknown _ -> true | _ -> false) inferred_pats) in + if num_unknowns > 1 || (num_unknowns > 0 && Option.is_none typ_opt) then ( + match Util.option_first (function VC_elem_error (_, exn) -> Some exn | _ -> None) inferred_pats with + | Some exn -> raise exn + | None -> (annotate (List.map (unwrap_vector_concat_elem ~at:l) inferred_pats) unknown_typ, env, guards) + ) else ( (* Will be none if the subpatterns are bitvectors *) let elem_typ = match typ_opt with | Some (_, elem_typ) -> elem_typ | None -> ( - match List.find_opt Result.is_ok inferred_pats with - | Some (Ok pat) -> begin + match List.find_opt vector_concat_elem_is_ok inferred_pats with + | Some (VC_elem_ok pat) -> begin match destruct_any_vector_typ l env (funcs.typ_of pat) with | Destruct_vector (_, t) -> Some t | Destruct_bitvector _ -> None @@ -2790,32 +2810,43 @@ and bind_vector_concat_generic : ) in - (* We can handle a single None in inferred_pats from something like + (* We can handle a single uninferred element (treating unknown + elements as uninferred) in inferred_pats from something like 0b00 @ _ @ 0b00, because we know the wildcard will be bits('n - 4) where 'n is the total length of the pattern. *) - let before_uninferred, rest = Util.take_drop Result.is_ok inferred_pats in - let before_uninferred = List.map Result.get_ok before_uninferred in + let before_uninferred, rest = Util.take_drop vector_concat_elem_is_ok inferred_pats in + let before_uninferred = List.map (unwrap_vector_concat_elem ~at:l) before_uninferred in let uninferred, after_uninferred = + (* When we encounter an unknown or uninferred pattern, check the rest for a second 'bad' + pattern that is also unknown or uninferred *) + let check_rest ~first_bad rest = + let msg = + "Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation \ + pattern" + in + match List.find_opt (fun elem -> not (vector_concat_elem_is_ok elem)) rest with + | Some (VC_elem_error (second_bad, _)) -> + typ_raise (funcs.get_loc second_bad) + (err_because (Err_other msg, first_bad, Err_other "A previous subpattern is here")) + | Some (VC_elem_unknown second_bad) -> + typ_raise + (funcs.get_loc (funcs.strip second_bad)) + (err_because (Err_other msg, first_bad, Err_other "A previous subpattern is here")) + | _ -> () + in match rest with - | Error (first_uninferred, exn) :: rest -> - begin - match List.find_opt Result.is_error rest with - | Some (Error (second_uninferred, _)) -> - let msg = - "Cannot infer width here, as there are multiple subpatterns with unclear width in vector \ - concatenation pattern" - in - typ_raise (funcs.get_loc second_uninferred) - (err_because - (Err_other msg, funcs.get_loc first_uninferred, Err_other "A previous subpattern is here") - ) - | _ -> () - end; - begin - match typ_opt with - | Some (total_len, _) -> (Some (total_len, first_uninferred), List.map Result.get_ok rest) - | None -> raise exn - end + | VC_elem_error (first_uninferred, exn) :: rest -> begin + check_rest ~first_bad:(funcs.get_loc first_uninferred) rest; + match typ_opt with + | Some (total_len, _) -> (Some (total_len, first_uninferred), List.map (unwrap_vector_concat_elem ~at:l) rest) + | None -> raise exn + end + | VC_elem_unknown first_unknown :: rest -> + let first_unknown = funcs.strip first_unknown in + check_rest ~first_bad:(funcs.get_loc first_unknown) rest; + (* If we have unknown elems, we check above that the typ_opt is Some _ *) + let total_len = fst (Option.get typ_opt) in + (Some (total_len, first_unknown), List.map (unwrap_vector_concat_elem ~at:l) rest) | _ -> (None, []) in @@ -2876,7 +2907,16 @@ and bind_vector_concat_generic : and bind_vector_concat_pat l env uannot pat pats typ_opt = let annot_vcp pats typ = P_aux (P_vector_concat pats, (l, mk_tannot ~uannot env typ)) in - let funcs = { infer = infer_pat; bind = bind_pat; typ_of = typ_of_pat; get_loc = pat_loc; get_loc_typed = pat_loc } in + let funcs = + { + infer = infer_pat; + bind = bind_pat; + strip = strip_pat; + typ_of = typ_of_pat; + get_loc = pat_loc; + get_loc_typed = pat_loc; + } + in bind_vector_concat_generic funcs annot_vcp l false env pat pats typ_opt and bind_vector_concat_mpat l allow_unknown other_env env uannot mpat mpats typ_opt = @@ -2885,6 +2925,7 @@ and bind_vector_concat_mpat l allow_unknown other_env env uannot mpat mpats typ_ { infer = infer_mpat allow_unknown other_env; bind = bind_mpat allow_unknown other_env; + strip = strip_mpat; typ_of = typ_of_mpat; get_loc = mpat_loc; get_loc_typed = mpat_loc; diff --git a/test/c/encdec_subrange.expect b/test/c/encdec_subrange.expect new file mode 100644 index 000000000..09282999d --- /dev/null +++ b/test/c/encdec_subrange.expect @@ -0,0 +1,4 @@ +x = 0xFFFC03FF +x = 0xF0000000 +x = 0xFAB00000 +ok diff --git a/test/c/encdec_subrange.sail b/test/c/encdec_subrange.sail new file mode 100644 index 000000000..37dae0950 --- /dev/null +++ b/test/c/encdec_subrange.sail @@ -0,0 +1,31 @@ +default Order dec + +$include + +scattered union instr + +val encdec : instr <-> bits(32) + +scattered mapping encdec + +union clause instr = X : bits(21) + +mapping clause encdec = X(imm20 @ 0b0) + <-> 0xF @ imm20[9..0] @ 0x00 @ imm20[19..10] + +val main : unit -> unit + +function main() = { + let x = encdec(X(0xFFFFF @ 0b0)); + print_bits("x = ", x); + let x = encdec(X(0x00000 @ 0b0)); + print_bits("x = ", x); + let x = encdec(X(0b00 @ 0x00AB @ 0b00 @ 0b0)); + print_bits("x = ", x); + + match encdec(0xFFFC03FF) { + X(y) => assert(y == 0xFFFFF @ 0b0), + }; + + print_endline("ok"); +} \ No newline at end of file diff --git a/test/typecheck/fail/encdec_unknown_and_error.expect b/test/typecheck/fail/encdec_unknown_and_error.expect new file mode 100644 index 000000000..a0d5f336e --- /dev/null +++ b/test/typecheck/fail/encdec_unknown_and_error.expect @@ -0,0 +1,10 @@ +Type error: +fail/encdec_unknown_and_error.sail:13.30-31: +13 |mapping clause encdec = X(y @ x @ Z() : bits(1)) +  | ^ +  | Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation pattern +  | +  | Caused by fail/encdec_unknown_and_error.sail:13.26-27: +  | 13 |mapping clause encdec = X(y @ x @ Z() : bits(1)) +  |  | ^ +  |  | A previous subpattern is here diff --git a/test/typecheck/fail/encdec_unknown_and_error.sail b/test/typecheck/fail/encdec_unknown_and_error.sail new file mode 100644 index 000000000..cdc139a89 --- /dev/null +++ b/test/typecheck/fail/encdec_unknown_and_error.sail @@ -0,0 +1,14 @@ +default Order dec + +$include + +scattered union instr + +val encdec : instr <-> bits(32) + +scattered mapping encdec + +union clause instr = X : bits(21) + +mapping clause encdec = X(y @ x @ Z() : bits(1)) + <-> 0xF @ x @ 0x00 @ y