Skip to content

Commit

Permalink
Coq: Add proper redundant clause checking to exhaustivity rewrite
Browse files Browse the repository at this point in the history
Necessary for some of the C tests suite; useful for warnings.
  • Loading branch information
bacam committed Oct 5, 2023
1 parent a74281c commit 0a70ae3
Showing 1 changed file with 49 additions and 43 deletions.
92 changes: 49 additions & 43 deletions src/lib/rewrites.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4142,15 +4142,10 @@ let rewrite_ast_realize_mappings effect_info env ast =
let ast = { ast with defs = List.map rewrite_def ast.defs |> List.flatten } in
(ast, !effect_info, env)

(* Rewrite to make all pattern matches in Coq output exhaustive.
Assumes that guards, vector patterns, etc have been rewritten already,
and the scattered functions have been merged.
Will add escape effect where a default is needed, so effects will
need recalculated afterwards.
Also detects and removes redundant wildcard patterns at the end of the match.
(We could do more, but this is sufficient to deal with the code generated by
the mappings rewrites.)
(* Rewrite to make all pattern matches in Coq output exhaustive and
remove redundant clauses. Assumes that guards, vector patterns,
etc have been rewritten already, and the scattered functions have
been merged.
Note: if this naive implementation turns out to be too slow or buggy, we
could look at implementing Maranget JFP 17(3), 2007.
Expand Down Expand Up @@ -4255,19 +4250,20 @@ module MakeExhaustive = struct
let subpats rm_pats res_pats =
(* Pointwise removal *)
let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in
let progress = List.exists snd res_pats' in
(* Form the list of residual tuples by combining one position from the
pointwise removal with the original residual of the other positions. *)
let rec aux acc fixed residual =
match (fixed, residual) with
| [], [] -> []
| fh :: ft, rh :: rt ->
| fh :: ft, (rh, _) :: rt ->
(* ... so order matters here *)
let rt' = aux (acc @ [fh]) ft rt in
let newr = List.map (fun x -> acc @ (x :: ft)) rh in
newr @ rt'
| _, _ -> assert false (* impossible because we managed map2 above *)
in
aux [] res_pats res_pats'
(aux [] res_pats res_pats', progress)
in
let inconsistent () =
raise
Expand All @@ -4280,53 +4276,56 @@ module MakeExhaustive = struct
let _ = printprefix := " " ^ !printprefix in*)
let rp' =
match rm_pat with
| P_wild -> []
| P_id id when match Env.lookup_id id ctx.env with Unbound _ | Local _ -> true | _ -> false -> []
| P_wild -> ([], true)
| P_id id when match Env.lookup_id id ctx.env with Unbound _ | Local _ -> true | _ -> false -> ([], true)
| P_lit lit -> (
match res_pat with
| RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit)
| RP_lit RL_inf -> [res_pat]
| RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat]
| RP_any -> (List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit), true)
| RP_lit RL_inf -> ([res_pat], true (* TODO: check for duplicates *))
| RP_lit lit' -> if lit' = rlit_of_lit lit then ([], true) else ([res_pat], false)
| _ -> inconsistent ()
)
| P_as (p, _) | P_typ (_, p) | P_var (p, _) -> remove_clause_from_pattern ctx p res_pat
| P_id id -> (
match Env.lookup_id id ctx.env with
| Enum enum -> (
match res_pat with
| RP_any -> Bindings.find id ctx.enum_to_rest
| RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat]
| RP_any -> (Bindings.find id ctx.enum_to_rest, true)
| RP_enum id' -> if Id.compare id id' == 0 then ([], true) else ([res_pat], false)
| _ -> inconsistent ()
)
| _ -> assert false
)
| P_tuple rm_pats ->
let previous_res_pats =
match res_pat with
| RP_tuple res_pats -> res_pats
| RP_any -> List.map (fun _ -> RP_any) rm_pats
| _ -> inconsistent ()
in
let res_pats' = subpats rm_pats previous_res_pats in
List.map (fun rps -> RP_tuple rps) res_pats'
if Util.list_empty rm_pats then ([], true)
else (
let previous_res_pats =
match res_pat with
| RP_tuple res_pats -> res_pats
| RP_any -> List.map (fun _ -> RP_any) rm_pats
| _ -> inconsistent ()
in
let res_pats', progress = subpats rm_pats previous_res_pats in
(List.map (fun rps -> RP_tuple rps) res_pats', progress)
)
| P_app (id, args) -> (
match res_pat with
| RP_app (id', residual_args) ->
if Id.compare id id' == 0 then (
let res_pats' =
let res_pats', progress =
(* Constructors that were specified without a return type might get
an extra tuple in their type; expand that here if necessary.
TODO: this should go away if we enforce proper arities. *)
match (args, residual_args) with
| [], [RP_any] | _ :: _ :: _, [RP_any] -> subpats args (List.map (fun _ -> RP_any) args)
| _, _ -> subpats args residual_args
in
List.map (fun rps -> RP_app (id, rps)) res_pats'
(List.map (fun rps -> RP_app (id, rps)) res_pats', progress)
)
else [res_pat]
else ([res_pat], false)
| RP_any ->
let res_args = subpats args (List.map (fun _ -> RP_any) args) in
List.map (fun l -> RP_app (id, l)) res_args @ Bindings.find id ctx.constructor_to_rest
let res_args, progress = subpats args (List.map (fun _ -> RP_any) args) in
(List.map (fun l -> RP_app (id, l)) res_args @ Bindings.find id ctx.constructor_to_rest, progress)
| _ -> inconsistent ()
)
| P_struct (field_pats, _) ->
Expand Down Expand Up @@ -4355,16 +4354,16 @@ module MakeExhaustive = struct
)
all_ids
in
let res_pats' = subpats cur_pats res_pats in
List.map (fun rps -> RP_struct (List.combine all_ids rps)) res_pats'
let res_pats', progress = subpats cur_pats res_pats in
(List.map (fun rps -> RP_struct (List.combine all_ids rps)) res_pats', progress)
| P_list ps -> (
match ps with
| p1 :: ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1, P_aux (P_list ptl, ann)), ann)) res_pat
| [] -> (
match res_pat with
| RP_any -> [RP_cons (RP_any, RP_any)]
| RP_cons _ -> [res_pat]
| RP_nil -> []
| RP_any -> ([RP_cons (RP_any, RP_any)], true)
| RP_cons _ -> ([res_pat], false)
| RP_nil -> ([], true)
| _ -> inconsistent ()
)
)
Expand All @@ -4377,10 +4376,10 @@ module MakeExhaustive = struct
| _ -> inconsistent ()
in
match rps with
| None -> rp'
| None -> (rp', false)
| Some rps ->
let res_pats = subpats [p1; p2] rps in
rp' @ List.map (function [rp1; rp2] -> RP_cons (rp1, rp2) | _ -> assert false) res_pats
let res_pats, progress = subpats [p1; p2] rps in
(rp' @ List.map (function [rp1; rp2] -> RP_cons (rp1, rp2) | _ -> assert false) res_pats, progress)
end
| P_or _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Or pattern not supported")
| P_not _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Negated pattern not supported")
Expand All @@ -4401,18 +4400,25 @@ module MakeExhaustive = struct
(*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in
let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*)
match patexp with
| Pat_aux (Pat_exp (p, _), _) -> List.concat (List.map (remove_clause_from_pattern ctx p) rps)
| Pat_aux (Pat_exp (p, _), _) ->
let rps, progress = List.split (List.map (remove_clause_from_pattern ctx p) rps) in
(List.concat rps, List.exists (fun b -> b) progress)
| Pat_aux (Pat_when _, (l, _)) ->
raise (Reporting.err_unreachable l __POS__ "Guarded pattern should have been rewritten away")

(* We do some minimal redundancy checking to remove bogus wildcard patterns here *)
let check_cases process is_wild loc_of cases =
let rec aux rps acc = function
| [] -> (acc, rps)
| [p] when is_wild p && match rps with [] -> true | _ -> false ->
let () = Reporting.print_err (loc_of p) "Match checking" "Redundant wildcard clause" in
(acc, [])
| h :: t -> aux (process rps h) (h :: acc) t
| h :: t ->
let rps', progress = process rps h in
if progress then aux rps' (h :: acc) t
else begin
Reporting.print_err (loc_of h) "Match checking" "Redundant clause";
aux rps' acc t
end
in
let cases, rps = aux [RP_any] [] cases in
(List.rev cases, rps)
Expand Down Expand Up @@ -4458,7 +4464,7 @@ module MakeExhaustive = struct
| E_let (LB_aux (LB_val (pat, e1), lb_ann), e2) -> begin
let env = env_of_annot ann in
let ctx = ctx_from_env env in
let rps = remove_clause_from_pattern ctx pat RP_any in
let rps, _ = remove_clause_from_pattern ctx pat RP_any in
match rps with
| [] -> E_aux (e, ann)
| example :: _ ->
Expand Down

0 comments on commit 0a70ae3

Please sign in to comment.