From 1575ddca6b0769a820dbd3a03e113c02a8a043bc Mon Sep 17 00:00:00 2001 From: Brian Campbell Date: Thu, 5 Oct 2023 10:54:00 +0100 Subject: [PATCH] Coq: more fixes for "C" tests - rewrite complex patterns in top-level let bindings - print recursive function with implicit decreasing arguments properly - avoid built-in type "vec" - rerun effect inference when exhaustivity rewrite introduces a match failure - don't generate autocast for an argument with an existentially bound size - correct Coq built-ins in several tests - add expected test failures --- src/lib/rewrites.ml | 101 +++++++++++++++++++---- src/sail_coq_backend/pretty_print_coq.ml | 29 +++++-- src/sail_coq_backend/sail_plugin_coq.ml | 1 + test/c/cfold_reg.sail | 6 +- test/c/cheri128_hsb.sail | 4 +- test/c/enum_match.sail | 4 +- test/c/poly_pair.sail | 4 +- test/c/run_tests.py | 18 ++++ 8 files changed, 132 insertions(+), 35 deletions(-) diff --git a/src/lib/rewrites.ml b/src/lib/rewrites.ml index 43ba69e3c..2f302987a 100644 --- a/src/lib/rewrites.ml +++ b/src/lib/rewrites.ml @@ -4145,7 +4145,8 @@ let rewrite_ast_realize_mappings effect_info env ast = (* Rewrite to make all pattern matches in Coq output exhaustive and remove redundant clauses. Assumes that guards, vector patterns, etc have been rewritten already, and the scattered functions have - been merged. + been merged. It also reruns effect inference if a pattern match + failure has to be added. Note: if this naive implementation turns out to be too slow or buggy, we could look at implementing Maranget JFP 17(3), 2007. @@ -4438,7 +4439,7 @@ module MakeExhaustive = struct let funcl_loc (FCL_aux (_, (def_annot, _))) = def_annot.loc - let rewrite_case (e, ann) = + let rewrite_case redo_effects (e, ann) = match e with | E_match (e1, cases) | E_try (e1, cases) -> begin let env = env_of_annot ann in @@ -4456,9 +4457,11 @@ module MakeExhaustive = struct let l = Parse_ast.Generated Parse_ast.Unknown in let p = P_aux (P_wild, (l, empty_tannot)) in + let l_ann = mk_tannot env unit_typ in let ann' = mk_tannot env (typ_of_annot ann) in (* TODO: use an expression that specifically indicates a failed pattern match *) - let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, l_ann))), (l, ann')) in + redo_effects := true; E_aux (rebuild (cases @ [Pat_aux (Pat_exp (p, b), (l, empty_tannot))]), ann) end | E_let (LB_aux (LB_val (pat, e1), lb_ann), e2) -> begin @@ -4474,9 +4477,11 @@ module MakeExhaustive = struct in let l = Parse_ast.Generated Parse_ast.Unknown in let p = P_aux (P_wild, (l, empty_tannot)) in + let l_ann = mk_tannot env unit_typ in let ann' = mk_tannot env (typ_of_annot ann) in (* TODO: use an expression that specifically indicates a failed pattern match *) - let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in + let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, l_ann))), (l, ann')) in + redo_effects := true; E_aux (E_match (e1, [Pat_aux (Pat_exp (pat, e2), ann); Pat_aux (Pat_exp (p, b), (l, empty_tannot))]), ann) end | _ -> E_aux (e, ann) @@ -4512,18 +4517,29 @@ module MakeExhaustive = struct FD_aux (FD_function (r, t, fcls' @ [default]), f_ann) - let rewrite env = - let alg = { id_exp_alg with e_aux = rewrite_case } in - rewrite_ast_base - { - rewrite_exp = (fun _ -> fold_exp alg); - rewrite_pat; - rewrite_let; - rewrite_lexp; - rewrite_fun; - rewrite_def; - rewrite_ast = rewrite_ast_base_progress "Make patterns exhaustive"; - } + let rewrite effect_info env ast = + let redo_effects = ref false in + let alg = { id_exp_alg with e_aux = rewrite_case redo_effects } in + let ast' = + rewrite_ast_base + { + rewrite_exp = (fun _ -> fold_exp alg); + rewrite_pat; + rewrite_let; + rewrite_lexp; + rewrite_fun; + rewrite_def; + rewrite_ast = rewrite_ast_base_progress "Make patterns exhaustive"; + } + ast + in + let effect_info' = + (* TODO: if we use this for anything other than Coq we'll need + to replace "true" with Target.asserts_termination target, + after plumbing target through to this rewrite. *) + if !redo_effects then Effects.infer_side_effects true ast' else effect_info + in + (ast', effect_info', env) end (* Splitting a function (e.g., an execute function on an AST) can produce @@ -4906,6 +4922,56 @@ let rewrite_truncate_hex_literals _type_env defs = { rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) } defs +(* Coq's Definition command doesn't allow patterns, so rewrite + top-level let bindings with complex patterns into a sequence of + single definitions. *) +let rewrite_toplevel_let_patterns env ast = + let is_pat_simple = function + | P_aux (P_typ (_, P_aux (P_id _id, _)), _) | P_aux (P_id _id, _) -> true + | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _) + | P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _) -> + Id.compare id (id_of_kid kid) == 0 + | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _) + | P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _) + when Id.compare app_id (mk_id "atom") == 0 -> + Id.compare id (id_of_kid kid) == 0 + | _ -> false + in + let rewrite_def = function + | DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), (l, annot))), def_annot) as def -> + if is_pat_simple pat then [def] + else ( + let ids = pat_ids pat in + let base_id = fresh_id "let" l in + let base_annot = mk_tannot env (typ_of exp) in + let base_def = + mk_def (DEF_let (LB_aux (LB_val (P_aux (P_id base_id, (l, base_annot)), exp), (l, empty_tannot)))) + in + let id_defs = + List.map + (fun id -> + let id_typ = match Env.lookup_id id env with Local (_, t) -> t | _ -> assert false in + let id_annot = (Parse_ast.Unknown, mk_tannot env id_typ) in + let def_body = + E_aux + ( E_let + ( LB_aux (LB_val (pat, E_aux (E_id base_id, (l, base_annot))), (l, empty_tannot)), + E_aux (E_id id, id_annot) + ), + id_annot + ) + in + mk_def (DEF_let (LB_aux (LB_val (P_aux (P_id id, id_annot), def_body), (l, empty_tannot)))) + ) + (IdSet.elements ids) + in + base_def :: id_defs + ) + | d -> [d] + in + let defs = List.map rewrite_def ast.defs |> List.concat in + { ast with defs } + let opt_mono_rewrites = ref false let opt_mono_complex_nexps = ref true @@ -5020,7 +5086,7 @@ let all_rewriters = ("add_bitvector_casts", basic_rewriter Monomorphise.add_bitvector_casts); ("remove_impossible_int_cases", basic_rewriter Constant_propagation.remove_impossible_int_cases); ("const_prop_mutrec", String_rewriter (fun target -> Base_rewriter (Constant_propagation_mutrec.rewrite_ast target))); - ("make_cases_exhaustive", basic_rewriter MakeExhaustive.rewrite); + ("make_cases_exhaustive", Base_rewriter MakeExhaustive.rewrite); ("undefined", Bool_rewriter (fun b -> basic_rewriter (rewrite_undefined_if_gen b))); ("vector_string_pats_to_bit_list", basic_rewriter rewrite_ast_vector_string_pats_to_bit_list); ("remove_not_pats", basic_rewriter rewrite_ast_not_pats); @@ -5074,6 +5140,7 @@ let all_rewriters = ) ); ("add_unspecified_rec", basic_rewriter rewrite_add_unspecified_rec); + ("toplevel_let_patterns", basic_rewriter rewrite_toplevel_let_patterns); ] let rewrites_interpreter = diff --git a/src/sail_coq_backend/pretty_print_coq.ml b/src/sail_coq_backend/pretty_print_coq.ml index 79d1ec905..d391a4f07 100644 --- a/src/sail_coq_backend/pretty_print_coq.ml +++ b/src/sail_coq_backend/pretty_print_coq.ml @@ -130,8 +130,8 @@ type context = { constant_kids : Nat_big_num.num KBindings.t; (* type variables that should be replaced by a constant definition *) bound_nvars : KidSet.t; build_at_return : string option; - recursive_fns : (int * int) Bindings.t; - (* Number of implicit arguments and constraints for (mutually) recursive definitions *) + recursive_fns : (int * int * bool) Bindings.t; + (* Number of implicit arguments and constraints for (mutually) recursive definitions, and whether there is a measure *) debug : bool; ret_typ_pp : PPrint.document; (* Return type formatted for use with returnR *) effect_info : Effects.side_effect_info; @@ -197,7 +197,7 @@ let rec fix_id avoid remove_tick name = | "in" | "let" | "match" | "return" | "then" | "where" | "with" | "by" | "exists" | "exists2" | "using" (* other identifiers we shouldn't override *) | "assert" | "lsl" | "lsr" | "asr" | "type" | "function" | "raise" | "try" | "check" | "field" | "LT" | "GT" | "EQ" - | "Z" | "O" | "R" | "S" | "mod" | "M" | "tt" | "register_ref" -> + | "Z" | "O" | "R" | "S" | "mod" | "M" | "tt" | "register_ref" | "vec" -> name ^ "'" | _ -> if StringSet.mem name avoid then name ^ "'" @@ -1765,10 +1765,16 @@ let doc_exp, doc_let = | ExNone, _, t1 -> t1 ) in - let out_typ = match ann_typ with Typ_aux (Typ_exist (_, _, t1), _) -> t1 | t1 -> t1 in + let out_typ_bound, out_typ = + match ann_typ with Typ_aux (Typ_exist (ks, _, t1), _) -> (ks, t1) | t1 -> ([], t1) + in let autocast = (* Avoid using helper functions which simplify the nexps *) match (in_typ, out_typ) with + (* When we expect a bitvector of arbitrary length we don't need a cast *) + | _, Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp (Nexp_aux (Nexp_var v, _)), _)]), _) + when List.exists (fun k -> Kid.compare v (kopt_kid k) == 0) out_typ_bound -> + false | ( Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n1, _)]), _), Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n2, _)]), _) ) -> not (similar_nexps ctxt env n1 n2) @@ -1841,11 +1847,11 @@ let doc_exp, doc_let = let argspp = List.map2 (doc_arg true) args arg_typs in let all = match is_rec with - | Some (pre, post) -> + | Some (pre, post, is_measured) -> (call :: List.init pre (fun _ -> underscore)) @ argspp @ List.init post (fun _ -> underscore) - @ [parens (string "_limit_reduces _acc")] + @ if is_measured then [parens (string "_limit_reduces _acc")] else [] | None -> ( match f with | Id_aux (Id x, _) when is_prefix "#rec#" x -> @@ -2940,7 +2946,9 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_ in let intropp = match mutrec with NotMutrec -> intropp | FirstFn -> string "Fixpoint" | LaterFn -> string "with" in let ctxt = - if is_measured then { ctxt with recursive_fns = Bindings.singleton id (List.length quantspp, 0) } else ctxt + match mutrec with + | NotMutrec -> ctxt + | _ -> { ctxt with recursive_fns = Bindings.singleton id (List.length quantspp, 0, is_measured) } in let _ = match guard with @@ -3238,6 +3246,13 @@ let doc_val avoid_target_names pat exp = | P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _) when Id.compare id (id_of_kid kid) == 0 -> (id, Some typ) + | P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _) + when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 -> + (id, None) + | P_aux + (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _) + when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 -> + (id, Some typ) | _ -> raise (Reporting.err_todo (pat_loc pat) "Top-level value definition with complex pattern not supported for Coq yet") diff --git a/src/sail_coq_backend/sail_plugin_coq.ml b/src/sail_coq_backend/sail_plugin_coq.ml index 449e7b014..7e3561151 100644 --- a/src/sail_coq_backend/sail_plugin_coq.ml +++ b/src/sail_coq_backend/sail_plugin_coq.ml @@ -168,6 +168,7 @@ let coq_rewrites = ("remove_superfluous_letbinds", []); ("remove_superfluous_returns", []); ("bit_lists_to_lits", []); + ("toplevel_let_patterns", []); ("recheck_defs", []); ("attach_effects", []); ] diff --git a/test/c/cfold_reg.sail b/test/c/cfold_reg.sail index 53066bf59..a75541832 100644 --- a/test/c/cfold_reg.sail +++ b/test/c/cfold_reg.sail @@ -2,10 +2,6 @@ default Order dec $include -val eq_string = { lem: "eq", _: "eq_string" } : (string, string) -> bool - -overload operator == = {eq_string} - register R : bool val "print_endline" : string -> unit @@ -27,4 +23,4 @@ function main(() : unit) -> unit = { } else { print_endline("false") } -} \ No newline at end of file +} diff --git a/test/c/cheri128_hsb.sail b/test/c/cheri128_hsb.sail index d8501d88b..a5a7f86bc 100644 --- a/test/c/cheri128_hsb.sail +++ b/test/c/cheri128_hsb.sail @@ -6,7 +6,7 @@ $include $include $include -val modulus = {ocaml: "modulus", lem: "hardware_mod", coq: "euclid_modulo", _ : "tmod_int"} : forall 'n, 'n > 0 . (int, atom('n)) -> range(0, 'n - 1) +val modulus = {ocaml: "modulus", lem: "hardware_mod", coq: "ZEuclid.modulo", _ : "tmod_int"} : forall 'n, 'n > 0 . (int, atom('n)) -> range(0, 'n - 1) val add_range = {ocaml: "add_int", lem: "integerAdd", coq: "add_range", c: "add_int"} : forall 'n 'm 'o 'p. (range('n, 'm), range('o, 'p)) -> range('n + 'o, 'm + 'p) @@ -59,4 +59,4 @@ val main : unit -> unit effect {escape} function main() = { let _ = computeE(0xFFFF_FFFF_FFFF_FFFF @ 0b1); () -} \ No newline at end of file +} diff --git a/test/c/enum_match.sail b/test/c/enum_match.sail index 6c04d9dca..2de74a362 100644 --- a/test/c/enum_match.sail +++ b/test/c/enum_match.sail @@ -1,6 +1,6 @@ val "eq_anything" : forall ('a : Type). ('a, 'a) -> bool -val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int"} : forall 'n 'm. (atom('n), atom('m)) -> bool +val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : forall 'n 'm. (atom('n), atom('m)) -> bool overload operator == = {eq_atom, eq_anything} @@ -14,4 +14,4 @@ function main (() : unit) -> unit = { B => print("B"), A => print("A") } -} \ No newline at end of file +} diff --git a/test/c/poly_pair.sail b/test/c/poly_pair.sail index 7c86062dd..a2ccf93d9 100644 --- a/test/c/poly_pair.sail +++ b/test/c/poly_pair.sail @@ -2,7 +2,7 @@ default Order dec val print = "print_endline" : string -> unit -val eq_int = { lem: "eq", _: "eq_int" } : (int, int) -> bool +val eq_int = { lem: "eq", coq: "Z.eqb", _: "eq_int" } : (int, int) -> bool union test('a : Type, 'b : Type) = { Ctor1 : ('a, 'b), @@ -17,4 +17,4 @@ function main() = { Ctor1(y, z) if eq_int(y, 3) => print("1"), _ => print("2") }; -} \ No newline at end of file +} diff --git a/test/c/run_tests.py b/test/c/run_tests.py index fb9361d0c..ae3b3bc87 100755 --- a/test/c/run_tests.py +++ b/test/c/run_tests.py @@ -153,6 +153,23 @@ def test_lem(name): def test_coq(name): banner('Testing {}'.format(name)) results = Results(name) + results.expect_failure("inc_tests.sail", "missing built-in functions for increasing vectors in Coq library") + results.expect_failure("read_write_ram.sail", "uses memory primitives not provided by default in Coq") + results.expect_failure("for_shadow.sail", "Pure loops aren't current supported for Coq (and don't really make sense)") + results.expect_failure("fail_exception.sail", "try-blocks around pure expressions not supported in Coq (and a little silly)") + results.expect_failure("loop_exception.sail", "try-blocks around pure expressions not supported in Coq (and a little silly)") + results.expect_failure("concurrency_interface.sail", "test doesn't meet Coq library's expectations for the concurrency interface") + results.expect_failure("outcome_impl.sail", "test doesn't meet Coq backend's expectations for the concurrency interface") + results.expect_failure("pc_no_wildcard.sail", "register type unsupported by Coq backend") + results.expect_failure("cheri_capreg.sail", "test has strange 'pure' reg_deref") + results.expect_failure("poly_outcome.sail", "test doesn't meet Coq library's expectations for the concurrency interface") + results.expect_failure("poly_mapping.sail", "test requires non-standard hex built-ins") + results.expect_failure("real_prop.sail", "random_real not available for Coq at present") + results.expect_failure("fail_assert_mono_bug.sail", "test output checking not supported for Coq yet") + results.expect_failure("fail_issue203.sail", "test output checking not supported for Coq yet") + results.expect_failure("vector_example.sail", "bug: function defs and function calls treat 'len equation differently in Coq backedn") + results.expect_failure("list_torture.sail", "Coq backend doesn't remove a phantom type parameter") + results.expect_failure("tl_pat.sail", "Coq backend doesn't support constructors with the same name as a type") for filenames in chunks(os.listdir('.'), parallel()): tests = {} for filename in filenames: @@ -197,6 +214,7 @@ def test_coq(name): xml += test_ocaml('OCaml') #xml += test_lem('lem') +#xml += test_coq('coq') xml += '\n'