Skip to content

Commit

Permalink
Injecting AE type float rounding type into SMT rounding type
Browse files Browse the repository at this point in the history
  • Loading branch information
Stevendeo committed Oct 16, 2023
1 parent 5d1847e commit 8b94609
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 104 deletions.
63 changes: 44 additions & 19 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,8 @@ let builtin_enum = function
|> add_cstrs
| _ -> assert false

let ae_fpa_rounding_mode, ae_rounding_modes, ae_add_rounding_modes =
builtin_enum Fpa_rounding.AE.fpa_rounding_mode

let smt_fpa_rounding_mode, _smt_rounding_modes, smt_add_rounding_modes =
builtin_enum Fpa_rounding.SMT2.fpa_rounding_mode
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
builtin_enum Fpa_rounding.fpa_rounding_mode

module Const = struct
open DE
Expand All @@ -270,7 +267,7 @@ module Const = struct
~name
~builtin:(AERound (n, m))
(DStd.Path.global name)
Ty.(arrow [smt_fpa_rounding_mode; real] real))
Ty.(arrow [fpa_rounding_mode; real] real))
end

let bv2nat t =
Expand Down Expand Up @@ -308,6 +305,32 @@ let bv_builtins env s =
end
| _ -> `Not_found

(** Takes a dolmen identifier [id] and injects it in Alt-Ergo's registered
identifier. For example, transforms "fpa_rounding_mode", the Alt-Ergo builtin
type into the SMT rounding type 'RoundingMode'. Also does it for the enums
of this type. *)
let inject_identifier id =
match id with
| Id.{name = Simple n; _} ->
begin
if String.equal n Fpa_rounding.fpa_rounding_mode_ae_type_name then
(* Injecting the type name as the SMT2 Type name. *)
let name =
Dolmen_std.Name.simple Fpa_rounding.fpa_rounding_mode_type_name
in
{id with name}
else
match Fpa_rounding.rounding_mode_of_ae_hs (Hstring.make n) with
| rm ->
let name =
Dolmen_std.Name.simple (Fpa_rounding.string_of_rounding_mode rm)
in
{id with name}
| exception (Failure _) ->
id
end
| id -> id

let ae_fpa_builtins =
let (->.) args ret = (args, ret) in
let dterm name f =
Expand All @@ -326,7 +349,7 @@ let ae_fpa_builtins =
(module Dl.Typer.T) env cst
in
let float_cst =
let ty = DT.(arrow [int; int; ae_fpa_rounding_mode; real] real) in
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
DE.Id.mk ~name:"float" ~builtin:Float (DStd.Path.global "float") ty
in
let float prec exp mode x =
Expand All @@ -338,7 +361,7 @@ let ae_fpa_builtins =
match cst.DE.path with
| Absolute { name; _ } -> String.equal name m
| Local _ -> false)
ae_rounding_modes
rounding_modes
in
DE.Term.apply_cst cst [] []
in
Expand Down Expand Up @@ -371,13 +394,13 @@ let ae_fpa_builtins =
let open DT in
Id.Map.empty

|> ae_add_rounding_modes
|> add_rounding_modes

(* the first argument is mantissas' size (including the implicit bit),
the second one is the exp of the min representable normalized number,
the third one is the rounding mode, and the last one is the real to
be rounded *)
|> op "float" Float ([int; int; ae_fpa_rounding_mode; real] ->. real)
|> op "float" Float ([int; int; fpa_rounding_mode; real] ->. real)

|> partial2 "float32" float32
|> partial1 "float32d" float32d
Expand All @@ -386,7 +409,7 @@ let ae_fpa_builtins =
|> partial1 "float64d" float64d

(* rounds to nearest integer *)
|> op "integer_round" Integer_round ([ae_fpa_rounding_mode; real] ->. int)
|> op "integer_round" Integer_round ([fpa_rounding_mode; real] ->. int)

(* type cast: from int to real *)
|> dterm "real_of_int" DE.Term.Int.to_real
Expand Down Expand Up @@ -441,13 +464,15 @@ let ae_fpa_builtins =
|> op "linear_dependency" Linear_dependency ([real; real] ->. prop)
in
fun env s ->
let search_id id =
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
in
match s with
| Dl.Typer.T.Id id ->
begin
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
end
let new_id = inject_identifier id in
search_id new_id
| Builtin _ -> `Not_found

let smt_fpa_builtins =
Expand All @@ -457,7 +482,7 @@ let smt_fpa_builtins =
in
let other_builtins =
Id.Map.empty
|> smt_add_rounding_modes
|> add_rounding_modes
in
fun env s ->
match s with
Expand Down Expand Up @@ -995,8 +1020,8 @@ let mk_add translate sy ty l =
E.mk_term sy args ty

let mk_rounding fpar =
let name = Fpa_rounding.SMT2.string_of_rounding_mode fpar in
let ty = Fpa_rounding.SMT2.fpa_rounding_mode in
let name = Fpa_rounding.string_of_rounding_mode fpar in
let ty = Fpa_rounding.fpa_rounding_mode in
let sy =
Sy.Op (Sy.Constr (Hstring.make name)) in
E.mk_term sy [] ty
Expand Down
10 changes: 5 additions & 5 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ module Env = struct
| _ -> assert false

let add_fpa_builtins env =
let module FPAU = Fpa_rounding.AE in
(* let module FPAU = Fpa_rounding.AE in *)
let (->.) args result = { args; result } in
let int n = {
c = { tt_desc = TTconst (Tint n); tt_ty = Ty.Tint} ;
annot = new_id () ;
} in
let rm = FPAU.fpa_rounding_mode in
let rm = Fpa_rounding.fpa_rounding_mode in
let mode m =
let h = find_builtin_cstr rm m in
{
Expand All @@ -299,8 +299,8 @@ module Env = struct
let float prec exp mode x =
TTapp (Symbols.Op Float, [prec; exp; mode; x])
in
let nte = FPAU.string_of_rounding_mode NearestTiesToEven in
let tname = FPAU.fpa_rounding_mode_type_name in
let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in
let tname = Fpa_rounding.fpa_rounding_mode_type_name in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
Expand All @@ -318,7 +318,7 @@ module Env = struct
types = Types.add_builtin env.types tname rm ;
builtins =
add_builtin_enum
FPAU.fpa_rounding_mode
Fpa_rounding.fpa_rounding_mode
env.builtins;
} in
let builtins =
Expand Down
7 changes: 2 additions & 5 deletions src/lib/structures/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2819,11 +2819,8 @@ let const_view t =
Fmt.failwith "error when trying to convert %a to an int" Z.pp_print n
end
| { f = Op (Constr c); ty; _ }
when Ty.equal ty Fpa_rounding.SMT2.fpa_rounding_mode ->
RoundingMode (Fpa_rounding.SMT2.rounding_mode_of_hs c)
| { f = Op (Constr c); ty; _ }
when Ty.equal ty Fpa_rounding.AE.fpa_rounding_mode ->
RoundingMode (Fpa_rounding.AE.rounding_mode_of_hs c)
when Ty.equal ty Fpa_rounding.fpa_rounding_mode ->
RoundingMode (Fpa_rounding.rounding_mode_of_smt_hs c)
| _ -> Fmt.failwith "unsupported constant: %a" print t

let int_view t =
Expand Down
135 changes: 68 additions & 67 deletions src/lib/structures/fpa_rounding.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,73 +52,74 @@ let cstrs =
NearestTiesToAway;
]

module type S = sig
val fpa_rounding_mode_type_name : string

val fpa_rounding_mode : Ty.t

val rounding_mode_of_hs : Hstring.t -> rounding_mode

val string_of_rounding_mode : rounding_mode -> string
end

module Make (I : sig
val name : string
val to_string : rounding_mode -> string
end) : S = struct

let fpa_rounding_mode_type_name = I.name

let string_of_rounding_mode = I.to_string

let fpa_rounding_mode, rounding_mode_of_hs =
let h_cstrs =
List.map (fun c -> Hs.make (I.to_string c)) cstrs
in
let ty = Ty.Tsum (Hs.make I.name, h_cstrs) in
let table =
let table = Hashtbl.create 5 in
List.iter2 (
fun key bnd ->
Hashtbl.add table key bnd
) h_cstrs cstrs;
table
in
ty,
(fun key -> match Hashtbl.find_opt table key with
| None ->
Fmt.failwith
"Error while searching for FPA value %a : %s"
Hstring.print key
I.name
| Some res -> res)
end

module AE : S =
Make (struct
let name = "fpa_rounding_mode"
let to_string =
function
| NearestTiesToEven -> "NearestTiesToEven"
| ToZero -> "ToZero"
| Up -> "Up"
| Down -> "Down"
| NearestTiesToAway -> "NearestTiesToAway"
end
)

module SMT2 : S =
Make (struct
let name = "RoundingMode"

let to_string =
function
| NearestTiesToEven -> "RNE"
| ToZero -> "RTZ"
| Up -> "RTP"
| Down -> "RTN"
| NearestTiesToAway -> "RNA"
end)
let to_smt_string =
function
| NearestTiesToEven -> "RNE"
| ToZero -> "RTZ"
| Up -> "RTP"
| Down -> "RTN"
| NearestTiesToAway -> "RNA"

let to_ae_string = function
| NearestTiesToEven -> "NearestTiesToEven"
| ToZero -> "ToZero"
| Up -> "Up"
| Down -> "Down"
| NearestTiesToAway -> "NearestTiesToAway"


let fpa_rounding_mode_ae_type_name = "fpa_rounding_mode"

let fpa_rounding_mode_type_name = "RoundingMode"

(* The exported 'to string' function is the SMT one. *)
let string_of_rounding_mode = to_smt_string

let hstring_smt_reprs =
List.map
(fun c -> Hs.make (to_smt_string c))
cstrs

let hstring_ae_reprs =
List.map
(fun c -> Hs.make (to_ae_string c))
cstrs

(* The rounding mode is the enum with the SMT values.
The Alt-Ergo values are injected in this type. *)
let fpa_rounding_mode =
Ty.Tsum (Hs.make "RoundingMode", hstring_smt_reprs)

let rounding_mode_of_smt_hs_opt =
let table = Hashtbl.create 5 in
List.iter2 (
fun key bnd ->
Hashtbl.add table key bnd
) hstring_smt_reprs cstrs;
fun key -> Hashtbl.find_opt table key

let rounding_mode_of_smt_hs hs =
match rounding_mode_of_smt_hs_opt hs with
| None ->
Fmt.failwith
"Error while searching for FPA value %a."
Hstring.print hs
fpa_rounding_mode_type_name
| Some res -> res

let rounding_mode_of_ae_hs =
let table = Hashtbl.create 5 in
List.iter2 (
fun key bnd ->
Hashtbl.add table key bnd
) hstring_ae_reprs cstrs;
fun key ->
match Hashtbl.find_opt table key with
| None ->
(* Alt-Ergo's legacy language also accepts the SMT representation
of rounding modes. *)
rounding_mode_of_smt_hs key
| Some res -> res

(** Helper functions **)

Expand Down
15 changes: 7 additions & 8 deletions src/lib/structures/fpa_rounding.mli
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,17 @@ type rounding_mode =
| Down
| NearestTiesToAway

module type S = sig
val fpa_rounding_mode_type_name : string
val fpa_rounding_mode_type_name : string

val fpa_rounding_mode : Ty.t
val fpa_rounding_mode_ae_type_name : string

val rounding_mode_of_hs : Hstring.t -> rounding_mode
val fpa_rounding_mode : Ty.t

val string_of_rounding_mode : rounding_mode -> string
end
val rounding_mode_of_smt_hs : Hstring.t -> rounding_mode

module AE : S
module SMT2 : S
val rounding_mode_of_ae_hs : Hstring.t -> rounding_mode

val string_of_rounding_mode : rounding_mode -> string

(** Integer part of binary logarithm for NON-ZERO POSITIVE number **)
val integer_log_2 : Numbers.Q.t -> int
Expand Down

0 comments on commit 8b94609

Please sign in to comment.