From 8b94609651f84b119451836e943eeb7e0cc3ae04 Mon Sep 17 00:00:00 2001 From: Steven de Oliveira Date: Mon, 16 Oct 2023 15:37:36 +0200 Subject: [PATCH] Injecting AE type float rounding type into SMT rounding type --- src/lib/frontend/d_cnf.ml | 63 +++++++++---- src/lib/frontend/typechecker.ml | 10 +-- src/lib/structures/expr.ml | 7 +- src/lib/structures/fpa_rounding.ml | 135 ++++++++++++++-------------- src/lib/structures/fpa_rounding.mli | 15 ++-- 5 files changed, 126 insertions(+), 104 deletions(-) diff --git a/src/lib/frontend/d_cnf.ml b/src/lib/frontend/d_cnf.ml index 3f3ba3e8ec..8f9324e271 100644 --- a/src/lib/frontend/d_cnf.ml +++ b/src/lib/frontend/d_cnf.ml @@ -242,11 +242,8 @@ let builtin_enum = function |> add_cstrs | _ -> assert false -let ae_fpa_rounding_mode, ae_rounding_modes, ae_add_rounding_modes = - builtin_enum Fpa_rounding.AE.fpa_rounding_mode - -let smt_fpa_rounding_mode, _smt_rounding_modes, smt_add_rounding_modes = - builtin_enum Fpa_rounding.SMT2.fpa_rounding_mode +let fpa_rounding_mode, rounding_modes, add_rounding_modes = + builtin_enum Fpa_rounding.fpa_rounding_mode module Const = struct open DE @@ -270,7 +267,7 @@ module Const = struct ~name ~builtin:(AERound (n, m)) (DStd.Path.global name) - Ty.(arrow [smt_fpa_rounding_mode; real] real)) + Ty.(arrow [fpa_rounding_mode; real] real)) end let bv2nat t = @@ -308,6 +305,32 @@ let bv_builtins env s = end | _ -> `Not_found +(** Takes a dolmen identifier [id] and injects it in Alt-Ergo's registered + identifier. For example, transforms "fpa_rounding_mode", the Alt-Ergo builtin + type into the SMT rounding type 'RoundingMode'. Also does it for the enums + of this type. *) +let inject_identifier id = + match id with + | Id.{name = Simple n; _} -> + begin + if String.equal n Fpa_rounding.fpa_rounding_mode_ae_type_name then + (* Injecting the type name as the SMT2 Type name. *) + let name = + Dolmen_std.Name.simple Fpa_rounding.fpa_rounding_mode_type_name + in + {id with name} + else + match Fpa_rounding.rounding_mode_of_ae_hs (Hstring.make n) with + | rm -> + let name = + Dolmen_std.Name.simple (Fpa_rounding.string_of_rounding_mode rm) + in + {id with name} + | exception (Failure _) -> + id + end + | id -> id + let ae_fpa_builtins = let (->.) args ret = (args, ret) in let dterm name f = @@ -326,7 +349,7 @@ let ae_fpa_builtins = (module Dl.Typer.T) env cst in let float_cst = - let ty = DT.(arrow [int; int; ae_fpa_rounding_mode; real] real) in + let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in DE.Id.mk ~name:"float" ~builtin:Float (DStd.Path.global "float") ty in let float prec exp mode x = @@ -338,7 +361,7 @@ let ae_fpa_builtins = match cst.DE.path with | Absolute { name; _ } -> String.equal name m | Local _ -> false) - ae_rounding_modes + rounding_modes in DE.Term.apply_cst cst [] [] in @@ -371,13 +394,13 @@ let ae_fpa_builtins = let open DT in Id.Map.empty - |> ae_add_rounding_modes + |> add_rounding_modes (* the first argument is mantissas' size (including the implicit bit), the second one is the exp of the min representable normalized number, the third one is the rounding mode, and the last one is the real to be rounded *) - |> op "float" Float ([int; int; ae_fpa_rounding_mode; real] ->. real) + |> op "float" Float ([int; int; fpa_rounding_mode; real] ->. real) |> partial2 "float32" float32 |> partial1 "float32d" float32d @@ -386,7 +409,7 @@ let ae_fpa_builtins = |> partial1 "float64d" float64d (* rounds to nearest integer *) - |> op "integer_round" Integer_round ([ae_fpa_rounding_mode; real] ->. int) + |> op "integer_round" Integer_round ([fpa_rounding_mode; real] ->. int) (* type cast: from int to real *) |> dterm "real_of_int" DE.Term.Int.to_real @@ -441,13 +464,15 @@ let ae_fpa_builtins = |> op "linear_dependency" Linear_dependency ([real; real] ->. prop) in fun env s -> + let search_id id = + try + Id.Map.find_exn id fpa_builtins env s + with Not_found -> `Not_found + in match s with | Dl.Typer.T.Id id -> - begin - try - Id.Map.find_exn id fpa_builtins env s - with Not_found -> `Not_found - end + let new_id = inject_identifier id in + search_id new_id | Builtin _ -> `Not_found let smt_fpa_builtins = @@ -457,7 +482,7 @@ let smt_fpa_builtins = in let other_builtins = Id.Map.empty - |> smt_add_rounding_modes + |> add_rounding_modes in fun env s -> match s with @@ -995,8 +1020,8 @@ let mk_add translate sy ty l = E.mk_term sy args ty let mk_rounding fpar = - let name = Fpa_rounding.SMT2.string_of_rounding_mode fpar in - let ty = Fpa_rounding.SMT2.fpa_rounding_mode in + let name = Fpa_rounding.string_of_rounding_mode fpar in + let ty = Fpa_rounding.fpa_rounding_mode in let sy = Sy.Op (Sy.Constr (Hstring.make name)) in E.mk_term sy [] ty diff --git a/src/lib/frontend/typechecker.ml b/src/lib/frontend/typechecker.ml index 2507cf3b38..b12e6342a2 100644 --- a/src/lib/frontend/typechecker.ml +++ b/src/lib/frontend/typechecker.ml @@ -279,13 +279,13 @@ module Env = struct | _ -> assert false let add_fpa_builtins env = - let module FPAU = Fpa_rounding.AE in + (* let module FPAU = Fpa_rounding.AE in *) let (->.) args result = { args; result } in let int n = { c = { tt_desc = TTconst (Tint n); tt_ty = Ty.Tint} ; annot = new_id () ; } in - let rm = FPAU.fpa_rounding_mode in + let rm = Fpa_rounding.fpa_rounding_mode in let mode m = let h = find_builtin_cstr rm m in { @@ -299,8 +299,8 @@ module Env = struct let float prec exp mode x = TTapp (Symbols.Op Float, [prec; exp; mode; x]) in - let nte = FPAU.string_of_rounding_mode NearestTiesToEven in - let tname = FPAU.fpa_rounding_mode_type_name in + let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in + let tname = Fpa_rounding.fpa_rounding_mode_type_name in let float32 = float (int "24") (int "149") in let float32d = float32 (mode nte) in let float64 = float (int "53") (int "1074") in @@ -318,7 +318,7 @@ module Env = struct types = Types.add_builtin env.types tname rm ; builtins = add_builtin_enum - FPAU.fpa_rounding_mode + Fpa_rounding.fpa_rounding_mode env.builtins; } in let builtins = diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index f30e029894..2135c39d60 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -2819,11 +2819,8 @@ let const_view t = Fmt.failwith "error when trying to convert %a to an int" Z.pp_print n end | { f = Op (Constr c); ty; _ } - when Ty.equal ty Fpa_rounding.SMT2.fpa_rounding_mode -> - RoundingMode (Fpa_rounding.SMT2.rounding_mode_of_hs c) - | { f = Op (Constr c); ty; _ } - when Ty.equal ty Fpa_rounding.AE.fpa_rounding_mode -> - RoundingMode (Fpa_rounding.AE.rounding_mode_of_hs c) + when Ty.equal ty Fpa_rounding.fpa_rounding_mode -> + RoundingMode (Fpa_rounding.rounding_mode_of_smt_hs c) | _ -> Fmt.failwith "unsupported constant: %a" print t let int_view t = diff --git a/src/lib/structures/fpa_rounding.ml b/src/lib/structures/fpa_rounding.ml index ffa3879c53..e2f0072323 100644 --- a/src/lib/structures/fpa_rounding.ml +++ b/src/lib/structures/fpa_rounding.ml @@ -52,73 +52,74 @@ let cstrs = NearestTiesToAway; ] -module type S = sig - val fpa_rounding_mode_type_name : string - - val fpa_rounding_mode : Ty.t - - val rounding_mode_of_hs : Hstring.t -> rounding_mode - - val string_of_rounding_mode : rounding_mode -> string -end - -module Make (I : sig - val name : string - val to_string : rounding_mode -> string - end) : S = struct - - let fpa_rounding_mode_type_name = I.name - - let string_of_rounding_mode = I.to_string - - let fpa_rounding_mode, rounding_mode_of_hs = - let h_cstrs = - List.map (fun c -> Hs.make (I.to_string c)) cstrs - in - let ty = Ty.Tsum (Hs.make I.name, h_cstrs) in - let table = - let table = Hashtbl.create 5 in - List.iter2 ( - fun key bnd -> - Hashtbl.add table key bnd - ) h_cstrs cstrs; - table - in - ty, - (fun key -> match Hashtbl.find_opt table key with - | None -> - Fmt.failwith - "Error while searching for FPA value %a : %s" - Hstring.print key - I.name - | Some res -> res) -end - -module AE : S = - Make (struct - let name = "fpa_rounding_mode" - let to_string = - function - | NearestTiesToEven -> "NearestTiesToEven" - | ToZero -> "ToZero" - | Up -> "Up" - | Down -> "Down" - | NearestTiesToAway -> "NearestTiesToAway" - end - ) - -module SMT2 : S = - Make (struct - let name = "RoundingMode" - - let to_string = - function - | NearestTiesToEven -> "RNE" - | ToZero -> "RTZ" - | Up -> "RTP" - | Down -> "RTN" - | NearestTiesToAway -> "RNA" - end) +let to_smt_string = + function + | NearestTiesToEven -> "RNE" + | ToZero -> "RTZ" + | Up -> "RTP" + | Down -> "RTN" + | NearestTiesToAway -> "RNA" + +let to_ae_string = function + | NearestTiesToEven -> "NearestTiesToEven" + | ToZero -> "ToZero" + | Up -> "Up" + | Down -> "Down" + | NearestTiesToAway -> "NearestTiesToAway" + + +let fpa_rounding_mode_ae_type_name = "fpa_rounding_mode" + +let fpa_rounding_mode_type_name = "RoundingMode" + +(* The exported 'to string' function is the SMT one. *) +let string_of_rounding_mode = to_smt_string + +let hstring_smt_reprs = + List.map + (fun c -> Hs.make (to_smt_string c)) + cstrs + +let hstring_ae_reprs = + List.map + (fun c -> Hs.make (to_ae_string c)) + cstrs + +(* The rounding mode is the enum with the SMT values. + The Alt-Ergo values are injected in this type. *) +let fpa_rounding_mode = + Ty.Tsum (Hs.make "RoundingMode", hstring_smt_reprs) + +let rounding_mode_of_smt_hs_opt = + let table = Hashtbl.create 5 in + List.iter2 ( + fun key bnd -> + Hashtbl.add table key bnd + ) hstring_smt_reprs cstrs; + fun key -> Hashtbl.find_opt table key + +let rounding_mode_of_smt_hs hs = + match rounding_mode_of_smt_hs_opt hs with + | None -> + Fmt.failwith + "Error while searching for FPA value %a." + Hstring.print hs + fpa_rounding_mode_type_name + | Some res -> res + +let rounding_mode_of_ae_hs = + let table = Hashtbl.create 5 in + List.iter2 ( + fun key bnd -> + Hashtbl.add table key bnd + ) hstring_ae_reprs cstrs; + fun key -> + match Hashtbl.find_opt table key with + | None -> + (* Alt-Ergo's legacy language also accepts the SMT representation + of rounding modes. *) + rounding_mode_of_smt_hs key + | Some res -> res (** Helper functions **) diff --git a/src/lib/structures/fpa_rounding.mli b/src/lib/structures/fpa_rounding.mli index 939a7da70b..13b631673d 100644 --- a/src/lib/structures/fpa_rounding.mli +++ b/src/lib/structures/fpa_rounding.mli @@ -35,18 +35,17 @@ type rounding_mode = | Down | NearestTiesToAway -module type S = sig - val fpa_rounding_mode_type_name : string +val fpa_rounding_mode_type_name : string - val fpa_rounding_mode : Ty.t +val fpa_rounding_mode_ae_type_name : string - val rounding_mode_of_hs : Hstring.t -> rounding_mode +val fpa_rounding_mode : Ty.t - val string_of_rounding_mode : rounding_mode -> string -end +val rounding_mode_of_smt_hs : Hstring.t -> rounding_mode -module AE : S -module SMT2 : S +val rounding_mode_of_ae_hs : Hstring.t -> rounding_mode + +val string_of_rounding_mode : rounding_mode -> string (** Integer part of binary logarithm for NON-ZERO POSITIVE number **) val integer_log_2 : Numbers.Q.t -> int