Skip to content

Commit

Permalink
== optimization in flambda2_core
Browse files Browse the repository at this point in the history
  • Loading branch information
euisuny committed Jan 5, 2024
1 parent 56b712f commit e81177c
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 182 deletions.
252 changes: 167 additions & 85 deletions middle_end/flambda2/validate/flambda2_core.ml
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,6 @@ and apply_renaming_function_params_and_body ({expr; anon} as t) renaming =
in
if expr == expr' then t else { expr = expr'; anon = anon }


(* renaming for [Let_cont] *)
and apply_renaming_let_cont ({handler; body} as t) renaming : let_cont_expr =
let handler' =
Expand Down Expand Up @@ -505,6 +504,24 @@ and apply_renaming_switch ({scrutinee; arms} as t) renaming : switch_expr =
then t
else { scrutinee = scrutinee'; arms = arms' }

let must_be_let (e : core_exp) : let_expr option =
match descr e with
| Let n -> Some n
| (Named _ | Let_cont _ | Apply _ | Apply_cont _ | Lambda _ | Switch _
| Handler _ | Invalid _) -> None

let must_be_letcont (e : core_exp) : let_cont_expr option =
match descr e with
| Let_cont n -> Some n
| (Named _ | Let _ | Apply _ | Apply_cont _ | Lambda _ | Switch _
| Handler _ | Invalid _) -> None

let must_be_switch (e : core_exp) : switch_expr option =
match descr e with
| Switch n -> Some n
| (Named _ | Let _ | Let_cont _ | Apply _ | Apply_cont _ | Lambda _
| Handler _ | Invalid _) -> None

let must_be_named (e : core_exp) : named option =
match descr e with
| Named n -> Some n
Expand Down Expand Up @@ -560,6 +577,12 @@ let must_be_apply (e : core_exp) : apply_expr option =
| (Named _ | Let _ | Let_cont _ | Lambda _ | Apply_cont _ | Switch _
| Handler _ | Invalid _) -> None

let must_be_apply_cont (e : core_exp) : apply_cont_expr option =
match descr e with
| Apply_cont e -> Some e
| (Named _ | Let _ | Let_cont _ | Lambda _ | Apply _ | Switch _
| Handler _ | Invalid _) -> None

let must_be_static_consts (e : core_exp) : static_const_group option =
match must_be_named e with
| Some (Static_consts g) -> Some g
Expand Down Expand Up @@ -1397,104 +1420,165 @@ let lambda_to_handler (e : lambda_expr) : continuation_handler =
Core_continuation_handler.create params e)

(* Fixpoint combinator for core expressions *)
let let_fix (f : core_exp -> core_exp) {let_abst; expr_body} =
Core_let.pattern_match {let_abst; expr_body}
~f:(fun ~x ~e1 ~e2 ->
(Core_let.create
~x
~e1:(f e1)
~e2:(f e2)))

let let_cont_fix (f : core_exp -> core_exp) ({handler; body} : let_cont_expr) =
let handler =
Core_continuation_handler.pattern_match handler
(fun param exp ->
Core_continuation_handler.create param (f exp))
let let_fix (f : core_exp -> core_exp) e =
let _let_fix ~x ~e1 ~e2 =
let e1' = f e1 in
let e2' = f e2 in
if e1 == e1' && e2 == e2' then e
else Core_let.create ~x ~e1:e1' ~e2:e2'
in
let body =
Core_letcont_body.pattern_match body
(fun cont exp ->
Core_letcont_body.create cont (f exp))
let {let_abst; expr_body} =
match must_be_let e with
| Some e -> e
| None -> Misc.fatal_error "Expected let expr"
in
Core_let.pattern_match {let_abst; expr_body} ~f:_let_fix

let let_cont_fix (f : core_exp -> core_exp) e =
let {handler; body} =
match must_be_letcont e with
| Some e -> e
| None -> Misc.fatal_error "Expected let expr"
in
let handler_fix param exp =
let exp' = f exp in
if exp == exp' then handler
else Core_continuation_handler.create param exp'
in
let body_fix cont exp =
let exp' = f exp in
if exp == exp' then body
else Core_letcont_body.create cont exp'
in
let handler = Core_continuation_handler.pattern_match handler handler_fix in
let body = Core_letcont_body.pattern_match body body_fix in
Core_letcont.create handler ~body

let handler_fix (f : core_exp -> core_exp)
(handler : continuation_handler) =
(Core_continuation_handler.pattern_match handler
(fun param exp -> Core_continuation_handler.create param (f exp)))
|> Expr.create_handler
let handler_fix (f : core_exp -> core_exp) e =
let handler =
match must_be_handler e with
| Some e -> e
| None -> Misc.fatal_error "Expected handler"
in
let _handler_fix param exp =
let exp' = f exp in
if exp == exp' then e
else Core_continuation_handler.create param (f exp)
|> Expr.create_handler
in
Core_continuation_handler.pattern_match handler _handler_fix

let apply_fix (f : core_exp -> core_exp)
({callee; continuation; exn_continuation; region; apply_args} : apply_expr) =
Apply
{callee = f callee;
continuation = f continuation;
exn_continuation = f exn_continuation;
region = f region;
apply_args = List.map f apply_args;}
|> With_delayed_renaming.create

let apply_cont_fix (f : core_exp -> core_exp)
({k; args} : apply_cont_expr) =
Expr.create_apply_cont
{k = f k;
args = List.map f args}

let lambda_fix (f : core_exp -> core_exp) (e : lambda_expr) =
Core_lambda.pattern_match e
~f:(fun b e ->
(Core_lambda.create b (f e)))
|> Expr.create_lambda
let apply_fix (f : core_exp -> core_exp) e =
let {callee; continuation; exn_continuation; region; apply_args} =
match must_be_apply e with
| Some e -> e
| None -> Misc.fatal_error "Expected apply expr"
in
let callee' = f callee in
let continuation' = f continuation in
let exn_continuation' = f exn_continuation in
let region' = f region in
let apply_args' = Misc.Stdlib.List.map_sharing f apply_args in
if callee == callee'
&& continuation == continuation'
&& exn_continuation == exn_continuation'
&& region = region'
&& apply_args = apply_args' then e
else
Apply
{callee = callee';
continuation = continuation';
exn_continuation = exn_continuation';
region = region';
apply_args = apply_args'}
|> With_delayed_renaming.create

let apply_cont_fix (f : core_exp -> core_exp) e =
let {k; args} =
match must_be_apply_cont e with
| Some e -> e
| None -> Misc.fatal_error "Expected apply cont expr"
in
let k' = f k in
let args' = Misc.Stdlib.List.map_sharing f args in
if k == k' && args == args' then e
else Expr.create_apply_cont {k = k'; args = args'}

let lambda_fix (f : core_exp -> core_exp) e =
let _lambda_fix b x =
let x' = f x in
if x == x' then e else Core_lambda.create b x' |> Expr.create_lambda
in
let e =
match must_be_lambda e with
| Some e -> e
| None -> Misc.fatal_error "Expected lambda expr"
in
Core_lambda.pattern_match e ~f:_lambda_fix

let switch_fix (f : core_exp -> core_exp)
({scrutinee; arms} : switch_expr) =
{scrutinee = f scrutinee;
arms = Targetint_31_63.Map.map f arms}
|> Expr.create_switch
let switch_fix (f : core_exp -> core_exp) e =
let {scrutinee; arms} =
match must_be_switch e with
| Some e -> e
| None -> Misc.fatal_error "Expected switch expr"
in
let scrutinee' = f scrutinee in
let arms' = Targetint_31_63.Map.map f arms in
if scrutinee == scrutinee' && arms == arms' then e
else { scrutinee = scrutinee'; arms = arms'} |> Expr.create_switch

let set_of_closures_fix
(fix : core_exp -> core_exp) {function_decls; value_slots} =
let function_decls = SlotMap.map fix function_decls in
let value_slots =
Value_slot.Map.map (fun x -> fix x) value_slots
in
{function_decls; value_slots}
(fix : core_exp -> core_exp) ({function_decls; value_slots} as t) =
let function_decls' = SlotMap.map fix function_decls in
let value_slots' = Value_slot.Map.map fix value_slots in
if function_decls == function_decls' && value_slots == value_slots'
then t else {function_decls = function_decls'; value_slots = value_slots'}

let static_const_fix (fix : core_exp -> core_exp) (e : static_const) =
match e with
| Static_set_of_closures clo ->
let {function_decls; value_slots} = set_of_closures_fix fix clo in
Static_set_of_closures {function_decls; value_slots}
let t = set_of_closures_fix fix clo in
if Static_set_of_closures t == e then e else Static_set_of_closures t
| Block (tag, mut, list) ->
let list = List.map fix list in
Block (tag, mut, list)
let list' = Misc.Stdlib.List.map_sharing fix list in
if list == list' then e else Block (tag, mut, list')
| ( Boxed_float _ | Boxed_int32 _ | Boxed_int64 _ | Boxed_nativeint _
| Immutable_float_block _ | Immutable_float_array _ | Immutable_value_array _
| Empty_array | Mutable_string _ | Immutable_string _ ) -> e

let static_const_or_code_fix (fix : core_exp -> core_exp)
(e : static_const_or_code) =
let code_lambda_fix (anon : bool) params bound body =
let body' = fix body in
if body == body' then e
else
let expr =
body'
|> Core_lambda.create bound
|> Core_function_params_and_body.create params
in Code {expr; anon}
in
let code_fix anon params body =
Core_lambda.pattern_match body ~f:(code_lambda_fix anon params)
in
(match e with
| Code {expr; anon}->
Code
{expr =
Core_function_params_and_body.pattern_match expr
~f:(fun
params body ->
Core_function_params_and_body.create
params
(Core_lambda.pattern_match body
~f:(fun bound body ->
Core_lambda.create bound (fix body))));
anon}
| Code {expr; anon} ->
Core_function_params_and_body.pattern_match expr ~f:(code_fix anon)
| Deleted_code -> e
| Static_const const ->
Static_const (static_const_fix fix const))

let static_const_group_fix (fix : core_exp -> core_exp)
(e : static_const_group) =
Named (Static_consts (List.map (static_const_or_code_fix fix) e))
|> With_delayed_renaming.create
let static_const_group_fix (fix : core_exp -> core_exp) e =
let g =
match must_be_static_consts e with
| Some e -> e
| None -> Misc.fatal_error "Expected static const group"
in
let g' =
Misc.Stdlib.List.map_sharing (static_const_or_code_fix fix) g
in
if g == g' then e
else Named (Static_consts g') |> With_delayed_renaming.create

let prim_fix (fix : core_exp -> core_exp) (e : core_exp) =
let p =
Expand Down Expand Up @@ -1556,22 +1640,20 @@ let named_fix (fix : core_exp -> core_exp)
else
Named (Set_of_closures {function_decls; value_slots})
|> With_delayed_renaming.create
| Static_consts group ->
static_const_group_fix fix group
| Static_consts _-> static_const_group_fix fix e
| Rec_info _ -> e

(* LATER: Make this first order? *)
let rec core_fmap
(f : 'a -> core_exp -> core_exp) (arg : 'a) (e : core_exp) : core_exp =
match descr e with
| Named _ -> named_fix (core_fmap f arg) f arg e
| Let e -> let_fix (core_fmap f arg) e
| Let_cont e -> let_cont_fix (core_fmap f arg) e
| Apply e -> apply_fix (core_fmap f arg) e
| Apply_cont e -> apply_cont_fix (core_fmap f arg) e
| Lambda e -> lambda_fix (core_fmap f arg) e
| Handler e -> handler_fix (core_fmap f arg) e
| Switch e -> switch_fix (core_fmap f arg) e
| Let _ -> let_fix (core_fmap f arg) e
| Let_cont _ -> let_cont_fix (core_fmap f arg) e
| Apply _ -> apply_fix (core_fmap f arg) e
| Apply_cont _ -> apply_cont_fix (core_fmap f arg) e
| Lambda _ -> lambda_fix (core_fmap f arg) e
| Handler _ -> handler_fix (core_fmap f arg) e
| Switch _ -> switch_fix (core_fmap f arg) e
| Invalid _ -> e

let literal_contained (literal1 : literal) (literal2 : literal) : bool =
Expand Down
34 changes: 14 additions & 20 deletions middle_end/flambda2/validate/flambda2_core.mli
Original file line number Diff line number Diff line change
Expand Up @@ -306,28 +306,22 @@ val lambda_to_handler : lambda_expr -> continuation_handler

val core_fmap : ('a -> core_exp -> core_exp) -> 'a -> core_exp -> core_exp

(* Fixpoint functions for core expressions *)
val let_fix : (core_exp -> core_exp) -> let_expr -> core_exp
val let_cont_fix : (core_exp -> core_exp) -> let_cont_expr -> core_exp
val apply_fix : (core_exp -> core_exp) -> apply_expr -> core_exp
val apply_cont_fix : (core_exp -> core_exp) -> apply_cont_expr -> core_exp
val lambda_fix : (core_exp -> core_exp) -> lambda_expr -> core_exp
val handler_fix : (core_exp -> core_exp) -> continuation_handler -> core_exp
val switch_fix : (core_exp -> core_exp) -> switch_expr -> core_exp

val named_fix :
(core_exp -> core_exp) ->
('a -> core_exp -> core_exp) ->
(* Fixpoint functions for core expressions.
The signature expects a `core_exp` for performance *)
val let_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val let_cont_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val apply_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val apply_cont_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val lambda_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val handler_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val switch_fix : (core_exp -> core_exp) -> core_exp -> core_exp

val named_fix : (core_exp -> core_exp) -> ('a -> core_exp -> core_exp) ->
'a -> core_exp -> core_exp
val set_of_closures_fix :
(core_exp -> core_exp) ->
set_of_closures -> set_of_closures
val prim_fix :
(core_exp -> core_exp) ->
core_exp -> core_exp
val static_const_group_fix :
(core_exp -> core_exp) ->
static_const_group -> core_exp
(core_exp -> core_exp) -> set_of_closures -> set_of_closures
val prim_fix : (core_exp -> core_exp) -> core_exp -> core_exp
val static_const_group_fix : (core_exp -> core_exp) -> core_exp -> core_exp

val literal_contained : literal -> literal -> bool

Expand Down
Loading

0 comments on commit e81177c

Please sign in to comment.