Skip to content

Commit

Permalink
Coq: more fixes for "C" tests
Browse files Browse the repository at this point in the history
- rewrite complex patterns in top-level let bindings
- print recursive function with implicit decreasing arguments properly
- avoid built-in type "vec"
- rerun effect inference when exhaustivity rewrite introduces a match
  failure
- don't generate autocast for an argument with an existentially bound size
- correct Coq built-ins in several tests
- add expected test failures
  • Loading branch information
bacam committed Oct 5, 2023
1 parent 0a70ae3 commit 1575ddc
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 35 deletions.
101 changes: 84 additions & 17 deletions src/lib/rewrites.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4145,7 +4145,8 @@ let rewrite_ast_realize_mappings effect_info env ast =
(* Rewrite to make all pattern matches in Coq output exhaustive and
remove redundant clauses. Assumes that guards, vector patterns,
etc have been rewritten already, and the scattered functions have
been merged.
been merged. It also reruns effect inference if a pattern match
failure has to be added.
Note: if this naive implementation turns out to be too slow or buggy, we
could look at implementing Maranget JFP 17(3), 2007.
Expand Down Expand Up @@ -4438,7 +4439,7 @@ module MakeExhaustive = struct

let funcl_loc (FCL_aux (_, (def_annot, _))) = def_annot.loc

let rewrite_case (e, ann) =
let rewrite_case redo_effects (e, ann) =
match e with
| E_match (e1, cases) | E_try (e1, cases) -> begin
let env = env_of_annot ann in
Expand All @@ -4456,9 +4457,11 @@ module MakeExhaustive = struct

let l = Parse_ast.Generated Parse_ast.Unknown in
let p = P_aux (P_wild, (l, empty_tannot)) in
let l_ann = mk_tannot env unit_typ in
let ann' = mk_tannot env (typ_of_annot ann) in
(* TODO: use an expression that specifically indicates a failed pattern match *)
let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in
let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, l_ann))), (l, ann')) in
redo_effects := true;
E_aux (rebuild (cases @ [Pat_aux (Pat_exp (p, b), (l, empty_tannot))]), ann)
end
| E_let (LB_aux (LB_val (pat, e1), lb_ann), e2) -> begin
Expand All @@ -4474,9 +4477,11 @@ module MakeExhaustive = struct
in
let l = Parse_ast.Generated Parse_ast.Unknown in
let p = P_aux (P_wild, (l, empty_tannot)) in
let l_ann = mk_tannot env unit_typ in
let ann' = mk_tannot env (typ_of_annot ann) in
(* TODO: use an expression that specifically indicates a failed pattern match *)
let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, empty_tannot))), (l, ann')) in
let b = E_aux (E_exit (E_aux (E_lit (L_aux (L_unit, l)), (l, l_ann))), (l, ann')) in
redo_effects := true;
E_aux (E_match (e1, [Pat_aux (Pat_exp (pat, e2), ann); Pat_aux (Pat_exp (p, b), (l, empty_tannot))]), ann)
end
| _ -> E_aux (e, ann)
Expand Down Expand Up @@ -4512,18 +4517,29 @@ module MakeExhaustive = struct

FD_aux (FD_function (r, t, fcls' @ [default]), f_ann)

let rewrite env =
let alg = { id_exp_alg with e_aux = rewrite_case } in
rewrite_ast_base
{
rewrite_exp = (fun _ -> fold_exp alg);
rewrite_pat;
rewrite_let;
rewrite_lexp;
rewrite_fun;
rewrite_def;
rewrite_ast = rewrite_ast_base_progress "Make patterns exhaustive";
}
let rewrite effect_info env ast =
let redo_effects = ref false in
let alg = { id_exp_alg with e_aux = rewrite_case redo_effects } in
let ast' =
rewrite_ast_base
{
rewrite_exp = (fun _ -> fold_exp alg);
rewrite_pat;
rewrite_let;
rewrite_lexp;
rewrite_fun;
rewrite_def;
rewrite_ast = rewrite_ast_base_progress "Make patterns exhaustive";
}
ast
in
let effect_info' =
(* TODO: if we use this for anything other than Coq we'll need
to replace "true" with Target.asserts_termination target,
after plumbing target through to this rewrite. *)
if !redo_effects then Effects.infer_side_effects true ast' else effect_info
in
(ast', effect_info', env)
end

(* Splitting a function (e.g., an execute function on an AST) can produce
Expand Down Expand Up @@ -4906,6 +4922,56 @@ let rewrite_truncate_hex_literals _type_env defs =
{ rewriters_base with rewrite_exp = (fun _ -> fold_exp { id_exp_alg with e_aux = rewrite_aux }) }
defs

(* Coq's Definition command doesn't allow patterns, so rewrite
top-level let bindings with complex patterns into a sequence of
single definitions. *)
let rewrite_toplevel_let_patterns env ast =
let is_pat_simple = function
| P_aux (P_typ (_, P_aux (P_id _id, _)), _) | P_aux (P_id _id, _) -> true
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)
| P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _) ->
Id.compare id (id_of_kid kid) == 0
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)
| P_aux (P_typ (_, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _)
when Id.compare app_id (mk_id "atom") == 0 ->
Id.compare id (id_of_kid kid) == 0
| _ -> false
in
let rewrite_def = function
| DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), (l, annot))), def_annot) as def ->
if is_pat_simple pat then [def]
else (
let ids = pat_ids pat in
let base_id = fresh_id "let" l in
let base_annot = mk_tannot env (typ_of exp) in
let base_def =
mk_def (DEF_let (LB_aux (LB_val (P_aux (P_id base_id, (l, base_annot)), exp), (l, empty_tannot))))
in
let id_defs =
List.map
(fun id ->
let id_typ = match Env.lookup_id id env with Local (_, t) -> t | _ -> assert false in
let id_annot = (Parse_ast.Unknown, mk_tannot env id_typ) in
let def_body =
E_aux
( E_let
( LB_aux (LB_val (pat, E_aux (E_id base_id, (l, base_annot))), (l, empty_tannot)),
E_aux (E_id id, id_annot)
),
id_annot
)
in
mk_def (DEF_let (LB_aux (LB_val (P_aux (P_id id, id_annot), def_body), (l, empty_tannot))))
)
(IdSet.elements ids)
in
base_def :: id_defs
)
| d -> [d]
in
let defs = List.map rewrite_def ast.defs |> List.concat in
{ ast with defs }

let opt_mono_rewrites = ref false
let opt_mono_complex_nexps = ref true

Expand Down Expand Up @@ -5020,7 +5086,7 @@ let all_rewriters =
("add_bitvector_casts", basic_rewriter Monomorphise.add_bitvector_casts);
("remove_impossible_int_cases", basic_rewriter Constant_propagation.remove_impossible_int_cases);
("const_prop_mutrec", String_rewriter (fun target -> Base_rewriter (Constant_propagation_mutrec.rewrite_ast target)));
("make_cases_exhaustive", basic_rewriter MakeExhaustive.rewrite);
("make_cases_exhaustive", Base_rewriter MakeExhaustive.rewrite);
("undefined", Bool_rewriter (fun b -> basic_rewriter (rewrite_undefined_if_gen b)));
("vector_string_pats_to_bit_list", basic_rewriter rewrite_ast_vector_string_pats_to_bit_list);
("remove_not_pats", basic_rewriter rewrite_ast_not_pats);
Expand Down Expand Up @@ -5074,6 +5140,7 @@ let all_rewriters =
)
);
("add_unspecified_rec", basic_rewriter rewrite_add_unspecified_rec);
("toplevel_let_patterns", basic_rewriter rewrite_toplevel_let_patterns);
]

let rewrites_interpreter =
Expand Down
29 changes: 22 additions & 7 deletions src/sail_coq_backend/pretty_print_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ type context = {
constant_kids : Nat_big_num.num KBindings.t; (* type variables that should be replaced by a constant definition *)
bound_nvars : KidSet.t;
build_at_return : string option;
recursive_fns : (int * int) Bindings.t;
(* Number of implicit arguments and constraints for (mutually) recursive definitions *)
recursive_fns : (int * int * bool) Bindings.t;
(* Number of implicit arguments and constraints for (mutually) recursive definitions, and whether there is a measure *)
debug : bool;
ret_typ_pp : PPrint.document; (* Return type formatted for use with returnR *)
effect_info : Effects.side_effect_info;
Expand Down Expand Up @@ -197,7 +197,7 @@ let rec fix_id avoid remove_tick name =
| "in" | "let" | "match" | "return" | "then" | "where" | "with" | "by" | "exists" | "exists2" | "using"
(* other identifiers we shouldn't override *)
| "assert" | "lsl" | "lsr" | "asr" | "type" | "function" | "raise" | "try" | "check" | "field" | "LT" | "GT" | "EQ"
| "Z" | "O" | "R" | "S" | "mod" | "M" | "tt" | "register_ref" ->
| "Z" | "O" | "R" | "S" | "mod" | "M" | "tt" | "register_ref" | "vec" ->
name ^ "'"
| _ ->
if StringSet.mem name avoid then name ^ "'"
Expand Down Expand Up @@ -1765,10 +1765,16 @@ let doc_exp, doc_let =
| ExNone, _, t1 -> t1
)
in
let out_typ = match ann_typ with Typ_aux (Typ_exist (_, _, t1), _) -> t1 | t1 -> t1 in
let out_typ_bound, out_typ =
match ann_typ with Typ_aux (Typ_exist (ks, _, t1), _) -> (ks, t1) | t1 -> ([], t1)
in
let autocast =
(* Avoid using helper functions which simplify the nexps *)
match (in_typ, out_typ) with
(* When we expect a bitvector of arbitrary length we don't need a cast *)
| _, Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp (Nexp_aux (Nexp_var v, _)), _)]), _)
when List.exists (fun k -> Kid.compare v (kopt_kid k) == 0) out_typ_bound ->
false
| ( Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n1, _)]), _),
Typ_aux (Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp n2, _)]), _) ) ->
not (similar_nexps ctxt env n1 n2)
Expand Down Expand Up @@ -1841,11 +1847,11 @@ let doc_exp, doc_let =
let argspp = List.map2 (doc_arg true) args arg_typs in
let all =
match is_rec with
| Some (pre, post) ->
| Some (pre, post, is_measured) ->
(call :: List.init pre (fun _ -> underscore))
@ argspp
@ List.init post (fun _ -> underscore)
@ [parens (string "_limit_reduces _acc")]
@ if is_measured then [parens (string "_limit_reduces _acc")] else []
| None -> (
match f with
| Id_aux (Id x, _) when is_prefix "#rec#" x ->
Expand Down Expand Up @@ -2940,7 +2946,9 @@ let doc_funcl_init types_mod avoid_target_names effect_info mutrec rec_opt ?rec_
in
let intropp = match mutrec with NotMutrec -> intropp | FirstFn -> string "Fixpoint" | LaterFn -> string "with" in
let ctxt =
if is_measured then { ctxt with recursive_fns = Bindings.singleton id (List.length quantspp, 0) } else ctxt
match mutrec with
| NotMutrec -> ctxt
| _ -> { ctxt with recursive_fns = Bindings.singleton id (List.length quantspp, 0, is_measured) }
in
let _ =
match guard with
Expand Down Expand Up @@ -3238,6 +3246,13 @@ let doc_val avoid_target_names pat exp =
| P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _)
when Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, None)
| P_aux
(P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| _ ->
raise
(Reporting.err_todo (pat_loc pat) "Top-level value definition with complex pattern not supported for Coq yet")
Expand Down
1 change: 1 addition & 0 deletions src/sail_coq_backend/sail_plugin_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ let coq_rewrites =
("remove_superfluous_letbinds", []);
("remove_superfluous_returns", []);
("bit_lists_to_lits", []);
("toplevel_let_patterns", []);
("recheck_defs", []);
("attach_effects", []);
]
Expand Down
6 changes: 1 addition & 5 deletions test/c/cfold_reg.sail
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ default Order dec

$include <prelude.sail>

val eq_string = { lem: "eq", _: "eq_string" } : (string, string) -> bool

overload operator == = {eq_string}

register R : bool

val "print_endline" : string -> unit
Expand All @@ -27,4 +23,4 @@ function main(() : unit) -> unit = {
} else {
print_endline("false")
}
}
}
4 changes: 2 additions & 2 deletions test/c/cheri128_hsb.sail
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ $include <option.sail>
$include <vector_dec.sail>
$include <exception_basic.sail>

val modulus = {ocaml: "modulus", lem: "hardware_mod", coq: "euclid_modulo", _ : "tmod_int"} : forall 'n, 'n > 0 . (int, atom('n)) -> range(0, 'n - 1)
val modulus = {ocaml: "modulus", lem: "hardware_mod", coq: "ZEuclid.modulo", _ : "tmod_int"} : forall 'n, 'n > 0 . (int, atom('n)) -> range(0, 'n - 1)

val add_range = {ocaml: "add_int", lem: "integerAdd", coq: "add_range", c: "add_int"} : forall 'n 'm 'o 'p.
(range('n, 'm), range('o, 'p)) -> range('n + 'o, 'm + 'p)
Expand Down Expand Up @@ -59,4 +59,4 @@ val main : unit -> unit effect {escape}
function main() = {
let _ = computeE(0xFFFF_FFFF_FFFF_FFFF @ 0b1);
()
}
}
4 changes: 2 additions & 2 deletions test/c/enum_match.sail
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

val "eq_anything" : forall ('a : Type). ('a, 'a) -> bool
val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int"} : forall 'n 'm. (atom('n), atom('m)) -> bool
val eq_atom = {ocaml: "eq_int", lem: "eq", c: "eq_int", coq: "Z.eqb"} : forall 'n 'm. (atom('n), atom('m)) -> bool

overload operator == = {eq_atom, eq_anything}

Expand All @@ -14,4 +14,4 @@ function main (() : unit) -> unit = {
B => print("B"),
A => print("A")
}
}
}
4 changes: 2 additions & 2 deletions test/c/poly_pair.sail
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default Order dec

val print = "print_endline" : string -> unit

val eq_int = { lem: "eq", _: "eq_int" } : (int, int) -> bool
val eq_int = { lem: "eq", coq: "Z.eqb", _: "eq_int" } : (int, int) -> bool

union test('a : Type, 'b : Type) = {
Ctor1 : ('a, 'b),
Expand All @@ -17,4 +17,4 @@ function main() = {
Ctor1(y, z) if eq_int(y, 3) => print("1"),
_ => print("2")
};
}
}
18 changes: 18 additions & 0 deletions test/c/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ def test_lem(name):
def test_coq(name):
banner('Testing {}'.format(name))
results = Results(name)
results.expect_failure("inc_tests.sail", "missing built-in functions for increasing vectors in Coq library")
results.expect_failure("read_write_ram.sail", "uses memory primitives not provided by default in Coq")
results.expect_failure("for_shadow.sail", "Pure loops aren't current supported for Coq (and don't really make sense)")
results.expect_failure("fail_exception.sail", "try-blocks around pure expressions not supported in Coq (and a little silly)")
results.expect_failure("loop_exception.sail", "try-blocks around pure expressions not supported in Coq (and a little silly)")
results.expect_failure("concurrency_interface.sail", "test doesn't meet Coq library's expectations for the concurrency interface")
results.expect_failure("outcome_impl.sail", "test doesn't meet Coq backend's expectations for the concurrency interface")
results.expect_failure("pc_no_wildcard.sail", "register type unsupported by Coq backend")
results.expect_failure("cheri_capreg.sail", "test has strange 'pure' reg_deref")
results.expect_failure("poly_outcome.sail", "test doesn't meet Coq library's expectations for the concurrency interface")
results.expect_failure("poly_mapping.sail", "test requires non-standard hex built-ins")
results.expect_failure("real_prop.sail", "random_real not available for Coq at present")
results.expect_failure("fail_assert_mono_bug.sail", "test output checking not supported for Coq yet")
results.expect_failure("fail_issue203.sail", "test output checking not supported for Coq yet")
results.expect_failure("vector_example.sail", "bug: function defs and function calls treat 'len equation differently in Coq backedn")
results.expect_failure("list_torture.sail", "Coq backend doesn't remove a phantom type parameter")
results.expect_failure("tl_pat.sail", "Coq backend doesn't support constructors with the same name as a type")
for filenames in chunks(os.listdir('.'), parallel()):
tests = {}
for filename in filenames:
Expand Down Expand Up @@ -197,6 +214,7 @@ def test_coq(name):
xml += test_ocaml('OCaml')

#xml += test_lem('lem')
#xml += test_coq('coq')

xml += '</testsuites>\n'

Expand Down

0 comments on commit 1575ddc

Please sign in to comment.