Skip to content

Commit

Permalink
More optimization on exp fix operators
Browse files Browse the repository at this point in the history
  • Loading branch information
euisuny committed Jan 4, 2024
1 parent 1a34c40 commit 56b712f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 41 deletions.
43 changes: 33 additions & 10 deletions middle_end/flambda2/validate/flambda2_core.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1496,18 +1496,41 @@ let static_const_group_fix (fix : core_exp -> core_exp)
Named (Static_consts (List.map (static_const_or_code_fix fix) e))
|> With_delayed_renaming.create

let prim_fix (fix : core_exp -> core_exp) (e : primitive) =
(match e with
| Nullary _ -> Named (Prim e)
| Unary (p, e) ->
Named (Prim (Unary (p, fix e)))
let prim_fix (fix : core_exp -> core_exp) (e : core_exp) =
let p =
(match must_be_prim e with
| Some p -> p
| None -> Misc.fatal_error "Expected primitive expr")
in
match p with
| Nullary _ -> e
| Unary (p, e1) ->
let e1' = fix e1 in
if e1 == e1' then e
else
Named (Prim (Unary (p, e1')))
|> With_delayed_renaming.create
| Binary (p, e1, e2) ->
Named (Prim (Binary (p, fix e1, fix e2)))
let e1' = fix e1 in
let e2' = fix e2 in
if e1 == e1' && e2 == e2' then e
else
Named (Prim (Binary (p, e1', e2')))
|> With_delayed_renaming.create
| Ternary (p, e1, e2, e3) ->
Named (Prim (Ternary (p, fix e1, fix e2, fix e3)))
let e1' = fix e1 in
let e2' = fix e2 in
let e3' = fix e3 in
if e1 == e1' && e2 == e2' && e3 == e3' then e
else
Named (Prim (Ternary (p, e1', e2', e3')))
|> With_delayed_renaming.create
| Variadic (p, list) ->
Named (Prim (Variadic (p, List.map fix list))))
|> With_delayed_renaming.create
let list' = Misc.Stdlib.List.map_sharing fix list in
if list == list' then e
else
Named (Prim (Variadic (p, list')))
|> With_delayed_renaming.create

let named_fix (fix : core_exp -> core_exp)
(f : 'a -> core_exp -> core_exp) arg (e : core_exp) =
Expand All @@ -1520,7 +1543,7 @@ let named_fix (fix : core_exp -> core_exp)
| Literal _ ->
let f' = f arg e in
if e == f' then e else f'
| Prim e -> prim_fix fix e
| Prim _ -> prim_fix fix e
| Closure_expr (phi, slot, clo) ->
let ({function_decls; value_slots} as e') = set_of_closures_fix fix clo in
if clo == e' then e
Expand Down
4 changes: 3 additions & 1 deletion middle_end/flambda2/validate/flambda2_core.mli
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ val named_fix :
val set_of_closures_fix :
(core_exp -> core_exp) ->
set_of_closures -> set_of_closures
val prim_fix : (core_exp -> core_exp) -> primitive -> core_exp
val prim_fix :
(core_exp -> core_exp) ->
core_exp -> core_exp
val static_const_group_fix :
(core_exp -> core_exp) ->
static_const_group -> core_exp
Expand Down
48 changes: 24 additions & 24 deletions middle_end/flambda2/validate/normalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ let rec subst_pattern ~(bound : Bound_for_let.t) ~(let_body : core_exp)
and subst_singleton_set_of_closures ~(bound: Variable.t)
~(clo : set_of_closures) (e : core_exp) : core_exp =
match descr e with
| Named e -> subst_singleton_set_of_closures_named ~bound ~clo e
| Named n -> subst_singleton_set_of_closures_named ~bound ~clo n e
| Let e ->
let_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Let_cont e ->
Expand All @@ -103,7 +103,7 @@ and subst_singleton_set_of_closures ~(bound: Variable.t)
switch_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Invalid _ -> e

and subst_singleton_set_of_closures_named ~bound ~clo (e : named) : core_exp =
and subst_singleton_set_of_closures_named ~bound ~clo (n : named) (e : core_exp) : core_exp =
let f bound (v : literal) =
(match v with
| Simple v ->
Expand All @@ -120,15 +120,16 @@ and subst_singleton_set_of_closures_named ~bound ~clo (e : named) : core_exp =
(let decls = SlotMap.bindings clo.function_decls in
let bound_closure = List.find_opt (fun (x, _) -> x = slot) decls in
(match bound_closure with
| None -> Expr.create_named e
| None -> e
| Some (k, _) -> Expr.create_named (Closure_expr (phi, k, clo))
))
| (Cont _ | Res_cont _ | Slot (_, Value_slot _) | Code_id _) ->
Expr.create_named (Literal v))
in
match e with
match n with
| Literal v -> f bound v
| Prim e -> prim_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Prim _ ->
prim_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Closure_expr (phi, slot, set) ->
let set =
set_of_closures_fix (subst_singleton_set_of_closures ~bound ~clo) set
Expand All @@ -141,7 +142,7 @@ and subst_singleton_set_of_closures_named ~bound ~clo (e : named) : core_exp =
Expr.create_named (Set_of_closures set)
| Static_consts group ->
static_const_group_fix (subst_singleton_set_of_closures ~bound ~clo) group
| Rec_info _ -> Expr.create_named e
| Rec_info _ -> e

and subst_static_list ~(bound : Bound_codelike.t) ~let_body e : core_exp =
let rec subst_static_list_ bound body e =
Expand Down Expand Up @@ -181,9 +182,9 @@ and subst_pattern_static
| Block_like bound ->
subst_block_like ~bound ~let_body named
| Set_of_closures set ->
subst_bound_set_of_closures set ~let_body named
subst_bound_set_of_closures set ~let_body named e
| Code id ->
subst_code_id id ~let_body named)
subst_code_id id ~let_body named e)
| Invalid _ -> e

(* [Set of closures]
Expand All @@ -193,8 +194,8 @@ and subst_pattern_static
[let f = closure f_0 @f] where [@f] is the function slot and [f_0] refers
to the code *)
and subst_bound_set_of_closures (bound : Bound_var.t) ~(let_body : static_const_or_code)
(e : named) =
match e with
(n : named) (e : core_exp) =
match n with
| Literal (Simple v) ->
(match let_body with
| Static_const const ->
Expand All @@ -203,10 +204,10 @@ and subst_bound_set_of_closures (bound : Bound_var.t) ~(let_body : static_const_
if Simple.same v (Simple.var (Bound_var.var bound)) then
Expr.create_named
(Static_consts [Static_const (Static_set_of_closures set)])
else Expr.create_named e
| None -> Expr.create_named e)
else e
| None -> e)
| (Deleted_code | Code _) -> Misc.fatal_error "Cannot be reached")
| Prim e ->
| Prim _ ->
prim_fix (subst_pattern_static
~bound:(Bound_codelike.Pattern.set_of_closures bound)
~let_body) e
Expand All @@ -227,13 +228,12 @@ and subst_bound_set_of_closures (bound : Bound_var.t) ~(let_body : static_const_
List.find_opt (fun (x, _) -> x = slot) bound
in
(match bound_closure with
| None -> Expr.create_named e
| None -> e
| Some (k, _) -> Expr.create_named (Closure_expr (phi, k, set)))
| None -> Misc.fatal_error "Cannot be reached")
| (Deleted_code | Code _) -> Misc.fatal_error "Cannot be reached")
| Literal (Res_cont _ | Code_id _ | Cont _ | Slot (_, Value_slot _))
| Closure_expr _ | Set_of_closures _ | Rec_info _ ->
Expr.create_named e
| Closure_expr _ | Set_of_closures _ | Rec_info _ -> e

and subst_code_id_set_of_closures (bound : Code_id.t) ~let_body
{function_decls; value_slots}
Expand All @@ -247,16 +247,16 @@ and subst_code_id_set_of_closures (bound : Code_id.t) ~let_body
in
{function_decls; value_slots}

and subst_code_id (bound : Code_id.t) ~(let_body : static_const_or_code) (e : named) : core_exp =
match e with
| Literal e ->
(match e with
and subst_code_id (bound : Code_id.t) ~(let_body : static_const_or_code) (n : named) (e : core_exp) : core_exp =
match n with
| Literal l ->
(match l with
| Code_id code_id ->
if (Code_id.compare code_id bound = 0)
then Expr.create_named (Static_consts [let_body])
else (Expr.create_named (Literal e))
| (Simple _ | Cont _ | Res_cont _ | Slot _) -> Expr.create_named (Literal e))
| Prim e ->
else e
| (Simple _ | Cont _ | Res_cont _ | Slot _) -> e)
| Prim _ ->
prim_fix
(subst_pattern_static
~bound:(Bound_codelike.Pattern.code bound) ~let_body) e
Expand All @@ -270,7 +270,7 @@ and subst_code_id (bound : Code_id.t) ~(let_body : static_const_or_code) (e : na
static_const_group_fix
(subst_pattern_static ~bound:(Bound_codelike.Pattern.code bound) ~let_body)
consts
| Rec_info _ -> Expr.create_named e
| Rec_info _ -> e

and subst_block_like
~(bound : Symbol.t) ~(let_body : static_const_or_code) (e : named) : core_exp =
Expand Down
12 changes: 6 additions & 6 deletions middle_end/flambda2/validate/translate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ and let_cont_to_core (e : Let_cont_expr.t) (sub : env) :
and subst_singleton_set_of_closures ~(bound: Variable.t)
~(clo : set_of_closures) (e : core_exp) : core_exp =
match descr e with
| Named e -> subst_singleton_set_of_closures_named ~bound ~clo e
| Named n -> subst_singleton_set_of_closures_named ~bound ~clo n e
| Let e ->
let_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Let_cont e ->
Expand All @@ -358,7 +358,7 @@ and subst_singleton_set_of_closures ~(bound: Variable.t)
switch_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Invalid _ -> e

and subst_singleton_set_of_closures_named ~bound ~clo (e : named) : core_exp =
and subst_singleton_set_of_closures_named ~bound ~clo (n : named) (e : core_exp) : core_exp =
let f bound (v : literal) =
(match v with
| Simple v ->
Expand All @@ -375,15 +375,15 @@ and subst_singleton_set_of_closures_named ~bound ~clo (e : named) : core_exp =
(let decls = SlotMap.bindings clo.function_decls in
let bound_closure = List.find_opt (fun (x, _) -> x = slot) decls in
(match bound_closure with
| None -> Expr.create_named e
| None -> e
| Some (k, _) -> Expr.create_named (Closure_expr (phi, k, clo))
))
| (Cont _ | Res_cont _ | Slot (_, Value_slot _) | Code_id _) ->
Expr.create_named (Literal v))
in
match e with
match n with
| Literal v -> f bound v
| Prim e -> prim_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Prim _ -> prim_fix (subst_singleton_set_of_closures ~bound ~clo) e
| Closure_expr (phi, slot, set) ->
let set =
set_of_closures_fix (subst_singleton_set_of_closures ~bound ~clo) set
Expand All @@ -396,7 +396,7 @@ and subst_singleton_set_of_closures_named ~bound ~clo (e : named) : core_exp =
Expr.create_named (Set_of_closures set)
| Static_consts group ->
static_const_group_fix (subst_singleton_set_of_closures ~bound ~clo) group
| Rec_info _ -> Expr.create_named e
| Rec_info _ -> e

and cont_handler_to_core
(e : Continuation_handler.t) (s : env)
Expand Down

0 comments on commit 56b712f

Please sign in to comment.