Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coq: more efficient equality decision procedures for enums #495

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 93 additions & 7 deletions src/sail_coq_backend/pretty_print_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,7 @@ let rec doc_range ctxt (BF_aux(r,_)) = match r with
*)

(* TODO: check use of empty_ctxt below doesn't cause problems due to missing info *)
let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, annot))) =
let doc_typdef types_mod avoid_target_names generic_eq_types enum_number_defs (TD_aux (td, (l, annot))) =
let bare_ctxt = { empty_ctxt with avoid_target_names } in
match td with
| TD_abbrev (id, typq, A_aux (A_typ typ, _)) ->
Expand Down Expand Up @@ -2632,7 +2632,58 @@ let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, an
let typ_pp =
(doc_op coloneq) (concat [string "Inductive"; space; id_pp]) (ifflat empty (pipe ^^ space) ^^ enums_doc)
in
let eq1_pp = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in
(* If we have conversion functions to Z, put them here and
derive a decision procedure that's efficient even for
large enums. *)
let eq1_pp =
let fallback = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in
match (Bindings.find_opt id (fst enum_number_defs), Bindings.find_opt id (snd enum_number_defs)) with
| Some (num_of_id, num_of_pp), Some (of_num_id, of_num_pp) ->
let num_of_id_pp = doc_id bare_ctxt num_of_id in
let of_num_id_pp = doc_id bare_ctxt of_num_id in
let lemma1 =
separate hardline
[
string "Lemma " ^^ id_pp ^^ string "_num_of_roundtrip "
^^ parens (string "x : " ^^ id_pp)
^^ string " : " ^^ of_num_id_pp ^^ space
^^ parens (num_of_id_pp ^^ string " x")
^^ string " = x.";
string "destruct x; reflexivity.";
string "Qed.";
]
in
let lemma2 =
separate hardline
[
string "Lemma " ^^ num_of_id_pp ^^ string "_injective "
^^ parens (string "x y : " ^^ id_pp)
^^ string " : " ^^ num_of_id_pp ^^ string " x = " ^^ num_of_id_pp ^^ string " y -> x = y.";
string "intro.";
string "rewrite <- (" ^^ id_pp ^^ string "_num_of_roundtrip x).";
string "rewrite <- (" ^^ id_pp ^^ string "_num_of_roundtrip y).";
string "congruence.";
string "Qed.";
]
in
let eq_pp =
separate hardline
[
string "Definition " ^^ id_pp ^^ string "_eq_dec (x y : " ^^ id_pp
^^ string ") : {x = y} + {x <> y}.";
string "refine (match Z.eq_dec (" ^^ num_of_id_pp ^^ string " x) (" ^^ num_of_id_pp
^^ string " y) with";
string "| left e => left (" ^^ num_of_id_pp ^^ string "_injective x y e)";
string "| right ne => right _";
string "end).";
string "congruence.";
string "Defined.";
]
in
num_of_pp ^^ of_num_pp ^^ separate hardline [lemma1; lemma2; eq_pp]
| Some (_, pp), None | None, Some (_, pp) -> pp ^^ fallback
| None, None -> fallback
in
let eq2_pp =
string "#[export]" ^^ hardline
^^ group
Expand Down Expand Up @@ -3364,14 +3415,15 @@ let doc_val avoid_target_names pat exp =
^^ group (separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."])
^^ hardline

let doc_def types_mod unimplemented avoid_target_names generic_eq_types effect_info (DEF_aux (aux, def_annot) as def) =
let doc_def types_mod unimplemented avoid_target_names generic_eq_types enum_number_defs effect_info
(DEF_aux (aux, def_annot) as def) =
match aux with
| DEF_val v_spec -> doc_val_spec def_annot unimplemented avoid_target_names effect_info v_spec
| DEF_fixity _ -> empty
| DEF_overload _ -> empty
| DEF_type t_def ->
if List.mem (string_of_id (id_of_type_def t_def)) !opt_extern_types <> !opt_generate_extern_types then empty
else doc_typdef types_mod avoid_target_names generic_eq_types t_def
else doc_typdef types_mod avoid_target_names generic_eq_types enum_number_defs t_def
| DEF_register dec -> group (doc_dec avoid_target_names dec)
| DEF_default df -> empty
| DEF_fundef fdef -> group (doc_fundef types_mod avoid_target_names effect_info fdef) ^/^ hardline
Expand Down Expand Up @@ -3540,8 +3592,6 @@ let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_m
in
let is_typ_def = function DEF_aux (DEF_type _, _) -> true | _ -> false in
let exc_typ = find_exc_typ defs in
let typdefs, defs = List.partition is_typ_def defs in
let statedefs, defs = List.partition is_state_def defs in
let unimplemented = find_unimplemented defs in
let avoid_target_names = builtin_target_names defs in
let bare_doc_id = doc_id { empty_ctxt with avoid_target_names } in
Expand Down Expand Up @@ -3603,7 +3653,43 @@ let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_m
@ mr_m
)
in
let doc_def = doc_def type_defs_module unimplemented avoid_target_names generic_eq_types effect_info in
let enums = Type_check.Env.get_enums type_env in
let defs, enum_number_defs =
let doc_def =
doc_def type_defs_module unimplemented avoid_target_names generic_eq_types (Bindings.empty, Bindings.empty)
effect_info
in
let num_of_map, of_num_map, rdefs =
List.fold_left
(fun (num_of_map, of_num_map, rdefs) def ->
match def with
| DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, _), _)]), _)), _) -> begin
match Type_check.Env.get_val_spec id type_env with
| _, Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> begin
match (arg_typ, ret_typ) with
| Typ_aux (Typ_id arg_id, _), _
when Bindings.mem arg_id enums && string_of_id id = "num_of_" ^ string_of_id arg_id ->
(Bindings.add arg_id (id, doc_def def) num_of_map, of_num_map, rdefs)
| _, Typ_aux (Typ_id ret_id, _)
when Bindings.mem ret_id enums && string_of_id id = string_of_id ret_id ^ "_of_num" ->
(num_of_map, Bindings.add ret_id (id, doc_def def) of_num_map, rdefs)
| _ -> (num_of_map, of_num_map, def :: rdefs)
end
| _ -> (num_of_map, of_num_map, def :: rdefs)
end
| _ -> (num_of_map, of_num_map, def :: rdefs)
)
(Bindings.empty, Bindings.empty, []) defs
in
(List.rev rdefs, (num_of_map, of_num_map))
in

let typdefs, defs = List.partition is_typ_def defs in
let statedefs, defs = List.partition is_state_def defs in

let doc_def =
doc_def type_defs_module unimplemented avoid_target_names generic_eq_types enum_number_defs effect_info
in
let () =
if !opt_undef_axioms || IdSet.is_empty unimplemented then ()
else
Expand Down
Loading