From 8ed25bfef25ebf9ccd890f37eda498cd9c2db499 Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Thu, 11 Apr 2024 14:26:50 +0100 Subject: [PATCH] Coq: more efficient equality decision procedures for enums --- src/sail_coq_backend/pretty_print_coq.ml | 100 +++++++++++++++++++++-- 1 file changed, 93 insertions(+), 7 deletions(-) diff --git a/src/sail_coq_backend/pretty_print_coq.ml b/src/sail_coq_backend/pretty_print_coq.ml index 3e44b4922..b42687ca6 100644 --- a/src/sail_coq_backend/pretty_print_coq.ml +++ b/src/sail_coq_backend/pretty_print_coq.ml @@ -2377,7 +2377,7 @@ let rec doc_range ctxt (BF_aux(r,_)) = match r with *) (* TODO: check use of empty_ctxt below doesn't cause problems due to missing info *) -let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, annot))) = +let doc_typdef types_mod avoid_target_names generic_eq_types enum_number_defs (TD_aux (td, (l, annot))) = let bare_ctxt = { empty_ctxt with avoid_target_names } in match td with | TD_abbrev (id, typq, A_aux (A_typ typ, _)) -> @@ -2632,7 +2632,58 @@ let doc_typdef types_mod avoid_target_names generic_eq_types (TD_aux (td, (l, an let typ_pp = (doc_op coloneq) (concat [string "Inductive"; space; id_pp]) (ifflat empty (pipe ^^ space) ^^ enums_doc) in - let eq1_pp = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in + (* If we have conversion functions to Z, put them here and + derive a decision procedure that's efficient even for + large enums. *) + let eq1_pp = + let fallback = string "Scheme Equality for" ^^ space ^^ id_pp ^^ dot in + match (Bindings.find_opt id (fst enum_number_defs), Bindings.find_opt id (snd enum_number_defs)) with + | Some (num_of_id, num_of_pp), Some (of_num_id, of_num_pp) -> + let num_of_id_pp = doc_id bare_ctxt num_of_id in + let of_num_id_pp = doc_id bare_ctxt of_num_id in + let lemma1 = + separate hardline + [ + string "Lemma " ^^ id_pp ^^ string "_num_of_roundtrip " + ^^ parens (string "x : " ^^ id_pp) + ^^ string " : " ^^ of_num_id_pp ^^ space + ^^ parens (num_of_id_pp ^^ string " x") + ^^ string " = x."; + string "destruct x; reflexivity."; + string "Qed."; + ] + in + let lemma2 = + separate hardline + [ + string "Lemma " ^^ num_of_id_pp ^^ string "_injective " + ^^ parens (string "x y : " ^^ id_pp) + ^^ string " : " ^^ num_of_id_pp ^^ string " x = " ^^ num_of_id_pp ^^ string " y -> x = y."; + string "intro."; + string "rewrite <- (" ^^ id_pp ^^ string "_num_of_roundtrip x)."; + string "rewrite <- (" ^^ id_pp ^^ string "_num_of_roundtrip y)."; + string "congruence."; + string "Qed."; + ] + in + let eq_pp = + separate hardline + [ + string "Definition " ^^ id_pp ^^ string "_eq_dec (x y : " ^^ id_pp + ^^ string ") : {x = y} + {x <> y}."; + string "refine (match Z.eq_dec (" ^^ num_of_id_pp ^^ string " x) (" ^^ num_of_id_pp + ^^ string " y) with"; + string "| left e => left (" ^^ num_of_id_pp ^^ string "_injective x y e)"; + string "| right ne => right _"; + string "end)."; + string "congruence."; + string "Defined."; + ] + in + num_of_pp ^^ of_num_pp ^^ separate hardline [lemma1; lemma2; eq_pp] + | Some (_, pp), None | None, Some (_, pp) -> pp ^^ fallback + | None, None -> fallback + in let eq2_pp = string "#[export]" ^^ hardline ^^ group @@ -3364,14 +3415,15 @@ let doc_val avoid_target_names pat exp = ^^ group (separate space [string "#[export] Hint Unfold"; idpp; colon; string "sail."]) ^^ hardline -let doc_def types_mod unimplemented avoid_target_names generic_eq_types effect_info (DEF_aux (aux, def_annot) as def) = +let doc_def types_mod unimplemented avoid_target_names generic_eq_types enum_number_defs effect_info + (DEF_aux (aux, def_annot) as def) = match aux with | DEF_val v_spec -> doc_val_spec def_annot unimplemented avoid_target_names effect_info v_spec | DEF_fixity _ -> empty | DEF_overload _ -> empty | DEF_type t_def -> if List.mem (string_of_id (id_of_type_def t_def)) !opt_extern_types <> !opt_generate_extern_types then empty - else doc_typdef types_mod avoid_target_names generic_eq_types t_def + else doc_typdef types_mod avoid_target_names generic_eq_types enum_number_defs t_def | DEF_register dec -> group (doc_dec avoid_target_names dec) | DEF_default df -> empty | DEF_fundef fdef -> group (doc_fundef types_mod avoid_target_names effect_info fdef) ^/^ hardline @@ -3540,8 +3592,6 @@ let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_m in let is_typ_def = function DEF_aux (DEF_type _, _) -> true | _ -> false in let exc_typ = find_exc_typ defs in - let typdefs, defs = List.partition is_typ_def defs in - let statedefs, defs = List.partition is_state_def defs in let unimplemented = find_unimplemented defs in let avoid_target_names = builtin_target_names defs in let bare_doc_id = doc_id { empty_ctxt with avoid_target_names } in @@ -3603,7 +3653,43 @@ let pp_ast_coq (types_file, types_modules) (defs_file, defs_modules) type_defs_m @ mr_m ) in - let doc_def = doc_def type_defs_module unimplemented avoid_target_names generic_eq_types effect_info in + let enums = Type_check.Env.get_enums type_env in + let defs, enum_number_defs = + let doc_def = + doc_def type_defs_module unimplemented avoid_target_names generic_eq_types (Bindings.empty, Bindings.empty) + effect_info + in + let num_of_map, of_num_map, rdefs = + List.fold_left + (fun (num_of_map, of_num_map, rdefs) def -> + match def with + | DEF_aux (DEF_fundef (FD_aux (FD_function (_, _, [FCL_aux (FCL_funcl (id, _), _)]), _)), _) -> begin + match Type_check.Env.get_val_spec id type_env with + | _, Typ_aux (Typ_fn ([arg_typ], ret_typ), _) -> begin + match (arg_typ, ret_typ) with + | Typ_aux (Typ_id arg_id, _), _ + when Bindings.mem arg_id enums && string_of_id id = "num_of_" ^ string_of_id arg_id -> + (Bindings.add arg_id (id, doc_def def) num_of_map, of_num_map, rdefs) + | _, Typ_aux (Typ_id ret_id, _) + when Bindings.mem ret_id enums && string_of_id id = string_of_id ret_id ^ "_of_num" -> + (num_of_map, Bindings.add ret_id (id, doc_def def) of_num_map, rdefs) + | _ -> (num_of_map, of_num_map, def :: rdefs) + end + | _ -> (num_of_map, of_num_map, def :: rdefs) + end + | _ -> (num_of_map, of_num_map, def :: rdefs) + ) + (Bindings.empty, Bindings.empty, []) defs + in + (List.rev rdefs, (num_of_map, of_num_map)) + in + + let typdefs, defs = List.partition is_typ_def defs in + let statedefs, defs = List.partition is_state_def defs in + + let doc_def = + doc_def type_defs_module unimplemented avoid_target_names generic_eq_types enum_number_defs effect_info + in let () = if !opt_undef_axioms || IdSet.is_empty unimplemented then () else