Skip to content

Commit

Permalink
Fix some mapping issues and add additional tests
Browse files Browse the repository at this point in the history
Make sure variables on both sides of mappings are the same type.
  • Loading branch information
Alasdair committed May 13, 2024
1 parent 3a0e9d3 commit 4c323bd
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 21 deletions.
115 changes: 94 additions & 21 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1605,18 +1605,95 @@ let check_pattern_duplicates env pat =
!subrange_ids;
!ids

(* Test that the output of two calls to check pattern duplicates refer to the same identifiers *)
let same_bindings env lhs rhs =
match Bindings.find_first_opt (fun id -> not (Bindings.mem id rhs)) lhs with
| Some (id, _) ->
typ_error (id_loc id) ("Identifier " ^ string_of_id id ^ " found on left hand side of mapping, but not on right")
| None -> (
match Bindings.find_first_opt (fun id -> not (Bindings.mem id lhs)) rhs with
| Some (id, _) ->
typ_error (id_loc id)
("Identifier " ^ string_of_id id ^ " found on right hand side of mapping, but not on left")
| None -> ()
(* This function checks if a type from one side of a mapping is the
same as a type from the other side of the mapping. Types from one
side of the mapping will have been checked in a different
environment, so we have to take the root environment used to create
both sides environments and carefully transfer unshared constraints
from one child environment to the other, before we can do the
type equality check. *)
let check_mapping_typ_equality ~root_env ~other_env ~env ~other_typ ~typ =
let kopt_arg (KOpt_aux (KOpt_kind (K_aux (k, _), v), _)) =
match k with K_int -> arg_nexp (nvar v) | K_bool -> arg_bool (nc_var v) | K_type -> arg_typ (mk_typ (Typ_var v))
in
let shared_vars = Env.get_typ_vars root_env in
let other_vars = KBindings.filter (fun v _ -> not (KBindings.mem v shared_vars)) (Env.get_typ_vars other_env) in
let substs =
KBindings.mapi
(fun v k ->
let fresh = Env.fresh_kid ~kid:v env in
mk_kopt k fresh
)
other_vars
in
let new_vars = KBindings.fold (fun _ subst set -> KidSet.add (kopt_kid subst) set) substs KidSet.empty in
let env = KBindings.fold (fun _ subst env -> Env.add_typ_var Parse_ast.Unknown subst env) substs env in
let env =
List.fold_left
(fun env nc ->
let nc = KBindings.fold (fun v subst nc -> constraint_subst v (kopt_arg subst) nc) substs nc in
Env.add_constraint nc env
)
env (Env.get_constraints other_env)
in
let other_typ = KBindings.fold (fun v subst typ -> typ_subst v (kopt_arg subst) typ) substs other_typ in
let goals = KidSet.filter (fun k -> KidSet.mem k new_vars) (tyvars_of_typ other_typ) in
let unifiers = unify Parse_ast.Unknown env goals other_typ typ in
let env =
KBindings.fold
(fun v arg env ->
match arg with
| A_aux (A_nexp n, _) -> Env.add_constraint (nc_eq (nvar v) n) env
| A_aux (A_bool nc, _) ->
Env.add_constraint (nc_or (nc_and (nc_var v) nc) (nc_and (nc_not (nc_var v)) (nc_not nc))) env
| A_aux (A_typ _, _) -> env
)
unifiers env
in
typ_equality Parse_ast.Unknown env other_typ typ

(* Test that the output of two calls to check pattern duplicates refer
to the same identifiers, and that those identifiers have the same
types. *)
let same_bindings ~at:l ~env ~left_env ~right_env lhs rhs =
let get_loc = function Pattern_singleton l -> l | Pattern_duplicate (l, _) -> l in
Bindings.iter
(fun id left ->
match Bindings.find_opt id rhs with
| Some right ->
let left_lvar = Env.lookup_id id left_env in
let right_lvar = Env.lookup_id id right_env in
if not (is_unbound left_lvar || is_unbound right_lvar) then (
let left_typ = lvar_typ left_lvar in
let right_typ = lvar_typ right_lvar in
let mapping_type_mismatch err =
typ_raise l
(Err_inner
( Err_other
(Printf.sprintf "'%s' must have the same type on both sides of the mapping" (string_of_id id)),
Hint ("has type " ^ string_of_typ left_typ, get_loc left, get_loc right),
"",
Err_with_hint ("has type " ^ string_of_typ right_typ, err)
)
)
in
try
check_mapping_typ_equality ~root_env:env ~other_env:left_env ~env:right_env ~other_typ:left_typ
~typ:right_typ
with
| Unification_error (_, m) -> mapping_type_mismatch (Err_other m)
| Type_error (_, err) -> mapping_type_mismatch err
)
| None ->
typ_error (get_loc left)
("Identifier " ^ string_of_id id ^ " found on left hand side of mapping, but not on right")
)
lhs;
match Bindings.find_first_opt (fun id -> not (Bindings.mem id lhs)) rhs with
| Some (id, right) ->
typ_error (get_loc right)
("Identifier " ^ string_of_id id ^ " found on right hand side of mapping, but not on left")
| None -> ()

let bitvector_typ_from_range l env n m =
let len =
Expand Down Expand Up @@ -3858,11 +3935,7 @@ and bind_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, uannot)) as mpa
let unifiers = unify l env (tyvars_of_typ ret_typ) ret_typ typ in
let arg_typ' = subst_unifiers unifiers arg_typ in
let quants' = List.fold_left instantiate_quants quants (KBindings.bindings unifiers) in
if not (List.for_all (solve_quant env) quants') then
typ_raise l
(Err_unresolved_quants
(f, quants', Env.get_locals env, Env.get_typ_vars_info env, Env.get_constraints env)
);
let env = Env.add_typquant l (mk_typquant quants') env in
let _ret_typ' = subst_unifiers unifiers ret_typ in
let tpats, env, guards =
try List.fold_left2 bind_tuple_mpat ([], env, []) mpats (untuple arg_typ')
Expand Down Expand Up @@ -4166,14 +4239,14 @@ let check_mapcl env (MCL_aux (cl, (def_annot, _))) typ =
| MCL_bidir (left_mpexp, right_mpexp) -> begin
let left_mpat, _, _ = destruct_mpexp left_mpexp in
let left_dups = check_pattern_duplicates env (pat_of_mpat left_mpat) in
let left_id_env = find_types env left_mpat typ1 in
let left_env = find_types env left_mpat typ1 in
let right_mpat, _, _ = destruct_mpexp right_mpexp in
let right_dups = check_pattern_duplicates env (pat_of_mpat right_mpat) in
same_bindings env left_dups right_dups;
let right_id_env = find_types env right_mpat typ2 in
let right_env = find_types env right_mpat typ2 in
same_bindings ~at:def_annot.loc ~env ~left_env ~right_env left_dups right_dups;

let typed_left_mpexp = check_mpexp right_id_env env left_mpexp typ1 in
let typed_right_mpexp = check_mpexp left_id_env env right_mpexp typ2 in
let typed_left_mpexp = check_mpexp right_env env left_mpexp typ1 in
let typed_right_mpexp = check_mpexp left_env env right_mpexp typ2 in
MCL_aux (MCL_bidir (typed_left_mpexp, typed_right_mpexp), (def_annot, mk_expected_tannot env typ (Some typ)))
end
| MCL_forwards pexp -> begin
Expand Down
1 change: 1 addition & 0 deletions src/lib/type_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ val with_global_scope : t -> t * module_state

val restore_scope : module_state -> t -> t

val fresh_kid : ?kid:kid -> env -> kid
val freshen_bind : t -> typquant * typ -> typquant * typ

val get_default_order : t -> order
Expand Down
3 changes: 3 additions & 0 deletions src/lib/type_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ let message_of_type_error type_error =
in
let rec to_message = function
| Err_hint h -> (Seq [], Some h)
| Err_with_hint (h, err) ->
let msg, _ = to_message err in
(msg, Some h)
| Err_inner (err, l', prefix, err') ->
let prefix = if prefix = "" then "" else Util.(prefix ^ " " |> yellow |> clear) in
let msg, hint = to_message err in
Expand Down
1 change: 1 addition & 0 deletions src/lib/type_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type type_error =
(** Takes the name of the identifier, the set of local bindings, and
whether we have a function of the same name in scope. *)
| Err_hint of string (** A short error that only appears attached to a location *)
| Err_with_hint of string * type_error

exception Type_error of Parse_ast.l * type_error

Expand Down
1 change: 1 addition & 0 deletions src/lib/type_internal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type type_error =
| Err_no_function_type of { id : id; functions : (typquant * typ) Bindings.t }
| Err_unbound_id of { id : id; locals : (mut * typ) Bindings.t; have_function : bool }
| Err_hint of string
| Err_with_hint of string * type_error

let err_because (error1, l, error2) = Err_inner (error1, l, "Caused by", error2)

Expand Down
13 changes: 13 additions & 0 deletions test/typecheck/fail/mapping_length_mismatch.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Type error:
fail/mapping_length_mismatch.sail:9.21-48:
9 |mapping clause foo = x : bits(1) <-> x : bits(2)
 | ^-------------------------^
 | 'x' must have the same type on both sides of the mapping
 |
 | fail/mapping_length_mismatch.sail:9.21-22:
 | 9 |mapping clause foo = x : bits(1) <-> x : bits(2)
 |  | ^ has type bits(1)
 | fail/mapping_length_mismatch.sail:9.37-38:
 | 9 |mapping clause foo = x : bits(1) <-> x : bits(2)
 |  | ^ has type bits(2)
 |  | Integer expressions 1 and 2 are not equal
9 changes: 9 additions & 0 deletions test/typecheck/fail/mapping_length_mismatch.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
default Order dec

$include <prelude.sail>

val foo : bits(1) <-> bits(2)

scattered mapping foo

mapping clause foo = x : bits(1) <-> x : bits(2)
13 changes: 13 additions & 0 deletions test/typecheck/fail/poly_ab_mapping.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Type error:
fail/poly_ab_mapping.sail:9.21-38:
9 |mapping clause foo = x : 'a <-> x : 'b
 | ^---------------^
 | 'x' must have the same type on both sides of the mapping
 |
 | fail/poly_ab_mapping.sail:9.21-22:
 | 9 |mapping clause foo = x : 'a <-> x : 'b
 |  | ^ has type 'a
 | fail/poly_ab_mapping.sail:9.32-33:
 | 9 |mapping clause foo = x : 'a <-> x : 'b
 |  | ^ has type 'b
 |  | Type mismatch between 'a and 'b
9 changes: 9 additions & 0 deletions test/typecheck/fail/poly_ab_mapping.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
default Order dec

$include <prelude.sail>

val foo : forall ('a 'b : Type). 'a <-> 'b

scattered mapping foo

mapping clause foo = x : 'a <-> x : 'b
13 changes: 13 additions & 0 deletions test/typecheck/pass/bool_mapping.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
default Order dec

$include <prelude.sail>

union instr = {
Foo : bool
}

val bar : instr <-> bool

scattered mapping bar

mapping clause bar = Foo(b) <-> b
13 changes: 13 additions & 0 deletions test/typecheck/pass/bool_mapping2.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
default Order dec

$include <prelude.sail>

union instr = {
Foo : (bool, unit)
}

val bar : instr <-> bool

scattered mapping bar

mapping clause bar = Foo(b, ()) <-> b

0 comments on commit 4c323bd

Please sign in to comment.