diff --git a/src/lib/rewrites.ml b/src/lib/rewrites.ml index 3ef3fcb9c..43ba69e3c 100644 --- a/src/lib/rewrites.ml +++ b/src/lib/rewrites.ml @@ -4142,15 +4142,10 @@ let rewrite_ast_realize_mappings effect_info env ast = let ast = { ast with defs = List.map rewrite_def ast.defs |> List.flatten } in (ast, !effect_info, env) -(* Rewrite to make all pattern matches in Coq output exhaustive. - Assumes that guards, vector patterns, etc have been rewritten already, - and the scattered functions have been merged. - Will add escape effect where a default is needed, so effects will - need recalculated afterwards. - - Also detects and removes redundant wildcard patterns at the end of the match. - (We could do more, but this is sufficient to deal with the code generated by - the mappings rewrites.) +(* Rewrite to make all pattern matches in Coq output exhaustive and + remove redundant clauses. Assumes that guards, vector patterns, + etc have been rewritten already, and the scattered functions have + been merged. Note: if this naive implementation turns out to be too slow or buggy, we could look at implementing Maranget JFP 17(3), 2007. @@ -4255,19 +4250,20 @@ module MakeExhaustive = struct let subpats rm_pats res_pats = (* Pointwise removal *) let res_pats' = List.map2 (remove_clause_from_pattern ctx) rm_pats res_pats in + let progress = List.exists snd res_pats' in (* Form the list of residual tuples by combining one position from the pointwise removal with the original residual of the other positions. *) let rec aux acc fixed residual = match (fixed, residual) with | [], [] -> [] - | fh :: ft, rh :: rt -> + | fh :: ft, (rh, _) :: rt -> (* ... so order matters here *) let rt' = aux (acc @ [fh]) ft rt in let newr = List.map (fun x -> acc @ (x :: ft)) rh in newr @ rt' | _, _ -> assert false (* impossible because we managed map2 above *) in - aux [] res_pats res_pats' + (aux [] res_pats res_pats', progress) in let inconsistent () = raise @@ -4280,13 +4276,13 @@ module MakeExhaustive = struct let _ = printprefix := " " ^ !printprefix in*) let rp' = match rm_pat with - | P_wild -> [] - | P_id id when match Env.lookup_id id ctx.env with Unbound _ | Local _ -> true | _ -> false -> [] + | P_wild -> ([], true) + | P_id id when match Env.lookup_id id ctx.env with Unbound _ | Local _ -> true | _ -> false -> ([], true) | P_lit lit -> ( match res_pat with - | RP_any -> List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit) - | RP_lit RL_inf -> [res_pat] - | RP_lit lit' -> if lit' = rlit_of_lit lit then [] else [res_pat] + | RP_any -> (List.map (fun l -> RP_lit l) (inv_rlit_of_lit lit), true) + | RP_lit RL_inf -> ([res_pat], true (* TODO: check for duplicates *)) + | RP_lit lit' -> if lit' = rlit_of_lit lit then ([], true) else ([res_pat], false) | _ -> inconsistent () ) | P_as (p, _) | P_typ (_, p) | P_var (p, _) -> remove_clause_from_pattern ctx p res_pat @@ -4294,26 +4290,29 @@ module MakeExhaustive = struct match Env.lookup_id id ctx.env with | Enum enum -> ( match res_pat with - | RP_any -> Bindings.find id ctx.enum_to_rest - | RP_enum id' -> if Id.compare id id' == 0 then [] else [res_pat] + | RP_any -> (Bindings.find id ctx.enum_to_rest, true) + | RP_enum id' -> if Id.compare id id' == 0 then ([], true) else ([res_pat], false) | _ -> inconsistent () ) | _ -> assert false ) | P_tuple rm_pats -> - let previous_res_pats = - match res_pat with - | RP_tuple res_pats -> res_pats - | RP_any -> List.map (fun _ -> RP_any) rm_pats - | _ -> inconsistent () - in - let res_pats' = subpats rm_pats previous_res_pats in - List.map (fun rps -> RP_tuple rps) res_pats' + if Util.list_empty rm_pats then ([], true) + else ( + let previous_res_pats = + match res_pat with + | RP_tuple res_pats -> res_pats + | RP_any -> List.map (fun _ -> RP_any) rm_pats + | _ -> inconsistent () + in + let res_pats', progress = subpats rm_pats previous_res_pats in + (List.map (fun rps -> RP_tuple rps) res_pats', progress) + ) | P_app (id, args) -> ( match res_pat with | RP_app (id', residual_args) -> if Id.compare id id' == 0 then ( - let res_pats' = + let res_pats', progress = (* Constructors that were specified without a return type might get an extra tuple in their type; expand that here if necessary. TODO: this should go away if we enforce proper arities. *) @@ -4321,12 +4320,12 @@ module MakeExhaustive = struct | [], [RP_any] | _ :: _ :: _, [RP_any] -> subpats args (List.map (fun _ -> RP_any) args) | _, _ -> subpats args residual_args in - List.map (fun rps -> RP_app (id, rps)) res_pats' + (List.map (fun rps -> RP_app (id, rps)) res_pats', progress) ) - else [res_pat] + else ([res_pat], false) | RP_any -> - let res_args = subpats args (List.map (fun _ -> RP_any) args) in - List.map (fun l -> RP_app (id, l)) res_args @ Bindings.find id ctx.constructor_to_rest + let res_args, progress = subpats args (List.map (fun _ -> RP_any) args) in + (List.map (fun l -> RP_app (id, l)) res_args @ Bindings.find id ctx.constructor_to_rest, progress) | _ -> inconsistent () ) | P_struct (field_pats, _) -> @@ -4355,16 +4354,16 @@ module MakeExhaustive = struct ) all_ids in - let res_pats' = subpats cur_pats res_pats in - List.map (fun rps -> RP_struct (List.combine all_ids rps)) res_pats' + let res_pats', progress = subpats cur_pats res_pats in + (List.map (fun rps -> RP_struct (List.combine all_ids rps)) res_pats', progress) | P_list ps -> ( match ps with | p1 :: ptl -> remove_clause_from_pattern ctx (P_aux (P_cons (p1, P_aux (P_list ptl, ann)), ann)) res_pat | [] -> ( match res_pat with - | RP_any -> [RP_cons (RP_any, RP_any)] - | RP_cons _ -> [res_pat] - | RP_nil -> [] + | RP_any -> ([RP_cons (RP_any, RP_any)], true) + | RP_cons _ -> ([res_pat], false) + | RP_nil -> ([], true) | _ -> inconsistent () ) ) @@ -4377,10 +4376,10 @@ module MakeExhaustive = struct | _ -> inconsistent () in match rps with - | None -> rp' + | None -> (rp', false) | Some rps -> - let res_pats = subpats [p1; p2] rps in - rp' @ List.map (function [rp1; rp2] -> RP_cons (rp1, rp2) | _ -> assert false) res_pats + let res_pats, progress = subpats [p1; p2] rps in + (rp' @ List.map (function [rp1; rp2] -> RP_cons (rp1, rp2) | _ -> assert false) res_pats, progress) end | P_or _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Or pattern not supported") | P_not _ -> raise (Reporting.err_unreachable (fst ann) __POS__ "Negated pattern not supported") @@ -4401,18 +4400,25 @@ module MakeExhaustive = struct (*let _ = print_endline ("res_pats: " ^ String.concat "; " (List.map string_of_rp rps)) in let _ = print_endline ("pat: " ^ string_of_pexp patexp) in*) match patexp with - | Pat_aux (Pat_exp (p, _), _) -> List.concat (List.map (remove_clause_from_pattern ctx p) rps) + | Pat_aux (Pat_exp (p, _), _) -> + let rps, progress = List.split (List.map (remove_clause_from_pattern ctx p) rps) in + (List.concat rps, List.exists (fun b -> b) progress) | Pat_aux (Pat_when _, (l, _)) -> raise (Reporting.err_unreachable l __POS__ "Guarded pattern should have been rewritten away") - (* We do some minimal redundancy checking to remove bogus wildcard patterns here *) let check_cases process is_wild loc_of cases = let rec aux rps acc = function | [] -> (acc, rps) | [p] when is_wild p && match rps with [] -> true | _ -> false -> let () = Reporting.print_err (loc_of p) "Match checking" "Redundant wildcard clause" in (acc, []) - | h :: t -> aux (process rps h) (h :: acc) t + | h :: t -> + let rps', progress = process rps h in + if progress then aux rps' (h :: acc) t + else begin + Reporting.print_err (loc_of h) "Match checking" "Redundant clause"; + aux rps' acc t + end in let cases, rps = aux [RP_any] [] cases in (List.rev cases, rps) @@ -4458,7 +4464,7 @@ module MakeExhaustive = struct | E_let (LB_aux (LB_val (pat, e1), lb_ann), e2) -> begin let env = env_of_annot ann in let ctx = ctx_from_env env in - let rps = remove_clause_from_pattern ctx pat RP_any in + let rps, _ = remove_clause_from_pattern ctx pat RP_any in match rps with | [] -> E_aux (e, ann) | example :: _ ->