From 1a34c40f771c4d02b4c74d13676966b23d6e2b74 Mon Sep 17 00:00:00 2001 From: Irene Yoon Date: Thu, 4 Jan 2024 17:54:34 +0100 Subject: [PATCH] Modifying core_fmap for more efficient core exp manipulation --- middle_end/flambda2/validate/flambda2_core.ml | 38 ++++++----- .../flambda2/validate/flambda2_core.mli | 6 +- middle_end/flambda2/validate/normalize.ml | 65 +++++++++---------- middle_end/flambda2/validate/translate.ml | 17 ++--- 4 files changed, 62 insertions(+), 64 deletions(-) diff --git a/middle_end/flambda2/validate/flambda2_core.ml b/middle_end/flambda2/validate/flambda2_core.ml index 490dabba485..3142f502df5 100644 --- a/middle_end/flambda2/validate/flambda2_core.ml +++ b/middle_end/flambda2/validate/flambda2_core.ml @@ -1510,30 +1510,38 @@ let prim_fix (fix : core_exp -> core_exp) (e : primitive) = |> With_delayed_renaming.create let named_fix (fix : core_exp -> core_exp) - (f : 'a -> literal -> core_exp) arg (e : named) = - match e with - | Literal l -> f arg l + (f : 'a -> core_exp -> core_exp) arg (e : core_exp) = + let n = + (match must_be_named e with + | Some n -> n + | None -> Misc.fatal_error "Expected name expr") + in + match n with + | Literal _ -> + let f' = f arg e in + if e == f' then e else f' | Prim e -> prim_fix fix e | Closure_expr (phi, slot, clo) -> - let {function_decls; value_slots} = set_of_closures_fix fix clo in - Named (Closure_expr (phi, slot, {function_decls; value_slots})) - |> With_delayed_renaming.create + let ({function_decls; value_slots} as e') = set_of_closures_fix fix clo in + if clo == e' then e + else + Named (Closure_expr (phi, slot, {function_decls; value_slots})) + |> With_delayed_renaming.create | Set_of_closures clo -> - let {function_decls; value_slots} = set_of_closures_fix fix clo in - Named (Set_of_closures {function_decls; value_slots}) - |> With_delayed_renaming.create + let ({function_decls; value_slots} as e') = set_of_closures_fix fix clo in + if clo == e' then e + else + Named (Set_of_closures {function_decls; value_slots}) + |> With_delayed_renaming.create | Static_consts group -> static_const_group_fix fix group - | Rec_info _ -> - Named e - |> With_delayed_renaming.create + | Rec_info _ -> e (* LATER: Make this first order? *) let rec core_fmap - (f : 'a -> literal -> core_exp) (arg : 'a) (e : core_exp) : core_exp = + (f : 'a -> core_exp -> core_exp) (arg : 'a) (e : core_exp) : core_exp = match descr e with - | Named e -> - named_fix (core_fmap f arg) f arg e + | Named _ -> named_fix (core_fmap f arg) f arg e | Let e -> let_fix (core_fmap f arg) e | Let_cont e -> let_cont_fix (core_fmap f arg) e | Apply e -> apply_fix (core_fmap f arg) e diff --git a/middle_end/flambda2/validate/flambda2_core.mli b/middle_end/flambda2/validate/flambda2_core.mli index cffed3cffcb..db39deba105 100644 --- a/middle_end/flambda2/validate/flambda2_core.mli +++ b/middle_end/flambda2/validate/flambda2_core.mli @@ -304,7 +304,7 @@ val apply_renaming : core_exp -> Renaming.t -> core_exp val lambda_to_handler : lambda_expr -> continuation_handler -val core_fmap : ('a -> literal -> core_exp) -> 'a -> core_exp -> core_exp +val core_fmap : ('a -> core_exp -> core_exp) -> 'a -> core_exp -> core_exp (* Fixpoint functions for core expressions *) val let_fix : (core_exp -> core_exp) -> let_expr -> core_exp @@ -317,8 +317,8 @@ val switch_fix : (core_exp -> core_exp) -> switch_expr -> core_exp val named_fix : (core_exp -> core_exp) -> - ('a -> literal -> core_exp) -> - 'a -> named -> core_exp + ('a -> core_exp -> core_exp) -> + 'a -> core_exp -> core_exp val set_of_closures_fix : (core_exp -> core_exp) -> set_of_closures -> set_of_closures diff --git a/middle_end/flambda2/validate/normalize.ml b/middle_end/flambda2/validate/normalize.ml index 12c7687a93c..3ff01836237 100644 --- a/middle_end/flambda2/validate/normalize.ml +++ b/middle_end/flambda2/validate/normalize.ml @@ -72,14 +72,13 @@ let rec subst_pattern ~(bound : Bound_for_let.t) ~(let_body : core_exp) | None -> core_fmap (fun (bound, let_body) s -> - match s with - | Simple s -> + match must_be_simple s with + | Some s -> let bound = Simple.var (Bound_var.var bound) in if (Simple.equal s bound) then let_body else Expr.create_named (Literal (Simple s)) - | (Cont _ | Res_cont _ | Slot _ | Code_id _) -> - Expr.create_named (Literal s)) + | None -> s) (bound, let_body) e) | Static bound -> subst_static_list ~bound ~let_body e @@ -277,14 +276,12 @@ and subst_block_like ~(bound : Symbol.t) ~(let_body : static_const_or_code) (e : named) : core_exp = core_fmap (fun _ v -> - match v with - | Simple v -> - if Simple.equal v (Simple.symbol bound) + match must_be_simple v with + | Some s -> + if Simple.equal s (Simple.symbol bound) then Expr.create_named (Static_consts [let_body]) - else Expr.create_named (Literal (Simple v)) - | (Cont _ | Res_cont _ | Slot _ | Code_id _) -> - Expr.create_named (Literal v)) - () (Expr.create_named e) + else v + | _ -> v) () (Expr.create_named e) let partial_combine l1 l2 = let rec partial_combine (l1 : 'a list) (l2 : 'b list) acc @@ -317,15 +314,13 @@ let subst_params let param_args = List.map (fun (x, y) -> (Bound_parameter.simple x, y)) param_args in core_fmap - (fun () s -> - match s with - | Simple s -> + (fun () v -> + match must_be_simple v with + | Some s -> (match List.assoc_opt s param_args with | Some arg_v -> arg_v - | None -> Expr.create_named (Literal (Simple s))) - | (Cont _ | Res_cont _ | Slot _ | Code_id _) -> - Expr.create_named (Literal s)) - () e + | None -> v) + | None -> v) () e (* [LetCont-β] *) let rec subst_cont (cont_e1: core_exp) (k: Bound_continuation.t) @@ -572,20 +567,19 @@ and step_apply_lambda lambda_expr continuation exn_continuation region apply_arg let params = bound.params in let exp = core_fmap - (fun _ l -> - match l with - | (Cont k | Res_cont (Return k)) -> + (fun _ v -> + match must_be_cont v with + | Some k -> if Continuation.equal k bound.return_continuation then continuation else if Continuation.equal k bound.exn_continuation then exn_continuation - else (Expr.create_named (Literal l)) - | Simple s -> - if (Simple.same (Simple.var bound.my_region) s) - then region - else (Expr.create_named (Literal l)) - | (Res_cont Never_returns | Slot _ | Code_id _) -> - Expr.create_named (Literal l) + else v + | None -> + (match must_be_simple v with + | Some s -> if (Simple.same (Simple.var bound.my_region) s) + then region else v + | None -> v) ) () exp in subst_params params exp apply_args) @@ -805,15 +799,14 @@ and concretize_my_closure phi (slot : Function_slot.t) the closure [phi] variable. *) let body = core_fmap - (fun _ s -> - match s with - | Simple simple -> + (fun _ v -> + match must_be_simple v with + | Some simple -> if (Simple.same (Simple.var (Bound_var.var bff)) simple) then Expr.create_named (Literal (Slot (phi, Function_slot slot))) - else (Expr.create_named (Literal s)) - | (Cont _ | Res_cont _ | Slot _ | Code_id _) -> - Expr.create_named (Literal s)) + else v + | _ -> v) () body in Core_lambda.create bound (subst_my_closure_body phi clo body)))) @@ -841,8 +834,8 @@ and step_set_of_closures var (* Inline non-recursive continuation handlers first *) let rec inline_handlers (e : core_exp) = match Expr.descr e with - | Named e -> - named_fix inline_handlers (fun () x -> Expr.create_named (Literal x)) () e + | Named _ -> + named_fix inline_handlers (fun () x -> x) () e | Let e -> let_fix inline_handlers e | Let_cont e -> inline_let_cont e diff --git a/middle_end/flambda2/validate/translate.ml b/middle_end/flambda2/validate/translate.ml index 208082ad09a..b62abf3b306 100644 --- a/middle_end/flambda2/validate/translate.ml +++ b/middle_end/flambda2/validate/translate.ml @@ -28,12 +28,12 @@ let tagged_immediate_to_core (e : Targetint_31_63.t) : core_exp = let apply_subst (s : substitutions) (e : core_exp) : core_exp = core_fmap (fun () v -> - match v with - | Simple v -> + match must_be_simple v with + | Some v -> (match Sub.find_opt v s with | Some exp -> exp | None -> Expr.create_named (Literal (Simple v))) - | (Cont _ | Res_cont _ | Slot _ | Code_id _) -> Expr.create_named (Literal v)) + | None -> v) () e let subst_var_slot @@ -253,13 +253,10 @@ and function_params_and_body_to_core s and subst_cont_id (cont : Continuation.t) (e1 : core_exp) (e2 : core_exp) : core_exp = core_fmap (fun _ x -> - match x with - | (Cont k | Res_cont (Return k)) -> - if Continuation.equal cont k - then e1 - else Expr.create_named (Literal x) - | (Simple _ | Res_cont Never_returns | Slot _ | Code_id _) -> - Expr.create_named (Literal x)) () e2 + match must_be_cont x with + | Some k -> + if Continuation.equal cont k then e1 else x + | _ -> x) () e2 and handler_map_to_closures (phi : Variable.t) (v : Bound_parameter.t list) (m : continuation_handler_map) : set_of_closures =