Skip to content

Commit

Permalink
Concrete instantiation of arguments when eliminated variable appears …
Browse files Browse the repository at this point in the history
…in subexpressions
  • Loading branch information
euisuny committed May 28, 2024
1 parent ef501aa commit f6dd8e5
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 31 deletions.
76 changes: 45 additions & 31 deletions middle_end/flambda2/validate/normalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ let literal_var_eq l v =

(* Check whether a parameter occurs in an expression other than in argument position *)
let rec parameter_occurs_only_as_call_arg (p : Bound_parameter.t) (e : core_exp) =
(match descr e with
match descr e with
| Named n -> parameter_occurs_only_as_call_arg_named p n
| Let e ->
Core_let.pattern_match e
Expand Down Expand Up @@ -111,7 +111,7 @@ let rec parameter_occurs_only_as_call_arg (p : Bound_parameter.t) (e : core_exp)
List.for_all (fun x ->
match must_be_literal x with
| Some _ -> true
| None -> parameter_occurs_only_as_call_arg p x) apply_args)
| None -> parameter_occurs_only_as_call_arg p x) apply_args

and parameter_occurs_only_as_call_arg_prim p (e : primitive) =
(match e with
Expand All @@ -130,8 +130,8 @@ and parameter_occurs_only_as_call_arg_prim p (e : primitive) =

and parameter_occurs_only_as_call_arg_named p (e : named) =
match e with
| Literal l ->
not (literal_var_eq l (Bound_parameter.var p))
| Literal _ -> true
(* not (literal_var_eq l (Bound_parameter.var p)) *)
| Prim e -> parameter_occurs_only_as_call_arg_prim p e
| Closure_expr (_, _, {function_decls; value_slots}) ->
SlotMap.for_all (fun _ a -> parameter_occurs_only_as_call_arg p a) function_decls &&
Expand Down Expand Up @@ -163,25 +163,33 @@ and parameter_occurs_only_as_call_arg_named p (e : named) =
| Empty_array | Mutable_string _ | Immutable_string _ ) -> true)) e
| Rec_info _ -> true

let rec eliminate (p : Bound_parameter.t) (e : core_exp) =
let rec eliminate (p : Bound_parameter.t) v (e : core_exp) =
match Expr.descr e with
| Named e ->
named_fix (eliminate p) (fun () x -> Expr.create_named (Literal x)) () e
named_fix (eliminate p v)
(fun () x ->
(* If there is a literal equivalent to the invariant parameter
in a subexpression, substitute in the concrete value of the
parameter. *)
if literal_var_eq x (Bound_parameter.var p) then v
else Expr.create_named (Literal x)) () e
| Let e ->
let_fix (eliminate p) e
let_fix (eliminate p v) e
| Let_cont e ->
let_cont_fix (eliminate p) e
let_cont_fix (eliminate p v) e
| Apply e ->
eliminate_apply p e
eliminate_apply p v e
| Apply_cont e ->
eliminate_apply_cont p e
| Lambda e -> lambda_fix (eliminate p) e
eliminate_apply_cont p v e
| Lambda e -> lambda_fix (eliminate p v) e
| Handler e ->
handler_fix (eliminate p) e
| Switch e -> switch_fix (eliminate p) e
handler_fix (eliminate p v) e
| Switch e -> switch_fix (eliminate p v) e
| Invalid _ -> e

and eliminate_apply (p : Bound_parameter.t) {callee; continuation; exn_continuation; region; apply_args} =
and eliminate_apply (p : Bound_parameter.t) v {callee; continuation; exn_continuation; region; apply_args} =
(* On the application arguments, remove any exact literal equivalent
to the invariant parameter. *)
let apply_args =
List.filter
(fun x ->
Expand All @@ -190,14 +198,20 @@ and eliminate_apply (p : Bound_parameter.t) {callee; continuation; exn_continuat
| None -> true
) apply_args
in
(* Instantiate any subexpression that contains argument to the concrete value. *)
let apply_args =
List.map
(fun x -> eliminate p v x)
apply_args
in
Expr.create_apply
{callee = eliminate p callee ;
continuation = eliminate p continuation;
exn_continuation = eliminate p exn_continuation;
region = eliminate p region;
{callee = eliminate p v callee ;
continuation = eliminate p v continuation;
exn_continuation = eliminate p v exn_continuation;
region = eliminate p v region;
apply_args}

and eliminate_apply_cont (p : Bound_parameter.t) {k; args} =
and eliminate_apply_cont (p : Bound_parameter.t) v {k; args} =
let args' =
List.filter
(fun x ->
Expand All @@ -207,12 +221,17 @@ and eliminate_apply_cont (p : Bound_parameter.t) {k; args} =
| None -> true
) args
in
let args' =
List.map
(fun x -> eliminate p v x)
args'
in
Expr.create_apply_cont
{k = eliminate p k; args = args'}
{k = eliminate p v k; args = args'}

let eliminate_arguments_rec_call (p : Bound_parameter.t) (e : core_exp) =
let eliminate_arguments_rec_call (p : Bound_parameter.t) v (e : core_exp) =
if parameter_occurs_only_as_call_arg p e then
(true, eliminate p e)
(true, eliminate p v e)
else
(false, e)

Expand Down Expand Up @@ -897,11 +916,12 @@ and step_apply_cont k args : core_exp =
List.combine (Bound_parameters.to_list params) args
in
(* Loop invariant argument reduction :
for each parameter, remove the argument if it only occurs
in argument position of the recursive call and nowhere else. *)
for each parameter, remove the argument if
it only occurs in argument position of the recursive call
and nowhere else. *)
let (params, e', args) =
List.fold_left (fun (l, e, args) (x, arg) ->
let (b, e) = eliminate_arguments_rec_call x e in
let (b, e) = eliminate_arguments_rec_call x arg e in
if b then (l, e, args) else (x :: l, e, arg :: args)
) ([], e', []) concrete_args
in
Expand Down Expand Up @@ -1129,16 +1149,10 @@ and step_set_of_closures var
in
(* Step the body *)
let params_and_body' = step params_and_body in
(* Format.printf "Before reduction : %a\n\n After reduction %a \n\n\n" *)
(* print params_and_body *)
(* print params_and_body'; *)
params_and_body'
| _ -> x))
function_decls)[@ocaml.warning "-4"]
in
(* Format.printf "Before reduction : %a\n\n After reduction %a \n\n\n" *)
(* print_set_of_closures {function_decls = function_decls; value_slots} *)
(* print_set_of_closures {function_decls = function_decls'; value_slots}; *)
{ function_decls = function_decls' ; value_slots }

(* Inline non-recursive continuation handlers first *)
Expand Down
98 changes: 98 additions & 0 deletions middle_end/flambda2/validate/test-validate/tests/list_fold.fl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
let $camlList_fold__first_const57 = Block 0 () in
let code rec loopify(default tailrec) size(27)
rev_append_0
(l1 : [ 0 | 0 of val * [ 0 | 0 of val * val ] ],
l2 : [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
my_closure my_region my_depth
-> k * k1
: [ 0 | 0 of val * [ 0 | 0 of val * val ] ] =
let next_depth = rec_info (succ my_depth) in
let prim = %is_int l1 in
let Pisint = %Tag_imm prim in
(let untagged = %untag_imm Pisint in
switch untagged
| 0 -> k2
| 1 -> k (l2))
where k2 =
let Pfield = %block_load (l1, 0) in
let Pmakeblock = %Block 0 (Pfield, l2) in
let Pfield_1 = %block_load (l1, 1) in
apply direct(rev_append_0)
(my_closure : _ -> [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
(Pfield_1, Pmakeblock)
&my_region
-> k * k1
in
let code size(5)
rev_1 (l : [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
my_closure my_region my_depth
-> k * k1
: [ 0 | 0 of val * [ 0 | 0 of val * val ] ] =
let rev_append = %project_value_slot rev.rev_append my_closure in
apply direct(rev_append_0)
(rev_append : _ -> [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
(l, 0)
&my_region
-> k * k1
in
let code size(36)
aux_3
(f,
accu,
l_accu : [ 0 | 0 of val * [ 0 | 0 of val * val ] ],
param : [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
my_closure my_region my_depth
-> k * k1
: [ 0 of val * [ 0 | 0 of val * val ] ] =
let rev = %project_value_slot aux.rev my_closure in
let prim = %is_int param in
let Pisint = %Tag_imm prim in
(let untagged = %untag_imm Pisint in
switch untagged
| 0 -> k2
| 1 -> k3)
where k3 =
(apply direct(rev_1)
(rev : _ -> [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
(l_accu)
&my_region
-> k3 * k1
where k3 (apply_result : [ 0 | 0 of val * [ 0 | 0 of val * val ] ]) =
let Pmakeblock = %Block 0 (accu, apply_result) in
cont k (Pmakeblock))
where k2 =
((let Pfield = %block_load (param, 0) in
apply f (accu, Pfield) &my_region -> k2 * k1)
where k2 (`*match*` : [ 0 of [ 0 of val * val ] * val ]) =
let Pfield = %block_load (`*match*`, 0) in
cont k (Pfield))
in
let code size(50)
fold_left_map_2
(f, accu, l : [ 0 | 0 of val * [ 0 | 0 of val * val ] ])
my_closure my_region my_depth
-> k * k1
: [ 0 of val * [ 0 | 0 of val * val ] ] =
let rev = %project_value_slot fold_left_map.rev_1 my_closure in
let aux = closure aux_3 @aux with { rev = rev } in
apply direct(aux_3)
(aux : _ -> [ 0 of val * [ 0 | 0 of val * val ] ])
(f, accu, 0, l)
&my_region
-> k * k1
in
(let append = %block_load ($Stdlib.camlStdlib, 36) in
let rev_append = closure rev_append_0 @rev_append in
let rev = closure rev_1 @rev with { rev_append = rev_append } in
let fold_left_map = closure fold_left_map_2 @fold_left_map
with { rev_1 = rev }
in
let Pmakeblock = %Block 0 (append, rev_append, rev, fold_left_map) in
cont k (Pmakeblock))
where k define_root_symbol (module_block) =
let field_0 = %block_load tag(0) size(4) (module_block, 0) in
let field_1 = %block_load tag(0) size(4) (module_block, 1) in
let field_2 = %block_load tag(0) size(4) (module_block, 2) in
let field_3 = %block_load tag(0) size(4) (module_block, 3) in
let $camlList_fold = Block 0 (field_0, field_1, field_2, field_3) in
cont done ($camlList_fold)

0 comments on commit f6dd8e5

Please sign in to comment.