Skip to content

Commit

Permalink
Merge Enum Theory into ADT Theory
Browse files Browse the repository at this point in the history
After refactoring both `Enum` and `ADT` theories, they shared most of
their implementation.

This PR merges `Enum` theory into `ADT` ones.

To prevent some regressions on the tests containing `enum` types, we
performs casesplits on it, even if the flag `--enable-adts-cs` is not
used. More precisely, the ADT casesplits works as follows:
- if we're not generating a model, we look for a tightenable constructor,
  that is a constructor without payload. We take a tightenable
  constructor with the largest domain.
- if there is no more such constructor and the flag `--enable-adts-cs`
  is turn on, we try to find a contradiction by propagating delayed
  destructors.
- if we're generating a model, we look for a constructor in a domain
  without restriction on it. The function `Ty.cons_weight` ensures the
  termination of this algorithm.
  • Loading branch information
Halbaroth committed Apr 13, 2024
1 parent feda342 commit 620dbbd
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 959 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
Cnf D_cnf D_loop D_state_option Input Frontend Parsed_interface Typechecker
Models
; reasoners
Ac Arith Arrays_rel Bitv Ccx Shostak Relation Enum Enum_rel
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals Ite_rel Matching Matching_types
Polynome Records Records_rel Satml_frontend_hybrid Satml_frontend Satml
Expand Down
33 changes: 13 additions & 20 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ let ty name ty =
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty

let builtin_enum = function
| Ty.Tsum (name, cstrs) as ty_ ->
| Ty.Tadt (name, params) as ty_ ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
let ty_cst =
DStd.Expr.Id.mk ~builtin:B.Base
(DStd.Path.global (Hstring.view name))
Expand Down Expand Up @@ -551,7 +553,6 @@ and handle_ty_app ?(update = false) ty_c l =
in
apply_ty_substs tysubsts ty

| Tsum _ as ty -> ty
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

Expand Down Expand Up @@ -617,17 +618,14 @@ let mk_ty_decl (ty_c: DE.ty_cst) =
(name, List.rev rev_fields) :: accl, is_enum
) ([], true) cases
in
if is_enum
then
let cstrs =
List.map (fun s -> fst s) (List.rev rev_cs)
in
let ty = Ty.tsum name cstrs in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty
else
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body name tyvl in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty
let body =
if is_enum then
Some (List.map (fun s -> fst s, []) (List.rev rev_cs))
else
Some (List.rev rev_cs)
in
let ty = Ty.t_adt ~body name tyvl in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty

| None | Some Abstract ->
let name = get_basename ty_c.path in
Expand Down Expand Up @@ -749,7 +747,8 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
in
if is_enum
then (
let ty = Ty.tsum name cns in
let body = Some (List.map (fun c -> c, []) cns) in
let ty = Ty.t_adt ~body name [] in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty;
(* If it's an enum we don't need the second iteration. *)
acc
Expand Down Expand Up @@ -1078,12 +1077,6 @@ let rec mk_expr
match Cache.find_ty (DE.Ty.Const.hash ty_c) with
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]
| Ty.Tsum _ as ty ->
let cstr =
let sy = Sy.Op (Sy.Constr (Hstring.make name)) in
E.mk_term sy [] ty
in
E.mk_eq ~iff:false (aux_mk_expr x) cstr
| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
Expand Down
19 changes: 8 additions & 11 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ module Types = struct
if List.length lty <> List.length lty' then
Errors.typing_error (WrongNumberofArgs (Hstring.view s)) loc;
lty'
| Ty.Tsum (s, _) ->
if List.length lty <> 0 then
Errors.typing_error (WrongNumberofArgs (Hstring.view s)) loc;
[]
| _ -> assert false

let equal_pp_vars lpp lvars =
Expand Down Expand Up @@ -151,7 +147,8 @@ module Types = struct
| Enum lc ->
if not (Lists.is_empty ty_vars) then
Errors.typing_error (PolymorphicEnum id) loc;
let ty = Ty.tsum id lc in
let body = List.map (fun c -> c, []) lc in
let ty = Ty.t_adt ~body:(Some body) id [] in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Record (record_constr, lbs) ->
let lbs =
Expand Down Expand Up @@ -262,7 +259,9 @@ module Env = struct
let add_fpa_enum map =
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tsum (_, cstrs) ->
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.fold_left
(fun m c ->
match Fpa_rounding.translate_smt_rounding_mode c with
Expand All @@ -285,7 +284,9 @@ module Env = struct

let find_builtin_cstr ty n =
match ty with
| Ty.Tsum (_, cstrs) ->
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.find (fun c -> String.equal n @@ Hstring.view c) cstrs
| _ -> assert false

Expand Down Expand Up @@ -979,8 +980,6 @@ let rec type_term ?(call_from_type_form=false) env f =
end
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr; destrs = lbs}]
| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e; destrs = []}) l
| _ -> Errors.typing_error (ShouldBeADT ty) loc
in
let pats =
Expand Down Expand Up @@ -1391,8 +1390,6 @@ and type_form ?(in_theory=false) env f =
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr ; destrs = lbs}]

| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e ; destrs = []}) l
| _ ->
Errors.typing_error (ShouldBeADT ty) f.pp_loc
in
Expand Down
141 changes: 84 additions & 57 deletions src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ module Domains = struct
in
acc, { t with changed = SX.empty }

let iter f t = MX.iter f t.domains
let fold f t = MX.fold f t.domains
end

let calc_destructor d e uf =
Expand Down Expand Up @@ -349,7 +349,8 @@ let is_tightenable ty c =
match ty with
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
Lists.is_empty @@ Ty.assoc_destrs c cases
let b = Lists.is_empty @@ Ty.assoc_destrs c cases in
b
| _ -> assert false

(* Update the domains of the semantic values [r1] and [r2] according to the
Expand Down Expand Up @@ -600,71 +601,97 @@ let constr_of_destr ty d =

| _ -> assert false

exception Found of X.r * Hstring.t

(* Pick a delayed destructor application in [env.delayed]. Returns [None]
if there is no delayed destructor. *)
let pick_delayed_destructor env =
try
Rel_utils.Delayed.iter_delayed
(fun r sy _e ->
match sy with
| Sy.Destruct destr ->
let d = Domains.get r env.domains in
if Domain.cardinal d > 1 then
raise_notrace @@ Found (r, destr)
else
let pick_delayed_destructor =
let exception Found of X.r * Hstring.t
in fun env ->
try
Rel_utils.Delayed.iter_delayed
(fun r sy _e ->
match sy with
| Sy.Destruct destr ->
let d = Domains.get r env.domains in
if Domain.cardinal d > 1 then
raise_notrace @@ Found (r, destr)
else
()
| _ ->
()
| _ ->
()
) env.delayed;
None
with Found (r, d) -> Some (r, d)
) env.delayed;
None
with Found (r, d) -> Some (r, d)

let can_split env n =
let m = Options.get_max_split () in
Numbers.Q.(compare (mult n env.size_splits) m) <= 0 || Numbers.Q.sign m < 0

(* Do a case-split by choosing a semantic value [r] and constructor [c]
for which there are delayed destructor applications and propagate the
literal [(not (_ is c) r)]. *)
let case_split env _uf ~for_model =
if Options.get_disable_adts ()
|| not (Options.get_enable_adts_cs () || for_model)
then
[]
else if for_model then
try
Domains.iter
(fun r d ->
if Domain.cardinal d > 1 then
let c = Domain.choose d in
raise_notrace @@ Found (r, c)
) env.domains;
[]
with Found (r, c) ->
match build_constr_eq r c with
| Some (_, cons) ->
let nr, _ = X.make cons in
let cs = LR.mkv_eq r nr in
if Options.get_debug_adt () then
Printer.print_dbg ~flushed:false
~module_name:"Adt_rel" ~function_name:"case_split"
"Assume %a = %a" X.print r Hstring.print c;
[ cs, true, Th_util.CS (Th_util.Th_adt, two) ]
| None -> assert false
else begin
let case_split env uf ~for_model =
if not @@ Options.get_disable_adts () then begin
if Options.get_debug_adt () then Debug.pp_env "before cs" env;
match pick_delayed_destructor env with
| Some (r, d) ->
if Options.get_debug_adt () then
Printer.print_dbg ~flushed:false
~module_name:"Adt_rel" ~function_name:"case_split"
"found r = %a and d = %a@ " X.print r Hstring.print d;
(* CS on negative version would be better in general. *)
let c = constr_of_destr (X.type_info r) d in
let cs = LR.mkv_builtin false (Sy.IsConstr c) [r] in
[ cs, true, Th_util.CS (Th_util.Th_adt, two) ]
(* Do a case-split by choosing a constructor for class representatives of
minimal size. *)
let best =
Domains.fold
(fun r d best ->
let rr, _ = Uf.find_r uf r in
match Th.embed rr with
| Constr _ ->
best
| _ ->
let cd = Domain.cardinal d in
let c = Domain.choose d in
if for_model || is_tightenable (X.type_info r) c then
match best with
| Some (n, _, _) when n <= cd -> best
| _ -> Some (cd, r, c)
else
best
) env.domains None
in
match best with
| Some (n, r, c) ->
let n = Numbers.Q.from_int n in
if for_model || can_split env n then
let _, cons = Option.get @@ build_constr_eq r c in
let nr, _ = X.make cons in
let cs = LR.mkv_eq r nr in
if Options.get_debug_adt () then
Printer.print_dbg ~flushed:false
~module_name:"Adt_rel" ~function_name:"case_split"
"Assume %a = %a" X.print r Hstring.print c;
[ cs, true, Th_util.CS (Th_util.Th_adt, two) ]
else begin
Debug.no_case_split ();
[]
end

| None ->
Debug.no_case_split ();
[]
if Options.get_enable_adts_cs () then begin
match pick_delayed_destructor env with
| Some (r, d) ->
if Options.get_debug_adt () then
Printer.print_dbg ~flushed:false
~module_name:"Adt_rel" ~function_name:"case_split"
"found r = %a and d = %a@ " X.print r Hstring.print d;
(* CS on negative version would be better in general. *)
let c = constr_of_destr (X.type_info r) d in
let cs = LR.mkv_builtin false (Sy.IsConstr c) [r] in
[ cs, true, Th_util.CS (Th_util.Th_adt, two) ]
| None ->
Debug.no_case_split ();
[]
end
else begin
Debug.no_case_split ();
[]
end
end
else
[]

let optimizing_objective _env _uf _o = None

Expand Down
Loading

0 comments on commit 620dbbd

Please sign in to comment.