Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add abstract types and global constraints #412

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion language/sail.ott
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,11 @@ n_constraint :: 'NC_' ::=
| nexp '<=' nexp' :: :: bounded_le
| nexp '<' nexp' :: :: bounded_lt
| nexp != nexp' :: :: not_equal
| kid 'IN' { num1 , ... , numn } :: :: set
| nexp 'IN' { num1 , ... , numn } :: :: set
| n_constraint & n_constraint' :: :: or
| n_constraint | n_constraint' :: :: and
| id ( typ_arg0 , ... , typ_argn ) :: :: app
| id :: :: id
| kid :: :: var
| true :: :: true
| false :: :: false
Expand Down Expand Up @@ -318,6 +319,8 @@ type_def_aux :: 'TD_' ::=
{{ com tagged union type definition}} {{ texlong }}
| typedef id = enumerate { id1 ; ... ; idn semi_opt } :: :: enum
{{ com enumeration type definition}} {{ texlong }}
| typedef id : kind :: :: abstract
{{ com abstract type }}
| bitfield id : typ = { id1 : index_range1 , ... , idn : index_rangen } :: :: bitfield
{{ com register mutable bitfield type definition }} {{ texlong }}

Expand Down Expand Up @@ -760,6 +763,8 @@ def :: 'DEF_' ::=
{{ aux _ def_annot }} {{ auxparam 'a }}
| type_def :: :: type
{{ com type definition }}
| constraint n_constraint :: :: constraint
{{ com top-level constraint }}
| fundef :: :: fundef
{{ com function definition }}
| mapdef :: :: mapdef
Expand Down
1 change: 1 addition & 0 deletions src/bin/sail.ml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ let rec options =
("-all_modules", Arg.Set opt_all_modules, " use all modules in project file");
("-list_files", Arg.Set Frontend.opt_list_files, " list files used in all project files");
("-config", Arg.String (fun file -> opt_config_file := Some file), "<file> configuration file");
("-abstract_types", Arg.Set Initial_check.opt_abstract_types, " (experimental) allow abstract types");
("-fmt", Arg.Set opt_format, " format input source code");
( "-fmt_backup",
Arg.String (fun suffix -> opt_format_backup := Some suffix),
Expand Down
62 changes: 31 additions & 31 deletions src/lib/ast_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ let rec get_nexp_constant (Nexp_aux (n, _)) =
let rec constraint_simp (NC_aux (nc_aux, l)) =
let nc_aux =
match nc_aux with
| NC_set (nexp, ints) ->
let nexp = nexp_simp nexp in
begin
match nexp with
| Nexp_aux (Nexp_constant c, _) -> if List.exists (fun i -> Big_int.equal c i) ints then NC_true else NC_false
| _ -> NC_set (nexp, ints)
end
| NC_equal (nexp1, nexp2) ->
let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in
if nexp_identical nexp1 nexp2 then NC_true else NC_equal (nexp1, nexp2)
Expand Down Expand Up @@ -538,6 +545,7 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown)
let nc_lt n1 n2 = NC_aux (NC_bounded_lt (n1, n2), Parse_ast.Unknown)
let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown)
let nc_gt n1 n2 = NC_aux (NC_bounded_gt (n1, n2), Parse_ast.Unknown)
let nc_id id = mk_nc (NC_id id)
let nc_var kid = mk_nc (NC_var kid)
let nc_true = mk_nc NC_true
let nc_false = mk_nc NC_false
Expand Down Expand Up @@ -565,6 +573,7 @@ let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc]))

let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown)

let mk_empty_typquant ~loc:l = TypQ_aux (TypQ_no_forall, l)
let mk_typquant qis = TypQ_aux (TypQ_tq qis, Parse_ast.Unknown)

let mk_fexp id exp = FE_aux (FE_fexp (id, exp), no_annot)
Expand Down Expand Up @@ -824,6 +833,7 @@ and map_def_annot f (DEF_aux (aux, annot)) =
let aux =
match aux with
| DEF_type td -> DEF_type (map_typedef_annot f td)
| DEF_constraint nc -> DEF_constraint nc
| DEF_fundef fd -> DEF_fundef (map_fundef_annot f fd)
| DEF_mapdef md -> DEF_mapdef (map_mapdef_annot f md)
| DEF_outcome (outcome_spec, defs) -> DEF_outcome (outcome_spec, List.map (map_def_annot f) defs)
Expand Down Expand Up @@ -912,6 +922,7 @@ and string_of_typ_arg_aux = function
| A_bool nc -> string_of_n_constraint nc

and string_of_n_constraint = function
| NC_aux (NC_id id, _) -> string_of_id id
| NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2
| NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
| NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2
Expand All @@ -920,7 +931,7 @@ and string_of_n_constraint = function
| NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2
| NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_set (kid, ns), _) -> string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_set (n, ns), _) -> string_of_nexp n ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_app (Id_aux (Operator op, _), [arg1; arg2]), _) ->
"(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")"
| NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")"
Expand Down Expand Up @@ -1137,8 +1148,10 @@ let id_of_type_def_aux = function
| TD_record (id, _, _, _)
| TD_variant (id, _, _, _)
| TD_enum (id, _, _)
| TD_abstract (id, _)
| TD_bitfield (id, _, _) ->
id

let id_of_type_def (TD_aux (td_aux, _)) = id_of_type_def_aux td_aux

let id_of_val_spec (VS_aux (VS_val_spec (_, id, _), _)) = id
Expand Down Expand Up @@ -1207,14 +1220,15 @@ let lex_ord f g x1 x2 y1 y2 = match f x1 x2 with 0 -> g y1 y2 | n -> n

let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) =
match (nc1, nc2) with
| NC_id id1, NC_id id2 -> Id.compare id1 id2
| NC_equal (n1, n2), NC_equal (n3, n4)
| NC_bounded_ge (n1, n2), NC_bounded_ge (n3, n4)
| NC_bounded_gt (n1, n2), NC_bounded_gt (n3, n4)
| NC_bounded_le (n1, n2), NC_bounded_le (n3, n4)
| NC_bounded_lt (n1, n2), NC_bounded_lt (n3, n4)
| NC_not_equal (n1, n2), NC_not_equal (n3, n4) ->
lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4
| NC_set (k1, s1), NC_set (k2, s2) -> lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2
| NC_set (n1, s1), NC_set (n2, s2) -> lex_ord Nexp.compare (Util.compare_list Nat_big_num.compare) n1 n2 s1 s2
| NC_or (nc1, nc2), NC_or (nc3, nc4) | NC_and (nc1, nc2), NC_and (nc3, nc4) ->
lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4
| NC_app (f1, args1), NC_app (f2, args2) -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2
Expand Down Expand Up @@ -1244,6 +1258,8 @@ let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) =
| _, NC_var _ -> 1
| NC_true, _ -> -1
| _, NC_true -> 1
| NC_id _, _ -> -1
| _, NC_id _ -> 1

and typ_compare (Typ_aux (t1, _)) (Typ_aux (t2, _)) =
match (t1, t2) with
Expand Down Expand Up @@ -1298,6 +1314,9 @@ let is_typ_arg_typ = function A_aux (A_typ _, _) -> true | _ -> false

let is_typ_arg_bool = function A_aux (A_bool _, _) -> true | _ -> false

let typ_arg_kind (A_aux (aux, l)) =
match aux with A_typ _ -> K_aux (K_type, l) | A_bool _ -> K_aux (K_bool, l) | A_nexp _ -> K_aux (K_int, l)

module NC = struct
type t = n_constraint
let compare = nc_compare
Expand Down Expand Up @@ -1411,11 +1430,11 @@ let rec kopts_of_constraint (NC_aux (nc, _)) =
| NC_bounded_lt (nexp1, nexp2)
| NC_not_equal (nexp1, nexp2) ->
KOptSet.union (kopts_of_nexp nexp1) (kopts_of_nexp nexp2)
| NC_set (kid, _) -> KOptSet.singleton (mk_kopt K_int kid)
| NC_set (nexp, _) -> kopts_of_nexp nexp
| NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KOptSet.union (kopts_of_constraint nc1) (kopts_of_constraint nc2)
| NC_app (_, args) -> List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ_arg t)) KOptSet.empty args
| NC_var kid -> KOptSet.singleton (mk_kopt K_bool kid)
| NC_true | NC_false -> KOptSet.empty
| NC_id _ | NC_true | NC_false -> KOptSet.empty

and kopts_of_typ (Typ_aux (t, _)) =
match t with
Expand Down Expand Up @@ -1456,11 +1475,11 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) =
| NC_bounded_lt (nexp1, nexp2)
| NC_not_equal (nexp1, nexp2) ->
KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2)
| NC_set (kid, _) -> KidSet.singleton kid
| NC_set (nexp, _) -> tyvars_of_nexp nexp
| NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2)
| NC_app (_, args) -> List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args
| NC_var kid -> KidSet.singleton kid
| NC_true | NC_false -> KidSet.empty
| NC_id _ | NC_true | NC_false -> KidSet.empty

and tyvars_of_typ (Typ_aux (t, _)) =
match t with
Expand Down Expand Up @@ -1728,13 +1747,14 @@ let rec locate_nexp f (Nexp_aux (nexp_aux, l)) =
let rec locate_nc f (NC_aux (nc_aux, l)) =
let nc_aux =
match nc_aux with
| NC_id id -> NC_id (locate_id f id)
| NC_equal (nexp1, nexp2) -> NC_equal (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_le (nexp1, nexp2) -> NC_bounded_le (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_lt (nexp1, nexp2) -> NC_bounded_lt (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_not_equal (nexp1, nexp2) -> NC_not_equal (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_set (kid, nums) -> NC_set (locate_kid f kid, nums)
| NC_set (nexp, nums) -> NC_set (locate_nexp f nexp, nums)
| NC_or (nc1, nc2) -> NC_or (locate_nc f nc1, locate_nc f nc2)
| NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2)
| NC_true -> NC_true
Expand Down Expand Up @@ -1920,26 +1940,17 @@ and nexp_subst_aux sv subst = function
| Nexp_exp nexp -> Nexp_exp (nexp_subst sv subst nexp)
| Nexp_neg nexp -> Nexp_neg (nexp_subst sv subst nexp)

let rec nexp_set_to_or l subst = function
| [] -> raise (Reporting.err_unreachable l __POS__ "Empty set in constraint")
| [int] -> NC_equal (subst, nconstant int)
| int :: ints -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints))

let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l)

and constraint_subst_aux l sv subst = function
| NC_id id -> NC_id id
| NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_gt (n1, n2) -> NC_bounded_gt (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_lt (n1, n2) -> NC_bounded_lt (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_set (kid, ints) as set_nc -> begin
match subst with
| A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _) when Kid.compare kid sv = 0 -> NC_set (kid', ints)
| A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> nexp_set_to_or l n ints
| _ -> set_nc
end
| NC_set (n, ints) -> NC_set (nexp_subst sv subst n, ints)
| NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
| NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
| NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args)
Expand Down Expand Up @@ -2017,25 +2028,14 @@ let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg =
let snc nc = subst_kids_nc substs nc in
let re nc = NC_aux (nc, l) in
match nc with
| NC_id id -> re (NC_id id)
| NC_equal (n1, n2) -> re (NC_equal (snexp n1, snexp n2))
| NC_bounded_ge (n1, n2) -> re (NC_bounded_ge (snexp n1, snexp n2))
| NC_bounded_gt (n1, n2) -> re (NC_bounded_gt (snexp n1, snexp n2))
| NC_bounded_le (n1, n2) -> re (NC_bounded_le (snexp n1, snexp n2))
| NC_bounded_lt (n1, n2) -> re (NC_bounded_lt (snexp n1, snexp n2))
| NC_not_equal (n1, n2) -> re (NC_not_equal (snexp n1, snexp n2))
| NC_set (kid, is) -> begin
match KBindings.find kid substs with
| Nexp_aux (Nexp_constant i, _) ->
if List.exists (fun j -> Big_int.equal i j) is then re NC_true else re NC_false
| nexp -> begin
match List.rev is with
| i :: is ->
let equal_num i = re (NC_equal (nexp, nconstant i)) in
List.fold_left (fun nc i -> re (NC_or (equal_num i, nc))) (equal_num i) is
| [] -> re NC_false
end
| exception Not_found -> n_constraint
end
| NC_set (n, ints) -> re (NC_set (snexp n, ints))
| NC_or (nc1, nc2) -> re (NC_or (snc nc1, snc nc2))
| NC_and (nc1, nc2) -> re (NC_and (snc nc1, snc nc2))
| NC_true | NC_false -> n_constraint
Expand Down
9 changes: 7 additions & 2 deletions src/lib/ast_util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ val mk_funcl : ?loc:l -> id -> uannot pat -> uannot exp -> uannot funcl
val mk_fundef : uannot funcl list -> uannot def
val mk_val_spec : val_spec_aux -> uannot def
val mk_typschm : typquant -> typ -> typschm
val mk_empty_typquant : loc:l -> typquant
val mk_typquant : quant_item list -> typquant
val mk_qi_id : kind_aux -> kid -> quant_item
val mk_qi_nc : n_constraint -> quant_item
Expand Down Expand Up @@ -215,6 +216,8 @@ val is_typ_arg_nexp : typ_arg -> bool
val is_typ_arg_typ : typ_arg -> bool
val is_typ_arg_bool : typ_arg -> bool

val typ_arg_kind : typ_arg -> kind

(** {2 Sail built-in types} *)

val unknown_typ : typ
Expand Down Expand Up @@ -301,8 +304,9 @@ val nc_or : n_constraint -> n_constraint -> n_constraint
val nc_not : n_constraint -> n_constraint
val nc_true : n_constraint
val nc_false : n_constraint
val nc_set : kid -> Big_int.num list -> n_constraint
val nc_int_set : kid -> int list -> n_constraint
val nc_set : nexp -> Big_int.num list -> n_constraint
val nc_int_set : nexp -> int list -> n_constraint
val nc_id : id -> n_constraint
val nc_var : kid -> n_constraint

(** {2 Functions for building type arguments}*)
Expand Down Expand Up @@ -466,6 +470,7 @@ val string_of_index_range : index_range -> string

val id_of_fundef : 'a fundef -> id
val id_of_mapdef : 'a mapdef -> id
val id_of_type_def_aux : type_def_aux -> id
val id_of_type_def : 'a type_def -> id
val id_of_val_spec : 'a val_spec -> id
val id_of_dec_spec : 'a dec_spec -> id
Expand Down
12 changes: 3 additions & 9 deletions src/lib/callgraph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ let rec constraint_ids' (NC_aux (aux, _)) =
IdSet.union (nexp_ids' n1) (nexp_ids' n2)
| NC_or (nc1, nc2) | NC_and (nc1, nc2) -> IdSet.union (constraint_ids' nc1) (constraint_ids' nc2)
| NC_var _ | NC_true | NC_false | NC_set _ -> IdSet.empty
| NC_id id -> IdSet.singleton id
| NC_app (id, args) -> IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args))

and nexp_ids' (Nexp_aux (aux, _)) =
Expand Down Expand Up @@ -278,6 +279,7 @@ let add_def_to_graph graph (DEF_aux (def, _)) =
scan_typquant (Type id) typq
| TD_enum (id, ctors, _) ->
List.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) ctors
| TD_abstract (id, _) -> graph := G.add_edges (Type id) [] !graph
| TD_bitfield (id, typ, ranges) ->
graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_ids typ))) !graph
in
Expand Down Expand Up @@ -378,14 +380,6 @@ let rec graph_of_defs defs =

let graph_of_ast ast = graph_of_defs ast.defs

let id_of_typedef (TD_aux (aux, _)) =
match aux with
| TD_abbrev (id, _, _) -> id
| TD_record (id, _, _, _) -> id
| TD_variant (id, _, _, _) -> id
| TD_enum (id, _, _) -> id
| TD_bitfield (id, _, _) -> id

let id_of_reg_dec (DEC_aux (DEC_reg (_, id, _), _)) = id

let id_of_funcl (FCL_aux (FCL_funcl (id, _), _)) = id
Expand Down Expand Up @@ -426,7 +420,7 @@ let filter_ast_extra cuts g ast keep_std =
let ids = pat_ids pat |> IdSet.elements in
if List.exists (fun id -> NM.mem (Letbind id) g) ids then DEF_aux (DEF_let lb, def_annot) :: filter_ast' g defs
else filter_ast' g defs
| DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_typedef tdef)) g ->
| DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_type_def tdef)) g ->
DEF_aux (DEF_type tdef, def_annot) :: filter_ast' g defs
| DEF_aux (DEF_type _, _) :: defs -> filter_ast' g defs
| DEF_aux (DEF_measure (id, _, _), _) :: defs when NS.mem (Function id) cuts -> filter_ast' g defs
Expand Down
4 changes: 2 additions & 2 deletions src/lib/constant_propagation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ let lit_match = function
let fabricate_nexp_exist env l typ kids nc typ' =
match (kids, nc, Env.expand_synonyms env typ') with
| ( [kid],
NC_aux (NC_set (kid', i :: _), _),
NC_aux (NC_set (Nexp_aux (Nexp_var kid', _), i :: _), _),
Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _)]), _) )
when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 ->
Nexp_aux (Nexp_constant i, Unknown)
Expand All @@ -154,7 +154,7 @@ let fabricate_nexp_exist env l typ kids nc typ' =
when Kid.compare kid kid'' = 0 ->
nint 32
| ( [kid],
NC_aux (NC_set (kid', i :: _), _),
NC_aux (NC_set (Nexp_aux (Nexp_var kid', _), i :: _), _),
Typ_aux
( Typ_app
( Id_aux (Id "range", _),
Expand Down
Loading
Loading