diff --git a/src/lib/frontend/d_cnf.ml b/src/lib/frontend/d_cnf.ml index bb03097ca..7b3582e86 100644 --- a/src/lib/frontend/d_cnf.ml +++ b/src/lib/frontend/d_cnf.ml @@ -546,8 +546,8 @@ and handle_ty_app ?(update = false) ty_c l = apply_ty_substs tysubsts tv ) - | Tadt (hs, tyl, enum) -> - Tadt (hs, List.map (apply_ty_substs tysubsts) tyl, enum) + | Tadt (hs, tyl) -> + Tadt (hs, List.map (apply_ty_substs tysubsts) tyl) | Trecord ({ args; lbs; _ } as rcrd) -> Trecord { @@ -565,7 +565,7 @@ and handle_ty_app ?(update = false) ty_c l = (* Recover the initial versions of the types and apply them on the provided type arguments stored in [tyl]. *) match Cache.find_ty ty_c with - | Tadt (hs, _, enum) -> Tadt (hs, tyl, enum) + | Tadt (hs, _) -> Tadt (hs, tyl) | Trecord { args; _ } as ty -> let tysubsts = @@ -615,10 +615,9 @@ let mk_ty_decl (ty_c: DE.ty_cst) = let uid = Uid.of_dolmen ty_c in let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in Cache.store_ty ty_c (Ty.t_adt uid tyvl); - let rev_cs, is_enum = + let rev_cs = Array.fold_left ( - fun (accl, is_enum) DE.{ cstr; dstrs; _ } -> - let is_enum = is_enum && Array.length dstrs = 0 in + fun accl DE.{ cstr; dstrs; _ } -> let rev_fields = Array.fold_left ( fun acc tc_o -> @@ -628,12 +627,11 @@ let mk_ty_decl (ty_c: DE.ty_cst) = | None -> assert false ) [] dstrs in - (Uid.of_dolmen cstr, List.rev rev_fields) :: accl, is_enum - ) ([], true) cases + (Uid.of_dolmen cstr, List.rev rev_fields) :: accl + ) [] cases in let body = Some (List.rev rev_cs) in - let kind = if is_enum then `Enum else `Adt in - let ty = Ty.t_adt ~kind ~body uid tyvl in + let ty = Ty.t_adt ~body uid tyvl in Cache.store_ty ty_c ty | None | Some Abstract -> @@ -693,7 +691,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) = in Cache.store_ty ty_c ty - | Tadt (hs, tyl, _), Some (Adt { cases; ty = ty_c; _ }) -> + | Tadt (hs, tyl), Some (Adt { cases; ty = ty_c; _ }) -> let rev_cs = Array.fold_left ( fun accl DE.{ cstr; dstrs; _ } -> @@ -737,12 +735,6 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) = fun acc tdef -> match tdef with | DE.Adt { cases; record; ty = ty_c; } as adt -> - let is_enum = - Array.fold_left ( - fun is_enum DE.{ dstrs; _ } -> - is_enum && Array.length dstrs = 0 - ) true cases - in let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in let uid = Uid.of_dolmen ty_c in let ty = @@ -750,8 +742,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) = then Ty.trecord ~record_constr:uid tyvl uid [] else - let kind = if is_enum then `Enum else `Adt in - Ty.t_adt ~kind uid tyvl + Ty.t_adt uid tyvl in Cache.store_ty ty_c ty; (ty, Some adt) :: acc diff --git a/src/lib/frontend/typechecker.ml b/src/lib/frontend/typechecker.ml index 7227f9098..fdc2373c0 100644 --- a/src/lib/frontend/typechecker.ml +++ b/src/lib/frontend/typechecker.ml @@ -69,14 +69,10 @@ module Types = struct match ty with | Ty.Text (lty', s) | Ty.Trecord { Ty.args = lty'; name = s; _ } - | Ty.Tadt (s,lty',`Adt) -> + | Ty.Tadt (s,lty') -> if List.length lty <> List.length lty' then Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc; lty' - | Ty.Tadt (s,lty',`Enum) -> - if List.length lty <> 0 || List.length lty' <> 0 then - Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc; - [] | _ -> assert false let equal_pp_vars lpp lvars = @@ -149,7 +145,7 @@ module Types = struct if not (Lists.is_empty ty_vars) then Errors.typing_error (PolymorphicEnum id) loc; let body = List.map (fun constr -> Uid.of_string constr, []) l in - let ty = Ty.t_adt ~kind:`Enum ~body:(Some body) (Uid.of_string id) [] in + let ty = Ty.t_adt ~body:(Some body) (Uid.of_string id) [] in ty, { env with to_ty = MString.add id ty env.to_ty } | Record (record_constr, lbs) -> let lbs = @@ -275,8 +271,9 @@ module Env = struct let add_fpa_enum map = let ty = Fpa_rounding.fpa_rounding_mode in match ty with - | Ty.Tadt (name, [], `Enum) -> - let Adt cases = Ty.type_body name [] in + | Ty.Tadt (name, []) -> + let Ty.{ cases; kind } = Ty.type_body name [] in + assert (Stdlib.(kind = Ty.Enum)); let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in List.fold_left (fun m c -> @@ -301,8 +298,9 @@ module Env = struct let find_builtin_cstr ty n = match ty with - | Ty.Tadt (name, [], `Enum) -> - let Adt cases = Ty.type_body name [] in + | Ty.Tadt (name, []) -> + let Ty.{ cases; kind } = Ty.type_body name [] in + assert (Stdlib.(kind = Ty.Enum)); let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in List.find (Uid.equal n) cstrs | _ -> @@ -1005,8 +1003,8 @@ let rec type_term ?(call_from_type_form=false) env f = let e = type_term env e in let ty = Ty.shorten e.c.tt_ty in let ty_body = match ty with - | Ty.Tadt (name, params, _) -> - let Ty.Adt cases = Ty.type_body name params in + | Ty.Tadt (name, params) -> + let Ty.{ cases; _ } = Ty.type_body name params in cases | Ty.Trecord { Ty.record_constr; lbs; _ } -> [{Ty.constr = record_constr; destrs = lbs}] @@ -1413,8 +1411,8 @@ and type_form ?(in_theory=false) env f = let e = type_term env e in let ty = e.c.tt_ty in let ty_body = match ty with - | Ty.Tadt (name, params, _) -> - let Ty.Adt cases = Ty.type_body name params in + | Ty.Tadt (name, params) -> + let Ty.{ cases; _ } = Ty.type_body name params in cases | Ty.Trecord { Ty.record_constr; lbs; _ } -> [{Ty.constr = record_constr ; destrs = lbs}] diff --git a/src/lib/reasoners/adt.ml b/src/lib/reasoners/adt.ml index 34cfc922f..be34ff793 100644 --- a/src/lib/reasoners/adt.ml +++ b/src/lib/reasoners/adt.ml @@ -54,9 +54,9 @@ let constr_of_destr ty dest = ~module_name:"Adt" ~function_name:"constr_of_destr" "ty = %a" Ty.print ty; match ty with - | Ty.Tadt (s, params, _) -> + | Ty.Tadt (s, params) -> begin - let Ty.Adt cases = Ty.type_body s params in + let Ty.{ cases; _ } = Ty.type_body s params in try List.find (fun { Ty.destrs; _ } -> @@ -173,8 +173,8 @@ module Shostak (X : ALIEN) = struct in let xs = List.rev sx in match f, xs, ty with - | Sy.Op Sy.Constr hs, _, Ty.Tadt (name, params, _) -> - let Ty.Adt cases = Ty.type_body name params in + | Sy.Op Sy.Constr hs, _, Ty.Tadt (name, params) -> + let Ty.{ cases; _ } = Ty.type_body name params in let case_hs = try Ty.assoc_destrs hs cases with Not_found -> assert false in diff --git a/src/lib/reasoners/adt_rel.ml b/src/lib/reasoners/adt_rel.ml index ca13e5c76..e85d243c4 100644 --- a/src/lib/reasoners/adt_rel.ml +++ b/src/lib/reasoners/adt_rel.ml @@ -83,14 +83,14 @@ module Domain = struct let unknown ty = match ty with - | Ty.Tadt (name, params, _) -> + | Ty.Tadt (name, params) -> (* Return the list of all the constructors of the type of [r]. *) - let Adt body = Ty.type_body name params in + let Ty.{ cases; _ } = Ty.type_body name params in let constrs = List.fold_left (fun acc Ty.{ constr; _ } -> TSet.add constr acc - ) TSet.empty body + ) TSet.empty cases in assert (not @@ TSet.is_empty constrs); { constrs; ex = Ex.empty } @@ -156,7 +156,12 @@ module Domains = struct let is_enum r = match X.type_info r with - | Ty.Tadt (_, [], `Enum) -> true + | Ty.Tadt (name, params) -> + let Ty.{ kind; _ } = Ty.type_body name params in + begin match kind with + | Enum -> true + | Adt -> false + end | _ -> false let internal_update r nd t = @@ -477,10 +482,10 @@ let build_constr_eq r c = match Th.embed r with | Alien r -> begin match X.type_info r with - | Ty.Tadt (name, params, _) as ty -> - let Ty.Adt body = Ty.type_body name params in + | Ty.Tadt (name, params) as ty -> + let Ty.{ cases; _ } = Ty.type_body name params in let ds = - try Ty.assoc_destrs c body with Not_found -> assert false + try Ty.assoc_destrs c cases with Not_found -> assert false in let xs = List.map (fun (_, ty) -> E.fresh_name ty) ds in let cons = E.mk_constr c xs ty in @@ -578,9 +583,9 @@ let two = Numbers.Q.from_int 2 (* TODO: we should compute this reverse map in `Ty` and store it there. *) let constr_of_destr ty d = match ty with - | Ty.Tadt (name, params, _) -> + | Ty.Tadt (name, params) -> begin - let Ty.Adt cases = Ty.type_body name params in + let Ty.{ cases; _ } = Ty.type_body name params in try let r = List.find diff --git a/src/lib/reasoners/theory.ml b/src/lib/reasoners/theory.ml index 9c8e098b2..5c5dddd8a 100644 --- a/src/lib/reasoners/theory.ml +++ b/src/lib/reasoners/theory.ml @@ -170,7 +170,7 @@ module Main_Default : S = struct (* cannot do better for records ? *) Uid.Map.add name ty mp - | Tadt (hs, _, _) -> + | Tadt (hs, _) -> (* cannot do better for ADT ? *) Uid.Map.add hs ty mp )sty Uid.Map.empty diff --git a/src/lib/structures/fpa_rounding.ml b/src/lib/structures/fpa_rounding.ml index 73fb653e1..d12e38d51 100644 --- a/src/lib/structures/fpa_rounding.ml +++ b/src/lib/structures/fpa_rounding.ml @@ -96,7 +96,7 @@ let fpa_rounding_mode_dty, d_cstrs, fpa_rounding_mode = let body = List.map (fun (c, _) -> Uid.of_dolmen c, []) d_cstrs in - let ty = Ty.t_adt ~kind:`Enum ~body:(Some body) (Uid.of_dolmen ty_cst) [] in + let ty = Ty.t_adt ~body:(Some body) (Uid.of_dolmen ty_cst) [] in DE.Ty.apply ty_cst [], d_cstrs, ty let rounding_mode_of_smt_hs = diff --git a/src/lib/structures/ty.ml b/src/lib/structures/ty.ml index ed9c2fe70..0f2dcafbf 100644 --- a/src/lib/structures/ty.ml +++ b/src/lib/structures/ty.ml @@ -34,7 +34,7 @@ type t = | Tbitv of int | Text of t list * Uid.t | Tfarray of t * t - | Tadt of Uid.t * t list * [`Adt | `Enum] + | Tadt of Uid.t * t list | Trecord of trecord and tvar = { v : int ; mutable value : t option } @@ -55,9 +55,9 @@ module Smtlib = struct | Tfarray (a_t, r_t) -> Fmt.pf ppf "(Array %a %a)" pp a_t pp r_t | Text ([], name) - | Trecord { args = []; name; _ } | Tadt (name, [], _) -> Uid.pp ppf name + | Trecord { args = []; name; _ } | Tadt (name, []) -> Uid.pp ppf name | Text (args, name) - | Trecord { args; name; _ } | Tadt (name, args, _) -> + | Trecord { args; name; _ } | Tadt (name, args) -> Fmt.(pf ppf "(@[%a %a@])" Uid.pp name (list ~sep:sp pp) args) | Tvar { v; value = None; _ } -> Fmt.pf ppf "A%d" v | Tvar { value = Some t; _ } -> pp ppf t @@ -72,8 +72,12 @@ type adt_constr = { constr : Uid.t ; destrs : (Uid.t * t) list } -type type_body = - | Adt of adt_constr list +type adt_kind = Enum | Adt + +type type_body = { + cases : adt_constr list; + kind : adt_kind +} let assoc_destrs hs cases = @@ -136,14 +140,12 @@ let print_generic body_of = fprintf fmt "}" end end - | Tadt (n, lv, _) -> + | Tadt (n, lv) -> fprintf fmt "%a %a" print_list lv Uid.pp n; begin match body_of with | None -> () | Some type_body -> - let cases = match type_body n lv with - | Adt cases -> cases - in + let { cases; _ } = type_body n lv in fprintf fmt " = {"; let first = ref true in List.iter @@ -219,11 +221,11 @@ let rec shorten ty = r.lbs <- List.map (fun (lb, ty) -> lb, shorten ty) r.lbs; ty - | Tadt (n, args, enum) -> + | Tadt (n, args) -> let args' = List.map shorten args in shorten_body n args; (* should not rebuild the type if no changes are made *) - Tadt (n, args', enum) + Tadt (n, args') | Tint | Treal | Tbool | Tunit | Tbitv _ -> ty @@ -255,7 +257,7 @@ let rec compare t1 t2 = compare_list l1 l2 | Trecord _, _ -> -1 | _ , Trecord _ -> 1 - | Tadt (s1, pars1, _), Tadt (s2, pars2, _) -> + | Tadt (s1, pars1), Tadt (s2, pars2) -> let c = Uid.compare s1 s2 in if c <> 0 then c else compare_list pars1 pars2 @@ -307,7 +309,7 @@ let rec equal t1 t2 = | Tint, Tint | Treal, Treal | Tbool, Tbool | Tunit, Tunit -> true | Tbitv n1, Tbitv n2 -> n1 =n2 - | Tadt (s1, pars1, _), Tadt (s2, pars2, _) -> + | Tadt (s1, pars1), Tadt (s2, pars2) -> begin try Uid.equal s1 s2 && List.for_all2 equal pars1 pars2 with Invalid_argument _ -> false @@ -338,7 +340,7 @@ let rec matching s pat t = (fun s (_, p) (_, ty) -> matching s p ty) s r1.lbs r2.lbs | Tint , Tint | Tbool , Tbool | Treal , Treal | Tunit, Tunit -> s | Tbitv n , Tbitv m when n=m -> s - | Tadt(n1, args1, _), Tadt(n2, args2, _) when Uid.equal n1 n2 -> + | Tadt(n1, args1), Tadt(n2, args2) when Uid.equal n1 n2 -> List.fold_left2 matching s args1 args2 | _ , _ -> raise (TypeClash(pat,t)) @@ -368,10 +370,10 @@ let apply_subst = name = r.name; lbs = lbs} - | Tadt(name, params, enum) + | Tadt(name, params) [@ocaml.ppwarning "TODO: detect when there are no changes "] -> - Tadt (name, List.map (apply_subst s) params, enum) + Tadt (name, List.map (apply_subst s) params) | Tint | Treal | Tbool | Tunit | Tbitv _ -> ty in @@ -402,9 +404,9 @@ let rec fresh ty subst = (x, ty)::lbs, subst) lbs ([], subst) in Trecord {r with args = args; name = n; lbs = lbs}, subst - | Tadt(s, args, enum) -> + | Tadt(s, args) -> let args, subst = fresh_list args subst in - Tadt (s, args, enum), subst + Tadt (s, args), subst | t -> t, subst and fresh_list lty subst = @@ -434,22 +436,21 @@ module Decls = struct let fresh_type params body = let params, subst = fresh_list params esubst in - match body with - | Adt cases -> - let _subst, cases = - List.fold_left - (fun (subst, cases) {constr; destrs} -> - let subst, destrs = - List.fold_left - (fun (subst, destrs) (d, ty) -> - let ty, subst = fresh ty subst in - subst, (d, ty) :: destrs - )(subst, []) (List.rev destrs) - in - subst, {constr; destrs} :: cases - )(subst, []) (List.rev cases) - in - params, Adt cases + let { cases; kind } = body in + let _subst, cases = + List.fold_left + (fun (subst, cases) {constr; destrs} -> + let subst, destrs = + List.fold_left + (fun (subst, destrs) (d, ty) -> + let ty, subst = fresh ty subst in + subst, (d, ty) :: destrs + )(subst, []) (List.rev destrs) + in + subst, {constr; destrs} :: cases + )(subst, []) (List.rev cases) + in + params, { cases; kind } let add name params body = @@ -488,16 +489,17 @@ module Decls = struct )M.empty params args with Invalid_argument _ -> assert false in - let body = match body with - | Adt cases -> - Adt( - List.map - (fun {constr; destrs} -> - {constr; - destrs = - List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs } - ) cases - ) + let body = + let { cases; kind } = body in + let cases = + List.map + (fun {constr; destrs} -> + {constr; + destrs = + List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs } + ) cases + in + { cases; kind } in let params = List.map (fun ty -> apply_subst sbt ty) params in add name params body; @@ -529,8 +531,8 @@ let fresh_empty_text = in text [] (Uid.of_dolmen id) -let t_adt ?(kind=`Adt) ?(body=None) s ty_vars = - let ty = Tadt (s, ty_vars, kind) in +let t_adt ?(body=None) s ty_vars = + let ty = Tadt (s, ty_vars) in begin match body with | None -> () | Some [] -> assert false @@ -541,12 +543,20 @@ let t_adt ?(kind=`Adt) ?(body=None) s ty_vars = let cases = List.map (fun (constr, destrs) -> {constr; destrs}) cases in - Decls.add s ty_vars (Adt cases) + let is_enum = + List.for_all (fun { destrs; _ } -> Lists.is_empty destrs) cases + in + let kind = if is_enum then Enum else Adt in + Decls.add s ty_vars { cases; kind } | Some cases -> let cases = List.map (fun (constr, destrs) -> {constr; destrs}) cases in - Decls.add s ty_vars (Adt cases) + let is_enum = + List.for_all (fun { destrs; _ } -> Lists.is_empty destrs) cases + in + let kind = if is_enum then Enum else Adt in + Decls.add s ty_vars { cases; kind } end; ty @@ -576,7 +586,7 @@ let rec hash t = in abs h - | Tadt (s, args, _) -> + | Tadt (s, args) -> (* We do not hash constructors. *) let h = List.fold_left (fun h ty -> 31 * h + hash ty) (Uid.hash s) args @@ -590,7 +600,7 @@ let occurs { v = n; _ } t = | Tvar { v = m; _ } -> n=m | Text(l,_) -> List.exists occursrec l | Tfarray (t1,t2) -> occursrec t1 || occursrec t2 - | Trecord { args ; _ } | Tadt (_, args, _) -> List.exists occursrec args + | Trecord { args ; _ } | Tadt (_, args) -> List.exists occursrec args | Tint | Treal | Tbool | Tunit | Tbitv _ -> false in occursrec t @@ -615,7 +625,7 @@ let rec unify t1 t2 = | Tint, Tint | Tbool, Tbool | Treal, Treal | Tunit, Tunit -> () | Tbitv n , Tbitv m when m=n -> () - | Tadt(n1, p1, _), Tadt (n2, p2, _) when Uid.equal n1 n2 -> + | Tadt(n1, p1), Tadt (n2, p2) when Uid.equal n1 n2 -> List.iter2 unify p1 p2 | _ , _ [@ocaml.ppwarning "TODO: remove fragile pattern "] -> @@ -658,7 +668,7 @@ let vty_of t = | Trecord { args; lbs; _ } -> let acc = List.fold_left vty_of_rec acc args in List.fold_left (fun acc (_, ty) -> vty_of_rec acc ty) acc lbs - | Tadt(_, args, _) -> + | Tadt(_, args) -> List.fold_left vty_of_rec acc args | Tvar { value = Some _ ; _ } @@ -683,8 +693,8 @@ let rec monomorphize ty = | Tvar {v=v; value=None} -> text [] (Uid.of_string ("'_c"^(string_of_int v))) | Tvar ({ value = Some ty1; _ } as r) -> Tvar { r with value = Some (monomorphize ty1)} - | Tadt(name, params, enum) -> - Tadt(name, List.map monomorphize params, enum) + | Tadt(name, params) -> + Tadt(name, List.map monomorphize params) let print_subst fmt sbt = M.iter (fun n ty -> Format.fprintf fmt "%d -> %a" n print ty) sbt; diff --git a/src/lib/structures/ty.mli b/src/lib/structures/ty.mli index c532b433b..2b730c67a 100644 --- a/src/lib/structures/ty.mli +++ b/src/lib/structures/ty.mli @@ -53,13 +53,13 @@ type t = (** Functional arrays. [TFarray (src,dst)] maps values of type [src] to values of type [dst]. *) - | Tadt of Uid.t * t list * [`Adt | `Enum] - (** Application of algebraic data types. [Tadt (a, params, kind)] denotes + | Tadt of Uid.t * t list + (** Application of algebraic data types. [Tadt (a, params)] denotes the application of the polymorphic datatype [a] to the types parameters - [params]. The flag [kind] determines if the ADT is an enum. + [params]. For instance the type of integer lists can be represented by the - value [Tadt (Hstring.make "list", [Tint], `Adt)] where the identifier + value [Tadt (Hstring.make "list", [Tint]] where the identifier {e list} denotes a polymorphic ADT defined by the user with [t_adt]. *) | Trecord of trecord @@ -99,13 +99,22 @@ type adt_constr = their respective types *) } -(** bodies of types definitions. Currently, bodies are inlined in the +type adt_kind = + | Enum (* ADT whose all the constructors have no payload. *) + | Adt + +(** Bodies of types definitions. Currently, bodies are inlined in the type [t] for records and enumerations. But, this is not possible for recursive ADTs *) -type type_body = - | Adt of adt_constr list +type type_body = { + cases : adt_constr list; (** body of an algebraic datatype *) + kind : adt_kind + (** This flag is used by the case splitting mechanism of the ADT theory. + We perform eager splitting on ADT of kind [enum]. *) +} + module Svty : Set.S with type elt = int (** Sets of type variables, indexed by their identifier. *) @@ -169,17 +178,13 @@ val text : t list -> Uid.t -> t given. *) val t_adt : - ?kind:[`Adt | `Enum] -> - ?body: ((Uid.t * (Uid.t * t) list) list) option -> Uid.t -> t list -> t + ?body:((Uid.t * (Uid.t * t) list) list) option -> Uid.t -> t list -> t (** Create an algebraic datatype. The body is a list of constructors, where each constructor is associated with the list of its destructors with their respective types. If [body] is none, then no definition will be registered for this type. The second argument is the name of the type. The third one provides its list - of arguments. - - The flag [kind] is used to annotate ADT that are enum types. [`Adt] - kind is the default. *) + of arguments. *) val trecord : ?sort_fields:bool ->