Skip to content

Commit

Permalink
Filter the set of overloads before checking them
Browse files Browse the repository at this point in the history
Construct a tree of possible overloadings whenever we find an E_app node
that is overloaded, then filter any which are not possible. We do this by doing
a simple check on the shape of the argument types. If they can't possibly match
then we discard that overload.
  • Loading branch information
Alasdair committed May 7, 2024
1 parent 3fde1d5 commit eb17745
Show file tree
Hide file tree
Showing 21 changed files with 243 additions and 183 deletions.
197 changes: 170 additions & 27 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ let wf_binding l env (typq, typ) =
let wf_typschm env (TypSchm_aux (TypSchm_ts (typq, typ), l)) = wf_binding l env (typq, typ)

let dvector_typ _env n typ = vector_typ n typ
let bits_typ _env n = bitvector_typ n

let add_existential l kopts nc env =
let env = List.fold_left (fun env kopt -> Env.add_typ_var l kopt env) env kopts in
Expand Down Expand Up @@ -1341,7 +1340,7 @@ let rec get_implicits typs =
| _ :: typs -> get_implicits typs
| [] -> []

let infer_lit env (L_aux (lit_aux, l)) =
let infer_lit (L_aux (lit_aux, l)) =
match lit_aux with
| L_unit -> unit_typ
| L_zero -> bit_typ
Expand All @@ -1352,8 +1351,8 @@ let infer_lit env (L_aux (lit_aux, l)) =
| L_string _ when !Type_env.opt_string_literal_type -> string_literal_typ
| L_string _ -> string_typ
| L_real _ -> real_typ
| L_bin str -> bits_typ env (nint (String.length str))
| L_hex str -> bits_typ env (nint (String.length str * 4))
| L_bin str -> bitvector_typ (nint (String.length str))
| L_hex str -> bitvector_typ (nint (String.length str * 4))
| L_undef -> typ_error l "Cannot infer the type of undefined"

let instantiate_simple_equations =
Expand Down Expand Up @@ -1654,6 +1653,143 @@ let bind_pattern_vector_subranges (P_aux (_, (l, _)) as pat) env =
)
id_ranges env

let unbound_id_error ~at:l env v =
match Bindings.find_opt v (Env.get_val_specs env) with
| Some _ -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = true })
| None -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = false })

type overload_leaf_type = OL_app of id | OL_id of id | OL_unknown

type 'a overload_tree =
| OT_overloads of id * id list * 'a overload_tree list * 'a annot
| OT_leaf of 'a exp * overload_leaf_type

let overload_leaf_type (Typ_aux (aux, _)) =
match aux with Typ_id id -> OL_id id | Typ_app (id, _) -> OL_app id | _ -> OL_unknown

let rec build_overload_tree env f xs annot =
let overloads = Env.get_overloads_recursive f env in
OT_overloads (f, overloads, List.map (build_overload_tree_arg env) xs, annot)

and build_overload_tree_arg env (E_aux (aux, annot) as exp) =
match aux with
| E_app_infix (x, op, y) when Env.is_overload (deinfix op) env -> build_overload_tree env (deinfix op) [x; y] annot
| E_app (f, xs) when Env.is_overload f env -> build_overload_tree env f xs annot
| E_id v -> begin
match Env.lookup_id v env with
| Local (_, typ) | Enum typ | Register typ -> OT_leaf (exp, overload_leaf_type (Env.expand_synonyms env typ))
| Unbound _ -> unbound_id_error ~at:(fst annot) env v
end
| E_lit lit -> begin
match lit with
| L_aux (L_undef, _) -> OT_leaf (exp, OL_unknown)
| _ -> OT_leaf (exp, overload_leaf_type (infer_lit lit))
end
| _ -> OT_leaf (exp, OL_unknown)

let string_of_overload_leaf = function
| OL_app id -> ": " ^ string_of_id id ^ "(...)"
| OL_id id -> ": " ^ string_of_id id
| OL_unknown -> ": ?"

let rec filter_overload_tree env =
let atom_like id =
let s = string_of_id id in
s = "atom" || s = "range" || s = "implicit"
in
let int_or_nat id =
let s = string_of_id id in
s = "int" || s = "nat"
in
let both_strings s1 s2 = (s1 = "string" && s2 = "string_literal") || (s1 = "string_literal" && s2 = "string") in
let plausible x y =
match (x, y) with
| OL_app id1, OL_id id2 | OL_id id2, OL_app id1 ->
(atom_like id1 && int_or_nat id2) || (string_of_id id1 = "atom_bool" && string_of_id id2 = "bool")
| OL_id id1, OL_id id2 ->
Id.compare id1 id2 = 0
|| both_strings (string_of_id id1) (string_of_id id2)
|| (int_or_nat id1 && int_or_nat id2)
| OL_app id1, OL_app id2 -> Id.compare id1 id2 = 0 || (atom_like id1 && atom_like id2)
| OL_unknown, _ -> true
| _, OL_unknown -> true
in
let is_implicit = function OL_app id -> string_of_id id = "implicit" | _ -> false in
let is_unit = function OL_id id -> string_of_id id = "unit" | _ -> false in
function
| OT_overloads (f, overloads, args, annot) ->
let args = List.map (filter_overload_tree env) args in
let overload_info =
List.map
(fun overload ->
let unwrap_overload_type = function
| Typ_aux (Typ_fn (arg_typs, ret_typ), _) ->
[(overload, List.map overload_leaf_type arg_typs, overload_leaf_type ret_typ)]
| Typ_aux (Typ_bidir (lhs_typ, rhs_typ), _) ->
let lhs = overload_leaf_type lhs_typ in
let rhs = overload_leaf_type rhs_typ in
[(overload, [lhs], rhs); (overload, [rhs], lhs)]
| _ ->
typ_error (fst annot) ("Overload " ^ string_of_id overload ^ " must have a function or mapping type")
in
unwrap_overload_type (snd (Env.get_val_spec overload env))
)
overloads
|> List.concat
in
let plausible_overloads =
List.filter_map
(fun (overload, param_lts, ret_lt) ->
(* If the overload and usage arity don't match, immediatly discard that overload *)
let arity_check = List.compare_lengths args param_lts in
if arity_check = 0 || (arity_check = -1 && is_implicit (List.hd param_lts)) then (
let param_lts = if arity_check = -1 then List.tl param_lts else param_lts in
(* Special case for a function with a single implicit argument *)
match (args, param_lts) with
| [(_, arg_lts)], [param_lt] when List.exists is_unit arg_lts && is_implicit param_lt ->
Some (overload, ret_lt)
| _ ->
let is_plausible =
List.fold_left2
(fun acc (tree, arg_lts) param_lt ->
acc && List.exists (fun arg_lt -> plausible arg_lt param_lt) arg_lts
)
true args param_lts
in
if is_plausible then Some (overload, ret_lt) else None
)
else None
)
overload_info
in
let overloads, returns = List.split plausible_overloads in
(OT_overloads (f, overloads, List.map fst args, annot), returns)
| OT_leaf (_, leaf_type) as tree -> (tree, [leaf_type])

let rec overload_tree_to_exp env = function
| OT_overloads (f, overloads, args, annot) ->
let id, env = Env.add_filtered_overload f overloads env in
let args, env =
List.fold_left
(fun (args, env) arg ->
let arg, env = overload_tree_to_exp env arg in
(arg :: args, env)
)
([], env) args
in
(E_aux (E_app (id, List.rev args), annot), env)
| OT_leaf (exp, _) -> (exp, env)

let rec string_of_overload_tree depth =
let indent = String.make depth ' ' in
function
| OT_overloads (_, overloads, args, _) ->
indent
^ Util.string_of_list ", " string_of_id overloads
^ ("\n" ^ indent)
^ Util.string_of_list ("\n" ^ indent) (string_of_overload_tree (depth + 4)) args
| OT_leaf (exp, leaf) -> indent ^ string_of_exp exp ^ string_of_overload_leaf leaf

let crule r env exp typ =
incr depth;
typ_print (lazy (Util.("Check " |> cyan |> clear) ^ string_of_exp exp ^ " <= " ^ string_of_typ typ));
Expand Down Expand Up @@ -1969,11 +2105,10 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
typ_raise l (Err_no_overloading (mapping, [(forwards_id, err1_loc, err1); (backwards_id, err2_loc, err2)]))
end
end
| E_app (f, xs), _ when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
| E_app (f, xs), _ when Env.is_filtered_overload f env ->
let orig_f, overloads = Env.get_filtered_overloads ~at:l f env in
let rec try_overload = function
| errs, [] -> typ_raise l (Err_no_overloading (f, errs))
| errs, [] -> typ_raise l (Err_no_overloading (orig_f, errs))
| errs, f :: fs -> begin
typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
try crule check_exp env (E_aux (E_app (f, xs), (l, uannot))) typ
Expand All @@ -1983,6 +2118,13 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
end
in
try_overload ([], overloads)
| E_app (f, xs), _ when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
let tree = build_overload_tree env f xs (l, uannot) in
let tree, _ = filter_overload_tree env tree in
let exp, env = overload_tree_to_exp env tree in
check_exp env exp typ
| E_app (f, [x; y]), _ when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> begin
(* We have to ensure that the type of y in (x || y) and (x && y)
is non-empty, otherwise it could force the entire type of the
Expand Down Expand Up @@ -2553,7 +2695,7 @@ and infer_pat env (P_aux (pat_aux, (l, uannot)) as pat) =
| P_lit (L_aux (L_string _, _) as lit) ->
(* String literal patterns match strings, not just string_literals *)
(annot_pat (P_lit lit) string_typ, env, [])
| P_lit lit -> (annot_pat (P_lit lit) (infer_lit env lit), env, [])
| P_lit lit -> (annot_pat (P_lit lit) (infer_lit lit), env, [])
| P_vector (pat :: pats) ->
let fold_pats (pats, env, guards) pat =
let typed_pat, env, guards' = bind_pat env pat bit_typ in
Expand All @@ -2564,7 +2706,7 @@ and infer_pat env (P_aux (pat_aux, (l, uannot)) as pat) =
let etyp = typ_of_pat (List.hd pats) in
(* BVS TODO: Non-bitvector P_vector *)
List.iter (fun pat -> typ_equality l env etyp (typ_of_pat pat)) pats;
(annot_pat (P_vector pats) (bits_typ env len), env, guards)
(annot_pat (P_vector pats) (bitvector_typ len), env, guards)
| P_vector_concat (pat :: pats) -> bind_vector_concat_pat l env uannot pat pats None
| P_vector_subrange (id, n, m) ->
let typ = bitvector_typ_from_range l env n m in
Expand Down Expand Up @@ -2723,7 +2865,7 @@ and bind_vector_concat_generic :
env,
guards' @ guards
)
| None -> (annotate before_uninferred (bits_typ env inferred_len), env, guards)
| None -> (annotate before_uninferred (bitvector_typ inferred_len), env, guards)
end
)

Expand Down Expand Up @@ -3043,15 +3185,10 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
annot_exp (E_block inferred_block) (last_typ inferred_block)
| E_id v -> begin
match Env.lookup_id v env with
| Local (_, typ) | Enum typ -> annot_exp (E_id v) typ
| Register typ -> annot_exp (E_id v) typ
| Unbound _ -> (
match Bindings.find_opt v (Env.get_val_specs env) with
| Some _ -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = true })
| None -> typ_raise l (Err_unbound_id { id = v; locals = Env.get_locals env; have_function = false })
)
| Local (_, typ) | Enum typ | Register typ -> annot_exp (E_id v) typ
| Unbound _ -> unbound_id_error ~at:l env v
end
| E_lit lit -> annot_exp (E_lit lit) (infer_lit env lit)
| E_lit lit -> annot_exp (E_lit lit) (infer_lit lit)
| E_sizeof nexp -> begin
match nexp with
| Nexp_aux (Nexp_id id, _) when Env.is_abstract_typ id env -> annot_exp (E_sizeof nexp) (atom_typ nexp)
Expand Down Expand Up @@ -3150,11 +3287,10 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
typ_raise l (Err_no_overloading (mapping, [(forwards_id, err1_loc, err1); (backwards_id, err2_loc, err2)]))
end
end
| E_app (f, xs) when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
| E_app (f, xs) when Env.is_filtered_overload f env ->
let orig_f, overloads = Env.get_filtered_overloads ~at:l f env in
let rec try_overload = function
| errs, [] -> typ_raise l (Err_no_overloading (f, errs))
| errs, [] -> typ_raise l (Err_no_overloading (orig_f, errs))
| errs, f :: fs -> begin
typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
try irule infer_exp env (E_aux (E_app (f, xs), (l, uannot)))
Expand All @@ -3164,6 +3300,13 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
end
in
try_overload ([], overloads)
| E_app (f, xs) when Env.is_overload f env ->
let overloads = Env.get_overloads f env in
check_overload_member_scope l f overloads env;
let tree = build_overload_tree env f xs (l, uannot) in
let tree, _ = filter_overload_tree env tree in
let exp, env = overload_tree_to_exp env tree in
infer_exp env exp
| E_app (f, [x; y]) when string_of_id f = "and_bool" || string_of_id f = "or_bool" -> begin
match destruct_exist (typ_of (irule infer_exp env y)) with
| None | Some (_, NC_aux (NC_true, _), _) -> infer_funapp l env f [x; y] None
Expand Down Expand Up @@ -3275,7 +3418,7 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
begin
match typ_of inferred_item with
| Typ_aux (Typ_id id, _) when string_of_id id = "bit" ->
let bitvec_typ = bits_typ env (nint (List.length vec)) in
let bitvec_typ = bitvector_typ (nint (List.length vec)) in
annot_exp (E_vector (inferred_item :: checked_items)) bitvec_typ
| _ ->
let vec_typ = dvector_typ env (nint (List.length vec)) (typ_of inferred_item) in
Expand Down Expand Up @@ -3866,7 +4009,7 @@ and infer_mpat allow_unknown other_env env (MP_aux (mpat_aux, (l, uannot)) as mp
| _ -> typ_error l ("Malformed mapping type " ^ string_of_id f)
end
| MP_lit (L_aux (L_string _, _) as lit) -> (annot_mpat (MP_lit lit) string_typ, env, [])
| MP_lit lit -> (annot_mpat (MP_lit lit) (infer_lit env lit), env, [])
| MP_lit lit -> (annot_mpat (MP_lit lit) (infer_lit lit), env, [])
| MP_typ (mpat, typ_annot) ->
Env.wf_typ ~at:l env typ_annot;
let typed_mpat, env, guards = bind_mpat allow_unknown other_env env mpat typ_annot in
Expand Down Expand Up @@ -4000,7 +4143,7 @@ let infer_funtyp l env tannotopt funcls =
| Typ_annot_opt_aux (Typ_annot_opt_some (quant, ret_typ), _) -> begin
let rec typ_from_pat (P_aux (pat_aux, (l, _)) as pat) =
match pat_aux with
| P_lit lit -> infer_lit env lit
| P_lit lit -> infer_lit lit
| P_typ (typ, _) -> typ
| P_tuple pats -> mk_typ (Typ_tuple (List.map typ_from_pat pats))
| _ -> typ_error l ("Cannot infer type from pattern " ^ string_of_pat pat)
Expand Down Expand Up @@ -4298,7 +4441,7 @@ let check_record l env def_annot id typq fields =
try
match get_def_attribute "bitfield" def_annot with
| Some (_, size) when not (Env.is_bitfield id env) ->
Env.add_bitfield id (bits_typ env (nconstant (Big_int.of_string size))) Bindings.empty env
Env.add_bitfield id (bitvector_typ (nconstant (Big_int.of_string size))) Bindings.empty env
| _ -> env
with _ -> env
in
Expand Down
20 changes: 20 additions & 0 deletions src/lib/type_env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type env = {
open_all : bool;
opened : Project.ModSet.t;
locals : (mut * typ) Bindings.t;
filtered_overloads : (id * id list) Bindings.t;
typ_vars : (Ast.l * kind_aux) KBindings.t;
shadow_vars : int KBindings.t;
allow_bindings : bool;
Expand Down Expand Up @@ -257,6 +258,7 @@ let empty =
open_all = false;
opened = Project.ModSet.empty;
locals = Bindings.empty;
filtered_overloads = Bindings.empty;
typ_vars = KBindings.empty;
shadow_vars = KBindings.empty;
allow_bindings = true;
Expand Down Expand Up @@ -463,6 +465,12 @@ let get_overloads id env =
let ids = get_item_with_loc hd_opt (id_loc id) env item in
List.filter (overload_item_in_scope env) ids
let get_overloads_recursive id env =
let overloads = get_overloads id env in
List.concat_map
(fun overload -> if is_overload overload env then get_overloads overload env else [overload])
overloads
let add_overloads l id ids env =
typ_print
(lazy (adding ^ "overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]"))
Expand Down Expand Up @@ -506,6 +514,18 @@ let add_overloads l id ids env =
)
env
let is_filtered_overload id env = Bindings.mem id env.filtered_overloads
let get_filtered_overloads ~at:l id env =
match Bindings.find_opt id env.filtered_overloads with
| None -> Reporting.unreachable l __POS__ "Failed to get filtered overload"
| Some overloads -> overloads
let add_filtered_overload original_id ids env =
let n = Bindings.cardinal env.filtered_overloads in
let id = mk_id ("filtered_overload#" ^ string_of_int n) in
(id, { env with filtered_overloads = Bindings.add id (original_id, ids) env.filtered_overloads })
let infer_kind env id =
let l = id_loc id in
if Bindings.mem id builtin_typs then Bindings.find id builtin_typs
Expand Down
5 changes: 5 additions & 0 deletions src/lib/type_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ val is_overload : id -> t -> bool
val get_overload_locs : id -> t -> Ast.l list
val add_overloads : l -> id -> id list -> t -> t
val get_overloads : id -> t -> id list
val get_overloads_recursive : id -> t -> id list

val is_filtered_overload : id -> t -> bool
val get_filtered_overloads : at:l -> id -> t -> id * id list
val add_filtered_overload : id -> id list -> t -> id * t

val is_extern : id -> t -> string -> bool
val add_extern : id -> extern -> t -> t
Expand Down
14 changes: 2 additions & 12 deletions test/typecheck/fail/add_vec_lit_old.expect
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,5 @@ Cast annotations are deprecated. They will be removed in a future version of the
Type error:
fail/add_vec_lit_old.sail:14.23-32:
14 |let x : range(0, 30) = 0xF + 0x2
 | ^-------^ (operator +)
 | No overloading for (operator +), tried:
 | * add_vec
 | fail/add_vec_lit_old.sail:14.23-32:
 | 14 |let x : range(0, 30) = 0xF + 0x2
 |  | ^-------^
 |  | Type mismatch between int('ex5#) and bitvector(4)
 | * add_range
 | fail/add_vec_lit_old.sail:14.23-26:
 | 14 |let x : range(0, 30) = 0xF + 0x2
 |  | ^-^ checking function argument has type int('ex1#)
 |  | Type mismatch between int('ex1#) and bitvector(4)
 | ^-------^
 | Type mismatch between int('ex5#) and bitvector(4)
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Old set syntax, {|1, 2, 3|} can now be written as {1, 2, 3}.
26 | Some(Ctor1(a, x, c))
 | ^------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor1
 | * 'ex343# in {32, 64}
 | * 'ex335# in {32, 64}
Loading

0 comments on commit eb17745

Please sign in to comment.