Skip to content

Commit

Permalink
Merge Enum Theory into ADT Theory
Browse files Browse the repository at this point in the history
After refactoring both `Enum` and `ADT` theories, they shared most of
their implementation.

This PR merges `Enum` theory into `ADT` ones.
  • Loading branch information
Halbaroth committed May 18, 2024
1 parent c49d903 commit 799ec51
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 275 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
Cnf D_cnf D_loop D_state_option Input Frontend Parsed_interface Typechecker
Models
; reasoners
Ac Arith Arrays_rel Bitv Ccx Shostak Relation Enum Enum_rel
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals Ite_rel Matching Matching_types
Polynome Records Records_rel Satml_frontend_hybrid Satml_frontend Satml
Expand Down
114 changes: 35 additions & 79 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,6 @@ and handle_ty_app ?(update = false) ty_c l =
in
apply_ty_substs tysubsts ty

| Tsum _ as ty -> ty
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

Expand Down Expand Up @@ -573,50 +572,29 @@ let mk_ty_decl (ty_c: DE.ty_cst) =
let ty = Ty.trecord ~record_constr tyvl (Uid.of_dolmen ty_c) lbs in
Cache.store_ty ty_c ty

| Some (
(Adt { cases; _ } as adt)
) ->
(Topological_order.sort [adt];
let uid = Uid.of_dolmen ty_c in
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let rev_cs, is_enum =
Array.fold_left (
fun (accl, is_enum) DE.{ cstr; dstrs; _ } ->
let is_enum =
if is_enum
then
if Array.length dstrs = 0
then true
else (
let ty = Ty.t_adt uid tyvl in
Cache.store_ty ty_c ty;
false
)
else false
in
let rev_fields =
Array.fold_left (
fun acc tc_o ->
match tc_o with
| Some (DE.{ id_ty; _ } as field) ->
(Uid.of_dolmen field, dty_to_ty id_ty) :: acc
| None -> assert false
) [] dstrs
in
(Uid.of_dolmen cstr, List.rev rev_fields) :: accl, is_enum
) ([], true) cases
in
if is_enum
then
let cstrs =
List.map (fun s -> fst s) (List.rev rev_cs)
in
let ty = Ty.tsum uid cstrs in
Cache.store_ty ty_c ty
else
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body uid tyvl in
Cache.store_ty ty_c ty)
| Some (Adt { cases; _ } as adt) ->
Topological_order.sort [adt];
let uid = Uid.of_dolmen ty_c in
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
Cache.store_ty ty_c (Ty.t_adt uid tyvl);
let rev_cs =
Array.fold_left (
fun accl DE.{ cstr; dstrs; _ } ->
let rev_fields =
Array.fold_left (
fun acc tc_o ->
match tc_o with
| Some (DE.{ id_ty; _ } as field) ->
(Uid.of_dolmen field, dty_to_ty id_ty) :: acc
| None -> assert false
) [] dstrs
in
(Uid.of_dolmen cstr, List.rev rev_fields) :: accl
) [] cases
in
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body uid tyvl in
Cache.store_ty ty_c ty

| None | Some Abstract ->
let ty_params = []
Expand Down Expand Up @@ -675,10 +653,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
in
Cache.store_ty ty_c ty

| Tadt (hs, tyl),
Some (
Adt { cases; ty = ty_c; _ }
) ->
| Tadt (hs, tyl), Some (Adt { cases; ty = ty_c; _ }) ->
let rev_cs =
Array.fold_left (
fun accl DE.{ cstr; dstrs; _ } ->
Expand All @@ -695,8 +670,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
) [] cases
in
let body = Some (List.rev rev_cs) in
let args = tyl in
let ty = Ty.t_adt ~body hs args in
let ty = Ty.t_adt ~body hs tyl in
Cache.store_ty ty_c ty

| _ -> assert false
Expand Down Expand Up @@ -724,32 +698,16 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
match tdef with
| DE.Adt { cases; record; ty = ty_c; } as adt ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in

let cns, is_enum =
Array.fold_right (
fun DE.{ dstrs; cstr; _ } (nacc, is_enum) ->
Uid.of_dolmen cstr :: nacc,
Array.length dstrs = 0 && is_enum
) cases ([], true)
in
let uid = Uid.of_dolmen ty_c in
if is_enum && not contains_adts
then (
let ty = Ty.tsum uid cns in
Cache.store_ty ty_c ty;
(* If it's an enum we don't need the second iteration. *)
acc
)
else (
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else Ty.t_adt uid tyvl
in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc
)
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else Ty.t_adt uid tyvl
in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc

| Abstract ->
assert false (* unreachable in the second iteration *)
) [] (List.rev rev_tdefs)
Expand Down Expand Up @@ -1050,9 +1008,7 @@ let rec mk_expr
match Cache.find_ty ty_c with
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]
| Ty.Tsum _ as ty ->
let cstr = E.mk_constr (Uid.of_dolmen cstr) [] ty in
E.mk_eq ~iff:false (aux_mk_expr x) cstr

| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
Expand Down
33 changes: 13 additions & 20 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ module Types = struct
if List.length lty <> List.length lty' then
Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc;
lty'
| Ty.Tsum (s, _) ->
if List.length lty <> 0 then
Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc;
[]
| _ -> assert false

let equal_pp_vars lpp lvars =
Expand Down Expand Up @@ -145,13 +141,6 @@ module Types = struct
| Abstract ->
let ty = Ty.text ty_vars (Uid.of_string id) in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Enum lc ->
if not (Lists.is_empty ty_vars) then
Errors.typing_error (PolymorphicEnum id) loc;
let ty =
Ty.tsum (Uid.of_string id) (List.map Uid.of_string lc)
in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Record (record_constr, lbs) ->
let lbs =
List.map (fun (x, pp) -> x, ty_of_pp loc env None pp) lbs in
Expand All @@ -171,6 +160,10 @@ module Types = struct
from_labels =
List.fold_left
(fun fl (l,_) -> MString.add l id fl) env.from_labels lbs }
| Enum l ->
let body = List.map (fun constr -> Uid.of_string constr, []) l in
let ty = Ty.t_adt ~body:(Some body) (Uid.of_string id) [] in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Algebraic l ->
let l = (* convert ppure_type to Ty.t in l *)
List.map (fun (constr, l) ->
Expand Down Expand Up @@ -276,7 +269,9 @@ module Env = struct
let add_fpa_enum map =
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tsum (_, cstrs) ->
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.fold_left
(fun m c ->
match Fpa_rounding.translate_smt_rounding_mode
Expand All @@ -300,8 +295,10 @@ module Env = struct

let find_builtin_cstr ty n =
match ty with
| Ty.Tsum (_, cstrs) ->
List.find (Uid.equal n) cstrs
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.find (fun c -> String.equal n @@ Uid.show c) cstrs
| _ -> assert false

let add_fpa_builtins env =
Expand All @@ -327,9 +324,9 @@ module Env = struct
let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in
let tname = Fpa_rounding.fpa_rounding_mode_ae_type_name in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode (Uid.of_string nte)) in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
let float64d = float64 (mode (Uid.of_string nte)) in
let float64d = float64 (mode nte) in
let op n op profile =
MString.add n @@ `Term (Symbols.Op op, profile, Other)
in
Expand Down Expand Up @@ -1007,8 +1004,6 @@ let rec type_term ?(call_from_type_form=false) env f =
end
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr; destrs = lbs}]
| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e; destrs = []}) l
| _ -> Errors.typing_error (ShouldBeADT ty) loc
in
let pats =
Expand Down Expand Up @@ -1419,8 +1414,6 @@ and type_form ?(in_theory=false) env f =
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr ; destrs = lbs}]

| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e ; destrs = []}) l
| _ ->
Errors.typing_error (ShouldBeADT ty) f.pp_loc
in
Expand Down
10 changes: 5 additions & 5 deletions src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,12 @@ let pick_delayed_destructor env uf =
if Domain.cardinal d > 1 then
raise_notrace @@ Found (r, destr)
else
()
| _ ->
()
| _ ->
()
) env.delayed;
None
with Found (r, d) -> Some (r, d)
) env.delayed;
None
with Found (r, d) -> Some (r, d)

let is_enum ty c =
match ty with
Expand Down
59 changes: 19 additions & 40 deletions src/lib/reasoners/relation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ module Rel3 : Sig_rel.RELATION = Bitv_rel

module Rel4 : Sig_rel.RELATION = Arrays_rel

module Rel5 : Sig_rel.RELATION = Enum_rel
module Rel5 : Sig_rel.RELATION = Adt_rel

module Rel6 : Sig_rel.RELATION = Adt_rel

module Rel7 : Sig_rel.RELATION = Ite_rel
module Rel6 : Sig_rel.RELATION = Ite_rel

(* This value is unused. *)
let timer = Timers.M_None
Expand All @@ -53,7 +51,6 @@ type t = {
r4: Rel4.t;
r5: Rel5.t;
r6: Rel6.t;
r7: Rel7.t;
}

let empty uf =
Expand All @@ -63,8 +60,7 @@ let empty uf =
let r4, doms4 = Rel4.empty (Uf.set_domains uf doms3) in
let r5, doms5 = Rel5.empty (Uf.set_domains uf doms4) in
let r6, doms6 = Rel6.empty (Uf.set_domains uf doms5) in
let r7, doms7 = Rel7.empty (Uf.set_domains uf doms6) in
{r1; r2; r3; r4; r5; r6; r7}, doms7
{r1; r2; r3; r4; r5; r6}, doms6

let (|@|) l1 l2 =
if l1 == [] then l2
Expand Down Expand Up @@ -97,14 +93,10 @@ let assume env uf sa =
Timers.with_timer Rel6.timer Timers.F_assume @@ fun () ->
Rel6.assume env.r6 (Uf.set_domains uf doms5) sa
in
let env7, doms7, ({ assume = a7; remove = rm7}:_ Sig_rel.result) =
Timers.with_timer Rel7.timer Timers.F_assume @@ fun () ->
Rel7.assume env.r7 (Uf.set_domains uf doms6) sa
in
{r1=env1; r2=env2; r3=env3; r4=env4; r5=env5; r6=env6; r7=env7},
doms7,
({ assume = a1 |@| a2 |@| a3 |@| a4 |@| a5 |@| a6 |@| a7;
remove = rm1 |@| rm2 |@| rm3 |@| rm4 |@| rm5 |@| rm6 |@| rm7}
{r1=env1; r2=env2; r3=env3; r4=env4; r5=env5; r6=env6},
doms6,
({ assume = a1 |@| a2 |@| a3 |@| a4 |@| a5 |@| a6;
remove = rm1 |@| rm2 |@| rm3 |@| rm4 |@| rm5 |@| rm6 }
: _ Sig_rel.result)

let assume_th_elt env th_elt dep =
Expand All @@ -115,8 +107,7 @@ let assume_th_elt env th_elt dep =
let env4 = Rel4.assume_th_elt env.r4 th_elt dep in
let env5 = Rel5.assume_th_elt env.r5 th_elt dep in
let env6 = Rel6.assume_th_elt env.r6 th_elt dep in
let env7 = Rel7.assume_th_elt env.r7 th_elt dep in
{r1=env1; r2=env2; r3=env3; r4=env4; r5=env5; r6=env6; r7=env7}
{r1=env1; r2=env2; r3=env3; r4=env4; r5=env5; r6=env6}

let try_query (type a) (module R : Sig_rel.RELATION with type t = a) env uf a
k =
Expand All @@ -132,8 +123,7 @@ let query env uf a =
try_query (module Rel3) env.r3 uf a @@ fun () ->
try_query (module Rel4) env.r4 uf a @@ fun () ->
try_query (module Rel5) env.r5 uf a @@ fun () ->
try_query (module Rel6) env.r6 uf a @@ fun () ->
try_query (module Rel7) env.r7 uf a @@ fun () -> None
try_query (module Rel6) env.r6 uf a @@ fun () -> None

let case_split env uf ~for_model =
Options.exec_thread_yield ();
Expand All @@ -143,8 +133,7 @@ let case_split env uf ~for_model =
let seq4 = Rel4.case_split env.r4 uf ~for_model in
let seq5 = Rel5.case_split env.r5 uf ~for_model in
let seq6 = Rel6.case_split env.r6 uf ~for_model in
let seq7 = Rel7.case_split env.r7 uf ~for_model in
let splits = [seq1; seq2; seq3; seq4; seq5; seq6; seq7] in
let splits = [seq1; seq2; seq3; seq4; seq5; seq6] in
let splits = List.fold_left (|@|) [] splits in
List.fast_sort
(fun (_ ,_ , sz1) (_ ,_ , sz2) ->
Expand Down Expand Up @@ -181,8 +170,7 @@ let add env uf r t =
let r4, doms4, eqs4 =Rel4.add env.r4 (Uf.set_domains uf doms3) r t in
let r5, doms5, eqs5 =Rel5.add env.r5 (Uf.set_domains uf doms4) r t in
let r6, doms6, eqs6 =Rel6.add env.r6 (Uf.set_domains uf doms5) r t in
let r7, doms7, eqs7 =Rel7.add env.r7 (Uf.set_domains uf doms6) r t in
{r1;r2;r3;r4;r5;r6;r7;},doms7,eqs1|@|eqs2|@|eqs3|@|eqs4|@|eqs5|@|eqs6|@|eqs7
{r1;r2;r3;r4;r5;r6}, doms6, eqs1|@|eqs2|@|eqs3|@|eqs4|@|eqs5|@|eqs6


let instantiate ~do_syntactic_matching t_match env uf selector =
Expand All @@ -199,22 +187,13 @@ let instantiate ~do_syntactic_matching t_match env uf selector =
Rel5.instantiate ~do_syntactic_matching t_match env.r5 uf selector in
let r6, l6 =
Rel6.instantiate ~do_syntactic_matching t_match env.r6 uf selector in
let r7, l7 =
Rel7.instantiate ~do_syntactic_matching t_match env.r7 uf selector in
{r1=r1; r2=r2; r3=r3; r4=r4; r5=r5; r6=r6; r7=r7},
l7 |@| l6 |@| l5 |@| l4 |@| l3 |@| l2 |@| l1
{r1=r1; r2=r2; r3=r3; r4=r4; r5=r5; r6=r6},
l6 |@| l5 |@| l4 |@| l3 |@| l2 |@| l1

let new_terms env =
let t1 = Rel1.new_terms env.r1 in
let t2 = Rel2.new_terms env.r2 in
let t3 = Rel3.new_terms env.r3 in
let t4 = Rel4.new_terms env.r4 in
let t5 = Rel5.new_terms env.r5 in
let t6 = Rel6.new_terms env.r6 in
let t7 = Rel7.new_terms env.r7 in
Expr.Set.union t1
(Expr.Set.union t2
(Expr.Set.union t3
(Expr.Set.union t4
(Expr.Set.union t5
(Expr.Set.union t6 t7)) )))
Rel1.new_terms env.r1
|> Expr.Set.union @@ Rel2.new_terms env.r2
|> Expr.Set.union @@ Rel3.new_terms env.r3
|> Expr.Set.union @@ Rel4.new_terms env.r4
|> Expr.Set.union @@ Rel5.new_terms env.r5
|> Expr.Set.union @@ Rel6.new_terms env.r6
Loading

0 comments on commit 799ec51

Please sign in to comment.