Skip to content

Commit

Permalink
Add access control mechanism to typing contexts
Browse files Browse the repository at this point in the history
This commit adds a feature that allows for typechecking using only
a subset of objects (functions/types/registers etc) in the global
typing context. Each `env_item` is extended with an identifier, and
the local typing context contains a set of 'in scope' such identifiers.
All functions accessing objects in the typing context will either raise
Err_not_in_scope (if accessing a single object) or return a filtered set
of objects based on the set of in-scope identifiers.

Note that this commit is effectively a no-op. While it adds the mechanism,
it does not add any way to use it other than a magic directive `unscope#`
which kicks an object out of the current scope (this is guarded by the
-dmagic_hash flag), which is used to test that simple cases have reasonable
error messages.
  • Loading branch information
Alasdair committed Nov 21, 2023
1 parent 13cbea3 commit 8484749
Show file tree
Hide file tree
Showing 17 changed files with 235 additions and 36 deletions.
3 changes: 2 additions & 1 deletion src/lib/preprocess.ml
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ let preprocess dir target opts =
incorrect start/end annotations *)
| (DEF_aux (DEF_pragma ("file_start", _), _) | DEF_aux (DEF_pragma ("file_end", _), _)) :: defs -> aux acc defs
| DEF_aux (DEF_pragma (p, arg), l) :: defs ->
if not (StringSet.mem p all_pragmas) then Reporting.warn "" l ("Unrecognised directive: " ^ p);
if not (StringSet.mem p all_pragmas || String.contains p '#') then
Reporting.warn "" l ("Unrecognised directive: " ^ p);
aux (DEF_aux (DEF_pragma (p, arg), l) :: acc) defs
| DEF_aux (DEF_outcome (outcome_spec, inner_defs), l) :: defs ->
aux (DEF_aux (DEF_outcome (outcome_spec, aux [] inner_defs), l) :: acc) defs
Expand Down
22 changes: 18 additions & 4 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4057,10 +4057,17 @@ let check_fundef env def_annot (FD_aux (FD_function (recopt, tannotopt, funcls),
in
typ_print (lazy ("\n" ^ Util.("Check function " |> cyan |> clear) ^ string_of_id id));
let have_val_spec, (quant, typ), env =
try (true, Env.get_val_spec id env, env)
with Type_error (l, _) ->
let quant, typ = infer_funtyp l env tannotopt funcls in
(false, (quant, typ), env)
try (true, Env.get_val_spec id env, env) with
| Type_error (l, Err_not_in_scope (_, scope_l)) ->
typ_raise l
(Err_not_in_scope
( Some "Cannot infer type of function as it has a defined type already. However, this type is not in scope.",
scope_l
)
)
| Type_error (l, _) ->
let quant, typ = infer_funtyp l env tannotopt funcls in
(false, (quant, typ), env)
in
let vtyp_args, vtyp_ret, vl =
match typ with
Expand Down Expand Up @@ -4493,6 +4500,13 @@ and check_def : Env.t -> uannot def -> tannot def list * Env.t =
],
env
)
| DEF_pragma ("unscope#", arg, l) when !Initial_check.opt_magic_hash ->
let env =
match String.split_on_char ' ' arg with
| [id_category; id] -> Type_env.unscope_pragma id_category (mk_id id) env
| _ -> env
in
([DEF_aux (DEF_pragma ("unscope#", arg, l), def_annot)], env)
| DEF_pragma (pragma, arg, l) -> ([DEF_aux (DEF_pragma (pragma, arg, l), def_annot)], env)
| DEF_scattered sdef -> check_scattered env def_annot sdef
| DEF_measure (id, pat, exp) -> ([check_termination_measure_decl env def_annot (id, pat, exp)], env)
Expand Down
110 changes: 79 additions & 31 deletions src/lib/type_env.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,13 @@ end

module IdPairMap = Map.Make (IdPair)

type ('a, 'b) generic_env_item = { item : 'a; loc : 'b }
type ('a, 'b) generic_env_item = { item : 'a; loc : 'b; mod_id : int }

type 'a env_item = ('a, Parse_ast.l) generic_env_item

type 'a multiple_env_item = ('a, Parse_ast.l list) generic_env_item

let mk_item ~loc:l item = { item; loc = l }

let get_item item = item.item
let mk_item ~loc:l item = { item; loc = l; mod_id = 0 }

let item_loc item = item.loc

Expand Down Expand Up @@ -145,6 +143,7 @@ let empty_global_env =

type env = {
global : global_env;
opened : IntSet.t;
locals : (mut * typ) Bindings.t;
typ_vars : (Ast.l * kind_aux) KBindings.t;
shadow_vars : int KBindings.t;
Expand All @@ -158,15 +157,50 @@ type env = {
outcome_typschm : (typquant * typ) option;
}

let filter_items_with f env bindings =
Bindings.map
(fun item -> f item.item)
(Bindings.filter (fun _ item -> item.mod_id = 0 || IntSet.mem item.mod_id env.opened) bindings)

let filter_items env bindings = filter_items_with (fun x -> x) env bindings

let item_in_scope env item = item.mod_id = 0 || IntSet.mem item.mod_id env.opened

let get_item_with_loc get_loc l env item =
if item_in_scope env item then item.item else typ_raise l (Err_not_in_scope (None, get_loc item.loc))

let get_item env item = get_item_with_loc (fun l -> Some l) env item

let hd_opt = function x :: _ -> Some x | [] -> None

type type_variables = Type_internal.type_variables

type t = env

let update_global f env = { env with global = f env.global }

let unscope_pragma id_category id env =
typ_debug (lazy ("Unscope " ^ id_category ^ " " ^ string_of_id id));
let update_id id m = Bindings.update id (function Some item -> Some { item with mod_id = -1 } | None -> None) m in
match id_category with
| "bitfield" -> update_global (fun global -> { global with bitfields = update_id id global.bitfields }) env
| "enum" -> update_global (fun global -> { global with enums = update_id id global.enums }) env
| "let" -> update_global (fun global -> { global with letbinds = update_id id global.letbinds }) env
| "mapping" -> update_global (fun global -> { global with mappings = update_id id global.mappings }) env
| "overload" -> update_global (fun global -> { global with overloads = update_id id global.overloads }) env
| "register" -> update_global (fun global -> { global with registers = update_id id global.registers }) env
| "struct" -> update_global (fun global -> { global with records = update_id id global.records }) env
| "type" -> update_global (fun global -> { global with synonyms = update_id id global.synonyms }) env
| "union" -> update_global (fun global -> { global with unions = update_id id global.unions }) env
| "val" -> update_global (fun global -> { global with val_specs = update_id id global.val_specs }) env
| _ ->
typ_debug (lazy "Unrecognized unscope category");
env

let empty =
{
global = empty_global_env;
opened = IntSet.empty;
locals = Bindings.empty;
typ_vars = KBindings.empty;
shadow_vars = KBindings.empty;
Expand Down Expand Up @@ -345,7 +379,8 @@ let already_bound_ctor_fn str id env =
Reporting.unreachable (id_loc id) __POS__
("Could not find original binding for duplicate " ^ str ^ " called " ^ string_of_id id)

let get_overloads id env = try get_item (Bindings.find id env.global.overloads) with Not_found -> []
let get_overloads id env =
try get_item_with_loc hd_opt (id_loc id) env (Bindings.find id env.global.overloads) with Not_found -> []

let add_overloads l id ids env =
typ_print (lazy (adding ^ "overloads for " ^ string_of_id id ^ " [" ^ string_of_list ", " string_of_id ids ^ "]"));
Expand All @@ -364,7 +399,9 @@ let add_overloads l id ids env =
{
global with
overloads =
Bindings.add id (mk_item ~loc:(l :: item_loc existing) (get_item existing @ ids)) global.overloads;
Bindings.add id
(mk_item ~loc:(l :: item_loc existing) (get_item_with_loc hd_opt l env existing @ ids))
global.overloads;
}
)
env
Expand All @@ -374,9 +411,10 @@ let add_overloads l id ids env =
env

let infer_kind env id =
let l = id_loc id in
if Bindings.mem id builtin_typs then Bindings.find id builtin_typs
else if Bindings.mem id env.global.unions then fst (get_item (Bindings.find id env.global.unions))
else if Bindings.mem id env.global.records then fst (get_item (Bindings.find id env.global.records))
else if Bindings.mem id env.global.unions then fst (get_item l env (Bindings.find id env.global.unions))
else if Bindings.mem id env.global.records then fst (get_item l env (Bindings.find id env.global.records))
else if Bindings.mem id env.global.enums then mk_typquant []
else if Bindings.mem id env.global.synonyms then
typ_error (id_loc id) ("Cannot infer kind of type synonym " ^ string_of_id id)
Expand Down Expand Up @@ -440,11 +478,11 @@ let mk_synonym typq typ_arg =
)

let get_typ_synonym id env =
match Option.map get_item (Bindings.find_opt id env.global.synonyms) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.synonyms) with
| Some (typq, arg) -> mk_synonym typq arg
| None -> raise Not_found

let get_typ_synonyms env = Bindings.map get_item env.global.synonyms
let get_typ_synonyms env = filter_items env env.global.synonyms

let get_constraints env = List.map snd env.constraints

Expand Down Expand Up @@ -791,12 +829,12 @@ let add_typ_synonym id typq arg env =
)

let get_val_spec_orig id env =
try get_item (Bindings.find id env.global.val_specs)
try get_item (id_loc id) env (Bindings.find id env.global.val_specs)
with Not_found -> typ_error (id_loc id) ("No type signature found for " ^ string_of_id id)

let get_val_spec id env =
try
let bind = get_item (Bindings.find id env.global.val_specs) in
let bind = get_item (id_loc id) env (Bindings.find id env.global.val_specs) in
typ_debug
( lazy
("get_val_spec: Env has "
Expand All @@ -810,19 +848,19 @@ let get_val_spec id env =
bind'
with Not_found -> typ_error (id_loc id) ("No type declaration found for " ^ string_of_id id)

let get_val_specs env = Bindings.map get_item env.global.val_specs
let get_val_specs env = filter_items env env.global.val_specs

let add_union_id id bind env =
if bound_ctor_fn env id then already_bound_ctor_fn "union constructor" id env
else (
typ_print (lazy (adding ^ "union identifier " ^ string_of_id id ^ " : " ^ string_of_bind bind));
update_global
(fun global -> { global with union_ids = Bindings.add id { item = bind; loc = id_loc id } global.union_ids })
(fun global -> { global with union_ids = Bindings.add id (mk_item ~loc:(id_loc id) bind) global.union_ids })
env
)

let get_union_id id env =
match Option.map get_item (Bindings.find_opt id env.global.union_ids) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.union_ids) with
| Some bind -> List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars)
| None -> typ_error (id_loc id) ("No union constructor found for " ^ string_of_id id)

Expand Down Expand Up @@ -910,7 +948,7 @@ and add_outcome id (typq, typ, params, vals, outcome_env) env =
env

and get_outcome l id env =
match Option.map get_item (Bindings.find_opt id env.global.outcomes) with
match Option.map (get_item l env) (Bindings.find_opt id env.global.outcomes) with
| Some (typq, typ, params, vals, val_specs) ->
(typq, typ, params, vals, { empty with global = { empty_global_env with val_specs } })
| None -> typ_error l ("Outcome " ^ string_of_id id ^ " does not exist")
Expand Down Expand Up @@ -1030,20 +1068,20 @@ let add_enum_clause id member env =
| None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist")

let get_enum id env =
match Option.map get_item (Bindings.find_opt id env.global.enums) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.enums) with
| Some (_, enum) -> IdSet.elements enum
| None -> typ_error (id_loc id) ("Enumeration " ^ string_of_id id ^ " does not exist")

let get_enums env = Bindings.map (fun item -> item |> get_item |> snd) env.global.enums
let get_enums env = filter_items_with snd env env.global.enums

let is_record id env = Bindings.mem id env.global.records

let get_record id env =
match Option.map get_item (Bindings.find_opt id env.global.records) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.records) with
| Some record -> record
| None -> typ_error (id_loc id) ("Struct type " ^ string_of_id id ^ " does not exist")

let get_records env = Bindings.map get_item env.global.records
let get_records env = filter_items env env.global.records

let add_record id typq fields env =
let fields = List.map (fun (typ, id) -> (expand_synonyms env typ, id)) fields in
Expand Down Expand Up @@ -1086,7 +1124,7 @@ let get_accessor_fn record_id field env =
let freshen_bind bind =
List.fold_left (fun bind (kid, _) -> freshen_kid env kid bind) bind (KBindings.bindings env.typ_vars)
in
try freshen_bind (get_item (IdPairMap.find (record_id, field) env.global.accessors))
try freshen_bind (get_item (id_loc field) env (IdPairMap.find (record_id, field) env.global.accessors))
with Not_found ->
typ_error (id_loc field) ("No field accessor found for " ^ string_of_id record_id ^ "." ^ string_of_id field)

Expand Down Expand Up @@ -1173,10 +1211,10 @@ let add_variant_clause id tu env =
env
| None -> typ_error (id_loc id) ("scattered union " ^ string_of_id id ^ " not found")

let get_variants env = Bindings.map get_item env.global.unions
let get_variants env = filter_items env env.global.unions

let get_variant id env =
match Option.map get_item (Bindings.find_opt id env.global.unions) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.unions) with
| Some (typq, tus) -> (typq, tus)
| None -> typ_error (id_loc id) ("union " ^ string_of_id id ^ " not found")

Expand All @@ -1193,11 +1231,11 @@ let set_scattered_variant_env ~variant_env id env =
let is_register id env = Bindings.mem id env.global.registers

let get_register id env =
match Option.map get_item (Bindings.find_opt id env.global.registers) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.registers) with
| Some typ -> typ
| None -> typ_error (id_loc id) ("No register binding found for " ^ string_of_id id)

let get_registers env = Bindings.map get_item env.global.registers
let get_registers env = filter_items env env.global.registers

let is_extern id env backend =
try not (Ast_util.extern_assoc backend (Bindings.find_opt id env.global.externs) = None) with Not_found -> false
Expand Down Expand Up @@ -1234,16 +1272,26 @@ let lookup_id id env =
match Bindings.find_opt id env.locals with
| Some (mut, typ) -> Local (mut, typ)
| None -> (
match Bindings.find_opt id env.global.letbinds with
| Some { item = typ; _ } -> Local (Immutable, typ)
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.letbinds) with
| Some typ -> Local (Immutable, typ)
| None -> (
match Bindings.find_opt id env.global.registers with
| Some { item = typ; _ } -> Register typ
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.registers) with
| Some typ -> Register typ
| None -> (
match
List.find_opt (fun (_, { item = _, ctors }) -> IdSet.mem id ctors) (Bindings.bindings env.global.enums)
with
| Some (enum, _) -> Enum (mk_typ (Typ_id enum))
| Some (enum_id, item) ->
if item_in_scope env item then Enum (mk_typ (Typ_id enum_id))
else (
let enum_name = string_of_id enum_id in
typ_raise (id_loc id)
(Err_not_in_scope
( Some ("Enumeration " ^ enum_name ^ " containing " ^ string_of_id id ^ " is not in scope"),
Some item.loc
)
)
)
| None -> Unbound id
)
)
Expand Down Expand Up @@ -1288,7 +1336,7 @@ let base_typ_of env typ =
let is_bitfield id env = Bindings.mem id env.global.bitfields

let get_bitfield id env =
match Option.map get_item (Bindings.find_opt id env.global.bitfields) with
match Option.map (get_item (id_loc id) env) (Bindings.find_opt id env.global.bitfields) with
| Some bitfield -> bitfield
| None -> typ_error (id_loc id) ("Could not find bitfield " ^ string_of_id id)

Expand Down
6 changes: 6 additions & 0 deletions src/lib/type_env.mli
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ type env

type t = env

(** This implements the unscope# pragma, which removes an identifier
from scope by setting its module id to -1 (which is never a valid
module id). This is used only to test the typechecker, and should
not be used for any other reason! *)
val unscope_pragma : string -> id -> t -> t

val freshen_bind : t -> typquant * typ -> typquant * typ

val get_default_order : t -> order
Expand Down
8 changes: 8 additions & 0 deletions src/lib/type_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ let message_of_type_error =
([coercion; Line "Coercion failed because:"; msg trigger]
@ if not (reasons = []) then Line "Possible reasons:" :: List.map msg reasons else []
)
| Err_not_in_scope (explanation, Some l) ->
Seq
[
Line (Option.value ~default:"Not in scope" explanation);
Line "Try bringing the following definition into scope:";
Location ("", Some "definition here", l, Seq []);
]
| Err_not_in_scope (explanation, None) -> Line (Option.value ~default:"Not in scope" explanation)
in
msg

Expand Down
1 change: 1 addition & 0 deletions src/lib/type_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ type type_error =
| Err_no_num_ident of id
| Err_other of string
| Err_inner of type_error * Parse_ast.l * string * string option * type_error
| Err_not_in_scope of string option * Parse_ast.l option

exception Type_error of Parse_ast.l * type_error

Expand Down
1 change: 1 addition & 0 deletions src/lib/type_internal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ type type_error =
| Err_no_num_ident of id
| Err_other of string
| Err_inner of type_error * Parse_ast.l * string * string option * type_error
| Err_not_in_scope of string option * Parse_ast.l option

let err_because (error1, l, error2) = Err_inner (error1, l, "Caused by", None, error2)

Expand Down
9 changes: 9 additions & 0 deletions test/typecheck/fail/unscope_enum.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Type error:
fail/unscope_enum.sail:14.10-11:
14 | let _ = A;
 | ^
 | Enumeration E containing A is not in scope
 | Try bringing the following definition into scope:
 | fail/unscope_enum.sail:7.5-6:
 | 7 |enum E = A | B | C
 |  | ^ definition here
16 changes: 16 additions & 0 deletions test/typecheck/fail/unscope_enum.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
$option -dmagic_hash

default Order dec

$include <prelude.sail>

enum E = A | B | C

$unscope# enum E

val bar : unit -> unit

function bar() = {
let _ = A;
()
}
Loading

0 comments on commit 8484749

Please sign in to comment.