diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index 6f9ee93b9..5a05521f7 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -121,25 +121,69 @@ type type_error = let err_because (error1, l, error2) = Err_inner (error1, l, "Caused by", None, error2) -type env = { - top_val_specs : (typquant * typ) Bindings.t; +module IdPair = struct + type t = id * id + let compare (a, b) (c, d) = + let x = Id.compare a c in + if x = 0 then Id.compare b d else x +end + +module IdPairMap = Map.Make (IdPair) + +type ('a, 'b) generic_env_item = { item : 'a; loc : 'b } + +type 'a env_item = ('a, Parse_ast.l) generic_env_item + +type 'a multiple_env_item = ('a, Parse_ast.l list) generic_env_item + +let mk_item ~loc:l item = { item; loc = l } + +let get_item item = item.item + +let item_loc item = item.loc + +type global_env = { + val_specs : (typquant * typ) env_item Bindings.t; defined_val_specs : IdSet.t; + externs : extern Bindings.t; + mappings : (typquant * typ * typ) env_item Bindings.t; + unions : (typquant * type_union list) env_item Bindings.t; + union_ids : (typquant * typ) env_item Bindings.t; + scattered_union_envs : global_env Bindings.t; + enums : (bool * IdSet.t) env_item Bindings.t; + records : (typquant * (typ * id) list) env_item Bindings.t; + synonyms : (typquant * typ_arg) env_item Bindings.t; + accessors : (typquant * typ) env_item IdPairMap.t; + bitfields : (typ * index_range Bindings.t) env_item Bindings.t; + letbinds : typ env_item Bindings.t; + registers : typ env_item Bindings.t; + overloads : id list multiple_env_item Bindings.t; +} + +let empty_global_env = + { + val_specs = Bindings.empty; + defined_val_specs = IdSet.empty; + externs = Bindings.empty; + mappings = Bindings.empty; + unions = Bindings.empty; + union_ids = Bindings.empty; + scattered_union_envs = Bindings.empty; + enums = Bindings.empty; + records = Bindings.empty; + accessors = IdPairMap.empty; + synonyms = Bindings.empty; + bitfields = Bindings.empty; + letbinds = Bindings.empty; + registers = Bindings.empty; + overloads = Bindings.empty; + } + +type env = { + global : global_env; locals : (mut * typ) Bindings.t; - top_letbinds : IdSet.t; - union_ids : (typquant * typ) Bindings.t; - registers : typ Bindings.t; - variants : (typquant * type_union list) Bindings.t; - scattered_variant_envs : env Bindings.t; - mappings : (typquant * typ * typ) Bindings.t; typ_vars : (Ast.l * kind_aux) KBindings.t; shadow_vars : int KBindings.t; - typ_synonyms : (typquant * typ_arg) Bindings.t; - typ_params : typquant Bindings.t; - overloads : id list Bindings.t; - enums : (bool * IdSet.t) Bindings.t; - records : (typquant * (typ * id) list) Bindings.t; - accessors : (typquant * typ) Bindings.t; - externs : extern Bindings.t; allow_bindings : bool; constraints : (constraint_reason * n_constraint) list; default_order : order option; @@ -147,13 +191,14 @@ type env = { poly_undefineds : bool; prove : (env -> n_constraint -> bool) option; allow_unknowns : bool; - bitfields : (typ * index_range Bindings.t) Bindings.t; toplevel : l option; outcomes : (typquant * typ * kinded_id list * id list * env) Bindings.t; outcome_typschm : (typquant * typ) option; outcome_instantiation : (Ast.l * typ) KBindings.t; } +let update_global f env = { env with global = f env.global } + exception Type_error of l * type_error let typ_error l m = raise (Type_error (l, Err_other m)) @@ -162,9 +207,6 @@ let typ_raise l err = raise (Type_error (l, err)) let deinfix = function Id_aux (Id v, l) -> Id_aux (Operator v, l) | Id_aux (Operator v, l) -> Id_aux (Operator v, l) -let field_name rec_id id = - match (rec_id, id) with Id_aux (Id r, _), Id_aux (Id v, l) -> Id_aux (Id (r ^ "." ^ v), l) | _, _ -> assert false - let string_of_bind (typquant, typ) = string_of_typquant typquant ^ ". " ^ string_of_typ typ let orig_kid (Kid_aux (Var v, l) as kid) = @@ -578,24 +620,10 @@ end = struct let empty = { - top_val_specs = Bindings.empty; - defined_val_specs = IdSet.empty; + global = empty_global_env; locals = Bindings.empty; - top_letbinds = IdSet.empty; - union_ids = Bindings.empty; - registers = Bindings.empty; - variants = Bindings.empty; - scattered_variant_envs = Bindings.empty; - mappings = Bindings.empty; typ_vars = KBindings.empty; shadow_vars = KBindings.empty; - typ_synonyms = Bindings.empty; - typ_params = Bindings.empty; - overloads = Bindings.empty; - enums = Bindings.empty; - records = Bindings.empty; - accessors = Bindings.empty; - externs = Bindings.empty; allow_bindings = true; constraints = []; default_order = None; @@ -603,7 +631,6 @@ end = struct poly_undefineds = false; prove = None; allow_unknowns = false; - bitfields = Bindings.empty; toplevel = None; outcomes = Bindings.empty; outcome_typschm = None; @@ -698,18 +725,16 @@ end = struct ] let bound_typ_id env id = - Bindings.mem id env.typ_synonyms || Bindings.mem id env.variants || Bindings.mem id env.records - || Bindings.mem id env.enums || Bindings.mem id builtin_typs + Bindings.mem id env.global.synonyms || Bindings.mem id env.global.unions || Bindings.mem id env.global.records + || Bindings.mem id env.global.enums || Bindings.mem id builtin_typs let get_binding_loc env id = - let find map = - Bindings.bindings map |> List.find (fun (id', _) -> Id.compare id id' = 0) |> fun (id', _) -> Some (id_loc id') - in + let find map = Some (item_loc (Bindings.find id map)) in if Bindings.mem id builtin_typs then None - else if Bindings.mem id env.variants then find env.variants - else if Bindings.mem id env.records then find env.records - else if Bindings.mem id env.enums then find env.enums - else if Bindings.mem id env.typ_synonyms then find env.typ_synonyms + else if Bindings.mem id env.global.unions then find env.global.unions + else if Bindings.mem id env.global.records then find env.global.records + else if Bindings.mem id env.global.enums then find env.global.enums + else if Bindings.mem id env.global.synonyms then find env.global.synonyms else None let already_bound str id env = @@ -728,14 +753,11 @@ end = struct let suffix = if Bindings.mem id builtin_typs then " as a built-in type" else "" in typ_error (id_loc id) ("Cannot create " ^ str ^ " type " ^ string_of_id id ^ ", name is already bound" ^ suffix) - let bound_ctor_fn env id = Bindings.mem id env.top_val_specs || Bindings.mem id env.union_ids + let bound_ctor_fn env id = Bindings.mem id env.global.val_specs || Bindings.mem id env.global.union_ids let get_ctor_fn_binding_loc env id = - let find map = - Bindings.bindings map |> List.find (fun (id', _) -> Id.compare id id' = 0) |> fun (id', _) -> Some (id_loc id') - in - if Bindings.mem id env.top_val_specs then find env.top_val_specs - else if Bindings.mem id env.union_ids then find env.union_ids + if Bindings.mem id env.global.val_specs then Some (item_loc (Bindings.find id env.global.val_specs)) + else if Bindings.mem id env.global.union_ids then Some (item_loc (Bindings.find id env.global.union_ids)) else None let already_bound_ctor_fn str id env = @@ -757,27 +779,40 @@ end = struct Reporting.unreachable (id_loc id) __POS__ ("Could not find original binding for duplicate " ^ str ^ " called " ^ string_of_id id) - let get_overloads id env = try Bindings.find id env.overloads with Not_found -> [] + let get_overloads id env = try get_item (Bindings.find id env.global.overloads) with Not_found -> [] let add_overloads l id ids env = typ_print (lazy (adding ^ "overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]")); List.iter (fun overload -> - if not (bound_ctor_fn env overload || Bindings.mem overload env.overloads) then + if not (bound_ctor_fn env overload || Bindings.mem overload env.global.overloads) then typ_error (Hint ("unbound identifier", id_loc overload, l)) ("Cannot create or extend overload " ^ string_of_id id ^ ", " ^ string_of_id overload ^ " is not bound") ) ids; - let existing = try Bindings.find id env.overloads with Not_found -> [] in - { env with overloads = Bindings.add id (existing @ ids) env.overloads } + match Bindings.find_opt id env.global.overloads with + | Some existing -> + update_global + (fun global -> + { + global with + overloads = + Bindings.add id (mk_item ~loc:(l :: item_loc existing) (get_item existing @ ids)) global.overloads; + } + ) + env + | None -> + update_global + (fun global -> { global with overloads = Bindings.add id (mk_item ~loc:[l] ids) global.overloads }) + env let infer_kind env id = if Bindings.mem id builtin_typs then Bindings.find id builtin_typs - else if Bindings.mem id env.variants then fst (Bindings.find id env.variants) - else if Bindings.mem id env.records then fst (Bindings.find id env.records) - else if Bindings.mem id env.enums then mk_typquant [] - else if Bindings.mem id env.typ_synonyms then + else if Bindings.mem id env.global.unions then fst (get_item (Bindings.find id env.global.unions)) + else if Bindings.mem id env.global.records then fst (get_item (Bindings.find id env.global.records)) + else if Bindings.mem id env.global.enums then mk_typquant [] + else if Bindings.mem id env.global.synonyms then typ_error (id_loc id) ("Cannot infer kind of type synonym " ^ string_of_id id) else typ_error (id_loc id) ("Cannot infer kind of " ^ string_of_id id) @@ -844,9 +879,11 @@ end = struct ) let get_typ_synonym id env = - match Bindings.find_opt id env.typ_synonyms with Some (typq, arg) -> mk_synonym typq arg | None -> raise Not_found + match Option.map get_item (Bindings.find_opt id env.global.synonyms) with + | Some (typq, arg) -> mk_synonym typq arg + | None -> raise Not_found - let get_typ_synonyms env = env.typ_synonyms + let get_typ_synonyms env = Bindings.map get_item env.global.synonyms let get_constraints env = List.map snd env.constraints @@ -1181,20 +1218,26 @@ end = struct ( lazy (adding ^ "type synonym " ^ string_of_id id ^ ", " ^ string_of_typquant typq ^ " = " ^ string_of_typ_arg arg) ); - { - env with - typ_synonyms = - Bindings.add id (typq, expand_arg_synonyms (add_typquant (id_loc id) typq env) arg) env.typ_synonyms; - } + update_global + (fun global -> + { + global with + synonyms = + Bindings.add id + (mk_item ~loc:(id_loc id) (typq, expand_arg_synonyms (add_typquant (id_loc id) typq env) arg)) + global.synonyms; + } + ) + env ) let get_val_spec_orig id env = - try Bindings.find id env.top_val_specs + try get_item (Bindings.find id env.global.val_specs) with Not_found -> typ_error (id_loc id) ("No type signature found for " ^ string_of_id id) let get_val_spec id env = try - let bind = Bindings.find id env.top_val_specs in + let bind = get_item (Bindings.find id env.global.val_specs) in typ_debug ( lazy ("get_val_spec: Env has " @@ -1210,17 +1253,19 @@ end = struct bind' with Not_found -> typ_error (id_loc id) ("No type declaration found for " ^ string_of_id id) - let get_val_specs env = env.top_val_specs + let get_val_specs env = Bindings.map get_item env.global.val_specs let add_union_id id bind env = if bound_ctor_fn env id then already_bound_ctor_fn "union constructor" id env - else begin + else ( typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind)); - { env with union_ids = Bindings.add id bind env.union_ids } - end + update_global + (fun global -> { global with union_ids = Bindings.add id { item = bind; loc = id_loc id } global.union_ids }) + env + ) let get_union_id id env = - match Bindings.find_opt id env.union_ids with + match Option.map get_item (Bindings.find_opt id env.global.union_ids) with | Some bind -> List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) | None -> typ_error (id_loc id) ("No union constructor found for " ^ string_of_id id) @@ -1256,20 +1301,29 @@ end = struct let arg_typs = List.map2 (fun typ -> function Some (_, _, typ) -> typ | None -> typ) arg_typs base_args in let typ = Typ_aux (Typ_fn (arg_typs, ret_typ), l) in typ_print (lazy (adding ^ "val " ^ string_of_id id ^ " : " ^ string_of_bind (typq, typ))); - { env with top_val_specs = Bindings.add id (typq, typ) env.top_val_specs } + update_global + (fun global -> + { global with val_specs = Bindings.add id (mk_item ~loc:(id_loc id) (typq, typ)) global.val_specs } + ) + env | Typ_aux (Typ_bidir (typ1, typ2), _) -> let env = add_mapping id (typq, typ1, typ2) env in typ_print (lazy (adding ^ "mapping " ^ string_of_id id ^ " : " ^ string_of_bind (typq, typ))); - { env with top_val_specs = Bindings.add id (typq, typ) env.top_val_specs } + update_global + (fun global -> + { global with val_specs = Bindings.add id (mk_item ~loc:(id_loc id) (typq, typ)) global.val_specs } + ) + env | _ -> typ_error (id_loc id) "val definition must have a mapping or function type" end and add_val_spec ?(ignore_duplicate = false) id (bind_typq, bind_typ) env = - if (not (Bindings.mem id env.top_val_specs)) || ignore_duplicate then update_val_spec id (bind_typq, bind_typ) env + if (not (Bindings.mem id env.global.val_specs)) || ignore_duplicate then + update_val_spec id (bind_typq, bind_typ) env else if ignore_duplicate then env else ( let previous_loc = - match Bindings.choose_opt (Bindings.filter (fun key _ -> Id.compare id key = 0) env.top_val_specs) with + match Bindings.choose_opt (Bindings.filter (fun key _ -> Id.compare id key = 0) env.global.val_specs) with | Some (prev_id, _) -> id_loc prev_id | None -> Parse_ast.Unknown in @@ -1305,7 +1359,10 @@ end = struct let backwards_typ = Typ_aux (Typ_fn ([typ2], typ1), Parse_ast.Unknown) in let backwards_matches_typ = Typ_aux (Typ_fn ([typ2], bool_typ), Parse_ast.Unknown) in let env = - { env with mappings = Bindings.add id (typq, typ1, typ2) env.mappings } + env + |> update_global (fun global -> + { global with mappings = Bindings.add id (mk_item ~loc:(id_loc id) (typq, typ1, typ2)) global.mappings } + ) |> add_val_spec ~ignore_duplicate:true forwards_id (typq, forwards_typ) |> add_val_spec ~ignore_duplicate:true backwards_id (typq, backwards_typ) |> add_val_spec ~ignore_duplicate:true forwards_matches_id (typq, forwards_matches_typ) @@ -1338,17 +1395,17 @@ end = struct { env with outcome_instantiation = KBindings.add kid (l, typ) env.outcome_instantiation } let define_val_spec id env = - if IdSet.mem id env.defined_val_specs then + if IdSet.mem id env.global.defined_val_specs then typ_error (id_loc id) ("Function " ^ string_of_id id ^ " has already been declared") - else { env with defined_val_specs = IdSet.add id env.defined_val_specs } + else update_global (fun global -> { global with defined_val_specs = IdSet.add id global.defined_val_specs }) env - let get_defined_val_specs env = env.defined_val_specs + let get_defined_val_specs env = env.global.defined_val_specs let is_ctor id (Tu_aux (tu, _)) = match tu with Tu_ty_id (_, ctor_id) when Id.compare id ctor_id = 0 -> true | _ -> false let union_constructor_info id env = - let type_unions = List.map (fun (id, (_, tus)) -> (id, tus)) (Bindings.bindings env.variants) in + let type_unions = List.map (fun (id, { item = _, tus; _ }) -> (id, tus)) (Bindings.bindings env.global.unions) in Util.find_map (fun (union_id, tus) -> Option.map (fun (n, tu) -> (n, List.length tus, union_id, tu)) (Util.find_index_opt (is_ctor id) tus) @@ -1356,56 +1413,69 @@ end = struct type_unions let is_union_constructor id env = - let type_unions = List.concat (List.map (fun (_, (_, tus)) -> tus) (Bindings.bindings env.variants)) in + let type_unions = + List.concat (List.map (fun (_, { item = _, tus; _ }) -> tus) (Bindings.bindings env.global.unions)) + in List.exists (is_ctor id) type_unions let is_singleton_union_constructor id env = - let type_unions = List.map (fun (_, (_, tus)) -> tus) (Bindings.bindings env.variants) in + let type_unions = List.map (fun (_, { item = _, tus; _ }) -> tus) (Bindings.bindings env.global.unions) in match List.find (List.exists (is_ctor id)) type_unions with l -> List.length l = 1 | exception Not_found -> false - let is_mapping id env = Bindings.mem id env.mappings + let is_mapping id env = Bindings.mem id env.global.mappings let add_enum' is_scattered id ids env = if bound_typ_id env id then already_bound "enum" id env else ( typ_print (lazy (adding ^ "enum " ^ string_of_id id)); - { env with enums = Bindings.add id (is_scattered, IdSet.of_list ids) env.enums } + update_global + (fun global -> + { + global with + enums = Bindings.add id (mk_item ~loc:(id_loc id) (is_scattered, IdSet.of_list ids)) global.enums; + } + ) + env ) let add_scattered_enum id env = add_enum' true id [] env let add_enum id ids env = add_enum' false id ids env let add_enum_clause id member env = - match Bindings.find_opt id env.enums with - | Some (true, members) -> + match Bindings.find_opt id env.global.enums with + | Some ({ item = true, members; _ } as item) -> if IdSet.mem member members then typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " already has a member " ^ string_of_id member) - else { env with enums = Bindings.add id (true, IdSet.add member members) env.enums } - | Some (false, _) -> - let prev_id, _ = Bindings.find_first (fun prev_id -> Id.compare id prev_id = 0) env.enums in + else + update_global + (fun global -> + { global with enums = Bindings.add id { item with item = (true, IdSet.add member members) } global.enums } + ) + env + | Some { item = false, _; loc = l } -> typ_error - (Parse_ast.Hint ("Declared as regular enumeration here", id_loc prev_id, id_loc id)) + (Parse_ast.Hint ("Declared as regular enumeration here", l, id_loc id)) ("Enumeration " ^ string_of_id id ^ " is not scattered - cannot add a new member with 'enum clause'") | None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist") let get_enum id env = - match Bindings.find_opt id env.enums with + match Option.map get_item (Bindings.find_opt id env.global.enums) with | Some (_, enum) -> IdSet.elements enum | None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist") - let get_enums env = Bindings.map snd env.enums + let get_enums env = Bindings.map (fun item -> item |> get_item |> snd) env.global.enums - let is_record id env = Bindings.mem id env.records + let is_record id env = Bindings.mem id env.global.records - let get_record id env = Bindings.find id env.records + let get_record id env = get_item (Bindings.find id env.global.records) - let get_records env = env.records + let get_records env = Bindings.map get_item env.global.records let add_record id typq fields env = let fields = List.map (fun (typ, id) -> (expand_synonyms env typ, id)) fields in if bound_typ_id env id then already_bound "struct" id env else ( - typ_print (lazy (adding ^ "record " ^ string_of_id id)); + typ_print (lazy (adding ^ "struct " ^ string_of_id id)); let rec record_typ_args = function | [] -> [] | QI_aux (QI_id kopt, _) :: qis when is_int_kopt kopt -> @@ -1414,62 +1484,79 @@ end = struct mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt)))) :: record_typ_args qis | _ :: qis -> record_typ_args qis in - let rectyp = + let record_typ = match record_typ_args (quant_items typq) with [] -> mk_id_typ id | args -> mk_typ (Typ_app (id, args)) in - let fold_accessors accs (typ, fid) = - let acc_typ = mk_typ (Typ_fn ([rectyp], typ)) in + let fold_accessors accessors (typ, field) = + let accessor_typ = mk_typ (Typ_fn ([record_typ], typ)) in typ_print ( lazy - (indent 1 ^ adding ^ "accessor " ^ string_of_id id ^ "." ^ string_of_id fid ^ " :: " - ^ string_of_bind (typq, acc_typ) + (indent 1 ^ adding ^ "field accessor " ^ string_of_id id ^ "." ^ string_of_id field ^ " :: " + ^ string_of_bind (typq, accessor_typ) ) ); - Bindings.add (field_name id fid) (typq, acc_typ) accs + IdPairMap.add (id, field) (mk_item ~loc:(id_loc field) (typq, accessor_typ)) accessors in - { - env with - records = Bindings.add id (typq, fields) env.records; - accessors = List.fold_left fold_accessors env.accessors fields; - } + update_global + (fun global -> + { + global with + records = Bindings.add id (mk_item ~loc:(id_loc id) (typq, fields)) global.records; + accessors = List.fold_left fold_accessors global.accessors fields; + } + ) + env ) - let get_accessor_fn rec_id id env = + let get_accessor_fn record_id field env = let freshen_bind bind = List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars) in - try freshen_bind (Bindings.find (field_name rec_id id) env.accessors) - with Not_found -> typ_error (id_loc id) ("No accessor found for " ^ string_of_id (field_name rec_id id)) + try freshen_bind (get_item (IdPairMap.find (record_id, field) env.global.accessors)) + with Not_found -> + typ_error (id_loc field) ("No field accessor found for " ^ string_of_id record_id ^ "." ^ string_of_id field) - let get_accessor rec_id id env = - match get_accessor_fn rec_id id env with + let get_accessor record_id field env = + match get_accessor_fn record_id field env with (* All accessors should have a single argument (the record itself) *) - | typq, Typ_aux (Typ_fn ([rec_typ], field_typ), _) -> (typq, rec_typ, field_typ) - | _ -> typ_error (id_loc id) ("Accessor with non-function type found for " ^ string_of_id (field_name rec_id id)) + | typq, Typ_aux (Typ_fn ([record_typ], field_typ), _) -> (typq, record_typ, field_typ) + | _ -> + typ_error (id_loc field) + ("Field accessor with non-function type found for " ^ string_of_id record_id ^ "." ^ string_of_id field) let is_mutable id env = - try - let mut, _ = Bindings.find id env.locals in - match mut with Mutable -> true | Immutable -> false - with Not_found -> false + let to_bool = function Mutable -> true | Immutable -> false in + match Bindings.find_opt id env.locals with Some (mut, _) -> to_bool mut | None -> false let string_of_mtyp (mut, typ) = match mut with Immutable -> string_of_typ typ | Mutable -> "ref<" ^ string_of_typ typ ^ ">" let add_local id mtyp env = - if not env.allow_bindings then typ_error (id_loc id) "Bindings are not allowed in this context" else (); + if not env.allow_bindings then typ_error (id_loc id) "Bindings are not allowed in this context"; wf_typ env (snd mtyp); - if Bindings.mem id env.top_val_specs then + if Bindings.mem id env.global.val_specs then typ_error (id_loc id) ("Local variable " ^ string_of_id id ^ " is already bound as a function name") else (); typ_print (lazy (adding ^ "local binding " ^ string_of_id id ^ " : " ^ string_of_mtyp mtyp)); - { env with locals = Bindings.add id mtyp env.locals; top_letbinds = IdSet.remove id env.top_letbinds } - - let add_toplevel_lets ids env = { env with top_letbinds = IdSet.union ids env.top_letbinds } + { env with locals = Bindings.add id mtyp env.locals } + + (* Promote a set of identifiers from local bindings to top-level global letbindings *) + let add_toplevel_lets ids (env : env) = + IdSet.fold + (fun id (env : env) -> + match Bindings.find_opt id env.locals with + | Some (_, typ) -> + let env = { env with locals = Bindings.remove id env.locals } in + update_global + (fun global -> { global with letbinds = Bindings.add id (mk_item ~loc:(id_loc id) typ) global.letbinds }) + env + | None -> env + ) + ids env - let get_toplevel_lets env = env.top_letbinds + let get_toplevel_lets env = Bindings.bindings env.global.letbinds |> List.map fst |> IdSet.of_list - let is_variant id env = Bindings.mem id env.variants + let is_variant id env = Bindings.mem id env.global.unions let add_variant id (typq, constructors) env = let constructors = @@ -1482,82 +1569,106 @@ end = struct if bound_typ_id env id then already_bound "union" id env else ( typ_print (lazy (adding ^ "variant " ^ string_of_id id)); - { env with variants = Bindings.add id (typq, constructors) env.variants } + update_global + (fun global -> + { global with unions = Bindings.add id (mk_item ~loc:(id_loc id) (typq, constructors)) global.unions } + ) + env ) let add_scattered_variant id typq env = if bound_typ_id env id then already_bound "scattered union" id env else ( typ_print (lazy (adding ^ "scattered variant " ^ string_of_id id)); - { - env with - variants = Bindings.add id (typq, []) env.variants; - scattered_variant_envs = Bindings.add id env env.scattered_variant_envs; - } + update_global + (fun global -> + { + global with + unions = Bindings.add id (mk_item ~loc:(id_loc id) (typq, [])) global.unions; + scattered_union_envs = Bindings.add id env.global global.scattered_union_envs; + } + ) + env ) let add_variant_clause id tu env = - match Bindings.find_opt id env.variants with - | Some (typq, tus) -> { env with variants = Bindings.add id (typq, tus @ [tu]) env.variants } + match Bindings.find_opt id env.global.unions with + | Some ({ item = typq, tus; _ } as item) -> + update_global + (fun global -> { global with unions = Bindings.add id { item with item = (typq, tus @ [tu]) } global.unions }) + env | None -> typ_error (id_loc id) ("scattered union " ^ string_of_id id ^ " not found") - let get_variants env = env.variants + let get_variants env = Bindings.map get_item env.global.unions let get_variant id env = - match Bindings.find_opt id env.variants with + match Option.map get_item (Bindings.find_opt id env.global.unions) with | Some (typq, tus) -> (typq, tus) | None -> typ_error (id_loc id) ("union " ^ string_of_id id ^ " not found") let get_scattered_variant_env id env = - match Bindings.find_opt id env.scattered_variant_envs with - | Some env' -> env' + match Bindings.find_opt id env.global.scattered_union_envs with + | Some global_env -> { env with global = global_env } | None -> typ_error (id_loc id) ("scattered union " ^ string_of_id id ^ " has not been declared") - let is_register id env = Bindings.mem id env.registers + let is_register id env = Bindings.mem id env.global.registers let get_register id env = - try Bindings.find id env.registers - with Not_found -> typ_error (id_loc id) ("No register binding found for " ^ string_of_id id) + match Option.map get_item (Bindings.find_opt id env.global.registers) with + | Some typ -> typ + | None -> typ_error (id_loc id) ("No register binding found for " ^ string_of_id id) - let get_registers env = env.registers + let get_registers env = Bindings.map get_item env.global.registers let is_extern id env backend = - try not (Ast_util.extern_assoc backend (Bindings.find_opt id env.externs) = None) with Not_found -> false + try not (Ast_util.extern_assoc backend (Bindings.find_opt id env.global.externs) = None) with Not_found -> false - let add_extern id ext env = { env with externs = Bindings.add id ext env.externs } + let add_extern id ext env = + update_global (fun global -> { global with externs = Bindings.add id ext global.externs }) env let get_extern id env backend = try - match Ast_util.extern_assoc backend (Bindings.find_opt id env.externs) with + match Ast_util.extern_assoc backend (Bindings.find_opt id env.global.externs) with | Some ext -> ext | None -> typ_error (id_loc id) ("No extern binding found for " ^ string_of_id id) with Not_found -> typ_error (id_loc id) ("No extern binding found for " ^ string_of_id id) let add_register id typ env = wf_typ env typ; - if Bindings.mem id env.registers then typ_error (id_loc id) ("Register " ^ string_of_id id ^ " is already bound") - else begin + if Bindings.mem id env.global.registers then + typ_error (id_loc id) ("Register " ^ string_of_id id ^ " is already bound") + else ( typ_print (lazy (adding ^ "register binding " ^ string_of_id id ^ " :: " ^ string_of_typ typ)); - { env with registers = Bindings.add id typ env.registers } - end + update_global + (fun global -> { global with registers = Bindings.add id (mk_item ~loc:(id_loc id) typ) global.registers }) + env + ) - let get_locals env = env.locals + let get_locals env = + Bindings.fold + (fun id { item = typ; _ } locals -> + if not (Bindings.mem id locals) then Bindings.add id (Immutable, typ) locals else locals + ) + env.global.letbinds env.locals let lookup_id id env = - try - let mut, typ = Bindings.find id env.locals in - Local (mut, typ) - with Not_found -> ( - try - let typ = Bindings.find id env.registers in - Register typ - with Not_found -> ( - try - let enum, _ = List.find (fun (_, (_, ctors)) -> IdSet.mem id ctors) (Bindings.bindings env.enums) in - Enum (mk_typ (Typ_id enum)) - with Not_found -> Unbound id + match Bindings.find_opt id env.locals with + | Some (mut, typ) -> Local (mut, typ) + | None -> ( + match Bindings.find_opt id env.global.letbinds with + | Some { item = typ; _ } -> Local (Immutable, typ) + | None -> ( + match Bindings.find_opt id env.global.registers with + | Some { item = typ; _ } -> Register typ + | None -> ( + match + List.find_opt (fun (_, { item = _, ctors }) -> IdSet.mem id ctors) (Bindings.bindings env.global.enums) + with + | Some (enum, _) -> Enum (mk_typ (Typ_id enum)) + | None -> Unbound id + ) + ) ) - ) let get_ret_typ env = env.ret_typ @@ -1593,11 +1704,19 @@ end = struct in aux (expand_synonyms env typ) - let is_bitfield id env = Bindings.mem id env.bitfields + let is_bitfield id env = Bindings.mem id env.global.bitfields - let get_bitfield id env = Bindings.find id env.bitfields + let get_bitfield id env = + match Option.map get_item (Bindings.find_opt id env.global.bitfields) with + | Some bitfield -> bitfield + | None -> typ_error (id_loc id) ("Could not find bitfield " ^ string_of_id id) - let add_bitfield id typ ranges env = { env with bitfields = Bindings.add id (typ, ranges) env.bitfields } + let add_bitfield id typ ranges env = + update_global + (fun global -> + { global with bitfields = Bindings.add id (mk_item ~loc:(id_loc id) (typ, ranges)) global.bitfields } + ) + env let allow_polymorphic_undefineds env = { env with poly_undefineds = true } diff --git a/src/lib/type_check.mli b/src/lib/type_check.mli index 5b7afffa1..c66762ca6 100644 --- a/src/lib/type_check.mli +++ b/src/lib/type_check.mli @@ -198,7 +198,7 @@ module Env : sig val get_variants : t -> (typquant * type_union list) Bindings.t - (** Return type is: quantifier, argument type, return type, effect *) + (** Return type is: quantifier, argument type, return type *) val get_accessor : id -> id -> t -> typquant * typ * typ (** If the environment is checking a function, then this will get