From a53f3a06889b56556f77623b4fd22a71c26264a2 Mon Sep 17 00:00:00 2001 From: Alasdair Date: Fri, 1 Dec 2023 14:08:47 +0000 Subject: [PATCH] Add abstract types and global constraints Will likely fail with anything other than --just-check, as only implemented in type system and parser for now --- language/sail.ott | 7 +- src/bin/sail.ml | 1 + src/lib/ast_util.ml | 62 ++-- src/lib/ast_util.mli | 9 +- src/lib/callgraph.ml | 12 +- src/lib/constant_propagation.ml | 4 +- src/lib/constraint.ml | 107 ++++-- src/lib/constraint.mli | 8 +- src/lib/effects.ml | 1 + src/lib/frontend.ml | 25 ++ src/lib/frontend.mli | 9 +- src/lib/initial_check.ml | 23 +- src/lib/initial_check.mli | 4 + src/lib/jib_compile.ml | 2 + src/lib/monomorphise.ml | 25 +- src/lib/parse_ast.ml | 2 + src/lib/parser.mly | 4 + src/lib/pretty_print_sail.ml | 9 +- src/lib/rewriter.ml | 4 +- src/lib/rewrites.ml | 1 + src/lib/spec_analysis.ml | 6 +- src/lib/specialize.ml | 4 +- src/lib/type_check.ml | 61 +-- src/lib/type_check.mli | 4 +- src/lib/type_env.ml | 350 ++++++++++-------- src/lib/type_env.mli | 8 +- src/lib/type_internal.ml | 6 +- src/sail_coq_backend/pretty_print_coq.ml | 21 +- src/sail_latex_backend/latex.ml | 1 + src/sail_lem_backend/pretty_print_lem.ml | 1 + src/sail_ocaml_backend/ocaml_backend.ml | 4 +- .../fail/abstract_bool_inconsistent.expect | 5 + .../fail/abstract_bool_inconsistent.sail | 8 + .../fail/global_false_constraint.expect | 5 + .../fail/global_false_constraint.sail | 2 + test/typecheck/pass/abstract_bool.sail | 21 ++ test/typecheck/pass/abstract_bool2.sail | 16 + .../pass/complex_exist_sat/v2.expect | 2 +- .../pass/constrained_struct/v1.expect | 2 +- test/typecheck/pass/constraint_syn.sail | 27 ++ .../typecheck/pass/existential_ast3/v1.expect | 2 +- .../typecheck/pass/existential_ast3/v2.expect | 2 +- .../typecheck/pass/existential_ast3/v3.expect | 2 +- test/typecheck/pass/reg_32_64/v1.expect | 2 +- 44 files changed, 572 insertions(+), 309 deletions(-) create mode 100644 test/typecheck/fail/abstract_bool_inconsistent.expect create mode 100644 test/typecheck/fail/abstract_bool_inconsistent.sail create mode 100644 test/typecheck/fail/global_false_constraint.expect create mode 100644 test/typecheck/fail/global_false_constraint.sail create mode 100644 test/typecheck/pass/abstract_bool.sail create mode 100644 test/typecheck/pass/abstract_bool2.sail create mode 100644 test/typecheck/pass/constraint_syn.sail diff --git a/language/sail.ott b/language/sail.ott index 12de6d03e..e4c979c71 100644 --- a/language/sail.ott +++ b/language/sail.ott @@ -266,10 +266,11 @@ n_constraint :: 'NC_' ::= | nexp '<=' nexp' :: :: bounded_le | nexp '<' nexp' :: :: bounded_lt | nexp != nexp' :: :: not_equal - | kid 'IN' { num1 , ... , numn } :: :: set + | nexp 'IN' { num1 , ... , numn } :: :: set | n_constraint & n_constraint' :: :: or | n_constraint | n_constraint' :: :: and | id ( typ_arg0 , ... , typ_argn ) :: :: app + | id :: :: id | kid :: :: var | true :: :: true | false :: :: false @@ -318,6 +319,8 @@ type_def_aux :: 'TD_' ::= {{ com tagged union type definition}} {{ texlong }} | typedef id = enumerate { id1 ; ... ; idn semi_opt } :: :: enum {{ com enumeration type definition}} {{ texlong }} + | typedef id : kind :: :: abstract + {{ com abstract type }} | bitfield id : typ = { id1 : index_range1 , ... , idn : index_rangen } :: :: bitfield {{ com register mutable bitfield type definition }} {{ texlong }} @@ -760,6 +763,8 @@ def :: 'DEF_' ::= {{ aux _ def_annot }} {{ auxparam 'a }} | type_def :: :: type {{ com type definition }} + | constraint n_constraint :: :: constraint + {{ com top-level constraint }} | fundef :: :: fundef {{ com function definition }} | mapdef :: :: mapdef diff --git a/src/bin/sail.ml b/src/bin/sail.ml index 4815b7375..d90169cd1 100644 --- a/src/bin/sail.ml +++ b/src/bin/sail.ml @@ -218,6 +218,7 @@ let rec options = ("-all_modules", Arg.Set opt_all_modules, " use all modules in project file"); ("-list_files", Arg.Set Frontend.opt_list_files, " list files used in all project files"); ("-config", Arg.String (fun file -> opt_config_file := Some file), " configuration file"); + ("-abstract_types", Arg.Set Initial_check.opt_abstract_types, " (experimental) allow abstract types"); ("-fmt", Arg.Set opt_format, " format input source code"); ( "-fmt_backup", Arg.String (fun suffix -> opt_format_backup := Some suffix), diff --git a/src/lib/ast_util.ml b/src/lib/ast_util.ml index 4f157cd30..db928b910 100644 --- a/src/lib/ast_util.ml +++ b/src/lib/ast_util.ml @@ -409,6 +409,13 @@ let rec get_nexp_constant (Nexp_aux (n, _)) = let rec constraint_simp (NC_aux (nc_aux, l)) = let nc_aux = match nc_aux with + | NC_set (nexp, ints) -> + let nexp = nexp_simp nexp in + begin + match nexp with + | Nexp_aux (Nexp_constant c, _) -> if List.exists (fun i -> Big_int.equal c i) ints then NC_true else NC_false + | _ -> NC_set (nexp, ints) + end | NC_equal (nexp1, nexp2) -> let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in if nexp_identical nexp1 nexp2 then NC_true else NC_equal (nexp1, nexp2) @@ -538,6 +545,7 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown) let nc_lt n1 n2 = NC_aux (NC_bounded_lt (n1, n2), Parse_ast.Unknown) let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown) let nc_gt n1 n2 = NC_aux (NC_bounded_gt (n1, n2), Parse_ast.Unknown) +let nc_id id = mk_nc (NC_id id) let nc_var kid = mk_nc (NC_var kid) let nc_true = mk_nc NC_true let nc_false = mk_nc NC_false @@ -565,6 +573,7 @@ let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc])) let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown) +let mk_empty_typquant ~loc:l = TypQ_aux (TypQ_no_forall, l) let mk_typquant qis = TypQ_aux (TypQ_tq qis, Parse_ast.Unknown) let mk_fexp id exp = FE_aux (FE_fexp (id, exp), no_annot) @@ -824,6 +833,7 @@ and map_def_annot f (DEF_aux (aux, annot)) = let aux = match aux with | DEF_type td -> DEF_type (map_typedef_annot f td) + | DEF_constraint nc -> DEF_constraint nc | DEF_fundef fd -> DEF_fundef (map_fundef_annot f fd) | DEF_mapdef md -> DEF_mapdef (map_mapdef_annot f md) | DEF_outcome (outcome_spec, defs) -> DEF_outcome (outcome_spec, List.map (map_def_annot f) defs) @@ -912,6 +922,7 @@ and string_of_typ_arg_aux = function | A_bool nc -> string_of_n_constraint nc and string_of_n_constraint = function + | NC_aux (NC_id id, _) -> string_of_id id | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 @@ -920,7 +931,7 @@ and string_of_n_constraint = function | NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2 | NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" - | NC_aux (NC_set (kid, ns), _) -> string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}" + | NC_aux (NC_set (n, ns), _) -> string_of_nexp n ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_app (Id_aux (Operator op, _), [arg1; arg2]), _) -> "(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")" | NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")" @@ -1137,8 +1148,10 @@ let id_of_type_def_aux = function | TD_record (id, _, _, _) | TD_variant (id, _, _, _) | TD_enum (id, _, _) + | TD_abstract (id, _) | TD_bitfield (id, _, _) -> id + let id_of_type_def (TD_aux (td_aux, _)) = id_of_type_def_aux td_aux let id_of_val_spec (VS_aux (VS_val_spec (_, id, _), _)) = id @@ -1207,6 +1220,7 @@ let lex_ord f g x1 x2 y1 y2 = match f x1 x2 with 0 -> g y1 y2 | n -> n let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) = match (nc1, nc2) with + | NC_id id1, NC_id id2 -> Id.compare id1 id2 | NC_equal (n1, n2), NC_equal (n3, n4) | NC_bounded_ge (n1, n2), NC_bounded_ge (n3, n4) | NC_bounded_gt (n1, n2), NC_bounded_gt (n3, n4) @@ -1214,7 +1228,7 @@ let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) = | NC_bounded_lt (n1, n2), NC_bounded_lt (n3, n4) | NC_not_equal (n1, n2), NC_not_equal (n3, n4) -> lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 - | NC_set (k1, s1), NC_set (k2, s2) -> lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2 + | NC_set (n1, s1), NC_set (n2, s2) -> lex_ord Nexp.compare (Util.compare_list Nat_big_num.compare) n1 n2 s1 s2 | NC_or (nc1, nc2), NC_or (nc3, nc4) | NC_and (nc1, nc2), NC_and (nc3, nc4) -> lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 | NC_app (f1, args1), NC_app (f2, args2) -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 @@ -1244,6 +1258,8 @@ let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) = | _, NC_var _ -> 1 | NC_true, _ -> -1 | _, NC_true -> 1 + | NC_id _, _ -> -1 + | _, NC_id _ -> 1 and typ_compare (Typ_aux (t1, _)) (Typ_aux (t2, _)) = match (t1, t2) with @@ -1298,6 +1314,9 @@ let is_typ_arg_typ = function A_aux (A_typ _, _) -> true | _ -> false let is_typ_arg_bool = function A_aux (A_bool _, _) -> true | _ -> false +let typ_arg_kind (A_aux (aux, l)) = + match aux with A_typ _ -> K_aux (K_type, l) | A_bool _ -> K_aux (K_bool, l) | A_nexp _ -> K_aux (K_int, l) + module NC = struct type t = n_constraint let compare = nc_compare @@ -1411,11 +1430,11 @@ let rec kopts_of_constraint (NC_aux (nc, _)) = | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> KOptSet.union (kopts_of_nexp nexp1) (kopts_of_nexp nexp2) - | NC_set (kid, _) -> KOptSet.singleton (mk_kopt K_int kid) + | NC_set (nexp, _) -> kopts_of_nexp nexp | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KOptSet.union (kopts_of_constraint nc1) (kopts_of_constraint nc2) | NC_app (_, args) -> List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ_arg t)) KOptSet.empty args | NC_var kid -> KOptSet.singleton (mk_kopt K_bool kid) - | NC_true | NC_false -> KOptSet.empty + | NC_id _ | NC_true | NC_false -> KOptSet.empty and kopts_of_typ (Typ_aux (t, _)) = match t with @@ -1456,11 +1475,11 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) = | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2) - | NC_set (kid, _) -> KidSet.singleton kid + | NC_set (nexp, _) -> tyvars_of_nexp nexp | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2) | NC_app (_, args) -> List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args | NC_var kid -> KidSet.singleton kid - | NC_true | NC_false -> KidSet.empty + | NC_id _ | NC_true | NC_false -> KidSet.empty and tyvars_of_typ (Typ_aux (t, _)) = match t with @@ -1728,13 +1747,14 @@ let rec locate_nexp f (Nexp_aux (nexp_aux, l)) = let rec locate_nc f (NC_aux (nc_aux, l)) = let nc_aux = match nc_aux with + | NC_id id -> NC_id (locate_id f id) | NC_equal (nexp1, nexp2) -> NC_equal (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_le (nexp1, nexp2) -> NC_bounded_le (locate_nexp f nexp1, locate_nexp f nexp2) | NC_bounded_lt (nexp1, nexp2) -> NC_bounded_lt (locate_nexp f nexp1, locate_nexp f nexp2) | NC_not_equal (nexp1, nexp2) -> NC_not_equal (locate_nexp f nexp1, locate_nexp f nexp2) - | NC_set (kid, nums) -> NC_set (locate_kid f kid, nums) + | NC_set (nexp, nums) -> NC_set (locate_nexp f nexp, nums) | NC_or (nc1, nc2) -> NC_or (locate_nc f nc1, locate_nc f nc2) | NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2) | NC_true -> NC_true @@ -1920,26 +1940,17 @@ and nexp_subst_aux sv subst = function | Nexp_exp nexp -> Nexp_exp (nexp_subst sv subst nexp) | Nexp_neg nexp -> Nexp_neg (nexp_subst sv subst nexp) -let rec nexp_set_to_or l subst = function - | [] -> raise (Reporting.err_unreachable l __POS__ "Empty set in constraint") - | [int] -> NC_equal (subst, nconstant int) - | int :: ints -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints)) - let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l) and constraint_subst_aux l sv subst = function + | NC_id id -> NC_id id | NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_gt (n1, n2) -> NC_bounded_gt (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_bounded_lt (n1, n2) -> NC_bounded_lt (nexp_subst sv subst n1, nexp_subst sv subst n2) | NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2) - | NC_set (kid, ints) as set_nc -> begin - match subst with - | A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _) when Kid.compare kid sv = 0 -> NC_set (kid', ints) - | A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> nexp_set_to_or l n ints - | _ -> set_nc - end + | NC_set (n, ints) -> NC_set (nexp_subst sv subst n, ints) | NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2) | NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2) | NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args) @@ -2017,25 +2028,14 @@ let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg = let snc nc = subst_kids_nc substs nc in let re nc = NC_aux (nc, l) in match nc with + | NC_id id -> re (NC_id id) | NC_equal (n1, n2) -> re (NC_equal (snexp n1, snexp n2)) | NC_bounded_ge (n1, n2) -> re (NC_bounded_ge (snexp n1, snexp n2)) | NC_bounded_gt (n1, n2) -> re (NC_bounded_gt (snexp n1, snexp n2)) | NC_bounded_le (n1, n2) -> re (NC_bounded_le (snexp n1, snexp n2)) | NC_bounded_lt (n1, n2) -> re (NC_bounded_lt (snexp n1, snexp n2)) | NC_not_equal (n1, n2) -> re (NC_not_equal (snexp n1, snexp n2)) - | NC_set (kid, is) -> begin - match KBindings.find kid substs with - | Nexp_aux (Nexp_constant i, _) -> - if List.exists (fun j -> Big_int.equal i j) is then re NC_true else re NC_false - | nexp -> begin - match List.rev is with - | i :: is -> - let equal_num i = re (NC_equal (nexp, nconstant i)) in - List.fold_left (fun nc i -> re (NC_or (equal_num i, nc))) (equal_num i) is - | [] -> re NC_false - end - | exception Not_found -> n_constraint - end + | NC_set (n, ints) -> re (NC_set (snexp n, ints)) | NC_or (nc1, nc2) -> re (NC_or (snc nc1, snc nc2)) | NC_and (nc1, nc2) -> re (NC_and (snc nc1, snc nc2)) | NC_true | NC_false -> n_constraint diff --git a/src/lib/ast_util.mli b/src/lib/ast_util.mli index 8dddee4f0..83afe890c 100644 --- a/src/lib/ast_util.mli +++ b/src/lib/ast_util.mli @@ -166,6 +166,7 @@ val mk_funcl : ?loc:l -> id -> uannot pat -> uannot exp -> uannot funcl val mk_fundef : uannot funcl list -> uannot def val mk_val_spec : val_spec_aux -> uannot def val mk_typschm : typquant -> typ -> typschm +val mk_empty_typquant : loc:l -> typquant val mk_typquant : quant_item list -> typquant val mk_qi_id : kind_aux -> kid -> quant_item val mk_qi_nc : n_constraint -> quant_item @@ -215,6 +216,8 @@ val is_typ_arg_nexp : typ_arg -> bool val is_typ_arg_typ : typ_arg -> bool val is_typ_arg_bool : typ_arg -> bool +val typ_arg_kind : typ_arg -> kind + (** {2 Sail built-in types} *) val unknown_typ : typ @@ -301,8 +304,9 @@ val nc_or : n_constraint -> n_constraint -> n_constraint val nc_not : n_constraint -> n_constraint val nc_true : n_constraint val nc_false : n_constraint -val nc_set : kid -> Big_int.num list -> n_constraint -val nc_int_set : kid -> int list -> n_constraint +val nc_set : nexp -> Big_int.num list -> n_constraint +val nc_int_set : nexp -> int list -> n_constraint +val nc_id : id -> n_constraint val nc_var : kid -> n_constraint (** {2 Functions for building type arguments}*) @@ -466,6 +470,7 @@ val string_of_index_range : index_range -> string val id_of_fundef : 'a fundef -> id val id_of_mapdef : 'a mapdef -> id +val id_of_type_def_aux : type_def_aux -> id val id_of_type_def : 'a type_def -> id val id_of_val_spec : 'a val_spec -> id val id_of_dec_spec : 'a dec_spec -> id diff --git a/src/lib/callgraph.ml b/src/lib/callgraph.ml index 5369bcd67..df85e0d8d 100644 --- a/src/lib/callgraph.ml +++ b/src/lib/callgraph.ml @@ -130,6 +130,7 @@ let rec constraint_ids' (NC_aux (aux, _)) = IdSet.union (nexp_ids' n1) (nexp_ids' n2) | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> IdSet.union (constraint_ids' nc1) (constraint_ids' nc2) | NC_var _ | NC_true | NC_false | NC_set _ -> IdSet.empty + | NC_id id -> IdSet.singleton id | NC_app (id, args) -> IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args)) and nexp_ids' (Nexp_aux (aux, _)) = @@ -278,6 +279,7 @@ let add_def_to_graph graph (DEF_aux (def, _)) = scan_typquant (Type id) typq | TD_enum (id, ctors, _) -> List.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) ctors + | TD_abstract (id, _) -> graph := G.add_edges (Type id) [] !graph | TD_bitfield (id, typ, ranges) -> graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_ids typ))) !graph in @@ -378,14 +380,6 @@ let rec graph_of_defs defs = let graph_of_ast ast = graph_of_defs ast.defs -let id_of_typedef (TD_aux (aux, _)) = - match aux with - | TD_abbrev (id, _, _) -> id - | TD_record (id, _, _, _) -> id - | TD_variant (id, _, _, _) -> id - | TD_enum (id, _, _) -> id - | TD_bitfield (id, _, _) -> id - let id_of_reg_dec (DEC_aux (DEC_reg (_, id, _), _)) = id let id_of_funcl (FCL_aux (FCL_funcl (id, _), _)) = id @@ -426,7 +420,7 @@ let filter_ast_extra cuts g ast keep_std = let ids = pat_ids pat |> IdSet.elements in if List.exists (fun id -> NM.mem (Letbind id) g) ids then DEF_aux (DEF_let lb, def_annot) :: filter_ast' g defs else filter_ast' g defs - | DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_typedef tdef)) g -> + | DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_type_def tdef)) g -> DEF_aux (DEF_type tdef, def_annot) :: filter_ast' g defs | DEF_aux (DEF_type _, _) :: defs -> filter_ast' g defs | DEF_aux (DEF_measure (id, _, _), _) :: defs when NS.mem (Function id) cuts -> filter_ast' g defs diff --git a/src/lib/constant_propagation.ml b/src/lib/constant_propagation.ml index fd20eda36..6ced6cff8 100644 --- a/src/lib/constant_propagation.ml +++ b/src/lib/constant_propagation.ml @@ -144,7 +144,7 @@ let lit_match = function let fabricate_nexp_exist env l typ kids nc typ' = match (kids, nc, Env.expand_synonyms env typ') with | ( [kid], - NC_aux (NC_set (kid', i :: _), _), + NC_aux (NC_set (Nexp_aux (Nexp_var kid', _), i :: _), _), Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _)]), _) ) when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 -> Nexp_aux (Nexp_constant i, Unknown) @@ -154,7 +154,7 @@ let fabricate_nexp_exist env l typ kids nc typ' = when Kid.compare kid kid'' = 0 -> nint 32 | ( [kid], - NC_aux (NC_set (kid', i :: _), _), + NC_aux (NC_set (Nexp_aux (Nexp_var kid', _), i :: _), _), Typ_aux ( Typ_app ( Id_aux (Id "range", _), diff --git a/src/lib/constraint.ml b/src/lib/constraint.ml index fe3586f89..ad820605b 100644 --- a/src/lib/constraint.ml +++ b/src/lib/constraint.ml @@ -72,11 +72,19 @@ open Util let opt_smt_verbose = ref false -type solver = { command : string; header : string; footer : string; negative_literals : bool; uninterpret_power : bool } +type solver = { + command : string; + args : string -> string Array.t; + header : string; + footer : string; + negative_literals : bool; + uninterpret_power : bool; +} let cvc4_solver = { - command = "cvc4 -L smtlib2 --tlimit=2000"; + command = "cvc4"; + args = (fun input -> [| "-L"; "smtlib2"; "--tlimit=2000"; input |]); header = "(set-logic QF_UFNIA)\n"; footer = ""; negative_literals = false; @@ -86,6 +94,7 @@ let cvc4_solver = let mathsat_solver = { command = "mathsat"; + args = (fun input -> [| input |]); header = "(set-logic QF_UFLIA)\n"; footer = ""; negative_literals = false; @@ -94,7 +103,8 @@ let mathsat_solver = let z3_solver = { - command = "z3 -t:1000 -T:10"; + command = "z3"; + args = (fun input -> [| "-t:1000"; "-T:10"; input |]); (* Using push and pop is much faster, I believe because incremental mode uses a different solver. *) header = "(push)\n"; @@ -106,25 +116,23 @@ let z3_solver = let yices_solver = { command = "yices-smt2 --timeout=2"; + args = (fun input -> [| "--timeout=2"; input |]); header = "(set-logic QF_UFLIA)\n"; footer = ""; negative_literals = false; uninterpret_power = true; } -let vampire_solver = +let alt_ergo_solver = { - (* vampire sometimes likes to ignore its time limit *) - command = "timeout -s SIGKILL 3s vampire --time_limit 2s --input_syntax smtlib2 --mode smtcomp"; + command = "alt-ergo"; + args = (fun input -> [| input |]); header = ""; footer = ""; negative_literals = false; uninterpret_power = true; } -let alt_ergo_solver = - { command = "alt-ergo"; header = ""; footer = ""; negative_literals = false; uninterpret_power = true } - let opt_solver = ref z3_solver let set_solver = function @@ -132,7 +140,6 @@ let set_solver = function | "alt-ergo" -> opt_solver := alt_ergo_solver | "cvc4" -> opt_solver := cvc4_solver | "mathsat" -> opt_solver := mathsat_solver - | "vampire" -> opt_solver := vampire_solver | "yices" -> opt_solver := yices_solver | unknown -> prerr_endline ("Unrecognised SMT solver " ^ unknown) @@ -168,9 +175,9 @@ let rec add_list buf sep add_elem = function let smt_type l = function | K_int -> Atom "Int" | K_bool -> Atom "Bool" - | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kinded variable to SMT solver") + | K_type -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type kinded variable to SMT solver") -let to_smt l vars constr = +let to_smt l abstract vars constr = (* Numbering all SMT variables v0, ... vn, rather than generating names based on their Sail names (e.g. using zencode) ensures that alpha-equivalent constraints generate the same SMT problem, which @@ -189,10 +196,20 @@ let to_smt l vars constr = let exponentials = ref [] in + let abstract_decs = + abstract |> Bindings.bindings + |> List.filter_map (fun (id, kind) -> + match kind with + | K_aux (K_type, _) -> None + | _ -> + Some (sfun "declare-const" [Atom (Util.zencode_string (string_of_id id)); smt_type l (unaux_kind kind)]) + ) + in + (* var_decs outputs the list of variables to be used by the SMT solver in SMTLIB v2.0 format. It takes a kind_aux KBindings, as returned by Type_check.get_typ_vars *) - let var_decs l (vars : kind_aux KBindings.t) : sexpr list = + let var_decs (vars : kind_aux KBindings.t) : sexpr list = vars |> KBindings.bindings |> List.map (fun (v, k) -> sfun "declare-const" [fst (smt_var v); smt_type l k]) in let rec smt_nexp (Nexp_aux (aux, _) : nexp) : sexpr = @@ -223,13 +240,14 @@ let to_smt l vars constr = in let rec smt_constraint (NC_aux (aux, _) : n_constraint) : sexpr = match aux with + | NC_id id -> Atom (Util.zencode_string (string_of_id id)) | NC_equal (nexp1, nexp2) -> sfun "=" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_le (nexp1, nexp2) -> sfun "<=" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_lt (nexp1, nexp2) -> sfun "<" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_ge (nexp1, nexp2) -> sfun ">=" [smt_nexp nexp1; smt_nexp nexp2] | NC_bounded_gt (nexp1, nexp2) -> sfun ">" [smt_nexp nexp1; smt_nexp nexp2] | NC_not_equal (nexp1, nexp2) -> sfun "not" [sfun "=" [smt_nexp nexp1; smt_nexp nexp2]] - | NC_set (v, ints) -> sfun "or" (List.map (fun i -> sfun "=" [fst (smt_var v); Atom (Big_int.to_string i)]) ints) + | NC_set (nexp, ints) -> sfun "or" (List.map (fun i -> sfun "=" [smt_nexp nexp; Atom (Big_int.to_string i)]) ints) | NC_or (nc1, nc2) -> sfun "or" [smt_constraint nc1; smt_constraint nc2] | NC_and (nc1, nc2) -> sfun "and" [smt_constraint nc1; smt_constraint nc2] | NC_app (id, args) -> sfun (string_of_id id) (List.map smt_typ_arg args) @@ -243,18 +261,19 @@ let to_smt l vars constr = | _ -> raise (Reporting.err_unreachable l __POS__ "Tried to pass Type or Order kind to SMT function") in let smt_constr = smt_constraint constr in - (var_decs l vars, smt_constr, smt_var, !exponentials) + (abstract_decs @ var_decs vars, smt_constr, smt_var, !exponentials) let sailexp_concrete n = List.init (n + 1) (fun i -> sfun "=" [sfun "sailexp" [Atom (string_of_int i)]; Atom (Big_int.to_string (Big_int.pow_int_positive 2 i))] ) -let smtlib_of_constraints ?(get_model = false) l vars extra constr : string * (kid -> sexpr * bool) * sexpr list = +let smtlib_of_constraints ?(get_model = false) l abstract vars extra constr : + string * (kid -> sexpr * bool) * sexpr list = let open Buffer in let buf = create 512 in add_string buf !opt_solver.header; - let variables, problem, var_map, exponentials = to_smt l vars constr in + let variables, problem, var_map, exponentials = to_smt l abstract vars constr in add_list buf '\n' add_sexpr variables; add_char buf '\n'; if !opt_solver.uninterpret_power then add_string buf "(declare-fun sailexp (Int) Int)\n"; @@ -326,7 +345,7 @@ let constraint_to_smt l constr = kopts_of_constraint constr |> KOptSet.elements |> List.map kopt_pair |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty in - let vars, sexpr, var_map, exponentials = to_smt l vars constr in + let vars, sexpr, var_map, exponentials = to_smt l Bindings.empty vars constr in let vars = string_of_list "\n" pp_sexpr vars in ( vars ^ "\n(assert " ^ pp_sexpr sexpr ^ ")", (fun v -> @@ -336,13 +355,13 @@ let constraint_to_smt l constr = List.map pp_sexpr exponentials ) -let rec call_smt' l extra constraints = +let rec call_smt' l abstract extra constraints = let vars = kopts_of_constraint constraints |> KOptSet.elements |> List.map kopt_pair |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty in let problems = [constraints] in - let smt_file, _, exponentials = smtlib_of_constraints l vars extra constraints in + let smt_file, _, exponentials = smtlib_of_constraints l abstract vars extra constraints in if !opt_smt_verbose then prerr_endline (Printf.sprintf "SMTLIB2 constraints are: \n%s%!" smt_file); @@ -371,7 +390,11 @@ let rec call_smt' l extra constraints = let status, smt_output, smt_errors = try let smt_out, smt_in, smt_err = - Unix.open_process_full (!opt_solver.command ^ " " ^ input_file) (Unix.environment ()) + let cmd = + !opt_solver.command ^ " " + ^ Util.string_of_list " " (fun x -> x) (Array.to_list (!opt_solver.args input_file)) + in + Unix.open_process_full cmd (Unix.environment ()) in let smt_output = try List.combine problems (input_lines smt_out (List.length problems)) @@ -418,7 +441,7 @@ let rec call_smt' l extra constraints = then try replacing `2^` with an uninterpreted function to see if the problem would be unsat in that case. *) opt_solver := { !opt_solver with uninterpret_power = true }; - let result = call_smt_uninterpret_power ~bound:64 l constraints in + let result = call_smt_uninterpret_power ~bound:64 l abstract constraints in opt_solver := { !opt_solver with uninterpret_power = false }; result | Unknown -> Unknown @@ -426,31 +449,31 @@ let rec call_smt' l extra constraints = exponentials ) -and call_smt_uninterpret_power ~bound l constraints = - match call_smt' l (sailexp_concrete bound) constraints with +and call_smt_uninterpret_power ~bound l abstract constraints = + match call_smt' l abstract (sailexp_concrete bound) constraints with | Unsat, _ -> Unsat | Sat, exponentials -> begin - match call_smt' l (sailexp_concrete bound @ List.map bound_exponential exponentials) constraints with + match call_smt' l abstract (sailexp_concrete bound @ List.map bound_exponential exponentials) constraints with | Sat, _ -> Sat | _ -> Unknown end | _ -> Unknown -let call_smt l constraints = +let call_smt l abstract constraints = let t = Profile.start_smt () in let result = - if !opt_solver.uninterpret_power then call_smt_uninterpret_power ~bound:64 l constraints - else fst (call_smt' l [] constraints) + if !opt_solver.uninterpret_power then call_smt_uninterpret_power ~bound:64 l abstract constraints + else fst (call_smt' l abstract [] constraints) in Profile.finish_smt t; result -let solve_smt_file l extra constraints = +let solve_smt_file l abstract extra constraints = let vars = kopts_of_constraint constraints |> KOptSet.elements |> List.map kopt_pair |> List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty in - smtlib_of_constraints ~get_model:true l vars extra constraints + smtlib_of_constraints ~get_model:true l abstract vars extra constraints let call_smt_solve l smt_file smt_vars var = let smt_var = pp_sexpr (fst (smt_vars var)) in @@ -539,23 +562,23 @@ let call_smt_solve_bitvector l smt_file smt_vars = smt_vars |> Util.option_all -let solve_smt l constraints var = - let smt_file, smt_vars, _ = solve_smt_file l [] constraints in +let solve_smt l abstract constraints var = + let smt_file, smt_vars, _ = solve_smt_file l abstract [] constraints in call_smt_solve l smt_file smt_vars var -let solve_all_smt l constraints var = +let solve_all_smt l abstract constraints var = let rec aux results = let constraints = List.fold_left (fun ncs r -> nc_and ncs (nc_neq (nconstant r) (nvar var))) constraints results in - match solve_smt l constraints var with + match solve_smt l abstract constraints var with | Some result -> aux (result :: results) | None -> ( - match call_smt l constraints with Unsat -> Some results | _ -> None + match call_smt l abstract constraints with Unsat -> Some results | _ -> None ) in aux [] -let solve_unique_smt' l constraints exp_defn exp_bound var = - let smt_file, smt_vars, exponentials = solve_smt_file l (exp_defn @ exp_bound) constraints in +let solve_unique_smt' l abstract constraints exp_defn exp_bound var = + let smt_file, smt_vars, exponentials = solve_smt_file l abstract (exp_defn @ exp_bound) constraints in let digest = Digest.string (smt_file ^ pp_sexpr (fst (smt_vars var))) in let result = match DigestMap.find_opt digest !known_uniques with @@ -565,7 +588,9 @@ let solve_unique_smt' l constraints exp_defn exp_bound var = match call_smt_solve l smt_file smt_vars var with | Some result -> let t = Profile.start_smt () in - let smt_result' = fst (call_smt' l exp_defn (nc_and constraints (nc_neq (nconstant result) (nvar var)))) in + let smt_result' = + fst (call_smt' l abstract exp_defn (nc_and constraints (nc_neq (nconstant result) (nvar var)))) + in Profile.finish_smt t; begin match smt_result' with @@ -588,17 +613,17 @@ let solve_unique_smt' l constraints exp_defn exp_bound var = (* Follows the same approach as call_smt' for unknown results due to exponentials, retrying with a bounded spec. *) -let solve_unique_smt l constraints var = +let solve_unique_smt l abstract constraints var = let t = Profile.start_smt () in let result = - match solve_unique_smt' l constraints [] [] var with + match solve_unique_smt' l abstract constraints [] [] var with | Some result, _ -> Some result | None, [] -> None | None, exponentials -> opt_solver := { !opt_solver with uninterpret_power = true }; let sailexp = sailexp_concrete 64 in let exp_bound = List.map bound_exponential exponentials in - let result, _ = solve_unique_smt' l constraints sailexp exp_bound var in + let result, _ = solve_unique_smt' l abstract constraints sailexp exp_bound var in opt_solver := { !opt_solver with uninterpret_power = false }; result in diff --git a/src/lib/constraint.mli b/src/lib/constraint.mli index 333e1c0a9..de2cbc501 100644 --- a/src/lib/constraint.mli +++ b/src/lib/constraint.mli @@ -83,12 +83,12 @@ val save_digests : unit -> unit val constraint_to_smt : l -> n_constraint -> string * (kid -> string * bool) * string list -val call_smt : l -> n_constraint -> smt_result +val call_smt : l -> kind Bindings.t -> n_constraint -> smt_result val call_smt_solve_bitvector : l -> string -> (int * string) list -> (int * lit) list option -val solve_smt : l -> n_constraint -> kid -> Big_int.num option +val solve_smt : l -> kind Bindings.t -> n_constraint -> kid -> Big_int.num option -val solve_all_smt : l -> n_constraint -> kid -> Big_int.num list option +val solve_all_smt : l -> kind Bindings.t -> n_constraint -> kid -> Big_int.num list option -val solve_unique_smt : l -> n_constraint -> kid -> Big_int.num option +val solve_unique_smt : l -> kind Bindings.t -> n_constraint -> kid -> Big_int.num option diff --git a/src/lib/effects.ml b/src/lib/effects.ml index 9306cc0d5..f845c8610 100644 --- a/src/lib/effects.ml +++ b/src/lib/effects.ml @@ -263,6 +263,7 @@ let infer_mapdef_extra_direct_effects def = let can_have_direct_side_effect (DEF_aux (aux, _)) = match aux with | DEF_type _ -> false + | DEF_constraint _ -> false | DEF_fundef _ -> true | DEF_mapdef _ -> false | DEF_impl _ -> true diff --git a/src/lib/frontend.ml b/src/lib/frontend.ml index 1c9cf3139..c5abe9898 100644 --- a/src/lib/frontend.ml +++ b/src/lib/frontend.ml @@ -84,6 +84,31 @@ let check_ast (asserts_termination : bool) (env : Type_check.Env.t) (ast : uanno let () = if !opt_ddump_tc_ast then Pretty_print_sail.pp_ast stdout (Type_check.strip_ast ast) else () in (ast, env, side_effects) +let instantiate_abstract_types insts ast = + let open Ast in + let instantiate = function + | DEF_aux (DEF_type (TD_aux (TD_abstract (id, kind), (l, _))), def_annot) as def -> begin + match Bindings.find_opt id insts with + | Some arg -> + let arg_kind = typ_arg_kind arg in + if Kind.compare arg_kind kind <> 0 then + raise + (Reporting.err_general l + (Printf.sprintf + "Failed to instantiate abstract type. Abstract type has kind %s, but instantiation has kind %s" + (string_of_kind kind) (string_of_kind arg_kind) + ) + ); + DEF_aux + ( DEF_type (TD_aux (TD_abbrev (id, mk_empty_typquant ~loc:(gen_loc l), arg), (l, Type_check.empty_tannot))), + def_annot + ) + | None -> def + end + | def -> def + in + { ast with defs = List.map instantiate ast.defs } + type parsed_module = { id : Project.mod_id; included : bool; diff --git a/src/lib/frontend.mli b/src/lib/frontend.mli index 2225804cf..cd1983eda 100644 --- a/src/lib/frontend.mli +++ b/src/lib/frontend.mli @@ -65,17 +65,20 @@ (* SUCH DAMAGE. *) (****************************************************************************) +open Ast +open Ast_defs +open Ast_util + val opt_ddump_initial_ast : bool ref val opt_ddump_tc_ast : bool ref val opt_list_files : bool ref val opt_reformat : string option ref -open Ast_defs -open Ast_util - val check_ast : bool -> Type_check.Env.t -> uannot ast -> Type_check.tannot ast * Type_check.Env.t * Effects.side_effect_info +val instantiate_abstract_types : typ_arg Bindings.t -> Type_check.tannot ast -> Type_check.tannot ast + val load_modules : ?target:Target.target -> string -> diff --git a/src/lib/initial_check.ml b/src/lib/initial_check.ml index 34683ccbc..f363d9255 100644 --- a/src/lib/initial_check.ml +++ b/src/lib/initial_check.ml @@ -77,6 +77,9 @@ module P = Parse_ast (* See mli file for details on what these flags do *) let opt_fast_undefined = ref false let opt_magic_hash = ref false +let opt_abstract_types = ref false + +let abstract_type_error = "Abstract types are currently experimental, use the --abstract-types flag to enable" module StringSet = Set.Make (String) module StringMap = Map.Make (String) @@ -334,7 +337,7 @@ let rec to_ast_typ ctx atyp = | P.ATyp_bidir (typ1, typ2, _) -> Typ_aux (Typ_bidir (to_ast_typ ctx typ1, to_ast_typ ctx typ2), l) | P.ATyp_nset nums -> let n = Kid_aux (Var "'n", gen_loc l) in - Typ_aux (Typ_exist ([mk_kopt ~loc:l K_int n], nc_set n nums, atom_typ (nvar n)), l) + Typ_aux (Typ_exist ([mk_kopt ~loc:l K_int n], nc_set (nvar n) nums, atom_typ (nvar n)), l) | P.ATyp_tuple typs -> Typ_aux (Typ_tuple (List.map (to_ast_typ ctx) typs), l) | P.ATyp_app (P.Id_aux (P.Id "int", il), [n]) -> Typ_aux (Typ_app (Id_aux (Id "atom", il), [to_ast_typ_arg ctx n K_int]), l) @@ -473,10 +476,11 @@ and to_ast_constraint ctx atyp = ) ) end + | P.ATyp_id id -> NC_id (to_ast_id ctx id) | P.ATyp_var v -> NC_var (to_ast_var v) | P.ATyp_lit (P.L_aux (P.L_true, _)) -> NC_true | P.ATyp_lit (P.L_aux (P.L_false, _)) -> NC_false - | P.ATyp_in (P.ATyp_aux (P.ATyp_var v, _), P.ATyp_aux (P.ATyp_nset bounds, _)) -> NC_set (to_ast_var v, bounds) + | P.ATyp_in (n, P.ATyp_aux (P.ATyp_nset bounds, _)) -> NC_set (to_ast_nexp ctx n, bounds) | _ -> raise (Reporting.err_typ l "Invalid constraint") in NC_aux (aux, l) @@ -1055,6 +1059,17 @@ let rec to_ast_typedef ctx def_annot (P.TD_aux (aux, l) : P.type_def) : uannot d ( fns @ [DEF_aux (DEF_type (TD_aux (TD_enum (id, enums, false), (l, empty_uannot))), def_annot)], { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } ) + | P.TD_abstract (id, kind) -> + if not !opt_abstract_types then raise (Reporting.err_general l abstract_type_error); + let id = to_ast_reserved_type_id ctx id in + begin + match to_ast_kind kind with + | Some kind -> + ( [DEF_aux (DEF_type (TD_aux (TD_abstract (id, kind), (l, empty_uannot))), def_annot)], + { ctx with type_constructors = Bindings.add id [] ctx.type_constructors } + ) + | None -> raise (Reporting.err_general l "Abstract type cannot have Order kind") + end | P.TD_bitfield (id, typ, ranges) -> let id = to_ast_reserved_type_id ctx id in let typ = to_ast_typ ctx typ in @@ -1329,6 +1344,10 @@ let rec to_ast_def doc attrs vis ctx (P.DEF_aux (def, l)) : uannot def list ctx_ | P.DEF_register dec -> let d = to_ast_dec ctx dec in ([DEF_aux (DEF_register d, annot)], ctx) + | P.DEF_constraint nc -> + if not !opt_abstract_types then raise (Reporting.err_general l abstract_type_error); + let nc = to_ast_constraint ctx nc in + ([DEF_aux (DEF_constraint nc, annot)], ctx) | P.DEF_pragma (pragma, arg, ltrim) -> let l = pragma_arg_loc pragma ltrim l in begin diff --git a/src/lib/initial_check.mli b/src/lib/initial_check.mli index 96b4760c8..2534d9988 100644 --- a/src/lib/initial_check.mli +++ b/src/lib/initial_check.mli @@ -77,6 +77,10 @@ val merge_ctx : Parse_ast.l -> ctx -> ctx -> ctx (** {2 Options} *) +(** Enable abstract types in the AST. If unset, will report an error + if they are encountered. *) +val opt_abstract_types : bool ref + (** Generate faster undefined_T functions. Rather than generating functions that allow for the undefined values of enums and variants to be picked at runtime using a RNG or similar, this creates diff --git a/src/lib/jib_compile.ml b/src/lib/jib_compile.ml index 55c6c4659..5301ffac9 100644 --- a/src/lib/jib_compile.ml +++ b/src/lib/jib_compile.ml @@ -1154,6 +1154,7 @@ module Make (C : CONFIG) = struct | TD_bitfield _ -> Reporting.unreachable l __POS__ "Cannot compile TD_bitfield" (* All type abbreviations are filtered out in compile_def *) | TD_abbrev _ -> Reporting.unreachable l __POS__ "Found TD_abbrev in compile_type_def" + | TD_abstract _ -> Reporting.unreachable l __POS__ "Abstract type not supported yet" let generate_cleanup instrs = let generate_cleanup' (I_aux (instr, _)) = @@ -1559,6 +1560,7 @@ module Make (C : CONFIG) = struct | DEF_scattered _ | DEF_mapdef _ | DEF_outcome _ | DEF_impl _ | DEF_instantiation _ -> Reporting.unreachable (def_loc def) __POS__ ("Could not compile:\n" ^ Pretty_print_sail.to_string (Pretty_print_sail.doc_def (strip_def def))) + | DEF_constraint _ -> Reporting.unreachable (def_loc def) __POS__ "Toplevel constraint not supported" let mangle_mono_id id ctx ctyps = append_id id ("<" ^ Util.string_of_list "," (mangle_string_of_ctyp ctx) ctyps ^ ">") diff --git a/src/lib/monomorphise.ml b/src/lib/monomorphise.ml index 5b2bc4fd2..662fd9470 100644 --- a/src/lib/monomorphise.ml +++ b/src/lib/monomorphise.ml @@ -232,7 +232,7 @@ let extract_set_nc env l var nc = in match nc with - | NC_set (id, is) when KidSet.mem id vars -> Some (is, re NC_true) + | NC_set (Nexp_aux (Nexp_var id, _), is) when KidSet.mem id vars -> Some (is, re NC_true) | NC_equal (Nexp_aux (Nexp_var id, _), Nexp_aux (Nexp_constant n, _)) when KidSet.mem id vars -> Some ([n], re NC_true) | NC_and ((NC_aux (NC_bounded_le (Nexp_aux (Nexp_constant n, _), Nexp_aux (Nexp_var kid, _)), _) as nc1), nc2) @@ -1255,8 +1255,8 @@ let split_defs target all_errors (splits : split_req list) env ast = let map_def idx (DEF_aux (aux, def_annot) as def) = Util.progress "Monomorphising " (string_of_int idx ^ "/" ^ string_of_int num_defs) idx num_defs; match aux with - | DEF_type _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ | DEF_fixity _ | DEF_pragma _ - | DEF_internal_mutrec _ -> + | DEF_type _ | DEF_constraint _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ | DEF_fixity _ + | DEF_pragma _ | DEF_internal_mutrec _ -> [def] | DEF_fundef fd -> [DEF_aux (DEF_fundef (map_fundef fd), def_annot)] | DEF_let lb -> [DEF_aux (DEF_let (map_letbind lb), def_annot)] @@ -2030,16 +2030,12 @@ module Analysis = struct | NC_bounded_lt (nexp1, nexp2) | NC_not_equal (nexp1, nexp2) -> dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) - | NC_set (kid, _) -> ( - match KBindings.find kid kid_deps with - | deps -> deps - | exception Not_found -> Unknown (l, "Unknown type variable in constraint " ^ string_of_kid kid) - ) + | NC_set (nexp, _) -> deps_of_nexp l kid_deps [] nexp | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> dmerge (deps_of_nc kid_deps nc1) (deps_of_nc kid_deps nc2) | NC_true | NC_false -> dempty | NC_app (Id_aux (Id "mod", _), [A_aux (A_nexp nexp1, _); A_aux (A_nexp nexp2, _)]) -> dmerge (deps_of_nexp l kid_deps [] nexp1) (deps_of_nexp l kid_deps [] nexp2) - | NC_var _ | NC_app _ -> dempty + | NC_id _ | NC_var _ | NC_app _ -> dempty and deps_of_typ l kid_deps arg_deps typ = deps_of_tyvars l kid_deps arg_deps (tyvars_of_typ typ) @@ -2397,7 +2393,7 @@ module Analysis = struct let l' = Generated l in let split = match typ_of e1 with - | Typ_aux (Typ_exist ([kdid], NC_aux (NC_set (kid, sizes), _), typ), _) + | Typ_aux (Typ_exist ([kdid], NC_aux (NC_set (Nexp_aux (Nexp_var kid, _), sizes), _), typ), _) when Kid.compare (kopt_kid kdid) kid == 0 -> begin match Type_check.destruct_atom_nexp (env_of e1) typ with | Some nexp when Nexp.compare (nvar kid) nexp == 0 -> @@ -2412,7 +2408,11 @@ module Analysis = struct begin match Util.find_map - (function NC_aux (NC_set (kid'', is), _) when KidSet.mem kid'' vars -> Some is | _ -> None) + (function + | NC_aux (NC_set (Nexp_aux (Nexp_var kid'', _), is), _) when KidSet.mem kid'' vars -> + Some is + | _ -> None + ) constraints with | Some sizes -> @@ -2766,7 +2766,7 @@ module Analysis = struct let rec sets_from_nc (NC_aux (nc, l) as nc_full) = match nc with | NC_and (nc1, nc2) -> merge_set_asserts_by_kid (sets_from_nc nc1) (sets_from_nc nc2) - | NC_set (kid, is) -> KBindings.singleton kid (l, is) + | NC_set (Nexp_aux (Nexp_var kid, _), is) -> KBindings.singleton kid (l, is) | NC_equal (Nexp_aux (Nexp_var kid, _), Nexp_aux (Nexp_constant n, _)) -> KBindings.singleton kid (l, [n]) | NC_or _ -> ( match set_from_nc_or nc_full with @@ -4659,6 +4659,7 @@ module ToplevelNexpRewrites = struct | TD_abbrev (id, typq, A_aux (A_typ typ, l)) -> TD_aux (TD_abbrev (id, typq, A_aux (A_typ (expand_type typ), l)), annot) | TD_abbrev (id, typq, typ_arg) -> TD_aux (TD_abbrev (id, typq, typ_arg), annot) + | TD_abstract (id, kind) -> TD_aux (TD_abstract (id, kind), annot) | TD_record (id, typq, typ_ids, flag) -> TD_aux (TD_record (id, typq, List.map (fun (typ, id) -> (expand_type typ, id)) typ_ids, flag), annot) | TD_variant (id, typq, tus, flag) -> TD_aux (TD_variant (id, typq, List.map rw_union tus, flag), annot) diff --git a/src/lib/parse_ast.ml b/src/lib/parse_ast.ml index 107df6a30..94f00df9d 100644 --- a/src/lib/parse_ast.ml +++ b/src/lib/parse_ast.ml @@ -385,6 +385,7 @@ type type_def_aux = | TD_record of id * typquant * (atyp * id) list (* struct type definition *) | TD_variant of id * typquant * type_union list (* union type definition *) | TD_enum of id * (id * atyp) list * (id * exp option) list (* enumeration type definition *) + | TD_abstract of id * kind | TD_bitfield of id * atyp * (id * index_range) list (* register mutable bitfield type definition *) type val_spec_aux = (* Value type specification *) @@ -427,6 +428,7 @@ type fixity_token = prec * Big_int.num * string type def_aux = (* Top-level definition *) | DEF_type of type_def (* type definition *) + | DEF_constraint of atyp (* global constraint *) | DEF_fundef of fundef (* function definition *) | DEF_mapdef of mapdef (* mapping definition *) | DEF_impl of funcl (* impl definition *) diff --git a/src/lib/parser.mly b/src/lib/parser.mly index 52f20b690..a97abd50f 100644 --- a/src/lib/parser.mly +++ b/src/lib/parser.mly @@ -964,6 +964,8 @@ type_def: { mk_td (TD_abbrev ($2, $3, $5, $7)) $startpos $endpos } | Typedef id Colon kind Eq typ { mk_td (TD_abbrev ($2, mk_typqn, $4, $6)) $startpos $endpos } + | Typedef id Colon kind + { mk_td (TD_abstract ($2, $4)) $startpos $endpos } | Struct id Eq Lcurly struct_fields Rcurly { mk_td (TD_record ($2, TypQ_aux (TypQ_tq [], loc $endpos($2) $startpos($3)), $5)) $startpos $endpos } | Struct id typaram Eq Lcurly struct_fields Rcurly @@ -1305,6 +1307,8 @@ def_aux: { DEF_scattered $1 } | default_def { DEF_default $1 } + | Constraint typ + { DEF_constraint $2 } | Mutual Lcurly fun_def_list Rcurly { DEF_internal_mutrec $3 } | Pragma diff --git a/src/lib/pretty_print_sail.ml b/src/lib/pretty_print_sail.ml index 1d083782e..b1c76ec5a 100644 --- a/src/lib/pretty_print_sail.ml +++ b/src/lib/pretty_print_sail.ml @@ -144,6 +144,7 @@ let rec doc_nc nc = let nc_op op n1 n2 = separate space [doc_nexp n1; string op; doc_nexp n2] in let rec atomic_nc (NC_aux (nc_aux, _) as nc) = match nc_aux with + | NC_id id -> doc_id id | NC_true -> string "true" | NC_false -> string "false" | NC_equal (n1, n2) -> nc_op "==" n1 n2 @@ -152,8 +153,8 @@ let rec doc_nc nc = | NC_bounded_gt (n1, n2) -> nc_op ">" n1 n2 | NC_bounded_le (n1, n2) -> nc_op "<=" n1 n2 | NC_bounded_lt (n1, n2) -> nc_op "<" n1 n2 - | NC_set (kid, ints) -> - separate space [doc_kid kid; string "in"; braces (separate_map (comma ^^ space) doc_int ints)] + | NC_set (nexp, ints) -> + separate space [doc_nexp nexp; string "in"; braces (separate_map (comma ^^ space) doc_int ints)] | NC_app (id, args) -> doc_id id ^^ parens (separate_map (comma ^^ space) doc_typ_arg args) | NC_var kid -> doc_kid kid | NC_or _ | NC_and _ -> nc0 ~parenthesize:true nc @@ -197,7 +198,7 @@ and doc_typ ?(simple = false) (Typ_aux (typ_aux, l)) = (* Resugar set types like {|1, 2, 3|} *) | Typ_exist ( [kopt], - NC_aux (NC_set (kid1, ints), _), + NC_aux (NC_set (Nexp_aux (Nexp_var kid1, _), ints), _), Typ_aux (Typ_app (id, [A_aux (A_nexp (Nexp_aux (Nexp_var kid2, _)), _)]), _) ) when Kid.compare (kopt_kid kopt) kid1 == 0 && Kid.compare kid1 kid2 == 0 && Id.compare (mk_id "atom") id == 0 -> @@ -677,6 +678,7 @@ let doc_typ_arg_kind sep (A_aux (aux, _)) = let doc_typdef (TD_aux (td, _)) = match td with + | TD_abstract (id, kind) -> begin doc_op colon (concat [string "type"; space; doc_id id]) (doc_kind kind) end | TD_abbrev (id, typq, typ_arg) -> begin match doc_typquant typq with | Some qdoc -> @@ -792,6 +794,7 @@ let rec doc_def_no_hardline (DEF_aux (aux, def_annot)) = | DEF_type t_def -> doc_typdef t_def | DEF_fundef f_def -> doc_fundef f_def | DEF_mapdef m_def -> doc_mapdef m_def + | DEF_constraint nc -> string "constraint" ^^ space ^^ doc_nc nc | DEF_outcome (OV_aux (OV_outcome (id, typschm, args), _), defs) -> ( string "outcome" ^^ space ^^ doc_id id ^^ space ^^ colon ^^ space ^^ doc_typschm typschm ^^ break 1 ^^ (string "with" ^//^ separate_map (comma ^^ break 1) doc_kopt_no_parens args) diff --git a/src/lib/rewriter.ml b/src/lib/rewriter.ml index a1c6a5484..c2f7d5d93 100644 --- a/src/lib/rewriter.ml +++ b/src/lib/rewriter.ml @@ -328,8 +328,8 @@ let rec rewrite_def rewriters (DEF_aux (aux, def_annot)) = match aux with | DEF_register (DEC_aux (DEC_reg (typ, id, Some exp), annot)) -> DEF_register (DEC_aux (DEC_reg (typ, id, Some (rewriters.rewrite_exp rewriters exp)), annot)) - | DEF_type _ | DEF_mapdef _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ | DEF_fixity _ - | DEF_instantiation _ -> + | DEF_type _ | DEF_constraint _ | DEF_mapdef _ | DEF_val _ | DEF_default _ | DEF_register _ | DEF_overload _ + | DEF_fixity _ | DEF_instantiation _ -> aux | DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef) | DEF_impl funcl -> DEF_impl (rewrite_funcl rewriters funcl) diff --git a/src/lib/rewrites.ml b/src/lib/rewrites.ml index a4456245a..9d54fa3f3 100644 --- a/src/lib/rewrites.ml +++ b/src/lib/rewrites.ml @@ -1989,6 +1989,7 @@ let rewrite_type_union_typs rw_typ (Tu_aux (Tu_ty_id (typ, id), annot)) = Tu_aux let rewrite_type_def_typs rw_typ rw_typquant (TD_aux (td, annot)) = match td with + | TD_abstract (id, kind) -> TD_aux (TD_abstract (id, kind), annot) | TD_abbrev (id, typq, A_aux (A_typ typ, l)) -> TD_aux (TD_abbrev (id, rw_typquant typq, A_aux (A_typ (rw_typ typ), l)), annot) | TD_abbrev (id, typq, typ_arg) -> TD_aux (TD_abbrev (id, rw_typquant typq, typ_arg), annot) diff --git a/src/lib/spec_analysis.ml b/src/lib/spec_analysis.ml index 00770238d..53be761de 100644 --- a/src/lib/spec_analysis.ml +++ b/src/lib/spec_analysis.ml @@ -160,12 +160,14 @@ and fv_of_nconstraint consider_var bound used (Ast.NC_aux (nc, _)) = | NC_bounded_lt (n1, n2) | NC_not_equal (n1, n2) -> fv_of_nexp consider_var bound (fv_of_nexp consider_var bound used n1) n2 - | NC_set (Ast.Kid_aux (Ast.Var i, _), _) | NC_var (Ast.Kid_aux (Ast.Var i, _)) -> + | NC_var (Ast.Kid_aux (Ast.Var i, _)) -> if consider_var then conditional_add_typ bound used (Ast.Id_aux (Ast.Id i, Parse_ast.Unknown)) else used + | NC_set (n, _) -> fv_of_nexp consider_var bound used n | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> fv_of_nconstraint consider_var bound (fv_of_nconstraint consider_var bound used nc1) nc2 | NC_app (id, targs) -> List.fold_right (fun ta n -> fv_of_targ consider_var bound n ta) targs (conditional_add_typ bound used id) + | NC_id id -> conditional_add_typ bound used id | NC_true | NC_false -> used let typq_bindings (TypQ_aux (tq, _)) = @@ -361,6 +363,7 @@ let fv_of_abbrev consider_var bound used typq typ_arg = let fv_of_type_def consider_var (TD_aux (t, _)) = match t with + | TD_abstract (id, kind) -> (init_env ("typ:" ^ string_of_id id), Nameset.empty) | TD_abbrev (id, typq, typ_arg) -> (init_env ("typ:" ^ string_of_id id), snd (fv_of_abbrev consider_var mt mt typq typ_arg)) | TD_record (id, typq, tids, _) -> @@ -522,6 +525,7 @@ let fv_of_def consider_var consider_scatter_as_one all_defs (DEF_aux (aux, _) as | DEF_type tdef -> fv_of_type_def consider_var tdef | DEF_fundef fdef -> fv_of_fun consider_var fdef | DEF_mapdef mdef -> (mt, mt (* fv_of_map consider_var mdef *)) + | DEF_constraint nc -> (mt, mt) | DEF_let lebind -> (fun (b, u, _) -> (b, u)) (fv_of_let consider_var mt mt mt lebind) | DEF_val vspec -> fv_of_vspec consider_var vspec | DEF_fixity _ -> (mt, mt) diff --git a/src/lib/specialize.ml b/src/lib/specialize.ml index 5be9f512b..e7f1d56b4 100644 --- a/src/lib/specialize.ml +++ b/src/lib/specialize.ml @@ -197,6 +197,7 @@ let string_of_instantiation instantiation = | A_typ typ -> string_of_typ typ | A_bool nc -> string_of_n_constraint nc and string_of_n_constraint = function + | NC_aux (NC_id id, _) -> string_of_id id | NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " = " ^ string_of_nexp n2 | NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2 | NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2 @@ -205,8 +206,7 @@ let string_of_instantiation instantiation = | NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2 | NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")" | NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")" - | NC_aux (NC_set (kid, ns), _) -> - kid_name (mk_kopt K_int kid) ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" + | NC_aux (NC_set (n, ns), _) -> string_of_nexp n ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}" | NC_aux (NC_true, _) -> "true" | NC_aux (NC_false, _) -> "false" | NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid) diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index 0cc1d960d..09630de2f 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -156,6 +156,7 @@ and replace_nexp_nc nexp nexp' (NC_aux (nc_aux, l) as nc) = let rep_nc = replace_nexp_nc nexp nexp' in let rep n = if Nexp.compare n nexp == 0 then nexp' else n in match nc_aux with + | NC_id id -> NC_aux (NC_id id, l) | NC_equal (n1, n2) -> NC_aux (NC_equal (rep n1, rep n2), l) | NC_bounded_ge (n1, n2) -> NC_aux (NC_bounded_ge (rep n1, rep n2), l) | NC_bounded_le (n1, n2) -> NC_aux (NC_bounded_le (rep n1, rep n2), l) @@ -393,8 +394,8 @@ and simp_typ_aux = function which is then a problem we can feed to the constraint solver expecting unsat. *) -let prove_smt ~assumptions:ncs (NC_aux (_, l) as nc) = - match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs) with +let prove_smt ~abstract ~assumptions:ncs (NC_aux (_, l) as nc) = + match Constraint.call_smt l abstract (List.fold_left nc_and (nc_not nc) ncs) with | Constraint.Unsat -> typ_debug (lazy "unsat"); true @@ -406,7 +407,7 @@ let prove_smt ~assumptions:ncs (NC_aux (_, l) as nc) = constraints, even when such constraints are irrelevant *) let ncs' = List.concat (List.map constraint_conj ncs) in let ncs' = List.filter (fun nc -> KidSet.is_empty (constraint_power_variables nc)) ncs' in - match Constraint.call_smt l (List.fold_left nc_and (nc_not nc) ncs') with + match Constraint.call_smt l abstract (List.fold_left nc_and (nc_not nc) ncs') with | Constraint.Unsat -> typ_debug (lazy "unsat"); true @@ -429,8 +430,9 @@ let solve_unique env (Nexp_aux (_, l) as nexp) = let env = Env.add_typ_var l (mk_kopt K_int (mk_kid "solve#")) env in let vars = Env.get_typ_vars env in let _vars = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) vars in + let abstract = Env.get_abstract_typs env in let constr = List.fold_left nc_and (nc_eq (nvar (mk_kid "solve#")) nexp) (Env.get_constraints env) in - Constraint.solve_unique_smt l constr (mk_kid "solve#") + Constraint.solve_unique_smt l abstract constr (mk_kid "solve#") let debug_pos (file, line, _, _) = "(" ^ file ^ "/" ^ string_of_int line ^ ") " @@ -451,7 +453,7 @@ let prove pos env nc = ^ string_of_list ", " string_of_n_constraint ncs ^ " |- " ^ string_of_n_constraint nc ); - match nc_aux with NC_true -> true | _ -> prove_smt ~assumptions:ncs nc + match nc_aux with NC_true -> true | _ -> prove_smt ~abstract:(Env.get_abstract_typs env) ~assumptions:ncs nc (**************************************************************************) (* 3. Unification *) @@ -495,8 +497,8 @@ let rec nc_identical (NC_aux (nc1, _)) (NC_aux (nc2, _)) = | NC_and (nc1a, nc1b), NC_and (nc2a, nc2b) -> nc_identical nc1a nc2a && nc_identical nc1b nc2b | NC_true, NC_true -> true | NC_false, NC_false -> true - | NC_set (kid1, ints1), NC_set (kid2, ints2) when List.length ints1 = List.length ints2 -> - Kid.compare kid1 kid2 = 0 && List.for_all2 (fun i1 i2 -> i1 = i2) ints1 ints2 + | NC_set (nexp1, ints1), NC_set (nexp2, ints2) when List.length ints1 = List.length ints2 -> + nexp_identical nexp1 nexp2 && List.for_all2 (fun i1 i2 -> i1 = i2) ints1 ints2 | NC_var kid1, NC_var kid2 -> Kid.compare kid1 kid2 = 0 | NC_app (id1, args1), NC_app (id2, args2) when List.length args1 = List.length args2 -> Id.compare id1 id2 = 0 && List.for_all2 typ_arg_identical args1 args2 @@ -926,10 +928,10 @@ and kid_order_arg kind_map (A_aux (aux, _)) = and kid_order_constraint kind_map (NC_aux (aux, _)) = match aux with - | (NC_var kid | NC_set (kid, _)) when KBindings.mem kid kind_map -> + | NC_var kid when KBindings.mem kid kind_map -> ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) - | NC_var _ | NC_set _ -> ([], kind_map) - | NC_true | NC_false -> ([], kind_map) + | NC_set (n, _) -> kid_order_nexp kind_map n + | NC_var _ | NC_id _ | NC_true | NC_false -> ([], kind_map) | NC_equal (n1, n2) | NC_not_equal (n1, n2) | NC_bounded_le (n1, n2) @@ -1065,7 +1067,7 @@ let rec subtyp l env typ1 typ2 = KBindings.filter (fun _ k -> match k with K_int | K_bool -> true | _ -> false) (Env.get_typ_vars env) in begin - match Constraint.call_smt l (nc_eq nexp1 nexp2) with + match Constraint.call_smt l Bindings.empty (nc_eq nexp1 nexp2) with | Constraint.Sat -> let env = Env.add_constraint (nc_eq nexp1 nexp2) env in if prove __POS__ env nc2 then () @@ -1218,6 +1220,7 @@ let rec rewrite_sizeof' l env (Nexp_aux (aux, _) as nexp) = let exp1 = rewrite_sizeof' l env nexp1 in let exp2 = rewrite_sizeof' l env nexp2 in mk_exp (E_app (mk_id "emod_int", [exp1; exp2])) + | Nexp_id id when Env.is_abstract_typ id env -> mk_exp (E_sizeof nexp) | Nexp_app _ | Nexp_id _ -> typ_error l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")") let rewrite_sizeof l env nexp = @@ -1255,14 +1258,16 @@ and rewrite_nc_aux l env = | NC_false -> E_lit (mk_lit L_false) | NC_true -> E_lit (mk_lit L_true) | NC_set (_, []) -> E_lit (mk_lit L_false) - | NC_set (kid, int :: ints) -> - let kid_eq kid int = nc_eq (nvar kid) (nconstant int) in - unaux_exp (rewrite_nc env (List.fold_left (fun nc int -> nc_or nc (kid_eq kid int)) (kid_eq kid int) ints)) + | NC_set (nexp, int :: ints) -> + let nexp_eq int = nc_eq nexp (nconstant int) in + unaux_exp (rewrite_nc env (List.fold_left (fun nc int -> nc_or nc (nexp_eq int)) (nexp_eq int) ints)) | NC_app (f, [A_aux (A_bool nc, _)]) when string_of_id f = "not" -> E_app (mk_id "not_bool", [rewrite_nc env nc]) | NC_app (f, args) -> unaux_exp (rewrite_nc env (Env.expand_constraint_synonyms env (mk_nc (NC_app (f, args))))) | NC_var v -> (* Would be better to translate change E_sizeof to take a kid, then rewrite to E_sizeof *) E_id (id_of_kid v) + | NC_id id when Env.is_abstract_typ id env -> E_constraint (NC_aux (NC_id id, l)) + | NC_id id -> typ_error l ("Cannot re-write constraint(" ^ string_of_id id ^ ")") let can_be_undefined ~at:l env typ = let rec check (Typ_aux (aux, _)) = @@ -1932,11 +1937,11 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au end | E_app_infix (x, op, y), _ -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, uannot))) typ | E_app (f, [E_aux (E_constraint nc, _)]), _ when string_of_id f = "_prove" -> - Env.wf_constraint env nc; + Env.wf_constraint ~at:l env nc; if prove __POS__ env nc then annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ else typ_error l ("Cannot prove " ^ string_of_n_constraint nc) | E_app (f, [E_aux (E_constraint nc, _)]), _ when string_of_id f = "_not_prove" -> - Env.wf_constraint env nc; + Env.wf_constraint ~at:l env nc; if prove __POS__ env nc then typ_error l ("Can prove " ^ string_of_n_constraint nc) else annot_exp (E_lit (L_aux (L_unit, Parse_ast.Unknown))) unit_typ | E_app (f, [E_aux (E_typ (typ, exp), _)]), _ when string_of_id f = "_check" -> @@ -2119,7 +2124,7 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au else typ_error l ("Type " ^ string_of_typ typ ^ " could be empty") else typ_error l ("Type " ^ string_of_typ typ ^ " cannot be undefined") | E_internal_assume (nc, exp), _ -> - Env.wf_constraint env nc; + Env.wf_constraint ~at:l env nc; let env = Env.add_constraint nc env in let exp' = crule check_exp env exp typ in annot_exp (E_internal_assume (nc, exp')) typ @@ -2162,7 +2167,7 @@ and check_block l env exps ret_typ = end | [exp] -> [final env exp] | E_aux (E_app (f, [E_aux (E_constraint nc, _)]), _) :: exps when string_of_id f = "_assume" -> - Env.wf_constraint env nc; + Env.wf_constraint ~at:l env nc; let env = Env.add_constraint nc env in let annotated_exp = annot_exp (E_app (f, [annot_exp (E_constraint nc) bool_typ None])) unit_typ None in annotated_exp :: check_block l env exps ret_typ @@ -3068,9 +3073,13 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) = ) end | E_lit lit -> annot_exp (E_lit lit) (infer_lit env lit) - | E_sizeof nexp -> irule infer_exp env (rewrite_sizeof l env (Env.expand_nexp_synonyms env nexp)) + | E_sizeof nexp -> begin + match nexp with + | Nexp_aux (Nexp_id id, _) when Env.is_abstract_typ id env -> annot_exp (E_sizeof nexp) (atom_typ nexp) + | _ -> irule infer_exp env (rewrite_sizeof l env (Env.expand_nexp_synonyms env nexp)) + end | E_constraint nc -> - Env.wf_constraint env nc; + Env.wf_constraint ~at:l env nc; crule check_exp env (rewrite_nc env (Env.expand_constraint_synonyms env nc)) (atom_bool_typ nc) | E_field (exp, field) -> begin let inferred_exp = irule infer_exp env exp in @@ -3348,7 +3357,7 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) = let typ = Env.get_register id env in annot_exp (E_ref id) (register_typ typ) | E_internal_assume (nc, exp) -> - Env.wf_constraint env nc; + Env.wf_constraint ~at:l env nc; let env = Env.add_constraint nc env in let exp' = irule infer_exp env exp in annot_exp (E_internal_assume (nc, exp')) (typ_of exp') @@ -4305,8 +4314,15 @@ let check_record l env def_annot id typq fields = in Env.add_record id typq fields env +let check_global_constraint env def_annot nc = + let env = Env.add_constraint ~global:true nc env in + if prove __POS__ env nc_false then + typ_error def_annot.loc "Global constraint appears inconsistent with previous global constraints"; + ([DEF_aux (DEF_constraint nc, def_annot)], env) + let rec check_typedef : Env.t -> def_annot -> uannot type_def -> tannot def list * Env.t = fun env def_annot (TD_aux (tdef, (l, _))) -> + typ_print (lazy ("\n" ^ Util.("Check type " |> cyan |> clear) ^ string_of_id (id_of_type_def_aux tdef))); match tdef with | TD_abbrev (id, typq, typ_arg) -> begin @@ -4315,6 +4331,8 @@ let rec check_typedef : Env.t -> def_annot -> uannot type_def -> tannot def list | _ -> () end; ([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_typ_synonym id typq typ_arg env) + | TD_abstract (id, kind) -> + ([DEF_aux (DEF_type (TD_aux (tdef, (l, empty_tannot))), def_annot)], Env.add_abstract_typ id kind env) | TD_record (id, typq, fields, _) -> let env = check_record l env def_annot id typq fields in begin @@ -4625,6 +4643,7 @@ and check_def : Env.t -> uannot def -> tannot def list * Env.t = match aux with | DEF_fixity (prec, n, op) -> ([DEF_aux (DEF_fixity (prec, n, op), def_annot)], env) | DEF_type tdef -> check_typedef env def_annot tdef + | DEF_constraint nc -> check_global_constraint env def_annot nc | DEF_fundef fdef -> check_fundef env def_annot fdef | DEF_mapdef mdef -> check_mapdef env def_annot mdef | DEF_impl funcl -> check_impldef env def_annot funcl diff --git a/src/lib/type_check.mli b/src/lib/type_check.mli index ca825ab7f..dbc0dc8c2 100644 --- a/src/lib/type_check.mli +++ b/src/lib/type_check.mli @@ -147,7 +147,7 @@ module Env : sig (** Get the current set of constraints. *) val get_constraints : t -> n_constraint list - val add_constraint : ?reason:Ast.l * string -> n_constraint -> t -> t + val add_constraint : ?global:bool -> ?reason:Ast.l * string -> n_constraint -> t -> t (** Push all the type variables and constraints from a typquant into an environment *) @@ -512,4 +512,4 @@ val initial_env : Env.t (** The initial type checking environment, with a specific set of available modules. *) val initial_env_with_modules : Project.project_structure -> Env.t -val prove_smt : assumptions:n_constraint list -> n_constraint -> bool +val prove_smt : abstract:kind Bindings.t -> assumptions:n_constraint list -> n_constraint -> bool diff --git a/src/lib/type_env.ml b/src/lib/type_env.ml index 9015a1abf..921d29dbe 100644 --- a/src/lib/type_env.ml +++ b/src/lib/type_env.ml @@ -108,6 +108,8 @@ type global_env = { unions : (typquant * type_union list) env_item Bindings.t; union_ids : (typquant * typ) env_item Bindings.t; scattered_union_envs : global_env Bindings.t; + abstract_typs : kind env_item Bindings.t; + constraints : (constraint_reason * n_constraint) list; enums : (bool * IdSet.t) env_item Bindings.t; records : (typquant * (typ * id) list) env_item Bindings.t; synonyms : (typquant * typ_arg) env_item Bindings.t; @@ -132,6 +134,8 @@ let empty_global_env = unions = Bindings.empty; union_ids = Bindings.empty; scattered_union_envs = Bindings.empty; + abstract_typs = Bindings.empty; + constraints = []; enums = Bindings.empty; records = Bindings.empty; accessors = IdPairMap.empty; @@ -382,6 +386,7 @@ let builtin_typs = let bound_typ_id env id = Bindings.mem id env.global.synonyms || Bindings.mem id env.global.unions || Bindings.mem id env.global.records || Bindings.mem id env.global.enums || Bindings.mem id builtin_typs + || Bindings.mem id env.global.abstract_typs let get_binding_loc env id = let find map = Some (item_loc (Bindings.find id map)) in @@ -390,6 +395,7 @@ let get_binding_loc env id = else if Bindings.mem id env.global.records then find env.global.records else if Bindings.mem id env.global.enums then find env.global.enums else if Bindings.mem id env.global.synonyms then find env.global.synonyms + else if Bindings.mem id env.global.abstract_typs then find env.global.abstract_typs else None let already_bound str id env = @@ -510,6 +516,7 @@ let infer_kind env id = else if Bindings.mem id env.global.enums then mk_typquant [] else if Bindings.mem id env.global.synonyms then typ_error (id_loc id) ("Cannot infer kind of type synonym " ^ string_of_id id) + else if Bindings.mem id env.global.abstract_typs then mk_typquant [] else typ_error (id_loc id) ("Cannot infer kind of " ^ string_of_id id) let check_args_typquant id env args typq = @@ -530,6 +537,10 @@ let check_args_typquant id env args typq = typ_error (id_loc id) ("Could not prove " ^ string_of_list ", " string_of_n_constraint ncs ^ " for type constructor " ^ string_of_id id) +let get_constraints env = List.map snd env.global.constraints @ List.map snd env.constraints + +let get_constraint_reasons env = env.global.constraints @ env.constraints + let mk_synonym typq typ_arg = let kopts, ncs = quant_split typq in let kopts = List.map (fun kopt -> (kopt, fresh_existential (kopt_loc kopt) (unaux_kind (kopt_kind kopt)))) kopts in @@ -571,7 +582,7 @@ let mk_synonym typq typ_arg = ("Could not prove constraints " ^ string_of_list ", " string_of_n_constraint ncs ^ " in type synonym " ^ string_of_typ_arg typ_arg ^ " with " - ^ Util.string_of_list ", " string_of_n_constraint (List.map snd env.constraints) + ^ Util.string_of_list ", " string_of_n_constraint (get_constraints env) ) let get_typ_synonym id env = @@ -581,141 +592,150 @@ let get_typ_synonym id env = let get_typ_synonyms env = filter_items env env.global.synonyms -let get_constraints env = List.map snd env.constraints - -let get_constraint_reasons env = env.constraints - -let wf_debug str f x exs = - typ_debug ~level:2 - (lazy ("wf_" ^ str ^ ": " ^ f x ^ " exs: " ^ Util.string_of_list ", " string_of_kid (KidSet.elements exs))) -[@@coverage off] - -(* Check if a type, order, n-expression or constraint is - well-formed. Throws a type error if the type is badly formed. *) -let rec wf_typ' ?(exs = KidSet.empty) env (Typ_aux (typ_aux, l) as typ) = - match typ_aux with - | Typ_id id when bound_typ_id env id -> - let typq = infer_kind env id in - if quant_kopts typq != [] then - typ_error l ("Type constructor " ^ string_of_id id ^ " expected " ^ string_of_typquant typq) - else () - | Typ_id id -> typ_error l ("Undefined type " ^ string_of_id id) - | Typ_var kid -> begin - match KBindings.find kid env.typ_vars with - | _, K_type -> () - | _, k -> - typ_error l - ("Type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ ^ " is " ^ string_of_kind_aux k - ^ " rather than Type" - ) - | exception Not_found -> - typ_error l ("Unbound type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ) - end - | Typ_fn (arg_typs, ret_typ) -> - List.iter (wf_typ' ~exs env) arg_typs; - wf_typ' ~exs env ret_typ - | Typ_bidir (typ1, typ2) when unloc_typ typ1 = unloc_typ typ2 -> - typ_error l "Bidirectional types cannot be the same on both sides" - | Typ_bidir (typ1, typ2) -> - wf_typ' ~exs env typ1; - wf_typ' ~exs env typ2 - | Typ_tuple typs -> List.iter (wf_typ' ~exs env) typs - | Typ_app (id, [(A_aux (A_nexp _, _) as arg)]) when string_of_id id = "implicit" -> wf_typ_arg ~exs env arg - | Typ_app (id, args) when bound_typ_id env id -> - List.iter (wf_typ_arg ~exs env) args; - check_args_typquant id env args (infer_kind env id) - | Typ_app (id, _) -> typ_error l ("Undefined type " ^ string_of_id id) - | Typ_exist ([], _, _) -> typ_error l "Existential must have some type variables" - | Typ_exist (kopts, nc, typ) when KidSet.is_empty exs -> - wf_constraint ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env nc; - wf_typ' ~exs:(KidSet.of_list (List.map kopt_kid kopts)) env typ - | Typ_exist (_, _, _) -> typ_error l "Nested existentials are not allowed" - | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" [@coverage off] - -and wf_typ_arg ?(exs = KidSet.empty) env (A_aux (typ_arg_aux, _)) = - match typ_arg_aux with - | A_nexp nexp -> wf_nexp ~exs env nexp - | A_typ typ -> wf_typ' ~exs env typ - | A_bool nc -> wf_constraint ~exs env nc - -and wf_nexp ?(exs = KidSet.empty) env (Nexp_aux (nexp_aux, l) as nexp) = - wf_debug "nexp" string_of_nexp nexp exs; - match nexp_aux with - | Nexp_id id -> typ_error l ("Undefined type synonym " ^ string_of_id id) - | Nexp_var kid when KidSet.mem kid exs -> () - | Nexp_var kid -> begin - match get_typ_var kid env with - | K_int -> () - | kind -> - typ_error l - ("Constraint is badly formed, " ^ string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind - ^ " but should have kind Int" - ) - end - | Nexp_constant _ -> () - | Nexp_app (id, nexps) -> - let name = string_of_id id in - (* We allow the abs, mod, and div functions that are included in the SMTLIB2 integer theory *) - if name = "abs" || name = "mod" || name = "div" || Bindings.mem id env.global.synonyms then - List.iter (fun n -> wf_nexp ~exs env n) nexps - else typ_error l ("Unknown type level operator or function " ^ name) - | Nexp_times (nexp1, nexp2) -> - wf_nexp ~exs env nexp1; - wf_nexp ~exs env nexp2 - | Nexp_sum (nexp1, nexp2) -> - wf_nexp ~exs env nexp1; - wf_nexp ~exs env nexp2 - | Nexp_minus (nexp1, nexp2) -> - wf_nexp ~exs env nexp1; - wf_nexp ~exs env nexp2 - | Nexp_exp nexp -> wf_nexp ~exs env nexp (* MAYBE: Could put restrictions on what is allowed here *) - | Nexp_neg nexp -> wf_nexp ~exs env nexp - -and wf_constraint ?(exs = KidSet.empty) env (NC_aux (nc_aux, l) as nc) = - wf_debug "constraint" string_of_n_constraint nc exs; - match nc_aux with - | NC_equal (n1, n2) -> - wf_nexp ~exs env n1; - wf_nexp ~exs env n2 - | NC_not_equal (n1, n2) -> - wf_nexp ~exs env n1; - wf_nexp ~exs env n2 - | NC_bounded_ge (n1, n2) -> - wf_nexp ~exs env n1; - wf_nexp ~exs env n2 - | NC_bounded_gt (n1, n2) -> - wf_nexp ~exs env n1; - wf_nexp ~exs env n2 - | NC_bounded_le (n1, n2) -> - wf_nexp ~exs env n1; - wf_nexp ~exs env n2 - | NC_bounded_lt (n1, n2) -> - wf_nexp ~exs env n1; - wf_nexp ~exs env n2 - | NC_set (kid, _) when KidSet.mem kid exs -> () - | NC_set (kid, _) -> begin - match get_typ_var kid env with - | K_int -> () - | kind -> - typ_error l - ("Set constraint is badly formed, " ^ string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind - ^ " but should have kind Int" - ) - end - | NC_or (nc1, nc2) -> - wf_constraint ~exs env nc1; - wf_constraint ~exs env nc2 - | NC_and (nc1, nc2) -> - wf_constraint ~exs env nc1; - wf_constraint ~exs env nc2 - | NC_app (_, args) -> List.iter (wf_typ_arg ~exs env) args - | NC_var kid when KidSet.mem kid exs -> () - | NC_var kid -> begin - match get_typ_var kid env with - | K_bool -> () - | kind -> typ_error l (string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind ^ " but should have kind Bool") - end - | NC_true | NC_false -> () +module Well_formedness = struct + let wf_debug str f x exs = + typ_debug ~level:2 + (lazy ("wf_" ^ str ^ ": " ^ f x ^ " exs: " ^ Util.string_of_list ", " string_of_kid (KidSet.elements exs))) + [@@coverage off] + + (* Check if a type, order, n-expression or constraint is + well-formed. Throws a type error if the type is badly formed. *) + let rec wf_typ exs env (Typ_aux (typ_aux, l) as typ) = + match typ_aux with + | Typ_id id when bound_typ_id env id -> + let typq = infer_kind env id in + if not (Util.list_empty (quant_kopts typq)) then + typ_error l ("Type constructor " ^ string_of_id id ^ " expected " ^ string_of_typquant typq) + else () + | Typ_id id -> typ_error l ("Undefined type " ^ string_of_id id) + | Typ_var kid -> begin + match KBindings.find kid env.typ_vars with + | _, K_type -> () + | _, k -> + typ_error l + ("Type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ ^ " is " ^ string_of_kind_aux k + ^ " rather than Type" + ) + | exception Not_found -> + typ_error l ("Unbound type variable " ^ string_of_kid kid ^ " in type " ^ string_of_typ typ) + end + | Typ_fn (arg_typs, ret_typ) -> + List.iter (wf_typ exs env) arg_typs; + wf_typ exs env ret_typ + | Typ_bidir (typ1, typ2) when unloc_typ typ1 = unloc_typ typ2 -> + typ_error l "Bidirectional types cannot be the same on both sides" + | Typ_bidir (typ1, typ2) -> + wf_typ exs env typ1; + wf_typ exs env typ2 + | Typ_tuple typs -> List.iter (wf_typ exs env) typs + | Typ_app (id, [(A_aux (A_nexp _, _) as arg)]) when string_of_id id = "implicit" -> wf_typ_arg exs env arg + | Typ_app (id, args) when bound_typ_id env id -> + List.iter (wf_typ_arg exs env) args; + check_args_typquant id env args (infer_kind env id) + | Typ_app (id, _) -> typ_error l ("Undefined type " ^ string_of_id id) + | Typ_exist ([], _, _) -> typ_error l "Existential must have some type variables" + | Typ_exist (kopts, nc, typ) when KidSet.is_empty exs -> + wf_constraint (KidSet.of_list (List.map kopt_kid kopts)) env nc; + wf_typ (KidSet.of_list (List.map kopt_kid kopts)) env typ + | Typ_exist (_, _, _) -> typ_error l "Nested existentials are not allowed" + | Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown" [@coverage off] + + and wf_typ_arg exs env (A_aux (typ_arg_aux, _)) = + match typ_arg_aux with + | A_nexp nexp -> wf_nexp exs env nexp + | A_typ typ -> wf_typ exs env typ + | A_bool nc -> wf_constraint exs env nc + + and wf_nexp exs env (Nexp_aux (nexp_aux, l) as nexp) = + wf_debug "nexp" string_of_nexp nexp exs; + match nexp_aux with + | Nexp_id id when Bindings.mem id env.global.abstract_typs -> () + | Nexp_id id -> typ_error l ("Undefined type synonym " ^ string_of_id id) + | Nexp_var kid when KidSet.mem kid exs -> () + | Nexp_var kid -> begin + match get_typ_var kid env with + | K_int -> () + | kind -> + typ_error l + ("Constraint is badly formed, " ^ string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind + ^ " but should have kind Int" + ) + end + | Nexp_constant _ -> () + | Nexp_app (id, nexps) -> + let name = string_of_id id in + (* We allow the abs, mod, and div functions that are included in the SMTLIB2 integer theory *) + if name = "abs" || name = "mod" || name = "div" || Bindings.mem id env.global.synonyms then + List.iter (fun n -> wf_nexp exs env n) nexps + else typ_error l ("Unknown type level operator or function " ^ name) + | Nexp_times (nexp1, nexp2) -> + wf_nexp exs env nexp1; + wf_nexp exs env nexp2 + | Nexp_sum (nexp1, nexp2) -> + wf_nexp exs env nexp1; + wf_nexp exs env nexp2 + | Nexp_minus (nexp1, nexp2) -> + wf_nexp exs env nexp1; + wf_nexp exs env nexp2 + | Nexp_exp nexp -> wf_nexp exs env nexp (* MAYBE: Could put restrictions on what is allowed here *) + | Nexp_neg nexp -> wf_nexp exs env nexp + + and wf_constraint exs env (NC_aux (nc_aux, l) as nc) = + wf_debug "constraint" string_of_n_constraint nc exs; + match nc_aux with + | NC_id id when Bindings.mem id env.global.abstract_typs -> () + | NC_id id -> typ_error l ("Undefined type synonym " ^ string_of_id id) + | NC_equal (n1, n2) -> + wf_nexp exs env n1; + wf_nexp exs env n2 + | NC_not_equal (n1, n2) -> + wf_nexp exs env n1; + wf_nexp exs env n2 + | NC_bounded_ge (n1, n2) -> + wf_nexp exs env n1; + wf_nexp exs env n2 + | NC_bounded_gt (n1, n2) -> + wf_nexp exs env n1; + wf_nexp exs env n2 + | NC_bounded_le (n1, n2) -> + wf_nexp exs env n1; + wf_nexp exs env n2 + | NC_bounded_lt (n1, n2) -> + wf_nexp exs env n1; + wf_nexp exs env n2 + | NC_set (nexp, _) -> wf_nexp exs env nexp + | NC_or (nc1, nc2) -> + wf_constraint exs env nc1; + wf_constraint exs env nc2 + | NC_and (nc1, nc2) -> + wf_constraint exs env nc1; + wf_constraint exs env nc2 + | NC_app (_, args) -> List.iter (wf_typ_arg exs env) args + | NC_var kid when KidSet.mem kid exs -> () + | NC_var kid -> begin + match get_typ_var kid env with + | K_bool -> () + | kind -> typ_error l (string_of_kid kid ^ " has kind " ^ string_of_kind_aux kind ^ " but should have kind Bool") + end + | NC_true | NC_false -> () +end + +let add_abstract_typ id kind env = + if bound_typ_id env id then + typ_error (id_loc id) + ("Cannot introduce abstract type " ^ string_of_id id ^ " as a type or synonym with that name already exists") + else ( + typ_print (lazy (adding ^ "abstract type " ^ string_of_id id ^ " : " ^ string_of_kind kind)) [@coverage off]; + update_global + (fun global -> + { global with abstract_typs = Bindings.add id (mk_item env ~loc:(id_loc id) kind) global.abstract_typs } + ) + env + ) + +let get_abstract_typs env = filter_items env env.global.abstract_typs + +let is_abstract_typ id env = Bindings.mem id env.global.abstract_typs let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) = match aux with @@ -737,6 +757,16 @@ let rec expand_constraint_synonyms env (NC_aux (aux, l) as nc) = end with Not_found -> NC_aux (NC_app (id, List.map (expand_arg_synonyms env) args), l) ) + | NC_id id -> ( + try + begin + match get_typ_synonym id env l env [] with + | A_aux (A_bool nc, _) -> expand_constraint_synonyms env nc + | arg -> + typ_error l ("Expected Bool when expanding synonym " ^ string_of_id id ^ " got " ^ string_of_typ_arg arg) + end + with Not_found -> nc + ) | NC_true | NC_false | NC_var _ | NC_set _ -> nc and expand_nexp_synonyms env (Nexp_aux (aux, l) as nexp) = @@ -834,9 +864,9 @@ and expand_arg_synonyms env (A_aux (typ_arg, l)) = | A_bool nc -> A_aux (A_bool (expand_constraint_synonyms env nc), l) | A_nexp nexp -> A_aux (A_nexp (expand_nexp_synonyms env nexp), l) -and add_constraint ?reason constr env = +and add_constraint ?(global = false) ?reason constr env = let (NC_aux (nc_aux, l) as constr) = constraint_simp (expand_constraint_synonyms env constr) in - wf_constraint env constr; + Well_formedness.wf_constraint KidSet.empty env constr; let power_vars = constraint_power_variables constr in if KidSet.cardinal power_vars > 1 && !opt_smt_linearize then typ_error l @@ -847,7 +877,7 @@ and add_constraint ?reason constr env = let v = KidSet.choose power_vars in let constrs = List.fold_left nc_and nc_true (get_constraints env) in begin - match Constraint.solve_all_smt l constrs v with + match Constraint.solve_all_smt l (get_abstract_typs env) constrs v with | Some solutions -> typ_print ( lazy @@ -877,31 +907,47 @@ and add_constraint ?reason constr env = | NC_true -> env | _ -> typ_print (lazy (adding ^ "constraint " ^ string_of_n_constraint constr)) [@coverage off]; - { env with constraints = (reason, constr) :: env.constraints } + if global then + update_global + (fun global_env -> { global_env with constraints = (reason, constr) :: global_env.constraints }) + env + else { env with constraints = (reason, constr) :: env.constraints } ) -let add_typquant l quant env = - let rec add_quant_item env = function QI_aux (qi, _) -> add_quant_item_aux env qi - and add_quant_item_aux env = function - | QI_constraint constr -> add_constraint constr env - | QI_id kopt -> add_typ_var l kopt env - in - match quant with - | TypQ_aux (TypQ_no_forall, _) -> env - | TypQ_aux (TypQ_tq quants, _) -> List.fold_left add_quant_item env quants - let wf_typ ~at:at_l env (Typ_aux (_, l) as typ) = let typ = expand_synonyms env typ in - wf_debug "typ" string_of_typ typ KidSet.empty; + Well_formedness.wf_debug "typ" string_of_typ typ KidSet.empty; incr depth; try - wf_typ' env typ; + Well_formedness.wf_typ KidSet.empty env typ; decr depth with Type_error (err_l, err) -> decr depth; let extra, l = match l with Parse_ast.Unknown -> (" here", at_l) | _ -> ("", l) in typ_raise l (err_because (Err_other ("Well-formedness check failed for type" ^ extra), err_l, err)) +let wf_constraint ~at:at_l env (NC_aux (_, l) as nc) = + let nc = expand_constraint_synonyms env nc in + Well_formedness.wf_debug "constraint" string_of_n_constraint nc KidSet.empty; + incr depth; + try + Well_formedness.wf_constraint KidSet.empty env nc; + decr depth + with Type_error (err_l, err) -> + decr depth; + let extra, l = match l with Parse_ast.Unknown -> (" here", at_l) | _ -> ("", l) in + typ_raise l (err_because (Err_other ("Well-formedness check failed for constraint" ^ extra), err_l, err)) + +let add_typquant l quant env = + let rec add_quant_item env = function QI_aux (qi, _) -> add_quant_item_aux env qi + and add_quant_item_aux env = function + | QI_constraint constr -> add_constraint constr env + | QI_id kopt -> add_typ_var l kopt env + in + match quant with + | TypQ_aux (TypQ_no_forall, _) -> env + | TypQ_aux (TypQ_tq quants, _) -> List.fold_left add_quant_item env quants + let add_typ_synonym id typq arg env = if bound_typ_id env id then typ_error (id_loc id) diff --git a/src/lib/type_env.mli b/src/lib/type_env.mli index 8cfe39cb7..7755547b7 100644 --- a/src/lib/type_env.mli +++ b/src/lib/type_env.mli @@ -143,6 +143,10 @@ val is_user_undefined : id -> t -> bool val allow_user_undefined : id -> t -> t +val add_abstract_typ : id -> kind -> t -> t +val is_abstract_typ : id -> t -> bool +val get_abstract_typs : t -> kind Bindings.t + val is_variant : id -> t -> bool val add_variant : id -> typquant * type_union list -> t -> t val add_scattered_variant : id -> typquant -> t -> t @@ -180,7 +184,7 @@ val add_register : id -> typ -> t -> t val get_constraints : t -> n_constraint list val get_constraint_reasons : t -> ((Ast.l * string) option * n_constraint) list -val add_constraint : ?reason:Ast.l * string -> n_constraint -> t -> t +val add_constraint : ?global:bool -> ?reason:Ast.l * string -> n_constraint -> t -> t val add_typquant : l -> typquant -> t -> t @@ -244,7 +248,7 @@ val is_toplevel : t -> l option (* Well formedness-checks *) val wf_typ : at:l -> t -> typ -> unit -val wf_constraint : ?exs:KidSet.t -> t -> n_constraint -> unit +val wf_constraint : at:l -> t -> n_constraint -> unit (** Some of the code in the environment needs to use the smt solver, which is defined below. To break the circularity this would cause diff --git a/src/lib/type_internal.ml b/src/lib/type_internal.ml index 98d43ecd4..f089440fe 100644 --- a/src/lib/type_internal.ml +++ b/src/lib/type_internal.ml @@ -144,13 +144,14 @@ and unloc_nexp_aux = function and unloc_nexp = function Nexp_aux (nexp_aux, _) -> Nexp_aux (unloc_nexp_aux nexp_aux, Parse_ast.Unknown) and unloc_n_constraint_aux = function + | NC_id id -> NC_id (unloc_id id) | NC_equal (nexp1, nexp2) -> NC_equal (unloc_nexp nexp1, unloc_nexp nexp2) | NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (unloc_nexp nexp1, unloc_nexp nexp2) | NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (unloc_nexp nexp1, unloc_nexp nexp2) | NC_bounded_le (nexp1, nexp2) -> NC_bounded_le (unloc_nexp nexp1, unloc_nexp nexp2) | NC_bounded_lt (nexp1, nexp2) -> NC_bounded_lt (unloc_nexp nexp1, unloc_nexp nexp2) | NC_not_equal (nexp1, nexp2) -> NC_not_equal (unloc_nexp nexp1, unloc_nexp nexp2) - | NC_set (kid, nums) -> NC_set (unloc_kid kid, nums) + | NC_set (nexp, nums) -> NC_set (unloc_nexp nexp, nums) | NC_or (nc1, nc2) -> NC_or (unloc_n_constraint nc1, unloc_n_constraint nc2) | NC_and (nc1, nc2) -> NC_and (unloc_n_constraint nc1, unloc_n_constraint nc2) | NC_var kid -> NC_var (unloc_kid kid) @@ -222,7 +223,8 @@ and constraint_nexps (NC_aux (nc_aux, _)) = | NC_bounded_lt (n1, n2) | NC_not_equal (n1, n2) -> [n1; n2] - | NC_set _ | NC_true | NC_false | NC_var _ -> [] + | NC_id _ | NC_true | NC_false | NC_var _ -> [] + | NC_set (n, _) -> [n] | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> constraint_nexps nc1 @ constraint_nexps nc2 | NC_app (_, args) -> List.concat (List.map typ_arg_nexps args) diff --git a/src/sail_coq_backend/pretty_print_coq.ml b/src/sail_coq_backend/pretty_print_coq.ml index 67e399702..7d9c9e8c4 100644 --- a/src/sail_coq_backend/pretty_print_coq.ml +++ b/src/sail_coq_backend/pretty_print_coq.ml @@ -320,12 +320,12 @@ let rec orig_nc (NC_aux (nc, l) as full_nc) = | NC_bounded_le (nexp1, nexp2) -> rewrap (NC_bounded_le (orig_nexp nexp1, orig_nexp nexp2)) | NC_bounded_lt (nexp1, nexp2) -> rewrap (NC_bounded_lt (orig_nexp nexp1, orig_nexp nexp2)) | NC_not_equal (nexp1, nexp2) -> rewrap (NC_not_equal (orig_nexp nexp1, orig_nexp nexp2)) - | NC_set (kid, s) -> rewrap (NC_set (orig_kid kid, s)) + | NC_set (nexp, s) -> rewrap (NC_set (orig_nexp nexp, s)) | NC_or (nc1, nc2) -> rewrap (NC_or (orig_nc nc1, orig_nc nc2)) | NC_and (nc1, nc2) -> rewrap (NC_and (orig_nc nc1, orig_nc nc2)) | NC_app (f, args) -> rewrap (NC_app (f, List.map orig_typ_arg args)) | NC_var kid -> rewrap (NC_var (orig_kid kid)) - | NC_true | NC_false -> full_nc + | NC_id _ | NC_true | NC_false -> full_nc and orig_typ_arg (A_aux (arg, l)) = let rewrap a = A_aux (a, l) in @@ -419,7 +419,8 @@ let rec count_nc_vars (NC_aux (nc, _)) = in match nc with | NC_or (nc1, nc2) | NC_and (nc1, nc2) -> merge_kid_count (count_nc_vars nc1) (count_nc_vars nc2) - | NC_var kid | NC_set (kid, _) -> KBindings.singleton kid 1 + | NC_var kid -> KBindings.singleton kid 1 + | NC_set (n, _) -> count_nexp_vars n | NC_equal (n1, n2) | NC_bounded_ge (n1, n2) | NC_bounded_gt (n1, n2) @@ -427,7 +428,7 @@ let rec count_nc_vars (NC_aux (nc, _)) = | NC_bounded_lt (n1, n2) | NC_not_equal (n1, n2) -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) - | NC_true | NC_false -> KBindings.empty + | NC_id _ | NC_true | NC_false -> KBindings.empty | NC_app (_, args) -> List.fold_left merge_kid_count KBindings.empty (List.map count_arg args) (* Simplify some of the complex boolean types created by the Sail type checker, @@ -459,7 +460,7 @@ let simplify_atom_bool l kopts nc atom_nc = | NC_bounded_lt (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid | NC_not_equal (Nexp_aux (Nexp_var kid, _), _) when KBindings.mem kid lin_ty_vars -> Some kid | NC_not_equal (_, Nexp_aux (Nexp_var kid, _)) when KBindings.mem kid lin_ty_vars -> Some kid - | NC_set (kid, _ :: _) when KBindings.mem kid lin_ty_vars -> Some kid + | NC_set (Nexp_aux (Nexp_var kid, _), _ :: _) when KBindings.mem kid lin_ty_vars -> Some kid | _ -> None in let replace kills vars = @@ -816,17 +817,18 @@ and doc_nc_exp ctx env nc = match nc with | NC_not_equal (ne1, ne2) -> string "negb" ^^ space ^^ parens (doc_op (string "=?") (doc_nexp ctx ne1) (doc_nexp ctx ne2)) - | NC_set (kid, is) -> + | NC_set (nexp, is) -> separate space [ string "member_Z_list"; - doc_var ctx kid; + doc_nexp ctx nexp; brackets (separate (string "; ") (List.map (fun i -> string (Nat_big_num.to_string i)) is)); ] | NC_app (f, args) -> separate space (doc_nc_fn ctx f :: List.map doc_typ_arg_exp args) | _ -> l0 nc_full and l0 (NC_aux (nc, _) as nc_full) = match nc with + | NC_id id -> doc_id_type ctx.types_mod ctx.avoid_target_names (Some env) id | NC_true -> string "true" | NC_false -> string "false" | NC_var kid -> doc_nexp ctx (nvar kid) @@ -2326,7 +2328,8 @@ let types_used_with_generic_eq defs = in let typs_req_def (DEF_aux (aux, _) as def) = match aux with - | DEF_type _ | DEF_val _ | DEF_fixity _ | DEF_overload _ | DEF_default _ | DEF_pragma _ | DEF_register _ -> + | DEF_type _ | DEF_constraint _ | DEF_val _ | DEF_fixity _ | DEF_overload _ | DEF_default _ | DEF_pragma _ + | DEF_register _ -> IdSet.empty | DEF_fundef fd -> typs_req_fundef fd | DEF_internal_mutrec fds -> List.fold_left IdSet.union IdSet.empty (List.map typs_req_fundef fds) @@ -2387,6 +2390,7 @@ let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, an ^^ dot ^^ hardline ^^ separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."] ^^ twice hardline + | TD_abstract _ -> unreachable l __POS__ "Abstract type not supported by Coq backend" | TD_bitfield _ -> empty (* TODO? *) | TD_record (id, typq, fs, _) -> let fname fid = doc_field_name bare_ctxt id fid in @@ -3350,6 +3354,7 @@ let doc_def types_mod unimplemented avoid_target_names generic_eq_types effect_i ("Loop termination measures for " ^ string_of_id id ^ " should have been rewritten before backend") | DEF_impl _ | DEF_outcome _ | DEF_instantiation _ -> unreachable (def_loc def) __POS__ "Event definition should have been rewritten before backend" + | DEF_constraint _ -> unreachable (def_loc def) __POS__ "Abstract constraint not supported by Coq backend" let find_exc_typ defs = let is_exc_typ_def = function diff --git a/src/sail_latex_backend/latex.ml b/src/sail_latex_backend/latex.ml index e55dbdb95..e987928e7 100644 --- a/src/sail_latex_backend/latex.ml +++ b/src/sail_latex_backend/latex.ml @@ -470,6 +470,7 @@ let process_pragma l command = None let tdef_id = function + | TD_abstract (id, _) -> id | TD_abbrev (id, _, _) -> id | TD_record (id, _, _, _) -> id | TD_variant (id, _, _, _) -> id diff --git a/src/sail_lem_backend/pretty_print_lem.ml b/src/sail_lem_backend/pretty_print_lem.ml index 69578c100..8be214a2d 100644 --- a/src/sail_lem_backend/pretty_print_lem.ml +++ b/src/sail_lem_backend/pretty_print_lem.ml @@ -1715,6 +1715,7 @@ let doc_def_lem effect_info params_to_print type_env (DEF_aux (aux, _) as def) = | DEF_type t_def -> if List.mem (string_of_id (id_of_type_def t_def)) !opt_extern_types then empty else group (doc_typdef_lem params_to_print type_env t_def) ^/^ hardline + | DEF_constraint _ -> unreachable (def_loc def) __POS__ "Toplevel constraint not supported by lem backend" | DEF_register dec -> group (doc_dec_lem dec) | DEF_default df -> empty | DEF_fundef fdef -> group (doc_fundef_lem effect_info params_to_print type_env fdef) ^/^ hardline diff --git a/src/sail_ocaml_backend/ocaml_backend.ml b/src/sail_ocaml_backend/ocaml_backend.ml index 468e4f0af..500b73ccf 100644 --- a/src/sail_ocaml_backend/ocaml_backend.ml +++ b/src/sail_ocaml_backend/ocaml_backend.ml @@ -809,6 +809,7 @@ let ocaml_typedef ctx (TD_aux (td_aux, (l, _))) = separate space [string "type"; ocaml_typquant typq; zencode ctx id; equals; ocaml_typ ctx typ] ^^ ocaml_def_end ^^ ocaml_string_of_abbrev ctx id typq typ ^^ ocaml_def_end | TD_abbrev _ -> empty + | TD_abstract _ -> Reporting.unreachable l __POS__ "Abstract type not supported in OCaml backend" | TD_bitfield _ -> Reporting.unreachable l __POS__ "Bitfield should be re-written" let get_externs defs = @@ -895,6 +896,7 @@ let ocaml_pp_generators ctx defs orig_types required = match td with | TD_abbrev (_, _, A_aux (A_typ typ, _)) -> add_req_from_typ required typ | TD_abbrev _ -> required + | TD_abstract _ -> required | TD_record (_, _, fields, _) -> List.fold_left (fun req (typ, _) -> add_req_from_typ req typ) required fields | TD_variant (_, _, variants, _) -> List.fold_left (fun req (Tu_aux (Tu_ty_id (typ, _), _)) -> add_req_from_typ req typ) required variants @@ -911,7 +913,7 @@ let ocaml_pp_generators ctx defs orig_types required = | TD_abbrev (_, tqs, A_aux (A_typ _, _)) -> tqs | TD_record (_, tqs, _, _) -> tqs | TD_variant (_, tqs, _, _) -> tqs - | TD_enum _ -> TypQ_aux (TypQ_no_forall, Unknown) + | TD_abstract _ | TD_enum _ -> TypQ_aux (TypQ_no_forall, Unknown) | TD_abbrev (_, _, _) -> assert false | TD_bitfield _ -> assert false ) diff --git a/test/typecheck/fail/abstract_bool_inconsistent.expect b/test/typecheck/fail/abstract_bool_inconsistent.expect new file mode 100644 index 000000000..2d164ed1d --- /dev/null +++ b/test/typecheck/fail/abstract_bool_inconsistent.expect @@ -0,0 +1,5 @@ +Type error: +fail/abstract_bool_inconsistent.sail:8.0-17: +8 |constraint not(b) +  |^---------------^ +  | Global constraint appears inconsistent with previous global constraints diff --git a/test/typecheck/fail/abstract_bool_inconsistent.sail b/test/typecheck/fail/abstract_bool_inconsistent.sail new file mode 100644 index 000000000..05fdbc14b --- /dev/null +++ b/test/typecheck/fail/abstract_bool_inconsistent.sail @@ -0,0 +1,8 @@ +default Order dec + +$include + +type b : Bool + +constraint b +constraint not(b) diff --git a/test/typecheck/fail/global_false_constraint.expect b/test/typecheck/fail/global_false_constraint.expect new file mode 100644 index 000000000..24b52c54b --- /dev/null +++ b/test/typecheck/fail/global_false_constraint.expect @@ -0,0 +1,5 @@ +Type error: +fail/global_false_constraint.sail:2.0-16: +2 |constraint false +  |^--------------^ +  | Global constraint appears inconsistent with previous global constraints diff --git a/test/typecheck/fail/global_false_constraint.sail b/test/typecheck/fail/global_false_constraint.sail new file mode 100644 index 000000000..7ffb82660 --- /dev/null +++ b/test/typecheck/fail/global_false_constraint.sail @@ -0,0 +1,2 @@ + +constraint false diff --git a/test/typecheck/pass/abstract_bool.sail b/test/typecheck/pass/abstract_bool.sail new file mode 100644 index 000000000..365953f28 --- /dev/null +++ b/test/typecheck/pass/abstract_bool.sail @@ -0,0 +1,21 @@ +default Order dec + +$include + +type b : Bool + +val some_int : unit -> int + +val only_true : bool(true) -> unit + +val test : bool(b) -> unit + +function test(b) = { + let x = b; + let y = some_int(); + if x | y == 32 then { + if y != 32 then { + only_true(x) + } + } +} diff --git a/test/typecheck/pass/abstract_bool2.sail b/test/typecheck/pass/abstract_bool2.sail new file mode 100644 index 000000000..741f5a355 --- /dev/null +++ b/test/typecheck/pass/abstract_bool2.sail @@ -0,0 +1,16 @@ +default Order dec + +$include + +type b : Bool + +constraint b + +val only_true : bool(true) -> unit + +val test : bool(b) -> unit + +function test(b) = { + let x = b; + only_true(x) +} diff --git a/test/typecheck/pass/complex_exist_sat/v2.expect b/test/typecheck/pass/complex_exist_sat/v2.expect index 1a98f3339..e7d253d79 100644 --- a/test/typecheck/pass/complex_exist_sat/v2.expect +++ b/test/typecheck/pass/complex_exist_sat/v2.expect @@ -3,7 +3,7 @@ 3 |function foo(x) = 4  | ^  | int(4) is not a subtype of {('q : Int), 'q in {0, 1}. int((2 * 'q))} -  | as ('ex2 == 0 | 'ex2 == 1) could not be proven +  | as 'ex2 in {0, 1} could not be proven  |  | type variable 'ex2:  | pass/complex_exist_sat/v2.sail:1.18-50: diff --git a/test/typecheck/pass/constrained_struct/v1.expect b/test/typecheck/pass/constrained_struct/v1.expect index c9ed441f7..c832f69b6 100644 --- a/test/typecheck/pass/constrained_struct/v1.expect +++ b/test/typecheck/pass/constrained_struct/v1.expect @@ -12,4 +12,4 @@  |  | Caused by pass/constrained_struct/v1.sail:10.18-26:  |  | 10 |type MyStruct64 = MyStruct(65)  |  |  | ^------^ -  |  | Could not prove (65 == 32 | 65 == 64) for type constructor MyStruct +  |  | Could not prove 65 in {32, 64} for type constructor MyStruct diff --git a/test/typecheck/pass/constraint_syn.sail b/test/typecheck/pass/constraint_syn.sail new file mode 100644 index 000000000..c0e1d928c --- /dev/null +++ b/test/typecheck/pass/constraint_syn.sail @@ -0,0 +1,27 @@ +default Order dec +$include + +type xlen : Int + +constraint xlen in {32, 64} + +type is_32 : Bool = xlen == 32 + +type s_xlen : Int = xlen + +val test : (bool(is_32), bits(s_xlen)) -> unit + +function test(b: bool(is_32), xs: bits(s_xlen)) -> unit = { + () +} + +val main : unit -> unit + +function main() = { + if sizeof(xlen) == 32 then { + test(true, 0xFFFF_FFFF) + }; + if constraint(is_32) then { + () + } +} diff --git a/test/typecheck/pass/existential_ast3/v1.expect b/test/typecheck/pass/existential_ast3/v1.expect index 78dce04af..d2bb18654 100644 --- a/test/typecheck/pass/existential_ast3/v1.expect +++ b/test/typecheck/pass/existential_ast3/v1.expect @@ -3,7 +3,7 @@ 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));  | ^---------------^  | (int(33), int('ex291)) is not a subtype of (int('ex286), int('ex287)) -  | as ((33 == 32 | 33 == 64) & (0 <= 'ex291 & 'ex291 < 33)) could not be proven +  | as false could not be proven  |  | type variable 'ex286:  | pass/existential_ast3/v1.sail:16.23-25: diff --git a/test/typecheck/pass/existential_ast3/v2.expect b/test/typecheck/pass/existential_ast3/v2.expect index 858b6cf3f..3a1d2d19d 100644 --- a/test/typecheck/pass/existential_ast3/v2.expect +++ b/test/typecheck/pass/existential_ast3/v2.expect @@ -3,7 +3,7 @@ 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));  | ^---------------^  | (int(31), int('ex291)) is not a subtype of (int('ex286), int('ex287)) -  | as ((31 == 32 | 31 == 64) & (0 <= 'ex291 & 'ex291 < 31)) could not be proven +  | as false could not be proven  |  | type variable 'ex286:  | pass/existential_ast3/v2.sail:16.23-25: diff --git a/test/typecheck/pass/existential_ast3/v3.expect b/test/typecheck/pass/existential_ast3/v3.expect index 342792364..a69623584 100644 --- a/test/typecheck/pass/existential_ast3/v3.expect +++ b/test/typecheck/pass/existential_ast3/v3.expect @@ -3,4 +3,4 @@ 25 | Some(Ctor(64, unsigned(0b0 @ b @ a)))  | ^-----------------------------^  | Could not resolve quantifiers for Ctor -  | * ((64 == 32 | 64 == 64) & (0 <= 'ex330# & 'ex330# < 64)) +  | * (64 in {32, 64} & (0 <= 'ex330# & 'ex330# < 64)) diff --git a/test/typecheck/pass/reg_32_64/v1.expect b/test/typecheck/pass/reg_32_64/v1.expect index 5a9d57f87..bd5c98f54 100644 --- a/test/typecheck/pass/reg_32_64/v1.expect +++ b/test/typecheck/pass/reg_32_64/v1.expect @@ -11,6 +11,6 @@ Explicit effect annotations are deprecated. They are no longer used and can be r  | No overloading for R, tried:  | * set_R  | Could not resolve quantifiers for set_R -  | * (regno(0) & (56 == 32 | 56 == 64)) +  | * (regno(0) & 56 in {32, 64})  | * get_R  | Could not unify int('r) and bitvector(56)