Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various mapping fixes #540

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 172 additions & 58 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1605,18 +1605,95 @@ let check_pattern_duplicates env pat =
!subrange_ids;
!ids

(* Test that the output of two calls to check pattern duplicates refer to the same identifiers *)
let same_bindings env lhs rhs =
match Bindings.find_first_opt (fun id -> not (Bindings.mem id rhs)) lhs with
| Some (id, _) ->
typ_error (id_loc id) ("Identifier " ^ string_of_id id ^ " found on left hand side of mapping, but not on right")
| None -> (
match Bindings.find_first_opt (fun id -> not (Bindings.mem id lhs)) rhs with
| Some (id, _) ->
typ_error (id_loc id)
("Identifier " ^ string_of_id id ^ " found on right hand side of mapping, but not on left")
| None -> ()
(* This function checks if a type from one side of a mapping is the
same as a type from the other side of the mapping. Types from one
side of the mapping will have been checked in a different
environment, so we have to take the root environment used to create
both sides environments and carefully transfer unshared constraints
from one child environment to the other, before we can do the
type equality check. *)
let check_mapping_typ_equality ~root_env ~other_env ~env ~other_typ ~typ =
let kopt_arg (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) =
match k with K_int -> arg_nexp (nvar v) | K_bool -> arg_bool (nc_var v) | K_type -> arg_typ (mk_typ (Typ_var v))
in
let shared_vars = Env.get_typ_vars root_env in
let other_vars = KBindings.filter (fun v _ -> not (KBindings.mem v shared_vars)) (Env.get_typ_vars other_env) in
let substs =
KBindings.mapi
(fun v k ->
let fresh = Env.fresh_kid ~kid:v env in
mk_kopt k fresh
)
other_vars
in
let new_vars = KBindings.fold (fun _ subst set -> KidSet.add (kopt_kid subst) set) substs KidSet.empty in
let env = KBindings.fold (fun _ subst env -> Env.add_typ_var Parse_ast.Unknown subst env) substs env in
let env =
List.fold_left
(fun env nc ->
let nc = KBindings.fold (fun v subst nc -> constraint_subst v (kopt_arg subst) nc) substs nc in
Env.add_constraint nc env
)
env (Env.get_constraints other_env)
in
let other_typ = KBindings.fold (fun v subst typ -> typ_subst v (kopt_arg subst) typ) substs other_typ in
let goals = KidSet.filter (fun k -> KidSet.mem k new_vars) (tyvars_of_typ other_typ) in
let unifiers = unify Parse_ast.Unknown env goals other_typ typ in
let env =
KBindings.fold
(fun v arg env ->
match arg with
| A_aux (A_nexp n, _) -> Env.add_constraint (nc_eq (nvar v) n) env
| A_aux (A_bool nc, _) ->
Env.add_constraint (nc_or (nc_and (nc_var v) nc) (nc_and (nc_not (nc_var v)) (nc_not nc))) env
| A_aux (A_typ _, _) -> env
)
unifiers env
in
typ_equality Parse_ast.Unknown env other_typ typ

(* Test that the output of two calls to check pattern duplicates refer
to the same identifiers, and that those identifiers have the same
types. *)
let same_bindings ~at:l ~env ~left_env ~right_env lhs rhs =
let get_loc = function Pattern_singleton l -> l | Pattern_duplicate (l, _) -> l in
Bindings.iter
(fun id left ->
match Bindings.find_opt id rhs with
| Some right ->
let left_lvar = Env.lookup_id id left_env in
let right_lvar = Env.lookup_id id right_env in
if not (is_unbound left_lvar || is_unbound right_lvar) then (
let left_typ = lvar_typ left_lvar in
let right_typ = lvar_typ right_lvar in
let mapping_type_mismatch err =
typ_raise l
(Err_inner
( Err_other
(Printf.sprintf "'%s' must have the same type on both sides of the mapping" (string_of_id id)),
Hint ("has type " ^ string_of_typ left_typ, get_loc left, get_loc right),
"",
Err_with_hint ("has type " ^ string_of_typ right_typ, err)
)
)
in
try
check_mapping_typ_equality ~root_env:env ~other_env:left_env ~env:right_env ~other_typ:left_typ
~typ:right_typ
with
| Unification_error (_, m) -> mapping_type_mismatch (Err_other m)
| Type_error (_, err) -> mapping_type_mismatch err
)
| None ->
typ_error (get_loc left)
("Identifier " ^ string_of_id id ^ " found on left hand side of mapping, but not on right")
)
lhs;
match Bindings.find_first_opt (fun id -> not (Bindings.mem id lhs)) rhs with
| Some (id, right) ->
typ_error (get_loc right)
("Identifier " ^ string_of_id id ^ " found on right hand side of mapping, but not on left")
| None -> ()

let bitvector_typ_from_range l env n m =
let len =
Expand Down Expand Up @@ -1933,11 +2010,21 @@ let tc_assume nc (E_aux (aux, annot)) = E_aux (E_internal_assume (nc, E_aux (aux
type ('a, 'b) pattern_functions = {
infer : Env.t -> 'a -> 'b * Env.t * uannot exp list;
bind : Env.t -> 'a -> typ -> 'b * Env.t * uannot exp list;
strip : 'b -> 'a;
typ_of : 'b -> typ;
get_loc : 'a -> l;
get_loc_typed : 'b -> l;
}

type ('a, 'b) vector_concat_elem = VC_elem_ok of 'a | VC_elem_error of 'b * exn | VC_elem_unknown of 'a

let unwrap_vector_concat_elem ~at:l = function
| VC_elem_ok x -> x
| VC_elem_unknown x -> x
| VC_elem_error _ -> Reporting.unreachable l __POS__ "Tried to unwrap VC_elem_error"

let vector_concat_elem_is_ok = function VC_elem_ok _ -> true | _ -> false

module PC_config = struct
type t = tannot
let typ_of_t = typ_of_tannot
Expand Down Expand Up @@ -2748,40 +2835,50 @@ and bind_vector_concat_generic :
let typ_opt =
Option.bind typ_opt (fun typ ->
match destruct_any_vector_typ l env typ with
| Destruct_vector (len, elem_typ) -> Option.map (fun l -> (l, Some elem_typ)) (solve_unique env len)
| Destruct_bitvector len -> Option.map (fun l -> (l, None)) (solve_unique env len)
| Destruct_vector (len, elem_typ) -> Option.map (fun len -> (len, Some elem_typ)) (solve_unique env len)
| Destruct_bitvector len -> Option.map (fun len -> (len, None)) (solve_unique env len)
)
in

(* Try to infer any subpatterns, skipping those we cannot infer *)
let fold_pats (pats, env, guards) pat =
let wrap_some (x, y, z) = (Ok x, y, z) in
let wrap_ok (x, y, z) = (VC_elem_ok x, y, z) in
let inferred_pat, env, guards' =
if Option.is_none typ_opt then wrap_some (funcs.infer env pat)
else (try wrap_some (funcs.infer env pat) with Type_error _ as exn -> (Error (pat, exn), env, []))
if Option.is_none typ_opt then wrap_ok (funcs.infer env pat)
else (try wrap_ok (funcs.infer env pat) with Type_error _ as exn -> (VC_elem_error (pat, exn), env, []))
in
(inferred_pat :: pats, env, guards' @ guards)
in
let inferred_pats, env, guards = List.fold_left fold_pats ([], env, []) (pat :: pats) in
let inferred_pats = List.rev inferred_pats in

(* If we are checking a mapping we can have unknown types, in this
case we can't continue, so just give the entire vector concat an
unknown type. *)
let have_unknown =
allow_unknown
&& List.exists (function Ok pat -> is_unknown_type (funcs.typ_of pat) | Error _ -> false) inferred_pats
case we can't continue if there is more than a single unknown, so
just give the entire vector concat an unknown type. *)
let inferred_pats =
if allow_unknown then
List.map
(function
| VC_elem_ok pat -> if is_unknown_type (funcs.typ_of pat) then VC_elem_unknown pat else VC_elem_ok pat
| err -> err
)
inferred_pats
else inferred_pats
in
if have_unknown && List.for_all Result.is_ok inferred_pats then
(annotate (List.map Result.get_ok inferred_pats) unknown_typ, env, guards)
let num_unknowns = List.length (List.filter (function VC_elem_unknown _ -> true | _ -> false) inferred_pats) in
if num_unknowns > 1 || (num_unknowns > 0 && Option.is_none typ_opt) then (
match Util.option_first (function VC_elem_error (_, exn) -> Some exn | _ -> None) inferred_pats with
| Some exn -> raise exn
| None -> (annotate (List.map (unwrap_vector_concat_elem ~at:l) inferred_pats) unknown_typ, env, guards)
)
else (
(* Will be none if the subpatterns are bitvectors *)
let elem_typ =
match typ_opt with
| Some (_, elem_typ) -> elem_typ
| None -> (
match List.find_opt Result.is_ok inferred_pats with
| Some (Ok pat) -> begin
match List.find_opt vector_concat_elem_is_ok inferred_pats with
| Some (VC_elem_ok pat) -> begin
match destruct_any_vector_typ l env (funcs.typ_of pat) with
| Destruct_vector (_, t) -> Some t
| Destruct_bitvector _ -> None
Expand All @@ -2790,32 +2887,43 @@ and bind_vector_concat_generic :
)
in

(* We can handle a single None in inferred_pats from something like
(* We can handle a single uninferred element (treating unknown
elements as uninferred) in inferred_pats from something like
0b00 @ _ @ 0b00, because we know the wildcard will be bits('n - 4)
where 'n is the total length of the pattern. *)
let before_uninferred, rest = Util.take_drop Result.is_ok inferred_pats in
let before_uninferred = List.map Result.get_ok before_uninferred in
let before_uninferred, rest = Util.take_drop vector_concat_elem_is_ok inferred_pats in
let before_uninferred = List.map (unwrap_vector_concat_elem ~at:l) before_uninferred in
let uninferred, after_uninferred =
(* When we encounter an unknown or uninferred pattern, check the rest for a second 'bad'
pattern that is also unknown or uninferred *)
let check_rest ~first_bad rest =
let msg =
"Cannot infer width here, as there are multiple subpatterns with unclear width in vector concatenation \
pattern"
in
match List.find_opt (fun elem -> not (vector_concat_elem_is_ok elem)) rest with
| Some (VC_elem_error (second_bad, _)) ->
typ_raise (funcs.get_loc second_bad)
(err_because (Err_other msg, first_bad, Err_other "A previous subpattern is here"))
| Some (VC_elem_unknown second_bad) ->
typ_raise
(funcs.get_loc (funcs.strip second_bad))
(err_because (Err_other msg, first_bad, Err_other "A previous subpattern is here"))
| _ -> ()
in
match rest with
| Error (first_uninferred, exn) :: rest ->
begin
match List.find_opt Result.is_error rest with
| Some (Error (second_uninferred, _)) ->
let msg =
"Cannot infer width here, as there are multiple subpatterns with unclear width in vector \
concatenation pattern"
in
typ_raise (funcs.get_loc second_uninferred)
(err_because
(Err_other msg, funcs.get_loc first_uninferred, Err_other "A previous subpattern is here")
)
| _ -> ()
end;
begin
match typ_opt with
| Some (total_len, _) -> (Some (total_len, first_uninferred), List.map Result.get_ok rest)
| None -> raise exn
end
| VC_elem_error (first_uninferred, exn) :: rest -> begin
check_rest ~first_bad:(funcs.get_loc first_uninferred) rest;
match typ_opt with
| Some (total_len, _) -> (Some (total_len, first_uninferred), List.map (unwrap_vector_concat_elem ~at:l) rest)
| None -> raise exn
end
| VC_elem_unknown first_unknown :: rest ->
let first_unknown = funcs.strip first_unknown in
check_rest ~first_bad:(funcs.get_loc first_unknown) rest;
(* If we have unknown elems, we check above that the typ_opt is Some _ *)
let total_len = fst (Option.get typ_opt) in
(Some (total_len, first_unknown), List.map (unwrap_vector_concat_elem ~at:l) rest)
| _ -> (None, [])
in

Expand Down Expand Up @@ -2876,7 +2984,16 @@ and bind_vector_concat_generic :

and bind_vector_concat_pat l env uannot pat pats typ_opt =
let annot_vcp pats typ = P_aux (P_vector_concat pats, (l, mk_tannot ~uannot env typ)) in
let funcs = { infer = infer_pat; bind = bind_pat; typ_of = typ_of_pat; get_loc = pat_loc; get_loc_typed = pat_loc } in
let funcs =
{
infer = infer_pat;
bind = bind_pat;
strip = strip_pat;
typ_of = typ_of_pat;
get_loc = pat_loc;
get_loc_typed = pat_loc;
}
in
bind_vector_concat_generic funcs annot_vcp l false env pat pats typ_opt

and bind_vector_concat_mpat l allow_unknown other_env env uannot mpat mpats typ_opt =
Expand All @@ -2885,6 +3002,7 @@ and bind_vector_concat_mpat l allow_unknown other_env env uannot mpat mpats typ_
{
infer = infer_mpat allow_unknown other_env;
bind = bind_mpat allow_unknown other_env;
strip = strip_mpat;
typ_of = typ_of_mpat;
get_loc = mpat_loc;
get_loc_typed = mpat_loc;
Expand Down Expand Up @@ -3817,11 +3935,7 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, uannot)) as mpa
let unifiers = unify l env (tyvars_of_typ ret_typ) ret_typ typ in
let arg_typ' = subst_unifiers unifiers arg_typ in
let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if not (List.for_all (solve_quant env) quants') then
typ_raise l
(Err_unresolved_quants
(f, quants', Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
);
let env = Env.add_typquant l (mk_typquant quants') env in
let _ret_typ' = subst_unifiers unifiers ret_typ in
let tpats, env, guards =
try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ')
Expand Down Expand Up @@ -4125,14 +4239,14 @@ let check_mapcl env (MCL_aux (cl, (def_annot, _))) typ =
| MCL_bidir (left_mpexp, right_mpexp) -> begin
let left_mpat, _, _ = destruct_mpexp left_mpexp in
let left_dups = check_pattern_duplicates env (pat_of_mpat left_mpat) in
let left_id_env = find_types env left_mpat typ1 in
let left_env = find_types env left_mpat typ1 in
let right_mpat, _, _ = destruct_mpexp right_mpexp in
let right_dups = check_pattern_duplicates env (pat_of_mpat right_mpat) in
same_bindings env left_dups right_dups;
let right_id_env = find_types env right_mpat typ2 in
let right_env = find_types env right_mpat typ2 in
same_bindings ~at:def_annot.loc ~env ~left_env ~right_env left_dups right_dups;

let typed_left_mpexp = check_mpexp right_id_env env left_mpexp typ1 in
let typed_right_mpexp = check_mpexp left_id_env env right_mpexp typ2 in
let typed_left_mpexp = check_mpexp right_env env left_mpexp typ1 in
let typed_right_mpexp = check_mpexp left_env env right_mpexp typ2 in
MCL_aux (MCL_bidir (typed_left_mpexp, typed_right_mpexp), (def_annot, mk_expected_tannot env typ (Some typ)))
end
| MCL_forwards pexp -> begin
Expand Down
1 change: 1 addition & 0 deletions src/lib/type_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ val with_global_scope : t -> t * module_state

val restore_scope : module_state -> t -> t

val fresh_kid : ?kid:kid -> env -> kid
val freshen_bind : t -> typquant * typ -> typquant * typ

val get_default_order : t -> order
Expand Down
3 changes: 3 additions & 0 deletions src/lib/type_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ let message_of_type_error type_error =
in
let rec to_message = function
| Err_hint h -> (Seq [], Some h)
| Err_with_hint (h, err) ->
let msg, _ = to_message err in
(msg, Some h)
| Err_inner (err, l', prefix, err') ->
let prefix = if prefix = "" then "" else Util.(prefix ^ " " |> yellow |> clear) in
let msg, hint = to_message err in
Expand Down
1 change: 1 addition & 0 deletions src/lib/type_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type type_error =
(** Takes the name of the identifier, the set of local bindings, and
whether we have a function of the same name in scope. *)
| Err_hint of string (** A short error that only appears attached to a location *)
| Err_with_hint of string * type_error

exception Type_error of Parse_ast.l * type_error

Expand Down
1 change: 1 addition & 0 deletions src/lib/type_internal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type type_error =
| Err_no_function_type of { id : id; functions : (typquant * typ) Bindings.t }
| Err_unbound_id of { id : id; locals : (mut * typ) Bindings.t; have_function : bool }
| Err_hint of string
| Err_with_hint of string * type_error

let err_because (error1, l, error2) = Err_inner (error1, l, "Caused by", error2)

Expand Down
4 changes: 4 additions & 0 deletions test/c/encdec_subrange.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
x = 0xFFFC03FF
x = 0xF0000000
x = 0xFAB00000
ok
31 changes: 31 additions & 0 deletions test/c/encdec_subrange.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
default Order dec

$include <prelude.sail>

scattered union instr

val encdec : instr <-> bits(32)

scattered mapping encdec

union clause instr = X : bits(21)

mapping clause encdec = X(imm20 @ 0b0)
<-> 0xF @ imm20[9..0] @ 0x00 @ imm20[19..10]

val main : unit -> unit

function main() = {
let x = encdec(X(0xFFFFF @ 0b0));
print_bits("x = ", x);
let x = encdec(X(0x00000 @ 0b0));
print_bits("x = ", x);
let x = encdec(X(0b00 @ 0x00AB @ 0b00 @ 0b0));
print_bits("x = ", x);

match encdec(0xFFFC03FF) {
X(y) => assert(y == 0xFFFFF @ 0b0),
};

print_endline("ok");
}
Loading
Loading