From a86cabf7d3d38dd95283fc06bcdddfc5f9433b16 Mon Sep 17 00:00:00 2001 From: Chris Casinghino Date: Fri, 12 Jan 2024 10:04:49 +0000 Subject: [PATCH] Turn back on switch optimization (And formatting) --- middle_end/flambda2/flambda2.ml | 58 +++++++++++------------ middle_end/flambda2/validate/normalize.ml | 14 +++--- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/middle_end/flambda2/flambda2.ml b/middle_end/flambda2/flambda2.ml index 594d223d524..dfb9c275da6 100644 --- a/middle_end/flambda2/flambda2.ml +++ b/middle_end/flambda2/flambda2.ml @@ -97,33 +97,35 @@ let print_flexpect name main_dump_ppf ~raw_flambda:old_unit new_unit = ~header:("Before and after " ^ name) ~f:pp_flambda_as_flexpect (old_unit, new_unit) -let dump_validator_files filename src_init src_normalized res_init res_normalized = +let dump_validator_files filename src_init src_normalized res_init + res_normalized = Misc.protect_writing_to_file ~filename ~f:(fun out -> - let ppf = Format.formatter_of_out_channel out in - Format.fprintf ppf - "\n\n\n------------------ Translated Original ------------------\n\n"; - Flambda2_core.print ppf src_init; - Format.fprintf ppf - "\n\n\n------------------ Normalized Original ------------------\n\n"; - Flambda2_core.print ppf src_normalized; - Format.fprintf ppf - "\n\n\n------------------ Translated Result ------------------\n\n"; - Flambda2_core.print ppf res_init; - Format.fprintf ppf - "\n\n\n------------------ Normalized Result ------------------\n\n"; - Flambda2_core.print ppf res_normalized) + let ppf = Format.formatter_of_out_channel out in + Format.fprintf ppf + "\n\n\n------------------ Translated Original ------------------\n\n"; + Flambda2_core.print ppf src_init; + Format.fprintf ppf + "\n\n\n------------------ Normalized Original ------------------\n\n"; + Flambda2_core.print ppf src_normalized; + Format.fprintf ppf + "\n\n\n------------------ Translated Result ------------------\n\n"; + Flambda2_core.print ppf res_init; + Format.fprintf ppf + "\n\n\n------------------ Normalized Result ------------------\n\n"; + Flambda2_core.print ppf res_normalized; + Format.fprintf ppf "\n%!") let validate filename (src : Flambda_unit.t) (res : Flambda_unit.t) = let src_core = - Profile.record ~accumulate:true "translate_src" Translate.flambda_unit_to_core - src + Profile.record ~accumulate:true "translate_src" + Translate.flambda_unit_to_core src in let src_core' = Profile.record ~accumulate:true "normalize_src" Normalize.normalize src_core in let res_core = - Profile.record ~accumulate:true "translate_res" Translate.flambda_unit_to_core - res + Profile.record ~accumulate:true "translate_res" + Translate.flambda_unit_to_core res in let res_core' = Profile.record ~accumulate:true "normalize_res" Normalize.normalize res_core @@ -131,10 +133,9 @@ let validate filename (src : Flambda_unit.t) (res : Flambda_unit.t) = let validated = Profile.record ~accumulate:true "equiv" (Equiv.core_eq src_core') res_core' in - begin match !Flambda_backend_flags.validate_debug with + (match !Flambda_backend_flags.validate_debug with | None -> () - | Some file -> dump_validator_files file src_core src_core' res_core res_core' - end; + | Some file -> dump_validator_files file src_core src_core' res_core res_core'); if validated then Format.eprintf "fλ2: %s PASS@." filename else Format.eprintf "fλ2: %s FAIL@." filename @@ -196,16 +197,13 @@ let lambda_to_cmm ~ppf_dump:ppf ~prefixname ~filename ~keep_symbol_tables Simplify.run ~cmx_loader ~round raw_flambda) in (* Run the validator *) - (if !Flambda_backend_flags.validate - then begin - Normalize.comp_unit := compilation_unit; - (try - Profile.record_call ~accumulate:true "validate" (fun () -> + if !Flambda_backend_flags.validate + then ( + Normalize.comp_unit := compilation_unit; + try + Profile.record_call ~accumulate:true "validate" (fun () -> validate filename raw_flambda flambda) - with - | _ -> Format.eprintf "fλ2: %s FAIL [ERROR]@." filename - ) - end); + with _ -> Format.eprintf "fλ2: %s FAIL [ERROR]@." filename); (if Flambda_features.inlining_report () then let output_prefix = Printf.sprintf "%s.%d" prefixname round in diff --git a/middle_end/flambda2/validate/normalize.ml b/middle_end/flambda2/validate/normalize.ml index 12c7687a93c..fc16c878973 100644 --- a/middle_end/flambda2/validate/normalize.ml +++ b/middle_end/flambda2/validate/normalize.ml @@ -403,14 +403,12 @@ and step_handler (e : continuation_handler) = and step_switch scrutinee arms : core_exp = let default = (* if the arms are all the same, collapse them to a single arm *) - Expr.create_switch {scrutinee; arms} - (* let bindings = Targetint_31_63.Map.bindings arms in - * let (_, hd) = List.hd bindings in - * Equiv.debug := false; - * if (List.for_all (fun (_, x) -> Equiv.core_eq hd x) bindings) - * then (Equiv.debug := false; hd) - * else (Equiv.debug := false; - * Expr.create_switch {scrutinee; arms})) *) + (* Expr.create_switch {scrutinee; arms} *) + let bindings = Targetint_31_63.Map.bindings arms in + let (_, hd) = List.hd bindings in + if (List.for_all (fun (_, x) -> Equiv.core_eq hd x) bindings) + then hd + else Expr.create_switch {scrutinee; arms} in (* if the scrutinee is exactly one of the arms, simplify *) match must_be_simple_or_immediate scrutinee with