diff --git a/example/define_host_function/dune b/example/define_host_function/dune index d78a500f1..86a5f8fe2 100644 --- a/example/define_host_function/dune +++ b/example/define_host_function/dune @@ -1,9 +1,11 @@ (executable (name extern) + (modules extern) (libraries owi)) (executable (name extern_mem) + (modules extern_mem) (libraries owi)) (mdx diff --git a/src/ast/binary_encoder.ml b/src/ast/binary_encoder.ml index a4d8ee345..b93a1e3a6 100644 --- a/src/ast/binary_encoder.ml +++ b/src/ast/binary_encoder.ml @@ -28,7 +28,7 @@ let write_bytes_8 buf i = let rec write_u64 buf i = let b = Int64.to_int (Int64.logand i 0x7fL) in - if 0L <= i && i < 128L then write_byte buf b + if Int64.le 0L i && Int64.lt i 128L then write_byte buf b else begin write_byte buf (b lor 0x80); write_u64 buf (Int64.shift_right_logical i 7) @@ -48,7 +48,7 @@ let write_string buf str = let rec write_s64 buf i = let b = Int64.to_int (Int64.logand i 0x7fL) in - if -64L <= i && i < 64L then write_byte buf b + if Int64.le (-64L) i && Int64.lt i 64L then write_byte buf b else begin write_byte buf (b lor 0x80); write_s64 buf (Int64.shift_right i 7) @@ -524,7 +524,8 @@ let write_locals buf locals = (fun compressed (_so, local_type) -> let c = get_char_valtype local_type in match compressed with - | (ch, cnt) :: compressed when ch = c -> (c, cnt + 1) :: compressed + | (ch, cnt) :: compressed when Char.equal ch c -> + (c, cnt + 1) :: compressed | compressed -> (c, 1) :: compressed ) [] locals in @@ -770,8 +771,8 @@ let encode (modul : Binary.modul) = let write_file filename content = let filename, _ext = Fpath.split_ext filename in - let filename = Fpath.filename filename in - let filename = filename ^ ".wasm" in + let filename = Fpath.add_ext ".wasm" filename in + let filename = Fpath.to_string filename in let oc = Out_channel.open_bin filename in Out_channel.output_string oc content; Out_channel.close oc diff --git a/src/ast/binary_to_text.ml b/src/ast/binary_to_text.ml index 30ac34f00..f30a9fb45 100644 --- a/src/ast/binary_to_text.ml +++ b/src/ast/binary_to_text.ml @@ -129,7 +129,8 @@ let from_types (types : Types.binary Types.rec_type Named.t) : let t = convert_rec_type t in (i, MType t) :: acc ) types [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_global (global : (Binary.global, binary global_type) Runtime.t Named.t) : Text.module_field list = @@ -145,7 +146,8 @@ let from_global (global : (Binary.global, binary global_type) Runtime.t Named.t) let desc = Import_global (assigned_name, convert_global_type desc) in (i, MImport { modul; name; desc }) :: acc ) global [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_table (table : (binary table, binary table_type) Runtime.t Named.t) : Text.module_field list = @@ -159,7 +161,8 @@ let from_table (table : (binary table, binary table_type) Runtime.t Named.t) : let desc = Import_table (assigned_name, convert_table_type desc) in (i, MImport { modul; name; desc }) :: acc ) table [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_mem (mem : (mem, limits) Runtime.t Named.t) : Text.module_field list = Named.fold @@ -170,7 +173,8 @@ let from_mem (mem : (mem, limits) Runtime.t Named.t) : Text.module_field list = let desc = Import_mem (assigned_name, desc) in (i, MImport { modul; name; desc }) :: acc ) mem [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_func (func : (binary func, binary block_type) Runtime.t Named.t) : Text.module_field list = @@ -187,7 +191,8 @@ let from_func (func : (binary func, binary block_type) Runtime.t Named.t) : let desc = Import_func (assigned_name, convert_block_type desc) in (i, MImport { modul; name; desc }) :: acc ) func [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_elem (elem : Binary.elem Named.t) : Text.module_field list = Named.fold @@ -195,7 +200,8 @@ let from_elem (elem : Binary.elem Named.t) : Text.module_field list = let elem = convert_elem elem in (i, MElem elem) :: acc ) elem [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_data (data : Binary.data Named.t) : Text.module_field list = Named.fold @@ -203,7 +209,8 @@ let from_data (data : Binary.data Named.t) : Text.module_field list = let data = convert_data data in (i, MData data) :: acc ) data [] - |> List.sort compare |> List.map snd + |> List.sort (fun (i1, _t1) (i2, _t2) -> Int.compare i1 i2) + |> List.map snd let from_exports (exports : Binary.exports) : Text.module_field list = let global = @@ -244,7 +251,6 @@ let from_start = function None -> [] | Some n -> [ MStart (Raw n) ] let modul { Binary.id; types; global; table; mem; func; elem; data; start; exports } = - ignore types; let fields = from_types types @ from_global global @ from_table table @ from_mem mem @ from_func func @ from_elem elem @ from_data data @ from_exports exports diff --git a/src/ast/binary_types.ml b/src/ast/binary_types.ml index f297d3146..167139f17 100644 --- a/src/ast/binary_types.ml +++ b/src/ast/binary_types.ml @@ -5,13 +5,6 @@ open Types open Syntax -let equal_func_types (a : binary func_type) (b : binary func_type) : bool = - let remove_param (pt, rt) = - let pt = List.map (fun (_id, vt) -> (None, vt)) pt in - (pt, rt) - in - remove_param a = remove_param b - type tbl = (string, int) Hashtbl.t Option.t let convert_heap_type tbl = function diff --git a/src/ast/binary_types.mli b/src/ast/binary_types.mli index 74f67f9d5..650e1035c 100644 --- a/src/ast/binary_types.mli +++ b/src/ast/binary_types.mli @@ -6,8 +6,6 @@ open Types type tbl = (string, int) Hashtbl.t Option.t -val equal_func_types : binary func_type -> binary func_type -> bool - val convert_val_type : tbl -> text val_type -> binary val_type Result.t val convert_heap_type : tbl -> text heap_type -> binary heap_type Result.t diff --git a/src/ast/text.ml b/src/ast/text.ml index 3242e384d..f71c9ef3e 100644 --- a/src/ast/text.ml +++ b/src/ast/text.ml @@ -2,7 +2,7 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -open Format +open Fmt open Types let symbolic v = Text v @@ -20,7 +20,7 @@ type global = } let pp_global fmt (g : global) = - pp fmt "(global%a %a %a)" pp_id_opt g.id pp_global_type g.typ pp_expr g.init + pf fmt "(global%a %a %a)" pp_id_opt g.id pp_global_type g.typ pp_expr g.init type data_mode = | Data_passive @@ -29,7 +29,7 @@ type data_mode = let pp_data_mode fmt = function | Data_passive -> () | Data_active (i, e) -> - pp fmt "(memory %a) (offset %a)" pp_indice_opt i pp_expr e + pf fmt "(memory %a) (offset %a)" pp_indice_opt i pp_expr e type data = { id : string option @@ -38,7 +38,7 @@ type data = } let pp_data fmt (d : data) = - pp fmt {|(data%a %a %S)|} pp_id_opt d.id pp_data_mode d.mode d.init + pf fmt {|(data%a %a %S)|} pp_id_opt d.id pp_data_mode d.mode d.init type elem_mode = | Elem_passive @@ -47,11 +47,11 @@ type elem_mode = let pp_elem_mode fmt = function | Elem_passive -> () - | Elem_declarative -> pp fmt "declare" + | Elem_declarative -> pf fmt "declare" | Elem_active (i, e) -> ( match i with - | None -> pp fmt "(offset %a)" pp_expr e - | Some i -> pp fmt "(table %a) (offset %a)" pp_indice i pp_expr e ) + | None -> pf fmt "(offset %a)" pp_expr e + | Some i -> pf fmt "(table %a) (offset %a)" pp_indice i pp_expr e ) type elem = { id : string option @@ -60,12 +60,12 @@ type elem = ; mode : elem_mode } -let pp_elem_expr fmt e = pp fmt "(item %a)" pp_expr e +let pp_elem_expr fmt e = pf fmt "(item %a)" pp_expr e let pp_elem fmt (e : elem) = - pp fmt "@[(elem%a %a %a %a)@]" pp_id_opt e.id pp_elem_mode e.mode + pf fmt "@[(elem%a %a %a %a)@]" pp_id_opt e.id pp_elem_mode e.mode pp_ref_type e.typ - (pp_list ~pp_sep:pp_newline pp_elem_expr) + (list ~sep:pp_newline pp_elem_expr) e.init type module_field = @@ -98,8 +98,8 @@ type modul = } let pp_modul fmt (m : modul) = - pp fmt "(module%a@\n @[%a@]@\n)" pp_id_opt m.id - (pp_list ~pp_sep:pp_newline pp_module_field) + pf fmt "(module%a@\n @[%a@]@\n)" pp_id_opt m.id + (list ~sep:pp_newline pp_module_field) m.fields type action = @@ -108,8 +108,8 @@ type action = let pp_action fmt = function | Invoke (mod_name, name, c) -> - pp fmt {|(invoke%a "%s" %a)|} pp_id_opt mod_name name pp_consts c - | Get _ -> pp fmt "" + pf fmt {|(invoke%a "%s" %a)|} pp_id_opt mod_name name pp_consts c + | Get _ -> pf fmt "" type result_const = | Literal of text const @@ -118,8 +118,8 @@ type result_const = let pp_result_const fmt = function | Literal c -> pp_const fmt c - | Nan_canon n -> pp fmt "float%a.const nan:canonical" pp_nn n - | Nan_arith n -> pp fmt "float%a.const nan:arithmetic" pp_nn n + | Nan_canon n -> pf fmt "float%a.const nan:canonical" pp_nn n + | Nan_arith n -> pf fmt "float%a.const nan:arithmetic" pp_nn n type result = | Result_const of result_const @@ -127,14 +127,14 @@ type result = | Result_func_ref let pp_result fmt = function - | Result_const c -> pp fmt "(%a)" pp_result_const c + | Result_const c -> pf fmt "(%a)" pp_result_const c | Result_func_ref | Result_extern_ref -> Log.err "not yet implemented" let pp_result_bis fmt = function - | Result_const c -> pp fmt "%a" pp_result_const c + | Result_const c -> pf fmt "%a" pp_result_const c | Result_extern_ref | Result_func_ref -> Log.err "not yet implemented" -let pp_results fmt r = pp_list ~pp_sep:pp_space pp_result_bis fmt r +let pp_results fmt r = list ~sep:sp pp_result_bis fmt r type assertion = | Assert_return of action * result list @@ -151,31 +151,31 @@ type assertion = let pp_assertion fmt = function | Assert_return (a, l) -> - pp fmt "(assert_return %a %a)" pp_action a pp_results l + pf fmt "(assert_return %a %a)" pp_action a pp_results l | Assert_exhaustion (a, msg) -> - pp fmt "(assert_exhaustion %a %s)" pp_action a msg - | Assert_trap (a, f) -> pp fmt {|(assert_trap %a "%s")|} pp_action a f + pf fmt "(assert_exhaustion %a %s)" pp_action a msg + | Assert_trap (a, f) -> pf fmt {|(assert_trap %a "%s")|} pp_action a f | Assert_trap_module (m, f) -> - pp fmt {|(assert_trap_module %a "%s")|} pp_modul m f + pf fmt {|(assert_trap_module %a "%s")|} pp_modul m f | Assert_invalid (m, msg) -> - pp fmt "(assert_invalid@\n @[%a@]@\n @[%S@]@\n)" pp_modul m msg + pf fmt "(assert_invalid@\n @[%a@]@\n @[%S@]@\n)" pp_modul m msg | Assert_unlinkable (m, msg) -> - pp fmt "(assert_unlinkable@\n @[%a@]@\n @[%S@]@\n)" pp_modul m msg + pf fmt "(assert_unlinkable@\n @[%a@]@\n @[%S@]@\n)" pp_modul m msg | Assert_malformed (m, msg) -> - pp fmt "(assert_malformed (module binary@\n @[%a@])@\n @[%S@]@\n)" + pf fmt "(assert_malformed (module binary@\n @[%a@])@\n @[%S@]@\n)" pp_modul m msg | Assert_malformed_quote (ls, msg) -> - pp fmt "(assert_malformed_quote@\n @[%S@]@\n @[%S@]@\n)" ls msg + pf fmt "(assert_malformed_quote@\n @[%S@]@\n @[%S@]@\n)" ls msg | Assert_invalid_quote (ls, msg) -> - pp fmt "(assert_invalid_quote@\n @[%S@]@\n @[%S@]@\n)" ls msg + pf fmt "(assert_invalid_quote@\n @[%S@]@\n @[%S@]@\n)" ls msg | Assert_malformed_binary (ls, msg) -> - pp fmt "(assert_malformed_binary@\n @[%S@]@\n @[%S@]@\n)" ls msg + pf fmt "(assert_malformed_binary@\n @[%S@]@\n @[%S@]@\n)" ls msg | Assert_invalid_binary (ls, msg) -> - pp fmt "(assert_invalid_binary@\n @[%S@]@\n @[%S@]@\n)" ls msg + pf fmt "(assert_invalid_binary@\n @[%S@]@\n @[%S@]@\n)" ls msg type register = string * string option -let pp_register fmt (s, _name) = pp fmt "(register %s)" s +let pp_register fmt (s, _name) = pf fmt "(register %s)" s type cmd = | Module of modul @@ -187,8 +187,8 @@ let pp_cmd fmt = function | Module m -> pp_modul fmt m | Assert a -> pp_assertion fmt a | Register (s, name) -> pp_register fmt (s, name) - | Action _a -> pp fmt "" + | Action _a -> pf fmt "" type script = cmd list -let pp_script fmt l = pp_list ~pp_sep:pp_newline pp_cmd fmt l +let pp_script fmt l = list ~sep:pp_newline pp_cmd fmt l diff --git a/src/ast/types.ml b/src/ast/types.ml index 2930729f5..319309571 100644 --- a/src/ast/types.ml +++ b/src/ast/types.ml @@ -2,7 +2,7 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -open Format +open Fmt exception Trap of string @@ -32,17 +32,17 @@ type _ indice = | Text : string -> < with_string_indices ; .. > indice | Raw : int -> < .. > indice -let pp_id fmt id = pp fmt "$%s" id +let pp_id fmt id = pf fmt "$%s" id -let pp_id_opt fmt = function None -> () | Some i -> pp fmt " %a" pp_id i +let pp_id_opt fmt = function None -> () | Some i -> pf fmt " %a" pp_id i let pp_indice (type kind) fmt : kind indice -> unit = function - | Raw u -> pp_int fmt u + | Raw u -> int fmt u | Text i -> pp_id fmt i let pp_indice_opt fmt = function None -> () | Some i -> pp_indice fmt i -let pp_indices fmt ids = pp_list ~pp_sep:pp_space pp_indice fmt ids +let pp_indices fmt ids = list ~sep:sp pp_indice fmt ids type nonrec num_type = | I32 @@ -51,10 +51,23 @@ type nonrec num_type = | F64 let pp_num_type fmt = function - | I32 -> pp fmt "i32" - | I64 -> pp fmt "i64" - | F32 -> pp fmt "f32" - | F64 -> pp fmt "f64" + | I32 -> pf fmt "i32" + | I64 -> pf fmt "i64" + | F32 -> pf fmt "f32" + | F64 -> pf fmt "f64" + +let num_type_eq t1 t2 = + match (t1, t2) with + | I32, I32 | I64, I64 | F32, F32 | F64, F64 -> true + | _, _ -> false + +let compare_num_type t1 t2 = + match (t1, t2) with + | I32, I32 | I64, I64 | F32, F32 | F64, F64 -> 0 + | I32, _ -> 1 + | I64, _ -> 1 + | F32, _ -> 1 + | F64, _ -> 1 type nullable = | No_null @@ -63,32 +76,35 @@ type nullable = let pp_nullable fmt = function | No_null -> (* TODO: no notation to enforce nonnull ? *) - pp fmt "" - | Null -> pp fmt "null" + pf fmt "" + | Null -> pf fmt "null" type nonrec packed_type = | I8 | I16 -let pp_packed_type fmt = function I8 -> pp fmt "i8" | I16 -> pp fmt "i16" +let pp_packed_type fmt = function I8 -> pf fmt "i8" | I16 -> pf fmt "i16" + +let packed_type_eq t1 t2 = + match (t1, t2) with I8, I8 | I16, I16 -> true | _, _ -> false type nonrec mut = | Const | Var -let pp_mut fmt = function Const -> () | Var -> pp fmt "mut" +let pp_mut fmt = function Const -> () | Var -> pf fmt "mut" type nonrec nn = | S32 | S64 -let pp_nn fmt = function S32 -> pp fmt "32" | S64 -> pp fmt "64" +let pp_nn fmt = function S32 -> pf fmt "32" | S64 -> pf fmt "64" type nonrec sx = | U | S -let pp_sx fmt = function U -> pp fmt "u" | S -> pp fmt "s" +let pp_sx fmt = function U -> pf fmt "u" | S -> pf fmt "s" type nonrec iunop = | Clz @@ -96,9 +112,9 @@ type nonrec iunop = | Popcnt let pp_iunop fmt = function - | Clz -> pp fmt "clz" - | Ctz -> pp fmt "ctz" - | Popcnt -> pp fmt "popcnt" + | Clz -> pf fmt "clz" + | Ctz -> pf fmt "ctz" + | Popcnt -> pf fmt "popcnt" type nonrec funop = | Abs @@ -110,13 +126,13 @@ type nonrec funop = | Nearest let pp_funop fmt = function - | Abs -> pp fmt "abs" - | Neg -> pp fmt "neg" - | Sqrt -> pp fmt "sqrt" - | Ceil -> pp fmt "ceil" - | Floor -> pp fmt "floor" - | Trunc -> pp fmt "trunc" - | Nearest -> pp fmt "nearest" + | Abs -> pf fmt "abs" + | Neg -> pf fmt "neg" + | Sqrt -> pf fmt "sqrt" + | Ceil -> pf fmt "ceil" + | Floor -> pf fmt "floor" + | Trunc -> pf fmt "trunc" + | Nearest -> pf fmt "nearest" type nonrec ibinop = | Add @@ -133,18 +149,18 @@ type nonrec ibinop = | Rotr let pp_ibinop fmt = function - | (Add : ibinop) -> pp fmt "add" - | Sub -> pp fmt "sub" - | Mul -> pp fmt "mul" - | Div s -> pp fmt "div_%a" pp_sx s - | Rem s -> pp fmt "rem_%a" pp_sx s - | And -> pp fmt "and" - | Or -> pp fmt "or" - | Xor -> pp fmt "xor" - | Shl -> pp fmt "shl" - | Shr s -> pp fmt "shr_%a" pp_sx s - | Rotl -> pp fmt "rotl" - | Rotr -> pp fmt "rotr" + | (Add : ibinop) -> pf fmt "add" + | Sub -> pf fmt "sub" + | Mul -> pf fmt "mul" + | Div s -> pf fmt "div_%a" pp_sx s + | Rem s -> pf fmt "rem_%a" pp_sx s + | And -> pf fmt "and" + | Or -> pf fmt "or" + | Xor -> pf fmt "xor" + | Shl -> pf fmt "shl" + | Shr s -> pf fmt "shr_%a" pp_sx s + | Rotl -> pf fmt "rotl" + | Rotr -> pf fmt "rotr" type nonrec fbinop = | Add @@ -156,17 +172,17 @@ type nonrec fbinop = | Copysign let pp_fbinop fmt = function - | (Add : fbinop) -> pp fmt "add" - | Sub -> pp fmt "sub" - | Mul -> pp fmt "mul" - | Div -> pp fmt "div" - | Min -> pp fmt "min" - | Max -> pp fmt "max" - | Copysign -> pp fmt "copysign" + | (Add : fbinop) -> pf fmt "add" + | Sub -> pf fmt "sub" + | Mul -> pf fmt "mul" + | Div -> pf fmt "div" + | Min -> pf fmt "min" + | Max -> pf fmt "max" + | Copysign -> pf fmt "copysign" type nonrec itestop = Eqz -let pp_itestop fmt = function Eqz -> pp fmt "eqz" +let pp_itestop fmt = function Eqz -> pf fmt "eqz" type nonrec irelop = | Eq @@ -177,12 +193,12 @@ type nonrec irelop = | Ge of sx let pp_irelop fmt : irelop -> Unit.t = function - | Eq -> pp fmt "eq" - | Ne -> pp fmt "ne" - | Lt sx -> pp fmt "lt_%a" pp_sx sx - | Gt sx -> pp fmt "gt_%a" pp_sx sx - | Le sx -> pp fmt "le_%a" pp_sx sx - | Ge sx -> pp fmt "ge_%a" pp_sx sx + | Eq -> pf fmt "eq" + | Ne -> pf fmt "ne" + | Lt sx -> pf fmt "lt_%a" pp_sx sx + | Gt sx -> pf fmt "gt_%a" pp_sx sx + | Le sx -> pf fmt "le_%a" pp_sx sx + | Ge sx -> pf fmt "ge_%a" pp_sx sx type nonrec frelop = | Eq @@ -193,12 +209,12 @@ type nonrec frelop = | Ge let frelop fmt : frelop -> Unit.t = function - | Eq -> pp fmt "eq" - | Ne -> pp fmt "ne" - | Lt -> pp fmt "lt" - | Gt -> pp fmt "gt" - | Le -> pp fmt "le" - | Ge -> pp fmt "ge" + | Eq -> pf fmt "eq" + | Ne -> pf fmt "ne" + | Lt -> pf fmt "lt" + | Gt -> pf fmt "gt" + | Le -> pf fmt "le" + | Ge -> pf fmt "ge" type nonrec memarg = { offset : Int32.t @@ -207,14 +223,14 @@ type nonrec memarg = let pp_memarg = let pow_2 n = - assert (n >= 0l); + assert (Int32.ge n 0l); Int32.shl 1l n in fun fmt { offset; align } -> let pp_offset fmt offset = - if offset > 0l then pp fmt "offset=%ld " offset + if Int32.gt offset 0l then pf fmt "offset=%ld " offset in - pp fmt "%aalign=%ld" pp_offset offset (pow_2 align) + pf fmt "%aalign=%ld" pp_offset offset (pow_2 align) type nonrec limits = { min : int @@ -222,19 +238,19 @@ type nonrec limits = } let pp_limits fmt { min; max } = - match max with None -> pp fmt "%d" min | Some max -> pp fmt "%d %d" min max + match max with None -> pf fmt "%d" min | Some max -> pf fmt "%d %d" min max type nonrec mem = string option * limits -let pp_mem fmt (id, ty) = pp fmt "(memory%a %a)" pp_id_opt id pp_limits ty +let pp_mem fmt (id, ty) = pf fmt "(memory%a %a)" pp_id_opt id pp_limits ty type nonrec final = | Final | No_final let pp_final fmt = function - | Final -> pp fmt "final" - | No_final -> pp fmt "no_final" + | Final -> pf fmt "final" + | No_final -> pf fmt "no_final" (** Structure *) @@ -254,37 +270,61 @@ type 'a heap_type = | Def_ht of 'a indice let pp_heap_type fmt = function - | Any_ht -> pp fmt "any" - | None_ht -> pp fmt "none" - | Eq_ht -> pp fmt "eq" - | I31_ht -> pp fmt "i31" - | Struct_ht -> pp fmt "struct" - | Array_ht -> pp fmt "array" - | Func_ht -> pp fmt "func" - | No_func_ht -> pp fmt "nofunc" - | Extern_ht -> pp fmt "extern" - | No_extern_ht -> pp fmt "noextern" - | Def_ht i -> pp fmt "%a" pp_indice i + | Any_ht -> pf fmt "any" + | None_ht -> pf fmt "none" + | Eq_ht -> pf fmt "eq" + | I31_ht -> pf fmt "i31" + | Struct_ht -> pf fmt "struct" + | Array_ht -> pf fmt "array" + | Func_ht -> pf fmt "func" + | No_func_ht -> pf fmt "nofunc" + | Extern_ht -> pf fmt "extern" + | No_extern_ht -> pf fmt "noextern" + | Def_ht i -> pf fmt "%a" pp_indice i let pp_heap_type_short fmt = function - | Any_ht -> pp fmt "anyref" - | None_ht -> pp fmt "(ref none)" - | Eq_ht -> pp fmt "eqref" - | I31_ht -> pp fmt "i31ref" - | Struct_ht -> pp fmt "(ref struct)" - | Array_ht -> pp fmt "(ref array)" - | Func_ht -> pp fmt "funcref" - | No_func_ht -> pp fmt "nofunc" - | Extern_ht -> pp fmt "externref" - | No_extern_ht -> pp fmt "(ref noextern)" - | Def_ht i -> pp fmt "(ref %a)" pp_indice i + | Any_ht -> pf fmt "anyref" + | None_ht -> pf fmt "(ref none)" + | Eq_ht -> pf fmt "eqref" + | I31_ht -> pf fmt "i31ref" + | Struct_ht -> pf fmt "(ref struct)" + | Array_ht -> pf fmt "(ref array)" + | Func_ht -> pf fmt "funcref" + | No_func_ht -> pf fmt "nofunc" + | Extern_ht -> pf fmt "externref" + | No_extern_ht -> pf fmt "(ref noextern)" + | Def_ht i -> pf fmt "(ref %a)" pp_indice i + +let heap_type_eq t1 t2 = + (* TODO: this is wrong *) + match (t1, t2) with + | Any_ht, Any_ht + | None_ht, None_ht + | Eq_ht, Eq_ht + | I31_ht, I31_ht + | Struct_ht, Struct_ht + | Array_ht, Array_ht + | Func_ht, Func_ht + | No_func_ht, No_func_ht + | Extern_ht, Extern_ht + | No_extern_ht, No_extern_ht -> + true + | Def_ht _, Def_ht _ -> assert false + | _, _ -> false type nonrec 'a ref_type = nullable * 'a heap_type let pp_ref_type fmt (n, ht) = match n with - | No_null -> pp fmt "%a" pp_heap_type_short ht - | Null -> pp fmt "(ref null %a)" pp_heap_type ht + | No_null -> pf fmt "%a" pp_heap_type_short ht + | Null -> pf fmt "(ref null %a)" pp_heap_type ht + +let ref_type_eq t1 t2 = + match (t1, t2) with + | (Null, t1), (Null, t2) | (No_null, t1), (No_null, t2) -> heap_type_eq t1 t2 + | _ -> false + +let compare_ref_type _ _ = assert false type nonrec 'a val_type = | Num_type of num_type @@ -294,24 +334,49 @@ let pp_val_type fmt = function | Num_type t -> pp_num_type fmt t | Ref_type t -> pp_ref_type fmt t +let val_type_eq t1 t2 = + match (t1, t2) with + | Num_type t1, Num_type t2 -> num_type_eq t1 t2 + | Ref_type t1, Ref_type t2 -> ref_type_eq t1 t2 + | _, _ -> false + +let compare_val_type t1 t2 = + match (t1, t2) with + | Num_type t1, Num_type t2 -> compare_num_type t1 t2 + | Ref_type t1, Ref_type t2 -> compare_ref_type t1 t2 + | Num_type _, _ -> 1 + | Ref_type _, _ -> -1 + type nonrec 'a param = string option * 'a val_type -let pp_param fmt (id, vt) = pp fmt "(param%a %a)" pp_id_opt id pp_val_type vt +let pp_param fmt (id, vt) = pf fmt "(param%a %a)" pp_id_opt id pp_val_type vt + +let param_eq (_, t1) (_, t2) = val_type_eq t1 t2 + +let compare_param (_, t1) (_, t2) = compare_val_type t1 t2 type nonrec 'a param_type = 'a param list -let pp_param_type fmt params = pp_list ~pp_sep:pp_space pp_param fmt params +let pp_param_type fmt params = list ~sep:sp pp_param fmt params + +let param_type_eq t1 t2 = List.equal param_eq t1 t2 + +let compare_param_type t1 t2 = List.compare compare_param t1 t2 type nonrec 'a result_type = 'a val_type list -let pp_result_ fmt vt = pp fmt "(result %a)" pp_val_type vt +let pp_result_ fmt vt = pf fmt "(result %a)" pp_val_type vt + +let pp_result_type fmt results = list ~sep:sp pp_result_ fmt results -let pp_result_type fmt results = pp_list ~pp_sep:pp_space pp_result_ fmt results +let result_type_eq t1 t2 = List.equal val_type_eq t1 t2 + +let compare_result_type t1 t2 = List.compare compare_val_type t1 t2 (* wrap printer to print a space before a non empty list *) (* TODO or make it an optional arg of pp_list? *) let with_space_list printer fmt l = - match l with [] -> () | _l -> pp fmt " %a" printer l + match l with [] -> () | _l -> pf fmt " %a" printer l (* TODO: add a third case that only has (pt * rt) and is the only one used in simplified *) type 'a block_type = @@ -321,9 +386,9 @@ type 'a block_type = -> (< .. > as 'a) block_type let pp_block_type (type kind) fmt : kind block_type -> unit = function - | Bt_ind ind -> pp fmt "(type %a)" pp_indice ind + | Bt_ind ind -> pf fmt "(type %a)" pp_indice ind | Bt_raw (_ind, (pt, rt)) -> - pp fmt "%a%a" + pf fmt "%a%a" (with_space_list pp_param_type) pt (with_space_list pp_result_type) @@ -336,23 +401,30 @@ let pp_block_type_opt fmt = function type nonrec 'a func_type = 'a param_type * 'a result_type let pp_func_type fmt (params, results) = - pp fmt "(func%a%a)" + pf fmt "(func%a%a)" (with_space_list pp_param_type) params (with_space_list pp_result_type) results +let func_type_eq (pt1, rt1) (pt2, rt2) = + param_type_eq pt1 pt2 && result_type_eq rt1 rt2 + +let compare_func_type (pt1, rt1) (pt2, rt2) = + let pt = compare_param_type pt1 pt2 in + if pt = 0 then compare_result_type rt1 rt2 else pt + type nonrec 'a table_type = limits * 'a ref_type let pp_table_type fmt (limits, ref_type) = - pp fmt "%a %a" pp_limits limits pp_ref_type ref_type + pf fmt "%a %a" pp_limits limits pp_ref_type ref_type type nonrec 'a global_type = mut * 'a val_type let pp_global_type fmt (mut, val_type) = match mut with - | Var -> pp fmt "(mut %a)" pp_val_type val_type - | Const -> pp fmt "%a" pp_val_type val_type + | Var -> pf fmt "(mut %a)" pp_val_type val_type + | Const -> pf fmt "%a" pp_val_type val_type type nonrec 'a extern_type = | Func of string option * 'a func_type @@ -481,141 +553,141 @@ type 'a instr = and 'a expr = 'a instr list +let pp_newline ppf () = string ppf "@\n" + let rec pp_instr fmt = function - | I32_const i -> pp fmt "i32.const %ld" i - | I64_const i -> pp fmt "i64.const %Ld" i - | F32_const f -> pp fmt "f32.const %a" Float32.pp f - | F64_const f -> pp fmt "f64.const %a" Float64.pp f - | I_unop (n, op) -> pp fmt "i%a.%a" pp_nn n pp_iunop op - | F_unop (n, op) -> pp fmt "f%a.%a" pp_nn n pp_funop op - | I_binop (n, op) -> pp fmt "i%a.%a" pp_nn n pp_ibinop op - | F_binop (n, op) -> pp fmt "f%a.%a" pp_nn n pp_fbinop op - | I_testop (n, op) -> pp fmt "i%a.%a" pp_nn n pp_itestop op - | I_relop (n, op) -> pp fmt "i%a.%a" pp_nn n pp_irelop op - | F_relop (n, op) -> pp fmt "f%a.%a" pp_nn n frelop op - | I_extend8_s n -> pp fmt "i%a.extend8_s" pp_nn n - | I_extend16_s n -> pp fmt "i%a.extend16_s" pp_nn n - | I64_extend32_s -> pp fmt "i64.extend32_s" - | I32_wrap_i64 -> pp fmt "i32.wrap_i64" - | I64_extend_i32 sx -> pp fmt "i64.extend_i32_%a" pp_sx sx - | I_trunc_f (n, n', sx) -> pp fmt "i%a.trunc_f%a_%a" pp_nn n pp_nn n' pp_sx sx + | I32_const i -> pf fmt "i32.const %ld" i + | I64_const i -> pf fmt "i64.const %Ld" i + | F32_const f -> pf fmt "f32.const %a" Float32.pp f + | F64_const f -> pf fmt "f64.const %a" Float64.pp f + | I_unop (n, op) -> pf fmt "i%a.%a" pp_nn n pp_iunop op + | F_unop (n, op) -> pf fmt "f%a.%a" pp_nn n pp_funop op + | I_binop (n, op) -> pf fmt "i%a.%a" pp_nn n pp_ibinop op + | F_binop (n, op) -> pf fmt "f%a.%a" pp_nn n pp_fbinop op + | I_testop (n, op) -> pf fmt "i%a.%a" pp_nn n pp_itestop op + | I_relop (n, op) -> pf fmt "i%a.%a" pp_nn n pp_irelop op + | F_relop (n, op) -> pf fmt "f%a.%a" pp_nn n frelop op + | I_extend8_s n -> pf fmt "i%a.extend8_s" pp_nn n + | I_extend16_s n -> pf fmt "i%a.extend16_s" pp_nn n + | I64_extend32_s -> pf fmt "i64.extend32_s" + | I32_wrap_i64 -> pf fmt "i32.wrap_i64" + | I64_extend_i32 sx -> pf fmt "i64.extend_i32_%a" pp_sx sx + | I_trunc_f (n, n', sx) -> pf fmt "i%a.trunc_f%a_%a" pp_nn n pp_nn n' pp_sx sx | I_trunc_sat_f (n, n', sx) -> - pp fmt "i%a.trunc_sat_f%a_%a" pp_nn n pp_nn n' pp_sx sx - | F32_demote_f64 -> pp fmt "f32.demote_f64" - | F64_promote_f32 -> pp fmt "f64.promote_f32" + pf fmt "i%a.trunc_sat_f%a_%a" pp_nn n pp_nn n' pp_sx sx + | F32_demote_f64 -> pf fmt "f32.demote_f64" + | F64_promote_f32 -> pf fmt "f64.promote_f32" | F_convert_i (n, n', sx) -> - pp fmt "f%a.convert_i%a_%a" pp_nn n pp_nn n' pp_sx sx - | I_reinterpret_f (n, n') -> pp fmt "i%a.reinterpret_f%a" pp_nn n pp_nn n' - | F_reinterpret_i (n, n') -> pp fmt "f%a.reinterpret_i%a" pp_nn n pp_nn n' - | Ref_null t -> pp fmt "ref.null %a" pp_heap_type t - | Ref_is_null -> pp fmt "ref.is_null" - | Ref_func fid -> pp fmt "ref.func %a" pp_indice fid - | Drop -> pp fmt "drop" + pf fmt "f%a.convert_i%a_%a" pp_nn n pp_nn n' pp_sx sx + | I_reinterpret_f (n, n') -> pf fmt "i%a.reinterpret_f%a" pp_nn n pp_nn n' + | F_reinterpret_i (n, n') -> pf fmt "f%a.reinterpret_i%a" pp_nn n pp_nn n' + | Ref_null t -> pf fmt "ref.null %a" pp_heap_type t + | Ref_is_null -> pf fmt "ref.is_null" + | Ref_func fid -> pf fmt "ref.func %a" pp_indice fid + | Drop -> pf fmt "drop" | Select vt -> begin match vt with - | None -> pp fmt "select" - | Some vt -> pp fmt "select (%a)" pp_result_type vt + | None -> pf fmt "select" + | Some vt -> pf fmt "select (%a)" pp_result_type vt (* TODO: are the parens needed ? *) end - | Local_get id -> pp fmt "local.get %a" pp_indice id - | Local_set id -> pp fmt "local.set %a" pp_indice id - | Local_tee id -> pp fmt "local.tee %a" pp_indice id - | Global_get id -> pp fmt "global.get %a" pp_indice id - | Global_set id -> pp fmt "global.set %a" pp_indice id - | Table_get id -> pp fmt "table.get %a" pp_indice id - | Table_set id -> pp fmt "table.set %a" pp_indice id - | Table_size id -> pp fmt "table.size %a" pp_indice id - | Table_grow id -> pp fmt "table.grow %a" pp_indice id - | Table_fill id -> pp fmt "table.fill %a" pp_indice id - | Table_copy (id, id') -> pp fmt "table.copy %a %a" pp_indice id pp_indice id' + | Local_get id -> pf fmt "local.get %a" pp_indice id + | Local_set id -> pf fmt "local.set %a" pp_indice id + | Local_tee id -> pf fmt "local.tee %a" pp_indice id + | Global_get id -> pf fmt "global.get %a" pp_indice id + | Global_set id -> pf fmt "global.set %a" pp_indice id + | Table_get id -> pf fmt "table.get %a" pp_indice id + | Table_set id -> pf fmt "table.set %a" pp_indice id + | Table_size id -> pf fmt "table.size %a" pp_indice id + | Table_grow id -> pf fmt "table.grow %a" pp_indice id + | Table_fill id -> pf fmt "table.fill %a" pp_indice id + | Table_copy (id, id') -> pf fmt "table.copy %a %a" pp_indice id pp_indice id' | Table_init (tid, eid) -> - pp fmt "table.init %a %a" pp_indice tid pp_indice eid - | Elem_drop id -> pp fmt "elem.drop %a" pp_indice id - | I_load (n, memarg) -> pp fmt "i%a.load %a" pp_nn n pp_memarg memarg - | F_load (n, memarg) -> pp fmt "f%a.load %a" pp_nn n pp_memarg memarg - | I_store (n, memarg) -> pp fmt "i%a.store %a" pp_nn n pp_memarg memarg - | F_store (n, memarg) -> pp fmt "f%a.store %a" pp_nn n pp_memarg memarg + pf fmt "table.init %a %a" pp_indice tid pp_indice eid + | Elem_drop id -> pf fmt "elem.drop %a" pp_indice id + | I_load (n, memarg) -> pf fmt "i%a.load %a" pp_nn n pp_memarg memarg + | F_load (n, memarg) -> pf fmt "f%a.load %a" pp_nn n pp_memarg memarg + | I_store (n, memarg) -> pf fmt "i%a.store %a" pp_nn n pp_memarg memarg + | F_store (n, memarg) -> pf fmt "f%a.store %a" pp_nn n pp_memarg memarg | I_load8 (n, sx, memarg) -> - pp fmt "i%a.load8_%a %a" pp_nn n pp_sx sx pp_memarg memarg + pf fmt "i%a.load8_%a %a" pp_nn n pp_sx sx pp_memarg memarg | I_load16 (n, sx, memarg) -> - pp fmt "i%a.load16_%a %a" pp_nn n pp_sx sx pp_memarg memarg + pf fmt "i%a.load16_%a %a" pp_nn n pp_sx sx pp_memarg memarg | I64_load32 (sx, memarg) -> - pp fmt "i64.load32_%a %a" pp_sx sx pp_memarg memarg - | I_store8 (n, memarg) -> pp fmt "i%a.store8 %a" pp_nn n pp_memarg memarg - | I_store16 (n, memarg) -> pp fmt "i%a.store16 %a" pp_nn n pp_memarg memarg - | I64_store32 memarg -> pp fmt "i64.store32 %a" pp_memarg memarg - | Memory_size -> pp fmt "memory.size" - | Memory_grow -> pp fmt "memory.grow" - | Memory_fill -> pp fmt "memory.fill" - | Memory_copy -> pp fmt "memory.copy" - | Memory_init id -> pp fmt "memory.init %a" pp_indice id - | Data_drop id -> pp fmt "data.drop %a" pp_indice id - | Nop -> pp fmt "nop" - | Unreachable -> pp fmt "unreachable" + pf fmt "i64.load32_%a %a" pp_sx sx pp_memarg memarg + | I_store8 (n, memarg) -> pf fmt "i%a.store8 %a" pp_nn n pp_memarg memarg + | I_store16 (n, memarg) -> pf fmt "i%a.store16 %a" pp_nn n pp_memarg memarg + | I64_store32 memarg -> pf fmt "i64.store32 %a" pp_memarg memarg + | Memory_size -> pf fmt "memory.size" + | Memory_grow -> pf fmt "memory.grow" + | Memory_fill -> pf fmt "memory.fill" + | Memory_copy -> pf fmt "memory.copy" + | Memory_init id -> pf fmt "memory.init %a" pp_indice id + | Data_drop id -> pf fmt "data.drop %a" pp_indice id + | Nop -> pf fmt "nop" + | Unreachable -> pf fmt "unreachable" | Block (id, bt, e) -> - pp fmt "(block%a%a@\n @[%a@])" pp_id_opt id pp_block_type_opt bt pp_expr + pf fmt "(block%a%a@\n @[%a@])" pp_id_opt id pp_block_type_opt bt pp_expr e | Loop (id, bt, e) -> - pp fmt "(loop%a%a@\n @[%a@])" pp_id_opt id pp_block_type_opt bt pp_expr + pf fmt "(loop%a%a@\n @[%a@])" pp_id_opt id pp_block_type_opt bt pp_expr e | If_else (id, bt, e1, e2) -> let pp_else fmt e = match e with | [] -> () - | e -> pp fmt "@\n(else@\n @[%a@]@\n)" pp_expr e + | e -> pf fmt "@\n(else@\n @[%a@]@\n)" pp_expr e in - pp fmt "(if%a%a@\n @[(then@\n @[%a@]@\n)%a@]@\n)" pp_id_opt id + pf fmt "(if%a%a@\n @[(then@\n @[%a@]@\n)%a@]@\n)" pp_id_opt id pp_block_type_opt bt pp_expr e1 pp_else e2 - | Br id -> pp fmt "br %a" pp_indice id - | Br_if id -> pp fmt "br_if %a" pp_indice id + | Br id -> pf fmt "br %a" pp_indice id + | Br_if id -> pf fmt "br_if %a" pp_indice id | Br_table (ids, id) -> - pp fmt "br_table %a %a" - (pp_array ~pp_sep:pp_space pp_indice) - ids pp_indice id - | Return -> pp fmt "return" - | Return_call id -> pp fmt "return_call %a" pp_indice id + pf fmt "br_table %a %a" (array ~sep:sp pp_indice) ids pp_indice id + | Return -> pf fmt "return" + | Return_call id -> pf fmt "return_call %a" pp_indice id | Return_call_indirect (tbl_id, ty_id) -> - pp fmt "return_call_indirect %a %a" pp_indice tbl_id pp_block_type ty_id - | Return_call_ref ty_id -> pp fmt "return_call_ref %a" pp_block_type ty_id - | Call id -> pp fmt "call %a" pp_indice id + pf fmt "return_call_indirect %a %a" pp_indice tbl_id pp_block_type ty_id + | Return_call_ref ty_id -> pf fmt "return_call_ref %a" pp_block_type ty_id + | Call id -> pf fmt "call %a" pp_indice id | Call_indirect (tbl_id, ty_id) -> - pp fmt "call_indirect %a %a" pp_indice tbl_id pp_block_type ty_id - | Call_ref ty_id -> pp fmt "call_ref %a" pp_indice ty_id - | Array_new id -> pp fmt "array.new %a" pp_indice id + pf fmt "call_indirect %a %a" pp_indice tbl_id pp_block_type ty_id + | Call_ref ty_id -> pf fmt "call_ref %a" pp_indice ty_id + | Array_new id -> pf fmt "array.new %a" pp_indice id | Array_new_data (id1, id2) -> - pp fmt "array.new_data %a %a" pp_indice id1 pp_indice id2 - | Array_new_default id -> pp fmt "array.new_default %a" pp_indice id + pf fmt "array.new_data %a %a" pp_indice id1 pp_indice id2 + | Array_new_default id -> pf fmt "array.new_default %a" pp_indice id | Array_new_elem (id1, id2) -> - pp fmt "array.new_elem %a %a" pp_indice id1 pp_indice id2 - | Array_new_fixed (id, i) -> pp fmt "array.new_fixed %a %d" pp_indice id i - | Array_get id -> pp fmt "array.get %a" pp_indice id - | Array_get_u id -> pp fmt "array.get_u %a" pp_indice id - | Array_set id -> pp fmt "array.set %a" pp_indice id - | Array_len -> pp fmt "array.len" - | Ref_i31 -> pp fmt "ref.i31" - | I31_get_s -> pp fmt "i31.get_s" - | I31_get_u -> pp fmt "i31.get_u" - | Struct_get (i1, i2) -> pp fmt "struct.get %a %a" pp_indice i1 pp_indice i2 + pf fmt "array.new_elem %a %a" pp_indice id1 pp_indice id2 + | Array_new_fixed (id, i) -> pf fmt "array.new_fixed %a %d" pp_indice id i + | Array_get id -> pf fmt "array.get %a" pp_indice id + | Array_get_u id -> pf fmt "array.get_u %a" pp_indice id + | Array_set id -> pf fmt "array.set %a" pp_indice id + | Array_len -> pf fmt "array.len" + | Ref_i31 -> pf fmt "ref.i31" + | I31_get_s -> pf fmt "i31.get_s" + | I31_get_u -> pf fmt "i31.get_u" + | Struct_get (i1, i2) -> pf fmt "struct.get %a %a" pp_indice i1 pp_indice i2 | Struct_get_s (i1, i2) -> - pp fmt "struct.get_s %a %a" pp_indice i1 pp_indice i2 - | Struct_new i -> pp fmt "struct.new %a" pp_indice i - | Struct_new_default i -> pp fmt "struct.new_default %a" pp_indice i - | Struct_set (i1, i2) -> pp fmt "struct.set %a %a" pp_indice i1 pp_indice i2 - | Extern_externalize -> pp fmt "extern.externalize" - | Extern_internalize -> pp fmt "extern.internalize" - | Ref_as_non_null -> pp fmt "ref.as_non_null" + pf fmt "struct.get_s %a %a" pp_indice i1 pp_indice i2 + | Struct_new i -> pf fmt "struct.new %a" pp_indice i + | Struct_new_default i -> pf fmt "struct.new_default %a" pp_indice i + | Struct_set (i1, i2) -> pf fmt "struct.set %a %a" pp_indice i1 pp_indice i2 + | Extern_externalize -> pf fmt "extern.externalize" + | Extern_internalize -> pf fmt "extern.internalize" + | Ref_as_non_null -> pf fmt "ref.as_non_null" | Ref_cast (n, t) -> - pp fmt "ref.cast (ref %a %a)" pp_nullable n pp_heap_type t - | Ref_test (n, t) -> pp fmt "ref.test %a %a" pp_nullable n pp_heap_type t - | Br_on_non_null id -> pp fmt "br_on_non_null %a" pp_indice id - | Br_on_null id -> pp fmt "br_on_null %a" pp_indice id + pf fmt "ref.cast (ref %a %a)" pp_nullable n pp_heap_type t + | Ref_test (n, t) -> pf fmt "ref.test %a %a" pp_nullable n pp_heap_type t + | Br_on_non_null id -> pf fmt "br_on_non_null %a" pp_indice id + | Br_on_null id -> pf fmt "br_on_null %a" pp_indice id | Br_on_cast (id, t1, t2) -> - pp fmt "br_on_cast %a %a %a" pp_indice id pp_ref_type t1 pp_ref_type t2 + pf fmt "br_on_cast %a %a %a" pp_indice id pp_ref_type t1 pp_ref_type t2 | Br_on_cast_fail (id, n, t) -> - pp fmt "br_on_cast_fail %a %a %a" pp_indice id pp_nullable n pp_heap_type t - | Ref_eq -> pp fmt "ref.eq" + pf fmt "br_on_cast_fail %a %a %a" pp_indice id pp_nullable n pp_heap_type t + | Ref_eq -> pf fmt "ref.eq" -and pp_expr fmt instrs = pp_list ~pp_sep:pp_newline pp_instr fmt instrs +and pp_expr fmt instrs = list ~sep:pp_newline pp_instr fmt instrs let rec iter_expr f (e : _ expr) = List.iter (iter_instr f) e @@ -691,25 +763,24 @@ type 'a func = ; id : string option } -let pp_local fmt (id, t) = pp fmt "(local%a %a)" pp_id_opt id pp_val_type t +let pp_local fmt (id, t) = pf fmt "(local%a %a)" pp_id_opt id pp_val_type t -let pp_locals fmt locals = pp_list ~pp_sep:pp_space pp_local fmt locals +let pp_locals fmt locals = list ~sep:sp pp_local fmt locals let pp_func : type kind. formatter -> kind func -> unit = fun fmt f -> (* TODO: typeuse ? *) - pp fmt "(func%a%a%a@\n @[%a@]@\n)" pp_id_opt f.id pp_block_type f.type_f + pf fmt "(func%a%a%a@\n @[%a@]@\n)" pp_id_opt f.id pp_block_type f.type_f (with_space_list pp_locals) f.locals pp_expr f.body -let pp_funcs fmt (funcs : 'a func list) = - pp_list ~pp_sep:pp_newline pp_func fmt funcs +let pp_funcs fmt (funcs : 'a func list) = list ~sep:pp_newline pp_func fmt funcs (* Tables & Memories *) type 'a table = string option * 'a table_type -let pp_table fmt (id, ty) = pp fmt "(table%a %a)" pp_id_opt id pp_table_type ty +let pp_table fmt (id, ty) = pf fmt "(table%a %a)" pp_id_opt id pp_table_type ty (* Modules *) @@ -720,11 +791,11 @@ type 'a import_desc = | Import_global of string option * 'a global_type let import_desc fmt : 'a import_desc -> Unit.t = function - | Import_func (id, t) -> pp fmt "(func%a %a)" pp_id_opt id pp_block_type t - | Import_table (id, t) -> pp fmt "(table%a %a)" pp_id_opt id pp_table_type t - | Import_mem (id, t) -> pp fmt "(memory%a %a)" pp_id_opt id pp_limits t + | Import_func (id, t) -> pf fmt "(func%a %a)" pp_id_opt id pp_block_type t + | Import_table (id, t) -> pf fmt "(table%a %a)" pp_id_opt id pp_table_type t + | Import_mem (id, t) -> pf fmt "(memory%a %a)" pp_id_opt id pp_limits t | Import_global (id, t) -> - pp fmt "(global%a %a)" pp_id_opt id pp_global_type t + pf fmt "(global%a %a)" pp_id_opt id pp_global_type t type 'a import = { modul : string (** The name of the module from which the import is done *) @@ -736,8 +807,8 @@ type 'a import = } let pp_import fmt i = - pp fmt {|(import "%a" "%a" %a)|} pp_string i.modul pp_string i.name - import_desc i.desc + pf fmt {|(import "%a" "%a" %a)|} string i.modul string i.name import_desc + i.desc type 'a export_desc = | Export_func of 'a indice option @@ -746,10 +817,10 @@ type 'a export_desc = | Export_global of 'a indice option let pp_export_desc fmt = function - | Export_func id -> pp fmt "(func %a)" pp_indice_opt id - | Export_table id -> pp fmt "(table %a)" pp_indice_opt id - | Export_mem id -> pp fmt "(memory %a)" pp_indice_opt id - | Export_global id -> pp fmt "(global %a)" pp_indice_opt id + | Export_func id -> pf fmt "(func %a)" pp_indice_opt id + | Export_table id -> pf fmt "(table %a)" pp_indice_opt id + | Export_mem id -> pf fmt "(memory %a)" pp_indice_opt id + | Export_global id -> pf fmt "(global %a)" pp_indice_opt id type 'a export = { name : string @@ -757,7 +828,7 @@ type 'a export = } let pp_export fmt (e : text export) = - pp fmt {|(export "%s" %a)|} e.name pp_export_desc e.desc + pf fmt {|(export "%s" %a)|} e.name pp_export_desc e.desc type 'a storage_type = | Val_storage_t of 'a val_type @@ -767,26 +838,40 @@ let pp_storage_type fmt = function | Val_storage_t t -> pp_val_type fmt t | Val_packed_t t -> pp_packed_type fmt t +let storage_type_eq t1 t2 = + match (t1, t2) with + | Val_storage_t t1, Val_storage_t t2 -> val_type_eq t1 t2 + | Val_packed_t t1, Val_packed_t t2 -> packed_type_eq t1 t2 + | _, _ -> false + type 'a field_type = mut * 'a storage_type let pp_field_type fmt (m, t) = match m with - | Const -> pp fmt " %a" pp_storage_type t - | Var -> pp fmt "(%a %a)" pp_mut m pp_storage_type t + | Const -> pf fmt " %a" pp_storage_type t + | Var -> pf fmt "(%a %a)" pp_mut m pp_storage_type t + +let field_type_eq t1 t2 = + match (t1, t2) with + | (Const, t1), (Const, t2) | (Var, t1), (Var, t2) -> storage_type_eq t1 t2 + | _, _ -> false type 'a struct_field = string option * 'a field_type list -let pp_fields fmt = pp_list ~pp_sep:pp_space pp_field_type fmt +let pp_fields fmt = list ~sep:sp pp_field_type fmt let pp_struct_field fmt ((n : string option), f) = - pp fmt "@\n @[(field%a%a)@]" pp_id_opt n pp_fields f + pf fmt "@\n @[(field%a%a)@]" pp_id_opt n pp_fields f + +let struct_field_eq (_, t1) (_, t2) = List.equal field_type_eq t1 t2 type 'a struct_type = 'a struct_field list -let pp_struct_type fmt = - pp fmt "(struct %a)" (pp_list ~pp_sep:pp_space pp_struct_field) +let pp_struct_type fmt = pf fmt "(struct %a)" (list ~sep:sp pp_struct_field) + +let struct_type_eq t1 t2 = List.equal struct_field_eq t1 t2 -let pp_array_type fmt = pp fmt "(array %a)" pp_field_type +let pp_array_type fmt = pf fmt "(array %a)" pp_field_type type 'a str_type = | Def_struct_t of 'a struct_type @@ -798,27 +883,39 @@ let str_type fmt = function | Def_array_t t -> pp_array_type fmt t | Def_func_t t -> pp_func_type fmt t +let str_type_eq t1 t2 = + match (t1, t2) with + | Def_struct_t t1, Def_struct_t t2 -> struct_type_eq t1 t2 + | Def_array_t t1, Def_array_t t2 -> field_type_eq t1 t2 + | Def_func_t t1, Def_func_t t2 -> func_type_eq t1 t2 + | _, _ -> false + +let compare_str_type t1 t2 = + match (t1, t2) with + | Def_func_t t1, Def_func_t t2 -> compare_func_type t1 t2 + | _, _ -> assert false + type 'a sub_type = final * 'a indice list * 'a str_type let pp_sub_type fmt (f, ids, t) = - pp fmt "(sub %a %a %a)" pp_final f pp_indices ids str_type t + pf fmt "(sub %a %a %a)" pp_final f pp_indices ids str_type t type 'a type_def = string option * 'a sub_type let pp_type_def_no_indent fmt (id, t) = - pp fmt "(type%a %a)" pp_id_opt id pp_sub_type t + pf fmt "(type%a %a)" pp_id_opt id pp_sub_type t -let pp_type_def fmt t = pp fmt "@\n @[%a@]" pp_type_def_no_indent t +let pp_type_def fmt t = pf fmt "@\n @[%a@]" pp_type_def_no_indent t type 'a rec_type = 'a type_def list let pp_rec_type fmt l = match l with | [] -> () - | [ t ] -> pp fmt "@\n%a" pp_type_def_no_indent t - | l -> pp fmt "(rec %a)" (pp_list ~pp_sep:pp_space pp_type_def) l + | [ t ] -> pf fmt "@\n%a" pp_type_def_no_indent t + | l -> pf fmt "(rec %a)" (list ~sep:sp pp_type_def) l -let pp_start fmt start = pp fmt "(start %a)" pp_indice start +let pp_start fmt start = pf fmt "(start %a)" pp_indice start type 'a const = | Const_I32 of Int32.t @@ -834,20 +931,20 @@ type 'a const = | Const_struct let pp_const fmt c = - pp fmt "(%a)" + pf fmt "(%a)" (fun fmt c -> match c with - | Const_I32 i -> pp fmt "i32.const %ld" i - | Const_I64 i -> pp fmt "i64.const %Ld" i - | Const_F32 f -> pp fmt "f32.const %a" Float32.pp f - | Const_F64 f -> pp fmt "f64.const %a" Float64.pp f - | Const_null rt -> pp fmt "ref.null %a" pp_heap_type rt - | Const_host i -> pp fmt "ref.host %d" i - | Const_extern i -> pp fmt "ref.extern %d" i - | Const_array -> pp fmt "ref.array" - | Const_eq -> pp fmt "ref.eq" - | Const_i31 -> pp fmt "ref.i31" - | Const_struct -> pp fmt "ref.struct" ) + | Const_I32 i -> pf fmt "i32.const %ld" i + | Const_I64 i -> pf fmt "i64.const %Ld" i + | Const_F32 f -> pf fmt "f32.const %a" Float32.pp f + | Const_F64 f -> pf fmt "f64.const %a" Float64.pp f + | Const_null rt -> pf fmt "ref.null %a" pp_heap_type rt + | Const_host i -> pf fmt "ref.host %d" i + | Const_extern i -> pf fmt "ref.extern %d" i + | Const_array -> pf fmt "ref.array" + | Const_eq -> pf fmt "ref.eq" + | Const_i31 -> pf fmt "ref.i31" + | Const_struct -> pf fmt "ref.struct" ) c -let pp_consts fmt c = pp_list ~pp_sep:pp_space pp_const fmt c +let pp_consts fmt c = list ~sep:sp pp_const fmt c diff --git a/src/bin/dune b/src/bin/dune index 2e55e8573..f21d3726f 100644 --- a/src/bin/dune +++ b/src/bin/dune @@ -4,6 +4,8 @@ (public_name owi) (modules owi) (package owi) - (libraries sedlex owi cmdliner) + (libraries cmdliner owi prelude sedlex) (instrumentation - (backend bisect_ppx))) + (backend bisect_ppx)) + (flags + (:standard -open Prelude))) diff --git a/src/bin/owi.ml b/src/bin/owi.ml index b7b50dc21..588730178 100644 --- a/src/bin/owi.ml +++ b/src/bin/owi.ml @@ -13,7 +13,7 @@ let existing_non_dir_file = let path = Fpath.v s in match Bos.OS.File.exists path with | Ok true -> `Ok path - | Ok false -> `Error (Format.asprintf "no file '%a'" Fpath.pp path) + | Ok false -> `Error (Fmt.str "no file '%a'" Fpath.pp path) | Error (`Msg s) -> `Error s in (parse, Fpath.pp) @@ -272,7 +272,7 @@ let exit_code = match result with | Ok () -> ok | Error e -> begin - Format.pp_err "%s" (Result.err_to_string e); + Fmt.epr "%s" (Result.err_to_string e); match e with | `No_error -> ok | `Alignment_too_large -> 1 diff --git a/src/cmd/cmd_c.ml b/src/cmd/cmd_c.ml index 77b4a265c..703d70b07 100644 --- a/src/cmd/cmd_c.ml +++ b/src/cmd/cmd_c.ml @@ -27,7 +27,7 @@ let find location file : Fpath.t Result.t = location in let rec loop = function - | [] -> Error (`Msg (Format.asprintf "can't find file %a" Fpath.pp file)) + | [] -> Error (`Msg (Fmt.str "can't find file %a" Fpath.pp file)) | None :: tl -> loop tl | Some file :: _tl -> Ok file in @@ -39,7 +39,7 @@ let compile ~includes ~opt_lvl (files : Fpath.t list) : Fpath.t Result.t = let includes = Cmd.of_list ~slip:"-I" (List.map Fpath.to_string includes) in Cmd.( of_list - [ "-O" ^ opt_lvl + [ Fmt.str "-O%s" opt_lvl ; "--target=wasm32" ; "-m32" ; "-ffreestanding" @@ -51,7 +51,7 @@ let compile ~includes ~opt_lvl (files : Fpath.t list) : Fpath.t Result.t = ; "-Wl,--export=main" (* TODO: allow this behind a flag, this is slooooow *) ; "-Wl,--lto-O0" - ; "-Wl,-z,stack-size=" ^ stack_size + ; Fmt.str "-Wl,-z,stack-size=%s" stack_size ] %% includes ) in @@ -71,7 +71,7 @@ let compile ~includes ~opt_lvl (files : Fpath.t list) : Fpath.t Result.t = let pp_tm fmt Unix.{ tm_year; tm_mon; tm_mday; tm_hour; tm_min; tm_sec; _ } : unit = - Format.pp fmt "%04d-%02d-%02dT%02d:%02d:%02dZ" (tm_year + 1900) tm_mon tm_mday + Fmt.pf fmt "%04d-%02d-%02dT%02d:%02d:%02dZ" (tm_year + 1900) tm_mon tm_mday tm_hour tm_min tm_sec let metadata ~workspace arch property files : unit Result.t = @@ -102,8 +102,8 @@ let metadata ~workspace arch property files : unit Result.t = ; el "programfile" file ; el "programhash" hash ; el "entryfunction" "main" - ; el "architecture" (Format.sprintf "%dbit" arch) - ; el "creationtime" (Format.asprintf "%a" pp_tm time) + ; el "architecture" (Fmt.str "%dbit" arch) + ; el "creationtime" (Fmt.str "%a" pp_tm time) ] ) in let dtd = diff --git a/src/cmd/cmd_conc.ml b/src/cmd/cmd_conc.ml index 0b2867bff..8955d8d81 100644 --- a/src/cmd/cmd_conc.ml +++ b/src/cmd/cmd_conc.ml @@ -82,7 +82,7 @@ type trace = ; end_of_trace : end_of_trace } -module IMap = Map.Make (Stdlib.Int32) +module IMap = Map.Make (Prelude.Int32) module Unexplored : sig type t @@ -288,8 +288,8 @@ let run_once tree link_state modules_to_run forced_values = { assignments = symbols_value; remaining_pc = List.rev pc; end_of_trace } in if debug then begin - Format.pp_std "Add trace:@\n"; - Format.pp_std "%a@\n" Concolic_choice.pp_pc trace.remaining_pc + Fmt.pr "Add trace:@\n"; + Fmt.pr "%a@\n" Concolic_choice.pp_pc trace.remaining_pc end; add_trace tree trace; r @@ -299,7 +299,7 @@ let rec find_node_to_run tree = match tree.node with | Not_explored -> if debug then begin - Format.pp_std "Try unexplored@.%a@.@." Concolic_choice.pp_pc tree.pc + Fmt.pr "Try unexplored@.%a@.@." Concolic_choice.pp_pc tree.pc end; Some tree.pc | Select { cond = _; if_true; if_false } -> @@ -309,7 +309,7 @@ let rec find_node_to_run tree = else Random.bool () in if debug then begin - Format.pp_std "Select bool %b@." b + Fmt.pr "Select bool %b@." b end; let tree = if b then if_true else if_false in find_node_to_run tree @@ -325,18 +325,18 @@ let rec find_node_to_run tree = let i = Random.int n in let i, branch = List.nth branches i in if debug then begin - Format.pp_std "Select_i32 %li@." i + Fmt.pr "Select_i32 %li@." i end; find_node_to_run branch end | Assume { cond = _; cont } -> find_node_to_run cont | Assert { cond; cont = _; disproved = None } -> let pc : Concolic_choice.pc = Select (cond, false) :: tree.pc in - Format.pp_std "Try Assert@.%a@.@." Concolic_choice.pp_pc pc; + Fmt.pr "Try Assert@.%a@.@." Concolic_choice.pp_pc pc; Some pc | Assert { cond = _; cont; disproved = Some _ } -> find_node_to_run cont | Unreachable -> - Format.pp_std "Unreachable (Retry)@.%a@." Concolic_choice.pp_pc tree.pc; + Fmt.pr "Unreachable (Retry)@.%a@." Concolic_choice.pp_pc tree.pc; None let pc_model solver pc = @@ -356,7 +356,7 @@ let find_model_to_run solver tree = let launch solver tree link_state modules_to_run = let rec find_model n = if n = 0 then begin - Format.pp_std "Failed to find something to run@\n"; + Fmt.pr "Failed to find something to run@\n"; None end else @@ -364,7 +364,7 @@ let launch solver tree link_state modules_to_run = | None -> find_model (n - 1) | Some m -> if debug then begin - Format.pp_std "Found something to run %a@\n" + Fmt.pr "Found something to run %a@\n" (Smtml.Model.pp ~no_values:false) m end; @@ -382,14 +382,14 @@ let launch solver tree link_state modules_to_run = | Ok (Error e) -> Result.failwith e | Error (Assume_fail c) -> begin if debug then begin - Format.pp_std "Assume_fail: %a@\n" Smtml.Expr.pp c; - Format.pp_std "Assignments:@\n%a@\n" Concolic_choice.pp_assignments + Fmt.pr "Assume_fail: %a@\n" Smtml.Expr.pp c; + Fmt.pr "Assignments:@\n%a@\n" Concolic_choice.pp_assignments thread.symbols_value; - Format.pp_std "Retry !@\n" + Fmt.pr "Retry !@\n" end; match pc_model solver thread.pc with | None -> - Format.pp_err "Can't satisfy assume !@\n"; + Fmt.epr "Can't satisfy assume !@\n"; loop (count - 1) | Some _model as model -> run_model model (count - 1) end @@ -402,10 +402,8 @@ let launch solver tree link_state modules_to_run = during evaluation (OS, syntax error, etc.), except for Trap and Assert, which are handled here. Most of the computations are done in the Result monad, hence the let*. *) -let cmd profiling debug unsafe optimize workers no_stop_at_failure no_values - deterministic_result_order (workspace : Fpath.t) solver files = - ignore (workers, no_stop_at_failure, deterministic_result_order, workspace); - +let cmd profiling debug unsafe optimize _workers _no_stop_at_failure no_values + _deterministic_result_order (workspace : Fpath.t) solver files = if profiling then Log.profiling_on := true; if debug then Log.debug_on := true; @@ -420,22 +418,30 @@ let cmd profiling debug unsafe optimize workers no_stop_at_failure no_values let result = launch solver tree link_state modules_to_run in let print_pc pc = - Format.pp_std "PC:@\n"; - Format.pp_std "%a@\n" Concolic_choice.pp_pc pc + Fmt.pr "PC:@\n"; + Fmt.pr "%a@\n" Concolic_choice.pp_pc pc in let print_values symbols_value = - Format.pp_std "Assignments:@\n"; + Fmt.pr "Assignments:@\n"; List.iter - (fun (s, v) -> - Format.pp_std " %a: %a" Smtml.Symbol.pp s Concrete_value.pp v ) + (fun (s, v) -> Fmt.pr " %a: %a" Smtml.Symbol.pp s Concrete_value.pp v) symbols_value; - Format.pp_std "@\n" + Fmt.pr "@\n" in let testcase model = if not no_values then let testcase = - List.sort compare (Smtml.Model.get_bindings model) |> List.map snd + let compare_pair fx fy (x1, y1) (x2, y2) = + let cx = fx x1 x2 in + if cx = 0 then fy y1 y2 else cx + in + (* TODO: add a function for this in smtml *) + (* TODO: merge this code with cmd_sym, it's almost the same.. *) + List.sort + (compare_pair Smtml.Symbol.compare Smtml.Value.compare) + (Smtml.Model.get_bindings model) + |> List.map snd in Cmd_utils.write_testcase ~dir:workspace testcase else Ok () @@ -443,27 +449,27 @@ let cmd profiling debug unsafe optimize workers no_stop_at_failure no_values match result with | None -> - Format.pp_std "OK@\n"; + Fmt.pr "OK@\n"; Ok () | Some (`Trap trap, thread) -> - Format.pp_std "Trap: %s@\n" (Trap.to_string trap); + Fmt.pr "Trap: %s@\n" (Trap.to_string trap); if debug then begin print_pc thread.pc; print_values thread.symbols_value end; let symbols = None in let model = get_model ~symbols solver thread.pc in - Format.pp_std "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model; + Fmt.pr "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model; let* () = testcase model in Error (`Found_bug 1) | Some (`Assert_fail, thread) -> - Format.pp_std "Assert failure@\n"; + Fmt.pr "Assert failure@\n"; if debug then begin print_pc thread.pc; print_values thread.symbols_value end; let symbols = None in let model = get_model ~symbols solver thread.pc in - Format.pp_std "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model; + Fmt.pr "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model; let* () = testcase model in Error (`Found_bug 1) diff --git a/src/cmd/cmd_fmt.ml b/src/cmd/cmd_fmt.ml index 626cde5c0..b380f8b58 100644 --- a/src/cmd/cmd_fmt.ml +++ b/src/cmd/cmd_fmt.ml @@ -19,19 +19,11 @@ let cmd_one inplace file = match get_printer file with | Error _e as e -> e | Ok pp -> - if inplace then - let* res = - Bos.OS.File.with_oc file - (fun chan () -> - let fmt = Stdlib.Format.formatter_of_out_channel chan in - Ok (Format.pp fmt "%a@\n" pp ()) ) - () - in - res - else Ok (Format.pp_std "%a@\n" pp ()) + if inplace then Bos.OS.File.writef file "%a@\n" pp () + else Ok (Fmt.pr "%a@\n" pp ()) let cmd inplace files = list_iter (cmd_one inplace) files let format_file_to_string file = let+ pp = get_printer file in - Format.asprintf "%a@\n" pp () + Fmt.str "%a@\n" pp () diff --git a/src/cmd/cmd_opt.ml b/src/cmd/cmd_opt.ml index 5c8735877..2ed335503 100644 --- a/src/cmd/cmd_opt.ml +++ b/src/cmd/cmd_opt.ml @@ -21,7 +21,7 @@ let cmd debug unsafe files = match optimize_file ~unsafe file with | Ok m -> let m = Binary_to_text.modul m in - Format.pp_std "%a@\n" Text.pp_modul m; + Fmt.pr "%a@\n" Text.pp_modul m; Ok () | Error _ as e -> e ) files diff --git a/src/cmd/cmd_sym.ml b/src/cmd/cmd_sym.ml index 1c916a308..771a6c188 100644 --- a/src/cmd/cmd_sym.ml +++ b/src/cmd/cmd_sym.ml @@ -75,11 +75,11 @@ let cmd profiling debug unsafe optimize workers no_stop_at_failure no_values in let print_bug = function | `ETrap (tr, model) -> - Format.pp_std "Trap: %s@\n" (Trap.to_string tr); - Format.pp_std "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model + Fmt.pr "Trap: %s@\n" (Trap.to_string tr); + Fmt.pr "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model | `EAssert (assertion, model) -> - Format.pp_std "Assert failure: %a@\n" Expr.pp assertion; - Format.pp_std "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model + Fmt.pr "Assert failure: %a@\n" Expr.pp assertion; + Fmt.pr "Model:@\n @[%a@]@." (Smtml.Model.pp ~no_values) model in let rec print_and_count_failures count_acc results = match results () with @@ -95,14 +95,22 @@ let cmd profiling debug unsafe optimize workers no_stop_at_failure no_values print_bug (`ETrap (tr, model)); Ok model | EVal (Ok ()) -> - failwith "unreachable: callback should have filtered eval ok out." + Fmt.failwith "unreachable: callback should have filtered eval ok out." | EVal (Error e) -> Error e in let count_acc = succ count_acc in let* () = if not no_values then let testcase = - List.sort compare (Smtml.Model.get_bindings model) |> List.map snd + let compare_pair fx fy (x1, y1) (x2, y2) = + let cx = fx x1 x2 in + if cx = 0 then fy y1 y2 else cx + in + (* TODO: add a function for this in smtml *) + List.sort + (compare_pair Smtml.Symbol.compare Smtml.Value.compare) + (Smtml.Model.get_bindings model) + |> List.map snd in Cmd_utils.write_testcase ~dir:workspace testcase else Ok () @@ -117,13 +125,13 @@ let cmd profiling debug unsafe optimize workers no_stop_at_failure no_values (x, List.rev @@ Thread.breadcrumbs thread) ) |> List.of_seq |> List.sort (fun (_, bc1) (_, bc2) -> - List.compare Stdlib.Int32.compare bc1 bc2 ) + List.compare Prelude.Int32.compare bc1 bc2 ) |> List.to_seq |> Seq.map fst else results in let* count = print_and_count_failures 0 results in if count > 0 then Error (`Found_bug count) else begin - Format.pp_std "All OK"; + Fmt.pr "All OK"; Ok () end diff --git a/src/cmd/cmd_utils.ml b/src/cmd/cmd_utils.ml index be7ac19d0..f16bfceb8 100644 --- a/src/cmd/cmd_utils.ml +++ b/src/cmd/cmd_utils.ml @@ -8,7 +8,7 @@ let out_testcase ~dst testcase = let o = Xmlm.make_output ~nl:true ~indent:(Some 2) dst in let tag atts name = (("", name), atts) in let atts = [ (("", "coversError"), "true") ] in - let to_string v = Format.asprintf "%a" Smtml.Value.pp_num v in + let to_string v = Fmt.str "%a" Smtml.Value.pp_num v in let input v = `El (tag [] "input", [ `Data (to_string v) ]) in let testcase = `El (tag atts "testcase", List.map input testcase) in let dtd = @@ -21,7 +21,7 @@ let write_testcase = let cnt = ref 0 in fun ~dir testcase -> incr cnt; - let name = Format.ksprintf Fpath.v "testcase-%d.xml" !cnt in + let name = Fmt.kstr Fpath.v "testcase-%d.xml" !cnt in let path = Fpath.append dir name in let* res = Bos.OS.File.with_oc path @@ -61,7 +61,7 @@ let add_main_as_start (m : Binary.modul) = | Ref_type (Types.No_null, t) -> Error (`Msg - (Format.asprintf "can not create default value of type %a" + (Fmt.str "can not create default value of type %a" Types.pp_heap_type t ) ) in let+ body = diff --git a/src/cmd/cmd_wasm2wat.ml b/src/cmd/cmd_wasm2wat.ml index ff7ffa4f0..4559e79d5 100644 --- a/src/cmd/cmd_wasm2wat.ml +++ b/src/cmd/cmd_wasm2wat.ml @@ -10,7 +10,7 @@ let cmd_one file = | ".wasm" -> let* m = Parse.Binary.Module.from_file file in let m = Binary_to_text.modul m in - Ok (Format.pp_std "%a@\n" Text.pp_modul m) + Ok (Fmt.pr "%a@\n" Text.pp_modul m) | ext -> Error (`Unsupported_file_extension ext) let cmd files = list_iter cmd_one files diff --git a/src/concolic/concolic.ml b/src/concolic/concolic.ml index dcf7723bd..3889fc3dd 100644 --- a/src/concolic/concolic.ml +++ b/src/concolic/concolic.ml @@ -53,7 +53,7 @@ module P = struct } | Ref _, Ref _ -> (* Concretization: add something to the PC *) - failwith "TODO" + Fmt.failwith "TODO" | _, _ -> assert false module Global = struct @@ -178,7 +178,7 @@ module P = struct ~dst:dst.symbolic ~len:len.symbolic } - let get_limit_max _ = failwith "TODO" + let get_limit_max _ = Fmt.failwith "TODO" end module Extern_func = Concrete_value.Make_extern_func (Value) (Choice) (Memory) diff --git a/src/concolic/concolic_choice.ml b/src/concolic/concolic_choice.ml index f897c0df4..126e40c28 100644 --- a/src/concolic/concolic_choice.ml +++ b/src/concolic/concolic_choice.ml @@ -15,17 +15,17 @@ type pc_elt = | Assert of Symbolic_value.vbool let pp_pc_elt fmt = function - | Select (c, v) -> Format.pp fmt "Select(%a, %b)" Smtml.Expr.pp c v - | Select_i32 (c, v) -> Format.pp fmt "Select_i32(%a, %li)" Smtml.Expr.pp c v - | Assume c -> Format.pp fmt "Assume(%a)" Smtml.Expr.pp c - | Assert c -> Format.pp fmt "Assert(%a)" Smtml.Expr.pp c + | Select (c, v) -> Fmt.pf fmt "Select(%a, %b)" Smtml.Expr.pp c v + | Select_i32 (c, v) -> Fmt.pf fmt "Select_i32(%a, %li)" Smtml.Expr.pp c v + | Assume c -> Fmt.pf fmt "Assume(%a)" Smtml.Expr.pp c + | Assert c -> Fmt.pf fmt "Assert(%a)" Smtml.Expr.pp c -let pp_pc fmt pc = List.iter (fun e -> Format.pp fmt " %a@\n" pp_pc_elt e) pc +let pp_pc fmt pc = List.iter (fun e -> Fmt.pf fmt " %a@\n" pp_pc_elt e) pc let pp_assignments fmt assignments = List.iter (fun (sym, v) -> - Format.pp fmt " %a : %a@\n" Smtml.Symbol.pp sym Concrete_value.pp v ) + Fmt.pf fmt " %a : %a@\n" Smtml.Symbol.pp sym Concrete_value.pp v ) assignments let pc_elt_to_expr = function @@ -55,7 +55,7 @@ type thread = let init_thread preallocated_values shared = { symbols = 0; pc = []; symbols_value = []; preallocated_values; shared } -type 'a run_result = ('a, err) Stdlib.Result.t * thread +type 'a run_result = ('a, err) Prelude.Result.t * thread type 'a t = M of (thread -> 'a run_result) [@@unboxed] @@ -134,7 +134,7 @@ let with_new_symbol ty f = M (fun st -> let id = st.symbols + 1 in - let sym = Format.kasprintf (Smtml.Symbol.make ty) "symbol_%d" id in + let sym = Fmt.kstr (Smtml.Symbol.make ty) "symbol_%d" id in let value = Hashtbl.find_opt st.preallocated_values sym in let concrete, v = f sym value in let st = diff --git a/src/concolic/concolic_value.ml b/src/concolic/concolic_value.ml index 39a09d93c..db607edbf 100644 --- a/src/concolic/concolic_value.ml +++ b/src/concolic/concolic_value.ml @@ -17,25 +17,23 @@ module T_pair (C : Value_intf.T) (S : Value_intf.T) = struct type int32 = (C.int32, S.int32) cs let pp_int32 fmt v = - Format.pp fmt "{ c = %a ; s = %a }" C.pp_int32 v.concrete S.pp_int32 - v.symbolic + Fmt.pf fmt "{ c = %a ; s = %a }" C.pp_int32 v.concrete S.pp_int32 v.symbolic type int64 = (C.int64, S.int64) cs let pp_int64 fmt v = - Format.pp fmt "{ c = %a ; s = %a }" C.pp_int64 v.concrete S.pp_int64 - v.symbolic + Fmt.pf fmt "{ c = %a ; s = %a }" C.pp_int64 v.concrete S.pp_int64 v.symbolic type float32 = (C.float32, S.float32) cs let pp_float32 fmt v = - Format.pp fmt "{ c = %a ; s = %a }" C.pp_float32 v.concrete S.pp_float32 + Fmt.pf fmt "{ c = %a ; s = %a }" C.pp_float32 v.concrete S.pp_float32 v.symbolic type float64 = (C.float64, S.float64) cs let pp_float64 fmt v = - Format.pp fmt "{ c = %a ; s = %a }" C.pp_float64 v.concrete S.pp_float64 + Fmt.pf fmt "{ c = %a ; s = %a }" C.pp_float64 v.concrete S.pp_float64 v.symbolic (* TODO: Probably beter not to have a different value for both, @@ -43,7 +41,7 @@ module T_pair (C : Value_intf.T) (S : Value_intf.T) = struct type ref_value = (C.ref_value, S.ref_value) cs let pp_ref_value fmt v = - Format.pp fmt "{ c = %a ; s = %a }" C.pp_ref_value v.concrete S.pp_ref_value + Fmt.pf fmt "{ c = %a ; s = %a }" C.pp_ref_value v.concrete S.pp_ref_value v.symbolic type t = @@ -126,8 +124,7 @@ module T_pair (C : Value_intf.T) (S : Value_intf.T) = struct let ref_is_null v = f_pair_1 C.ref_is_null S.ref_is_null v let mk_pp c symbolic ppf v = - Stdlib.Format.fprintf ppf "@[{c: %a@, s: %a}@]" c v.concrete symbolic - v.symbolic + Fmt.pf ppf "@[{c: %a@, s: %a}@]" c v.concrete symbolic v.symbolic let pp fmt = function | I32 i -> pp_int32 fmt i @@ -138,7 +135,7 @@ module T_pair (C : Value_intf.T) (S : Value_intf.T) = struct module Ref = struct let equal_func_intf (_ : Func_intf.t) (_ : Func_intf.t) : bool = - failwith "TODO equal_func_intf" + Fmt.failwith "TODO equal_func_intf" let get_func ref : Func_intf.t Value_intf.get_ref = match (C.Ref.get_func ref.concrete, S.Ref.get_func ref.symbolic) with @@ -445,7 +442,7 @@ module T_pair (C : Value_intf.T) (S : Value_intf.T) = struct include MK_Iop (struct - type t = Stdlib.Int32.t + type t = Int32.t end) (struct type t = C.int32 @@ -467,7 +464,7 @@ module T_pair (C : Value_intf.T) (S : Value_intf.T) = struct include MK_Iop (struct - type t = Stdlib.Int64.t + type t = Int64.t end) (struct type t = C.int64 diff --git a/src/concolic/concolic_wasm_ffi.ml b/src/concolic/concolic_wasm_ffi.ml index e1cba8924..ce532e396 100644 --- a/src/concolic/concolic_wasm_ffi.ml +++ b/src/concolic/concolic_wasm_ffi.ml @@ -107,9 +107,7 @@ module M : Log.debug2 {|free: cannot fetch pointer base of "%a"|} Expr.pp v.symbolic; Choice.bind (abort ()) (fun () -> assert false) - let exit (p : Value.int32) : unit Choice.t = - ignore p; - abort () + let exit (_p : Value.int32) : unit Choice.t = abort () let alloc _ (base : Value.int32) (_size : Value.int32) : Value.int32 Choice.t = diff --git a/src/concrete/concrete_value.ml b/src/concrete/concrete_value.ml index 7c062a46a..265a25a25 100644 --- a/src/concrete/concrete_value.ml +++ b/src/concrete/concrete_value.ml @@ -3,7 +3,7 @@ (* Written by the Owi programmers *) open Types -open Format +open Fmt module Make_extern_func (V : Func_intf.Value_types) @@ -114,9 +114,9 @@ type ref_value = | Arrayref of unit Array.t option let pp_ref_value fmt = function - | Externref _ -> pp fmt "externref" - | Funcref _ -> pp fmt "funcref" - | Arrayref _ -> pp fmt "array" + | Externref _ -> pf fmt "externref" + | Funcref _ -> pf fmt "funcref" + | Arrayref _ -> pf fmt "array" type t = | I32 of Int32.t @@ -142,10 +142,10 @@ let to_instr = function | Ref _ -> assert false let pp fmt = function - | I32 i -> pp fmt "i32.const %ld" i - | I64 i -> pp fmt "i64.const %Ld" i - | F32 f -> pp fmt "f32.const %a" Float32.pp f - | F64 f -> pp fmt "f64.const %a" Float64.pp f + | I32 i -> pf fmt "i32.const %ld" i + | I64 i -> pf fmt "i64.const %Ld" i + | F32 f -> pf fmt "f32.const %a" Float32.pp f + | F64 f -> pf fmt "f64.const %a" Float64.pp f | Ref r -> pp_ref_value fmt r let ref_null' = function diff --git a/src/concrete/concrete_value.mli b/src/concrete/concrete_value.mli index 79aa4eb94..15446e5af 100644 --- a/src/concrete/concrete_value.mli +++ b/src/concrete/concrete_value.mli @@ -35,7 +35,7 @@ type ref_value = | Funcref of Func_intf.t option | Arrayref of unit array option -val pp_ref_value : Format.formatter -> ref_value -> unit +val pp_ref_value : Fmt.formatter -> ref_value -> unit type t = | I32 of Int32.t @@ -60,4 +60,4 @@ val ref_externref : 'a Type.Id.t -> 'a -> t val ref_is_null : ref_value -> bool -val pp : Format.formatter -> t -> unit +val pp : Fmt.formatter -> t -> unit diff --git a/src/concrete/v.ml b/src/concrete/v.ml index 2f7a751e2..312431682 100644 --- a/src/concrete/v.ml +++ b/src/concrete/v.ml @@ -8,11 +8,11 @@ include ( type int32 = Int32.t - let pp_int32 fmt i = Format.pp fmt "%ld" i + let pp_int32 fmt i = Fmt.pf fmt "%ld" i type int64 = Int64.t - let pp_int64 fmt i = Format.pp fmt "%Ld" i + let pp_int64 fmt i = Fmt.pf fmt "%Ld" i type float32 = Float32.t @@ -62,7 +62,7 @@ include ( let int32 = function true -> 1l | false -> 0l - let pp = Format.pp_bool + let pp = Fmt.bool end module I32 = struct @@ -70,11 +70,15 @@ include ( include Convert.Int32 let to_bool i = Int32.ne i 0l + + let eq_const = eq end module I64 = struct include Int64 include Convert.Int64 + + let eq_const = eq end module F32 = struct diff --git a/src/data_structures/indexed.mli b/src/data_structures/indexed.mli index 1aab6820b..08d335838 100644 --- a/src/data_structures/indexed.mli +++ b/src/data_structures/indexed.mli @@ -18,4 +18,4 @@ val get_at_exn : int -> 'a t list -> 'a val has_index : int -> 'a t -> bool -val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit +val pp : (Fmt.formatter -> 'a -> unit) -> Fmt.formatter -> 'a t -> unit diff --git a/src/data_structures/stack.ml b/src/data_structures/stack.ml index f73512c2c..02e56236d 100644 --- a/src/data_structures/stack.ml +++ b/src/data_structures/stack.ml @@ -21,7 +21,7 @@ module type S = sig val empty : t - val pp : Format.formatter -> t -> unit + val pp : Fmt.formatter -> t -> unit (** pop operations *) @@ -130,7 +130,7 @@ module Make (V : Value_intf.T) : let push_array _ _ = assert false let pp fmt (s : t) = - Format.pp_list ~pp_sep:(fun fmt () -> Format.pp_string fmt " ; ") V.pp fmt s + Fmt.list ~sep:(fun fmt () -> Fmt.string fmt " ; ") V.pp fmt s let pop = function [] -> raise Empty | hd :: tl -> (hd, tl) @@ -213,5 +213,8 @@ module Make (V : Value_intf.T) : let rec drop_n s n = if n = 0 then s - else match s with [] -> invalid_arg "drop_n" | _ :: tl -> drop_n tl (n - 1) + else + match s with + | [] -> Fmt.invalid_arg "drop_n" + | _ :: tl -> drop_n tl (n - 1) end diff --git a/src/dune b/src/dune index 738236bb2..8df966e3f 100644 --- a/src/dune +++ b/src/dune @@ -38,7 +38,6 @@ env_id float32 float64 - format func_id func_intf grouped @@ -102,11 +101,14 @@ menhirLib ocaml_intrinsics ppxlib + prelude processor sedlex uutf runtime_events xmlm) + (flags + (:standard -open Prelude)) (preprocess (pps sedlex.ppx)) (instrumentation diff --git a/src/interpret/interpret.ml b/src/interpret/interpret.ml index ca41ca45d..3c761c975 100644 --- a/src/interpret/interpret.ml +++ b/src/interpret/interpret.ml @@ -89,8 +89,6 @@ module Make (P : Interpret_intf.P) : let* b = select b in return (b, stack) - let p_type_eq (_id1, t1) (_id2, t2) = t1 = t2 - let ( let> ) v f = let* v = select v in f v @@ -400,28 +398,29 @@ module Make (P : Interpret_intf.P) : end let exec_fconverti stack nn nn' sx = + let is_signed = match sx with S -> true | U -> false in match nn with | S32 -> ( let open F32 in match nn' with | S32 -> let n, stack = Stack.pop_i32 stack in - let n = if sx = S then convert_i32_s n else convert_i32_u n in + let n = if is_signed then convert_i32_s n else convert_i32_u n in Stack.push_f32 stack n | S64 -> let n, stack = Stack.pop_i64 stack in - let n = if sx = S then convert_i64_s n else convert_i64_u n in + let n = if is_signed then convert_i64_s n else convert_i64_u n in Stack.push_f32 stack n ) | S64 -> ( let open F64 in match nn' with | S32 -> let n, stack = Stack.pop_i32 stack in - let n = if sx = S then convert_i32_s n else convert_i32_u n in + let n = if is_signed then convert_i32_s n else convert_i32_u n in Stack.push_f64 stack n | S64 -> let n, stack = Stack.pop_i64 stack in - let n = if sx = S then convert_i64_s n else convert_i64_u n in + let n = if is_signed then convert_i64_s n else convert_i64_u n in Stack.push_f64 stack n ) let exec_ireinterpretf stack nn nn' = @@ -632,25 +631,29 @@ module Make (P : Interpret_intf.P) : let rec print_count ppf count = let calls ppf tbl = let l = - List.sort (fun (id1, _) (id2, _) -> compare id1 id2) + (* TODO: move this to Types.ml *) + List.sort + (fun + ((Raw id1 : binary indice), _) ((Raw id2 : binary indice), _) -> + compare id1 id2 ) @@ List.of_seq @@ Hashtbl.to_seq tbl in match l with | [] -> () | _ :: _ -> - Format.pp ppf "@ @[calls@ %a@]" - (Format.pp_list - ~pp_sep:(fun ppf () -> Format.pp ppf "@ ") + Fmt.pf ppf "@ @[calls@ %a@]" + (Fmt.list + ~sep:(fun ppf () -> Fmt.pf ppf "@ ") (fun ppf ((Raw id : binary indice), count) -> let name ppf = function | None -> () - | Some name -> Format.pp ppf " %s" name + | Some name -> Fmt.pf ppf " %s" name in - Format.pp ppf "@[id %i%a@ %a@]" id name count.name + Fmt.pf ppf "@[id %i%a@ %a@]" id name count.name print_count count ) ) l in - Format.pp ppf "@[enter %i@ intrs %i%a@]" count.enter count.instructions + Fmt.pf ppf "@[enter %i@ intrs %i%a@]" count.enter count.instructions calls count.calls let empty_count name = @@ -772,8 +775,7 @@ module Make (P : Interpret_intf.P) : let f = Env.get_extern_func state.env f in Extern_func.extern_type f - let call_ref ~return (state : State.exec_state) typ_i = - ignore (return, state, typ_i); + let call_ref ~return:_ (_state : State.exec_state) _typ_i = (* TODO *) assert false (* let fun_ref, stack = Stack.pop_as_ref state.stack in *) @@ -797,8 +799,8 @@ module Make (P : Interpret_intf.P) : let state = { state with stack } in let* t = Env.get_table state.env tbl_i in let _null, ref_kind = Table.typ t in - if ref_kind <> Func_ht then Choice.trap Indirect_call_type_mismatch - else + match ref_kind with + | Func_ht -> let size = Table.size t in let> out_of_bounds = Bool.or_ I32.(fun_i < const 0l) @@ I32.(consti size <= fun_i) @@ -808,15 +810,18 @@ module Make (P : Interpret_intf.P) : let* fun_i = Choice.select_i32 fun_i in let fun_i = Int32.to_int fun_i in let f_ref = Table.get t fun_i in - match Ref.get_func f_ref with - | Null -> Choice.trap (Uninitialized_element fun_i) - | Type_mismatch -> Choice.trap Element_type_error - | Ref_value func -> - let pt, rt = func_type state func in - let pt', rt' = typ_i in - if not (rt = rt' && List.equal p_type_eq pt pt') then - Choice.trap Indirect_call_type_mismatch - else exec_vfunc ~return state func + begin + match Ref.get_func f_ref with + | Null -> Choice.trap (Uninitialized_element fun_i) + | Type_mismatch -> Choice.trap Element_type_error + | Ref_value func -> + let ft = func_type state func in + let ft' = typ_i in + if not (Types.func_type_eq ft ft') then + Choice.trap Indirect_call_type_mismatch + else exec_vfunc ~return state func + end + | _ -> Choice.trap Indirect_call_type_mismatch let exec_instr instr (state : State.exec_state) : State.instr_result Choice.t = @@ -1018,7 +1023,6 @@ module Make (P : Interpret_intf.P) : st @@ Stack.push stack (Global.value g) | Global_set (Raw i) -> let* global = Env.get_global env i in - if Global.mut global = Const then Log.err "Can't set const global"; let v, stack = match Global.typ global with | Ref_type _rt -> Stack.pop_ref stack @@ -1071,13 +1075,13 @@ module Make (P : Interpret_intf.P) : let delta, stack = Stack.pop_i32 stack in let new_size = I32.(size + delta) in let allowed = - ( match Table.max_size t with - | None -> true - | Some max -> consti max >= new_size ) - && new_size >= const 0l - && new_size >= size + Bool.and_ + ( match Table.max_size t with + | None -> Bool.const true + | Some max -> I32.ge (consti max) new_size ) + @@ Bool.and_ (I32.ge new_size (const 0l)) (I32.ge new_size size) in - + let> allowed in if not allowed then let stack = Stack.drop stack in st @@ Stack.push_i32_of_int stack (-1) @@ -1091,10 +1095,11 @@ module Make (P : Interpret_intf.P) : let len, stack = Stack.pop_i32 stack in let x, stack = Stack.pop_as_ref stack in let pos, stack = Stack.pop_i32 stack in - let out_of_bounds = - len < const 0l - || pos < const 0l - || I32.(pos + len) > consti (Table.size t) + let> out_of_bounds = + Bool.or_ (I32.lt len (const 0l)) + @@ Bool.or_ + (I32.lt pos (const 0l)) + (I32.gt I32.(pos + len) (consti (Table.size t))) in if out_of_bounds then Choice.trap Out_of_bounds_table_access else begin @@ -1109,19 +1114,19 @@ module Make (P : Interpret_intf.P) : let len, stack = Stack.pop_i32 stack in let src, stack = Stack.pop_i32 stack in let dst, stack = Stack.pop_i32 stack in - let out_of_bounds = + let> out_of_bounds = let t_src_len = Table.size t_src in let t_dst_len = Table.size t_dst in - I32.(src + len) > consti t_src_len - || I32.(dst + len) > consti t_dst_len - || len < const 0l - || src < const 0l - || dst < const 0l + Bool.or_ (I32.gt I32.(src + len) (consti t_src_len)) + @@ Bool.or_ (I32.gt I32.(dst + len) (consti t_dst_len)) + @@ Bool.or_ (I32.lt len (const 0l)) + @@ Bool.or_ (I32.lt src (const 0l)) (I32.lt dst (const 0l)) in if out_of_bounds then Choice.trap Out_of_bounds_table_access else begin let* () = - if len <> const 0l then begin + let> len_is_not_zero = I32.ne len (const 0l) in + if len_is_not_zero then begin let* src = Choice.select_i32 src in let* dst = Choice.select_i32 dst in let+ len = Choice.select_i32 len in @@ -1189,7 +1194,8 @@ module Make (P : Interpret_intf.P) : if out_of_bounds then Choice.trap Out_of_bounds_memory_access else let* res = - (if sx = S then Memory.load_16_s else Memory.load_16_u) mem addr + (match sx with S -> Memory.load_16_s | U -> Memory.load_16_u) + mem addr in st @@ @@ -1211,7 +1217,7 @@ module Make (P : Interpret_intf.P) : if out_of_bounds then Choice.trap Out_of_bounds_memory_access else let* res = - (if sx = S then Memory.load_8_s else Memory.load_8_u) mem addr + (match sx with S -> Memory.load_8_s | U -> Memory.load_8_u) mem addr in st @@ @@ -1544,7 +1550,7 @@ module Make (P : Interpret_intf.P) : match end_stack with | [] -> () | _ :: _ -> - Format.pp_err "non empty stack@\n%a@." Stack.pp end_stack; + Fmt.epr "non empty stack@\n%a@." Stack.pp end_stack; assert false ) (Choice.return ()) (Module_to_run.to_run modul) diff --git a/src/interpret/trap.ml b/src/interpret/trap.ml index 24bc42788..dcf1bf4ec 100644 --- a/src/interpret/trap.ml +++ b/src/interpret/trap.ml @@ -21,8 +21,7 @@ let to_string = function | Out_of_bounds_table_access -> "out of bounds table access" | Out_of_bounds_memory_access -> "out of bounds memory access" | Undefined_element -> "undefined element" - | Uninitialized_element fun_i -> - Printf.sprintf "uninitialized element %i" fun_i + | Uninitialized_element fun_i -> Fmt.str "uninitialized element %i" fun_i | Integer_overflow -> "integer overflow" | Integer_divide_by_zero -> "integer divide by zero" | Element_type_error -> "element_type_error" diff --git a/src/intf/value_intf.ml b/src/intf/value_intf.ml index 300d5b516..5549191bd 100644 --- a/src/intf/value_intf.ml +++ b/src/intf/value_intf.ml @@ -170,23 +170,23 @@ module type T = sig type int32 - val pp_int32 : Format.formatter -> int32 -> unit + val pp_int32 : Fmt.formatter -> int32 -> unit type int64 - val pp_int64 : Format.formatter -> int64 -> unit + val pp_int64 : Fmt.formatter -> int64 -> unit type float32 - val pp_float32 : Format.formatter -> float32 -> unit + val pp_float32 : Fmt.formatter -> float32 -> unit type float64 - val pp_float64 : Format.formatter -> float64 -> unit + val pp_float64 : Fmt.formatter -> float64 -> unit type ref_value - val pp_ref_value : Format.formatter -> ref_value -> unit + val pp_ref_value : Fmt.formatter -> ref_value -> unit type t = | I32 of int32 @@ -195,7 +195,7 @@ module type T = sig | F64 of float64 | Ref of ref_value - val pp : Format.formatter -> t -> unit + val pp : Fmt.formatter -> t -> unit val const_i32 : Int32.t -> int32 @@ -231,7 +231,7 @@ module type T = sig val int32 : vbool -> int32 - val pp : Format.formatter -> vbool -> unit + val pp : Fmt.formatter -> vbool -> unit end module F32 : sig diff --git a/src/link/link.ml b/src/link/link.ml index c12ecc4f4..bcb17e5df 100644 --- a/src/link/link.ml +++ b/src/link/link.ml @@ -68,7 +68,7 @@ let load_global (ls : 'f state) (import : binary global_type Imported.t) : | Var, Const | Const, Var -> Error `Incompatible_import_type | Const, Const | Var, Var -> Ok () in - if snd import.desc <> global.typ then begin + if not @@ Types.val_type_eq (snd import.desc) global.typ then begin Error `Incompatible_import_type end else Ok global @@ -203,7 +203,7 @@ let eval_memories ls env memories = memories (Ok env) let table_types_are_compatible (import, (t1 : binary ref_type)) (imported, t2) = - limit_is_included ~import ~imported && t1 = t2 + limit_is_included ~import ~imported && Types.ref_type_eq t1 t2 let load_table (ls : 'f state) (import : binary table_type Imported.t) : table Result.t = @@ -226,14 +226,6 @@ let eval_tables ls env tables = Ok env ) tables (Ok env) -let func_types_are_compatible a b = - (* TODO: copied from Simplify_bis.equal_func_types => should factorize *) - let remove_param (pt, rt) = - let pt = List.map (fun (_id, vt) -> (None, vt)) pt in - (pt, rt) - in - remove_param a = remove_param b - let load_func (ls : 'f state) (import : binary block_type Imported.t) : func Result.t = let (Bt_raw ((None | Some _), typ)) = import.desc in @@ -245,7 +237,7 @@ let load_func (ls : 'f state) (import : binary block_type Imported.t) : t | Extern func_id -> Func_id.get_typ func_id ls.collection in - if func_types_are_compatible typ type' then Ok func + if Types.func_type_eq typ type' then Ok func else Error `Incompatible_import_type let eval_func ls (finished_env : Link_env.t') func : func Result.t = diff --git a/src/link/link_env.ml b/src/link/link_env.ml index f384d7739..9ff62bffb 100644 --- a/src/link/link_env.ml +++ b/src/link/link_env.ml @@ -168,7 +168,7 @@ module type T = sig val get_func_typ : t -> func -> binary func_type - val pp : Format.formatter -> t -> unit + val pp : Fmt.formatter -> t -> unit val freeze : Build.t -> extern_func Func_id.collection -> t end diff --git a/src/link/link_env.mli b/src/link/link_env.mli index ee7af37af..3d4298a9b 100644 --- a/src/link/link_env.mli +++ b/src/link/link_env.mli @@ -95,7 +95,7 @@ module type T = sig val get_func_typ : t -> func -> binary func_type - val pp : Format.formatter -> t -> unit + val pp : Fmt.formatter -> t -> unit val freeze : Build.t -> extern_func Func_id.collection -> t end diff --git a/src/optimize/optimize.ml b/src/optimize/optimize.ml index 171a7921a..03a169438 100644 --- a/src/optimize/optimize.ml +++ b/src/optimize/optimize.ml @@ -252,8 +252,8 @@ let rec optimize_expr expr : bool * binary instr list = | (I32_const _ | I64_const _ | F32_const _ | F64_const _) :: Drop :: tl -> let _has_changed, e = optimize_expr tl in (true, e) - | Local_set x :: Local_get y :: tl when x = y -> - let _has_changed, e = optimize_expr (Local_tee x :: tl) in + | Local_set (Raw x) :: Local_get (Raw y) :: tl when x = y -> + let _has_changed, e = optimize_expr (Local_tee (Raw x) :: tl) in (true, e) | Local_get _ :: Drop :: tl -> let _has_changed, e = optimize_expr tl in diff --git a/src/parser/binary_parser.ml b/src/parser/binary_parser.ml index 50212902a..677c9494c 100644 --- a/src/parser/binary_parser.ml +++ b/src/parser/binary_parser.ml @@ -43,7 +43,7 @@ end = struct let sub ~pos ~len input = if pos <= input.size && len <= input.size - pos then Ok { input with pt = input.pt + pos; size = len } - else Error (`Msg (Format.sprintf "length out of bounds in section")) + else Error (`Msg (Fmt.str "length out of bounds in section")) let sub_suffix pos input = sub ~pos ~len:(input.size - pos) input @@ -190,9 +190,7 @@ let check_end_opcode ?unexpected_eoi_msg input = | Ok ('\x0B', input) -> Ok input | Ok (c, _input) -> Error - (`Msg - (Format.sprintf "END opcode expected (got %s instead)" (Char.escaped c)) - ) + (`Msg (Fmt.str "END opcode expected (got %s instead)" (Char.escaped c))) | Error _ as e -> e let check_zero_opcode input = @@ -200,7 +198,7 @@ let check_zero_opcode input = match read_byte ~msg input with | Ok ('\x00', input) -> Ok input | Ok (c, _input) -> - Error (`Msg (Format.sprintf "%s (got %s instead)" msg (Char.escaped c))) + Error (`Msg (Fmt.str "%s (got %s instead)" msg (Char.escaped c))) | Error _ as e -> e let read_bytes ~msg input = vector_no_id (read_byte ~msg) input @@ -216,7 +214,7 @@ let read_numtype input = | -0x02 -> Ok (I64, input) | -0x03 -> Ok (F32, input) | -0x04 -> Ok (F64, input) - | b -> Error (`Msg (Format.sprintf "malformed number type: %d" b)) + | b -> Error (`Msg (Fmt.str "malformed number type: %d" b)) let read_vectype input = let* b, _input = read_S7 input in @@ -224,14 +222,14 @@ let read_vectype input = | -0x05 -> (* V128 *) assert false - | b -> Error (`Msg (Format.sprintf "malformed vector type: %d" b)) + | b -> Error (`Msg (Fmt.str "malformed vector type: %d" b)) let read_reftype input = let* b, input = read_S7 input in match b with | -0x10 -> Ok ((Null, Func_ht), input) | -0x11 -> Ok ((Null, Extern_ht), input) - | b -> Error (`Msg (Format.sprintf "malformed reference type: %d" b)) + | b -> Error (`Msg (Fmt.str "malformed reference type: %d" b)) let read_valtype input = match read_numtype input with @@ -319,7 +317,7 @@ let read_FC input = | 17 -> let+ tableidx, input = read_indice input in (Table_fill tableidx, input) - | i -> Error (`Msg (Format.sprintf "illegal opcode (1) %i" i)) + | i -> Error (`Msg (Fmt.str "illegal opcode (1) %i" i)) let block_type_of_rec_type t = (* TODO: this is a ugly hack, it is necessary for now and should be removed at some point... *) @@ -330,7 +328,7 @@ let block_type_of_rec_type t = let read_block_type types input = match read_S33 input with - | Ok (i, input) when i >= 0L -> + | Ok (i, input) when Int64.ge i 0L -> let block_type = block_type_of_rec_type types.(Int64.to_int i) in Ok (block_type, input) | Error _ | Ok _ -> begin @@ -639,7 +637,7 @@ let rec read_instr types input = let+ funcidx, input = read_indice input in (Ref_func funcidx, input) | '\xFC' -> read_FC input - | c -> Error (`Msg (Format.sprintf "illegal opcode (2) %s" (Char.escaped c))) + | c -> Error (`Msg (Fmt.str "illegal opcode (2) %s" (Char.escaped c))) and read_expr types input = let rec aux acc input = @@ -679,11 +677,11 @@ let version_check str = let check_section_id = function | '\x00' .. '\x0C' -> Ok () - | c -> Error (`Msg (Format.sprintf "malformed section id %s" (Char.escaped c))) + | c -> Error (`Msg (Fmt.str "malformed section id %s" (Char.escaped c))) let section_parse input ~expected_id default section_content_parse = match Input.get 0 input with - | Some id when id = expected_id -> + | Some id when Char.equal id expected_id -> let* () = check_section_id id in let* input = Input.sub_suffix 1 input in let* () = @@ -727,8 +725,9 @@ let section_custom input = let read_type _id input = let* fcttype, input = read_byte ~msg:"read_type" input in let* () = - if fcttype <> '\x60' then Error (`Msg "integer representation too long") - else Ok () + match fcttype with + | '\x60' -> Ok () + | _ -> Error (`Msg "integer representation too long") in let* params, input = read_valtypes input in let+ results, input = read_valtypes input in @@ -799,9 +798,7 @@ let read_elem_kind input = match read_byte ~msg input with | Ok ('\x00', input) -> Ok ((Null, Func_ht), input) | Ok (c, _input) -> - Error - (`Msg - (Format.sprintf "%s (expected 0x00 but got %s)" msg (Char.escaped c)) ) + Error (`Msg (Fmt.str "%s (expected 0x00 but got %s)" msg (Char.escaped c))) | Error _ as e -> e let read_element types input = @@ -848,7 +845,7 @@ let read_element types input = let* typ, input = read_reftype input in let+ init, input = vector_no_id (read_const types) input in ({ id; typ; init; mode }, input) - | i -> Error (`Msg (Format.sprintf "malformed elements segment kind: %d" i)) + | i -> Error (`Msg (Fmt.str "malformed elements segment kind: %d" i)) let read_local input = let* n, input = read_U32 input in @@ -917,7 +914,7 @@ let read_data types input = let+ init, input = read_bytes ~msg:"read_data 2" input in let init = string_of_char_list init in ({ id; init; mode }, input) - | i -> Error (`Msg (Format.sprintf "malformed data segment kind %d" i)) + | i -> Error (`Msg (Fmt.str "malformed data segment kind %d" i)) let parse_many_custom_section input = let rec aux acc input = @@ -1202,7 +1199,7 @@ let sections_iterate (input : Input.t) = | '\x03' -> let global = export :: exports.global in { exports with global } - | _ -> failwith "read_exportdesc error" ) + | _ -> Fmt.failwith "read_exportdesc error" ) empty_exports export_section in let exports = diff --git a/src/parser/parse.ml b/src/parser/parse.ml index 28aea3f05..f8c94d1f5 100644 --- a/src/parser/parse.ml +++ b/src/parser/parse.ml @@ -315,9 +315,9 @@ let token_to_string = function | ANY_REF -> "anyref" | ANY -> "any" | ALIGN -> "align" - | NUM s -> Format.sprintf "%s" s - | NAME s -> Format.sprintf {|"%s"|} s - | ID s -> Format.sprintf "$%s" s + | NUM s -> Fmt.str "%s" s + | NAME s -> Fmt.str {|"%s"|} s + | ID s -> Fmt.str "$%s" s module Make (M : sig type t diff --git a/src/parser/text_lexer.ml b/src/parser/text_lexer.ml index 3e7bbd5c9..03a97648a 100644 --- a/src/parser/text_lexer.ml +++ b/src/parser/text_lexer.ml @@ -13,22 +13,22 @@ exception Unexpected_character of string let illegal_escape buf = let tok = Utf8.lexeme buf in - raise @@ Illegal_escape (Printf.sprintf "illegal escape %S" tok) + raise @@ Illegal_escape (Fmt.str "illegal escape %S" tok) let unknown_operator buf = let tok = Utf8.lexeme buf in - raise @@ Unknown_operator (Printf.sprintf "unknown operator %S" tok) + raise @@ Unknown_operator (Fmt.str "unknown operator %S" tok) let unexpected_character buf = let tok = Utf8.lexeme buf in - raise @@ Unexpected_character (Printf.sprintf "unexpected character `%S`" tok) + raise @@ Unexpected_character (Fmt.str "unexpected character `%S`" tok) let mk_string buf s = let b = Buffer.create (String.length s) in let i = ref 0 in while !i < String.length s do let c = - if s.[!i] <> '\\' then s.[!i] + if not @@ Char.equal s.[!i] '\\' then s.[!i] else match incr i; @@ -43,16 +43,17 @@ let mk_string buf s = | 'u' -> let j = !i + 2 in i := String.index_from s j '}'; - let n = int_of_string ("0x" ^ String.sub s j (!i - j)) in + let n = int_of_string (Fmt.str "0x%s" (String.sub s j (!i - j))) in + let n = match n with None -> assert false | Some n -> n in let bs = Wutf8.encode [ n ] in Buffer.add_substring b bs 0 (String.length bs - 1); bs.[String.length bs - 1] | h -> incr i; if !i >= String.length s then illegal_escape buf; - let str = Format.sprintf "0x%c%c" h s.[!i] in + let str = Fmt.str "0x%c%c" h s.[!i] in begin - match int_of_string_opt str with + match int_of_string str with | None -> illegal_escape buf | Some n -> Char.chr n end diff --git a/src/parser/text_parser.mly b/src/parser/text_parser.mly index 138b25932..058dcf7df 100644 --- a/src/parser/text_parser.mly +++ b/src/parser/text_parser.mly @@ -210,7 +210,7 @@ let num_type == let align == | ALIGN; EQUAL; n = NUM; { let n = i32 n in - if n = 0l || Int32.(logand n (sub n 1l)) <> 0l then failwith "alignment" + if Int32.eq n 0l || Int32.ne Int32.(logand n (sub n 1l)) 0l then failwith "alignment" else Int32.div n 2l } @@ -543,20 +543,19 @@ let call_instr_results_instr_list := let block_instr == | BLOCK; id = option(id); (bt, es) = block; END; id2 = option(id); { - if Option.is_some id2 && id <> id2 then failwith "mismatching label"; + if not @@ Option.equal String.equal id id2 then Fmt.failwith "mismatching label"; Block (id, bt, es) } | LOOP; id = option(id); (bt, es) = block; END; id2 = option(id); { - if Option.is_some id2 && id <> id2 then failwith "mismatching label"; + if not @@ Option.equal String.equal id id2 then Fmt.failwith "mismatching label"; Loop (id, bt, es) } | IF; id = option(id); (bt, es) = block; END; id2 = option(id); { - if Option.is_some id2 && id <> id2 then failwith "mismatching label"; + if not @@ Option.equal String.equal id id2 then Fmt.failwith "mismatching label"; If_else (id, bt, es, []) } | IF; id = option(id); (bt, es1) = block; ELSE; id2 = option(id); ~ = instr_list; END; id3 = option(id); { - if Option.is_some id2 && id <> id2 then failwith "mismatching label"; - if Option.is_some id3 && id <> id3 then failwith "mismatching label"; + if not @@ Option.equal String.equal id id2 || not @@ Option.equal String.equal id id3 then Fmt.failwith "mismatching label"; If_else (id, bt, es1, instr_list) } @@ -711,7 +710,7 @@ let func == | MExport e -> MExport { e with desc = Export_func func_id } | MFunc f -> MFunc { f with id } | MData _ | MElem _ | MGlobal _ | MStart _ | MType _ | MTable _ | MMem _ as field -> begin - Format.pp_err "got invalid field: `%a`@." pp_module_field field; + Fmt.epr "got invalid field: `%a`@." pp_module_field field; assert false end ) func_fields @@ -832,7 +831,7 @@ let table == | Import_func _ | Import_global _ | Import_mem _ -> assert false end | MMem _ | MData _ | MStart _ | MFunc _ | MGlobal _ | MType _ as field -> begin - Format.pp_err "got invalid field: `%a`@." pp_module_field field; + Fmt.epr "got invalid field: `%a`@." pp_module_field field; assert false end ) table_fields @@ -879,7 +878,7 @@ let memory == | Import_table _ | Import_func _ | Import_global _ -> assert false end | MElem _ | MType _ | MTable _ | MFunc _ | MGlobal _ | MStart _ as field -> begin - Format.pp_err "got invalid field: `%a`@." pp_module_field field; + Fmt.epr "got invalid field: `%a`@." pp_module_field field; assert false end ) memory_fields @@ -914,7 +913,7 @@ let global == | Import_mem _ | Import_table _ | Import_func _ -> assert false end | MStart _ | MFunc _ | MData _ | MElem _ | MMem _ | MTable _ | MType _ as field -> begin - Format.pp_err "got invalid field: `%a`@." pp_module_field field; + Fmt.epr "got invalid field: `%a`@." pp_module_field field; assert false end ) global_fields @@ -1016,8 +1015,16 @@ let literal_const == | F32_CONST; num = NUM; { Const_F32 (f32 num) } | F64_CONST; num = NUM; { Const_F64 (f64 num) } | REF_NULL; ~ = heap_type; - | REF_EXTERN; num = NUM; { Const_extern (int_of_string num) } - | REF_HOST; num = NUM; { Const_host (int_of_string num) } + | REF_EXTERN; num = NUM; { + match int_of_string num with + | None -> assert false + | Some num -> Const_extern num + } + | REF_HOST; num = NUM; { + match int_of_string num with + | None -> assert false + | Some num -> Const_host num + } | REF_ARRAY; { Const_array } | REF_EQ; { Const_eq } | REF_I31; { Const_i31 } diff --git a/src/primitives/convert.ml b/src/primitives/convert.ml index 6557d5baa..8bd17ebd7 100644 --- a/src/primitives/convert.ml +++ b/src/primitives/convert.ml @@ -15,67 +15,74 @@ module MInt32 = struct if Float32.ne x x then raise @@ Types.Trap "invalid conversion to integer" else let xf = Float32.to_float x in - if xf >= -.Int32.(to_float min_int) || xf < Int32.(to_float min_int) then - raise @@ Types.Trap "integer overflow" + if + let xf = Float64.of_float xf in + let mif = Int32.(to_float min_int) in + Float64.(ge xf (of_float ~-.mif)) || Float64.(le xf (of_float mif)) + then raise @@ Types.Trap "integer overflow" else Int32.of_float xf let trunc_f32_u x = if Float32.ne x x then raise @@ Types.Trap "invalid conversion to integer" else let xf = Float32.to_float x in - if xf >= -.Int32.(to_float min_int) *. 2.0 || xf <= -1.0 then - raise @@ Types.Trap "integer overflow" + if + let xf = Float64.of_float xf in + Float64.(ge xf (of_float @@ (-.Int32.(to_float min_int) *. 2.0))) + || Float64.(ge xf (Float64.of_float ~-.1.0)) + then raise @@ Types.Trap "integer overflow" else Int64.(to_int32 (of_float xf)) let trunc_f64_s x = if Float64.ne x x then raise @@ Types.Trap "invalid conversion to integer" - else - let xf = Float64.to_float x in - if - xf >= -.Int32.(to_float min_int) - || xf <= Int32.(to_float min_int) -. 1.0 - then raise @@ Types.Trap "integer overflow" - else Int32.of_float xf + else if + let mif = Int32.(to_float min_int) in + Float64.(ge x (of_float @@ -.mif)) + || Float64.(le x (of_float @@ (mif -. 1.0))) + then raise @@ Types.Trap "integer overflow" + else Int32.of_float (Float64.to_float x) let trunc_f64_u x = if Float64.ne x x then raise @@ Types.Trap "invalid conversion to integer" - else - let xf = Float64.to_float x in - if xf >= -.Int32.(to_float min_int) *. 2.0 || xf <= -1.0 then - raise @@ Types.Trap "integer overflow" - else Int64.(to_int32 (of_float xf)) + else if + Float64.( + ge x (mul (of_float @@ -.Int32.(to_float min_int)) (of_float 2.0)) ) + || Float64.(le x (of_float ~-.1.0)) + then raise @@ Types.Trap "integer overflow" + else Int64.(to_int32 (of_float (Float64.to_float x))) let trunc_sat_f32_s x = if Float32.ne x x then 0l else - let xf = Float32.to_float x in - if xf < Int32.(to_float min_int) then Int32.min_int - else if xf >= -.Int32.(to_float min_int) then Int32.max_int - else Int32.of_float xf + let xf = Float32.to_float x |> Float64.of_float in + let mif = Int32.(to_float min_int) in + if Float64.(lt xf (of_float mif)) then Int32.min_int + else if Float64.(ge xf (of_float ~-.mif)) then Int32.max_int + else Int32.of_float (Float64.to_float xf) let trunc_sat_f32_u x = if Float32.ne x x then 0l else - let xf = Float32.to_float x in - if xf <= -1.0 then 0l - else if xf >= -.Int32.(to_float min_int) *. 2.0 then -1l - else Int64.(to_int32 (of_float xf)) + let xf = Float32.to_float x |> Float64.of_float in + if Float64.(le xf (of_float ~-.1.0)) then 0l + else if Float64.(ge xf @@ of_float @@ (~-.Int32.(to_float min_int) *. 2.0)) + then -1l + else Int64.(to_int32 @@ of_float (Float64.to_float xf)) let trunc_sat_f64_s x = if Float64.ne x x then 0l - else - let xf = Float64.to_float x in - if xf < Int32.(to_float min_int) then Int32.min_int - else if xf >= -.Int32.(to_float min_int) then Int32.max_int - else Int32.of_float xf + else if Float64.(le x @@ of_float @@ Int32.(to_float min_int)) then + Int32.min_int + else if Float64.(ge x @@ of_float @@ ~-.Int32.(to_float min_int)) then + Int32.max_int + else Int32.of_float @@ Float64.to_float x let trunc_sat_f64_u x = if Float64.ne x x then 0l - else - let xf = Float64.to_float x in - if xf <= -1.0 then 0l - else if xf >= -.Int32.(to_float min_int) *. 2.0 then -1l - else Int64.(to_int32 (of_float xf)) + else if Float64.(le x @@ of_float ~-.1.0) then 0l + else if Float64.(ge x @@ of_float @@ ~-.(Int32.(to_float min_int) *. 2.0)) + then -1l + else Int64.(to_int32 (of_float @@ Float64.to_float x)) let reinterpret_f32 = Float32.to_bits end @@ -89,75 +96,83 @@ module MInt64 = struct let trunc_f32_s x = if Float32.ne x x then raise @@ Types.Trap "invalid conversion to integer" - else - let xf = Float32.to_float x in - if xf >= -.Int64.(to_float min_int) || xf < Int64.(to_float min_int) then - raise @@ Types.Trap "integer overflow" - else Int64.of_float xf + else if + let mif = Int64.(to_float min_int) in + Float32.(ge x @@ of_float @@ ~-.mif) || Float32.(lt x @@ of_float mif) + then raise @@ Types.Trap "integer overflow" + else Int64.of_float @@ Float32.to_float x let trunc_f32_u x = + let mif = Int64.(to_float min_int) in if Float32.ne x x then raise @@ Types.Trap "invalid conversion to integer" - else - let xf = Float32.to_float x in - if xf >= -.Int64.(to_float min_int) *. 2.0 || xf <= -1.0 then - raise @@ Types.Trap "integer overflow" - else if xf >= -.Int64.(to_float min_int) then - Int64.(logxor (of_float (xf -. 0x1p63)) min_int) - else Int64.of_float xf + else if + Float32.(ge x @@ of_float ~-.(mif *. 2.0)) + || Float32.(le x @@ of_float ~-.1.0) + then raise @@ Types.Trap "integer overflow" + else if Float32.(le x @@ of_float ~-.mif) then + Int64.(logxor (of_float (Float32.to_float x -. 0x1p63)) min_int) + else Int64.of_float @@ Float32.to_float x let trunc_f64_s x = if Float64.ne x x then raise @@ Types.Trap "invalid conversion to integer" - else - let xf = Float64.to_float x in - if xf >= -.Int64.(to_float min_int) || xf < Int64.(to_float min_int) then - raise @@ Types.Trap "integer overflow" - else Int64.of_float xf + else if + let mif = Int64.(to_float min_int) in + Float64.(ge x @@ of_float ~-.mif) || Float64.(lt x @@ of_float mif) + then raise @@ Types.Trap "integer overflow" + else Int64.of_float @@ Float64.to_float x let trunc_f64_u x = + let mif = Int64.(to_float min_int) in if Float64.ne x x then raise @@ Types.Trap "invalid conversion to integer" - else - let xf = Float64.to_float x in - if xf >= -.Int64.(to_float min_int) *. 2.0 || xf <= -1.0 then - raise @@ Types.Trap "integer overflow" - else if xf >= -.Int64.(to_float min_int) then - Int64.(logxor (of_float (xf -. 0x1p63)) min_int) - else Int64.of_float xf + else if + Float64.(ge x @@ of_float (~-.mif *. 2.0)) + || Float64.(le x @@ of_float ~-.1.0) + then raise @@ Types.Trap "integer overflow" + else if Float64.(ge x @@ of_float ~-.mif) then + Int64.(logxor (of_float (Float64.to_float x -. 0x1p63)) min_int) + else Int64.of_float @@ Float64.to_float x let trunc_sat_f32_s x = if Float32.ne x x then 0L else - let xf = Float32.to_float x in - if xf < Int64.(to_float min_int) then Int64.min_int - else if xf >= -.Int64.(to_float min_int) then Int64.max_int - else Int64.of_float xf + let mif = Int64.(to_float min_int) in + if Float32.(lt x @@ of_float mif) then Int64.min_int + else if Float32.(ge x @@ of_float ~-.mif) then Int64.max_int + else Int64.of_float (Float32.to_float x) let trunc_sat_f32_u x = if Float32.ne x x then 0L else - let xf = Float32.to_float x in - if xf <= -1.0 then 0L - else if xf >= -.Int64.(to_float min_int) *. 2.0 then -1L - else if xf >= -.Int64.(to_float min_int) then - Int64.(logxor (of_float (xf -. 9223372036854775808.0)) min_int) - else Int64.of_float xf + let mif = Int64.(to_float min_int) in + if Float32.(le x @@ of_float ~-.1.0) then 0L + else if Float32.(ge x @@ of_float (~-.mif *. 2.0)) then -1L + else if Float32.(le x @@ of_float ~-.mif) then + Int64.( + logxor + (of_float (Float32.to_float x -. 9223372036854775808.0)) + min_int ) + else Int64.of_float @@ Float32.to_float x let trunc_sat_f64_s x = if Float64.ne x x then 0L else - let xf = Float64.to_float x in - if xf < Int64.(to_float min_int) then Int64.min_int - else if xf >= -.Int64.(to_float min_int) then Int64.max_int - else Int64.of_float xf + let mif = Int64.(to_float min_int) in + if Float64.(lt x @@ of_float mif) then Int64.min_int + else if Float64.(ge x @@ of_float ~-.mif) then Int64.max_int + else Int64.of_float @@ Float64.to_float x let trunc_sat_f64_u x = if Float64.ne x x then 0L else - let xf = Float64.to_float x in - if xf <= -1.0 then 0L - else if xf >= -.Int64.(to_float min_int) *. 2.0 then -1L - else if xf >= -.Int64.(to_float min_int) then - Int64.(logxor (of_float (xf -. 9223372036854775808.0)) min_int) - else Int64.of_float xf + let mif = Int64.(to_float min_int) in + if Float64.(le x @@ of_float ~-.1.0) then 0L + else if Float64.(ge x @@ of_float @@ (~-.mif *. 2.0)) then -1L + else if Float64.(ge x @@ of_float ~-.mif) then + Int64.( + logxor + (of_float (Float64.to_float x -. 9223372036854775808.0)) + min_int ) + else Int64.of_float @@ Float64.to_float x let reinterpret_f64 = Float64.to_bits end @@ -166,8 +181,7 @@ module MFloat32 = struct type t = Float32.t let demote_f64 x = - let xf = Float64.to_float x in - if xf = xf then Float32.of_float xf + if Float64.eq x x then Float32.of_float @@ Float64.to_float x else let nan64bits = Float64.to_bits x in let sign_field = @@ -190,7 +204,7 @@ module MFloat32 = struct let convert_i32_u x = Float32.of_float Int32.( - if x >= zero then to_float x + if Int32.ge x zero then to_float x else to_float (logor (shift_right_logical x 1) (logand x 1l)) *. 2.0 ) (* @@ -202,17 +216,17 @@ module MFloat32 = struct let convert_i64_s (x : int64) = Float32.of_float Int64.( - if abs x < 0x10_0000_0000_0000L then to_float x + if Int64.lt (abs x) 0x10_0000_0000_0000L then to_float x else - let r = if logand x 0xfffL = 0L then 0L else 1L in + let r = if Int64.eq (logand x 0xfffL) 0L then 0L else 1L in to_float (logor (shift_right x 12) r) *. 0x1p12 ) let convert_i64_u x = Float32.of_float Int64.( - if lt_u x 0x10_0000_0000_0000L then to_float x + if Int64.lt_u x 0x10_0000_0000_0000L then to_float x else - let r = if logand x 0xfffL = 0L then 0L else 1L in + let r = if Int64.eq (logand x 0xfffL) 0L then 0L else 1L in to_float (logor (shift_right_logical x 12) r) *. 0x1p12 ) let reinterpret_i32 = Float32.of_bits @@ -222,8 +236,7 @@ module MFloat64 = struct type t = Float64.t let promote_f32 x = - let xf = Float32.to_float x in - if xf = xf then Float64.of_float xf + if Float32.eq x x then Float64.of_float @@ Float32.to_float x else let nan32bits = MInt64.extend_i32_u (Float32.to_bits x) in let sign_field = @@ -258,7 +271,7 @@ module MFloat64 = struct let convert_i64_u x = Float64.of_float Int64.( - if x >= zero then to_float x + if Int64.ge x zero then to_float x else to_float (logor (shift_right_logical x 1) (logand x 1L)) *. 2.0 ) let reinterpret_i64 = Float64.of_bits diff --git a/src/primitives/float32.ml b/src/primitives/float32.ml index 5dd069040..25a56eca7 100644 --- a/src/primitives/float32.ml +++ b/src/primitives/float32.ml @@ -12,7 +12,7 @@ let neg_nan = 0xffc0_0000l let bare_nan = 0x7f80_0000l -let to_hex_string = Printf.sprintf "%lx" +let to_hex_string = Fmt.str "%lx" type t = Int32.t @@ -28,11 +28,15 @@ let of_bits x = x let to_bits x = x -let is_inf x = x = pos_inf || x = neg_inf +let is_inf x = Int32.eq x pos_inf || Int32.eq x neg_inf let is_nan x = let xf = Int32.float_of_bits x in - xf <> xf + Float.is_nan xf + +let is_pos_nan f = Int32.eq f pos_nan + +let is_neg_nan f = Int32.eq f neg_nan (* * When the result of an arithmetic operation is NaN, the most significant @@ -73,11 +77,11 @@ let binary x op y = let xf = to_float x in let yf = to_float y in let t = op xf yf in - if t = t then of_float t else determine_binary_nan x y + if not @@ Float.is_nan t then of_float t else determine_binary_nan x y let unary op x = let t = op (to_float x) in - if t = t then of_float t else determine_unary_nan x + if not @@ Float.is_nan t then of_float t else determine_unary_nan x let zero = of_float 0.0 @@ -89,38 +93,41 @@ let mul x y = binary x ( *. ) y let div x y = binary x ( /. ) y -let sqrt x = unary Stdlib.sqrt x +let sqrt x = unary Float.sqrt x -let ceil x = unary Stdlib.ceil x +let ceil x = unary Float.ceil x -let floor x = unary Stdlib.floor x +let floor x = unary Float.floor x let trunc x = let xf = to_float x in (* preserve the sign of zero *) - if xf = 0.0 then x + if Float.equal xf 0.0 then x else (* trunc is either ceil or floor depending on which one is toward zero *) - let f = if xf < 0.0 then Stdlib.ceil xf else Stdlib.floor xf in + let f = + if Float.compare xf 0.0 < 0 then Float.ceil xf else Float.floor xf + in let result = of_float f in if is_nan result then determine_unary_nan result else result let nearest x = let xf = to_float x in (* preserve the sign of zero *) - if xf = 0.0 then x + if Float.equal xf 0.0 then x else (* nearest is either ceil or floor depending on which is nearest or even *) - let u = Stdlib.ceil xf in - let d = Stdlib.floor xf in + let u = Float.ceil xf in + let d = Float.floor xf in let um = abs_float (xf -. u) in let dm = abs_float (xf -. d) in + let delta_u_d = Float.compare um dm in let u_or_d = - um < dm - || um = dm + delta_u_d < 0 + || delta_u_d = 0 && let h = u /. 2. in - Stdlib.floor h = h + Float.equal (Float.floor h) h in let f = if u_or_d then u else d in let result = of_float f in @@ -130,18 +137,20 @@ let min x y = let xf = to_float x in let yf = to_float y in (* min -0 0 is -0 *) - if xf = yf then Int32.logor x y - else if xf < yf then x - else if xf > yf then y + let delta = Float.compare xf yf in + if delta = 0 && (not @@ Float.is_nan xf) then Int32.logor x y + else if delta < 0 then x + else if delta > 0 then y else determine_binary_nan x y let max x y = let xf = to_float x in let yf = to_float y in (* max -0 0 is 0 *) - if xf = yf then Int32.logand x y - else if xf > yf then x - else if xf < yf then y + let delta = Float.compare xf yf in + if delta = 0 && (not @@ Float.is_nan xf) then Int32.logand x y + else if delta < 0 then x + else if delta > 0 then y else determine_binary_nan x y (* abs, neg, copysign are purely bitwise operations, even on NaN values *) @@ -151,25 +160,30 @@ let neg x = Int32.logxor x Int32.min_int let copy_sign x y = Int32.logor (abs x) (Int32.logand y Int32.min_int) -let eq x y = to_float x = to_float y +let eq x y = + let x = to_float x in + let y = to_float y in + Float.compare x y = 0 && (not @@ Float.is_nan x) -let ne x y = to_float x <> to_float y +let ne x y = Float.compare (to_float x) (to_float y) <> 0 -let lt x y = to_float x < to_float y +let lt x y = Float.compare (to_float x) (to_float y) < 0 -let gt x y = to_float x > to_float y +let gt x y = Float.compare (to_float x) (to_float y) > 0 -let le x y = to_float x <= to_float y +let le x y = Float.compare (to_float x) (to_float y) <= 0 -let ge x y = to_float x >= to_float y +let ge x y = Float.compare (to_float x) (to_float y) >= 0 (* * Compare mantissa of two floats in string representation (hex or dec). * This is a gross hack to detect rounding during parsing of floats. *) -let is_hex c = ('0' <= c && c <= '9') || ('A' <= c && c <= 'F') +let is_hex c = + (Char.compare '0' c <= 0 && Char.compare c '9' <= 0) + || (Char.compare 'A' c <= 0 && Char.compare c 'F' <= 0) -let is_exp hex c = c = if hex then 'P' else 'E' +let is_exp hex c = Char.compare c (if hex then 'P' else 'E') = 0 let at_end hex s i = i = String.length s || is_exp hex s.[i] @@ -179,7 +193,8 @@ let rec skip_non_hex s i = let rec skip_zeroes s i = let i' = skip_non_hex s i in - if at_end true s i' || s.[i'] <> '0' then i' else skip_zeroes s (i' + 1) + if at_end true s i' || Char.compare s.[i'] '0' <> 0 then i' + else skip_zeroes s (i' + 1) let rec compare_mantissa_str' hex s1 i1 s2 i2 = let i1' = skip_non_hex s1 i1 in @@ -189,7 +204,7 @@ let rec compare_mantissa_str' hex s1 i1 s2 i2 = | true, false -> if at_end hex s2 (skip_zeroes s2 i2') then 0 else -1 | false, true -> if at_end hex s1 (skip_zeroes s1 i1') then 0 else 1 | false, false -> ( - match compare s1.[i1'] s2.[i2'] with + match Char.compare s1.[i1'] s2.[i2'] with | 0 -> compare_mantissa_str' hex s1 (i1' + 1) s2 (i2' + 1) | n -> n ) @@ -208,9 +223,9 @@ let compare_mantissa_str hex s1 s2 = *) let float_of_string_prevent_double_rounding s = (* First parse to a 64 bit float. *) - let z = float_of_string s in + let z = match float_of_string s with None -> assert false | Some z -> z in (* If value is already infinite we are done. *) - if abs_float z = 1.0 /. 0.0 then z + if Float.compare (abs_float z) (1.0 /. 0.0) = 0 then z else (* Else, bit twiddling to see what rounding to target precision will do. *) let open Int64 in @@ -220,7 +235,7 @@ let float_of_string_prevent_double_rounding s = let tie = shift_right lsb 1 in let mask = lognot (shift_left (-1L) 29) in (* If we have no tie, we are good. *) - if logand bits mask <> tie then z + if Int64.ne (logand bits mask) tie then z else (* Else, define epsilon to be the value of the tie bit. *) let exp = float_of_bits (logand bits 0xfff0_0000_0000_0000L) in @@ -228,14 +243,14 @@ let float_of_string_prevent_double_rounding s = (* Convert 64 bit float back to string to compare to input. *) let hex = String.contains s 'x' in let s' = - if not hex then Printf.sprintf "%.*g" (String.length s) z + if not hex then Fmt.str "%.*g" (String.length s) z else let m = logor (logand bits 0xf_ffff_ffff_ffffL) 0x10_0000_0000_0000L in (* Shift mantissa to match msb position in most significant hex digit *) let i = skip_zeroes (String.uppercase_ascii s) 0 in - if i = String.length s then Printf.sprintf "%.*g" (String.length s) z + if i = String.length s then Fmt.str "%.*g" (String.length s) z else let sh = match s.[i] with @@ -244,7 +259,7 @@ let float_of_string_prevent_double_rounding s = | '4' .. '7' -> 2 | _ -> 3 in - Printf.sprintf "%Lx" (shift_left m sh) + Fmt.str "%Lx" (shift_left m sh) in (* - If mantissa became larger, float was rounded up to tie already; * round-to-even might round up again: sub epsilon to round down. @@ -258,15 +273,15 @@ let float_of_string_prevent_double_rounding s = | _ -> z let of_signless_string s = - if s = "inf" then pos_inf - else if s = "nan" then pos_nan - else if String.length s > 6 && String.sub s 0 6 = "nan:0x" then + if String.equal s "inf" then pos_inf + else if String.equal s "nan" then pos_nan + else if String.length s > 6 && String.equal (String.sub s 0 6) "nan:0x" then let x = Int32.of_string (String.sub s 4 (String.length s - 4)) in - if x = Int32.zero then failwith "nan payload must not be zero" - else if Int32.logand x bare_nan <> Int32.zero then - failwith "nan payload must not overlap with exponent bits" - else if x < Int32.zero then - failwith "nan payload must not overlap with sign bit" + if Int32.eq x Int32.zero then Fmt.failwith "nan payload must not be zero" + else if Int32.ne (Int32.logand x bare_nan) Int32.zero then + Fmt.failwith "nan payload must not overlap with exponent bits" + else if Int32.ne x Int32.zero then + Fmt.failwith "nan payload must not overlap with sign bit" else Int32.logor x bare_nan else let s' = String.concat "" (String.split_on_char '_' s) in @@ -274,17 +289,17 @@ let of_signless_string s = if is_inf x then Log.err "of_string" else x let of_string s = - if s = "" then Log.err "of_string" - else if s.[0] = '+' || s.[0] = '-' then + if String.equal s "" then Log.err "of_string" + else if Char.equal s.[0] '+' || Char.equal s.[0] '-' then let x = of_signless_string (String.sub s 1 (String.length s - 1)) in - if s.[0] = '+' then x else neg x + if Char.equal s.[0] '+' then x else neg x else of_signless_string s (* String conversion that groups digits for readability *) -let is_digit c = '0' <= c && c <= '9' +let is_digit = function '0' .. '9' -> true | _ -> false -let is_hex_digit c = is_digit c || ('a' <= c && c <= 'f') +let is_hex_digit = function '0' .. '9' | 'a' .. 'f' -> true | _ -> false let rec add_digits buf s i j k n = if i < j then begin @@ -302,7 +317,7 @@ let group_digits = fun is_digit n s -> let isnt_digit c = not (is_digit c) in let len = String.length s in - let x = Option.value (find_from_opt (( = ) 'x') s 0) ~default:0 in + let x = Option.value (find_from_opt (Char.equal 'x') s 0) ~default:0 in let mant = Option.value (find_from_opt is_digit s x) ~default:len in let point = Option.value (find_from_opt isnt_digit s mant) ~default:len in let frac = Option.value (find_from_opt is_digit s point) ~default:len in @@ -315,18 +330,20 @@ let group_digits = Buffer.add_substring buf s exp (len - exp); Buffer.contents buf -(* TODO: convert all the following to a proper use of Format and stop concatenating strings *) +(* TODO: convert all the following to a proper use of Fmt and stop concatenating strings *) let to_string' convert is_digit n x = - (if x < Int32.zero then "-" else "") - ^ - if is_nan x then - let payload = Int32.logand (abs x) (Int32.lognot bare_nan) in - "nan:0x" ^ group_digits is_hex_digit 4 (to_hex_string payload) - else - let s = convert (to_float (abs x)) in - group_digits is_digit n - (if s.[String.length s - 1] = '.' then s ^ "0" else s) - -let to_string = to_string' (Printf.sprintf "%.17g") is_digit 3 - -let pp fmt v = Format.pp_string fmt (to_string v) + Fmt.str "%s%s" + (if Int32.lt x Int32.zero then "-" else "") + ( if is_nan x then + let payload = Int32.logand (abs x) (Int32.lognot bare_nan) in + Fmt.str "%s%s" "nan:0x" + (group_digits is_hex_digit 4 (to_hex_string payload)) + else + let s = convert (to_float (abs x)) in + group_digits is_digit n + (if Char.equal s.[String.length s - 1] '.' then Fmt.str "%s0" s else s) + ) + +let to_string = to_string' (Fmt.str "%.17g") is_digit 3 + +let pp fmt v = Fmt.string fmt (to_string v) diff --git a/src/primitives/float32.mli b/src/primitives/float32.mli index a380e6716..7b981048f 100644 --- a/src/primitives/float32.mli +++ b/src/primitives/float32.mli @@ -10,6 +10,10 @@ val neg_nan : t val pos_nan : t +val is_neg_nan : t -> bool + +val is_pos_nan : t -> bool + val of_bits : Int32.t -> t val to_bits : t -> Int32.t @@ -66,4 +70,4 @@ val to_float : t -> Float.t val of_float : Float.t -> t -val pp : Format.formatter -> t -> unit +val pp : Fmt.formatter -> t -> unit diff --git a/src/primitives/float64.ml b/src/primitives/float64.ml index 4f3801cb1..27a526f52 100644 --- a/src/primitives/float64.ml +++ b/src/primitives/float64.ml @@ -12,7 +12,7 @@ let neg_nan = 0xfff8_0000_0000_0000L let bare_nan = 0x7ff0_0000_0000_0000L -let to_hex_string = Printf.sprintf "%Lx" +let to_hex_string = Fmt.str "%Lx" type t = Int64.t @@ -28,11 +28,15 @@ let of_bits x = x let to_bits x = x -let is_inf x = x = pos_inf || x = neg_inf +let is_inf x = Int64.eq x pos_inf || Int64.eq x neg_inf let is_nan x = let xf = Int64.float_of_bits x in - xf <> xf + Float.is_nan xf + +let is_pos_nan f = Int64.eq f pos_nan + +let is_neg_nan f = Int64.eq f neg_nan (* * When the result of an arithmetic operation is NaN, the most significant @@ -73,11 +77,11 @@ let binary x op y = let xf = to_float x in let yf = to_float y in let t = op xf yf in - if t = t then of_float t else determine_binary_nan x y + if Float.is_nan t then of_float t else determine_binary_nan x y let unary op x = let t = op (to_float x) in - if t = t then of_float t else determine_unary_nan x + if Float.is_nan t then of_float t else determine_unary_nan x let zero = of_float 0.0 @@ -89,38 +93,40 @@ let mul x y = binary x ( *. ) y let div x y = binary x ( /. ) y -let sqrt x = unary Stdlib.sqrt x +let sqrt x = unary Float.sqrt x -let ceil x = unary Stdlib.ceil x +let ceil x = unary Float.ceil x -let floor x = unary Stdlib.floor x +let floor x = unary Float.floor x let trunc x = let xf = to_float x in (* preserve the sign of zero *) - if xf = 0.0 then x + if Float.equal xf 0.0 then x else (* trunc is either ceil or floor depending on which one is toward zero *) - let f = if xf < 0.0 then Stdlib.ceil xf else Stdlib.floor xf in + let f = + if Float.compare xf 0.0 < 0 then Float.ceil xf else Float.floor xf + in let result = of_float f in if is_nan result then determine_unary_nan result else result let nearest x = let xf = to_float x in (* preserve the sign of zero *) - if xf = 0.0 then x + if Float.compare xf 0.0 = 0 then x else (* nearest is either ceil or floor depending on which is nearest or even *) - let u = Stdlib.ceil xf in - let d = Stdlib.floor xf in + let u = Float.ceil xf in + let d = Float.floor xf in let um = abs_float (xf -. u) in let dm = abs_float (xf -. d) in let u_or_d = - um < dm - || um = dm + Float.compare um dm < 0 + || Float.compare um dm = 0 && let h = u /. 2. in - Stdlib.floor h = h + Float.compare (Float.floor h) h = 0 in let f = if u_or_d then u else d in let result = of_float f in @@ -130,18 +136,18 @@ let min x y = let xf = to_float x in let yf = to_float y in (* min -0 0 is -0 *) - if xf = yf then Int64.logor x y - else if xf < yf then x - else if xf > yf then y + if Float.compare xf yf = 0 then Int64.logor x y + else if Float.compare xf yf = 0 then x + else if Float.compare xf yf = 0 then y else determine_binary_nan x y let max x y = let xf = to_float x in let yf = to_float y in (* max -0 0 is 0 *) - if xf = yf then Int64.logand x y - else if xf > yf then x - else if xf < yf then y + if Float.compare xf yf = 0 then Int64.logand x y + else if Float.compare xf yf > 0 then x + else if Float.compare xf yf < 0 then y else determine_binary_nan x y (* abs, neg, copysign are purely bitwise operations, even on NaN values *) @@ -151,25 +157,25 @@ let neg x = Int64.logxor x Int64.min_int let copy_sign x y = Int64.logor (abs x) (Int64.logand y Int64.min_int) -let eq x y = to_float x = to_float y +let eq x y = Float.compare (to_float x) (to_float y) = 0 -let ne x y = to_float x <> to_float y +let ne x y = Float.compare (to_float x) (to_float y) <> 0 -let lt x y = to_float x < to_float y +let lt x y = Float.compare (to_float x) (to_float y) < 0 -let gt x y = to_float x > to_float y +let gt x y = Float.compare (to_float x) (to_float y) > 0 -let le x y = to_float x <= to_float y +let le x y = Float.compare (to_float x) (to_float y) <= 0 -let ge x y = to_float x >= to_float y +let ge x y = Float.compare (to_float x) (to_float y) >= 0 (* * Compare mantissa of two floats in string representation (hex or dec). * This is a gross hack to detect rounding during parsing of floats. *) -let is_hex c = ('0' <= c && c <= '9') || ('A' <= c && c <= 'F') +let is_hex = function '0' .. '9' | 'A' .. 'F' -> true | _ -> false -let is_exp hex c = c = if hex then 'P' else 'E' +let is_exp hex c = Char.compare c (if hex then 'P' else 'E') = 0 let at_end hex s i = i = String.length s || is_exp hex s.[i] @@ -179,7 +185,8 @@ let rec skip_non_hex s i = let rec skip_zeroes s i = let i' = skip_non_hex s i in - if at_end true s i' || s.[i'] <> '0' then i' else skip_zeroes s (i' + 1) + if at_end true s i' || Char.compare s.[i'] '0' <> 0 then i' + else skip_zeroes s (i' + 1) let rec compare_mantissa_str' hex s1 i1 s2 i2 = let i1' = skip_non_hex s1 i1 in @@ -189,7 +196,7 @@ let rec compare_mantissa_str' hex s1 i1 s2 i2 = | true, false -> if at_end hex s2 (skip_zeroes s2 i2') then 0 else -1 | false, true -> if at_end hex s1 (skip_zeroes s1 i1') then 0 else 1 | false, false -> ( - match compare s1.[i1'] s2.[i2'] with + match Char.compare s1.[i1'] s2.[i2'] with | 0 -> compare_mantissa_str' hex s1 (i1' + 1) s2 (i2' + 1) | n -> n ) @@ -208,9 +215,9 @@ let compare_mantissa_str hex s1 s2 = *) let float_of_string_prevent_double_rounding s = (* First parse to a 64 bit float. *) - let z = float_of_string s in + let z = match float_of_string s with None -> assert false | Some z -> z in (* If value is already infinite we are done. *) - if abs_float z = 1.0 /. 0.0 then z + if Float.equal (abs_float z) (1.0 /. 0.0) then z else (* Else, bit twiddling to see what rounding to target precision will do. *) let open Int64 in @@ -220,7 +227,7 @@ let float_of_string_prevent_double_rounding s = let tie = shift_right lsb 1 in let mask = lognot (shift_left (-1L) 0) in (* If we have no tie, we are good. *) - if logand bits mask <> tie then z + if Int64.ne (logand bits mask) tie then z else (* Else, define epsilon to be the value of the tie bit. *) let exp = float_of_bits (logand bits 0xfff0_0000_0000_0000L) in @@ -228,14 +235,14 @@ let float_of_string_prevent_double_rounding s = (* Convert 64 bit float back to string to compare to input. *) let hex = String.contains s 'x' in let s' = - if not hex then Printf.sprintf "%.*g" (String.length s) z + if not hex then Fmt.str "%.*g" (String.length s) z else let m = logor (logand bits 0xf_ffff_ffff_ffffL) 0x10_0000_0000_0000L in (* Shift mantissa to match msb position in most significant hex digit *) let i = skip_zeroes (String.uppercase_ascii s) 0 in - if i = String.length s then Printf.sprintf "%.*g" (String.length s) z + if i = String.length s then Fmt.str "%.*g" (String.length s) z else let sh = match s.[i] with @@ -244,7 +251,7 @@ let float_of_string_prevent_double_rounding s = | '4' .. '7' -> 2 | _ -> 3 in - Printf.sprintf "%Lx" (shift_left m sh) + Fmt.str "%Lx" (shift_left m sh) in (* - If mantissa became larger, float was rounded up to tie already; * round-to-even might round up again: sub epsilon to round down. @@ -258,15 +265,15 @@ let float_of_string_prevent_double_rounding s = | _ -> z let of_signless_string s = - if s = "inf" then pos_inf - else if s = "nan" then pos_nan - else if String.length s > 6 && String.sub s 0 6 = "nan:0x" then + if String.equal s "inf" then pos_inf + else if String.equal s "nan" then pos_nan + else if String.length s > 6 && String.equal (String.sub s 0 6) "nan:0x" then let x = Int64.of_string (String.sub s 4 (String.length s - 4)) in - if x = Int64.zero then failwith "nan payload must not be zero" - else if Int64.logand x bare_nan <> Int64.zero then - failwith "nan payload must not overlap with exponent bits" - else if x < Int64.zero then - failwith "nan payload must not overlap with sign bit" + if Int64.eq x Int64.zero then Fmt.failwith "nan payload must not be zero" + else if Int64.ne (Int64.logand x bare_nan) Int64.zero then + Fmt.failwith "nan payload must not overlap with exponent bits" + else if Int64.lt x Int64.zero then + Fmt.failwith "nan payload must not overlap with sign bit" else Int64.logor x bare_nan else let s' = String.concat "" (String.split_on_char '_' s) in @@ -274,17 +281,17 @@ let of_signless_string s = if is_inf x then Log.err "of_string" else x let of_string s = - if s = "" then Log.err "of_string" - else if s.[0] = '+' || s.[0] = '-' then + if String.equal s "" then Log.err "of_string" + else if Char.equal s.[0] '+' || Char.equal s.[0] '-' then let x = of_signless_string (String.sub s 1 (String.length s - 1)) in - if s.[0] = '+' then x else neg x + if Char.equal s.[0] '+' then x else neg x else of_signless_string s (* String conversion that groups digits for readability *) -let is_digit c = '0' <= c && c <= '9' +let is_digit = function '0' .. '9' -> true | _ -> false -let is_hex_digit c = is_digit c || ('a' <= c && c <= 'f') +let is_hex_digit = function 'a' .. 'f' -> true | _ -> false let rec add_digits buf s i j k n = if i < j then begin @@ -302,7 +309,7 @@ let group_digits = fun is_digit n s -> let isnt_digit c = not (is_digit c) in let len = String.length s in - let x = Option.value (find_from_opt (( = ) 'x') s 0) ~default:0 in + let x = Option.value (find_from_opt (Char.equal 'x') s 0) ~default:0 in let mant = Option.value (find_from_opt is_digit s x) ~default:len in let point = Option.value (find_from_opt isnt_digit s mant) ~default:len in let frac = Option.value (find_from_opt is_digit s point) ~default:len in @@ -317,20 +324,21 @@ let group_digits = (* TODO: convert all the following to a proper use of Format and stop concatenating strings *) let to_string' convert is_digit n x = - (if x < Int64.zero then "-" else "") - ^ - if is_nan x then - let payload = Int64.logand (abs x) (Int64.lognot bare_nan) in - "nan:0x" ^ group_digits is_hex_digit 4 (to_hex_string payload) - else - let s = convert (to_float (abs x)) in - group_digits is_digit n - (if s.[String.length s - 1] = '.' then s ^ "0" else s) - -let to_string = to_string' (Printf.sprintf "%.17g") is_digit 3 + Fmt.str "%s%s" + (if Int64.lt x Int64.zero then "-" else "") + ( if is_nan x then + let payload = Int64.logand (abs x) (Int64.lognot bare_nan) in + Fmt.str "%s%s" "nan:0x" + (group_digits is_hex_digit 4 (to_hex_string payload)) + else + let s = convert (to_float (abs x)) in + group_digits is_digit n + (if Char.equal s.[String.length s - 1] '.' then Fmt.str "%s0" s else s) + ) + +let to_string = to_string' (Fmt.str "%.17g") is_digit 3 let to_hex_string x = - if is_inf x then to_string x - else to_string' (Printf.sprintf "%h") is_hex_digit 4 x + if is_inf x then to_string x else to_string' (Fmt.str "%h") is_hex_digit 4 x -let pp fmt v = Format.pp_string fmt (to_string v) +let pp fmt v = Fmt.string fmt (to_string v) diff --git a/src/primitives/float64.mli b/src/primitives/float64.mli index c31f558b8..54ea182ee 100644 --- a/src/primitives/float64.mli +++ b/src/primitives/float64.mli @@ -10,6 +10,10 @@ val neg_nan : t val pos_nan : t +val is_neg_nan : t -> bool + +val is_pos_nan : t -> bool + val of_bits : Int64.t -> t val to_bits : t -> Int64.t @@ -66,4 +70,4 @@ val to_float : t -> Float.t val of_float : Float.t -> t -val pp : Format.formatter -> t -> unit +val pp : Fmt.formatter -> t -> unit diff --git a/src/primitives/int32.ml b/src/primitives/int32.ml index bec4f9b2c..79bf15d85 100644 --- a/src/primitives/int32.ml +++ b/src/primitives/int32.ml @@ -10,17 +10,17 @@ (* Copyright © 2021-2024 OCamlPro *) (* Modified by the Owi programmers *) -include Stdlib.Int32 +include Prelude.Int32 -let clz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_leading_zeros n) +let clz n = of_int (Ocaml_intrinsics.Int32.count_leading_zeros n) -let ctz n = Stdlib.Int32.of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n) +let ctz n = of_int (Ocaml_intrinsics.Int32.count_trailing_zeros n) (* Taken from Base *) let popcnt = let mask = 0xffff_ffffL in fun [@inline] x -> - Stdlib.Int64.to_int32 (Int64.popcnt (Int64.logand (Int64.of_int32 x) mask)) + Int64.to_int32 (Int64.popcnt (Int64.logand (Int64.of_int32 x) mask)) let of_int64 = Int64.to_int32 @@ -29,6 +29,26 @@ let to_int64 = Int64.of_int32 (* Unsigned comparison in terms of signed comparison. *) let cmp_u x op y = op (add x min_int) (add y min_int) +let eq (x : int32) y = equal x y + +let ne (x : int32) y = compare x y <> 0 + +let lt (x : int32) y = compare x y < 0 + +let gt (x : int32) y = compare x y > 0 + +let le (x : int32) y = compare x y >= 0 + +let ge (x : int32) y = compare x y <= 0 + +let lt_u x y = cmp_u x lt y + +let le_u x y = cmp_u x le y + +let gt_u x y = cmp_u x gt y + +let ge_u x y = cmp_u x ge y + (* If bit (32 - 1) is set, sx will sign-extend t to maintain the * invariant that small ints are stored sign-extended inside a wider int. *) let sx x = @@ -65,26 +85,6 @@ let extend_s n x = let shift = 32 - n in shift_right (shift_left x shift) shift -let eq (x : int32) y = x = y - -let ne (x : int32) y = x <> y - -let lt (x : int32) y = x < y - -let gt (x : int32) y = x > y - -let le (x : int32) y = x <= y - -let ge (x : int32) y = x >= y - -let lt_u x y = cmp_u x ( < ) y - -let le_u x y = cmp_u x ( <= ) y - -let gt_u x y = cmp_u x ( > ) y - -let ge_u x y = cmp_u x ( >= ) y - (* String conversion that allows leading signs and unsigned values *) let require b = if not b then Log.err "of_string (int32)" @@ -105,7 +105,7 @@ let max_lower = unsigned_rem minus_one 10l let sign_extend i = let sign_bit = logand (of_int (1 lsl (32 - 1))) i in - if sign_bit = zero then i + if eq sign_bit zero then i else (* Build a sign-extension mask *) let sign_mask = shift_left minus_one 32 in @@ -115,7 +115,7 @@ let of_string s = let len = String.length s in let rec parse_hex i num = if i = len then num - else if s.[i] = '_' then parse_hex (i + 1) num + else if Char.equal s.[i] '_' then parse_hex (i + 1) num else let digit = of_int (hex_digit s.[i]) in require (le_u num (shr_u minus_one (of_int 4))); @@ -123,15 +123,15 @@ let of_string s = in let rec parse_dec i num = if i = len then num - else if s.[i] = '_' then parse_dec (i + 1) num + else if Char.equal s.[i] '_' then parse_dec (i + 1) num else let digit = of_int (dec_digit s.[i]) in - require (lt_u num max_upper || (num = max_upper && le_u digit max_lower)); + require (lt_u num max_upper || (eq num max_upper && le_u digit max_lower)); parse_dec (i + 1) (add (mul num 10l) digit) in let parse_int i = require (len - i > 0); - if i + 2 <= len && s.[i] = '0' && s.[i + 1] = 'x' then + if i + 2 <= len && Char.equal s.[i] '0' && Char.equal s.[i + 1] 'x' then parse_hex (i + 2) zero else parse_dec i zero in @@ -141,12 +141,10 @@ let of_string s = | '+' -> parse_int 1 | '-' -> let n = parse_int 1 in - require (sub n one >= minus_one); + require (ge (sub n one) minus_one); neg n | _ -> parse_int 0 in let parsed = sign_extend parsed in - require (low_int <= parsed && parsed <= high_int); + require (le low_int parsed && le parsed high_int); parsed - -let eq_const (i : int32) j = i = j diff --git a/src/primitives/int32.mli b/src/primitives/int32.mli index a94f90dc5..1c2685a1b 100644 --- a/src/primitives/int32.mli +++ b/src/primitives/int32.mli @@ -101,5 +101,3 @@ val unsigned_div : t -> t -> t val rem : t -> t -> t val unsigned_rem : t -> t -> t - -val eq_const : t -> int32 -> bool diff --git a/src/primitives/int64.ml b/src/primitives/int64.ml index bf8d0d604..c4dfbf4bd 100644 --- a/src/primitives/int64.ml +++ b/src/primitives/int64.ml @@ -10,11 +10,11 @@ (* Copyright © 2021-2024 OCamlPro *) (* Modified by the Owi programmers *) -include Stdlib.Int64 +include Prelude.Int64 -let clz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_leading_zeros n) +let clz n = of_int (Ocaml_intrinsics.Int64.count_leading_zeros n) -let ctz n = Stdlib.Int64.of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n) +let ctz n = of_int (Ocaml_intrinsics.Int64.count_trailing_zeros n) (* Taken from Base *) let popcnt = @@ -46,19 +46,39 @@ let popcnt = *) let cmp_u x op y = op (add x min_int) (add y min_int) +let eq (x : int64) y = equal x y + +let ne (x : int64) y = not (equal x y) + +let lt (x : int64) y = compare x y < 0 + +let gt (x : int64) y = compare x y > 0 + +let le (x : int64) y = compare x y <= 0 + +let ge (x : int64) y = compare x y >= 0 + +let lt_u x y = cmp_u x lt y + +let le_u x y = cmp_u x le y + +let gt_u x y = cmp_u x gt y + +let ge_u x y = cmp_u x ge y + (* * Unsigned division and remainder in terms of signed division; algorithm from * Hacker's Delight, Second Edition, by Henry S. Warren, Jr., section 9-3 * "Unsigned Short Division from Signed Division". *) let divrem_u n d = - if d = zero then raise Division_by_zero + if equal d zero then raise Division_by_zero else let t = shift_right d 63 in let n' = logand n (lognot t) in let q = shift_left (div (shift_right_logical n' 1) d) 1 in let r = sub n (mul q d) in - if cmp_u r ( < ) d then (q, r) else (add q one, sub r d) + if cmp_u r lt d then (q, r) else (add q one, sub r d) (* We don't override min_int and max_int since those are used * by other functions (like parsing), and rely on it being @@ -90,26 +110,6 @@ let extend_s n x = let shift = 64 - n in shift_right (shift_left x shift) shift -let eq (x : int64) y = x = y - -let ne (x : int64) y = x <> y - -let lt (x : int64) y = x < y - -let gt (x : int64) y = x > y - -let le (x : int64) y = x <= y - -let ge (x : int64) y = x >= y - -let lt_u x y = cmp_u x ( < ) y - -let le_u x y = cmp_u x ( <= ) y - -let gt_u x y = cmp_u x ( > ) y - -let ge_u x y = cmp_u x ( >= ) y - (* String conversion that allows leading signs and unsigned values *) let require b = if not b then Log.err "of_string (int64)" @@ -130,7 +130,7 @@ let of_string s = let len = String.length s in let rec parse_hex i num = if i = len then num - else if s.[i] = '_' then parse_hex (i + 1) num + else if Char.equal s.[i] '_' then parse_hex (i + 1) num else let digit = of_int (hex_digit s.[i]) in require (le_u num (shr_u minus_one (of_int 4))); @@ -138,15 +138,15 @@ let of_string s = in let rec parse_dec i num = if i = len then num - else if s.[i] = '_' then parse_dec (i + 1) num + else if Char.equal s.[i] '_' then parse_dec (i + 1) num else let digit = of_int (dec_digit s.[i]) in - require (lt_u num max_upper || (num = max_upper && le_u digit max_lower)); + require (lt_u num max_upper || (eq num max_upper && le_u digit max_lower)); parse_dec (i + 1) (add (mul num 10L) digit) in let parse_int i = require (len - i > 0); - if i + 2 <= len && s.[i] = '0' && s.[i + 1] = 'x' then + if i + 2 <= len && Char.equal s.[i] '0' && Char.equal s.[i + 1] 'x' then parse_hex (i + 2) zero else parse_dec i zero in @@ -156,11 +156,9 @@ let of_string s = | '+' -> parse_int 1 | '-' -> let n = parse_int 1 in - require (sub n one >= minus_one); + require (ge (sub n one) minus_one); neg n | _ -> parse_int 0 in - require (low_int <= parsed && parsed <= high_int); + require (le low_int parsed && le parsed high_int); parsed - -let eq_const (i : int64) j = i = j diff --git a/src/primitives/int64.mli b/src/primitives/int64.mli index 1fbefcaf3..11939ebb8 100644 --- a/src/primitives/int64.mli +++ b/src/primitives/int64.mli @@ -105,5 +105,3 @@ val unsigned_div : t -> t -> t val rem : t -> t -> t val unsigned_rem : t -> t -> t - -val eq_const : t -> int64 -> bool diff --git a/src/script/script.ml b/src/script/script.ml index 64753ad4d..e799c6c19 100644 --- a/src/script/script.ml +++ b/src/script/script.ml @@ -16,14 +16,17 @@ end let check_error ~expected ~got : unit Result.t = let ok = - Result.err_to_string got = expected + String.equal (Result.err_to_string got) expected || String.starts_with ~prefix:expected (Result.err_to_string got) - || ( got = `Constant_out_of_range - || got = `Msg "constant out of range" - || got = `Parse_fail "constant out of range" ) - && (expected = "i32 constant out of range" || expected = "i32 constant") - || got = `Msg "unexpected end of section or function" - && expected = "section size mismatch" + || + match got with + | `Constant_out_of_range + | `Msg "constant out of range" + | `Parse_fail "constant out of range" -> + String.starts_with ~prefix:"i32 constant" expected + | `Msg "unexpected end of section or function" -> + String.equal expected "section size mismatch" + | _ -> false in if not ok then begin Error (`Failed_with_but_expected (got, expected)) @@ -70,10 +73,10 @@ let load_global_from_module ls mod_id name = let compare_result_const result (const : Concrete_value.t) = match (result, const) with - | Text.Result_const (Literal (Const_I32 n)), I32 n' -> n = n' - | Result_const (Literal (Const_I64 n)), I64 n' -> n = n' - | Result_const (Literal (Const_F32 n)), F32 n' -> n = n' - | Result_const (Literal (Const_F64 n)), F64 n' -> n = n' + | Text.Result_const (Literal (Const_I32 n)), I32 n' -> Int32.eq n n' + | Result_const (Literal (Const_I64 n)), I64 n' -> Int64.eq n n' + | Result_const (Literal (Const_F32 n)), F32 n' -> Float32.eq n n' + | Result_const (Literal (Const_F64 n)), F64 n' -> Float64.eq n n' | Result_const (Literal (Const_null Func_ht)), Ref (Funcref None) -> true | Result_const (Literal (Const_null Extern_ht)), Ref (Externref None) -> true | Result_const (Literal (Const_extern n)), Ref (Externref (Some ref)) -> begin @@ -82,15 +85,15 @@ let compare_result_const result (const : Concrete_value.t) = | Some n' -> n = n' end | Result_const (Nan_canon S32), F32 f -> - f = Float32.pos_nan || f = Float32.neg_nan + Float32.is_pos_nan f || Float32.is_neg_nan f | Result_const (Nan_canon S64), F64 f -> - f = Float64.pos_nan || f = Float64.neg_nan + Float64.is_pos_nan f || Float64.is_neg_nan f | Result_const (Nan_arith S32), F32 f -> let pos_nan = Float32.to_bits Float32.pos_nan in - Int32.logand (Float32.to_bits f) pos_nan = pos_nan + Int32.eq (Int32.logand (Float32.to_bits f) pos_nan) pos_nan | Result_const (Nan_arith S64), F64 f -> let pos_nan = Float64.to_bits Float64.pos_nan in - Int64.logand (Float64.to_bits f) pos_nan = pos_nan + Int64.eq (Int64.logand (Float64.to_bits f) pos_nan) pos_nan | Result_const (Nan_arith _), _ | Result_const (Nan_canon _), _ | Result_const (Literal (Const_I32 _)), _ @@ -120,7 +123,7 @@ let value_of_const : text const -> V.t Result.t = function let action (link_state : Concrete_value.Func.extern_func Link.state) = function | Text.Invoke (mod_id, f, args) -> begin Log.debug5 "invoke %a %s %a...@\n" - (Format.pp_option ~none:Format.pp_nothing Format.pp_string) + (Fmt.option ~none:Fmt.nop Fmt.string) mod_id f Types.pp_consts args; let* f, env_id = load_func_from_module link_state mod_id f in let* stack = list_map value_of_const args in @@ -230,7 +233,7 @@ let run ~no_exhaustion ~optimize script = List.compare_lengths res stack <> 0 || not (List.for_all2 compare_result_const res (List.rev stack)) then begin - Format.pp_err "got: %a@.expected: %a@." Stack.pp (List.rev stack) + Fmt.epr "got: %a@.expected: %a@." Stack.pp (List.rev stack) Text.pp_results res; Error `Bad_result end @@ -250,7 +253,7 @@ let run ~no_exhaustion ~optimize script = in link_state | Register (name, mod_name) -> - if !curr_module = 1 && !registered = false then Log.debug_on := false; + if !curr_module = 1 && not !registered then Log.debug_on := false; Log.debug0 "*** register@\n"; let+ state = Link.register_module link_state ~name ~id:mod_name in Log.debug_on := debug_on; diff --git a/src/script/spectest.ml b/src/script/spectest.ml index dbfcfc092..77ef66a50 100644 --- a/src/script/spectest.ml +++ b/src/script/spectest.ml @@ -2,7 +2,7 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -open Format +open Fmt open Types open Concrete_value.Func @@ -10,10 +10,10 @@ type extern_module = extern_func Link.extern_module let extern_m = let print = () in - let print_i32 i = pp_std "%li@\n%!" i in - let print_i64 i = pp_std "%Li@\n%!" i in - let print_f32 f = pp_std "%a@\n%!" Float32.pp f in - let print_f64 f = pp_std "%a@\n%!" Float64.pp f in + let print_i32 i = pr "%li@\n%!" i in + let print_i64 i = pr "%Li@\n%!" i in + let print_f32 f = pr "%a@\n%!" Float32.pp f in + let print_f64 f = pr "%a@\n%!" Float64.pp f in let print_i32_f32 i f = print_i32 i; print_f32 f diff --git a/src/symbolic/solver.ml b/src/symbolic/solver.ml index 76316f5f6..3d57f5070 100644 --- a/src/symbolic/solver.ml +++ b/src/symbolic/solver.ml @@ -20,7 +20,10 @@ let check (S (solver_module, s)) pc = let model (S (solver_module, s)) ~symbols ~pc = let module Solver = (val solver_module) in - assert (Solver.check s pc = `Sat); - match Solver.model ?symbols s with - | None -> assert false - | Some model -> model + match Solver.check s pc with + | `Sat -> begin + match Solver.model ?symbols s with + | None -> assert false + | Some model -> model + end + | `Unsat | `Unknown -> assert false diff --git a/src/symbolic/symbolic_choice.ml b/src/symbolic/symbolic_choice.ml index 46e286d48..05de59027 100644 --- a/src/symbolic/symbolic_choice.ml +++ b/src/symbolic/symbolic_choice.ml @@ -339,7 +339,7 @@ let add_breadcrumb crumb = let with_new_symbol ty f = let* thread in let n = Thread.symbols thread in - let sym = Format.ksprintf (Smtml.Symbol.make ty) "symbol_%d" n in + let sym = Fmt.kstr (Smtml.Symbol.make ty) "symbol_%d" n in let+ () = modify_thread (fun thread -> let thread = Thread.add_symbol thread sym in @@ -372,7 +372,8 @@ let get_model_or_stop symbol = let model = Solver.model solver ~symbols ~pc in match Smtml.Model.evaluate model symbol with | None -> - failwith "Unreachable: The model exists so this symbol should evaluate" + Fmt.failwith + "Unreachable: The model exists so this symbol should evaluate" | Some v -> return v end @@ -381,7 +382,7 @@ let select_inner ~explore_first (cond : Symbolic_value.vbool) = match Smtml.Expr.view v with | Val True -> return true | Val False -> return false - | Val (Num (I32 _)) -> failwith "unreachable (type error)" + | Val (Num (I32 _)) -> Fmt.failwith "unreachable (type error)" | _ -> let true_branch = let* () = add_pc v in @@ -409,7 +410,7 @@ let summary_symbol (e : Smtml.Expr.t) = | _ -> let num_symbols = Thread.symbols thread in let+ () = modify_thread Thread.incr_symbols in - let sym_name = Format.sprintf "choice_i32_%i" num_symbols in + let sym_name = Fmt.str "choice_i32_%i" num_symbols in let sym_type = Smtml.Ty.Ty_bitv 32 in let sym = Smtml.Symbol.make sym_type sym_name in let assign = Smtml.Expr.(relop Ty_bool Eq (mk_symbol sym) e) in @@ -430,7 +431,7 @@ let select_i32 (i : Symbolic_value.int32) = let i = match possible_value with | Smtml.Value.Num (I32 i) -> i - | _ -> failwith "Unreachable: found symbol must be a value" + | _ -> Fmt.failwith "Unreachable: found symbol must be a value" in let s = Smtml.Expr.mk_symbol symbol in let this_value_cond = diff --git a/src/symbolic/symbolic_choice_minimalist.ml b/src/symbolic/symbolic_choice_minimalist.ml index 3213da592..24e55a32a 100644 --- a/src/symbolic/symbolic_choice_minimalist.ml +++ b/src/symbolic/symbolic_choice_minimalist.ml @@ -8,10 +8,11 @@ type err = | Assert_fail | Trap of Trap.t -type 'a t = M of (Thread.t -> Solver.t -> ('a, err) Stdlib.Result.t * Thread.t) +type 'a t = + | M of (Thread.t -> Solver.t -> ('a, err) Prelude.Result.t * Thread.t) [@@unboxed] -type 'a run_result = ('a, err) Stdlib.Result.t * Thread.t +type 'a run_result = ('a, err) Prelude.Result.t * Thread.t let return v = M (fun t _sol -> (Ok v, t)) @@ -38,7 +39,7 @@ let select (vb : vbool) = match Smtml.Expr.view v with | Val True -> return true | Val False -> return false - | _ -> Format.kasprintf failwith "%a" Smtml.Expr.pp v + | _ -> Fmt.failwith "%a" Smtml.Expr.pp v let select_i32 (i : int32) = let v = Smtml.Expr.simplify i in diff --git a/src/symbolic/symbolic_choice_minimalist.mli b/src/symbolic/symbolic_choice_minimalist.mli index 7c032c500..a4a264317 100644 --- a/src/symbolic/symbolic_choice_minimalist.mli +++ b/src/symbolic/symbolic_choice_minimalist.mli @@ -9,7 +9,7 @@ type err = private include Choice_intf.Complete with type thread := Thread.t - and type 'a run_result = ('a, err) Stdlib.Result.t * Thread.t + and type 'a run_result = ('a, err) Prelude.Result.t * Thread.t and module V := Symbolic_value val run : diff --git a/src/symbolic/symbolic_memory.ml b/src/symbolic/symbolic_memory.ml index 3ac62f85e..fc55b8ffb 100644 --- a/src/symbolic/symbolic_memory.ml +++ b/src/symbolic/symbolic_memory.ml @@ -154,7 +154,8 @@ let extract v pos = value (Num (I8 i')) | Cvtop (_, Zero_extend 24, ({ node = Symbol _; _ } as sym)) | Cvtop (_, Sign_extend 24, ({ node = Symbol _; _ } as sym)) - when ty sym = Ty_bitv 8 -> + when match ty sym with Ty_bitv 8 -> true | _ -> false -> + (* TODO: implement an equal function in smtml for this *) sym | _ -> make (Extract (v, pos + 1, pos)) @@ -187,12 +188,14 @@ let check_within_bounds m a = let upper_bound = Value.(I32.ge (const_i32 ptr) (I32.add (const_i32 base) size)) in - Ok (Value.Bool.(or_ (const (ptr < base)) upper_bound), Value.const_i32 ptr) - ) + Ok + ( Value.Bool.(or_ (const (Int32.lt ptr base)) upper_bound) + , Value.const_i32 ptr ) ) | _ -> Log.err {|Unable to calculate address of: "%a"|} Expr.pp a let free m base = - if not @@ Hashtbl.mem m.chunks base then failwith "Memory leak double free"; + if not @@ Hashtbl.mem m.chunks base then + Fmt.failwith "Memory leak double free"; Hashtbl.remove m.chunks base let replace_size m base size = Hashtbl.replace m.chunks base size diff --git a/src/symbolic/symbolic_value.ml b/src/symbolic/symbolic_value.ml index d32f69fdc..c3ee37722 100644 --- a/src/symbolic/symbolic_value.ml +++ b/src/symbolic/symbolic_value.ml @@ -169,25 +169,33 @@ module I32 = struct end | _ -> relop Ty_bool Eq e (const_i32 c) - let eq e1 e2 = if e1 == e2 then Bool.const true else relop Ty_bool Eq e1 e2 + let eq e1 e2 = + if phys_equal e1 e2 then Bool.const true else relop Ty_bool Eq e1 e2 - let ne e1 e2 = if e1 == e2 then Bool.const false else relop Ty_bool Ne e1 e2 + let ne e1 e2 = + if phys_equal e1 e2 then Bool.const false else relop Ty_bool Ne e1 e2 - let lt e1 e2 = if e1 == e2 then Bool.const false else relop ty Lt e1 e2 + let lt e1 e2 = + if phys_equal e1 e2 then Bool.const false else relop ty Lt e1 e2 - let gt e1 e2 = if e1 == e2 then Bool.const false else relop ty Gt e1 e2 + let gt e1 e2 = + if phys_equal e1 e2 then Bool.const false else relop ty Gt e1 e2 - let lt_u e1 e2 = if e1 == e2 then Bool.const false else relop ty LtU e1 e2 + let lt_u e1 e2 = + if phys_equal e1 e2 then Bool.const false else relop ty LtU e1 e2 - let gt_u e1 e2 = if e1 == e2 then Bool.const false else relop ty GtU e1 e2 + let gt_u e1 e2 = + if phys_equal e1 e2 then Bool.const false else relop ty GtU e1 e2 - let le e1 e2 = if e1 == e2 then Bool.const true else relop ty Le e1 e2 + let le e1 e2 = if phys_equal e1 e2 then Bool.const true else relop ty Le e1 e2 - let ge e1 e2 = if e1 == e2 then Bool.const true else relop ty Ge e1 e2 + let ge e1 e2 = if phys_equal e1 e2 then Bool.const true else relop ty Ge e1 e2 - let le_u e1 e2 = if e1 == e2 then Bool.const true else relop ty LeU e1 e2 + let le_u e1 e2 = + if phys_equal e1 e2 then Bool.const true else relop ty LeU e1 e2 - let ge_u e1 e2 = if e1 == e2 then Bool.const true else relop ty GeU e1 e2 + let ge_u e1 e2 = + if phys_equal e1 e2 then Bool.const true else relop ty GeU e1 e2 let to_bool (e : vbool) = match view e with diff --git a/src/symbolic/symbolic_wasm_ffi.ml b/src/symbolic/symbolic_wasm_ffi.ml index da589cb4a..962dec31c 100644 --- a/src/symbolic/symbolic_wasm_ffi.ml +++ b/src/symbolic/symbolic_wasm_ffi.ml @@ -68,9 +68,7 @@ module M : let+ base = ptr p in Memory.free m base - let exit (p : Value.int32) : unit Choice.t = - ignore p; - abort () + let exit (_p : Value.int32) : unit Choice.t = abort () end type extern_func = Symbolic.P.Extern_func.extern_func diff --git a/src/text_to_binary/assigned.ml b/src/text_to_binary/assigned.ml index 5244bac7c..33e93e047 100644 --- a/src/text_to_binary/assigned.ml +++ b/src/text_to_binary/assigned.ml @@ -8,18 +8,11 @@ open Syntax module StrType = struct type t = binary str_type - let compare = compare + let compare = Types.compare_str_type end module TypeMap = Map.Make (StrType) -let equal_func_types (a : binary func_type) (b : binary func_type) : bool = - let remove_param (pt, rt) = - let pt = List.map (fun (_id, vt) -> (None, vt)) pt in - (pt, rt) - in - remove_param a = remove_param b - type t = { id : string option ; typ : binary str_type Named.t @@ -112,7 +105,7 @@ let name kind ~get_name values = | Some name -> let index = Indexed.get_index elt in if String_map.mem name named then - Error (`Msg (Format.sprintf "duplicate %s %s" kind name)) + Error (`Msg (Fmt.str "duplicate %s %s" kind name)) else ok @@ String_map.add name index named in let+ named = list_fold_left assign_one String_map.empty values in @@ -128,7 +121,7 @@ let check_type_id (types : binary str_type Named.t) | None -> Error (`Unknown_type (Raw id)) | Some (Def_func_t func_type') -> let* func_type = Binary_types.convert_func_type None func_type in - if not (equal_func_types func_type func_type') then + if not (Types.func_type_eq func_type func_type') then Error `Inline_function_type else Ok () | Some _ -> assert false diff --git a/src/text_to_binary/rewrite.ml b/src/text_to_binary/rewrite.ml index d456b5649..ac0ff4749 100644 --- a/src/text_to_binary/rewrite.ml +++ b/src/text_to_binary/rewrite.ml @@ -8,7 +8,7 @@ open Syntax module StrType = struct type t = binary str_type - let compare = compare + let compare x y = if Types.str_type_eq x y then 0 else 1 end module TypeMap = Map.Make (StrType) @@ -47,11 +47,11 @@ let rewrite_expr (modul : Assigned.t) (locals : binary param list) begin try List.iteri - (fun i n -> - if n = Some id then begin + (fun i -> function + | Some id' when String.equal id id' -> pos := i; raise Exit - end ) + | None | Some _ -> () ) block_ids with Exit -> () end; @@ -83,7 +83,7 @@ let rewrite_expr (modul : Assigned.t) (locals : binary param list) let+ v = get (`Unknown_type ind) modul.typ ind in match Indexed.get v with Def_func_t t' -> t' | _ -> assert false in - let ok = Binary_types.equal_func_types t t' in + let ok = Types.func_type_eq t t' in if not ok then Error `Inline_function_type else Ok (Bt_raw (None, t)) ) in @@ -296,9 +296,7 @@ let rewrite_block_type (typemap : binary indice TypeMap.t) (modul : Assigned.t) try Ok (TypeMap.find (Def_func_t t) typemap) with Not_found -> Error - (`Msg - (Format.asprintf "Missing func type in index table %a" pp_func_type - t ) ) + (`Msg (Fmt.str "Missing func type in index table %a" pp_func_type t)) in Bt_raw (Some idx, t) diff --git a/src/utils/format.ml b/src/utils/format.ml deleted file mode 100644 index 59d76c66e..000000000 --- a/src/utils/format.ml +++ /dev/null @@ -1,35 +0,0 @@ -(* SPDX-License-Identifier: AGPL-3.0-or-later *) -(* Copyright © 2021-2024 OCamlPro *) -(* Written by the Owi programmers *) - -include Stdlib.Format - -let pp = fprintf - -let pp_err = eprintf - -let pp_std = printf - -let pp_nothing _fmt () = () - -let pp_char = pp_print_char - -let pp_list = pp_print_list - -let pp_array = pp_print_array - -let pp_iter = pp_print_iter - -let pp_string = pp_print_string - -let pp_option = pp_print_option - -let pp_bool = pp_print_bool - -let pp_flush = pp_print_flush - -let pp_space fmt () = pp_string fmt " " - -let pp_newline fmt () = pp fmt "@\n" - -let pp_int = pp_print_int diff --git a/src/utils/format.mli b/src/utils/format.mli deleted file mode 100644 index f4e1c432b..000000000 --- a/src/utils/format.mli +++ /dev/null @@ -1,64 +0,0 @@ -(* SPDX-License-Identifier: AGPL-3.0-or-later *) -(* Copyright © 2021-2024 OCamlPro *) -(* Written by the Owi programmers *) - -type formatter = Stdlib.Format.formatter - -val pp : formatter -> ('a, formatter, unit) format -> 'a - -val pp_err : ('a, formatter, unit) format -> 'a - -val pp_std : ('a, formatter, unit) format -> 'a - -val pp_nothing : formatter -> unit -> unit - -val pp_space : formatter -> unit -> unit - -val pp_bool : formatter -> bool -> unit - -val pp_char : formatter -> char -> unit - -val pp_int : formatter -> int -> unit - -val pp_flush : formatter -> unit -> unit - -val pp_list : - ?pp_sep:(formatter -> unit -> unit) - -> (formatter -> 'a -> unit) - -> formatter - -> 'a list - -> unit - -val pp_array : - ?pp_sep:(formatter -> unit -> unit) - -> (formatter -> 'a -> unit) - -> formatter - -> 'a array - -> unit - -val pp_iter : - ?pp_sep:(formatter -> unit -> unit) - -> (('a -> unit) -> 'b -> unit) - -> (formatter -> 'a -> unit) - -> formatter - -> 'b - -> unit - -val pp_string : formatter -> string -> unit - -val pp_option : - ?none:(formatter -> unit -> unit) - -> (formatter -> 'a -> unit) - -> formatter - -> 'a option - -> unit - -val pp_newline : formatter -> unit -> unit - -val sprintf : ('a, unit, string) format -> 'a - -val ksprintf : (string -> 'a) -> ('b, unit, string, 'a) format4 -> 'b - -val asprintf : ('a, formatter, unit, string) format4 -> 'a - -val kasprintf : (string -> 'a) -> ('b, formatter, unit, 'a) format4 -> 'b diff --git a/src/utils/log.ml b/src/utils/log.ml index 2917c7f61..476ca4cb8 100644 --- a/src/utils/log.ml +++ b/src/utils/log.ml @@ -6,14 +6,15 @@ let debug_on = ref false let profiling_on = ref false -let debug0 t : unit = if !debug_on then Format.pp_err t +let debug0 t : unit = if !debug_on then Fmt.epr t -let debug1 t a : unit = if !debug_on then Format.pp_err t a +let debug1 t a : unit = if !debug_on then Fmt.epr t a -let debug2 t a b : unit = if !debug_on then Format.pp_err t a b +let debug2 t a b : unit = if !debug_on then Fmt.epr t a b -let debug5 t a b c d e : unit = if !debug_on then Format.pp_err t a b c d e +let debug5 t a b c d e : unit = if !debug_on then Fmt.epr t a b c d e -let profile3 t a b c : unit = if !profiling_on then Format.pp_err t a b c +let profile3 t a b c : unit = if !profiling_on then Fmt.epr t a b c -let err f = Format.kasprintf failwith f +(* TODO: remove this *) +let err f = Fmt.failwith f diff --git a/src/utils/log.mli b/src/utils/log.mli index f5adcecd2..385ec587f 100644 --- a/src/utils/log.mli +++ b/src/utils/log.mli @@ -11,15 +11,14 @@ val debug_on : bool ref val profiling_on : bool ref (** print some debug info *) -val debug0 : (unit, Format.formatter, unit) format -> unit +val debug0 : (unit, Fmt.formatter, unit) format -> unit -val debug1 : ('a -> unit, Format.formatter, unit) format -> 'a -> unit +val debug1 : ('a -> unit, Fmt.formatter, unit) format -> 'a -> unit -val debug2 : - ('a -> 'b -> unit, Format.formatter, unit) format -> 'a -> 'b -> unit +val debug2 : ('a -> 'b -> unit, Fmt.formatter, unit) format -> 'a -> 'b -> unit val debug5 : - ('a -> 'b -> 'c -> 'd -> 'e -> unit, Format.formatter, unit) format + ('a -> 'b -> 'c -> 'd -> 'e -> unit, Fmt.formatter, unit) format -> 'a -> 'b -> 'c @@ -29,11 +28,7 @@ val debug5 : (** print some profiling info *) val profile3 : - ('a -> 'b -> 'c -> unit, Format.formatter, unit) format - -> 'a - -> 'b - -> 'c - -> unit + ('a -> 'b -> 'c -> unit, Fmt.formatter, unit) format -> 'a -> 'b -> 'c -> unit (** print some error and exit *) -val err : ('a, Format.formatter, unit, 'b) format4 -> 'a +val err : ('a, Fmt.formatter, unit, 'b) format4 -> 'a diff --git a/src/utils/result.ml b/src/utils/result.ml index 0a090ec7c..99d6d5c4f 100644 --- a/src/utils/result.ml +++ b/src/utils/result.ml @@ -2,7 +2,7 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -include Stdlib.Result +include Prelude.Result type err = [ `Alignment_too_large @@ -61,7 +61,7 @@ type err = | `Unsupported_file_extension of string ] -type 'a t = ('a, err) Stdlib.Result.t +type 'a t = ('a, err) Prelude.Result.t let rec err_to_string = function | `Alignment_too_large -> "alignment must not be larger than natural" @@ -71,19 +71,19 @@ let rec err_to_string = function | `Constant_expression_required -> "constant expression required" | `Constant_out_of_range -> "constant out of range" | `Did_not_fail_but_expected expected -> - Format.sprintf "expected %s but there was no error" expected + Fmt.str "expected %s but there was no error" expected | `Duplicate_export_name -> "duplicate export name" - | `Duplicate_global id -> Format.sprintf "duplicate global %s" id - | `Duplicate_local id -> Format.sprintf "duplicate local %s" id - | `Duplicate_memory id -> Format.sprintf "duplicate memory %s" id - | `Duplicate_table id -> Format.sprintf "duplicate table %s" id + | `Duplicate_global id -> Fmt.str "duplicate global %s" id + | `Duplicate_local id -> Fmt.str "duplicate local %s" id + | `Duplicate_memory id -> Fmt.str "duplicate memory %s" id + | `Duplicate_table id -> Fmt.str "duplicate table %s" id | `Failed_with_but_expected (got, expected) -> - Format.sprintf "expected %s but got (%s)" expected (err_to_string got) + Fmt.str "expected %s but got (%s)" expected (err_to_string got) | `Found_bug n -> - if n > 1 then Format.sprintf "Reached %d problems!" n - else Format.sprintf "Reached problem!" + if n > 1 then Fmt.str "Reached %d problems!" n + else Fmt.str "Reached problem!" | `Global_is_immutable -> "global is immutable" - | `Illegal_escape txt -> Format.sprintf "illegal escape %S" txt + | `Illegal_escape txt -> Fmt.str "illegal escape %S" txt | `Import_after_function -> "import after function" | `Import_after_global -> "import after global" | `Import_after_memory -> "import after memory" @@ -91,9 +91,8 @@ let rec err_to_string = function | `Incompatible_import_type -> "incompatible import type" | `Inline_function_type -> "inline function type" | `Invalid_result_arity -> "invalid result arity" - | `Lexer_unknown_operator op -> Format.sprintf "unknown operator %s" op - | `Malformed_utf8_encoding txt -> - Format.sprintf "malformed UTF-8 encoding %S" txt + | `Lexer_unknown_operator op -> Fmt.str "unknown operator %s" op + | `Malformed_utf8_encoding txt -> Fmt.str "malformed UTF-8 encoding %S" txt | `Memory_size_too_large -> "memory size must be at most 65536 pages (4GiB)" | `Msg msg -> msg | `Multiple_memories -> "multiple memories" @@ -104,29 +103,26 @@ let rec err_to_string = function "size minimum must not be greater than maximum" | `Start_function -> "start function must have type [] -> []" | `Timeout -> "timeout" - | `Trap t -> Format.sprintf "trap: %s" (Trap.to_string t) - | `Type_mismatch msg -> Format.sprintf "type mismatch (%s)" msg + | `Trap t -> Fmt.str "trap: %s" (Trap.to_string t) + | `Type_mismatch msg -> Fmt.str "type mismatch (%s)" msg | `Unbound_last_module -> "unbound last module" - | `Unbound_module id -> Format.sprintf "unbound module %s" id - | `Unbound_name id -> Format.sprintf "unbound name %s" id + | `Unbound_module id -> Fmt.str "unbound module %s" id + | `Unbound_name id -> Fmt.str "unbound name %s" id | `Undeclared_function_reference -> "undeclared function reference" - | `Unexpected_token s -> Format.sprintf "unexpected token %S" s - | `Unknown_data id -> - Format.asprintf "unknown data segment %a" Types.pp_indice id - | `Unknown_elem id -> - Format.asprintf "unknown elem segment %a" Types.pp_indice id - | `Unknown_func id -> Format.asprintf "unknown function %a" Types.pp_indice id - | `Unknown_global id -> Format.asprintf "unknown global %a" Types.pp_indice id - | `Unknown_import (modul, value) -> - Format.sprintf "unknown import %S %S" modul value - | `Unknown_label id -> Format.asprintf "unknown label %a" Types.pp_indice id - | `Unknown_local id -> Format.asprintf "unknown local %a" Types.pp_indice id - | `Unknown_memory id -> Format.asprintf "unknown memory %a" Types.pp_indice id - | `Unknown_module name -> Format.sprintf "unknown module %s" name - | `Unknown_operator -> Format.sprintf "unknown operator" - | `Unknown_table id -> Format.asprintf "unknown table %a" Types.pp_indice id - | `Unknown_type id -> Format.asprintf "unknown type %a" Types.pp_indice id + | `Unexpected_token s -> Fmt.str "unexpected token %S" s + | `Unknown_data id -> Fmt.str "unknown data segment %a" Types.pp_indice id + | `Unknown_elem id -> Fmt.str "unknown elem segment %a" Types.pp_indice id + | `Unknown_func id -> Fmt.str "unknown function %a" Types.pp_indice id + | `Unknown_global id -> Fmt.str "unknown global %a" Types.pp_indice id + | `Unknown_import (modul, value) -> Fmt.str "unknown import %S %S" modul value + | `Unknown_label id -> Fmt.str "unknown label %a" Types.pp_indice id + | `Unknown_local id -> Fmt.str "unknown local %a" Types.pp_indice id + | `Unknown_memory id -> Fmt.str "unknown memory %a" Types.pp_indice id + | `Unknown_module name -> Fmt.str "unknown module %s" name + | `Unknown_operator -> Fmt.str "unknown operator" + | `Unknown_table id -> Fmt.str "unknown table %a" Types.pp_indice id + | `Unknown_type id -> Fmt.str "unknown type %a" Types.pp_indice id | `Unsupported_file_extension ext -> - Format.sprintf "unsupported file_extension %S" ext + Fmt.str "unsupported file_extension %S" ext -let failwith e = failwith (err_to_string e) +let failwith e = Fmt.failwith "%s" (err_to_string e) diff --git a/src/utils/result.mli b/src/utils/result.mli index 242afb9f0..56e93e634 100644 --- a/src/utils/result.mli +++ b/src/utils/result.mli @@ -2,7 +2,7 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -include module type of Stdlib.Result +include module type of Prelude.Result type err = [ `Alignment_too_large @@ -61,7 +61,7 @@ type err = | `Unsupported_file_extension of string ] -type 'a t = ('a, err) Stdlib.Result.t +type 'a t = ('a, err) Prelude.Result.t val err_to_string : err -> string diff --git a/src/utils/syntax.ml b/src/utils/syntax.ml index 96af9b1ab..81d77970c 100644 --- a/src/utils/syntax.ml +++ b/src/utils/syntax.ml @@ -2,7 +2,7 @@ (* Copyright © 2021-2024 OCamlPro *) (* Written by the Owi programmers *) -open Stdlib.Result +open Prelude.Result let ( let* ) o f = match o with Ok v -> f v | Error _ as e -> e @@ -10,8 +10,6 @@ let ( let+ ) o f = match o with Ok v -> Ok (f v) | Error _ as e -> e let error v = Error v -let error_s format = Format.kasprintf error format - let ok v = Ok v let list_iter f l = diff --git a/src/utils/syntax.mli b/src/utils/syntax.mli index 88be9df4f..b02f2b83a 100644 --- a/src/utils/syntax.mli +++ b/src/utils/syntax.mli @@ -3,48 +3,45 @@ (* Written by the Owi programmers *) val ( let* ) : - ('a, 'err) Stdlib.Result.t - -> ('a -> ('b, 'err) Stdlib.Result.t) - -> ('b, 'err) Stdlib.Result.t + ('a, 'err) Prelude.Result.t + -> ('a -> ('b, 'err) Prelude.Result.t) + -> ('b, 'err) Prelude.Result.t val ( let+ ) : - ('a, 'err) Stdlib.Result.t -> ('a -> 'b) -> ('b, 'err) Stdlib.Result.t + ('a, 'err) Prelude.Result.t -> ('a -> 'b) -> ('b, 'err) Prelude.Result.t -val error : string -> ('a, string) Stdlib.Result.t +val error : string -> ('a, string) Prelude.Result.t -val error_s : - ('a, Format.formatter, unit, ('b, string) Stdlib.Result.t) format4 -> 'a - -val ok : 'a -> ('a, 'err) Stdlib.Result.t +val ok : 'a -> ('a, 'err) Prelude.Result.t val list_iter : - ('a -> (unit, 'err) Stdlib.Result.t) + ('a -> (unit, 'err) Prelude.Result.t) -> 'a list - -> (unit, 'err) Stdlib.Result.t + -> (unit, 'err) Prelude.Result.t val list_map : - ('a -> ('b, 'err) Stdlib.Result.t) + ('a -> ('b, 'err) Prelude.Result.t) -> 'a list - -> ('b list, 'err) Stdlib.Result.t + -> ('b list, 'err) Prelude.Result.t val list_fold_left : - ('a -> 'b -> ('a, 'err) Stdlib.Result.t) + ('a -> 'b -> ('a, 'err) Prelude.Result.t) -> 'a -> 'b list - -> ('a, 'err) Stdlib.Result.t + -> ('a, 'err) Prelude.Result.t val array_iter : - ('a -> (unit, 'err) Stdlib.Result.t) + ('a -> (unit, 'err) Prelude.Result.t) -> 'a array - -> (unit, 'err) Stdlib.Result.t + -> (unit, 'err) Prelude.Result.t val array_map : - ('a -> ('b, 'err) Stdlib.Result.t) + ('a -> ('b, 'err) Prelude.Result.t) -> 'a array - -> ('b array, 'err) Stdlib.Result.t + -> ('b array, 'err) Prelude.Result.t val array_fold_left : - ('a -> 'b -> ('a, 'err) Stdlib.Result.t) + ('a -> 'b -> ('a, 'err) Prelude.Result.t) -> 'a -> 'b array - -> ('a, 'err) Stdlib.Result.t + -> ('a, 'err) Prelude.Result.t diff --git a/src/validate/typecheck.ml b/src/validate/typecheck.ml index d386e7537..36a2a35bb 100644 --- a/src/validate/typecheck.ml +++ b/src/validate/typecheck.ml @@ -5,7 +5,7 @@ open Types open Binary open Syntax -open Format +open Fmt type typ = | Num_type of num_type @@ -13,13 +13,21 @@ type typ = | Any | Something +let typ_equal t1 t2 = + match (t1, t2) with + | Num_type t1, Num_type t2 -> Types.num_type_eq t1 t2 + | Ref_type t1, Ref_type t2 -> Types.heap_type_eq t1 t2 + | Any, _ | _, Any -> true + | Something, _ | _, Something -> true + | _, _ -> false + let pp_typ fmt = function | Num_type t -> pp_num_type fmt t | Ref_type t -> pp_heap_type fmt t - | Any -> pp_string fmt "any" - | Something -> pp_string fmt "something" + | Any -> string fmt "any" + | Something -> string fmt "something" -let pp_typ_list fmt l = pp_list ~pp_sep:pp_space pp_typ fmt l +let pp_typ_list fmt l = list ~sep:sp pp_typ fmt l let typ_of_val_type = function | Types.Ref_type (_null, t) -> Ref_type t @@ -42,7 +50,7 @@ let check_data modul n = else Ok () let check_align memarg_align align = - if memarg_align >= align then Error `Alignment_too_large else Ok () + if Int32.ge memarg_align align then Error `Alignment_too_large else Ok () module Env = struct type t = @@ -147,7 +155,7 @@ module Stack : sig end = struct type t = typ list - let pp fmt (s : stack) = pp fmt "[%a]" pp_typ_list s + let pp fmt (s : stack) = pf fmt "[%a]" pp_typ_list s let match_num_type (required : num_type) (got : num_type) = match (required, got) with @@ -183,7 +191,7 @@ end = struct let rec equal s s' = match (s, s') with - | [], s | s, [] -> List.for_all (( = ) Any) s + | [], s | s, [] -> List.for_all (function Any -> true | _ -> false) s | Any :: tl, Any :: tl' -> equal tl s' || equal s tl' | Any :: tl, hd :: tl' | hd :: tl', Any :: tl -> equal tl (hd :: tl') || equal (Any :: tl) tl' @@ -449,7 +457,7 @@ let rec typecheck_instr (env : Env.t) (stack : stack) (instr : binary instr) : | Table_copy (Raw i, Raw i') -> let* typ = Env.table_type_get i env.modul in let* typ' = Env.table_type_get i' env.modul in - if typ <> typ' then Error (`Type_mismatch "table_copy") + if not @@ Types.ref_type_eq typ typ' then Error (`Type_mismatch "table_copy") else Stack.pop [ i32; i32; i32 ] stack | Table_fill (Raw i) -> let* _null, t = Env.table_type_get i env.modul in @@ -557,8 +565,8 @@ and typecheck_expr env expr ~is_loop (block_type : binary block_type option) | None -> Error (`Type_mismatch - (Format.asprintf "expected a prefix of %a but stack has type %a" - Stack.pp pt Stack.pp previous_stack ) ) + (Fmt.str "expected a prefix of %a but stack has type %a" Stack.pp pt + Stack.pp previous_stack ) ) | Some stack_to_push -> Stack.push rt stack_to_push let typecheck_function (modul : modul) func refs = @@ -628,7 +636,8 @@ let typecheck_global (modul : modul) refs match real_type with | [ real_type ] -> let expected = typ_of_val_type @@ snd typ in - if expected <> real_type then Error (`Type_mismatch "typecheck global 1") + if not @@ typ_equal expected real_type then + Error (`Type_mismatch "typecheck global 1") else Ok () | _whatever -> Error (`Type_mismatch "typecheck_global 2") ) @@ -641,7 +650,7 @@ let typecheck_elem modul refs (elem : elem Indexed.t) = let* real_type = typecheck_const_expr modul refs init in match real_type with | [ real_type ] -> - if Ref_type expected_type <> real_type then + if not @@ typ_equal (Ref_type expected_type) real_type then Error (`Type_mismatch "typecheck_elem 1") else Ok () | _whatever -> Error (`Type_mismatch "typecheck elem 2") ) @@ -652,12 +661,14 @@ let typecheck_elem modul refs (elem : elem Indexed.t) = | Elem_active (None, _e) -> assert false | Elem_active (Some tbl_i, e) -> ( let* _null, tbl_type = Env.table_type_get tbl_i modul in - if tbl_type <> expected_type then Error (`Type_mismatch "typecheck elem 3") + if not @@ Types.heap_type_eq tbl_type expected_type then + Error (`Type_mismatch "typecheck elem 3") else let* t = typecheck_const_expr modul refs e in match t with | [ Ref_type t ] -> - if t <> tbl_type then Error (`Type_mismatch "typecheck_elem 4") + if not @@ Types.heap_type_eq t tbl_type then + Error (`Type_mismatch "typecheck_elem 4") else Ok () | [ _t ] -> Ok () | _whatever -> Error (`Type_mismatch "typecheck_elem 5") )