From 2816f11288711b157534b8779f5be34c1aa84cf8 Mon Sep 17 00:00:00 2001 From: Steven de Oliveira Date: Fri, 13 Oct 2023 18:53:57 +0200 Subject: [PATCH] Indexed identifier --- src/lib/frontend/d_cnf.ml | 389 +++++++++++++++++------------ src/lib/structures/fpa_rounding.ml | 10 +- 2 files changed, 229 insertions(+), 170 deletions(-) diff --git a/src/lib/frontend/d_cnf.ml b/src/lib/frontend/d_cnf.ml index 78950fdd71..269ebff57c 100644 --- a/src/lib/frontend/d_cnf.ml +++ b/src/lib/frontend/d_cnf.ml @@ -178,6 +178,7 @@ end (** Builtins *) type _ DStd.Builtin.t += | Float + | SMTFloat of int * int * Fpa_rounding.rounding_mode | Integer_round | Abs_real | Sqrt_real @@ -200,6 +201,52 @@ let with_cache ~cache f x = Hashtbl.add cache x res; res +let builtin_term t = Dl.Typer.T.builtin_term t + +let builtin_ty t = Dl.Typer.T.builtin_ty t + +let ty name ty = + Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@ + fun env s -> + builtin_ty @@ + Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty + +let builtin_enum = function + | Ty.Tsum (name, cstrs) as ty_ -> + let ty_cst = + DStd.Expr.Id.mk ~builtin:B.Base + (DStd.Path.global (Hstring.view name)) + DStd.Expr.{ arity = 0; alias = No_alias } + in + let cstrs = + List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs + in + let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in + let dty = DT.apply ty_cst [] in + let add_cstrs map = + List.fold_left (fun map ((c : DE.term_cst), _) -> + let name = get_basename c.path in + Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ -> + builtin_term @@ + Dolmen_type.Base.term_app_cst + (module Dl.Typer.T) env c) map) + map cstrs + in + Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_; + dty, + cstrs, + fun map -> + map + |> ty (Hstring.view name) dty + |> add_cstrs + | _ -> assert false + +let ae_fpa_rounding_mode, ae_rounding_modes, ae_add_rounding_modes = + builtin_enum Fpa_rounding.AE.fpa_rounding_mode + +let smt_fpa_rounding_mode, _smt_rounding_modes, smt_add_rounding_modes = + builtin_enum Fpa_rounding.SMT2.fpa_rounding_mode + module Const = struct open DE @@ -214,6 +261,16 @@ module Const = struct let name = "int2bv" in Id.mk ~name ~builtin:(Int2BV n) (DStd.Path.global name) Ty.(arrow [int] (bitv n))) + + let smt_round = + with_cache ~cache:(Hashtbl.create 13) (fun n -> + with_cache ~cache:(Hashtbl.create 13) (fun m rm -> + let name = "ae.round" in + Id.mk + ~name + ~builtin:(SMTFloat (n, m, rm)) + (DStd.Path.global name) + Ty.(arrow [real] real))) end let bv2nat t = @@ -227,6 +284,9 @@ let bv2nat t = let int2bv n t = DE.Term.apply_cst (Const.int2bv n) [] [t] +let smt_round n m rm t = + DE.Term.apply_cst (Const.smt_round n m rm) [] [t] + let bv_builtins env s = let term_app1 f = Dl.Typer.T.builtin_term @@ @@ -248,70 +308,14 @@ let bv_builtins env s = end | _ -> `Not_found -let fpa_builtins map env s = - match s with - | Dl.Typer.T.Id id -> - begin - try - Id.Map.find_exn id map env s - with Not_found -> `Not_found - end - | Builtin _ -> `Not_found - -let fpa_builtins_map l = +let ae_fpa_builtins = let (->.) args ret = (args, ret) in - let builtin_term t = Dl.Typer.T.builtin_term t in - let builtin_ty t = Dl.Typer.T.builtin_ty t in let dterm name f = Id.Map.add { name = DStd.Name.simple name; ns = Term } @@ fun env s -> builtin_term @@ Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f in - let ty name ty = - Id.Map.add { name = DStd.Name.simple name; ns = Sort } @@ - fun env s -> - builtin_ty @@ - Dolmen_type.Base.app0 (module Dl.Typer.T) env s ty - in - let builtin_enum = function - | Ty.Tsum (name, cstrs) as ty_ -> - let ty_cst = - DStd.Expr.Id.mk ~builtin:B.Base - (DStd.Path.global (Hstring.view name)) - DStd.Expr.{ arity = 0; alias = No_alias } - in - let cstrs = - List.map (fun c -> DStd.Path.global (Hstring.view c), []) cstrs - in - let _, cstrs = DStd.Expr.Term.define_adt ty_cst [] cstrs in - let dty = DT.apply ty_cst [] in - let add_cstrs map = - List.fold_left (fun map ((c : DE.term_cst), _) -> - let name = get_basename c.path in - Id.Map.add { name = DStd.Name.simple name; ns = Term } (fun env _ -> - builtin_term @@ - Dolmen_type.Base.term_app_cst - (module Dl.Typer.T) env c) map) - map cstrs - in - Cache.store_ty (DE.Ty.Const.hash ty_cst) ty_; - dty, - cstrs, - fun map -> - map - |> ty (Hstring.view name) dty - |> add_cstrs - | _ -> assert false - in - let (module FPAU : Fpa_rounding.S) = - match l with - | `Smtlib2 -> (module Fpa_rounding.SMT2) - | `Ae -> (module Fpa_rounding.AE) - in - let fpa_rounding_mode, rounding_modes, add_rounding_modes = - builtin_enum FPAU.fpa_rounding_mode - in let op ?(tyvars = []) name builtin (args, ret) = let ty = DT.pi tyvars @@ DT.arrow args ret in let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in @@ -322,131 +326,160 @@ let fpa_builtins_map l = (module Dl.Typer.T) env cst in let open DT in - match l with - | `Ae -> - begin (* Legacy fpa builtins *) - let float_cst = - let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in - DE.Id.mk ~name:"float" ~builtin:Float (DStd.Path.global "float") ty - in - let float prec exp mode x = - DE.Term.apply_cst float_cst [] [prec; exp; mode; x] - in - let mode m = - let cst, _ = - List.find (fun (cst, _args) -> - match cst.DE.path with - | Absolute { name; _ } -> String.equal name m - | Local _ -> false) - rounding_modes - in - DE.Term.apply_cst cst [] [] - in - let float32 = float (DE.Term.int "24") (DE.Term.int "149") in - let float32d x = float32 (mode "NearestTiesToEven") x in - let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in - let float64d x = float64 (mode "NearestTiesToEven") x in - let partial1 name f = - Id.Map.add { name = DStd.Name.simple name; ns = Term } @@ - fun env s -> - builtin_term @@ - Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f - in - let partial2 name f = - Id.Map.add { name = DStd.Name.simple name; ns = Term } @@ - fun env s -> - builtin_term @@ - Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f - in - let is_theory_constant = - let open DT in - let a = Var.mk "alpha" in - op ~tyvars:[a] "is_theory_constant" Is_theory_constant ([of_var a] ->. prop) - in - Id.Map.empty - - |> add_rounding_modes - - (* the first argument is mantissas' size (including the implicit bit), - the second one is the exp of the min representable normalized number, - the third one is the rounding mode, and the last one is the real to - be rounded *) - |> op "float" Float ([int; int; fpa_rounding_mode; real] ->. real) + let float_cst = + let ty = DT.(arrow [int; int; ae_fpa_rounding_mode; real] real) in + DE.Id.mk ~name:"float" ~builtin:Float (DStd.Path.global "float") ty + in + let float prec exp mode x = + DE.Term.apply_cst float_cst [] [prec; exp; mode; x] + in + let mode m = + let cst, _ = + List.find (fun (cst, _args) -> + match cst.DE.path with + | Absolute { name; _ } -> String.equal name m + | Local _ -> false) + ae_rounding_modes + in + DE.Term.apply_cst cst [] [] + in + let float32 = float (DE.Term.int "24") (DE.Term.int "149") in + let float32d x = float32 (mode "NearestTiesToEven") x in + let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in + let float64d x = float64 (mode "NearestTiesToEven") x in + let partial1 name f = + Id.Map.add { name = DStd.Name.simple name; ns = Term } @@ + fun env s -> + builtin_term @@ + Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f + in + let partial2 name f = + Id.Map.add { name = DStd.Name.simple name; ns = Term } @@ + fun env s -> + builtin_term @@ + Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f + in + let is_theory_constant = + let open DT in + let a = Var.mk "alpha" in + op + ~tyvars:[a] + "is_theory_constant" + Is_theory_constant + ([of_var a] ->. prop) + in + let fpa_builtins = + Id.Map.empty - |> partial2 "float32" float32 - |> partial1 "float32d" float32d + |> ae_add_rounding_modes - |> partial2 "float64" float64 - |> partial1 "float64d" float64d + (* the first argument is mantissas' size (including the implicit bit), + the second one is the exp of the min representable normalized number, + the third one is the rounding mode, and the last one is the real to + be rounded *) + |> op "float" Float ([int; int; ae_fpa_rounding_mode; real] ->. real) - (* rounds to nearest integer *) - |> op "integer_round" Integer_round ([fpa_rounding_mode; real] ->. int) + |> partial2 "float32" float32 + |> partial1 "float32d" float32d - (* type cast: from int to real *) - |> dterm "real_of_int" DE.Term.Int.to_real + |> partial2 "float64" float64 + |> partial1 "float64d" float64d - (* type check: integers *) - |> dterm "real_is_int" DE.Term.Real.is_int + (* rounds to nearest integer *) + |> op "integer_round" Integer_round ([ae_fpa_rounding_mode; real] ->. int) - (* abs value of a real *) - |> op "abs_real" Abs_real ([real] ->. real) + (* type cast: from int to real *) + |> dterm "real_of_int" DE.Term.Int.to_real - (* sqrt value of a real *) - |> op "sqrt_real" Sqrt_real ([real] ->. real) + (* type check: integers *) + |> dterm "real_is_int" DE.Term.Real.is_int - (* sqrt value of a real by default *) - |> op "sqrt_real_default" Sqrt_real_default ([real] ->. real) + (* abs value of a real *) + |> op "abs_real" Abs_real ([real] ->. real) - (* sqrt value of a real by excess *) - |> op "sqrt_real_excess" Sqrt_real_excess ([real] ->. real) + (* sqrt value of a real *) + |> op "sqrt_real" Sqrt_real ([real] ->. real) - (* abs value of an int *) - |> dterm "abs_int" DE.Term.Int.abs + (* sqrt value of a real by default *) + |> op "sqrt_real_default" Sqrt_real_default ([real] ->. real) - (* (integer) floor of a rational *) - |> dterm "int_floor" DE.Term.Real.floor_to_int + (* sqrt value of a real by excess *) + |> op "sqrt_real_excess" Sqrt_real_excess ([real] ->. real) - (* (integer) ceiling of a ratoinal *) - |> op "int_ceil" (Ceiling_to_int `Real) ([real] ->. int) + (* abs value of an int *) + |> dterm "abs_int" DE.Term.Int.abs - (* The functions below are only interpreted when applied on constants. - Aximatization for the general case are not currently imlemented *) + (* (integer) floor of a rational *) + |> dterm "int_floor" DE.Term.Real.floor_to_int - (* maximum of two reals *) - |> op "max_real" Max_real ([real; real] ->. real) + (* (integer) ceiling of a ratoinal *) + |> op "int_ceil" (Ceiling_to_int `Real) ([real] ->. int) - (* maximum of two ints *) - |> op "max_int" Max_int ([int; int] ->. int) + (* The functions below are only interpreted when applied on constants. + Aximatization for the general case are not currently imlemented *) - (* minimum of two ints *) - |> op "min_int" Min_int ([int; int] ->. int) + (* maximum of two reals *) + |> op "max_real" Max_real ([real; real] ->. real) - (* computes an integer log2 of a real. The function is only - interpreted on (non-zero) positive real constants. When applied on a - real 'm', the result 'res' of the function is such that: 2^res <= m < - 2^(res+1) *) - |> op "integer_log2" Integer_log2 ([real] ->. int) + (* maximum of two ints *) + |> op "max_int" Max_int ([int; int] ->. int) - (* only used for arithmetic. It should not be used for x in float(x) - to enable computations modulo equality *) + (* minimum of two ints *) + |> op "min_int" Min_int ([int; int] ->. int) - |> op "not_theory_constant" Not_theory_constant ([real] ->. prop) - |> is_theory_constant - |> op "linear_dependency" Linear_dependency ([real; real] ->. prop) - end (* Legacy fpa builtins *) + (* computes an integer log2 of a real. The function is only + interpreted on (non-zero) positive real constants. When applied on a + real 'm', the result 'res' of the function is such that: 2^res <= m < + 2^(res+1) *) + |> op "integer_log2" Integer_log2 ([real] ->. int) - | `Smtlib2 -> - begin (* SMTLib2 builtins *) - Id.Map.empty + (* only used for arithmetic. It should not be used for x in float(x) + to enable computations modulo equality *) - |> add_rounding_modes + |> op "not_theory_constant" Not_theory_constant ([real] ->. prop) + |> is_theory_constant + |> op "linear_dependency" Linear_dependency ([real; real] ->. prop) + in + fun env s -> + match s with + | Dl.Typer.T.Id id -> + begin + try + Id.Map.find_exn id fpa_builtins env s + with Not_found -> `Not_found + end + | Builtin _ -> `Not_found - (* the first argument is mantissas' size (including the implicit bit), - the second one is the exp of the min representable normalized number, - the third one is the rounding mode, and the last one is the real to - be rounded *) - |> op "ae.round" Float ([int; int; fpa_rounding_mode; real] ->. real) - end +let smt_fpa_builtins = + let term_app1 env s f = + Dl.Typer.T.builtin_term @@ + Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f + in + let other_builtins = + Id.Map.empty + |> smt_add_rounding_modes + in + fun env s -> + match s with + | Dl.Typer.T.Id { + ns = Term ; + name = Indexed { + basename = "ae.round" ; + indexes = [ i; j; rm ] } } -> + begin match + int_of_string i, + int_of_string j, + Fpa_rounding.SMT2.rounding_mode_of_hs (Hstring.make rm) + with + | n, m, rm -> term_app1 env s (smt_round n m rm) + | exception Failure _ -> `Not_found + end + | Dl.Typer.T.Id id -> begin + match Id.Map.find_exn id other_builtins env s with + | e -> e + | exception Not_found -> `Not_found + end + | _ -> `Not_found (** Concatenation of builtins handlers. *) let (++) bt1 bt2 = @@ -458,9 +491,8 @@ let (++) bt1 bt2 = let builtins = fun _st (lang : Typer.lang) -> match lang with - | `Logic Alt_ergo -> fpa_builtins (fpa_builtins_map `Ae) - | `Logic (Smtlib2 _) -> - (fpa_builtins (fpa_builtins_map `Smtlib2)) ++ bv_builtins + | `Logic Alt_ergo -> ae_fpa_builtins + | `Logic (Smtlib2 _) -> bv_builtins ++ smt_fpa_builtins | _ -> fun _ _ -> `Not_found (** Translates dolmen locs to Alt-Ergo's locs *) @@ -963,6 +995,13 @@ let mk_add translate sy ty l = let args = aux_mk_add l in E.mk_term sy args ty +let mk_rounding fpar = + let name = Fpa_rounding.SMT2.string_of_rounding_mode fpar in + let ty = Fpa_rounding.SMT2.fpa_rounding_mode in + let sy = + Sy.Op (Sy.Constr (Hstring.make name)) in + E.mk_term sy [] ty + (** [mk_expr ~loc ~name_base ~toplevel ~decl_kind term] Builds an Alt-Ergo hashconsed expression from a dolmen term @@ -1389,6 +1428,16 @@ let rec mk_expr | _ -> unsupported "coercion: %a" DE.Term.print term end | Float, _ -> op Float + | SMTFloat(i, j, rm), _ -> + let args = + let i = E.Ints.of_int i in + let j = E.Ints.of_int j in + let rm = mk_rounding rm in + i :: j :: rm :: List.map (fun a -> aux_mk_expr a) args in + E.mk_term + (Sy.Op Float) + args + (dty_to_ty term_ty) | Integer_round, _ -> op Integer_round | Abs_real, _ -> op Abs_real | Sqrt_real, _ -> op Sqrt_real @@ -1407,6 +1456,20 @@ let rec mk_expr | Not_theory_constant, _ -> op Not_theory_constant | Is_theory_constant, _ -> op Is_theory_constant | Linear_dependency, _ -> op Linear_dependency + | (B.RoundNearestTiesToEven + | B.RoundNearestTiesToAway + | B.RoundTowardPositive + | B.RoundTowardNegative + | B.RoundTowardZero as b), _ -> + let fpa_rounding = match b with + B.RoundNearestTiesToEven -> Fpa_rounding.NearestTiesToEven + | B.RoundNearestTiesToAway -> NearestTiesToAway + | B.RoundTowardPositive -> Up + | B.RoundTowardNegative -> Down + | B.RoundTowardZero -> ToZero + | _ -> assert false + in + mk_rounding fpa_rounding | _, _ -> unsupported "Application Term %a" DE.Term.print term end diff --git a/src/lib/structures/fpa_rounding.ml b/src/lib/structures/fpa_rounding.ml index 373004c2ce..18fba3e773 100644 --- a/src/lib/structures/fpa_rounding.ml +++ b/src/lib/structures/fpa_rounding.ml @@ -88,13 +88,9 @@ module Make (I : sig (fun key -> match Hashtbl.find_opt table key with | None -> Fmt.failwith - "%a" - (fun fmt k -> - Format.pp_print_string fmt I.name; - Hashtbl.iter (fun key _ -> Format.fprintf fmt "%a --" Hstring.print key) table; - Hstring.print fmt k - ) - key + "Error while searching for FPA value %a : %s" + Hstring.print key + I.name | Some res -> res) end