Skip to content

Commit

Permalink
Adding round as a non indexed primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
Stevendeo committed Oct 13, 2023
1 parent eb8bbbf commit b809f5b
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 109 deletions.
235 changes: 127 additions & 108 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,17 @@ let bv_builtins env s =
end
| _ -> `Not_found

let fpa_builtins =
let fpa_builtins map env s =
match s with
| Dl.Typer.T.Id id ->
begin
try
Id.Map.find_exn id map env s
with Not_found -> `Not_found
end
| Builtin _ -> `Not_found

let fpa_builtins_map l =
let (->.) args ret = (args, ret) in
let builtin_term t = Dl.Typer.T.builtin_term t in
let builtin_ty t = Dl.Typer.T.builtin_ty t in
Expand Down Expand Up @@ -294,31 +304,14 @@ let fpa_builtins =
|> add_cstrs
| _ -> assert false
in
let (module FPAU : Fpa_rounding.S) =
match l with
| `Smtlib2 -> (module Fpa_rounding.SMT2)
| `Ae -> (module Fpa_rounding.AE)
in
let fpa_rounding_mode, rounding_modes, add_rounding_modes =
let module FPAU : Fpa_rounding.S = (val (Fpa_rounding.fpa_rounding_utils ())) in
builtin_enum FPAU.fpa_rounding_mode
in
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
DE.Id.mk ~name:"float" ~builtin:Float (DStd.Path.global "float") ty
in
let float prec exp mode x =
DE.Term.apply_cst float_cst [] [prec; exp; mode; x]
in
let mode m =
let cst, _ =
List.find (fun (cst, _args) ->
match cst.DE.path with
| Absolute { name; _ } -> String.equal name m
| Local _ -> false)
rounding_modes
in
DE.Term.apply_cst cst [] []
in
let float32 = float (DE.Term.int "24") (DE.Term.int "149") in
let float32d x = float32 (mode "NearestTiesToEven") x in
let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in
let float64d x = float64 (mode "NearestTiesToEven") x in
let op ?(tyvars = []) name builtin (args, ret) =
let ty = DT.pi tyvars @@ DT.arrow args ret in
let cst = DE.Id.mk ~name ~builtin (DStd.Path.global name) ty in
Expand All @@ -328,120 +321,146 @@ let fpa_builtins =
Dolmen_type.Base.term_app_cst
(module Dl.Typer.T) env cst
in
let partial1 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f
in
let partial2 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f
in
let is_theory_constant =
let open DT in
let a = Var.mk "alpha" in
op ~tyvars:[a] "is_theory_constant" Is_theory_constant ([of_var a] ->. prop)
in
let fpa_builtins =
let open DT in
Id.Map.empty
let open DT in
match l with
| `Ae ->
begin (* Legacy fpa builtins *)
let float_cst =
let ty = DT.(arrow [int; int; fpa_rounding_mode; real] real) in
DE.Id.mk ~name:"float" ~builtin:Float (DStd.Path.global "float") ty
in
let float prec exp mode x =
DE.Term.apply_cst float_cst [] [prec; exp; mode; x]
in
let mode m =
let cst, _ =
List.find (fun (cst, _args) ->
match cst.DE.path with
| Absolute { name; _ } -> String.equal name m
| Local _ -> false)
rounding_modes
in
DE.Term.apply_cst cst [] []
in
let float32 = float (DE.Term.int "24") (DE.Term.int "149") in
let float32d x = float32 (mode "NearestTiesToEven") x in
let float64 = float (DE.Term.int "53") (DE.Term.int "1074") in
let float64d x = float64 (mode "NearestTiesToEven") x in
let partial1 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app1 (module Dl.Typer.T) env s f
in
let partial2 name f =
Id.Map.add { name = DStd.Name.simple name; ns = Term } @@
fun env s ->
builtin_term @@
Dolmen_type.Base.term_app2 (module Dl.Typer.T) env s f
in
let is_theory_constant =
let open DT in
let a = Var.mk "alpha" in
op ~tyvars:[a] "is_theory_constant" Is_theory_constant ([of_var a] ->. prop)
in
Id.Map.empty

|> add_rounding_modes
|> add_rounding_modes

(* the first argument is mantissas' size (including the implicit bit),
the second one is the exp of the min representable normalized number,
the third one is the rounding mode, and the last one is the real to
be rounded *)
|> op "float" Float ([int; int; fpa_rounding_mode; real] ->. real)
(* the first argument is mantissas' size (including the implicit bit),
the second one is the exp of the min representable normalized number,
the third one is the rounding mode, and the last one is the real to
be rounded *)
|> op "float" Float ([int; int; fpa_rounding_mode; real] ->. real)

|> partial2 "float32" float32
|> partial1 "float32d" float32d
|> partial2 "float32" float32
|> partial1 "float32d" float32d

|> partial2 "float64" float64
|> partial1 "float64d" float64d
|> partial2 "float64" float64
|> partial1 "float64d" float64d

(* rounds to nearest integer *)
|> op "integer_round" Integer_round ([fpa_rounding_mode; real] ->. int)
(* rounds to nearest integer *)
|> op "integer_round" Integer_round ([fpa_rounding_mode; real] ->. int)

(* type cast: from int to real *)
|> dterm "real_of_int" DE.Term.Int.to_real
(* type cast: from int to real *)
|> dterm "real_of_int" DE.Term.Int.to_real

(* type check: integers *)
|> dterm "real_is_int" DE.Term.Real.is_int
(* type check: integers *)
|> dterm "real_is_int" DE.Term.Real.is_int

(* abs value of a real *)
|> op "abs_real" Abs_real ([real] ->. real)
(* abs value of a real *)
|> op "abs_real" Abs_real ([real] ->. real)

(* sqrt value of a real *)
|> op "sqrt_real" Sqrt_real ([real] ->. real)
(* sqrt value of a real *)
|> op "sqrt_real" Sqrt_real ([real] ->. real)

(* sqrt value of a real by default *)
|> op "sqrt_real_default" Sqrt_real_default ([real] ->. real)
(* sqrt value of a real by default *)
|> op "sqrt_real_default" Sqrt_real_default ([real] ->. real)

(* sqrt value of a real by excess *)
|> op "sqrt_real_excess" Sqrt_real_excess ([real] ->. real)
(* sqrt value of a real by excess *)
|> op "sqrt_real_excess" Sqrt_real_excess ([real] ->. real)

(* abs value of an int *)
|> dterm "abs_int" DE.Term.Int.abs
(* abs value of an int *)
|> dterm "abs_int" DE.Term.Int.abs

(* (integer) floor of a rational *)
|> dterm "int_floor" DE.Term.Real.floor_to_int
(* (integer) floor of a rational *)
|> dterm "int_floor" DE.Term.Real.floor_to_int

(* (integer) ceiling of a ratoinal *)
|> op "int_ceil" (Ceiling_to_int `Real) ([real] ->. int)
(* (integer) ceiling of a ratoinal *)
|> op "int_ceil" (Ceiling_to_int `Real) ([real] ->. int)

(* The functions below are only interpreted when applied on constants.
Aximatization for the general case are not currently imlemented *)
(* The functions below are only interpreted when applied on constants.
Aximatization for the general case are not currently imlemented *)

(* maximum of two reals *)
|> op "max_real" Max_real ([real; real] ->. real)
(* maximum of two reals *)
|> op "max_real" Max_real ([real; real] ->. real)

(* maximum of two ints *)
|> op "max_int" Max_int ([int; int] ->. int)
(* maximum of two ints *)
|> op "max_int" Max_int ([int; int] ->. int)

(* minimum of two ints *)
|> op "min_int" Min_int ([int; int] ->. int)
(* minimum of two ints *)
|> op "min_int" Min_int ([int; int] ->. int)

(* computes an integer log2 of a real. The function is only
interpreted on (non-zero) positive real constants. When applied on a
real 'm', the result 'res' of the function is such that: 2^res <= m <
2^(res+1) *)
|> op "integer_log2" Integer_log2 ([real] ->. int)
(* computes an integer log2 of a real. The function is only
interpreted on (non-zero) positive real constants. When applied on a
real 'm', the result 'res' of the function is such that: 2^res <= m <
2^(res+1) *)
|> op "integer_log2" Integer_log2 ([real] ->. int)

(* only used for arithmetic. It should not be used for x in float(x)
to enable computations modulo equality *)
(* only used for arithmetic. It should not be used for x in float(x)
to enable computations modulo equality *)

|> op "not_theory_constant" Not_theory_constant ([real] ->. prop)
|> is_theory_constant
|> op "linear_dependency" Linear_dependency ([real; real] ->. prop)
|> op "not_theory_constant" Not_theory_constant ([real] ->. prop)
|> is_theory_constant
|> op "linear_dependency" Linear_dependency ([real; real] ->. prop)
end (* Legacy fpa builtins *)

in
fun env s ->
begin match s with
| Dl.Typer.T.Id id ->
begin
try
Id.Map.find_exn id fpa_builtins env s
with Not_found -> `Not_found
end
| Builtin _ -> `Not_found
| `Smtlib2 ->
begin (* SMTLib2 builtins *)
Id.Map.empty

|> add_rounding_modes

(* the first argument is mantissas' size (including the implicit bit),
the second one is the exp of the min representable normalized number,
the third one is the rounding mode, and the last one is the real to
be rounded *)
|> op "ae.round" Float ([int; int; fpa_rounding_mode; real] ->. real)
end

(** Concatenation of builtins handlers. *)
(* let (++) bt1 bt2 =
* fun a b ->
* match bt1 a b with
* | `Not_found -> bt2 a b
* | res -> res *)
let (++) bt1 bt2 =
fun a b ->
match bt1 a b with
| `Not_found -> bt2 a b
| res -> res

let builtins =
fun _st (lang : Typer.lang) ->
match lang with
| `Logic Alt_ergo -> fpa_builtins
| `Logic (Smtlib2 _) -> (* fpa_builtins ++ *) bv_builtins
| `Logic Alt_ergo -> fpa_builtins (fpa_builtins_map `Ae)
| `Logic (Smtlib2 _) ->
(fpa_builtins (fpa_builtins_map `Smtlib2)) ++ bv_builtins
| _ -> fun _ _ -> `Not_found

(** Translates dolmen locs to Alt-Ergo's locs *)
Expand Down
2 changes: 1 addition & 1 deletion src/lib/structures/fpa_rounding.mli
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
module AE : S
module SMT2 : S

(** Returns (module SMT2) if the input format is [None] or [Some Smtlib2],
(** Returns (module SMT2) if the input format is [None] or [Some Smtlib2],
otherwise returns [AE]. *)
val fpa_rounding_utils : unit -> (module S)

Expand Down

0 comments on commit b809f5b

Please sign in to comment.