Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter the set of overloads before checking them #526

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 170 additions & 27 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ let wf_binding l env (typq, typ) =
let wf_typschm env (TypSchm_aux (TypSchm_ts (typq, typ), l)) = wf_binding l env (typq, typ)

let dvector_typ _env n typ = vector_typ n typ
let bits_typ _env n = bitvector_typ n

let add_existential l kopts nc env =
let env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts in
Expand Down Expand Up @@ -1341,7 +1340,7 @@ let rec get_implicits typs =
| _ :: typs -> get_implicits typs
| [] -> []

let infer_lit env (L_aux (lit_aux, l)) =
let infer_lit (L_aux (lit_aux, l)) =
match lit_aux with
| L_unit -> unit_typ
| L_zero -> bit_typ
Expand All @@ -1352,8 +1351,8 @@ let infer_lit env (L_aux (lit_aux, l)) =
| L_string _ when !Type_env.opt_string_literal_type -> string_literal_typ
| L_string _ -> string_typ
| L_real _ -> real_typ
| L_bin str -> bits_typ env (nint (String.length str))
| L_hex str -> bits_typ env (nint (String.length str * 4))
| L_bin str -> bitvector_typ (nint (String.length str))
| L_hex str -> bitvector_typ (nint (String.length str * 4))
| L_undef -> typ_error l "Cannot infer the type of undefined"

let instantiate_simple_equations =
Expand Down Expand Up @@ -1654,6 +1653,143 @@ let bind_pattern_vector_subranges (P_aux (_, (l, _)) as pat) env =
)
id_ranges env

let unbound_id_error ~at:l env v =
match Bindings.find_opt v (Env.get_val_specs env) with
| Some _ -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = true })
| None -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = false })

type overload_leaf_type = OL_app of id | OL_id of id | OL_unknown

type 'a overload_tree =
| OT_overloads of id * id list * 'a overload_tree list * 'a annot
| OT_leaf of 'a exp * overload_leaf_type

let overload_leaf_type (Typ_aux (aux, _)) =
match aux with Typ_id id -> OL_id id | Typ_app (id, _) -> OL_app id | _ -> OL_unknown

let rec build_overload_tree env f xs annot =
let overloads = Env.get_overloads_recursive f env in
OT_overloads (f, overloads, List.map (build_overload_tree_arg env) xs, annot)

and build_overload_tree_arg env (E_aux (aux, annot) as exp) =
match aux with
| E_app_infix (x, op, y) when Env.is_overload (deinfix op) env -> build_overload_tree env (deinfix op) [x; y] annot
| E_app (f, xs) when Env.is_overload f env -> build_overload_tree env f xs annot
| E_id v -> begin
match Env.lookup_id v env with
| Local (_, typ) | Enum typ | Register typ -> OT_leaf (exp, overload_leaf_type (Env.expand_synonyms env typ))
| Unbound _ -> unbound_id_error ~at:(fst annot) env v
end
| E_lit lit -> begin
match lit with
| L_aux (L_undef, _) -> OT_leaf (exp, OL_unknown)
| _ -> OT_leaf (exp, overload_leaf_type (infer_lit lit))
end
| _ -> OT_leaf (exp, OL_unknown)

let string_of_overload_leaf = function
| OL_app id -> ": " ^ string_of_id id ^ "(...)"
| OL_id id -> ": " ^ string_of_id id
| OL_unknown -> ": ?"

let rec filter_overload_tree env =
let atom_like id =
let s = string_of_id id in
s = "atom" || s = "range" || s = "implicit"
in
let int_or_nat id =
let s = string_of_id id in
s = "int" || s = "nat"
in
let both_strings s1 s2 = (s1 = "string" && s2 = "string_literal") || (s1 = "string_literal" && s2 = "string") in
let plausible x y =
match (x, y) with
| OL_app id1, OL_id id2 | OL_id id2, OL_app id1 ->
(atom_like id1 && int_or_nat id2) || (string_of_id id1 = "atom_bool" && string_of_id id2 = "bool")
| OL_id id1, OL_id id2 ->
Id.compare id1 id2 = 0
|| both_strings (string_of_id id1) (string_of_id id2)
|| (int_or_nat id1 && int_or_nat id2)
| OL_app id1, OL_app id2 -> Id.compare id1 id2 = 0 || (atom_like id1 && atom_like id2)
| OL_unknown, _ -> true
| _, OL_unknown -> true
in
let is_implicit = function OL_app id -> string_of_id id = "implicit" | _ -> false in
let is_unit = function OL_id id -> string_of_id id = "unit" | _ -> false in
function
| OT_overloads (f, overloads, args, annot) ->
let args = List.map (filter_overload_tree env) args in
let overload_info =
List.map
(fun overload ->
let unwrap_overload_type = function
| Typ_aux (Typ_fn (arg_typs, ret_typ), _) ->
[(overload, List.map overload_leaf_type arg_typs, overload_leaf_type ret_typ)]
| Typ_aux (Typ_bidir (lhs_typ, rhs_typ), _) ->
let lhs = overload_leaf_type lhs_typ in
let rhs = overload_leaf_type rhs_typ in
[(overload, [lhs], rhs); (overload, [rhs], lhs)]
| _ ->
typ_error (fst annot) ("Overload " ^ string_of_id overload ^ " must have a function or mapping type")
in
unwrap_overload_type (snd (Env.get_val_spec overload env))
)
overloads
|> List.concat
in
let plausible_overloads =
List.filter_map
(fun (overload, param_lts, ret_lt) ->
(* If the overload and usage arity don't match, immediatly discard that overload *)
let arity_check = List.compare_lengths args param_lts in
if arity_check = 0 || (arity_check = -1 && is_implicit (List.hd param_lts)) then (
let param_lts = if arity_check = -1 then List.tl param_lts else param_lts in
(* Special case for a function with a single implicit argument *)
match (args, param_lts) with
| [(_, arg_lts)], [param_lt] when List.exists is_unit arg_lts && is_implicit param_lt ->
Some (overload, ret_lt)
| _ ->
let is_plausible =
List.fold_left2
(fun acc (tree, arg_lts) param_lt ->
acc && List.exists (fun arg_lt -> plausible arg_lt param_lt) arg_lts
)
true args param_lts
in
if is_plausible then Some (overload, ret_lt) else None
)
else None
)
overload_info
in
let overloads, returns = List.split plausible_overloads in
(OT_overloads (f, overloads, List.map fst args, annot), returns)
| OT_leaf (_, leaf_type) as tree -> (tree, [leaf_type])

let rec overload_tree_to_exp env = function
| OT_overloads (f, overloads, args, annot) ->
let id, env = Env.add_filtered_overload f overloads env in
let args, env =
List.fold_left
(fun (args, env) arg ->
let arg, env = overload_tree_to_exp env arg in
(arg :: args, env)
)
([], env) args
in
(E_aux (E_app (id, List.rev args), annot), env)
| OT_leaf (exp, _) -> (exp, env)

let rec string_of_overload_tree depth =
let indent = String.make depth ' ' in
function
| OT_overloads (_, overloads, args, _) ->
indent
^ Util.string_of_list ", " string_of_id overloads
^ ("\n" ^ indent)
^ Util.string_of_list ("\n" ^ indent) (string_of_overload_tree (depth + 4)) args
| OT_leaf (exp, leaf) -> indent ^ string_of_exp exp ^ string_of_overload_leaf leaf

let crule r env exp typ =
incr depth;
typ_print (lazy (Util.("Check " |> cyan |> clear) ^ string_of_exp exp ^ " <= " ^ string_of_typ typ));
Expand Down Expand Up @@ -1969,11 +2105,10 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
typ_raise l (Err_no_overloading (mapping, [(forwards_id, err1_loc, err1); (backwards_id, err2_loc, err2)]))
end
end
| E_app (f, xs), _ when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
| E_app (f, xs), _ when Env.is_filtered_overload f env ->
let orig_f, overloads = Env.get_filtered_overloads ~at:l f env in
let rec try_overload = function
| errs, [] -> typ_raise l (Err_no_overloading (f, errs))
| errs, [] -> typ_raise l (Err_no_overloading (orig_f, errs))
| errs, f :: fs -> begin
typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
try crule check_exp env (E_aux (E_app (f, xs), (l, uannot))) typ
Expand All @@ -1983,6 +2118,13 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
end
in
try_overload ([], overloads)
| E_app (f, xs), _ when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
let tree = build_overload_tree env f xs (l, uannot) in
let tree, _ = filter_overload_tree env tree in
let exp, env = overload_tree_to_exp env tree in
check_exp env exp typ
| E_app (f, [x; y]), _ when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> begin
(* We have to ensure that the type of y in (x || y) and (x && y)
is non-empty, otherwise it could force the entire type of the
Expand Down Expand Up @@ -2553,7 +2695,7 @@ and infer_pat env (P_aux (pat_aux, (l, uannot)) as pat) =
| P_lit (L_aux (L_string _, _) as lit) ->
(* String literal patterns match strings, not just string_literals *)
(annot_pat (P_lit lit) string_typ, env, [])
| P_lit lit -> (annot_pat (P_lit lit) (infer_lit env lit), env, [])
| P_lit lit -> (annot_pat (P_lit lit) (infer_lit lit), env, [])
| P_vector (pat :: pats) ->
let fold_pats (pats, env, guards) pat =
let typed_pat, env, guards' = bind_pat env pat bit_typ in
Expand All @@ -2564,7 +2706,7 @@ and infer_pat env (P_aux (pat_aux, (l, uannot)) as pat) =
let etyp = typ_of_pat (List.hd pats) in
(* BVS TODO: Non-bitvector P_vector *)
List.iter (fun pat -> typ_equality l env etyp (typ_of_pat pat)) pats;
(annot_pat (P_vector pats) (bits_typ env len), env, guards)
(annot_pat (P_vector pats) (bitvector_typ len), env, guards)
| P_vector_concat (pat :: pats) -> bind_vector_concat_pat l env uannot pat pats None
| P_vector_subrange (id, n, m) ->
let typ = bitvector_typ_from_range l env n m in
Expand Down Expand Up @@ -2723,7 +2865,7 @@ and bind_vector_concat_generic :
env,
guards' @ guards
)
| None -> (annotate before_uninferred (bits_typ env inferred_len), env, guards)
| None -> (annotate before_uninferred (bitvector_typ inferred_len), env, guards)
end
)

Expand Down Expand Up @@ -3043,15 +3185,10 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
annot_exp (E_block inferred_block) (last_typ inferred_block)
| E_id v -> begin
match Env.lookup_id v env with
| Local (_, typ) | Enum typ -> annot_exp (E_id v) typ
| Register typ -> annot_exp (E_id v) typ
| Unbound _ -> (
match Bindings.find_opt v (Env.get_val_specs env) with
| Some _ -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = true })
| None -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = false })
)
| Local (_, typ) | Enum typ | Register typ -> annot_exp (E_id v) typ
| Unbound _ -> unbound_id_error ~at:l env v
end
| E_lit lit -> annot_exp (E_lit lit) (infer_lit env lit)
| E_lit lit -> annot_exp (E_lit lit) (infer_lit lit)
| E_sizeof nexp -> begin
match nexp with
| Nexp_aux (Nexp_id id, _) when Env.is_abstract_typ id env -> annot_exp (E_sizeof nexp) (atom_typ nexp)
Expand Down Expand Up @@ -3150,11 +3287,10 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
typ_raise l (Err_no_overloading (mapping, [(forwards_id, err1_loc, err1); (backwards_id, err2_loc, err2)]))
end
end
| E_app (f, xs) when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
| E_app (f, xs) when Env.is_filtered_overload f env ->
let orig_f, overloads = Env.get_filtered_overloads ~at:l f env in
let rec try_overload = function
| errs, [] -> typ_raise l (Err_no_overloading (f, errs))
| errs, [] -> typ_raise l (Err_no_overloading (orig_f, errs))
| errs, f :: fs -> begin
typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
try irule infer_exp env (E_aux (E_app (f, xs), (l, uannot)))
Expand All @@ -3164,6 +3300,13 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
end
in
try_overload ([], overloads)
| E_app (f, xs) when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
let tree = build_overload_tree env f xs (l, uannot) in
let tree, _ = filter_overload_tree env tree in
let exp, env = overload_tree_to_exp env tree in
infer_exp env exp
| E_app (f, [x; y]) when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> begin
match destruct_exist (typ_of (irule infer_exp env y)) with
| None | Some (_, NC_aux (NC_true, _), _) -> infer_funapp l env f [x; y] None
Expand Down Expand Up @@ -3275,7 +3418,7 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
begin
match typ_of inferred_item with
| Typ_aux (Typ_id id, _) when string_of_id id = "bit" ->
let bitvec_typ = bits_typ env (nint (List.length vec)) in
let bitvec_typ = bitvector_typ (nint (List.length vec)) in
annot_exp (E_vector (inferred_item :: checked_items)) bitvec_typ
| _ ->
let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in
Expand Down Expand Up @@ -3866,7 +4009,7 @@ and infer_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, uannot)) as mp
| _ -> typ_error l ("Malformed mapping type " ^ string_of_id f)
end
| MP_lit (L_aux (L_string _, _) as lit) -> (annot_mpat (MP_lit lit) string_typ, env, [])
| MP_lit lit -> (annot_mpat (MP_lit lit) (infer_lit env lit), env, [])
| MP_lit lit -> (annot_mpat (MP_lit lit) (infer_lit lit), env, [])
| MP_typ (mpat, typ_annot) ->
Env.wf_typ ~at:l env typ_annot;
let typed_mpat, env, guards = bind_mpat allow_unknown other_env env mpat typ_annot in
Expand Down Expand Up @@ -4000,7 +4143,7 @@ let infer_funtyp l env tannotopt funcls =
| Typ_annot_opt_aux (Typ_annot_opt_some (quant, ret_typ), _) -> begin
let rec typ_from_pat (P_aux (pat_aux, (l, _)) as pat) =
match pat_aux with
| P_lit lit -> infer_lit env lit
| P_lit lit -> infer_lit lit
| P_typ (typ, _) -> typ
| P_tuple pats -> mk_typ (Typ_tuple (List.map typ_from_pat pats))
| _ -> typ_error l ("Cannot infer type from pattern " ^ string_of_pat pat)
Expand Down Expand Up @@ -4298,7 +4441,7 @@ let check_record l env def_annot id typq fields =
try
match get_def_attribute "bitfield" def_annot with
| Some (_, size) when not (Env.is_bitfield id env) ->
Env.add_bitfield id (bits_typ env (nconstant (Big_int.of_string size))) Bindings.empty env
Env.add_bitfield id (bitvector_typ (nconstant (Big_int.of_string size))) Bindings.empty env
| _ -> env
with _ -> env
in
Expand Down
19 changes: 19 additions & 0 deletions src/lib/type_env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type env = {
open_all : bool;
opened : Project.ModSet.t;
locals : (mut * typ) Bindings.t;
filtered_overloads : (id * id list) Bindings.t;
typ_vars : (Ast.l * kind_aux) KBindings.t;
shadow_vars : int KBindings.t;
allow_bindings : bool;
Expand Down Expand Up @@ -257,6 +258,7 @@ let empty =
open_all = false;
opened = Project.ModSet.empty;
locals = Bindings.empty;
filtered_overloads = Bindings.empty;
typ_vars = KBindings.empty;
shadow_vars = KBindings.empty;
allow_bindings = true;
Expand Down Expand Up @@ -463,6 +465,11 @@ let get_overloads id env =
let ids = get_item_with_loc hd_opt (id_loc id) env item in
List.filter (overload_item_in_scope env) ids

let get_overloads_recursive id env =
let overloads = get_overloads id env in
List.map (fun overload -> if is_overload overload env then get_overloads overload env else [overload]) overloads
|> List.concat

let add_overloads l id ids env =
typ_print
(lazy (adding ^ "overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]"))
Expand Down Expand Up @@ -506,6 +513,18 @@ let add_overloads l id ids env =
)
env

let is_filtered_overload id env = Bindings.mem id env.filtered_overloads

let get_filtered_overloads ~at:l id env =
match Bindings.find_opt id env.filtered_overloads with
| None -> Reporting.unreachable l __POS__ "Failed to get filtered overload"
| Some overloads -> overloads

let add_filtered_overload original_id ids env =
let n = Bindings.cardinal env.filtered_overloads in
let id = mk_id ("filtered_overload#" ^ string_of_int n) in
(id, { env with filtered_overloads = Bindings.add id (original_id, ids) env.filtered_overloads })

let infer_kind env id =
let l = id_loc id in
if Bindings.mem id builtin_typs then Bindings.find id builtin_typs
Expand Down
5 changes: 5 additions & 0 deletions src/lib/type_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ val is_overload : id -> t -> bool
val get_overload_locs : id -> t -> Ast.l list
val add_overloads : l -> id -> id list -> t -> t
val get_overloads : id -> t -> id list
val get_overloads_recursive : id -> t -> id list

val is_filtered_overload : id -> t -> bool
val get_filtered_overloads : at:l -> id -> t -> id * id list
val add_filtered_overload : id -> id list -> t -> id * t

val is_extern : id -> t -> string -> bool
val add_extern : id -> extern -> t -> t
Expand Down
14 changes: 2 additions & 12 deletions test/typecheck/fail/add_vec_lit_old.expect
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,5 @@ Cast annotations are deprecated. They will be removed in a future version of the
Type error:
fail/add_vec_lit_old.sail:14.23-32:
14 |let x : range(0, 30) = 0xF + 0x2
 | ^-------^ (operator +)
 | No overloading for (operator +), tried:
 | * add_vec
 | fail/add_vec_lit_old.sail:14.23-32:
 | 14 |let x : range(0, 30) = 0xF + 0x2
 |  | ^-------^
 |  | Type mismatch between int('ex5#) and bitvector(4)
 | * add_range
 | fail/add_vec_lit_old.sail:14.23-26:
 | 14 |let x : range(0, 30) = 0xF + 0x2
 |  | ^-^ checking function argument has type int('ex1#)
 |  | Type mismatch between int('ex1#) and bitvector(4)
 | ^-------^
 | Type mismatch between int('ex5#) and bitvector(4)
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Old set syntax, {|1, 2, 3|} can now be written as {1, 2, 3}.
26 | Some(Ctor1(a, x, c))
 | ^------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor1
 | * 'ex343# in {32, 64}
 | * 'ex335# in {32, 64}
Loading
Loading