Skip to content

Commit

Permalink
Merge Enum theory into ADT theory (#1094)
Browse files Browse the repository at this point in the history
* Merge `Enum` Theory into `ADT` Theory

After refactoring both `Enum` and `ADT` theories, they shared most of
their implementation.
  • Loading branch information
Halbaroth authored Jun 12, 2024
1 parent 2069070 commit 84e29b9
Show file tree
Hide file tree
Showing 16 changed files with 243 additions and 590 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Cnf D_cnf D_loop D_state_option Input Frontend Parsed_interface Typechecker
Models
; reasoners
Ac Arith Arrays_rel Bitv Ccx Shostak Relation Enum Enum_rel
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals Ite_rel Matching Matching_types
Polynome Records Records_rel Satml_frontend_hybrid Satml_frontend Satml
Expand Down
86 changes: 24 additions & 62 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,6 @@ and handle_ty_app ?(update = false) ty_c l =
in
apply_ty_substs tysubsts ty

| Tsum _ as ty -> ty
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

Expand Down Expand Up @@ -611,48 +610,29 @@ let mk_ty_decl (ty_c: DE.ty_cst) =
let ty = Ty.trecord ~record_constr tyvl (Uid.of_dolmen ty_c) lbs in
Cache.store_ty ty_c ty

| Some ((Adt { cases; _ } as adt)) ->
| Some (Adt { cases; _ } as adt) ->
Nest.add_nest [adt];
let uid = Uid.of_dolmen ty_c in
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let rev_cs, is_enum =
Cache.store_ty ty_c (Ty.t_adt uid tyvl);
let rev_cs =
Array.fold_left (
fun (accl, is_enum) DE.{ cstr; dstrs; _ } ->
let is_enum =
if is_enum
then
if Array.length dstrs = 0
then true
else (
let ty = Ty.t_adt uid tyvl in
Cache.store_ty ty_c ty;
false
)
else false
in
fun accl DE.{ cstr; dstrs; _ } ->
let rev_fields =
Array.fold_left (
fun acc tc_o ->
match tc_o with
| Some (DE.{ id_ty; _ } as id) ->
(Uid.of_dolmen id, dty_to_ty id_ty) :: acc
| Some (DE.{ id_ty; _ } as field) ->
(Uid.of_dolmen field, dty_to_ty id_ty) :: acc
| None -> assert false
) [] dstrs
in
(Uid.of_dolmen cstr, List.rev rev_fields) :: accl, is_enum
) ([], true) cases
(Uid.of_dolmen cstr, List.rev rev_fields) :: accl
) [] cases
in
if is_enum
then
let cstrs =
List.map (fun s -> fst s) (List.rev rev_cs)
in
let ty = Ty.tsum uid cstrs in
Cache.store_ty ty_c ty
else
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body uid tyvl in
Cache.store_ty ty_c ty
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body uid tyvl in
Cache.store_ty ty_c ty

| None | Some Abstract ->
let ty_params = []
Expand Down Expand Up @@ -728,8 +708,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
) [] cases
in
let body = Some (List.rev rev_cs) in
let args = tyl in
let ty = Ty.t_adt ~body hs args in
let ty = Ty.t_adt ~body hs tyl in
Cache.store_ty ty_c ty

| _ -> assert false
Expand Down Expand Up @@ -757,32 +736,17 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
match tdef with
| DE.Adt { cases; record; ty = ty_c; } as adt ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in

let cns, is_enum =
Array.fold_right (
fun DE.{ dstrs; cstr; _ } (nacc, is_enum) ->
Uid.of_dolmen cstr :: nacc,
Array.length dstrs = 0 && is_enum
) cases ([], true)
in
let uid = Uid.of_dolmen ty_c in
if is_enum
then (
let ty = Ty.tsum uid cns in
Cache.store_ty ty_c ty;
(* If it's an enum we don't need the second iteration. *)
acc
)
else (
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else Ty.t_adt uid tyvl
in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc
)
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else
Ty.t_adt uid tyvl
in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc

| Abstract ->
assert false (* unreachable in the second iteration *)
) [] (List.rev rev_tdefs)
Expand Down Expand Up @@ -1082,9 +1046,7 @@ let rec mk_expr
match Cache.find_ty ty_c with
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]
| Ty.Tsum _ as ty ->
let cstr = E.mk_constr (Uid.of_dolmen cstr) [] ty in
E.mk_eq ~iff:false (aux_mk_expr x) cstr

| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
Expand Down Expand Up @@ -1375,7 +1337,7 @@ let rec mk_expr
| B.Constructor _, _ ->
let ty = dty_to_ty term_ty in
begin match ty with
| Ty.Tadt (_, _) ->
| Ty.Tadt _ ->
let sy = Sy.constr @@ Uid.of_dolmen tcst in
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term sy l ty
Expand Down
38 changes: 17 additions & 21 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ module Types = struct
if List.length lty <> List.length lty' then
Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc;
lty'
| Ty.Tsum (s, _) ->
if List.length lty <> 0 then
Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc;
[]
| _ -> assert false

let equal_pp_vars lpp lvars =
Expand Down Expand Up @@ -145,12 +141,11 @@ module Types = struct
| Abstract ->
let ty = Ty.text ty_vars (Uid.of_string id) in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Enum lc ->
| Enum l ->
if not (Lists.is_empty ty_vars) then
Errors.typing_error (PolymorphicEnum id) loc;
let ty =
Ty.tsum (Uid.of_string id) (List.map Uid.of_string lc)
in
let body = List.map (fun constr -> Uid.of_string constr, []) l in
let ty = Ty.t_adt ~body:(Some body) (Uid.of_string id) [] in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Record (record_constr, lbs) ->
let lbs =
Expand Down Expand Up @@ -276,7 +271,10 @@ module Env = struct
let add_fpa_enum map =
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tsum (_, cstrs) ->
| Ty.Tadt (name, []) ->
let Ty.{ cases; kind } = Ty.type_body name [] in
assert (Stdlib.(kind = Ty.Enum));
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.fold_left
(fun m c ->
match Fpa_rounding.translate_smt_rounding_mode
Expand All @@ -300,9 +298,13 @@ module Env = struct

let find_builtin_cstr ty n =
match ty with
| Ty.Tsum (_, cstrs) ->
| Ty.Tadt (name, []) ->
let Ty.{ cases; kind } = Ty.type_body name [] in
assert (Stdlib.(kind = Ty.Enum));
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.find (Uid.equal n) cstrs
| _ -> assert false
| _ ->
assert false

let add_fpa_builtins env =
let (->.) args result = { args; result } in
Expand Down Expand Up @@ -1002,13 +1004,10 @@ let rec type_term ?(call_from_type_form=false) env f =
let ty = Ty.shorten e.c.tt_ty in
let ty_body = match ty with
| Ty.Tadt (name, params) ->
begin match Ty.type_body name params with
| Ty.Adt cases -> cases
end
let Ty.{ cases; _ } = Ty.type_body name params in
cases
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr; destrs = lbs}]
| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e; destrs = []}) l
| _ -> Errors.typing_error (ShouldBeADT ty) loc
in
let pats =
Expand Down Expand Up @@ -1413,14 +1412,11 @@ and type_form ?(in_theory=false) env f =
let ty = e.c.tt_ty in
let ty_body = match ty with
| Ty.Tadt (name, params) ->
begin match Ty.type_body name params with
| Ty.Adt cases -> cases
end
let Ty.{ cases; _ } = Ty.type_body name params in
cases
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr ; destrs = lbs}]

| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e ; destrs = []}) l
| _ ->
Errors.typing_error (ShouldBeADT ty) f.pp_loc
in
Expand Down
4 changes: 2 additions & 2 deletions src/lib/reasoners/adt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ let constr_of_destr ty dest =
match ty with
| Ty.Tadt (s, params) ->
begin
let Ty.Adt cases = Ty.type_body s params in
let Ty.{ cases; _ } = Ty.type_body s params in
try
List.find
(fun { Ty.destrs; _ } ->
Expand Down Expand Up @@ -174,7 +174,7 @@ module Shostak (X : ALIEN) = struct
let xs = List.rev sx in
match f, xs, ty with
| Sy.Op Sy.Constr hs, _, Ty.Tadt (name, params) ->
let Ty.Adt cases = Ty.type_body name params in
let Ty.{ cases; _ } = Ty.type_body name params in
let case_hs =
try Ty.assoc_destrs hs cases with Not_found -> assert false
in
Expand Down
Loading

0 comments on commit 84e29b9

Please sign in to comment.