Skip to content

Commit

Permalink
Merge Records into ADT
Browse files Browse the repository at this point in the history
  • Loading branch information
Halbaroth committed Apr 19, 2024
1 parent f16ac0e commit bbb4feb
Show file tree
Hide file tree
Showing 19 changed files with 188 additions and 638 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals Ite_rel Matching Matching_types
Polynome Records Records_rel Satml_frontend_hybrid Satml_frontend Satml
Polynome Satml_frontend_hybrid Satml_frontend Satml
Sat_solver Sat_solver_sig Sig Sig_rel Theory Uf Use Rel_utils Bitlist
; structures
Commands Errors Explanation Fpa_rounding
Expand Down
18 changes: 13 additions & 5 deletions src/lib/frontend/cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,20 @@ let rec make_term quant_basename t =
E.mk_term (Sy.Op Sy.Concat)
[mk_term t1; mk_term t2] ty

| TTdot (t, s) ->
E.mk_term (Sy.Op (Sy.Access s)) [mk_term t] ty

| TTrecord lbs ->
| TTrecord (ty, lbs) ->
let lbs = List.map (fun (_, t) -> mk_term t) lbs in
E.mk_term (Sy.Op Sy.Record) lbs ty
let cstr =
match ty with
| Tadt (name, params, true) ->
begin
let Adt cases = Ty.type_body name params in
match cases with
| [{ constr; _ }] -> Hstring.view constr
| _ -> assert false
end
| _ -> assert false
in
E.mk_constr cstr lbs ty

| TTlet (binders, t2) ->
let binders =
Expand Down
145 changes: 10 additions & 135 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ let ty name ty =
Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty

let builtin_enum = function
| Ty.Tadt (name, params) as ty_ ->
| Ty.Tadt (name, params, _) as ty_ ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
let ty_cst =
Expand Down Expand Up @@ -505,86 +505,18 @@ let rec dty_to_ty ?(update = false) ?(is_var = false) dty =
| _ -> unsupported "Type %a" DE.Ty.print dty

and handle_ty_app ?(update = false) ty_c l =
(* Applies the substitutions in [tysubsts] to each encountered type
variable. *)
let rec apply_ty_substs tysubsts ty =
match ty with
| Ty.Tvar { v; _ } ->
Ty.M.find v tysubsts

| Text (tyl, hs) ->
Ty.Text (List.map (apply_ty_substs tysubsts) tyl, hs)

| Tfarray (ti, tv) ->
Tfarray (
apply_ty_substs tysubsts ti,
apply_ty_substs tysubsts tv
)

| Tadt (hs, tyl) ->
Tadt (hs, List.map (apply_ty_substs tysubsts) tyl)

| Trecord ({ args; lbs; _ } as rcrd) ->
Trecord {
rcrd with
args = List.map (apply_ty_substs tysubsts) args;
lbs = List.map (
fun (hs, t) ->
hs, apply_ty_substs tysubsts t
) lbs;
}

| _ -> ty
in
let tyl = List.map (dty_to_ty ~update) l in
(* Recover the initial versions of the types and apply them on the provided
type arguments stored in [tyl]. *)
match Cache.find_ty (DE.Ty.Const.hash ty_c) with
| Tadt (hs, _) -> Tadt (hs, tyl)

| Trecord { args; _ } as ty ->
let tysubsts =
List.fold_left2 (
fun acc tv ty ->
match tv with
| Ty.Tvar { v; _ } -> Ty.M.add v ty acc
| _ -> assert false
) Ty.M.empty args tyl
in
apply_ty_substs tysubsts ty
| Tadt (hs, _, record) -> Tadt (hs, tyl, record)

| Text (_, s) -> Text (tyl, s)
| _ -> assert false

(** Handles a simple type declaration. *)
let mk_ty_decl (ty_c: DE.ty_cst) =
match DT.definition ty_c with
| Some (
Adt { cases = [| { cstr = { id_ty; path; _ }; dstrs; _ } |]; _ }
) ->
(* Records and adts that only have one case are treated in the same way,
and considered as records. *)
let tyvl = Cache.store_ty_vars_ret id_ty in
let rev_lbs =
Array.fold_left (
fun acc c ->
match c with
| Some DE.{ path; id_ty; _ } ->
let pn = get_basename path in
let pty = dty_to_ty id_ty in
(pn, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c

) [] dstrs
in
let lbs = List.rev rev_lbs in
let record_constr = Format.asprintf "%a" DStd.Path.print path in
let ty = Ty.trecord ~record_constr tyvl (get_basename ty_c.path) lbs in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty

| Some (
(Adt { cases; _ } as _adt)
) ->
Expand Down Expand Up @@ -662,35 +594,9 @@ let mk_term_decl ({ id_ty; path; tags; _ } as tcst: DE.term_cst) =
let mk_mr_ty_decls (tdl: DE.ty_cst list) =
let handle_ty_decl (ty: Ty.t) (tdef: DE.Ty.def option) =
match ty, tdef with
| Trecord { args; name; record_constr; _ },
| Tadt (hs, tyl, _),
Some (
Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ }
) ->
let rev_lbs =
Array.fold_left (
fun acc c ->
match c with
| Some DE.{ path; id_ty; _ } ->
let pn = get_basename path in
let pty = dty_to_ty id_ty in
(pn, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c
) [] dstrs
in
let lbs = List.rev rev_lbs in
let name = Hstring.view name in
let record_constr = Hstring.view record_constr in
let ty =
Ty.trecord ~record_constr args name lbs
in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty

| Tadt (hs, tyl),
Some (
Adt { cases; ty = ty_c; _ }
Adt { cases; ty = ty_c; record }
) ->
let rev_cs =
Array.fold_left (
Expand All @@ -710,30 +616,18 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
in
let body = Some (List.rev rev_cs) in
let args = tyl in
let ty = Ty.t_adt ~body (Hstring.view hs) args in
let ty = Ty.t_adt ~record ~body (Hstring.view hs) args in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty

| _ -> assert false
in
(* If there are adts in the list of type declarations then records are
converted to adts, because that's how it's done in the legacy typechecker.
But it might be more efficient not to do that. *)
let rev_tdefs, contains_adts =
List.fold_left (
fun (acc, ca) ty_c ->
match DT.definition ty_c with
| Some (Adt { record; cases; _ }) as df
when not record && Array.length cases > 1 ->
df :: acc, true
| df -> df :: acc, ca
) ([], false) tdl
in
let rev_tdefs = List.rev_map DT.definition tdl in
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| Some (
(DE.Adt { cases; record; ty = ty_c; }) as adt
(DE.Adt { cases; ty = ty_c; _ }) as adt
) ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let name = get_basename ty_c.path in
Expand All @@ -754,15 +648,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
acc
)
else (
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
let record_constr =
Format.asprintf "%a" DStd.Path.print ty_c.path
in
Ty.trecord ~record_constr tyvl name []
else Ty.t_adt name tyvl
in
let ty = Ty.t_adt name tyvl in
Cache.store_ty (DE.Ty.Const.hash ty_c) ty;
(ty, Some adt) :: acc
)
Expand Down Expand Up @@ -1042,10 +928,7 @@ let rec mk_expr
let e = aux_mk_expr x in
let sy =
match Cache.find_ty (DE.Ty.Const.hash adt) with
| Trecord _ ->
Sy.Op (Sy.Access (Hstring.make name))
| Tadt _ ->
Sy.destruct name
| Tadt _ -> Sy.destruct name
| _ -> assert false
in
E.mk_term sy [e] ty
Expand Down Expand Up @@ -1077,11 +960,6 @@ let rec mk_expr
match Cache.find_ty (DE.Ty.Const.hash ty_c) with
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]
| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
replace the tester of a record by the true literal. *)
E.vrai
| _ -> assert false
end

Expand Down Expand Up @@ -1368,13 +1246,10 @@ let rec mk_expr
let name = get_basename path in
let ty = dty_to_ty term_ty in
begin match ty with
| Ty.Tadt (_, _) ->
| Ty.Tadt _ ->
let sy = Sy.Op (Sy.Constr (Hstring.make name)) in
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term sy l ty
| Ty.Trecord _ ->
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term (Sy.Op Sy.Record) l ty
| _ ->
Fmt.failwith
"Constructor error: %a does not belong to a record nor an\
Expand Down
22 changes: 1 addition & 21 deletions src/lib/frontend/models.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ module Pp_smtlib_term = struct
asprintf "%a" Ty.pp_smtlib t

let rec print fmt t =
let {Expr.f;xs;ty; _} = Expr.term_view t in
let {Expr.f;xs; _} = Expr.term_view t in
match f, xs with

| Sy.Lit lit, xs ->
Expand Down Expand Up @@ -149,26 +149,6 @@ module Pp_smtlib_term = struct
| Sy.Op Sy.Extract (i, j), [e] ->
fprintf fmt "%a^{%d,%d}" print e i j

| Sy.Op (Sy.Access field), [e] ->
if Options.get_output_smtlib () then
fprintf fmt "(%s %a)" (Hstring.view field) print e
else
fprintf fmt "%a.%s" print e (Hstring.view field)

| Sy.Op (Sy.Record), _ ->
begin match ty with
| Ty.Trecord { Ty.lbs = lbs; _ } ->
assert (List.length xs = List.length lbs);
fprintf fmt "{";
ignore (List.fold_left2 (fun first (field,_) e ->
fprintf fmt "%s%s = %a" (if first then "" else "; ")
(Hstring.view field) print e;
false
) true lbs xs);
fprintf fmt "}";
| _ -> assert false
end

(* TODO: introduce PrefixOp in the future to simplify this ? *)
| Sy.Op op, [e1; e2] when op == Sy.Pow || op == Sy.Integer_round ||
op == Sy.Max_real || op == Sy.Max_int ||
Expand Down
Loading

0 comments on commit bbb4feb

Please sign in to comment.