Skip to content

Commit

Permalink
Add abstract types and global constraints
Browse files Browse the repository at this point in the history
Will likely fail with anything other than --just-check, as
only implemented in type system and parser for now
  • Loading branch information
Alasdair committed Jan 8, 2024
1 parent b14f649 commit ac40c46
Show file tree
Hide file tree
Showing 42 changed files with 553 additions and 304 deletions.
7 changes: 6 additions & 1 deletion language/sail.ott
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,11 @@ n_constraint :: 'NC_' ::=
| nexp '<=' nexp' :: :: bounded_le
| nexp '<' nexp' :: :: bounded_lt
| nexp != nexp' :: :: not_equal
| kid 'IN' { num1 , ... , numn } :: :: set
| nexp 'IN' { num1 , ... , numn } :: :: set
| n_constraint & n_constraint' :: :: or
| n_constraint | n_constraint' :: :: and
| id ( typ_arg0 , ... , typ_argn ) :: :: app
| id :: :: id
| kid :: :: var
| true :: :: true
| false :: :: false
Expand Down Expand Up @@ -318,6 +319,8 @@ type_def_aux :: 'TD_' ::=
{{ com tagged union type definition}} {{ texlong }}
| typedef id = enumerate { id1 ; ... ; idn semi_opt } :: :: enum
{{ com enumeration type definition}} {{ texlong }}
| typedef id : kind :: :: abstract
{{ com abstract type }}
| bitfield id : typ = { id1 : index_range1 , ... , idn : index_rangen } :: :: bitfield
{{ com register mutable bitfield type definition }} {{ texlong }}

Expand Down Expand Up @@ -760,6 +763,8 @@ def :: 'DEF_' ::=
{{ aux _ def_annot }} {{ auxparam 'a }}
| type_def :: :: type
{{ com type definition }}
| constraint n_constraint :: :: constraint
{{ com top-level constraint }}
| fundef :: :: fundef
{{ com function definition }}
| mapdef :: :: mapdef
Expand Down
62 changes: 31 additions & 31 deletions src/lib/ast_util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,13 @@ and nexp_simp_aux = function
let rec constraint_simp (NC_aux (nc_aux, l)) =
let nc_aux =
match nc_aux with
| NC_set (nexp, ints) ->
let nexp = nexp_simp nexp in
begin
match nexp with
| Nexp_aux (Nexp_constant c, _) -> if List.exists (fun i -> Big_int.equal c i) ints then NC_true else NC_false
| _ -> NC_set (nexp, ints)
end
| NC_equal (nexp1, nexp2) ->
let nexp1, nexp2 = (nexp_simp nexp1, nexp_simp nexp2) in
if nexp_identical nexp1 nexp2 then NC_true else NC_equal (nexp1, nexp2)
Expand Down Expand Up @@ -518,6 +525,7 @@ let nc_lteq n1 n2 = NC_aux (NC_bounded_le (n1, n2), Parse_ast.Unknown)
let nc_lt n1 n2 = NC_aux (NC_bounded_lt (n1, n2), Parse_ast.Unknown)
let nc_gteq n1 n2 = NC_aux (NC_bounded_ge (n1, n2), Parse_ast.Unknown)
let nc_gt n1 n2 = NC_aux (NC_bounded_gt (n1, n2), Parse_ast.Unknown)
let nc_id id = mk_nc (NC_id id)
let nc_var kid = mk_nc (NC_var kid)
let nc_true = mk_nc NC_true
let nc_false = mk_nc NC_false
Expand Down Expand Up @@ -545,6 +553,7 @@ let nc_not nc = mk_nc (NC_app (mk_id "not", [arg_bool nc]))

let mk_typschm typq typ = TypSchm_aux (TypSchm_ts (typq, typ), Parse_ast.Unknown)

let mk_empty_typquant ~loc:l = TypQ_aux (TypQ_no_forall, l)
let mk_typquant qis = TypQ_aux (TypQ_tq qis, Parse_ast.Unknown)

let mk_fexp id exp = FE_aux (FE_fexp (id, exp), no_annot)
Expand Down Expand Up @@ -804,6 +813,7 @@ and map_def_annot f (DEF_aux (aux, annot)) =
let aux =
match aux with
| DEF_type td -> DEF_type (map_typedef_annot f td)
| DEF_constraint nc -> DEF_constraint nc
| DEF_fundef fd -> DEF_fundef (map_fundef_annot f fd)
| DEF_mapdef md -> DEF_mapdef (map_mapdef_annot f md)
| DEF_outcome (outcome_spec, defs) -> DEF_outcome (outcome_spec, List.map (map_def_annot f) defs)
Expand Down Expand Up @@ -892,6 +902,7 @@ and string_of_typ_arg_aux = function
| A_bool nc -> string_of_n_constraint nc

and string_of_n_constraint = function
| NC_aux (NC_id id, _) -> string_of_id id
| NC_aux (NC_equal (n1, n2), _) -> string_of_nexp n1 ^ " == " ^ string_of_nexp n2
| NC_aux (NC_not_equal (n1, n2), _) -> string_of_nexp n1 ^ " != " ^ string_of_nexp n2
| NC_aux (NC_bounded_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2
Expand All @@ -900,7 +911,7 @@ and string_of_n_constraint = function
| NC_aux (NC_bounded_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2
| NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_set (kid, ns), _) -> string_of_kid kid ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_set (n, ns), _) -> string_of_nexp n ^ " in {" ^ string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_app (Id_aux (Operator op, _), [arg1; arg2]), _) ->
"(" ^ string_of_typ_arg arg1 ^ " " ^ op ^ " " ^ string_of_typ_arg arg2 ^ ")"
| NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_typ_arg args ^ ")"
Expand Down Expand Up @@ -1117,8 +1128,10 @@ let id_of_type_def_aux = function
| TD_record (id, _, _, _)
| TD_variant (id, _, _, _)
| TD_enum (id, _, _)
| TD_abstract (id, _)
| TD_bitfield (id, _, _) ->
id

let id_of_type_def (TD_aux (td_aux, _)) = id_of_type_def_aux td_aux

let id_of_val_spec (VS_aux (VS_val_spec (_, id, _), _)) = id
Expand Down Expand Up @@ -1187,14 +1200,15 @@ let lex_ord f g x1 x2 y1 y2 = match f x1 x2 with 0 -> g y1 y2 | n -> n

let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) =
match (nc1, nc2) with
| NC_id id1, NC_id id2 -> Id.compare id1 id2
| NC_equal (n1, n2), NC_equal (n3, n4)
| NC_bounded_ge (n1, n2), NC_bounded_ge (n3, n4)
| NC_bounded_gt (n1, n2), NC_bounded_gt (n3, n4)
| NC_bounded_le (n1, n2), NC_bounded_le (n3, n4)
| NC_bounded_lt (n1, n2), NC_bounded_lt (n3, n4)
| NC_not_equal (n1, n2), NC_not_equal (n3, n4) ->
lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4
| NC_set (k1, s1), NC_set (k2, s2) -> lex_ord Kid.compare (Util.compare_list Nat_big_num.compare) k1 k2 s1 s2
| NC_set (n1, s1), NC_set (n2, s2) -> lex_ord Nexp.compare (Util.compare_list Nat_big_num.compare) n1 n2 s1 s2
| NC_or (nc1, nc2), NC_or (nc3, nc4) | NC_and (nc1, nc2), NC_and (nc3, nc4) ->
lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4
| NC_app (f1, args1), NC_app (f2, args2) -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2
Expand Down Expand Up @@ -1224,6 +1238,8 @@ let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) =
| _, NC_var _ -> 1
| NC_true, _ -> -1
| _, NC_true -> 1
| NC_id _, _ -> -1
| _, NC_id _ -> 1

and typ_compare (Typ_aux (t1, _)) (Typ_aux (t2, _)) =
match (t1, t2) with
Expand Down Expand Up @@ -1278,6 +1294,9 @@ let is_typ_arg_typ = function A_aux (A_typ _, _) -> true | _ -> false

let is_typ_arg_bool = function A_aux (A_bool _, _) -> true | _ -> false

let typ_arg_kind (A_aux (aux, l)) =
match aux with A_typ _ -> K_aux (K_type, l) | A_bool _ -> K_aux (K_bool, l) | A_nexp _ -> K_aux (K_int, l)

module NC = struct
type t = n_constraint
let compare = nc_compare
Expand Down Expand Up @@ -1391,11 +1410,11 @@ let rec kopts_of_constraint (NC_aux (nc, _)) =
| NC_bounded_lt (nexp1, nexp2)
| NC_not_equal (nexp1, nexp2) ->
KOptSet.union (kopts_of_nexp nexp1) (kopts_of_nexp nexp2)
| NC_set (kid, _) -> KOptSet.singleton (mk_kopt K_int kid)
| NC_set (nexp, _) -> kopts_of_nexp nexp
| NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KOptSet.union (kopts_of_constraint nc1) (kopts_of_constraint nc2)
| NC_app (_, args) -> List.fold_left (fun s t -> KOptSet.union s (kopts_of_typ_arg t)) KOptSet.empty args
| NC_var kid -> KOptSet.singleton (mk_kopt K_bool kid)
| NC_true | NC_false -> KOptSet.empty
| NC_id _ | NC_true | NC_false -> KOptSet.empty

and kopts_of_typ (Typ_aux (t, _)) =
match t with
Expand Down Expand Up @@ -1436,11 +1455,11 @@ let rec tyvars_of_constraint (NC_aux (nc, _)) =
| NC_bounded_lt (nexp1, nexp2)
| NC_not_equal (nexp1, nexp2) ->
KidSet.union (tyvars_of_nexp nexp1) (tyvars_of_nexp nexp2)
| NC_set (kid, _) -> KidSet.singleton kid
| NC_set (nexp, _) -> tyvars_of_nexp nexp
| NC_or (nc1, nc2) | NC_and (nc1, nc2) -> KidSet.union (tyvars_of_constraint nc1) (tyvars_of_constraint nc2)
| NC_app (_, args) -> List.fold_left (fun s t -> KidSet.union s (tyvars_of_typ_arg t)) KidSet.empty args
| NC_var kid -> KidSet.singleton kid
| NC_true | NC_false -> KidSet.empty
| NC_id _ | NC_true | NC_false -> KidSet.empty

and tyvars_of_typ (Typ_aux (t, _)) =
match t with
Expand Down Expand Up @@ -1708,13 +1727,14 @@ let rec locate_nexp f (Nexp_aux (nexp_aux, l)) =
let rec locate_nc f (NC_aux (nc_aux, l)) =
let nc_aux =
match nc_aux with
| NC_id id -> NC_id (locate_id f id)
| NC_equal (nexp1, nexp2) -> NC_equal (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_ge (nexp1, nexp2) -> NC_bounded_ge (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_gt (nexp1, nexp2) -> NC_bounded_gt (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_le (nexp1, nexp2) -> NC_bounded_le (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_bounded_lt (nexp1, nexp2) -> NC_bounded_lt (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_not_equal (nexp1, nexp2) -> NC_not_equal (locate_nexp f nexp1, locate_nexp f nexp2)
| NC_set (kid, nums) -> NC_set (locate_kid f kid, nums)
| NC_set (nexp, nums) -> NC_set (locate_nexp f nexp, nums)
| NC_or (nc1, nc2) -> NC_or (locate_nc f nc1, locate_nc f nc2)
| NC_and (nc1, nc2) -> NC_and (locate_nc f nc1, locate_nc f nc2)
| NC_true -> NC_true
Expand Down Expand Up @@ -1900,26 +1920,17 @@ and nexp_subst_aux sv subst = function
| Nexp_exp nexp -> Nexp_exp (nexp_subst sv subst nexp)
| Nexp_neg nexp -> Nexp_neg (nexp_subst sv subst nexp)

let rec nexp_set_to_or l subst = function
| [] -> raise (Reporting.err_unreachable l __POS__ "Empty set in constraint")
| [int] -> NC_equal (subst, nconstant int)
| int :: ints -> NC_or (mk_nc (NC_equal (subst, nconstant int)), mk_nc (nexp_set_to_or l subst ints))

let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l)

and constraint_subst_aux l sv subst = function
| NC_id id -> NC_id id
| NC_equal (n1, n2) -> NC_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_ge (n1, n2) -> NC_bounded_ge (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_gt (n1, n2) -> NC_bounded_gt (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_le (n1, n2) -> NC_bounded_le (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_bounded_lt (n1, n2) -> NC_bounded_lt (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_not_equal (n1, n2) -> NC_not_equal (nexp_subst sv subst n1, nexp_subst sv subst n2)
| NC_set (kid, ints) as set_nc -> begin
match subst with
| A_aux (A_nexp (Nexp_aux (Nexp_var kid', _)), _) when Kid.compare kid sv = 0 -> NC_set (kid', ints)
| A_aux (A_nexp n, _) when Kid.compare kid sv = 0 -> nexp_set_to_or l n ints
| _ -> set_nc
end
| NC_set (n, ints) -> NC_set (nexp_subst sv subst n, ints)
| NC_or (nc1, nc2) -> NC_or (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
| NC_and (nc1, nc2) -> NC_and (constraint_subst sv subst nc1, constraint_subst sv subst nc2)
| NC_app (id, args) -> NC_app (id, List.map (typ_arg_subst sv subst) args)
Expand Down Expand Up @@ -1997,25 +2008,14 @@ let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg =
let snc nc = subst_kids_nc substs nc in
let re nc = NC_aux (nc, l) in
match nc with
| NC_id id -> re (NC_id id)
| NC_equal (n1, n2) -> re (NC_equal (snexp n1, snexp n2))
| NC_bounded_ge (n1, n2) -> re (NC_bounded_ge (snexp n1, snexp n2))
| NC_bounded_gt (n1, n2) -> re (NC_bounded_gt (snexp n1, snexp n2))
| NC_bounded_le (n1, n2) -> re (NC_bounded_le (snexp n1, snexp n2))
| NC_bounded_lt (n1, n2) -> re (NC_bounded_lt (snexp n1, snexp n2))
| NC_not_equal (n1, n2) -> re (NC_not_equal (snexp n1, snexp n2))
| NC_set (kid, is) -> begin
match KBindings.find kid substs with
| Nexp_aux (Nexp_constant i, _) ->
if List.exists (fun j -> Big_int.equal i j) is then re NC_true else re NC_false
| nexp -> begin
match List.rev is with
| i :: is ->
let equal_num i = re (NC_equal (nexp, nconstant i)) in
List.fold_left (fun nc i -> re (NC_or (equal_num i, nc))) (equal_num i) is
| [] -> re NC_false
end
| exception Not_found -> n_constraint
end
| NC_set (n, ints) -> re (NC_set (snexp n, ints))
| NC_or (nc1, nc2) -> re (NC_or (snc nc1, snc nc2))
| NC_and (nc1, nc2) -> re (NC_and (snc nc1, snc nc2))
| NC_true | NC_false -> n_constraint
Expand Down
9 changes: 7 additions & 2 deletions src/lib/ast_util.mli
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ val mk_funcl : ?loc:l -> id -> uannot pat -> uannot exp -> uannot funcl
val mk_fundef : uannot funcl list -> uannot def
val mk_val_spec : val_spec_aux -> uannot def
val mk_typschm : typquant -> typ -> typschm
val mk_empty_typquant : loc:l -> typquant
val mk_typquant : quant_item list -> typquant
val mk_qi_id : kind_aux -> kid -> quant_item
val mk_qi_nc : n_constraint -> quant_item
Expand Down Expand Up @@ -215,6 +216,8 @@ val is_typ_arg_nexp : typ_arg -> bool
val is_typ_arg_typ : typ_arg -> bool
val is_typ_arg_bool : typ_arg -> bool

val typ_arg_kind : typ_arg -> kind

(** {2 Sail built-in types} *)

val unknown_typ : typ
Expand Down Expand Up @@ -299,8 +302,9 @@ val nc_or : n_constraint -> n_constraint -> n_constraint
val nc_not : n_constraint -> n_constraint
val nc_true : n_constraint
val nc_false : n_constraint
val nc_set : kid -> Big_int.num list -> n_constraint
val nc_int_set : kid -> int list -> n_constraint
val nc_set : nexp -> Big_int.num list -> n_constraint
val nc_int_set : nexp -> int list -> n_constraint
val nc_id : id -> n_constraint
val nc_var : kid -> n_constraint

(** {2 Functions for building type arguments}*)
Expand Down Expand Up @@ -464,6 +468,7 @@ val string_of_index_range : index_range -> string

val id_of_fundef : 'a fundef -> id
val id_of_mapdef : 'a mapdef -> id
val id_of_type_def_aux : type_def_aux -> id
val id_of_type_def : 'a type_def -> id
val id_of_val_spec : 'a val_spec -> id
val id_of_dec_spec : 'a dec_spec -> id
Expand Down
12 changes: 3 additions & 9 deletions src/lib/callgraph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ let rec constraint_ids' (NC_aux (aux, _)) =
IdSet.union (nexp_ids' n1) (nexp_ids' n2)
| NC_or (nc1, nc2) | NC_and (nc1, nc2) -> IdSet.union (constraint_ids' nc1) (constraint_ids' nc2)
| NC_var _ | NC_true | NC_false | NC_set _ -> IdSet.empty
| NC_id id -> IdSet.singleton id
| NC_app (id, args) -> IdSet.add id (List.fold_left IdSet.union IdSet.empty (List.map typ_arg_ids' args))

and nexp_ids' (Nexp_aux (aux, _)) =
Expand Down Expand Up @@ -278,6 +279,7 @@ let add_def_to_graph graph (DEF_aux (def, _)) =
scan_typquant (Type id) typq
| TD_enum (id, ctors, _) ->
List.iter (fun ctor_id -> graph := G.add_edge (Constructor ctor_id) (Type id) !graph) ctors
| TD_abstract (id, _) -> graph := G.add_edges (Type id) [] !graph
| TD_bitfield (id, typ, ranges) ->
graph := G.add_edges (Type id) (List.map (fun id -> Type id) (IdSet.elements (typ_ids typ))) !graph
in
Expand Down Expand Up @@ -378,14 +380,6 @@ let rec graph_of_defs defs =

let graph_of_ast ast = graph_of_defs ast.defs

let id_of_typedef (TD_aux (aux, _)) =
match aux with
| TD_abbrev (id, _, _) -> id
| TD_record (id, _, _, _) -> id
| TD_variant (id, _, _, _) -> id
| TD_enum (id, _, _) -> id
| TD_bitfield (id, _, _) -> id

let id_of_reg_dec (DEC_aux (DEC_reg (_, id, _), _)) = id

let id_of_funcl (FCL_aux (FCL_funcl (id, _), _)) = id
Expand Down Expand Up @@ -426,7 +420,7 @@ let filter_ast_extra cuts g ast keep_std =
let ids = pat_ids pat |> IdSet.elements in
if List.exists (fun id -> NM.mem (Letbind id) g) ids then DEF_aux (DEF_let lb, def_annot) :: filter_ast' g defs
else filter_ast' g defs
| DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_typedef tdef)) g ->
| DEF_aux (DEF_type tdef, def_annot) :: defs when NM.mem (Type (id_of_type_def tdef)) g ->
DEF_aux (DEF_type tdef, def_annot) :: filter_ast' g defs
| DEF_aux (DEF_type _, _) :: defs -> filter_ast' g defs
| DEF_aux (DEF_measure (id, _, _), _) :: defs when NS.mem (Function id) cuts -> filter_ast' g defs
Expand Down
4 changes: 2 additions & 2 deletions src/lib/constant_propagation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ let lit_match = function
let fabricate_nexp_exist env l typ kids nc typ' =
match (kids, nc, Env.expand_synonyms env typ') with
| ( [kid],
NC_aux (NC_set (kid', i :: _), _),
NC_aux (NC_set (Nexp_aux (Nexp_var kid', _), i :: _), _),
Typ_aux (Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var kid'', _)), _)]), _) )
when Kid.compare kid kid' = 0 && Kid.compare kid kid'' = 0 ->
Nexp_aux (Nexp_constant i, Unknown)
Expand All @@ -154,7 +154,7 @@ let fabricate_nexp_exist env l typ kids nc typ' =
when Kid.compare kid kid'' = 0 ->
nint 32
| ( [kid],
NC_aux (NC_set (kid', i :: _), _),
NC_aux (NC_set (Nexp_aux (Nexp_var kid', _), i :: _), _),
Typ_aux
( Typ_app
( Id_aux (Id "range", _),
Expand Down
Loading

0 comments on commit ac40c46

Please sign in to comment.