From fa52c963822365f5a37963747d483f74b34197c9 Mon Sep 17 00:00:00 2001 From: Zhicheng HUI Date: Mon, 26 Aug 2024 15:27:09 +0200 Subject: [PATCH] declare instruction-level block_type in rewriting process --- src/ast/binary_encoder.ml | 2 +- src/text_to_binary/grouped.ml | 28 ++++++++-- src/text_to_binary/rewrite.ml | 101 +++++++++++++++------------------- test/opt/if.t | 2 + test/wat2wasm/loop.t | 2 + 5 files changed, 72 insertions(+), 63 deletions(-) diff --git a/src/ast/binary_encoder.ml b/src/ast/binary_encoder.ml index c445c3520..78e82a27e 100644 --- a/src/ast/binary_encoder.ml +++ b/src/ast/binary_encoder.ml @@ -117,7 +117,7 @@ let write_block_type buf (typ : binary block_type option) = match typ with | None | Some (Bt_raw (None, ([], []))) -> Buffer.add_char buf '\x40' | Some (Bt_raw (None, ([], [ vt ]))) -> write_valtype buf vt - | Some (Bt_raw (None, (pt, _))) -> write_paramtype buf pt + | Some (Bt_raw (Some idx, _)) -> write_indice buf idx (* TODO: memo will this pattern matching be enough with the use of the new modul.types field? *) diff --git a/src/text_to_binary/grouped.ml b/src/text_to_binary/grouped.ml index 6f77fd367..afe6046bf 100644 --- a/src/text_to_binary/grouped.ml +++ b/src/text_to_binary/grouped.ml @@ -77,7 +77,7 @@ let init_curr () = ; data = ref 0 } -let declare_func_type type_f (fields : t) = +let declare_func_type (fields : t) type_f = match type_f with | Bt_ind _ -> fields | Bt_raw (id, typ) -> @@ -103,13 +103,31 @@ let add_mem value (fields : t) (curr : curr) = incr curr.mem; { fields with mem = Indexed.return index value :: fields.mem } +let rec extract_block_types expr = + let aux instr = + match instr with + | Block (_str_opt, bt, expr1) -> + Option.to_list bt @ extract_block_types expr1 + | Loop (_str_opt, bt, expr1) -> + Option.to_list bt @ extract_block_types expr1 + | If_else (_str_opt, bt, expr1, expr2) -> + Option.to_list bt @ extract_block_types expr1 @ extract_block_types expr2 + | Return_call_indirect (_ind, bt) -> [ bt ] + | Return_call_ref bt -> [ bt ] + | Call_indirect (_ind, bt) -> [ bt ] + | _ -> [] + in + List.concat_map aux expr + let add_func value (fields : t) (curr : curr) = - let func_type = + let fields = match value with - | Runtime.Local func -> func.type_f - | Imported func -> func.desc + | Runtime.Local func -> + let fields = declare_func_type fields func.type_f in + let temp = extract_block_types func.body in + List.fold_left declare_func_type fields temp + | Imported func -> declare_func_type fields func.desc in - let fields = declare_func_type func_type fields in let index = !(curr.func) in incr curr.func; { fields with func = Indexed.return index value :: fields.func } diff --git a/src/text_to_binary/rewrite.ml b/src/text_to_binary/rewrite.ml index 8e9fe6e17..51a7b4931 100644 --- a/src/text_to_binary/rewrite.ml +++ b/src/text_to_binary/rewrite.ml @@ -34,6 +34,26 @@ let find_global (modul : Assigned.t) id : binary indice = find modul.global id let find_memory (modul : Assigned.t) id : binary indice = find modul.mem id +let rewrite_block_type (typemap : binary indice TypeMap.t) (modul : Assigned.t) + (block_type : text block_type) : binary block_type Result.t = + match block_type with + | Bt_ind id -> begin + let+ v = get (`Unknown_type id) modul.typ id in + match Indexed.get v with + | Def_func_t t' -> + let idx = Indexed.get_index v in + Bt_raw (Some (Raw idx), t') + | _ -> assert false + end + | Bt_raw (_, func_type) -> + let+ t = Binary_types.convert_func_type None func_type in + let idx = + match TypeMap.find_opt (Def_func_t t) typemap with + | None -> assert false + | Some idx -> idx + in + Bt_raw (Some idx, t) + let rewrite_expr (modul : Assigned.t) (locals : binary param list) (iexpr : text expr) : binary expr Result.t = (* block_ids handling *) @@ -57,37 +77,6 @@ let rewrite_expr (modul : Assigned.t) (locals : binary param list) else Ok (Raw id) in - let bt_some_to_raw : text block_type -> binary block_type Result.t = function - | Bt_ind ind -> begin - let+ v = get (`Unknown_type ind) modul.typ ind in - match Indexed.get v with - | Def_func_t t' -> - let idx = Indexed.get_index v in - Bt_raw (Some (Raw idx), t') - | _ -> assert false - end - | Bt_raw (type_use, t) -> ( - let* t = Binary_types.convert_func_type None t in - match type_use with - | None -> Ok (Bt_raw (None, t)) - | Some ind -> - (* we check that the explicit type match the type_use, we have to remove parameters names to do so *) - let* t' = - let+ v = get (`Unknown_type ind) modul.typ ind in - match Indexed.get v with Def_func_t t' -> t' | _ -> assert false - in - let ok = Types.func_type_eq t t' in - if not ok then Error `Inline_function_type else Ok (Bt_raw (None, t)) ) - in - - let bt_to_raw : text block_type option -> binary block_type option Result.t = - function - | None -> Ok None - | Some bt -> - let+ raw = bt_some_to_raw bt in - Some raw - in - let* locals, _after_last_assigned_local = list_fold_left (fun (locals, next_free_int) ((name, _type) : binary param) -> @@ -143,32 +132,50 @@ let rewrite_expr (modul : Assigned.t) (locals : binary param list) let id = find_local id in Ok (Local_tee id) | If_else (id, bt, e1, e2) -> - let* bt = bt_to_raw bt in + let* bt = + match bt with + | Some bt -> + let+ bt = rewrite_block_type (typemap modul.typ) modul bt in + Some bt + | None -> Ok None + in let block_ids = id :: block_ids in let* e1 = expr e1 (loop_count, block_ids) in let+ e2 = expr e2 (loop_count, block_ids) in If_else (id, bt, e1, e2) | Loop (id, bt, e) -> - let* bt = bt_to_raw bt in + let* bt = + match bt with + | Some bt -> + let+ bt = rewrite_block_type (typemap modul.typ) modul bt in + Some bt + | None -> Ok None + in let+ e = expr e (loop_count + 1, id :: block_ids) in Loop (id, bt, e) | Block (id, bt, e) -> - let* bt = bt_to_raw bt in + let* bt = + match bt with + | Some bt -> + let+ bt = rewrite_block_type (typemap modul.typ) modul bt in + Some bt + | None -> Ok None + in let+ e = expr e (loop_count, id :: block_ids) in Block (id, bt, e) | Call_indirect (tbl_i, bt) -> let tbl_i = find_table tbl_i in - let+ bt = bt_some_to_raw bt in + let+ bt = rewrite_block_type (typemap modul.typ) modul bt in Call_indirect (tbl_i, bt) | Return_call_indirect (tbl_i, bt) -> let tbl_i = find_table tbl_i in - let+ bt = bt_some_to_raw bt in + let+ bt = rewrite_block_type (typemap modul.typ) modul bt in Return_call_indirect (tbl_i, bt) | Call_ref t -> let t = find_type t in Ok (Call_ref t) | Return_call_ref bt -> - let+ bt = bt_some_to_raw bt in + let+ bt = rewrite_block_type (typemap modul.typ) modul bt in Return_call_ref bt | Global_set id -> let idx = find_global modul id in @@ -270,26 +277,6 @@ let rewrite_expr (modul : Assigned.t) (locals : binary param list) in expr iexpr (0, []) -let rewrite_block_type (typemap : binary indice TypeMap.t) (modul : Assigned.t) - (block_type : text block_type) : binary block_type Result.t = - match block_type with - | Bt_ind id -> begin - let+ v = get (`Unknown_type id) modul.typ id in - match Indexed.get v with - | Def_func_t t' -> - let idx = Indexed.get_index v in - Bt_raw (Some (Raw idx), t') - | _ -> assert false - end - | Bt_raw (_, func_type) -> - let+ t = Binary_types.convert_func_type None func_type in - let idx = - match TypeMap.find_opt (Def_func_t t) typemap with - | None -> assert false - | Some idx -> idx - in - Bt_raw (Some idx, t) - let rewrite_global (modul : Assigned.t) (global : Text.global) : Binary.global Result.t = let* init = rewrite_expr modul [] global.init in diff --git a/test/opt/if.t b/test/opt/if.t index 3f390f50b..a7a76dd9e 100644 --- a/test/opt/if.t +++ b/test/opt/if.t @@ -4,6 +4,8 @@ if then else instruction: (module (type (sub final (func))) + + (type (sub final (func (result i32)))) (func $start (block (result i32) i32.const 42) diff --git a/test/wat2wasm/loop.t b/test/wat2wasm/loop.t index e58e75c31..5028eafd0 100644 --- a/test/wat2wasm/loop.t +++ b/test/wat2wasm/loop.t @@ -115,6 +115,8 @@ (type (sub final (func (param i32) (result i64)))) + (type (sub final (func (result i64)))) + (type (sub final (func))) (func (param i32) (result i64) (local i32) i32.const 0