From 9218c4916b7024176e5a8884aa0394e749c4bc88 Mon Sep 17 00:00:00 2001 From: Alasdair Date: Tue, 9 Apr 2024 07:08:07 +0100 Subject: [PATCH] Refactor Sail->SMT backend and rework Sail->SV to share code This commit reworks the Sail->SMT backend to use the new shared primtive generation logic that Sail->SV uses. Various older features have been removed to simplify the code. The Sail->SV backend has been refactored to use modules rather than just functions, but this has left it in quite the broken state for the time being. There are a few known issues with the Sail->SMT pipeline that pre-date this commit, for example function arguments with complex constraints are sometimes generalised, and we would require additional well-formedness checks to be inserted to guard against this. There is now a small test that demonstrates the lack of these well-formedness checks in Sail->SMT Fix path condition computation to use post-dominators in order to handle cases where the graph has some non-structured control flow --- aarch64_small/armV8_A64_sys_regs.sail | 2 +- aarch64_small/prelude.sail | 7 +- language/jib.ott | 12 +- lib/sv/sail.sv | 8 - src/lib/ast_util.mli | 2 +- src/lib/jib_compile.ml | 21 +- src/lib/jib_compile.mli | 4 +- src/lib/jib_optimize.ml | 25 +- src/{sail_smt_backend => lib}/jib_ssa.ml | 199 +- src/{sail_smt_backend => lib}/jib_ssa.mli | 15 +- src/lib/jib_util.ml | 75 +- src/lib/jib_util.mli | 2 + src/lib/jib_visitor.ml | 26 +- src/lib/jib_visitor.mli | 14 +- src/lib/property.ml | 116 +- src/lib/smt_exp.ml | 392 ++- src/lib/smt_gen.ml | 264 +- src/lib/smt_gen.mli | 30 +- src/lib/type_check.ml | 3 +- src/sail_c_backend/c_backend.ml | 23 +- src/sail_smt_backend/jib_smt.ml | 2909 ++++++--------------- src/sail_smt_backend/jib_smt.mli | 128 +- src/sail_smt_backend/sail_plugin_smt.ml | 59 +- src/sail_smt_backend/smtlib.ml | 750 ------ src/sail_sv_backend/generate_primop2.ml | 128 + src/sail_sv_backend/jib_sv.ml | 1000 +++++-- src/sail_sv_backend/jib_sv.mli | 140 + src/sail_sv_backend/sail_plugin_sv.ml | 53 +- src/sail_sv_backend/sv_ir.ml | 341 +++ src/sail_sv_backend/sv_ir.mli | 171 ++ test/smt/issue573_1.sat.sail | 10 + test/smt/issue573_2.sat.sail | 10 + test/smt/linked_int.unsat.sail | 6 + test/smt/linked_int2.unsat.sail | 6 + test/smt/lzcnt.unsat.sail | 2 +- test/smt/revrev_endianness2.unsat.sail | 2 + test/smt/revrev_endianness3.unsat.sail | 22 + test/smt/run_tests.py | 9 +- test/smt/rv_add_1.unsat.sail | 2 + test/smt/string.unsat.sail | 4 - test/sv/.gitignore | 2 + test/sv/run_tests.py | 1 + 42 files changed, 3588 insertions(+), 3407 deletions(-) rename src/{sail_smt_backend => lib}/jib_ssa.ml (83%) rename src/{sail_smt_backend => lib}/jib_ssa.mli (94%) delete mode 100644 src/sail_smt_backend/smtlib.ml create mode 100644 src/sail_sv_backend/generate_primop2.ml create mode 100644 src/sail_sv_backend/jib_sv.mli create mode 100644 src/sail_sv_backend/sv_ir.ml create mode 100644 src/sail_sv_backend/sv_ir.mli create mode 100644 test/smt/issue573_1.sat.sail create mode 100644 test/smt/issue573_2.sat.sail create mode 100644 test/smt/linked_int.unsat.sail create mode 100644 test/smt/linked_int2.unsat.sail create mode 100644 test/smt/revrev_endianness3.unsat.sail create mode 100644 test/sv/.gitignore diff --git a/aarch64_small/armV8_A64_sys_regs.sail b/aarch64_small/armV8_A64_sys_regs.sail index 20aa4a5c5..43f5de7e0 100644 --- a/aarch64_small/armV8_A64_sys_regs.sail +++ b/aarch64_small/armV8_A64_sys_regs.sail @@ -176,7 +176,7 @@ register SCTLR_EL3 : SCTLR_type /* System Control Register (EL3) */ /* CP: added coercion from SCTLR_EL1_type to SCTLR_type for the SCTLR function */ -val cast "SCTLR_EL1_type_to_SCTLR_type" : SCTLR_EL1_type -> SCTLR_type +val "SCTLR_EL1_type_to_SCTLR_type" : SCTLR_EL1_type -> SCTLR_type bitfield TCR_EL1_type : bits(64) = { diff --git a/aarch64_small/prelude.sail b/aarch64_small/prelude.sail index 2e5412a56..871ed742b 100644 --- a/aarch64_small/prelude.sail +++ b/aarch64_small/prelude.sail @@ -59,17 +59,14 @@ function operator <=_u (x, y) = unsigned(x) <= unsigned(y) val pow2_atom = "pow2" : forall 'n. int('n) -> int(2 ^ 'n) val pow2_int = "pow2" : int -> int -overload pow2 = {pow2_atom, pow2_int} - - -val cast cast_bool_bit : bool -> bit +val cast_bool_bit : bool -> bit function cast_bool_bit(b) = match b { true => b1, false => b0 } -val cast cast_bit_bool : bit -> bool +val cast_bit_bool : bit -> bool function cast_bit_bool (b) = match b { bitzero => false, diff --git a/language/jib.ott b/language/jib.ott index 2e9ee337a..771072d11 100644 --- a/language/jib.ott +++ b/language/jib.ott @@ -72,11 +72,16 @@ type iannot = int * ocaml_l grammar +chan :: 'Chan_' ::= + | stdout :: :: stdout + | stderr :: :: stderr + name :: '' ::= | id nat :: :: name | have_exception nat :: :: have_exception | current_exception nat :: :: current_exception | throw_location nat :: :: throw_location + | channel chan nat :: :: channel | return nat :: :: return op :: '' ::= @@ -121,6 +126,7 @@ uid :: 'UId_' ::= cval :: 'V_' ::= | name : ctyp :: :: id + | id : ctyp :: :: member | value : ctyp :: :: lit | ( cval0 , ... , cvaln ) ctyp :: :: tuple | struct { id0 = cval0 , ... , idn = cvaln } ctyp :: :: struct @@ -205,6 +211,10 @@ iannot :: '' ::= {{ lem iannot }} {{ ocaml iannot }} +creturn :: 'CR_' ::= + | clexp :: :: one + | ( clexp0 , ... , clexpn ) :: :: multi + instr :: 'I_' ::= {{ aux _ iannot }} % The following are the minimal set of instructions output by @@ -214,7 +224,7 @@ instr :: 'I_' ::= | jump ( cval ) string :: :: jump | goto string :: :: goto | string : :: :: label - | clexp = bool uid ( cval0 , ... , cvaln ) :: :: funcall + | creturn = bool uid ( cval0 , ... , cvaln ) :: :: funcall | clexp = cval :: :: copy | clear ctyp name :: :: clear | undefined ctyp :: :: undefined diff --git a/lib/sv/sail.sv b/lib/sv/sail.sv index aee8103ff..28d3eb197 100644 --- a/lib/sv/sail.sv +++ b/lib/sv/sail.sv @@ -33,10 +33,6 @@ function automatic bit sail_eq_string(sail_unit s1, sail_unit s2); return 0; endfunction -function automatic sail_unit sail_concat_str(sail_unit s1, sail_unit s2); - return SAIL_UNIT; -endfunction - `else function automatic sail_unit sail_print_endline(string s); @@ -70,10 +66,6 @@ function automatic bit sail_eq_string(string s1, string s2); return s1 == s2; endfunction -function automatic string sail_concat_str(string s1, string s2); - return {s1, s2}; -endfunction - `endif typedef enum logic [0:0] {SAIL_REAL} sail_real; diff --git a/src/lib/ast_util.mli b/src/lib/ast_util.mli index 5f9f47e55..f69d49ece 100644 --- a/src/lib/ast_util.mli +++ b/src/lib/ast_util.mli @@ -383,7 +383,7 @@ module Typ : sig end module IdSet : sig - include Set.S with type elt = id + include Set.S with type elt = id and type t = Set.Make(Id).t end module NexpSet : sig diff --git a/src/lib/jib_compile.ml b/src/lib/jib_compile.ml index 0fa2f5e8c..82d6c8a46 100644 --- a/src/lib/jib_compile.ml +++ b/src/lib/jib_compile.ml @@ -258,6 +258,7 @@ module type CONFIG = sig val use_real : bool val branch_coverage : out_channel option val track_throw : bool + val use_void : bool end module IdGraph = Graph.Make (Id) @@ -361,6 +362,7 @@ module Make (C : CONFIG) = struct ([iinit l ctyp' gs cval], V_id (gs, ctyp'), [iclear ctyp' gs]) ) else ([], cval, []) + | AV_id (id, Enum typ) -> ([], V_member (id, ctyp_of_typ ctx typ), []) | AV_id (id, typ) -> begin match Bindings.find_opt id ctx.locals with | Some (_, ctyp) -> ([], V_id (name id, ctyp), []) @@ -635,7 +637,7 @@ module Make (C : CONFIG) = struct | AP_id (pid, _) when is_ct_enum ctyp -> begin match Env.lookup_id pid ctx.tc_env with | Unbound _ -> ([], [idecl l ctyp (name pid); icopy l (CL_id (name pid, ctyp)) cval], [], ctx) - | _ -> ([ijump l (V_call (Neq, [V_id (name pid, ctyp); cval])) case_label], [], [], ctx) + | _ -> ([ijump l (V_call (Neq, [V_member (pid, ctyp); cval])) case_label], [], [], ctx) end | AP_id (pid, typ) -> let id_ctyp = ctyp_of_typ ctx typ in @@ -1132,8 +1134,11 @@ module Make (C : CONFIG) = struct | (AE_aux (_, { loc = l; _ }) as exp) :: exps -> let setup, call, cleanup = compile_aexp ctx exp in let rest = compile_block ctx exps in - let gs = ngensym () in - iblock (setup @ [idecl l CT_unit gs; call (CL_id (gs, CT_unit))] @ cleanup) :: rest + if C.use_void then iblock (setup @ [call CL_void] @ cleanup) :: rest + else ( + let gs = ngensym () in + iblock (setup @ [idecl l CT_unit gs; call (CL_id (gs, CT_unit))] @ cleanup) :: rest + ) let fast_int = function CT_lint when !optimize_aarch64_fast_struct -> CT_fint 64 | ctyp -> ctyp @@ -1825,7 +1830,7 @@ module Make (C : CONFIG) = struct let rec specialize_variants ctx prior = let instantiations = ref CTListSet.empty in - let fix_variants ctx var_id = visit_ctyp (new fix_variants_visitor ctx var_id) in + let fix_variants ctx var_id = visit_ctyp (new fix_variants_visitor ctx var_id :> common_visitor) in let specialize_constructor ctx ctor_id = visit_cdefs (new specialize_constructor_visitor instantiations ctx ctor_id) @@ -1995,7 +2000,7 @@ module Make (C : CONFIG) = struct let precise_call call tail = match call with - | I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), ((_, l) as aux)) as instr -> begin + | I_aux (I_funcall (CR_one clexp, extern, (id, ctyp_args), args), ((_, l) as aux)) as instr -> begin match get_function_typ id with | None when string_of_id id = "sail_cons" -> begin match (ctyp_args, args) with @@ -2007,7 +2012,9 @@ module Make (C : CONFIG) = struct [ iblock (cast - @ [I_aux (I_funcall (clexp, extern, (id, ctyp_args), [V_id (gs, ctyp_arg); tl_arg]), aux)] + @ [ + I_aux (I_funcall (CR_one clexp, extern, (id, ctyp_args), [V_id (gs, ctyp_arg); tl_arg]), aux); + ] @ tail @ cleanup ); ] @@ -2051,7 +2058,7 @@ module Make (C : CONFIG) = struct [ iblock1 (casts @ ret_setup - @ [I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), aux)] + @ [I_aux (I_funcall (CR_one clexp, extern, (id, ctyp_args), args), aux)] @ tail @ ret_cleanup @ cleanup ); ] diff --git a/src/lib/jib_compile.mli b/src/lib/jib_compile.mli index ab28de601..625eafb1b 100644 --- a/src/lib/jib_compile.mli +++ b/src/lib/jib_compile.mli @@ -158,10 +158,12 @@ module type CONFIG = sig for debugging C but we want to turn it off for SMT generation where we can't use strings *) val track_throw : bool + + val use_void : bool end module IdGraph : sig - include Graph.S with type node = id + include Graph.S with type node = id and type node_set = IdSet.t end val callgraph : cdef list -> IdGraph.graph diff --git a/src/lib/jib_optimize.ml b/src/lib/jib_optimize.ml index 5a25e2853..4f49cdf9d 100644 --- a/src/lib/jib_optimize.ml +++ b/src/lib/jib_optimize.ml @@ -73,9 +73,9 @@ open Jib_util let optimize_unit instrs = let unit_cval cval = match cval_ctyp cval with CT_unit -> V_lit (VL_unit, CT_unit) | _ -> cval in let unit_instr = function - | I_aux (I_funcall (clexp, extern, id, args), annot) as instr -> begin + | I_aux (I_funcall (CR_one clexp, extern, id, args), annot) as instr -> begin match clexp_ctyp clexp with - | CT_unit -> I_aux (I_funcall (CL_void, extern, id, List.map unit_cval args), annot) + | CT_unit -> I_aux (I_funcall (CR_one CL_void, extern, id, List.map unit_cval args), annot) | _ -> instr end | I_aux (I_copy (clexp, cval), annot) as instr -> begin @@ -163,6 +163,7 @@ let unique_per_function_ids cdefs = let rec cval_subst id subst = function | V_id (id', ctyp) -> if Name.compare id id' = 0 then subst else V_id (id', ctyp) + | V_member (id, ctyp) -> V_member (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (op, cvals) -> V_call (op, List.map (cval_subst id subst) cvals) | V_field (cval, field) -> V_field (cval_subst id subst cval, field) @@ -174,6 +175,7 @@ let rec cval_subst id subst = function let rec cval_map_id f = function | V_id (id, ctyp) -> V_id (f id, ctyp) + | V_member (id, ctyp) -> V_member (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (call, cvals) -> V_call (call, List.map (cval_map_id f) cvals) | V_field (cval, field) -> V_field (cval_map_id f cval, field) @@ -259,6 +261,10 @@ let rec clexp_subst id subst = function | CL_void -> CL_void | CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Cannot substitute into read-modify-write construct" +let creturn_subst id subst = function + | CR_one clexp -> CR_one (clexp_subst id subst clexp) + | CR_multi clexps -> CR_multi (List.map (clexp_subst id subst) clexps) + let rec find_function fid = function | CDEF_aux (CDEF_fundef (fid', heap_return, args, body), _) :: _ when Id.compare fid fid' = 0 -> Some (heap_return, args, body) @@ -271,14 +277,15 @@ let ssa_name i = function | Current_exception _ -> Current_exception i | Throw_location _ -> Throw_location i | Return _ -> Return i + | Channel (chan, _) -> Channel (chan, i) let inline cdefs should_inline instrs = let inlines = ref (-1) in let label_count = ref (-1) in let replace_return subst = function - | I_aux (I_funcall (clexp, extern, fid, args), aux) -> - I_aux (I_funcall (clexp_subst return subst clexp, extern, fid, args), aux) + | I_aux (I_funcall (creturn, extern, fid, args), aux) -> + I_aux (I_funcall (creturn_subst return subst creturn, extern, fid, args), aux) | I_aux (I_copy (clexp, cval), aux) -> I_aux (I_copy (clexp_subst return subst clexp, cval), aux) | instr -> instr in @@ -314,7 +321,8 @@ let inline cdefs should_inline instrs = in let inline_instr = function - | I_aux (I_funcall (clexp, false, function_id, args), aux) as instr when should_inline (fst function_id) -> begin + | I_aux (I_funcall (CR_one clexp, false, function_id, args), aux) as instr when should_inline (fst function_id) -> + begin match find_function (fst function_id) cdefs with | Some (None, ids, body) -> incr inlines; @@ -446,6 +454,7 @@ let remove_tuples cdefs ctx = ctyp and fix_cval = function | V_id (id, ctyp) -> V_id (id, ctyp) + | V_member (id, ctyp) -> V_member (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_ctor_kind (cval, ctor, ctyp) -> V_ctor_kind (fix_cval cval, ctor, ctyp) | V_ctor_unwrap (cval, ctor, ctyp) -> V_ctor_unwrap (fix_cval cval, ctor, ctyp) @@ -485,8 +494,12 @@ let remove_tuples cdefs ctx = | CL_void -> CL_void | CL_rmw (read, write, ctyp) -> CL_rmw (read, write, ctyp) in + let fix_creturn = function + | CR_one clexp -> CR_one (fix_clexp clexp) + | CR_multi clexps -> CR_multi (List.map fix_clexp clexps) + in let rec fix_instr_aux = function - | I_funcall (clexp, extern, id, args) -> I_funcall (fix_clexp clexp, extern, id, List.map fix_cval args) + | I_funcall (creturn, extern, id, args) -> I_funcall (fix_creturn creturn, extern, id, List.map fix_cval args) | I_copy (clexp, cval) -> I_copy (fix_clexp clexp, fix_cval cval) | I_init (ctyp, id, cval) -> I_init (ctyp, id, fix_cval cval) | I_reinit (ctyp, id, cval) -> I_reinit (ctyp, id, fix_cval cval) diff --git a/src/sail_smt_backend/jib_ssa.ml b/src/lib/jib_ssa.ml similarity index 83% rename from src/sail_smt_backend/jib_ssa.ml rename to src/lib/jib_ssa.ml index 5fb86ced3..12efc9138 100644 --- a/src/sail_smt_backend/jib_ssa.ml +++ b/src/lib/jib_ssa.ml @@ -65,21 +65,32 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Libsail - open Ast_util open Jib open Jib_util -module IntSet = Set.Make (struct - type t = int - let compare = compare -end) +module IntSet = Util.IntSet module IntMap = Map.Make (struct type t = int let compare = compare end) +let ssa_name i = function + | Name (id, _) -> Name (id, i) + | Have_exception _ -> Have_exception i + | Current_exception _ -> Current_exception i + | Throw_location _ -> Throw_location i + | Channel (c, _) -> Channel (c, i) + | Return _ -> Return i + +let unssa_name = function + | Name (id, n) -> (Name (id, -1), n) + | Have_exception n -> (Have_exception (-1), n) + | Current_exception n -> (Current_exception (-1), n) + | Throw_location n -> (Throw_location (-1), n) + | Channel (c, n) -> (Channel (c, -1), n) + | Return n -> (Return (-1), n) + (**************************************************************************) (* 1. Mutable graph type *) (**************************************************************************) @@ -215,7 +226,12 @@ type terminator = | T_label of string | T_none -type cf_node = CF_label of string | CF_block of instr list * terminator | CF_guard of int | CF_start of ctyp NameMap.t +type cf_node = + | CF_label of string + | CF_block of instr list * terminator + | CF_guard of int + | CF_start of ctyp NameMap.t + | CF_end let to_terminator graph = function | I_label label -> T_label label @@ -247,6 +263,9 @@ let control_flow_graph instrs = match aux with I_label _ | I_goto _ | I_jump _ | I_end _ | I_exit _ | I_undefined _ -> true | _ -> false in + let start = add_vertex ([], CF_start NameMap.empty) graph in + let finish = add_vertex ([], CF_end) graph in + let rec cfg preds instrs = let before, after = instr_split_at cf_split instrs in let terminator, after = @@ -261,7 +280,9 @@ let control_flow_graph instrs = [n] in match terminator with - | T_end _ | T_exit _ | T_undefined _ -> cfg [] after + | T_end _ | T_exit _ | T_undefined _ -> + List.iter (fun p -> add_edge p finish graph) preds; + cfg [] after | T_goto label -> List.iter (fun p -> add_edge p (StringMap.find label !labels) graph) preds; cfg [] after @@ -280,8 +301,7 @@ let control_flow_graph instrs = | T_none -> preds in - let start = add_vertex ([], CF_start NameMap.empty) graph in - let finish = cfg [start] instrs in + let _ = cfg [start] instrs in let visited = reachable (IntSet.singleton start) graph in prune visited graph; @@ -292,11 +312,17 @@ let control_flow_graph instrs = (* 3. Computing dominators *) (**************************************************************************) +(* If we are computing post-dominators rather than dominators, we + swap the graph ordering. *) +let graph_order ~post predecessors successors = if post then (successors, predecessors) else (predecessors, successors) + (** Calculate the (immediate) dominators of a graph using the Lengauer-Tarjan algorithm. This is the slightly less sophisticated version from Appel's book 'Modern compiler implementation in ML' - which runs in O(n log(n)) time. *) -let immediate_dominators graph root = + which runs in O(n log(n)) time. + + If the post flag is set this computes the post-dominators. *) +let immediate_dominators ?(post = false) graph root = let none = -1 in let vertex = Array.make (cardinal graph) 0 in let parent = Array.make (cardinal graph) none in @@ -333,7 +359,9 @@ let immediate_dominators graph root = parent.(n) <- p; incr count; match graph.nodes.(n) with - | Some (_, _, successors) -> IntSet.iter (fun w -> dfs n w) successors + | Some (_, predecessors, successors) -> + let predecessors, successors = graph_order ~post predecessors successors in + IntSet.iter (fun w -> dfs n w) successors | None -> assert false end in @@ -346,7 +374,8 @@ let immediate_dominators graph root = begin match graph.nodes.(n) with - | Some (_, predecessors, _) -> + | Some (_, predecessors, successors) -> + let predecessors, successors = graph_order ~post predecessors successors in IntSet.iter (fun v -> let s' = if dfnum.(v) <= dfnum.(n) then v else semi.(ancestor_with_lowest_semi v) in @@ -390,7 +419,7 @@ let rec dominate idom n w = let p = idom.(n) in if p = none then false else if p = w then true else dominate idom p w -let dominance_frontiers graph root idom children = +let dominance_frontiers ?(post = false) graph root idom children = let df = Array.make (cardinal graph) IntSet.empty in let rec compute_df n = @@ -398,7 +427,9 @@ let dominance_frontiers graph root idom children = begin match graph.nodes.(n) with - | Some (content, _, succs) -> IntSet.iter (fun y -> if idom.(y) <> n then set := IntSet.add y !set) succs + | Some (content, predecessors, successors) -> + let predecessors, successors = graph_order ~post predecessors successors in + IntSet.iter (fun y -> if idom.(y) <> n then set := IntSet.add y !set) successors | None -> () end; IntSet.iter @@ -486,14 +517,6 @@ let rename_variables graph root children = let phi_zeros = ref NameMap.empty in - let ssa_name i = function - | Name (id, _) -> Name (id, i) - | Have_exception _ -> Have_exception i - | Current_exception _ -> Current_exception i - | Throw_location _ -> Throw_location i - | Return _ -> Return i - in - let get_count id = match NameMap.find_opt id !counts with Some n -> n | None -> 0 in let top_stack id = match NameMap.find_opt id !stacks with Some (x :: _) -> x | Some [] -> 0 | None -> 0 in let top_stack_phi id ctyp = @@ -512,6 +535,7 @@ let rename_variables graph root children = | V_id (id, ctyp) -> let i = top_stack id in V_id (ssa_name i id, ctyp) + | V_member (id, ctyp) -> V_member (id, ctyp) | V_lit (vl, ctyp) -> V_lit (vl, ctyp) | V_call (id, fs) -> V_call (id, List.map fold_cval fs) | V_field (f, field) -> V_field (fold_cval f, field) @@ -540,13 +564,17 @@ let rename_variables graph root children = | CL_tuple (clexp, n) -> CL_tuple (fold_clexp true clexp, n) | CL_void -> CL_void in + let fold_creturn = function + | CR_one clexp -> CR_one (fold_clexp false clexp) + | CR_multi clexps -> CR_multi (List.map (fold_clexp false) clexps) + in let ssa_instr (I_aux (aux, annot)) = let aux = match aux with - | I_funcall (clexp, extern, id, args) -> + | I_funcall (creturn, extern, id, args) -> let args = List.map fold_cval args in - I_funcall (fold_clexp false clexp, extern, id, args) + I_funcall (fold_creturn creturn, extern, id, args) | I_copy (clexp, cval) -> let cval = fold_cval cval in I_copy (fold_clexp false clexp, cval) @@ -587,6 +615,7 @@ let rename_variables graph root children = CF_block (instrs, ssa_terminator terminator) | CF_label label -> CF_label label | CF_guard cond -> CF_guard cond + | CF_end -> CF_end in let ssa_ssanode = function @@ -642,35 +671,79 @@ let rename_variables graph root children = | Some ((ssa, CF_start _), preds, succs) -> graph.nodes.(root) <- Some ((ssa, CF_start !phi_zeros), preds, succs) | _ -> failwith "root node is not CF_start" -let place_pi_functions graph start idom children = +let is_true_literal = function V_lit (VL_bool true, _) -> true | _ -> false + +let is_false_literal = function V_lit (VL_bool false, _) -> true | _ -> false + +let simp_disj = function + | [x; V_call (Bnot, [y])] when x = y -> [V_lit (VL_bool true, CT_bool)] + | xs -> + if List.exists is_true_literal xs then [V_lit (VL_bool true, CT_bool)] + else List.filter (fun x -> not (is_false_literal x)) xs + +let simp_conj = function + | [x; V_call (Bnot, [y])] when x = y -> [V_lit (VL_bool false, CT_bool)] + | xs -> + if List.exists is_false_literal xs then [V_lit (VL_bool false, CT_bool)] + else List.filter (fun x -> not (is_true_literal x)) xs + +let place_pi_functions ~start ~finish ~post_idom ~post_df graph = let get_guard = function | CF_guard cond -> begin match IntMap.find_opt (abs cond) graph.conds with - | Some guard when cond > 0 -> [guard] - | Some guard -> [V_call (Bnot, [guard])] + | Some guard when cond > 0 -> Some guard + | Some guard -> Some (V_call (Bnot, [guard])) | None -> assert false end - | _ -> [] + | _ -> None in - let get_pi_contents ssanodes = List.concat (List.map (function Pi guards -> guards | _ -> []) ssanodes) in + let get_pi ssanode = List.concat (List.map (function Pi guards -> guards | _ -> []) ssanode) in + + let mk_disj xs = match simp_disj xs with [] -> V_lit (VL_bool false, CT_bool) | [x] -> x | xs -> V_call (Bor, xs) in + let mk_conj xs = match simp_conj xs with [] -> V_lit (VL_bool true, CT_bool) | [x] -> x | xs -> V_call (Band, xs) in + let visited = ref IntSet.empty in let rec go n = - begin + if not (IntSet.mem n !visited) then ( match graph.nodes.(n) with | Some ((ssa, cfnode), preds, succs) -> - let p = idom.(n) in - if p <> -1 then begin - match graph.nodes.(p) with - | Some ((dom_ssa, _), _, _) -> - let args = get_guard cfnode @ get_pi_contents dom_ssa in - graph.nodes.(n) <- Some ((Pi args :: ssa, cfnode), preds, succs) - | None -> assert false - end - | None -> assert false - end; - IntSet.iter go children.(n) + let disj = + List.map + (fun post_frontier -> + assert (post_frontier <> n); + go post_frontier; + match graph.nodes.(post_frontier) with + | Some ((ssanode, _), _, succs) -> + let pathcond = get_pi ssanode in + let disj = + List.filter_map + (fun s -> + if s = n || dominate post_idom s n then ( + let (_, cfnode), _, _ = Option.get graph.nodes.(s) in + get_guard cfnode + ) + else None + ) + (IntSet.elements succs) + in + mk_disj disj :: pathcond + | None -> assert false + ) + (IntSet.elements post_df.(n)) + in + let mk_pi = function + | [] -> Pi [] + | [conj] -> Pi conj + | conjs -> Pi [mk_disj (List.map (fun conj -> mk_conj conj) conjs)] + in + visited := IntSet.add n !visited; + graph.nodes.(n) <- Some ((mk_pi disj :: ssa, cfnode), preds, succs) + | None -> () + ) in - go start + for n = 0 to graph.next - 1 do + go n + done (** Remove p nodes. Assumes the graph is acyclic. *) let remove_nodes remove_cf graph = @@ -697,16 +770,6 @@ let remove_nodes remove_cf graph = | _ -> () done -let ssa instrs = - let start, finish, cfg = control_flow_graph instrs in - let idom = immediate_dominators cfg start in - let children = dominator_children idom in - let df = dominance_frontiers cfg start idom children in - place_phi_functions cfg df; - rename_variables cfg start children; - place_pi_functions cfg start idom children; - (start, cfg) - (* Debugging utilities for outputing Graphviz files. *) let string_of_ssainstr = function @@ -719,20 +782,17 @@ let string_of_phis = function [] -> "" | phis -> Util.string_of_list "\\l" strin let string_of_node = function | phis, CF_label label -> string_of_phis phis ^ label | phis, CF_block (instrs, terminator) -> - let string_of_instr instr = - let buf = Buffer.create 128 in - Jib_ir.Flat_ir_formatter.output_instr 0 buf 0 Jib_ir.StringMap.empty instr; - Buffer.contents buf - in string_of_phis phis ^ Util.string_of_list "\\l" (fun instr -> String.escaped (string_of_instr instr)) instrs | phis, CF_start inits -> string_of_phis phis ^ "START" | phis, CF_guard cval -> string_of_phis phis ^ string_of_int cval + | phis, CF_end -> string_of_phis phis ^ "END" let vertex_color = function | _, CF_start _ -> "peachpuff" | _, CF_block _ -> "white" | _, CF_label _ -> "springgreen" | _, CF_guard _ -> "yellow" + | _, CF_end -> "red" let make_dot out_chan graph = Util.opt_colors := false; @@ -776,3 +836,24 @@ let make_dominators_dot out_chan idom graph = done; output_string out_chan "}\n"; Util.opt_colors := true + +let ssa ?debug_prefix instrs = + let start, finish, cfg = control_flow_graph instrs in + let idom = immediate_dominators cfg start in + let post_idom = immediate_dominators ~post:true cfg finish in + begin + match debug_prefix with + | Some prefix -> + let out_chan = open_out (prefix ^ "_post_doms.gv") in + make_dominators_dot out_chan post_idom cfg; + close_out out_chan + | None -> () + end; + let children = dominator_children idom in + let post_children = dominator_children post_idom in + let df = dominance_frontiers cfg start idom children in + let post_df = dominance_frontiers ~post:true cfg finish post_idom post_children in + place_phi_functions cfg df; + rename_variables cfg start children; + place_pi_functions ~start ~finish ~post_idom ~post_df cfg; + (start, cfg) diff --git a/src/sail_smt_backend/jib_ssa.mli b/src/lib/jib_ssa.mli similarity index 94% rename from src/sail_smt_backend/jib_ssa.mli rename to src/lib/jib_ssa.mli index 89097d298..cd55226af 100644 --- a/src/sail_smt_backend/jib_ssa.mli +++ b/src/lib/jib_ssa.mli @@ -65,11 +65,13 @@ (* SUCH DAMAGE. *) (****************************************************************************) -open Libsail - open Array open Jib_util +val ssa_name : int -> Jib.name -> Jib.name + +val unssa_name : Jib.name -> Jib.name * int + (** A mutable array based graph type, with nodes indexed by integers. *) type 'a array_graph @@ -77,7 +79,7 @@ type 'a array_graph underlying array. *) val make : initial_size:int -> unit -> 'a array_graph -module IntSet : Set.S with type elt = int +module IntSet = Util.IntSet val get_cond : 'a array_graph -> int -> Jib.cval @@ -112,17 +114,18 @@ type cf_node = | CF_block of Jib.instr list * terminator | CF_guard of int | CF_start of Jib.ctyp NameMap.t + | CF_end -val control_flow_graph : Jib.instr list -> int * int list * ('a list * cf_node) array_graph +val control_flow_graph : Jib.instr list -> int * int * ('a list * cf_node) array_graph (** [immediate_dominators graph root] will calculate the immediate dominators for a control flow graph with a specified root node. *) -val immediate_dominators : 'a array_graph -> int -> int array +val immediate_dominators : ?post:bool -> 'a array_graph -> int -> int array type ssa_elem = Phi of Jib.name * Jib.ctyp * Jib.name list | Pi of Jib.cval list (** Convert a list of instructions into SSA form *) -val ssa : Jib.instr list -> int * (ssa_elem list * cf_node) array_graph +val ssa : ?debug_prefix:string -> Jib.instr list -> int * (ssa_elem list * cf_node) array_graph (** Output the control-flow graph in graphviz format for debugging. Can use 'dot -Tpng X.gv -o X.png' to generate a png diff --git a/src/lib/jib_util.ml b/src/lib/jib_util.ml index 9fdd9c322..1d3da1769 100644 --- a/src/lib/jib_util.ml +++ b/src/lib/jib_util.ml @@ -101,9 +101,11 @@ let iinit l ctyp id cval = I_aux (I_init (ctyp, id, cval), (instr_number (), l)) let iif l cval then_instrs else_instrs ctyp = I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (instr_number (), l)) -let ifuncall l clexp id cvals = I_aux (I_funcall (clexp, false, id, cvals), (instr_number (), l)) +let ifuncall l clexp id cvals = I_aux (I_funcall (CR_one clexp, false, id, cvals), (instr_number (), l)) -let iextern l clexp id cvals = I_aux (I_funcall (clexp, true, id, cvals), (instr_number (), l)) +let ifuncall_multi l clexps id cvals = I_aux (I_funcall (CR_multi clexps, false, id, cvals), (instr_number (), l)) + +let iextern l clexp id cvals = I_aux (I_funcall (CR_one clexp, true, id, cvals), (instr_number (), l)) let icopy l clexp cval = I_aux (I_copy (clexp, cval), (instr_number (), l)) @@ -145,6 +147,13 @@ module Name = struct | Have_exception n, Have_exception m -> compare n m | Current_exception n, Current_exception m -> compare n m | Return n, Return m -> compare n m + | Channel (c1, n), Channel (c2, m) -> begin + match (c1, c2) with + | Chan_stdout, Chan_stdout -> compare n m + | Chan_stderr, Chan_stderr -> compare n m + | Chan_stdout, Chan_stderr -> 1 + | Chan_stderr, Chan_stdout -> -1 + end | Name _, _ -> 1 | _, Name _ -> -1 | Have_exception _, _ -> 1 @@ -153,6 +162,8 @@ module Name = struct | _, Current_exception _ -> -1 | Throw_location _, _ -> 1 | _, Throw_location _ -> -1 + | Return _, _ -> 1 + | _, Return _ -> -1 end module NameSet = Set.Make (Name) @@ -207,6 +218,9 @@ let string_of_name ?deref_current_exception:(dce = false) ?(zencode = true) = | Current_exception n when dce -> "(*current_exception)" ^ ssa_num n | Current_exception n -> "current_exception" ^ ssa_num n | Throw_location n -> "throw_location" ^ ssa_num n + | Channel (chan, n) -> ( + match chan with Chan_stdout -> "stdout" ^ ssa_num n | Chan_stderr -> "stderr" ^ ssa_num n + ) let string_of_op = function | Bnot -> "@not" @@ -307,6 +321,7 @@ let string_of_value = function let rec string_of_cval = function | V_id (id, _) -> string_of_name id + | V_member (id, _) -> string_of_id id | V_lit (VL_undefined, ctyp) -> string_of_value VL_undefined ^ " : " ^ string_of_ctyp ctyp | V_lit (vl, ctyp) -> string_of_value vl | V_call (op, cvals) -> Printf.sprintf "%s(%s)" (string_of_op op) (Util.string_of_list ", " string_of_cval cvals) @@ -335,6 +350,10 @@ let rec string_of_clexp = function | CL_void -> "void" | CL_rmw (id1, id2, ctyp) -> Printf.sprintf "rmw(%s, %s)" (string_of_name id1) (string_of_name id2) +let string_of_creturn = function + | CR_one clexp -> string_of_clexp clexp + | CR_multi clexps -> "(" ^ Util.string_of_list ", " string_of_clexp clexps ^ ")" + let rec doc_instr (I_aux (aux, _)) = let open Printf in let instr s = twice space ^^ string s in @@ -356,11 +375,11 @@ let rec doc_instr (I_aux (aux, _)) = | I_comment str -> twice space ^^ string "//" ^^ string str | I_throw cval -> ksprintf instr "throw %s" (string_of_cval cval) | I_return cval -> ksprintf instr "return %s" (string_of_cval cval) - | I_funcall (clexp, false, uid, args) -> - ksprintf instr "%s = %s(%s)" (string_of_clexp clexp) (string_of_uid uid) + | I_funcall (creturn, false, uid, args) -> + ksprintf instr "%s = %s(%s)" (string_of_creturn creturn) (string_of_uid uid) (Util.string_of_list ", " string_of_cval args) - | I_funcall (clexp, true, uid, args) -> - ksprintf instr "%s = $%s(%s)" (string_of_clexp clexp) (string_of_uid uid) + | I_funcall (creturn, true, uid, args) -> + ksprintf instr "%s = $%s(%s)" (string_of_creturn creturn) (string_of_uid uid) (Util.string_of_list ", " string_of_cval args) | I_copy (clexp, cval) -> ksprintf instr "%s = %s" (string_of_clexp clexp) (string_of_cval cval) | I_block instrs -> @@ -635,7 +654,7 @@ let rec is_polymorphic = function let rec cval_deps = function | V_id (id, _) -> NameSet.singleton id - | V_lit _ -> NameSet.empty + | V_lit _ | V_member _ -> NameSet.empty | V_field (cval, _) | V_tuple_member (cval, _, _) -> cval_deps cval | V_call (_, cvals) | V_tuple (cvals, _) -> List.fold_left NameSet.union NameSet.empty (List.map cval_deps cvals) | V_ctor_kind (cval, _, _) -> cval_deps cval @@ -650,6 +669,16 @@ let rec clexp_deps = function | CL_addr clexp -> clexp_deps clexp | CL_void -> (NameSet.empty, NameSet.empty) +let creturn_deps = function + | CR_one clexp -> clexp_deps clexp + | CR_multi clexps -> + List.fold_left + (fun (reads, writes) clexp -> + let new_reads, new_writes = clexp_deps clexp in + (NameSet.union reads new_reads, NameSet.union writes new_writes) + ) + (NameSet.empty, NameSet.empty) clexps + (* Return the direct, read/write dependencies of a single instruction *) let instr_deps = function | I_decl (_, id) -> (NameSet.empty, NameSet.singleton id) @@ -657,8 +686,8 @@ let instr_deps = function | I_init (_, id, cval) | I_reinit (_, id, cval) -> (cval_deps cval, NameSet.singleton id) | I_if (cval, _, _, _) -> (cval_deps cval, NameSet.empty) | I_jump (cval, _) -> (cval_deps cval, NameSet.empty) - | I_funcall (clexp, _, _, cvals) -> - let reads, writes = clexp_deps clexp in + | I_funcall (creturn, _, _, cvals) -> + let reads, writes = creturn_deps creturn in (List.fold_left NameSet.union reads (List.map cval_deps cvals), writes) | I_copy (clexp, cval) -> let reads, writes = clexp_deps clexp in @@ -691,11 +720,17 @@ let rec clexp_typed_writes = function | CL_addr clexp -> clexp_typed_writes clexp | CL_void -> NameCTSet.empty +let creturn_typed_writes = function + | CR_one clexp -> clexp_typed_writes clexp + | CR_multi clexps -> + List.fold_left (fun writes clexp -> NameCTSet.union writes (clexp_typed_writes clexp)) NameCTSet.empty clexps + let instr_typed_writes (I_aux (aux, _)) = match aux with | I_decl (ctyp, id) | I_reset (ctyp, id) -> NameCTSet.singleton (id, ctyp) | I_init (ctyp, id, _) | I_reinit (ctyp, id, _) -> NameCTSet.singleton (id, ctyp) - | I_funcall (clexp, _, _, _) | I_copy (clexp, _) -> clexp_typed_writes clexp + | I_copy (clexp, _) -> clexp_typed_writes clexp + | I_funcall (creturn, _, _, _) -> creturn_typed_writes creturn | _ -> NameCTSet.empty let rec map_clexp_ctyp f = function @@ -708,6 +743,7 @@ let rec map_clexp_ctyp f = function let rec map_cval_ctyp f = function | V_id (id, ctyp) -> V_id (id, f ctyp) + | V_member (id, ctyp) -> V_member (id, f ctyp) | V_lit (vl, ctyp) -> V_lit (vl, f ctyp) | V_ctor_kind (cval, (id, unifiers), ctyp) -> V_ctor_kind (map_cval_ctyp f cval, (id, List.map f unifiers), f ctyp) | V_ctor_unwrap (cval, (id, unifiers), ctyp) -> V_ctor_unwrap (map_cval_ctyp f cval, (id, List.map f unifiers), f ctyp) @@ -717,6 +753,10 @@ let rec map_cval_ctyp f = function | V_struct (fields, ctyp) -> V_struct (List.map (fun (id, cval) -> (id, map_cval_ctyp f cval)) fields, f ctyp) | V_tuple (members, ctyp) -> V_tuple (List.map (map_cval_ctyp f) members, f ctyp) +let map_creturn_ctyp f = function + | CR_one clexp -> CR_one (map_clexp_ctyp f clexp) + | CR_multi clexps -> CR_multi (List.map (map_clexp_ctyp f) clexps) + let rec map_instr_ctyp f (I_aux (instr, aux)) = let instr = match instr with @@ -730,8 +770,8 @@ let rec map_instr_ctyp f (I_aux (instr, aux)) = f ctyp ) | I_jump (cval, label) -> I_jump (map_cval_ctyp f cval, label) - | I_funcall (clexp, extern, (id, ctyps), cvals) -> - I_funcall (map_clexp_ctyp f clexp, extern, (id, List.map f ctyps), List.map (map_cval_ctyp f) cvals) + | I_funcall (creturn, extern, (id, ctyps), cvals) -> + I_funcall (map_creturn_ctyp f creturn, extern, (id, List.map f ctyps), List.map (map_cval_ctyp f) cvals) | I_copy (clexp, cval) -> I_copy (map_clexp_ctyp f clexp, map_cval_ctyp f cval) | I_clear (ctyp, id) -> I_clear (f ctyp, id) | I_return cval -> I_return (map_cval_ctyp f cval) @@ -964,6 +1004,7 @@ let rec infer_call op vs = and cval_ctyp = function | V_id (_, ctyp) -> ctyp + | V_member (_, ctyp) -> ctyp | V_lit (_, ctyp) -> ctyp | V_ctor_kind _ -> CT_bool | V_ctor_unwrap (_, _, ctyp) -> ctyp @@ -1009,16 +1050,18 @@ let rec clexp_ctyp = function end | CL_void -> CT_unit +let creturn_ctyp = function CR_one clexp -> clexp_ctyp clexp | CR_multi clexps -> CT_tup (List.map clexp_ctyp clexps) + let rec instr_ctyps (I_aux (instr, aux)) = match instr with | I_decl (ctyp, _) | I_reset (ctyp, _) | I_clear (ctyp, _) | I_undefined ctyp -> CTSet.singleton ctyp | I_init (ctyp, _, cval) | I_reinit (ctyp, _, cval) -> CTSet.add ctyp (CTSet.singleton (cval_ctyp cval)) | I_if (cval, instrs1, instrs2, ctyp) -> CTSet.union (instrs_ctyps instrs1) (instrs_ctyps instrs2) |> CTSet.add (cval_ctyp cval) |> CTSet.add ctyp - | I_funcall (clexp, _, (_, ctyps), cvals) -> + | I_funcall (creturn, _, (_, ctyps), cvals) -> List.fold_left (fun m ctyp -> CTSet.add ctyp m) CTSet.empty (List.map cval_ctyp cvals) |> CTSet.union (CTSet.of_list ctyps) - |> CTSet.add (clexp_ctyp clexp) + |> CTSet.add (creturn_ctyp creturn) | I_copy (clexp, cval) -> CTSet.add (clexp_ctyp clexp) (CTSet.singleton (cval_ctyp cval)) | I_block instrs | I_try_block instrs -> instrs_ctyps instrs | I_throw cval | I_jump (cval, _) | I_return cval -> CTSet.singleton (cval_ctyp cval) @@ -1032,8 +1075,8 @@ let rec instr_ctyps_exist pred (I_aux (instr, aux)) = | I_init (ctyp, _, cval) | I_reinit (ctyp, _, cval) -> pred ctyp || pred (cval_ctyp cval) | I_if (cval, instrs1, instrs2, ctyp) -> pred (cval_ctyp cval) || instrs_ctyps_exist pred instrs1 || instrs_ctyps_exist pred instrs2 || pred ctyp - | I_funcall (clexp, _, (_, ctyps), cvals) -> - pred (clexp_ctyp clexp) || List.exists pred ctyps || Util.map_exists pred cval_ctyp cvals + | I_funcall (creturn, _, (_, ctyps), cvals) -> + pred (creturn_ctyp creturn) || List.exists pred ctyps || Util.map_exists pred cval_ctyp cvals | I_copy (clexp, cval) -> pred (clexp_ctyp clexp) || pred (cval_ctyp cval) | I_block instrs | I_try_block instrs -> instrs_ctyps_exist pred instrs | I_throw cval | I_jump (cval, _) | I_return cval -> pred (cval_ctyp cval) diff --git a/src/lib/jib_util.mli b/src/lib/jib_util.mli index 3f97050de..4cdfcc58e 100644 --- a/src/lib/jib_util.mli +++ b/src/lib/jib_util.mli @@ -83,6 +83,7 @@ val ireset : l -> ctyp -> name -> instr val iinit : l -> ctyp -> name -> cval -> instr val iif : l -> cval -> instr list -> instr list -> ctyp -> instr val ifuncall : l -> clexp -> id * ctyp list -> cval list -> instr +val ifuncall_multi : l -> clexp list -> id * ctyp list -> cval list -> instr val iextern : l -> clexp -> id * ctyp list -> cval list -> instr val icopy : l -> clexp -> cval -> instr val iclear : ?loc:l -> ctyp -> name -> instr @@ -185,6 +186,7 @@ val subst_poly : ctyp KBindings.t -> ctyp -> ctyp val cval_ctyp : cval -> ctyp val clexp_ctyp : clexp -> ctyp +val creturn_ctyp : creturn -> ctyp val cdef_ctyps : cdef -> CTSet.t val cdef_ctyps_has : (ctyp -> bool) -> cdef -> bool diff --git a/src/lib/jib_visitor.ml b/src/lib/jib_visitor.ml index 59b1a1c36..0f3c1d60b 100644 --- a/src/lib/jib_visitor.ml +++ b/src/lib/jib_visitor.ml @@ -1,10 +1,14 @@ open Jib include Visitor -class type jib_visitor = object +class type common_visitor = object method vid : Ast.id -> Ast.id option method vname : name -> name option method vctyp : ctyp -> ctyp visit_action +end + +class type jib_visitor = object + inherit common_visitor method vcval : cval -> cval visit_action method vclexp : clexp -> clexp visit_action method vinstrs : instr list -> instr list visit_action @@ -100,6 +104,15 @@ let rec visit_clexp vis outer_clexp = in do_visit vis (vis#vclexp outer_clexp) aux outer_clexp +let visit_creturn vis no_change = + match no_change with + | CR_one clexp -> + let clexp' = visit_clexp vis clexp in + if clexp == clexp' then no_change else CR_one clexp' + | CR_multi clexps -> + let clexps' = map_no_copy (visit_clexp vis) clexps in + if clexps == clexps' then no_change else CR_multi clexps' + let rec visit_cval vis outer_cval = let aux vis no_change = match no_change with @@ -107,6 +120,9 @@ let rec visit_cval vis outer_cval = let name' = visit_name vis name in let ctyp' = visit_ctyp vis ctyp in if name == name' && ctyp == ctyp' then no_change else V_id (name', ctyp') + | V_member (id, ctyp) -> + let ctyp' = visit_ctyp vis ctyp in + if ctyp == ctyp' then no_change else V_member (id, ctyp') | V_lit (value, ctyp) -> let ctyp' = visit_ctyp vis ctyp in if ctyp == ctyp' then no_change else V_lit (value, ctyp') @@ -169,13 +185,13 @@ let rec visit_instr vis outer_instr = if cval == cval' then no_change else I_aux (I_jump (cval', label), aux) | I_aux (I_goto _, aux) -> no_change | I_aux (I_label _, aux) -> no_change - | I_aux (I_funcall (clexp, extern, (id, ctyps), cvals), aux) -> - let clexp' = visit_clexp vis clexp in + | I_aux (I_funcall (creturn, extern, (id, ctyps), cvals), aux) -> + let creturn' = visit_creturn vis creturn in let id' = visit_id vis id in let ctyps' = visit_ctyps vis ctyps in let cvals' = visit_cvals vis cvals in - if clexp == clexp' && id == id' && ctyps == ctyps' && cvals == cvals' then no_change - else I_aux (I_funcall (clexp', extern, (id', ctyps'), cvals'), aux) + if creturn == creturn' && id == id' && ctyps == ctyps' && cvals == cvals' then no_change + else I_aux (I_funcall (creturn', extern, (id', ctyps'), cvals'), aux) | I_aux (I_copy (clexp, cval), aux) -> let clexp' = visit_clexp vis clexp in let cval' = visit_cval vis cval in diff --git a/src/lib/jib_visitor.mli b/src/lib/jib_visitor.mli index 0d9c2bb3d..6406552d1 100644 --- a/src/lib/jib_visitor.mli +++ b/src/lib/jib_visitor.mli @@ -12,14 +12,22 @@ type 'a visit_action = node if any of the children has changed and then apply the function on the node *) +val do_visit : 'v -> 'a visit_action -> ('v -> 'a -> 'a) -> 'a -> 'a + val change_do_children : 'a -> 'a visit_action val map_no_copy : ('a -> 'a) -> 'a list -> 'a list -class type jib_visitor = object +val map_no_copy_opt : ('a -> 'a) -> 'a option -> 'a option + +class type common_visitor = object method vid : Ast.id -> Ast.id option method vname : name -> name option method vctyp : ctyp -> ctyp visit_action +end + +class type jib_visitor = object + inherit common_visitor method vcval : cval -> cval visit_action method vclexp : clexp -> clexp visit_action method vinstrs : instr list -> instr list visit_action @@ -29,7 +37,9 @@ end class empty_jib_visitor : jib_visitor -val visit_ctyp : jib_visitor -> ctyp -> ctyp +val visit_name : common_visitor -> name -> name + +val visit_ctyp : common_visitor -> ctyp -> ctyp val visit_cval : jib_visitor -> cval -> cval diff --git a/src/lib/property.ml b/src/lib/property.ml index a2c6fc02d..4108bb8be 100644 --- a/src/lib/property.ml +++ b/src/lib/property.ml @@ -84,6 +84,50 @@ let find_properties { defs; _ } = |> List.map (fun ((_, _, _, vs) as prop) -> (id_of_val_spec vs, prop)) |> List.fold_left (fun m (id, prop) -> Bindings.add id prop m) Bindings.empty +let well_formedness_check (Typ_aux (aux, _)) = + match aux with + | Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp nexp, _)]) -> + Some (fun exp -> mk_exp (E_app (mk_id "eq_int", [exp; mk_exp (E_sizeof nexp)]))) + | Typ_app (Id_aux (Id "atom_bool", _), [A_aux (A_bool nc, _)]) -> + Some (fun exp -> mk_exp (E_app (mk_id "eq_bool", [exp; mk_exp (E_constraint nc)]))) + | Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp nexp, _)]) -> + Some + (fun exp -> + mk_exp (E_app (mk_id "eq_int", [mk_exp (E_app (mk_id "bitvector_length", [exp])); mk_exp (E_sizeof nexp)])) + ) + | _ -> None + +let destruct_tuple_pat = function P_aux (P_tuple pats, annot) -> (pats, Some annot) | pat -> ([pat], None) + +let reconstruct_tuple_pat pats = function + | Some (l, tannot) -> P_aux (P_tuple pats, (l, Type_check.untyped_annot tannot)) + | None -> List.hd pats + +let well_formed_function_arguments pragma_l pat = + let wf_var n = mk_id ("wf_arg" ^ string_of_int n ^ "#") in + function + | Typ_aux (Typ_fn (arg_typs, _), _) -> + let pats, pats_annot = destruct_tuple_pat pat in + if List.compare_lengths pats arg_typs = 0 then ( + let pats, checks = + List.combine pats arg_typs + |> List.mapi (fun n (pat, arg_typ) -> + let id = wf_var n in + match well_formedness_check arg_typ with + | Some check -> + let pat = mk_pat (P_as (Type_check.strip_pat pat, id)) in + (pat, Some (check (mk_exp (E_id id)))) + | None -> (Type_check.strip_pat pat, None) + ) + |> List.split + in + (reconstruct_tuple_pat pats pats_annot, Util.option_these checks) + ) + else + Reporting.unreachable pragma_l __POS__ + "Function pattern and type do not match when generating well-formedness check for property" + | _ -> Reporting.unreachable pragma_l __POS__ "Found function with non-function type" + let add_property_guards props ast = let open Type_check in let open Type_error in @@ -91,49 +135,45 @@ let add_property_guards props ast = | (DEF_aux (DEF_fundef (FD_aux (FD_function (r_opt, t_opt, funcls), fd_aux) as fdef), def_annot) as def) :: defs -> begin match Bindings.find_opt (id_of_fundef fdef) props with - | Some (_, _, pragma_l, VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (quant, _), _), _, _), _)) -> begin + | Some (_, _, pragma_l, VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (quant, fn_typ), _), _, _), _)) -> begin match quant_split quant with - | _, [] -> add_property_guards' (def :: acc) defs | _, constraints -> - let add_constraints_to_funcl (FCL_aux (FCL_funcl (id, Pat_aux (pexp, pexp_aux)), fcl_aux)) = - let add_guard exp = - (* FIXME: Use an assert *) - let exp' = - mk_exp - (E_block - [ - mk_exp - (E_app - ( mk_id "sail_assume", - [mk_exp (E_constraint (List.fold_left nc_and nc_true constraints))] - ) - ); - strip_exp exp; - ] - ) - in - try Type_check.check_exp (env_of exp) exp' (typ_of exp) - with Type_error (l, err) -> - let msg = - "\n\ - Type error when generating guard for a property.\n\ - When generating guards we convert type quantifiers from the function signature\n\ - into runtime checks so it must be possible to reconstruct the quantifier from\n\ - the function arguments. For example:\n\n\ - function f : forall 'n, 'n <= 100. (x: int('n)) -> bool\n\n\ - would cause the runtime check x <= 100 to be added to the function body.\n\ - To fix this error, ensure that all quantifiers have corresponding function arguments.\n" - in - let original_msg, hint = Type_error.string_of_type_error err in - raise (Reporting.err_typ ?hint pragma_l (original_msg ^ msg)) + let add_checks_to_funcl (FCL_aux (FCL_funcl (id, pexp), (def_annot, fcl_tannot))) = + let pat, guard, exp, (pexp_l, pexp_tannot) = destruct_pexp pexp in + let pat, checks = well_formed_function_arguments pragma_l pat fn_typ in + let exp = + mk_exp + (E_block + (List.map + (fun check -> mk_exp (E_app (mk_id "sail_assume", [check]))) + (mk_exp (E_constraint (List.fold_left nc_and nc_true constraints)) :: checks) + @ [strip_exp exp] + ) + ) in - let mk_funcl p = FCL_aux (FCL_funcl (id, Pat_aux (p, pexp_aux)), fcl_aux) in - match pexp with - | Pat_exp (pat, exp) -> mk_funcl (Pat_exp (pat, add_guard exp)) - | Pat_when (pat, guard, exp) -> mk_funcl (Pat_when (pat, guard, add_guard exp)) + let pexp = + construct_pexp (pat, Option.map strip_exp guard, exp, (pexp_l, Type_check.untyped_annot pexp_tannot)) + in + try + Type_check.check_funcl (env_of_tannot fcl_tannot) + (FCL_aux (FCL_funcl (id, pexp), (def_annot, Type_check.untyped_annot fcl_tannot))) + (typ_of_tannot fcl_tannot) + with Type_error (l, err) -> + let msg = + "\n\ + Type error when generating guard for a property.\n\ + When generating guards we convert type quantifiers from the function signature\n\ + into runtime checks so it must be possible to reconstruct the quantifier from\n\ + the function arguments. For example:\n\n\ + function f : forall 'n, 'n <= 100. (x: int('n)) -> bool\n\n\ + would cause the runtime check x <= 100 to be added to the function body.\n\ + To fix this error, ensure that all quantifiers have corresponding function arguments.\n" + in + let original_msg, hint = Type_error.string_of_type_error err in + raise (Reporting.err_typ ?hint pragma_l (original_msg ^ msg)) in - let funcls = List.map add_constraints_to_funcl funcls in + let funcls = List.map add_checks_to_funcl funcls in let fdef = FD_aux (FD_function (r_opt, t_opt, funcls), fd_aux) in add_property_guards' (DEF_aux (DEF_fundef fdef, def_annot) :: acc) defs diff --git a/src/lib/smt_exp.ml b/src/lib/smt_exp.ml index ed3cb2132..41bf66a9e 100644 --- a/src/lib/smt_exp.ml +++ b/src/lib/smt_exp.ml @@ -67,6 +67,10 @@ open Ast_util +let zencode_id id = Util.zencode_string (string_of_id id) +let zencode_upper_id id = Util.zencode_upper_string (string_of_id id) +let zencode_name id = Jib_util.string_of_name ~deref_current_exception:false ~zencode:true id + type smt_typ = | Bitvec of int | Bool @@ -115,13 +119,9 @@ let rec fold_smt_exp f = function | Store (info, store_fn, arr, index, x) -> f (Store (info, store_fn, fold_smt_exp f arr, fold_smt_exp f index, fold_smt_exp f x)) | Hd (hd_op, xs) -> f (Hd (hd_op, fold_smt_exp f xs)) - | Tl (hd_op, xs) -> f (Tl (hd_op, fold_smt_exp f xs)) + | Tl (tl_op, xs) -> f (Tl (tl_op, fold_smt_exp f xs)) | (Bool_lit _ | Bitvec_lit _ | Real_lit _ | String_lit _ | Var _ | Enum _ | Empty_list) as exp -> f exp -let smt_conj = function [] -> Bool_lit true | [x] -> x | xs -> Fn ("and", xs) - -let smt_disj = function [] -> Bool_lit false | [x] -> x | xs -> Fn ("or", xs) - let extract i j x = Extract (i, j, x) let bvnot x = Fn ("bvnot", [x]) @@ -145,76 +145,358 @@ let bvone n = if n > 0 then Bitvec_lit (Sail2_operators_bitlists.zeros (Big_int.of_int (n - 1)) @ [Sail2_values.B1]) else Bitvec_lit [] -let rec simp exp = +let smt_conj = function [] -> Bool_lit true | [x] -> x | xs -> Fn ("and", xs) + +let smt_disj = function [] -> Bool_lit false | [x] -> x | xs -> Fn ("or", xs) + +let simp_and xs = + let xs = List.filter (function Bool_lit true -> false | _ -> true) xs in + match xs with + | [] -> Bool_lit true + | [x] -> x + | _ -> if List.exists (function Bool_lit false -> true | _ -> false) xs then Bool_lit false else Fn ("and", xs) + +let simp_or xs = + let xs = List.filter (function Bool_lit false -> false | _ -> true) xs in + match xs with + | [] -> Bool_lit false + | [x] -> x + | _ -> if List.exists (function Bool_lit true -> true | _ -> false) xs then Bool_lit true else Fn ("or", xs) + +let simp_fn f args = let open Sail2_operators_bitlists in - match exp with - | Fn (f, args) -> - let args = List.map simp args in + match (f, args) with + | "not", [Fn ("not", [exp])] -> exp + | "not", [Bool_lit b] -> Bool_lit (not b) + | "contents", [Fn ("Bits", [_; bv])] -> bv + | "len", [Fn ("Bits", [len; _])] -> len + | "or", _ -> simp_or args + | "and", _ -> simp_and args + | "concat", _ -> + let chunks, bv = + List.fold_left + (fun (chunks, current_literal) arg -> + match (current_literal, arg) with + | Some bv1, Bitvec_lit bv2 -> (chunks, Some (bv1 @ bv2)) + | None, Bitvec_lit bv -> (chunks, Some bv) + | Some bv, exp -> (exp :: Bitvec_lit bv :: chunks, None) + | None, exp -> (exp :: chunks, None) + ) + ([], None) args + in begin - match (f, args) with - | "contents", [Fn ("Bits", [_; bv])] -> bv - | "len", [Fn ("Bits", [len; _])] -> len - | "concat", _ -> - let chunks, bv = - List.fold_left - (fun (chunks, current_literal) arg -> - match (current_literal, arg) with - | Some bv1, Bitvec_lit bv2 -> (chunks, Some (bv1 @ bv2)) - | None, Bitvec_lit bv -> (chunks, Some bv) - | Some bv, exp -> (exp :: Bitvec_lit bv :: chunks, None) - | None, exp -> (exp :: chunks, None) - ) - ([], None) args - in - begin - match (chunks, bv) with - | [], Some bv -> Bitvec_lit bv - | chunks, Some bv -> Fn ("concat", List.rev (Bitvec_lit bv :: chunks)) - | chunks, None -> Fn ("concat", List.rev chunks) - end - | "bvnot", [Bitvec_lit bv] -> Bitvec_lit (not_vec bv) - | "bvand", [Bitvec_lit lhs; Bitvec_lit rhs] -> Bitvec_lit (and_vec lhs rhs) - | "bvor", [Bitvec_lit lhs; Bitvec_lit rhs] -> Bitvec_lit (or_vec lhs rhs) - | "bvxor", [Bitvec_lit lhs; Bitvec_lit rhs] -> Bitvec_lit (xor_vec lhs rhs) - | "bvshl", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin - match sint_maybe rhs with Some shift -> Bitvec_lit (shiftl lhs shift) | None -> Fn (f, args) - end - | "bvlshr", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin - match sint_maybe rhs with Some shift -> Bitvec_lit (shiftr lhs shift) | None -> Fn (f, args) - end - | "bvashr", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin - match sint_maybe rhs with Some shift -> Bitvec_lit (shiftr lhs shift) | None -> Fn (f, args) - end - | f, args -> Fn (f, args) + match (chunks, bv) with + | [], Some bv -> Bitvec_lit bv + | chunks, Some bv -> Fn ("concat", List.rev (Bitvec_lit bv :: chunks)) + | chunks, None -> Fn ("concat", List.rev chunks) end - | ZeroExtend (to_len, from_len, arg) -> - let arg = simp arg in + | "bvnot", [Bitvec_lit bv] -> Bitvec_lit (not_vec bv) + | "bvand", [Bitvec_lit lhs; Bitvec_lit rhs] -> Bitvec_lit (and_vec lhs rhs) + | "bvor", [Bitvec_lit lhs; Bitvec_lit rhs] -> Bitvec_lit (or_vec lhs rhs) + | "bvxor", [Bitvec_lit lhs; Bitvec_lit rhs] -> Bitvec_lit (xor_vec lhs rhs) + | "bvshl", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin + match sint_maybe rhs with Some shift -> Bitvec_lit (shiftl lhs shift) | None -> Fn (f, args) + end + | "bvlshr", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin + match sint_maybe rhs with Some shift -> Bitvec_lit (shiftr lhs shift) | None -> Fn (f, args) + end + | "bvashr", [Bitvec_lit lhs; Bitvec_lit rhs] -> begin + match sint_maybe rhs with Some shift -> Bitvec_lit (shiftr lhs shift) | None -> Fn (f, args) + end + | _, _ -> Fn (f, args) + +let rec simp vars exp = + let open Sail2_operators_bitlists in + match exp with + | Var v -> begin match vars v with Some exp -> simp vars exp | None -> Var v end + | Fn (f, args) -> + let args = List.map (simp vars) args in + simp_fn f args + | ZeroExtend (to_len, by_len, arg) -> + let arg = simp vars arg in begin match arg with | Bitvec_lit bv -> Bitvec_lit (zero_extend bv (Big_int.of_int to_len)) - | _ -> ZeroExtend (to_len, from_len, arg) + | _ -> ZeroExtend (to_len, by_len, arg) end - | SignExtend (to_len, from_len, arg) -> - let arg = simp arg in + | SignExtend (to_len, by_len, arg) -> + let arg = simp vars arg in begin match arg with | Bitvec_lit bv -> Bitvec_lit (sign_extend bv (Big_int.of_int to_len)) - | _ -> SignExtend (to_len, from_len, arg) + | _ -> SignExtend (to_len, by_len, arg) end | Extract (n, m, arg) -> begin - match simp arg with + match simp vars arg with | Bitvec_lit bv -> Bitvec_lit (subrange_vec_dec bv (Big_int.of_int n) (Big_int.of_int m)) | exp -> Extract (n, m, exp) end - | Store (info, store_fn, arr, i, x) -> Store (info, store_fn, simp arr, simp i, simp x) - | exp -> exp + | Store (info, store_fn, arr, i, x) -> Store (info, store_fn, simp vars arr, simp vars i, simp vars x) + | Ite (Fn ("not", [cond]), then_exp, else_exp) -> simp vars (Ite (cond, else_exp, then_exp)) + | Ite (cond, then_exp, else_exp) -> Ite (simp vars cond, simp vars then_exp, simp vars else_exp) + | Tester (ctor, exp) -> Tester (ctor, simp vars exp) + | Unwrap (ctor, b, exp) -> Unwrap (ctor, b, simp vars exp) + | Field (struct_id, field_id, exp) -> Field (struct_id, field_id, simp vars exp) + | Empty_list | Bool_lit _ | Bitvec_lit _ | Real_lit _ | String_lit _ | Enum _ | Hd _ | Tl _ -> exp type smt_def = | Define_fun of string * (string * smt_typ) list * smt_typ * smt_exp | Declare_fun of string * smt_typ list * smt_typ - | Declare_const of string * smt_typ - | Define_const of string * smt_typ * smt_exp + | Declare_const of Jib.name * smt_typ + | Define_const of Jib.name * smt_typ * smt_exp | Declare_datatypes of string * (string * (string * smt_typ) list) list | Assert of smt_exp let declare_datatypes = function Datatype (name, ctors) -> Declare_datatypes (name, ctors) | _ -> assert false + +let pp_sfun str docs = + let open PPrint in + parens (separate space (string str :: docs)) + +let rec pp_smt_typ = + let open PPrint in + function + | Bool -> string "Bool" + | String -> string "String" + | Real -> string "Real" + | Bitvec n -> string (Printf.sprintf "(_ BitVec %d)" n) + | Datatype (name, _) -> string name + | Array (ty1, ty2) -> pp_sfun "Array" [pp_smt_typ ty1; pp_smt_typ ty2] + +let pp_str_smt_typ (str, ty) = + let open PPrint in + parens (string str ^^ space ^^ pp_smt_typ ty) + +let rec pp_smt_exp = + let open PPrint in + function + | Bool_lit b -> string (string_of_bool b) + | Real_lit str -> string str + | String_lit str -> string ("\"" ^ str ^ "\"") + | Bitvec_lit bv -> string (Sail2_values.show_bitlist_prefix '#' bv) + | Var id -> string (zencode_name id) + | Enum str -> string str + | Fn (str, exps) -> parens (string str ^^ space ^^ separate_map space pp_smt_exp exps) + | Field (struct_id, field_id, exp) -> + parens (string (zencode_upper_id struct_id ^ "_" ^ zencode_id field_id) ^^ space ^^ pp_smt_exp exp) + | Unwrap (ctor, _, exp) -> parens (string ("un" ^ zencode_id ctor) ^^ space ^^ pp_smt_exp exp) + | Ite (cond, then_exp, else_exp) -> + parens (separate space [string "ite"; pp_smt_exp cond; pp_smt_exp then_exp; pp_smt_exp else_exp]) + | Extract (i, j, exp) -> parens (string (Printf.sprintf "(_ extract %d %d)" i j) ^^ space ^^ pp_smt_exp exp) + | Tester (kind, exp) -> parens (string (Printf.sprintf "(_ is %s)" kind) ^^ space ^^ pp_smt_exp exp) + | SignExtend (_, i, exp) -> parens (string (Printf.sprintf "(_ sign_extend %d)" i) ^^ space ^^ pp_smt_exp exp) + | ZeroExtend (_, i, exp) -> parens (string (Printf.sprintf "(_ zero_extend %d)" i) ^^ space ^^ pp_smt_exp exp) + | Store (_, _, arr, index, x) -> parens (string "store" ^^ space ^^ separate_map space pp_smt_exp [arr; index; x]) + | Hd (op, exp) | Tl (op, exp) -> parens (string op ^^ space ^^ pp_smt_exp exp) + | Empty_list -> string "empty_list" + +let pp_smt_def = + let open PPrint in + let open Printf in + function + | Define_fun (name, args, ty, exp) -> + parens + (string "define-fun" ^^ space ^^ string name ^^ space + ^^ parens (separate_map space pp_str_smt_typ args) + ^^ space ^^ pp_smt_typ ty ^//^ pp_smt_exp exp + ) + | Declare_fun (name, args, ty) -> + parens + (string "declare-fun" ^^ space ^^ string name ^^ space + ^^ parens (separate_map space pp_smt_typ args) + ^^ space ^^ pp_smt_typ ty + ) + | Declare_const (name, ty) -> pp_sfun "declare-const" [string (zencode_name name); pp_smt_typ ty] + | Define_const (name, ty, exp) -> pp_sfun "define-const" [string (zencode_name name); pp_smt_typ ty; pp_smt_exp exp] + | Declare_datatypes (name, ctors) -> + let pp_ctor (ctor_name, fields) = + match fields with [] -> parens (string ctor_name) | _ -> pp_sfun ctor_name (List.map pp_str_smt_typ fields) + in + pp_sfun "declare-datatypes" + [Printf.ksprintf string "((%s 0))" name; parens (parens (separate_map space pp_ctor ctors))] + | Assert exp -> pp_sfun "assert" [pp_smt_exp exp] + +let string_of_smt_def def = Pretty_print_sail.Document.to_string (pp_smt_def def) + +(**************************************************************************) +(* 2. Parser for SMT solver output *) +(**************************************************************************) + +(* Output format from each SMT solver does not seem to be completely + standardised, unlike the SMTLIB input format, but usually is some + form of s-expression based representation. Therefore we define a + simple parser for s-expressions using monadic parser combinators. *) + +type counterexample_solver = Cvc5 | Cvc4 | Z3 + +let counterexample_command = function Cvc5 -> "cvc5 --lang=smt2.6" | Cvc4 -> "cvc4 --lang=smt2.6" | Z3 -> "z3" + +let counterexample_solver_from_name name = + match String.lowercase_ascii name with "cvc4" -> Some Cvc4 | "cvc5" -> Some Cvc5 | "z3" -> Some Z3 | _ -> None + +module type COUNTEREXAMPLE_CONFIG = sig + val max_unknown_integer_width : int +end + +module Counterexample (Config : COUNTEREXAMPLE_CONFIG) = struct + type sexpr = List of sexpr list | Atom of string + + let rec string_of_sexpr = function + | List sexprs -> "(" ^ Util.string_of_list " " string_of_sexpr sexprs ^ ")" + | Atom str -> str + + open Parser_combinators + + let lparen = token (function Str.Delim "(" -> Some () | _ -> None) + let rparen = token (function Str.Delim ")" -> Some () | _ -> None) + let atom = token (function Str.Text str -> Some str | _ -> None) + + let rec sexp toks = + let parse = + pchoose + (atom >>= fun str -> preturn (Atom str)) + ( lparen >>= fun _ -> + plist sexp >>= fun xs -> + rparen >>= fun _ -> preturn (List xs) + ) + in + parse toks + + let parse_sexps input = + let delim = Str.regexp "[ \n\t]+\\|(\\|)" in + let tokens = Str.full_split delim input in + let non_whitespace = function Str.Delim d when String.trim d = "" -> false | _ -> true in + let tokens = List.filter non_whitespace tokens in + match plist sexp tokens with Ok (result, _) -> Some result | Fail -> None + + let parse_sexpr_int width = function + | List [Atom "_"; Atom v; Atom m] when int_of_string m = width && String.length v > 2 && String.sub v 0 2 = "bv" -> + let v = String.sub v 2 (String.length v - 2) in + Some (Big_int.of_string v) + | Atom v when String.length v > 2 && String.sub v 0 2 = "#b" -> + let v = String.sub v 2 (String.length v - 2) in + Some (Big_int.of_string ("0b" ^ v)) + | Atom v when String.length v > 2 && String.sub v 0 2 = "#x" -> + let v = String.sub v 2 (String.length v - 2) in + Some (Big_int.of_string ("0x" ^ v)) + | _ -> None + + let rec value_of_sexpr sexpr = + let open Jib in + let open Value in + function + | CT_fbits width -> begin + match parse_sexpr_int width sexpr with + | Some value -> mk_vector (Sail_lib.get_slice_int' (width, value, 0)) + | None -> failwith ("Cannot parse sexpr as bitvector: " ^ string_of_sexpr sexpr) + end + | CT_struct (_, fields) -> begin + match sexpr with + | List (Atom name :: smt_fields) -> + V_record + (List.fold_left2 + (fun m (field_id, ctyp) sexpr -> StringMap.add (string_of_id field_id) (value_of_sexpr sexpr ctyp) m) + StringMap.empty fields smt_fields + ) + | _ -> failwith ("Cannot parse sexpr as struct " ^ string_of_sexpr sexpr) + end + | CT_enum (_, members) -> begin + match sexpr with + | Atom name -> begin + match List.find_opt (fun member -> Util.zencode_string (string_of_id member) = name) members with + | Some member -> V_member (string_of_id member) + | None -> + failwith + ("Could not find enum member for " ^ name ^ " in " ^ Util.string_of_list ", " string_of_id members) + end + | _ -> failwith ("Cannot parse sexpr as enum " ^ string_of_sexpr sexpr) + end + | CT_bool -> begin + match sexpr with + | Atom "true" -> V_bool true + | Atom "false" -> V_bool false + | _ -> failwith ("Cannot parse sexpr as bool " ^ string_of_sexpr sexpr) + end + | CT_fint width -> begin + match parse_sexpr_int width sexpr with + | Some value -> V_int value + | None -> failwith ("Cannot parse sexpr as fixed-width integer: " ^ string_of_sexpr sexpr) + end + | CT_lint -> begin + match parse_sexpr_int Config.max_unknown_integer_width sexpr with + | Some value -> V_int value + | None -> failwith ("Cannot parse sexpr as integer: " ^ string_of_sexpr sexpr) + end + | ctyp -> failwith ("Unsupported type in sexpr: " ^ Jib_util.string_of_ctyp ctyp) + + let rec find_arg id ctyp arg_smt_names = function + | List [Atom "define-fun"; Atom str; List []; _; value] :: _ + when Util.assoc_compare_opt Id.compare id arg_smt_names = Some (Some str) -> + (id, value_of_sexpr value ctyp) + | _ :: sexps -> find_arg id ctyp arg_smt_names sexps + | [] -> (id, V_unit) + + let build_counterexample args arg_ctyps arg_smt_names model = + List.map2 (fun id ctyp -> find_arg id ctyp arg_smt_names model) args arg_ctyps + + let rec run frame = + match frame with + | Interpreter.Done (state, v) -> Some v + | Interpreter.Step (lazy_str, _, _, _) -> run (Interpreter.eval_frame frame) + | Interpreter.Break frame -> run (Interpreter.eval_frame frame) + | Interpreter.Fail (_, _, _, _, msg) -> None + | Interpreter.Effect_request (out, state, stack, eff) -> run (Interpreter.default_effect_interp state eff) + + let check ~env ~ast ~solver ~file_name ~function_id ~args ~arg_ctyps ~arg_smt_names = + let open Printf in + let open Ast in + print_endline ("Checking counterexample: " ^ file_name); + let in_chan = ksprintf Unix.open_process_in "%s %s" (counterexample_command solver) file_name in + let lines = ref [] in + begin + try + while true do + lines := input_line in_chan :: !lines + done + with End_of_file -> () + end; + let solver_output = List.rev !lines |> String.concat "\n" in + begin + match parse_sexps solver_output with + | Some (Atom "sat" :: (List (Atom "model" :: model) | List model) :: _) -> + let open Value in + let open Interpreter in + print_endline (sprintf "Solver found counterexample: %s" Util.("ok" |> green |> clear)); + let counterexample = build_counterexample args arg_ctyps arg_smt_names model in + List.iter (fun (id, v) -> print_endline (" " ^ string_of_id id ^ " -> " ^ string_of_value v)) counterexample; + let istate = initial_state ast env !primops in + let annot = (Parse_ast.Unknown, Type_check.mk_tannot env bool_typ) in + let call = + E_aux + ( E_app + ( function_id, + List.map + (fun (_, v) -> E_aux (E_internal_value v, (Parse_ast.Unknown, Type_check.empty_tannot))) + counterexample + ), + annot + ) + in + let result = run (Step (lazy "", istate, return call, [])) in + begin + match result with + | Some (V_bool false) | None -> + ksprintf print_endline "Replaying counterexample: %s" Util.("ok" |> green |> clear) + | _ -> () + end + | Some (Atom "unsat" :: _) -> + print_endline "Solver could not find counterexample"; + print_endline "Solver output:"; + print_endline solver_output + | _ -> + print_endline "Unexpected solver output:"; + print_endline solver_output + end; + let _ = Unix.close_process_in in_chan in + () +end diff --git a/src/lib/smt_gen.ml b/src/lib/smt_gen.ml index 5f2b2b758..46e1a6339 100644 --- a/src/lib/smt_gen.ml +++ b/src/lib/smt_gen.ml @@ -80,6 +80,8 @@ let zencode_uid (id, ctyps) = type checks = { overflows : smt_exp list; strings_used : bool; reals_used : bool } +let get_overflows c = c.overflows + let empty_checks = { overflows = []; strings_used = false; reals_used = false } let append_checks c1 c2 = @@ -118,6 +120,12 @@ let rec mapM f = function let* xs = mapM f xs in return (x :: xs) +let rec iterM f = function + | [] -> return () + | x :: xs -> + let* _ = f x in + iterM f xs + let run m l = let state = m l in (state.value, state.checks) @@ -216,7 +224,9 @@ let required_width n = module type CONFIG = sig val max_unknown_integer_width : int val max_unknown_bitvector_width : int + val max_unknown_generic_vector_length : int val union_ctyp_classify : ctyp -> bool + val register_ref : string -> smt_exp end module type PRIMOP_GEN = sig @@ -241,10 +251,14 @@ let builtin_type_error fn cvals ret_ctyp_opt = raise (Reporting.err_todo l message) | None -> raise (Reporting.err_todo l (Printf.sprintf "%s : (%s)" fn args)) +type undefined_mode = Undefined_zeros | Undefined_bits | Undefined_disable + +let undefined_enabled = function Undefined_disable -> false | _ -> true + module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let lint_size = Config.max_unknown_integer_width let lbits_size = Config.max_unknown_bitvector_width - let lbits_index = required_width (Big_int.of_int (lbits_size - 1)) + let lbits_index = required_width (Big_int.of_int lbits_size) let int_size = function | CT_constant n -> required_width n @@ -257,6 +271,13 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct | CT_lbits -> lbits_size | _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "Argument to bv_size must be a bitvector type" + let generic_vector_length = function + | CT_fvector (n, _) -> n + | CT_vector _ -> Config.max_unknown_generic_vector_length + | _ -> + Reporting.unreachable Parse_ast.Unknown __POS__ + "Argument to generic_vector_length must be a generic vector type" + let to_fbits ctyp bv = match ctyp with | CT_fbits n -> (n, bv) @@ -281,7 +302,7 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct | VL_real str, _ -> let* _ = real_used in return (if str.[0] = '-' then Fn ("-", [Real_lit (String.sub str 1 (String.length str - 1))]) else Real_lit str) - | VL_ref str, _ -> return (Fn ("reg_ref", [String_lit str])) + | VL_ref str, _ -> return (Config.register_ref str) | _ -> let* l = current_location in Reporting.unreachable l __POS__ @@ -291,9 +312,9 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct match (op, args) with | Bnot, [arg] -> Fn ("not", [arg]) | Bor, [arg] -> arg - | Bor, args -> Fn ("or", args) + | Bor, args -> smt_disj args | Band, [arg] -> arg - | Band, args -> Fn ("and", args) + | Band, args -> smt_conj args | Eq, args -> Fn ("=", args) | Neq, args -> Fn ("not", [Fn ("=", args)]) | Ilt, [lhs; rhs] -> Fn ("bvslt", [lhs; rhs]) @@ -325,6 +346,7 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct match cval with | V_lit (vl, ctyp) -> literal vl ctyp | V_id (id, _) -> return (Var id) + | V_member (id, _) -> return (Var (Name (id, -1))) | V_call (List_hd, [arg]) -> let* l = current_location in let op = Primop_gen.hd l (cval_ctyp arg) in @@ -466,10 +488,30 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct else if sz1 > sz2 then Fn (fn, [sv1; SignExtend (sz1, sz1 - sz2, sv2)]) else Fn (fn, [SignExtend (sz2, sz2 - sz1, sv1); sv2]) ) - | CT_constant c, CT_fint sz -> return (Fn (fn, [bvint sz c; sv2])) - | CT_fint sz, CT_constant c -> return (Fn (fn, [sv1; bvint sz c])) - | CT_constant c, CT_lint -> return (Fn (fn, [bvint lint_size c; sv2])) - | CT_lint, CT_constant c -> return (Fn (fn, [sv1; bvint lint_size c])) + | CT_constant c, CT_fint sz -> + let constant_sz = required_width c in + if constant_sz <= sz then return (Fn (fn, [bvint sz c; sv2])) + else + let* sv2 = signed_size ~checked:false ~into:constant_sz ~from:sz sv2 in + return (Fn (fn, [bvint constant_sz c; sv2])) + | CT_fint sz, CT_constant c -> + let constant_sz = required_width c in + if constant_sz <= sz then return (Fn (fn, [sv1; bvint sz c])) + else + let* sv1 = signed_size ~checked:false ~into:constant_sz ~from:sz sv1 in + return (Fn (fn, [sv1; bvint constant_sz c])) + | CT_constant c, CT_lint -> + let constant_sz = required_width c in + if constant_sz <= lint_size then return (Fn (fn, [bvint lint_size c; sv2])) + else + let* sv2 = signed_size ~checked:false ~into:constant_sz ~from:lint_size sv2 in + return (Fn (fn, [bvint constant_sz c; sv2])) + | CT_lint, CT_constant c -> + let constant_sz = required_width c in + if constant_sz <= lint_size then return (Fn (fn, [sv1; bvint lint_size c])) + else + let* sv1 = signed_size ~checked:false ~into:constant_sz ~from:lint_size sv1 in + return (Fn (fn, [sv1; bvint constant_sz c])) | CT_fint sz, CT_lint when sz < lint_size -> return (Fn (fn, [SignExtend (lint_size, lint_size - sz, sv1); sv2])) | CT_lint, CT_fint sz when sz < lint_size -> return (Fn (fn, [sv1; SignExtend (lint_size, lint_size - sz, sv2)])) | _, _ -> builtin_type_error fn [v1; v2] None @@ -511,6 +553,10 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let shift = Fn ("concat", [bvzero (lbits_size - lbits_index); len]) in bvnot (bvshl all_ones shift) + let wf_lbits bv = + let mask = bvnot (bvmask (Fn ("len", [bv]))) in + Fn ("=", [bvand mask (Fn ("contents", [bv])); bvzero lbits_size]) + let builtin_shift shiftop vbits vshift ret_ctyp = match cval_ctyp vbits with | CT_fbits n -> @@ -575,22 +621,19 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct | CT_fbits n, _, CT_fbits m -> let* bv = smt_cval vbits in return (Fn ("concat", [bvzero (m - n); bv])) - | CT_fbits n, CT_fint m, CT_lbits -> - let* bv = smt_cval vbits in - return (Fn ("concat", [bvzero (m - n); bv])) | CT_lbits, _, CT_fbits m -> let* bv = smt_cval vbits in return (Extract (m - 1, 0, Fn ("contents", [bv]))) - (* - | CT_fbits n, CT_lbits -> - assert (lbits_size ctx >= n); - let vbits = - if lbits_size ctx = n then smt_cval ctx vbits - else if lbits_size ctx > n then Fn ("concat", [bvzero (lbits_size ctx - n); smt_cval ctx vbits]) - else assert false - in - Fn ("Bits", [bvzeint ctx ctx.lbits_index vlen; vbits]) - *) + | CT_fbits n, _, CT_lbits -> + let* bits = + if lbits_size = n then smt_cval vbits + else if lbits_size > n then + let* unextended = smt_cval vbits in + return (Fn ("concat", [bvzero (lbits_size - n); unextended])) + else assert false + in + let* len = bvzeint lbits_index vlen in + return (Fn ("Bits", [len; bits])) | CT_lbits, CT_lint, CT_lbits -> let* len = smt_cval vlen in let* bv = smt_cval vbits in @@ -746,13 +789,17 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with | CT_fbits n, CT_fbits m, CT_fbits o -> assert (n + m = o); - let* smt1 = smt_cval v1 in - let* smt2 = smt_cval v2 in - return (Fn ("concat", [smt1; smt2])) + if n = 0 then smt_cval v2 + else if m = 0 then smt_cval v1 + else + let* smt1 = smt_cval v1 in + let* smt2 = smt_cval v2 in + return (Fn ("concat", [smt1; smt2])) + | CT_fbits n, CT_lbits, CT_lbits when n = 0 -> smt_cval v2 | CT_fbits n, CT_lbits, CT_lbits -> let* smt1 = smt_cval v1 in let* smt2 = smt_cval v2 in - let x = Fn ("concat", [bvzero (lbits_size - n); smt1]) in + let x = if lbits_size = n then smt1 else Fn ("concat", [bvzero (lbits_size - n); smt1]) in let shift = Fn ("concat", [bvzero (lbits_size - lbits_index); Fn ("len", [smt2])]) in return (Fn @@ -767,28 +814,23 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let* smt1 = smt_cval v1 in let* smt2 = smt_cval v2 in return (Extract (m - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2]))) - (* | CT_lbits, CT_fbits n, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn - ( "Bits", - [ - bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt1])); - Extract (lbits_size ctx - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])); - ] - ) + let* smt1 = smt_cval v1 in + let* smt2 = smt_cval v2 in + return + (Fn + ( "Bits", + [ + bvadd (bvint lbits_index (Big_int.of_int n)) (Fn ("len", [smt1])); + Extract (lbits_size - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])); + ] + ) + ) | CT_fbits n, CT_fbits m, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn - ( "Bits", - [ - bvint ctx.lbits_index (Big_int.of_int (n + m)); - unsigned_size ctx (lbits_size ctx) (n + m) (Fn ("concat", [smt1; smt2])); - ] - ) - *) + let* smt1 = smt_cval v1 in + let* smt2 = smt_cval v2 in + let* appended = unsigned_size ~into:lbits_size ~from:(n + m) (Fn ("concat", [smt1; smt2])) in + return (Fn ("Bits", [bvint lbits_index (Big_int.of_int (n + m)); appended])) | CT_lbits, CT_lbits, CT_lbits -> let* smt1 = smt_cval v1 in let* smt2 = smt_cval v2 in @@ -797,14 +839,12 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct return (Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) ) - (* | CT_lbits, CT_lbits, CT_fbits n -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("contents", [smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - unsigned_size ctx n (lbits_size ctx) (bvor (bvshl x shift) (Fn ("contents", [smt2]))) - *) + let* smt1 = smt_cval v1 in + let* smt2 = smt_cval v2 in + let x = Fn ("contents", [smt1]) in + let shift = Fn ("concat", [bvzero (lbits_size - lbits_index); Fn ("len", [smt2])]) in + unsigned_size ~into:n ~from:lbits_size (bvor (bvshl x shift) (Fn ("contents", [smt2]))) | _ -> builtin_type_error "append" [v1; v2] (Some ret_ctyp) let builtin_sail_truncate v1 v2 ret_ctyp = @@ -878,7 +918,7 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let* vec = smt_cval vec in let* i = bind (smt_cval i) - (unsigned_size ~checked:false ~into:(required_width (Big_int.of_int (len - 1)) - 1) ~from:(int_size i_ctyp)) + (unsigned_size ~checked:false ~into:(required_width (Big_int.of_int (len - 1))) ~from:(int_size i_ctyp)) in return (Fn ("select", [vec; i])) (* @@ -957,7 +997,7 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let* x = smt_cval x in let* i = bind (smt_cval i) - (unsigned_size ~checked:false ~into:(required_width (Big_int.of_int (len - 1)) - 1) ~from:(int_size i_ctyp)) + (unsigned_size ~checked:false ~into:(required_width (Big_int.of_int (len - 1))) ~from:(int_size i_ctyp)) in return (Store (Fixed len, store_fn, vec, i, x)) (* @@ -1052,25 +1092,70 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct return (bvshl (bvone lint_size) v) | _ -> builtin_type_error "pow2" [v] (Some ret_ctyp) - (* Technically, there's no bvclz in SMTLIB, but we can't generate - anything nice, so leave it in case a backend like SystemVerilog - can do better *) let builtin_count_leading_zeros v ret_ctyp = - let* l = current_location in + let ret_sz = int_size ret_ctyp in + let rec lzcnt sz smt = + if sz == 1 then + Ite + ( Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]), + bvint ret_sz (Big_int.of_int 1), + bvint ret_sz Big_int.zero + ) + else ( + assert (sz land (sz - 1) = 0); + let hsz = sz / 2 in + Ite + ( Fn ("=", [Extract (sz - 1, hsz, smt); bvzero hsz]), + Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); lzcnt hsz (Extract (hsz - 1, 0, smt))]), + lzcnt hsz (Extract (sz - 1, hsz, smt)) + ) + ) + in + let smallest_greater_power_of_two n = + let m = ref 1 in + while !m < n do + m := !m lsl 1 + done; + assert (!m land (!m - 1) = 0); + !m + in + let* smt = smt_cval v in match cval_ctyp v with + | CT_fbits sz when sz land (sz - 1) = 0 -> return (lzcnt sz smt) | CT_fbits sz -> - let bvclz = Primop_gen.count_leading_zeros l sz in - let* bv = smt_cval v in - unsigned_size ~max_value:sz ~into:(int_size ret_ctyp) ~from:sz (Fn (bvclz, [bv])) + let padded_sz = smallest_greater_power_of_two sz in + let padding = bvzero (padded_sz - sz) in + return + (Fn + ("bvsub", [lzcnt padded_sz (Fn ("concat", [padding; smt])); bvint ret_sz (Big_int.of_int (padded_sz - sz))]) + ) | CT_lbits -> - let bvclz = Primop_gen.count_leading_zeros l lbits_size in - let* bv = smt_cval v in - let contents_clz = Fn (bvclz, [Fn ("contents", [bv])]) in - let* len = unsigned_size ~into:lbits_size ~from:lbits_index (Fn ("len", [bv])) in - let lz = bvsub contents_clz (bvsub (bvpint lbits_size (Big_int.of_int lbits_size)) len) in - unsigned_size ~max_value:lbits_size ~into:(int_size ret_ctyp) ~from:lbits_size lz + return + (Fn + ( "bvsub", + [ + lzcnt lbits_size (Fn ("contents", [smt])); + Fn + ( "bvsub", + [ + bvint ret_sz (Big_int.of_int lbits_size); + Fn ("concat", [bvzero (ret_sz - lbits_index); Fn ("len", [smt])]); + ] + ); + ] + ) + ) | _ -> builtin_type_error "count_leading_zeros" [v] (Some ret_ctyp) + let unary_smt op v _ = + let* smt = smt_cval v in + return (Fn (op, [smt])) + + let binary_smt op v1 v2 _ = + let* smt1 = smt_cval v1 in + let* smt2 = smt_cval v2 in + return (Fn (op, [smt1; smt2])) + let arity_error = let* l = current_location in raise (Reporting.unreachable l __POS__ "Trying to generate primitive with incorrect number of arguments") @@ -1086,25 +1171,12 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let ternary_primop f = Some (fun args ret_ctyp -> match args with [v1; v2; v3] -> f v1 v2 v3 ret_ctyp | _ -> arity_error) - let builtin = function - | "eq_bit" -> - binary_primop_simple (fun v1 v2 -> - let* smt1 = smt_cval v1 in - let* smt2 = smt_cval v2 in - return (Fn ("=", [smt1; smt2])) - ) - | "eq_bool" -> - binary_primop_simple (fun v1 v2 -> - let* smt1 = smt_cval v1 in - let* smt2 = smt_cval v2 in - return (Fn ("=", [smt1; smt2])) - ) + let builtin ?(allow_io = true) ?(undefined = Undefined_disable) = function + | "eq_bit" -> binary_primop (binary_smt "=") + | "eq_bool" -> binary_primop (binary_smt "=") + | "eq_string" -> binary_primop (binary_smt "=") | "eq_int" -> binary_primop_simple builtin_eq_int - | "not" -> - unary_primop_simple (fun v -> - let* v = smt_cval v in - return (Fn ("not", [v])) - ) + | "not" -> unary_primop (unary_smt "not") | "lt" -> binary_primop_simple builtin_lt | "lteq" -> binary_primop_simple builtin_lteq | "gt" -> binary_primop_simple builtin_gt @@ -1152,7 +1224,23 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct | "length" -> unary_primop builtin_length | "replicate_bits" -> binary_primop builtin_replicate_bits | "count_leading_zeros" -> unary_primop builtin_count_leading_zeros - | "print_bits" -> + | "eq_real" -> binary_primop (binary_smt "=") + | "neg_real" -> unary_primop (unary_smt "-") + | "add_real" -> binary_primop (binary_smt "+") + | "sub_real" -> binary_primop (binary_smt "-") + | "mult_real" -> binary_primop (binary_smt "*") + | "div_real" -> binary_primop (binary_smt "/") + | "lt_real" -> binary_primop (binary_smt "<") + | "gt_real" -> binary_primop (binary_smt ">") + | "lteq_real" -> binary_primop (binary_smt "<=") + | "gteq_real" -> binary_primop (binary_smt ">=") + | "concat_str" -> + binary_primop_simple (fun str1 str2 -> + let* str1 = smt_cval str1 in + let* str2 = smt_cval str2 in + return (Fn ("str.++", [str1; str2])) + ) + | "print_bits" when allow_io -> binary_primop_simple (fun str bv -> let* l = current_location in let op = Primop_gen.print_bits l (cval_ctyp bv) in @@ -1188,13 +1276,13 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) = struct let* bv = smt_cval bv in return (Fn (op, [bv])) ) - | "sail_assert" -> + | "sail_assert" when allow_io -> binary_primop_simple (fun b msg -> let* b = smt_cval b in let* msg = smt_cval msg in return (Fn ("sail_assert", [b; msg])) ) - | "reg_deref" -> + | "reg_deref" when allow_io -> unary_primop_simple (fun reg_ref -> match cval_ctyp reg_ref with | CT_ref ctyp -> diff --git a/src/lib/smt_gen.mli b/src/lib/smt_gen.mli index f1bf4c2b0..d0c00c28d 100644 --- a/src/lib/smt_gen.mli +++ b/src/lib/smt_gen.mli @@ -78,6 +78,8 @@ open Jib features like strings or real numbers. *) type checks +val get_overflows : checks -> Smt_exp.smt_exp list + (** We generate primitives in a monad that accumulates any required dynamic checks, and contains the location information for any error messages. *) @@ -99,8 +101,14 @@ val ( let+ ) : 'a check_writer -> ('a -> 'b) -> 'b check_writer val mapM : ('a -> 'b check_writer) -> 'a list -> 'b list check_writer +val iterM : ('a -> unit check_writer) -> 'a list -> unit check_writer + val run : 'a check_writer -> Parse_ast.l -> 'a * checks +val string_used : unit check_writer + +val real_used : unit check_writer + (** Convert a SMT bitvector expression of size [from] into a SMT bitvector expression of size [into] with the same signed value. When [into < from] inserts a dynamic check that the @@ -131,12 +139,20 @@ module type CONFIG = sig sufficient. *) val max_unknown_bitvector_width : int + (** If we have a generic vector, [vector('n, 'a)], where ['n] is + unconstrained, then we represent it as a vector of at most this + length. *) + val max_unknown_generic_vector_length : int + (** Some SystemVerilog implementations (e.g. Verilator), don't support unpacked union types, which forces us to generate different code for different unions depending on the types the contain. This is abstracted into a classify function that the instantiator of this module can supply. *) val union_ctyp_classify : ctyp -> bool + + (** How we handle register references differs between backends *) + val register_ref : string -> Smt_exp.smt_exp end (** Some Sail primitives we can't directly convert to pure SMT @@ -157,6 +173,13 @@ module type PRIMOP_GEN = sig val tl : Parse_ast.l -> ctyp -> string end +(** We have various options for handling undefined bits for SMT + generation, either we can treat them all as zero (which is + consistent with the default emulator behavior), or generated + undefined bits, or have the builtin generator skip these + functions. *) +type undefined_mode = Undefined_zeros | Undefined_bits | Undefined_disable + module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) : sig (** Convert a Jib IR cval into an SMT expression *) val smt_cval : cval -> Smt_exp.smt_exp check_writer @@ -165,6 +188,10 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) : sig val bv_size : ctyp -> int + val generic_vector_length : ctyp -> int + + val wf_lbits : Smt_exp.smt_exp -> Smt_exp.smt_exp + (** Create an SMT expression that converts an expression of the jib type [from] into an SMT expression for the jib type [into]. Note that this function assumes that the input is of the correct @@ -174,5 +201,6 @@ module Make (Config : CONFIG) (Primop_gen : PRIMOP_GEN) : sig (** Compile a call to a Sail builtin function into an SMT expression implementing that call. Returns None if that builtin is unsupported by this module. *) - val builtin : string -> (cval list -> ctyp -> Smt_exp.smt_exp check_writer) option + val builtin : + ?allow_io:bool -> ?undefined:undefined_mode -> string -> (cval list -> ctyp -> Smt_exp.smt_exp check_writer) option end diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index eb017f4ee..5f27008b6 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -5180,8 +5180,9 @@ let initial_env = ( TypQ_aux (TypQ_tq [QI_aux (QI_id (mk_kopt K_int (mk_kid "n")), Parse_ast.Unknown)], Parse_ast.Unknown), function_typ [atom_typ (nvar (mk_kid "n"))] (app_typ (mk_id "itself") [mk_typ_arg (A_nexp (nvar (mk_kid "n")))]) ) - (* __assume is used by property.ml to add guards for SMT generation, + (* sail_assume is used by property.ml to add guards for SMT generation, but which don't affect flow-typing. *) + |> Env.add_extern (mk_id "sail_assume") { pure = true; bindings = [("_", "sail_assume")] } |> Env.add_val_spec (mk_id "sail_assume") (TypQ_aux (TypQ_no_forall, Parse_ast.Unknown), function_typ [bool_typ] unit_typ) diff --git a/src/sail_c_backend/c_backend.ml b/src/sail_c_backend/c_backend.ml index ffc77cca6..9b0bb43d3 100644 --- a/src/sail_c_backend/c_backend.ml +++ b/src/sail_c_backend/c_backend.ml @@ -593,6 +593,7 @@ end) : CONFIG = struct let use_real = false let branch_coverage = Opts.branch_coverage let track_throw = true + let use_void = false end (** Functions that have heap-allocated return types are implemented by @@ -619,8 +620,10 @@ let fix_early_heap_return ret instrs = before @ [itry_block l (rewrite_return instrs)] @ rewrite_return after | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> before @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] @ rewrite_return after - | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after -> - before @ [I_aux (I_funcall (CL_addr (CL_id (ret, CT_ref ctyp)), extern, fid, args), aux)] @ rewrite_return after + | before, I_aux (I_funcall (CR_one (CL_id (Return _, ctyp)), extern, fid, args), aux) :: after -> + before + @ [I_aux (I_funcall (CR_one (CL_addr (CL_id (ret, CT_ref ctyp))), extern, fid, args), aux)] + @ rewrite_return after | before, I_aux (I_copy (CL_id (Return _, ctyp), cval), aux) :: after -> before @ [I_aux (I_copy (CL_addr (CL_id (ret, CT_ref ctyp)), cval), aux)] @ rewrite_return after | before, I_aux ((I_end _ | I_undefined _), _) :: after -> @@ -643,8 +646,8 @@ let fix_early_stack_return ret ret_ctyp instrs = before @ [itry_block l (rewrite_return instrs)] @ rewrite_return after | before, I_aux (I_if (cval, then_instrs, else_instrs, ctyp), (_, l)) :: after -> before @ [iif l cval (rewrite_return then_instrs) (rewrite_return else_instrs) ctyp] @ rewrite_return after - | before, I_aux (I_funcall (CL_id (Return _, ctyp), extern, fid, args), aux) :: after -> - before @ [I_aux (I_funcall (CL_id (ret, ctyp), extern, fid, args), aux)] @ rewrite_return after + | before, I_aux (I_funcall (CR_one (CL_id (Return _, ctyp)), extern, fid, args), aux) :: after -> + before @ [I_aux (I_funcall (CR_one (CL_id (ret, ctyp)), extern, fid, args), aux)] @ rewrite_return after | before, I_aux (I_copy (CL_id (Return _, ctyp), cval), aux) :: after -> before @ [I_aux (I_copy (CL_id (ret, ctyp), cval), aux)] @ rewrite_return after | before, I_aux (I_end _, _) :: after -> before @ [ireturn (V_id (ret, ret_ctyp))] @ rewrite_return after @@ -965,6 +968,7 @@ let sgen_value = function let rec sgen_cval = function | V_id (id, _) -> sgen_name id + | V_member (id, _) -> sgen_id id | V_lit (vl, _) -> sgen_value vl | V_call (op, cvals) -> sgen_call op cvals | V_field (f, field) -> Printf.sprintf "%s.%s" (sgen_cval f) (sgen_id field) @@ -1108,7 +1112,8 @@ let rec sgen_clexp l = function | CL_id (Have_exception _, _) -> "have_exception" | CL_id (Current_exception _, _) -> "current_exception" | CL_id (Throw_location _, _) -> "throw_location" - | CL_id (Return _, _) -> Reporting.unreachable l __POS__ "CL_return should have been removed" + | CL_id (Channel _, _) -> Reporting.unreachable l __POS__ "CL_id Channel should not appear in C backend" + | CL_id (Return _, _) -> Reporting.unreachable l __POS__ "CL_id Return should have been removed" | CL_id (Name (id, _), _) -> "&" ^ sgen_id id | CL_field (clexp, field) -> "&((" ^ sgen_clexp l clexp ^ ")->" ^ zencode_id field ^ ")" | CL_tuple (clexp, n) -> "&((" ^ sgen_clexp l clexp ^ ")->ztup" ^ string_of_int n ^ ")" @@ -1120,7 +1125,8 @@ let rec sgen_clexp_pure l = function | CL_id (Have_exception _, _) -> "have_exception" | CL_id (Current_exception _, _) -> "current_exception" | CL_id (Throw_location _, _) -> "throw_location" - | CL_id (Return _, _) -> Reporting.unreachable l __POS__ "CL_return should have been removed" + | CL_id (Channel _, _) -> Reporting.unreachable l __POS__ "CL_id Channel should not appear in C backend" + | CL_id (Return _, _) -> Reporting.unreachable l __POS__ "CL_id Return should have been removed" | CL_id (Name (id, _), _) -> sgen_id id | CL_field (clexp, field) -> sgen_clexp_pure l clexp ^ "." ^ zencode_id field | CL_tuple (clexp, n) -> sgen_clexp_pure l clexp ^ ".ztup" ^ string_of_int n @@ -1240,6 +1246,11 @@ let rec codegen_instr fid ctx (I_aux (instr, (_, l))) = ^^ jump 2 2 (sq_separate_map hardline (codegen_instr fid ctx) instrs) ^^ hardline ^^ string " }" | I_funcall (x, special_extern, f, args) -> + let x = + match x with + | CR_one x -> x + | CR_multi _ -> Reporting.unreachable l __POS__ "Multiple returns should not exist in C backend" + in let c_args = Util.string_of_list ", " sgen_cval args in let ctyp = clexp_ctyp x in let is_extern = ctx_is_extern (fst f) ctx || special_extern in diff --git a/src/sail_smt_backend/jib_smt.ml b/src/sail_smt_backend/jib_smt.ml index 40a4471ea..8890eb3a9 100644 --- a/src/sail_smt_backend/jib_smt.ml +++ b/src/sail_smt_backend/jib_smt.ml @@ -25,6 +25,7 @@ (* Stephen Kell *) (* Mark Wassell *) (* Alastair Reid (Arm Ltd) *) +(* Louis-Emile Ploix *) (* *) (* All rights reserved. *) (* *) @@ -69,1305 +70,914 @@ open Libsail open Anf open Ast -open Ast_defs open Ast_util open Jib +open Jib_compile open Jib_util -open Smtlib +open PPrint +open Printf +open Smt_exp open Property -module IntSet = Set.Make (struct - type t = int - let compare = compare -end) -module IntMap = Map.Make (struct - type t = int - let compare = compare -end) +let opt_debug_graphs = ref false let zencode_upper_id id = Util.zencode_upper_string (string_of_id id) let zencode_id id = Util.zencode_string (string_of_id id) let zencode_name id = string_of_name ~deref_current_exception:false ~zencode:true id -let zencode_uid (id, ctyps) = - match ctyps with - | [] -> Util.zencode_string (string_of_id id) - | _ -> Util.zencode_string (string_of_id id ^ "#" ^ Util.string_of_list "_" string_of_ctyp ctyps) -let opt_ignore_overflow = ref false +let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1)) +let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) -let opt_auto = ref false +let required_width n = + let rec required_width' n = + if Big_int.equal n Big_int.zero then 1 else 1 + required_width' (Big_int.shift_right n 1) + in + required_width' (Big_int.abs n) -let opt_debug_graphs = ref false +module type Sequence = sig + type 'a t + val create : unit -> 'a t + val add : 'a -> 'a t -> unit +end + +module Make_optimizer (S : Sequence) = struct + module NameHash = struct + type t = Jib.name + let equal x y = Name.compare x y = 0 + let hash = function + | Name (Id_aux (aux, _), n) -> Hashtbl.hash (0, (aux, n)) + | Have_exception n -> Hashtbl.hash (1, n) + | Current_exception n -> Hashtbl.hash (2, n) + | Throw_location n -> Hashtbl.hash (3, n) + | Return n -> Hashtbl.hash (4, n) + | Channel (chan, n) -> Hashtbl.hash (5, (chan, n)) + end + + module NameHashtbl = Hashtbl.Make (NameHash) + + let optimize stack = + let stack' = Stack.create () in + let uses = NameHashtbl.create (Stack.length stack) in + + let rec uses_in_exp = function + | Var var -> begin + match NameHashtbl.find_opt uses var with + | Some n -> NameHashtbl.replace uses var (n + 1) + | None -> NameHashtbl.add uses var 1 + end + | Enum _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _ | Empty_list -> () + | Fn (_, exps) -> List.iter uses_in_exp exps + | Field (_, _, exp) -> uses_in_exp exp + | Ite (cond, t, e) -> + uses_in_exp cond; + uses_in_exp t; + uses_in_exp e + | Extract (_, _, exp) + | Unwrap (_, _, exp) + | Tester (_, exp) + | SignExtend (_, _, exp) + | ZeroExtend (_, _, exp) + | Hd (_, exp) + | Tl (_, exp) -> + uses_in_exp exp + | Store (_, _, arr, index, x) -> + uses_in_exp arr; + uses_in_exp index; + uses_in_exp x + in + + let remove_unused () = function + | Declare_const (var, _) as def -> begin + match NameHashtbl.find_opt uses var with None -> () | Some _ -> Stack.push def stack' + end + | Declare_fun _ as def -> Stack.push def stack' + | Define_const (var, _, exp) as def -> begin + match NameHashtbl.find_opt uses var with + | None -> () + | Some _ -> + uses_in_exp exp; + Stack.push def stack' + end + | Declare_datatypes _ as def -> Stack.push def stack' + | Assert exp as def -> + uses_in_exp exp; + Stack.push def stack' + | Define_fun _ -> assert false + in + Stack.fold remove_unused () stack; + + let vars = NameHashtbl.create (Stack.length stack') in + let seq = S.create () in -let opt_propagate_vars = ref false + let constant_propagate = function + | Declare_const _ as def -> S.add def seq + | Declare_fun _ as def -> S.add def seq + | Define_const (var, typ, exp) -> + let exp = Smt_exp.simp (NameHashtbl.find_opt vars) exp in + begin + match (NameHashtbl.find_opt uses var, Smt_exp.simp (NameHashtbl.find_opt vars) exp) with + | _, (Bitvec_lit _ | Bool_lit _) -> NameHashtbl.add vars var exp + | Some 1, _ -> NameHashtbl.add vars var exp + | Some _, exp -> S.add (Define_const (var, typ, exp)) seq + | None, _ -> assert false + end + | Assert exp -> S.add (Assert (Smt_exp.simp (NameHashtbl.find_opt vars) exp)) seq + | Declare_datatypes _ as def -> S.add def seq + | Define_fun _ -> assert false + in + Stack.iter constant_propagate stack'; + seq +end -let opt_unroll_limit = ref 10 +module Queue_optimizer = Make_optimizer (struct + type 'a t = 'a Queue.t + let create = Queue.create + let add = Queue.add + let iter = Queue.iter +end) module EventMap = Map.Make (Event) -(* Note that we have to use x : ty ref rather than mutable x : ty, to - make sure { ctx with x = ... } doesn't break the mutable state. *) - -(* See mli file for a description of each field *) -type ctx = { - lbits_index : int; - lint_size : int; - vector_index : int; - register_map : id list CTMap.t; - tuple_sizes : IntSet.t ref; - tc_env : Type_check.Env.t; - pragma_l : Ast.l; - arg_stack : (int * string) Stack.t; - ast : Type_check.typed_ast; - shared : ctyp Bindings.t; - preserved : IdSet.t; +type state = { events : smt_exp Stack.t EventMap.t ref; node : int; - pathcond : smt_exp Lazy.t; - use_string : bool ref; - use_real : bool ref; + cfg : (Jib_ssa.ssa_elem list * Jib_ssa.cf_node) Jib_ssa.array_graph; + arg_stack : (int * string) Stack.t; } -(* These give the default bounds for various SMT types, stored in the - initial_ctx. They shouldn't be read or written by anything else! If - they are changed the output of sail -help needs to be updated to - reflect this. *) -let opt_default_lint_size = ref 128 -let opt_default_lbits_index = ref 8 -let opt_default_vector_index = ref 5 - -let initial_ctx () = - { - lbits_index = !opt_default_lbits_index; - lint_size = !opt_default_lint_size; - vector_index = !opt_default_vector_index; - register_map = CTMap.empty; - tuple_sizes = ref IntSet.empty; - tc_env = Type_check.initial_env; - pragma_l = Parse_ast.Unknown; - arg_stack = Stack.create (); - ast = empty_ast; - shared = Bindings.empty; - preserved = IdSet.empty; - events = ref EventMap.empty; - node = -1; - pathcond = lazy (Bool_lit true); - use_string = ref false; - use_real = ref false; - } - -let event_stack ctx ev = - match EventMap.find_opt ev !(ctx.events) with +let event_stack state ev = + match EventMap.find_opt ev !(state.events) with | Some stack -> stack | None -> let stack = Stack.create () in - ctx.events := EventMap.add ev stack !(ctx.events); + state.events := EventMap.add ev stack !(state.events); stack -let add_event ctx ev smt = - let stack = event_stack ctx ev in - Stack.push (Fn ("and", [Lazy.force ctx.pathcond; smt])) stack +module type CONFIG = sig + val max_unknown_integer_width : int + val max_unknown_bitvector_width : int + val max_unknown_generic_vector_length : int + val register_map : id list CTMap.t + val ignore_overflow : bool +end + +module Make (Config : CONFIG) = struct + open Jib_visitor + + let lbits_index_width = required_width (Big_int.of_int Config.max_unknown_bitvector_width) + let vector_index_width = required_width (Big_int.of_int (Config.max_unknown_generic_vector_length - 1)) + + module Smt = + Smt_gen.Make + (struct + let max_unknown_integer_width = Config.max_unknown_integer_width + let max_unknown_bitvector_width = Config.max_unknown_bitvector_width + let max_unknown_generic_vector_length = Config.max_unknown_generic_vector_length + let union_ctyp_classify _ = true + let register_ref reg_name = + let id = mk_id reg_name in + let rmap = + CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) Config.register_map + in + assert (CTMap.cardinal rmap = 1); + match CTMap.min_binding_opt rmap with + | Some (ctyp, regs) -> begin + match Util.list_index (fun reg -> Id.compare reg id = 0) regs with + | Some i -> Smt_gen.bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) + | None -> assert false + end + | _ -> assert false + end) + (struct + let print_bits l = function _ -> Reporting.unreachable l __POS__ "print_bits" -let add_pathcond_event ctx ev = Stack.push (Lazy.force ctx.pathcond) (event_stack ctx ev) + let string_of_bits l = function _ -> Reporting.unreachable l __POS__ "string_of_bits" -let overflow_check ctx smt = - if not !opt_ignore_overflow then ( - Reporting.warn "Overflow check in generated SMT for" ctx.pragma_l ""; - add_event ctx Overflow smt - ) + let dec_str l = function _ -> Reporting.unreachable l __POS__ "dec_str" -let lbits_size ctx = Util.power 2 ctx.lbits_index + let hex_str l = function _ -> Reporting.unreachable l __POS__ "hex_str" -let vector_index = ref 5 + let hex_str_upper l = function _ -> Reporting.unreachable l __POS__ "hex_str_upper" -let smt_unit = mk_enum "Unit" ["Unit"] -let smt_lbits ctx = mk_record "Bits" [("size", Bitvec ctx.lbits_index); ("bits", Bitvec (lbits_size ctx))] + let count_leading_zeros l = function _ -> Reporting.unreachable l __POS__ "count_leading_zeros" -(* [required_width n] is the required number of bits to losslessly - represent an integer n *) -let required_width n = - let rec required_width' n = - if Big_int.equal n Big_int.zero then 1 else 1 + required_width' (Big_int.shift_right n 1) - in - required_width' (Big_int.abs n) + let fvector_store l _ _ = "store" -let rec smt_ctyp ctx = function - | CT_constant n -> Bitvec (required_width n) - | CT_fint n -> Bitvec n - | CT_lint -> Bitvec ctx.lint_size - | CT_unit -> smt_unit - | CT_bit -> Bitvec 1 - | CT_fbits n -> Bitvec n - | CT_sbits n -> smt_lbits ctx - | CT_lbits -> smt_lbits ctx - | CT_bool -> Bool - | CT_enum (id, elems) -> mk_enum (zencode_upper_id id) (List.map zencode_id elems) - | CT_struct (id, fields) -> - mk_record (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) fields) - | CT_variant (id, ctors) -> - mk_variant (zencode_upper_id id) (List.map (fun (id, ctyp) -> (zencode_id id, smt_ctyp ctx ctyp)) ctors) - | CT_tup ctyps -> - ctx.tuple_sizes := IntSet.add (List.length ctyps) !(ctx.tuple_sizes); - Tuple (List.map (smt_ctyp ctx) ctyps) - | CT_vector ctyp | CT_fvector (_, ctyp) -> Array (Bitvec !vector_index, smt_ctyp ctx ctyp) - | CT_string -> - ctx.use_string := true; - String - | CT_real -> - ctx.use_real := true; - Real - | CT_ref ctyp -> begin - match CTMap.find_opt ctyp ctx.register_map with - | Some regs -> Bitvec (required_width (Big_int.of_int (List.length regs))) - | _ -> failwith ("No registers with ctyp: " ^ string_of_ctyp ctyp) - end - | CT_list _ -> raise (Reporting.err_todo ctx.pragma_l "Lists not yet supported in SMT generation") - | CT_float _ | CT_rounding_mode -> Reporting.unreachable ctx.pragma_l __POS__ "Floating point in SMT property" - | CT_poly _ -> Reporting.unreachable ctx.pragma_l __POS__ "Found polymorphic type in SMT property" - -(* We often need to create a SMT bitvector of a length sz with integer - value x. [bvpint sz x] does this for positive integers, and [bvint sz x] - does this for all integers. It's quite awkward because we - don't have a very good way to get the binary representation of - either an ocaml integer or a big integer. *) -let bvpint sz x = - let open Sail2_values in - if Big_int.less_equal Big_int.zero x && Big_int.less_equal x (Big_int.of_int max_int) then ( - let x = Big_int.to_int x in - match Printf.sprintf "%X" x |> Util.string_to_list |> List.map nibble_of_char |> Util.option_all with - | Some nibbles -> - let bin = List.map (fun (a, b, c, d) -> [a; b; c; d]) nibbles |> List.concat in - let _, bin = Util.take_drop (function B0 -> true | _ -> false) bin in - let padding = List.init (sz - List.length bin) (fun _ -> B0) in - Bitvec_lit (padding @ bin) - | None -> assert false - ) - else if Big_int.greater x (Big_int.of_int max_int) then ( - let y = ref x in - let bin = ref [] in - while not (Big_int.equal !y Big_int.zero) do - let q, m = Big_int.quomod !y (Big_int.of_int 2) in - bin := (if Big_int.equal m Big_int.zero then B0 else B1) :: !bin; - y := q - done; - let padding_size = sz - List.length !bin in - if padding_size < 0 then - raise - (Reporting.err_general Parse_ast.Unknown - (Printf.sprintf "Could not create a %d-bit integer with value %s.\nTry increasing the maximum integer size" - sz (Big_int.to_string x) - ) - ); - let padding = List.init padding_size (fun _ -> B0) in - Bitvec_lit (padding @ !bin) - ) - else failwith "Invalid bvpint" - -let bvint sz x = - if Big_int.less x Big_int.zero then - Fn ("bvadd", [Fn ("bvnot", [bvpint sz (Big_int.abs x)]); bvpint sz (Big_int.of_int 1)]) - else bvpint sz x - -(** [force_size ctx n m exp] takes a smt expression assumed to be a - integer (signed bitvector) of length m and forces it to be length n - by either sign extending it or truncating it as required *) -let force_size ?(checked = true) ctx n m smt = - if n = m then smt - else if n > m then SignExtend (n - m, smt) - else ( - let check = - (* If the top bit of the truncated number is one *) - Ite - ( Fn ("=", [Extract (n - 1, n - 1, smt); Bitvec_lit [Sail2_values.B1]]), - (* Then we have an overflow, unless all bits we truncated were also one *) - Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvones (m - n)])]), - (* Otherwise, all the top bits must be zero *) - Fn ("not", [Fn ("=", [Extract (m - 1, n, smt); bvzero (m - n)])]) - ) - in - if checked then overflow_check ctx check else (); - Extract (n - 1, 0, smt) - ) - -(** [unsigned_size ctx n m exp] is much like force_size, but it - assumes that the bitvector is unsigned *) -let unsigned_size ?(checked = true) ctx n m smt = - if n = m then smt else if n > m then Fn ("concat", [bvzero (n - m); smt]) else Extract (n - 1, 0, smt) - -let smt_conversion ctx from_ctyp to_ctyp x = - match (from_ctyp, to_ctyp) with - | _, _ when ctyp_equal from_ctyp to_ctyp -> x - | CT_constant c, CT_fint sz -> bvint sz c - | CT_constant c, CT_lint -> bvint ctx.lint_size c - | CT_fint sz, CT_lint -> force_size ctx ctx.lint_size sz x - | CT_lint, CT_fint sz -> force_size ctx sz ctx.lint_size x - | CT_lint, CT_fbits n -> force_size ctx n ctx.lint_size x - | CT_lint, CT_lbits -> - Fn - ("Bits", [bvint ctx.lbits_index (Big_int.of_int ctx.lint_size); force_size ctx (lbits_size ctx) ctx.lint_size x]) - | CT_fint n, CT_lbits -> Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); force_size ctx (lbits_size ctx) n x]) - | CT_lbits, CT_fbits n -> unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [x])) - | CT_fbits n, CT_fbits m -> unsigned_size ctx m n x - | CT_fbits n, CT_lbits -> - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int n); unsigned_size ctx (lbits_size ctx) n x]) - | CT_fvector _, CT_vector _ -> x - | CT_vector _, CT_fvector _ -> x - | _, _ -> - failwith - (Printf.sprintf "Cannot perform conversion from %s to %s" (string_of_ctyp from_ctyp) (string_of_ctyp to_ctyp)) - -(* Translate Jib literals into SMT *) -let smt_value ctx vl ctyp = - let open Value2 in - match (vl, ctyp) with - | VL_bits bv, CT_fbits n -> unsigned_size ctx n (List.length bv) (Bitvec_lit bv) - | VL_bits bv, CT_lbits -> - let sz = List.length bv in - Fn ("Bits", [bvint ctx.lbits_index (Big_int.of_int sz); unsigned_size ctx (lbits_size ctx) sz (Bitvec_lit bv)]) - | VL_bool b, _ -> Bool_lit b - | VL_int n, CT_constant m -> bvint (required_width n) n - | VL_int n, CT_fint sz -> bvint sz n - | VL_int n, CT_lint -> bvint ctx.lint_size n - | VL_bit b, CT_bit -> Bitvec_lit [b] - | VL_unit, _ -> Enum "unit" - | VL_string str, _ -> - ctx.use_string := true; - String_lit (String.escaped str) - | VL_real str, _ -> - ctx.use_real := true; - if str.[0] = '-' then Fn ("-", [Real_lit (String.sub str 1 (String.length str - 1))]) else Real_lit str - | VL_enum str, _ -> Enum (Util.zencode_string str) - | VL_ref reg_name, _ -> - let id = mk_id reg_name in - let rmap = CTMap.filter (fun ctyp regs -> List.exists (fun reg -> Id.compare reg id = 0) regs) ctx.register_map in - assert (CTMap.cardinal rmap = 1); - begin - match CTMap.min_binding_opt rmap with - | Some (ctyp, regs) -> begin - match Util.list_index (fun reg -> Id.compare reg id = 0) regs with - | Some i -> bvint (required_width (Big_int.of_int (List.length regs))) (Big_int.of_int i) - | None -> assert false - end - | _ -> assert false + let is_empty l = function _ -> Reporting.unreachable l __POS__ "is_empty" + + let hd l = function _ -> Reporting.unreachable l __POS__ "hd" + + let tl l = function _ -> Reporting.unreachable l __POS__ "tl" + end) + + let ( let* ) = Smt_gen.bind + let return = Smt_gen.return + let mapM = Smt_gen.mapM + + let rec sequence = function + | x :: xs -> + let* y = x in + let* ys = sequence xs in + return (y :: ys) + | [] -> return [] + + let smt_unit = mk_enum "Unit" ["Unit"] + let smt_lbits = + mk_record "Bits" [("len", Bitvec lbits_index_width); ("bits", Bitvec Config.max_unknown_bitvector_width)] + + let rec wf_smt_ctyp = function CT_lbits -> Some (fun exp -> Smt.wf_lbits exp) | _ -> None + + let rec smt_ctyp = function + | CT_constant n -> return (Bitvec (required_width n)) + | CT_fint n -> return (Bitvec n) + | CT_lint -> return (Bitvec Config.max_unknown_integer_width) + | CT_unit -> return smt_unit + | CT_bit -> return (Bitvec 1) + | CT_fbits n -> return (Bitvec n) + | CT_sbits n -> return smt_lbits + | CT_lbits -> return smt_lbits + | CT_bool -> return Bool + | CT_enum (id, elems) -> return (mk_enum (zencode_upper_id id) (List.map zencode_id elems)) + | CT_struct (id, fields) -> + let* fields = + mapM + (fun (id, ctyp) -> + let* ctyp = smt_ctyp ctyp in + return (zencode_id id, ctyp) + ) + fields + in + return (mk_record (zencode_upper_id id) fields) + | CT_variant (id, ctors) -> + let* ctors = + mapM + (fun (id, ctyp) -> + let* ctyp = smt_ctyp ctyp in + return (zencode_id id, ctyp) + ) + ctors + in + return (mk_variant (zencode_upper_id id) ctors) + | CT_fvector (n, ctyp) -> + let* ctyp = smt_ctyp ctyp in + return (Array (Bitvec (required_width (Big_int.of_int (n - 1))), ctyp)) + | CT_vector ctyp -> + let* ctyp = smt_ctyp ctyp in + return (Array (Bitvec vector_index_width, ctyp)) + | CT_string -> + let* _ = Smt_gen.string_used in + return String + | CT_real -> + let* _ = Smt_gen.real_used in + return Real + | CT_ref ctyp -> begin + match CTMap.find_opt ctyp Config.register_map with + | Some regs -> return (Bitvec (required_width (Big_int.of_int (List.length regs)))) + | _ -> + let* l = Smt_gen.current_location in + Reporting.unreachable l __POS__ ("No registers with ctyp: " ^ string_of_ctyp ctyp) end - | _ -> failwith ("Cannot translate literal to SMT: " ^ string_of_value vl ^ " : " ^ string_of_ctyp ctyp) - -let rec smt_cval ctx cval = - match cval_ctyp cval with - | CT_constant n -> bvint (required_width n) n - | _ -> ( - match cval with - | V_lit (vl, ctyp) -> smt_value ctx vl ctyp - | V_id ((Name (id, _) as ssa_id), _) -> begin - match Type_check.Env.lookup_id id ctx.tc_env with - | Enum _ -> Enum (zencode_id id) - | _ when Bindings.mem id ctx.shared -> Shared (zencode_id id) - | _ -> Var (zencode_name ssa_id) - end - | V_id (ssa_id, _) -> Var (zencode_name ssa_id) - | V_call (Neq, [cval1; cval2]) -> Fn ("not", [Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2])]) - | V_call (Bvor, [cval1; cval2]) -> Fn ("bvor", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Eq, [cval1; cval2]) -> Fn ("=", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Bnot, [cval]) -> Fn ("not", [smt_cval ctx cval]) - | V_call (Band, cvals) -> smt_conj (List.map (smt_cval ctx) cvals) - | V_call (Bor, cvals) -> smt_disj (List.map (smt_cval ctx) cvals) - | V_call (Igt, [cval1; cval2]) -> Fn ("bvsgt", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_call (Iadd, [cval1; cval2]) -> Fn ("bvadd", [smt_cval ctx cval1; smt_cval ctx cval2]) - | V_ctor_kind (union, ctor, _) -> Fn ("not", [Tester (zencode_uid ctor, smt_cval ctx union)]) - | V_ctor_unwrap (union, ctor, _) -> Fn ("un" ^ zencode_uid ctor, [smt_cval ctx union]) - | V_field (record, field) -> begin - match cval_ctyp record with - | CT_struct (struct_id, _) -> Field (zencode_upper_id struct_id ^ "_" ^ zencode_id field, smt_cval ctx record) - | _ -> failwith "Field for non-struct type" + | CT_list _ -> + let* l = Smt_gen.current_location in + raise (Reporting.err_todo l "Lists not yet supported in SMT generation") + | CT_float _ | CT_rounding_mode -> + let* l = Smt_gen.current_location in + Reporting.unreachable l __POS__ "Floating point in SMT property" + | CT_tup _ -> + let* l = Smt_gen.current_location in + Reporting.unreachable l __POS__ "Tuples should be re-written before SMT generation" + | CT_poly _ -> + let* l = Smt_gen.current_location in + Reporting.unreachable l __POS__ "Found polymorphic type in SMT property" + + (* When generating SMT when we encounter joins between two or more + blocks such as in the example below, we have to generate a muxer + that chooses the correct value of v_n or v_m to assign to v_o. We + use the pi nodes that contain the path condition for each + block to generate an if-then-else for each phi function. The order + of the arguments to each phi function is based on the graph node + index for the predecessor nodes. + + +---------------+ +---------------+ + | pi(cond_1) | | pi(cond_2) | + | ... | | ... | + | Basic block 1 | | Basic block 2 | + +---------------+ +---------------+ + \ / + \ / + +---------------------+ + | v/o = phi(v/n, v/m) | + | ... | + +---------------------+ + + would generate: + + (define-const v/o (ite cond_1 v/n v/m)) + *) + let smt_ssanode cfg preds = + let open Jib_ssa in + function + | Pi _ -> return [] + | Phi (id, ctyp, ids) -> ( + let get_pi n = + match get_vertex cfg n with + | Some ((ssa_elems, _), _, _) -> List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems) + | None -> failwith "Predecessor node does not exist" + in + let pis = List.map get_pi (IntSet.elements preds) in + let* mux = + List.fold_right2 + (fun pi id chain -> + let* chain = chain in + let* pi = mapM Smt.smt_cval pi in + let pathcond = smt_conj pi in + match chain with Some smt -> return (Some (Ite (pathcond, Var id, smt))) | None -> return (Some (Var id)) + ) + pis ids (return None) + in + let* ctyp = smt_ctyp ctyp in + match mux with None -> assert false | Some mux -> return [Define_const (id, ctyp, mux)] + ) + + (* The pi condition are computed by traversing the dominator tree, + with each node having a pi condition defined as the conjunction of + all guards between it and the start node in the dominator + tree. This is imprecise because we have situations like: + + 1 + / \ + 2 3 + | | + | 4 + | |\ + 5 6 9 + \ / | + 7 10 + | + 8 + + where 8 = match_failure, 1 = start and 10 = return. + 2, 3, 6 and 9 are guards as they come directly after a control flow + split, which always follows a conditional jump. + + Here the path through the dominator tree for the match_failure is + 1->7->8 which contains no guards so the pi condition would be empty. + What we do now is walk backwards (CFG must be acyclic at this point) + until we hit the join point prior to where we require a path + condition. We then take the disjunction of the pi conditions for the + join point's predecessors, so 5 and 6 in this case. Which gives us a + path condition of 2 | (3 & 6) as the dominator chains are 1->2->5 and + 1->3->4->6. + + This should work as any split in the control flow must have been + caused by a conditional jump followed by distinct guards, so each of + the nodes immediately prior to a join point must be dominated by at + least one unique guard. It also explains why the pi conditions are + sufficient to choose outcomes of phi functions above. + + If we hit a guard before a join (such as 9 for return's path + conditional) we just return the pi condition for that guard, i.e. + (3 & 9) for 10. If we reach start then the path condition is simply + true. + *) + let rec get_pathcond n cfg = + let open Jib_ssa in + let get_pi m = + match get_vertex cfg m with + | Some ((ssa_elems, _), _, _) -> + V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems)) + | None -> failwith "Node does not exist" + in + match get_vertex cfg n with + | Some ((_, CF_guard cond), _, _) -> Smt.smt_cval (get_pi n) + | Some (_, preds, succs) -> + if IntSet.cardinal preds = 0 then return (Bool_lit true) + else if IntSet.cardinal preds = 1 then get_pathcond (IntSet.min_elt preds) cfg + else ( + let pis = List.map get_pi (IntSet.elements preds) in + Smt.smt_cval (V_call (Bor, pis)) + ) + | None -> assert false (* Should never be called for a non-existent node *) + + let add_event state ev smt = + let stack = event_stack state ev in + let* pathcond = get_pathcond state.node state.cfg in + Stack.push (Fn ("and", [pathcond; smt])) stack; + return () + + let add_pathcond_event state ev = + let* pathcond = get_pathcond state.node state.cfg in + Stack.push pathcond (event_stack state ev); + return () + + let define_const id ctyp exp = + let* ty = smt_ctyp ctyp in + return (Define_const (id, ty, exp)) + + let declare_const id ctyp = + let* ty = smt_ctyp ctyp in + return (Declare_const (id, ty)) + + let singleton = Smt_gen.fmap (fun x -> [x]) + + (* For any complex l-expression we need to turn it into a + read-modify-write in the SMT solver. The SSA transform turns CL_id + nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped + in any other l-expression. The read and write must have the same + name but different SSA numbers. + *) + let rec rmw_write = function + | CL_rmw (_, write, ctyp) -> (write, ctyp) + | CL_id _ -> assert false + | CL_tuple (clexp, _) -> rmw_write clexp + | CL_field (clexp, _) -> rmw_write clexp + | clexp -> failwith "Could not understand l-expression" + + let rmw_read = function CL_rmw (read, _, _) -> read | _ -> assert false + + let rmw_modify smt = function + | CL_tuple (clexp, n) -> + let ctyp = clexp_ctyp clexp in + begin + match ctyp with + | CT_tup ctyps -> + let len = List.length ctyps in + let set_tup i = if i == n then smt else Fn (Printf.sprintf "tup_%d_%d" len i, [Var (rmw_read clexp)]) in + Fn ("tup" ^ string_of_int len, List.init len set_tup) + | _ -> failwith "Tuple modify does not have tuple type" end - | V_struct (fields, ctyp) -> begin + | CL_field (clexp, field) -> + let ctyp = clexp_ctyp clexp in + begin match ctyp with - | CT_struct (struct_id, field_ctyps) -> - let set_field (field, cval) = - match Util.assoc_compare_opt Id.compare field field_ctyps with - | None -> failwith "Field type not found" - | Some ctyp -> - ( zencode_upper_id struct_id ^ "_" ^ zencode_id field, - smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval) - ) + | CT_struct (struct_id, fields) -> + let set_field (field', _) = + if Id.compare field field' = 0 then smt else Field (struct_id, field', Var (rmw_read clexp)) in - Struct (zencode_upper_id struct_id, List.map set_field fields) - | _ -> failwith "Struct does not have struct type" + Fn (zencode_upper_id struct_id, List.map set_field fields) + | _ -> failwith "Struct modify does not have struct type" end - | V_tuple_member (frag, len, n) -> - ctx.tuple_sizes := IntSet.add len !(ctx.tuple_sizes); - Fn (Printf.sprintf "tup_%d_%d" len n, [smt_cval ctx frag]) - | cval -> failwith ("Unrecognised cval " ^ string_of_cval cval) - ) - -(**************************************************************************) -(* 1. Generating SMT for Sail builtins *) -(**************************************************************************) - -let builtin_type_error ctx fn cvals = - let args = Util.string_of_list ", " (fun cval -> string_of_ctyp (cval_ctyp cval)) cvals in - function - | Some ret_ctyp -> - let message = Printf.sprintf "%s : (%s) -> %s" fn args (string_of_ctyp ret_ctyp) in - raise (Reporting.err_todo ctx.pragma_l message) - | None -> raise (Reporting.err_todo ctx.pragma_l (Printf.sprintf "%s : (%s)" fn args)) - -(* ***** Basic comparisons: lib/flow.sail ***** *) - -let builtin_int_comparison fn big_int_fn ctx v1 v2 = - match (cval_ctyp v1, cval_ctyp v2) with - | CT_lint, CT_lint -> Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) - | CT_fint sz1, CT_fint sz2 -> - if sz1 == sz2 then Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) - else if sz1 > sz2 then Fn (fn, [smt_cval ctx v1; SignExtend (sz1 - sz2, smt_cval ctx v2)]) - else Fn (fn, [SignExtend (sz2 - sz1, smt_cval ctx v1); smt_cval ctx v2]) - | CT_constant c, CT_fint sz -> Fn (fn, [bvint sz c; smt_cval ctx v2]) - | CT_constant c, CT_lint -> Fn (fn, [bvint ctx.lint_size c; smt_cval ctx v2]) - | CT_fint sz, CT_constant c -> Fn (fn, [smt_cval ctx v1; bvint sz c]) - | CT_fint sz, CT_lint when sz < ctx.lint_size -> - Fn (fn, [SignExtend (ctx.lint_size - sz, smt_cval ctx v1); smt_cval ctx v2]) - | CT_lint, CT_fint sz when sz < ctx.lint_size -> - Fn (fn, [smt_cval ctx v1; SignExtend (ctx.lint_size - sz, smt_cval ctx v2)]) - | CT_lint, CT_constant c -> Fn (fn, [smt_cval ctx v1; bvint ctx.lint_size c]) - | CT_constant c1, CT_constant c2 -> Bool_lit (big_int_fn c1 c2) - | _, _ -> builtin_type_error ctx fn [v1; v2] None - -let builtin_eq_int = builtin_int_comparison "=" Big_int.equal - -let builtin_lt = builtin_int_comparison "bvslt" Big_int.less -let builtin_lteq = builtin_int_comparison "bvsle" Big_int.less_equal -let builtin_gt = builtin_int_comparison "bvsgt" Big_int.greater -let builtin_gteq = builtin_int_comparison "bvsge" Big_int.greater_equal - -(* ***** Arithmetic operations: lib/arith.sail ***** *) - -let int_size ctx = function - | CT_constant n -> required_width n - | CT_fint sz -> sz - | CT_lint -> ctx.lint_size - | _ -> Reporting.unreachable ctx.pragma_l __POS__ "Argument to int_size must be an integer type" - -let builtin_arith fn big_int_fn padding ctx v1 v2 ret_ctyp = - (* To detect arithmetic overflow we can expand the input bitvectors - to some size determined by a padding function, then check we - don't lose precision when going back after performing the - operation. *) - let padding = if !opt_ignore_overflow then fun x -> x else padding in - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | _, _, CT_constant c -> bvint (required_width c) c - | CT_constant c1, CT_constant c2, _ -> bvint (int_size ctx ret_ctyp) (big_int_fn c1 c2) - | ctyp1, ctyp2, _ -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - force_size ctx ret_sz (padding ret_sz) - (Fn - ( fn, - [ - force_size ctx (padding ret_sz) (int_size ctx ctyp1) smt1; - force_size ctx (padding ret_sz) (int_size ctx ctyp2) smt2; - ] - ) - ) + | _ -> assert false -let builtin_add_int = builtin_arith "bvadd" Big_int.add (fun x -> x + 1) -let builtin_sub_int = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) -let builtin_mult_int = builtin_arith "bvmul" Big_int.mul (fun x -> x * 2) - -let builtin_sub_nat ctx v1 v2 ret_ctyp = - let result = builtin_arith "bvsub" Big_int.sub (fun x -> x + 1) ctx v1 v2 ret_ctyp in - Ite - ( Fn ("bvslt", [result; bvint (int_size ctx ret_ctyp) Big_int.zero]), - bvint (int_size ctx ret_ctyp) Big_int.zero, - result - ) - -let builtin_negate_int ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | _, CT_constant c -> bvint (required_width c) c - | CT_constant c, _ -> bvint (int_size ctx ret_ctyp) (Big_int.negate c) - | ctyp, _ -> - let open Sail2_values in - let smt = force_size ctx (int_size ctx ret_ctyp) (int_size ctx ctyp) (smt_cval ctx v) in - overflow_check ctx (Fn ("=", [smt; Bitvec_lit (B1 :: List.init (int_size ctx ret_ctyp - 1) (fun _ -> B0))])); - Fn ("bvneg", [smt]) - -let builtin_shift_int fn big_int_fn ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | _, _, CT_constant c -> bvint (required_width c) c - | CT_constant c1, CT_constant c2, _ -> bvint (int_size ctx ret_ctyp) (big_int_fn c1 (Big_int.to_int c2)) - | ctyp, CT_constant c, _ -> - let n = int_size ctx ctyp in - force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [smt_cval ctx v1; bvint n c])) - | CT_constant c, ctyp, _ -> - let n = int_size ctx ctyp in - force_size ctx (int_size ctx ret_ctyp) n (Fn (fn, [bvint n c; smt_cval ctx v2])) - | ctyp1, ctyp2, _ -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn (fn, [force_size ctx ret_sz (int_size ctx ctyp1) smt1; force_size ctx ret_sz (int_size ctx ctyp2) smt2]) - -let builtin_shl_int = builtin_shift_int "bvshl" Big_int.shift_left -let builtin_shr_int = builtin_shift_int "bvashr" Big_int.shift_right - -let builtin_abs_int ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | _, CT_constant c -> bvint (required_width c) c - | CT_constant c, _ -> bvint (int_size ctx ret_ctyp) (Big_int.abs c) - | ctyp, _ -> - let sz = int_size ctx ctyp in - let smt = smt_cval ctx v in - Ite - ( Fn ("=", [Extract (sz - 1, sz - 1, smt); Bitvec_lit [Sail2_values.B1]]), - force_size ctx (int_size ctx ret_ctyp) sz (Fn ("bvneg", [smt])), - force_size ctx (int_size ctx ret_ctyp) sz smt - ) + let builtin_sqrt_real root v = + let* smt = Smt.smt_cval v in + return + [ + Declare_const (root, Real); + Assert (Fn ("and", [Fn ("=", [smt; Fn ("*", [Var root; Var root])]); Fn (">=", [Var root; Real_lit "0.0"])])); + ] -let builtin_pow2 ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | CT_constant n, _ when Big_int.greater_equal n Big_int.zero -> - bvint (int_size ctx ret_ctyp) (Big_int.pow_int_positive 2 (Big_int.to_int n)) - | _ -> builtin_type_error ctx "pow2" [v] (Some ret_ctyp) - -let builtin_max_int ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2) with - | CT_constant n, CT_constant m -> bvint (int_size ctx ret_ctyp) (max n m) - | ctyp1, ctyp2 -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in - let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in - Ite (Fn ("bvslt", [smt1; smt2]), smt2, smt1) - -let builtin_min_int ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2) with - | CT_constant n, CT_constant m -> bvint (int_size ctx ret_ctyp) (min n m) - | ctyp1, ctyp2 -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in - let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in - Ite (Fn ("bvslt", [smt1; smt2]), smt1, smt2) - -let builtin_min_int ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2) with - | CT_constant n, CT_constant m -> bvint (int_size ctx ret_ctyp) (min n m) - | ctyp1, ctyp2 -> - let ret_sz = int_size ctx ret_ctyp in - let smt1 = force_size ctx ret_sz (int_size ctx ctyp1) (smt_cval ctx v1) in - let smt2 = force_size ctx ret_sz (int_size ctx ctyp2) (smt_cval ctx v2) in - Ite (Fn ("bvslt", [smt1; smt2]), smt1, smt2) - -let builtin_tdiv_int = builtin_arith "bvudiv" Sail2_values.tdiv_int (fun x -> x) - -let builtin_tmod_int = builtin_arith "bvurem" Sail2_values.tmod_int (fun x -> x) - -let bvmask ctx len = - let all_ones = bvones (lbits_size ctx) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); len]) in - bvnot (bvshl all_ones shift) - -let fbits_mask ctx n len = bvnot (bvshl (bvones n) len) - -let builtin_eq_bits ctx v1 v2 = - match (cval_ctyp v1, cval_ctyp v2) with - | CT_fbits n, CT_fbits m -> - let o = max n m in - let smt1 = unsigned_size ctx o n (smt_cval ctx v1) in - let smt2 = unsigned_size ctx o n (smt_cval ctx v2) in - Fn ("=", [smt1; smt2]) - | CT_lbits, CT_lbits -> - let len1 = Fn ("len", [smt_cval ctx v1]) in - let contents1 = Fn ("contents", [smt_cval ctx v1]) in - let len2 = Fn ("len", [smt_cval ctx v2]) in - let contents2 = Fn ("contents", [smt_cval ctx v2]) in - Fn - ( "and", - [ - Fn ("=", [len1; len2]); - Fn ("=", [Fn ("bvand", [bvmask ctx len1; contents1]); Fn ("bvand", [bvmask ctx len2; contents2])]); - ] - ) - | CT_lbits, CT_fbits n -> - let smt1 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v1])) in - Fn ("=", [smt1; smt_cval ctx v2]) - | CT_fbits n, CT_lbits -> - let smt2 = unsigned_size ctx n (lbits_size ctx) (Fn ("contents", [smt_cval ctx v2])) in - Fn ("=", [smt_cval ctx v1; smt2]) - | _ -> builtin_type_error ctx "eq_bits" [v1; v2] None - -let builtin_zeros ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | _, CT_fbits n -> bvzero n - | CT_constant c, CT_lbits -> Fn ("Bits", [bvint ctx.lbits_index c; bvzero (lbits_size ctx)]) - | ctyp, CT_lbits when int_size ctx ctyp >= ctx.lbits_index -> - Fn ("Bits", [extract (ctx.lbits_index - 1) 0 (smt_cval ctx v); bvzero (lbits_size ctx)]) - | _ -> builtin_type_error ctx "zeros" [v] (Some ret_ctyp) - -let builtin_ones ctx cval = function - | CT_fbits n -> bvones n - | CT_lbits -> - let len = extract (ctx.lbits_index - 1) 0 (smt_cval ctx cval) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvones (lbits_size ctx)])]) - | ret_ctyp -> builtin_type_error ctx "ones" [cval] (Some ret_ctyp) - -(* [bvzeint ctx esz cval] (BitVector Zero Extend INTeger), takes a cval - which must be an integer type (either CT_fint, or CT_lint), and - produces a bitvector which is either zero extended or truncated to - exactly esz bits. *) -let bvzeint ctx esz cval = - let sz = int_size ctx (cval_ctyp cval) in - match cval with - | V_lit (VL_int n, _) -> bvint esz n - | _ -> - let smt = smt_cval ctx cval in - if esz = sz then smt else if esz > sz then Fn ("concat", [bvzero (esz - sz); smt]) else Extract (esz - 1, 0, smt) - -let builtin_zero_extend ctx vbits vlen ret_ctyp = - match (cval_ctyp vbits, ret_ctyp) with - | CT_fbits n, CT_fbits m when n = m -> smt_cval ctx vbits - | CT_fbits n, CT_fbits m -> - let bv = smt_cval ctx vbits in - Fn ("concat", [bvzero (m - n); bv]) - | CT_lbits, CT_fbits m -> - assert (lbits_size ctx >= m); - Extract (m - 1, 0, Fn ("contents", [smt_cval ctx vbits])) - | CT_fbits n, CT_lbits -> - assert (lbits_size ctx >= n); - let vbits = - if lbits_size ctx = n then smt_cval ctx vbits - else if lbits_size ctx > n then Fn ("concat", [bvzero (lbits_size ctx - n); smt_cval ctx vbits]) - else assert false - in - Fn ("Bits", [bvzeint ctx ctx.lbits_index vlen; vbits]) - | _ -> builtin_type_error ctx "zero_extend" [vbits; vlen] (Some ret_ctyp) - -let builtin_sign_extend ctx vbits vlen ret_ctyp = - match (cval_ctyp vbits, ret_ctyp) with - | CT_fbits n, CT_fbits m when n = m -> smt_cval ctx vbits - | CT_fbits n, CT_fbits m -> - let bv = smt_cval ctx vbits in - let top_bit_one = Fn ("=", [Extract (n - 1, n - 1, bv); Bitvec_lit [Sail2_values.B1]]) in - Ite (top_bit_one, Fn ("concat", [bvones (m - n); bv]), Fn ("concat", [bvzero (m - n); bv])) - | _ -> builtin_type_error ctx "sign_extend" [vbits; vlen] (Some ret_ctyp) - -let builtin_shift shiftop ctx vbits vshift ret_ctyp = - match cval_ctyp vbits with - | CT_fbits n -> - let bv = smt_cval ctx vbits in - let len = bvzeint ctx n vshift in - Fn (shiftop, [bv; len]) - | CT_lbits -> - let bv = smt_cval ctx vbits in - let shift = bvzeint ctx (lbits_size ctx) vshift in - Fn ("Bits", [Fn ("len", [bv]); Fn (shiftop, [Fn ("contents", [bv]); shift])]) - | _ -> builtin_type_error ctx shiftop [vbits; vshift] (Some ret_ctyp) - -let builtin_not_bits ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | CT_lbits, CT_fbits n -> bvnot (Extract (n - 1, 0, Fn ("contents", [smt_cval ctx v]))) - | CT_lbits, CT_lbits -> - let bv = smt_cval ctx v in - let len = Fn ("len", [bv]) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; bvnot (Fn ("contents", [bv]))])]) - | CT_fbits n, CT_fbits m when n = m -> bvnot (smt_cval ctx v) - | _, _ -> builtin_type_error ctx "not_bits" [v] (Some ret_ctyp) - -let builtin_bitwise fn ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_fbits m, CT_fbits o -> - assert (n = m && m = o); - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn (fn, [smt1; smt2]) - | CT_lbits, CT_lbits, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("Bits", [Fn ("len", [smt1]); Fn (fn, [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) - | _ -> builtin_type_error ctx fn [v1; v2] (Some ret_ctyp) - -let builtin_and_bits = builtin_bitwise "bvand" -let builtin_or_bits = builtin_bitwise "bvor" -let builtin_xor_bits = builtin_bitwise "bvxor" - -let builtin_append ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_fbits m, CT_fbits o -> - assert (n + m = o); - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("concat", [smt1; smt2]) - | CT_fbits n, CT_lbits, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("concat", [bvzero (lbits_size ctx - n); smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - Fn - ( "Bits", - [ - bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt2])); - bvor (bvshl x shift) (Fn ("contents", [smt2])); - ] - ) - | CT_lbits, CT_fbits n, CT_fbits m -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Extract (m - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])) - | CT_lbits, CT_fbits n, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn - ( "Bits", - [ - bvadd (bvint ctx.lbits_index (Big_int.of_int n)) (Fn ("len", [smt1])); - Extract (lbits_size ctx - 1, 0, Fn ("concat", [Fn ("contents", [smt1]); smt2])); - ] - ) - | CT_fbits n, CT_fbits m, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn - ( "Bits", - [ - bvint ctx.lbits_index (Big_int.of_int (n + m)); - unsigned_size ctx (lbits_size ctx) (n + m) (Fn ("concat", [smt1; smt2])); - ] - ) - | CT_lbits, CT_lbits, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("contents", [smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - Fn ("Bits", [bvadd (Fn ("len", [smt1])) (Fn ("len", [smt2])); bvor (bvshl x shift) (Fn ("contents", [smt2]))]) - | CT_lbits, CT_lbits, CT_fbits n -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - let x = Fn ("contents", [smt1]) in - let shift = Fn ("concat", [bvzero (lbits_size ctx - ctx.lbits_index); Fn ("len", [smt2])]) in - unsigned_size ctx n (lbits_size ctx) (bvor (bvshl x shift) (Fn ("contents", [smt2]))) - | _ -> builtin_type_error ctx "append" [v1; v2] (Some ret_ctyp) - -let builtin_length ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | CT_fbits n, (CT_constant _ | CT_fint _ | CT_lint) -> bvint (int_size ctx ret_ctyp) (Big_int.of_int n) - | CT_lbits, (CT_constant _ | CT_fint _ | CT_lint) -> - let sz = ctx.lbits_index in - let m = int_size ctx ret_ctyp in - let len = Fn ("len", [smt_cval ctx v]) in - if m = sz then len else if m > sz then Fn ("concat", [bvzero (m - sz); len]) else Extract (m - 1, 0, len) - | _, _ -> builtin_type_error ctx "length" [v] (Some ret_ctyp) - -let builtin_vector_subrange ctx vec i j ret_ctyp = - match (cval_ctyp vec, cval_ctyp i, cval_ctyp j, ret_ctyp) with - | CT_fbits n, CT_constant i, CT_constant j, CT_fbits _ -> - Extract (Big_int.to_int i, Big_int.to_int j, smt_cval ctx vec) - | CT_lbits, CT_constant i, CT_constant j, CT_fbits _ -> - Extract (Big_int.to_int i, Big_int.to_int j, Fn ("contents", [smt_cval ctx vec])) - | CT_fbits n, i_ctyp, CT_constant j, CT_lbits when Big_int.equal j Big_int.zero -> - let i' = force_size ~checked:false ctx ctx.lbits_index (int_size ctx i_ctyp) (smt_cval ctx i) in - let len = bvadd i' (bvint ctx.lbits_index (Big_int.of_int 1)) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; unsigned_size ctx (lbits_size ctx) n (smt_cval ctx vec)])]) - | CT_fbits n, i_ctyp, j_ctyp, ret_ctyp -> - let i' = force_size ctx n (int_size ctx i_ctyp) (smt_cval ctx i) in - let j' = force_size ctx n (int_size ctx j_ctyp) (smt_cval ctx j) in - let len = bvadd (bvadd i' (bvneg j')) (bvint n (Big_int.of_int 1)) in - let vec' = bvand (bvlshr (smt_cval ctx vec) j') (fbits_mask ctx n len) in - smt_conversion ctx (CT_fbits n) ret_ctyp vec' - | _ -> builtin_type_error ctx "vector_subrange" [vec; i; j] (Some ret_ctyp) - -let builtin_vector_access ctx vec i ret_ctyp = - match (cval_ctyp vec, cval_ctyp i, ret_ctyp) with - | CT_fbits n, CT_constant i, CT_bit -> Extract (Big_int.to_int i, Big_int.to_int i, smt_cval ctx vec) - | CT_lbits, CT_constant i, CT_bit -> Extract (Big_int.to_int i, Big_int.to_int i, Fn ("contents", [smt_cval ctx vec])) - | CT_lbits, i_ctyp, CT_bit -> - let shift = force_size ~checked:false ctx (lbits_size ctx) (int_size ctx i_ctyp) (smt_cval ctx i) in - Extract (0, 0, Fn ("bvlshr", [Fn ("contents", [smt_cval ctx vec]); shift])) - | CT_vector _, CT_constant i, _ -> Fn ("select", [smt_cval ctx vec; bvint !vector_index i]) - | CT_vector _, index_ctyp, _ -> - Fn ("select", [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i)]) - | _ -> builtin_type_error ctx "vector_access" [vec; i] (Some ret_ctyp) - -let builtin_vector_update ctx vec i x ret_ctyp = - match (cval_ctyp vec, cval_ctyp i, cval_ctyp x, ret_ctyp) with - | CT_fbits n, CT_constant i, CT_bit, CT_fbits m when n - 1 > Big_int.to_int i && Big_int.to_int i > 0 -> - assert (n = m); - let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in - let bot = Extract (Big_int.to_int i - 1, 0, smt_cval ctx vec) in - Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) - | CT_fbits n, CT_constant i, CT_bit, CT_fbits m when n - 1 = Big_int.to_int i && Big_int.to_int i > 0 -> - let bot = Extract (Big_int.to_int i - 1, 0, smt_cval ctx vec) in - Fn ("concat", [smt_cval ctx x; bot]) - | CT_fbits n, CT_constant i, CT_bit, CT_fbits m when n - 1 > Big_int.to_int i && Big_int.to_int i = 0 -> - let top = Extract (n - 1, 1, smt_cval ctx vec) in - Fn ("concat", [top; smt_cval ctx x]) - | CT_fbits n, CT_constant i, CT_bit, CT_fbits m when n - 1 = 0 && Big_int.to_int i = 0 -> smt_cval ctx x - | CT_vector _, CT_constant i, ctyp, CT_vector _ -> - Fn ("store", [smt_cval ctx vec; bvint !vector_index i; smt_cval ctx x]) - | CT_vector _, index_ctyp, _, CT_vector _ -> - Fn - ( "store", - [smt_cval ctx vec; force_size ctx !vector_index (int_size ctx index_ctyp) (smt_cval ctx i); smt_cval ctx x] - ) - | _ -> builtin_type_error ctx "vector_update" [vec; i; x] (Some ret_ctyp) - -let builtin_vector_update_subrange ctx vec i j x ret_ctyp = - match (cval_ctyp vec, cval_ctyp i, cval_ctyp j, cval_ctyp x, ret_ctyp) with - | CT_fbits n, CT_constant i, CT_constant j, CT_fbits sz, CT_fbits m - when n - 1 > Big_int.to_int i && Big_int.to_int j > 0 -> - assert (n = m); - let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in - let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in - Fn ("concat", [top; Fn ("concat", [smt_cval ctx x; bot])]) - | CT_fbits n, CT_constant i, CT_constant j, CT_fbits sz, CT_fbits m - when n - 1 = Big_int.to_int i && Big_int.to_int j > 0 -> - assert (n = m); - let bot = Extract (Big_int.to_int j - 1, 0, smt_cval ctx vec) in - Fn ("concat", [smt_cval ctx x; bot]) - | CT_fbits n, CT_constant i, CT_constant j, CT_fbits sz, CT_fbits m - when n - 1 > Big_int.to_int i && Big_int.to_int j = 0 -> - assert (n = m); - let top = Extract (n - 1, Big_int.to_int i + 1, smt_cval ctx vec) in - Fn ("concat", [top; smt_cval ctx x]) - | CT_fbits n, CT_constant i, CT_constant j, CT_fbits sz, CT_fbits m - when n - 1 = Big_int.to_int i && Big_int.to_int j = 0 -> - smt_cval ctx x - | CT_fbits n, ctyp_i, ctyp_j, ctyp_x, CT_fbits m -> - assert (n = m); - let i' = force_size ctx n (int_size ctx ctyp_i) (smt_cval ctx i) in - let j' = force_size ctx n (int_size ctx ctyp_j) (smt_cval ctx j) in - let x' = smt_conversion ctx ctyp_x (CT_fbits n) (smt_cval ctx x) in - let len = bvadd (bvadd i' (bvneg j')) (bvint n (Big_int.of_int 1)) in - let mask = bvshl (fbits_mask ctx n len) j' in - bvor (bvand (smt_cval ctx vec) (bvnot mask)) (bvand (bvshl x' j') mask) - | _ -> builtin_type_error ctx "vector_update_subrange" [vec; i; j; x] (Some ret_ctyp) - -let builtin_unsigned ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | CT_fbits n, CT_fint m when m > n -> - let smt = smt_cval ctx v in - Fn ("concat", [bvzero (m - n); smt]) - | CT_fbits n, CT_lint -> - if n >= ctx.lint_size then failwith "Overflow detected" - else ( - let smt = smt_cval ctx v in - Fn ("concat", [bvzero (ctx.lint_size - n); smt]) - ) - | CT_lbits, CT_lint -> Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) - | CT_lbits, CT_fint m -> - let smt = Fn ("contents", [smt_cval ctx v]) in - force_size ctx m (lbits_size ctx) smt - | ctyp, _ -> builtin_type_error ctx "unsigned" [v] (Some ret_ctyp) - -let builtin_signed ctx v ret_ctyp = - match (cval_ctyp v, ret_ctyp) with - | CT_fbits n, CT_fint m when m >= n -> SignExtend (m - n, smt_cval ctx v) - | CT_fbits n, CT_lint -> SignExtend (ctx.lint_size - n, smt_cval ctx v) - | CT_lbits, CT_lint -> Extract (ctx.lint_size - 1, 0, Fn ("contents", [smt_cval ctx v])) - | ctyp, _ -> builtin_type_error ctx "signed" [v] (Some ret_ctyp) - -let builtin_add_bits ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_fbits m, CT_fbits o -> - assert (n = m && m = o); - Fn ("bvadd", [smt_cval ctx v1; smt_cval ctx v2]) - | CT_lbits, CT_lbits, CT_lbits -> - let smt1 = smt_cval ctx v1 in - let smt2 = smt_cval ctx v2 in - Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); Fn ("contents", [smt2])])]) - | _ -> builtin_type_error ctx "add_bits" [v1; v2] (Some ret_ctyp) - -let builtin_sub_bits ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_fbits m, CT_fbits o -> - assert (n = m && m = o); - Fn ("bvadd", [smt_cval ctx v1; Fn ("bvneg", [smt_cval ctx v2])]) - | _ -> failwith "Cannot compile sub_bits" - -let builtin_add_bits_int ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_constant c, CT_fbits o when n = o -> Fn ("bvadd", [smt_cval ctx v1; bvint o c]) - | CT_fbits n, CT_fint m, CT_fbits o when n = o -> Fn ("bvadd", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) - | CT_fbits n, CT_lint, CT_fbits o when n = o -> - Fn ("bvadd", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) - | CT_lbits, CT_fint n, CT_lbits when n < lbits_size ctx -> - let smt1 = smt_cval ctx v1 in - let smt2 = force_size ctx (lbits_size ctx) n (smt_cval ctx v2) in - Fn ("Bits", [Fn ("len", [smt1]); Fn ("bvadd", [Fn ("contents", [smt1]); smt2])]) - | _ -> builtin_type_error ctx "add_bits_int" [v1; v2] (Some ret_ctyp) - -let builtin_sub_bits_int ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_constant c, CT_fbits o when n = o -> Fn ("bvadd", [smt_cval ctx v1; bvint o (Big_int.negate c)]) - | CT_fbits n, CT_fint m, CT_fbits o when n = o -> Fn ("bvsub", [smt_cval ctx v1; force_size ctx o m (smt_cval ctx v2)]) - | CT_fbits n, CT_lint, CT_fbits o when n = o -> - Fn ("bvsub", [smt_cval ctx v1; force_size ctx o ctx.lint_size (smt_cval ctx v2)]) - | _ -> builtin_type_error ctx "sub_bits_int" [v1; v2] (Some ret_ctyp) - -let builtin_replicate_bits ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_constant c, CT_fbits m -> - assert (n * Big_int.to_int c = m); - let smt = smt_cval ctx v1 in - Fn ("concat", List.init (Big_int.to_int c) (fun _ -> smt)) - | CT_fbits n, _, CT_fbits m -> - let smt = smt_cval ctx v1 in - let c = m / n in - Fn ("concat", List.init c (fun _ -> smt)) - | CT_fbits n, v2_ctyp, CT_lbits -> - let times = (lbits_size ctx / n) + 1 in - let len = force_size ~checked:false ctx ctx.lbits_index (int_size ctx v2_ctyp) (smt_cval ctx v2) in - let smt1 = smt_cval ctx v1 in - let contents = Extract (lbits_size ctx - 1, 0, Fn ("concat", List.init times (fun _ -> smt1))) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) - | _ -> builtin_type_error ctx "replicate_bits" [v1; v2] (Some ret_ctyp) - -let builtin_sail_truncate ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_constant c, CT_fbits m -> - assert (Big_int.to_int c = m); - Extract (Big_int.to_int c - 1, 0, smt_cval ctx v1) - | CT_lbits, CT_constant c, CT_fbits m -> - assert (Big_int.to_int c = m && m < lbits_size ctx); - Extract (Big_int.to_int c - 1, 0, Fn ("contents", [smt_cval ctx v1])) - | CT_fbits n, _, CT_lbits -> - let smt1 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v1) in - let smt2 = bvzeint ctx ctx.lbits_index v2 in - Fn ("Bits", [smt2; Fn ("bvand", [bvmask ctx smt2; smt1])]) - | _ -> builtin_type_error ctx "sail_truncate" [v1; v2] (Some ret_ctyp) - -let builtin_sail_truncateLSB ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, ret_ctyp) with - | CT_fbits n, CT_constant c, CT_fbits m -> - assert (Big_int.to_int c = m); - Extract (n - 1, n - Big_int.to_int c, smt_cval ctx v1) - | _ -> builtin_type_error ctx "sail_truncateLSB" [v1; v2] (Some ret_ctyp) - -let builtin_slice ctx v1 v2 v3 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, ret_ctyp) with - | CT_lbits, CT_constant start, CT_constant len, CT_fbits _ -> - let top = Big_int.pred (Big_int.add start len) in - Extract (Big_int.to_int top, Big_int.to_int start, Fn ("contents", [smt_cval ctx v1])) - | CT_fbits _, CT_constant start, CT_constant len, CT_fbits _ -> - let top = Big_int.pred (Big_int.add start len) in - Extract (Big_int.to_int top, Big_int.to_int start, smt_cval ctx v1) - | CT_fbits _, CT_fint _, CT_constant len, CT_fbits _ -> - Extract (Big_int.to_int (Big_int.pred len), 0, builtin_shift "bvlshr" ctx v1 v2 (cval_ctyp v1)) - | CT_fbits n, ctyp2, _, CT_lbits -> - let smt1 = force_size ctx (lbits_size ctx) n (smt_cval ctx v1) in - let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in - let smt3 = bvzeint ctx ctx.lbits_index v3 in - Fn ("Bits", [smt3; Fn ("bvand", [Fn ("bvlshr", [smt1; smt2]); bvmask ctx smt3])]) - | _ -> builtin_type_error ctx "slice" [v1; v2; v3] (Some ret_ctyp) - -let builtin_get_slice_int ctx v1 v2 v3 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, ret_ctyp) with - | CT_constant len, ctyp, CT_constant start, CT_fbits ret_sz -> - let len = Big_int.to_int len in - let start = Big_int.to_int start in - let in_sz = int_size ctx ctyp in - let smt = if in_sz < len + start then force_size ctx (len + start) in_sz (smt_cval ctx v2) else smt_cval ctx v2 in - Extract (start + len - 1, start, smt) - | CT_lint, CT_lint, CT_constant start, CT_lbits when Big_int.equal start Big_int.zero -> - let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in - let contents = unsigned_size ~checked:false ctx (lbits_size ctx) ctx.lint_size (smt_cval ctx v2) in - Fn ("Bits", [len; Fn ("bvand", [bvmask ctx len; contents])]) - | CT_lint, ctyp2, ctyp3, ret_ctyp -> - let len = Extract (ctx.lbits_index - 1, 0, smt_cval ctx v1) in - let smt2 = force_size ctx (lbits_size ctx) (int_size ctx ctyp2) (smt_cval ctx v2) in - let smt3 = force_size ctx (lbits_size ctx) (int_size ctx ctyp3) (smt_cval ctx v3) in - let result = bvand (bvmask ctx len) (bvlshr smt2 smt3) in - smt_conversion ctx CT_lint ret_ctyp result - | _ -> builtin_type_error ctx "get_slice_int" [v1; v2; v3] (Some ret_ctyp) - -let builtin_count_leading_zeros ctx v ret_ctyp = - let ret_sz = int_size ctx ret_ctyp in - let rec lzcnt sz smt = - if sz == 1 then - Ite - ( Fn ("=", [Extract (0, 0, smt); Bitvec_lit [Sail2_values.B0]]), - bvint ret_sz (Big_int.of_int 1), - bvint ret_sz Big_int.zero - ) - else ( - assert (sz land (sz - 1) = 0); - let hsz = sz / 2 in - Ite - ( Fn ("=", [Extract (sz - 1, hsz, smt); bvzero hsz]), - Fn ("bvadd", [bvint ret_sz (Big_int.of_int hsz); lzcnt hsz (Extract (hsz - 1, 0, smt))]), - lzcnt hsz (Extract (sz - 1, hsz, smt)) - ) - ) - in - let smallest_greater_power_of_two n = - let m = ref 1 in - while !m < n do - m := !m lsl 1 - done; - assert (!m land (!m - 1) = 0); - !m - in - match cval_ctyp v with - | CT_fbits sz when sz land (sz - 1) = 0 -> lzcnt sz (smt_cval ctx v) - | CT_fbits sz -> - let padded_sz = smallest_greater_power_of_two sz in - let padding = bvzero (padded_sz - sz) in - Fn - ( "bvsub", - [lzcnt padded_sz (Fn ("concat", [padding; smt_cval ctx v])); bvint ret_sz (Big_int.of_int (padded_sz - sz))] - ) - | CT_lbits -> - let smt = smt_cval ctx v in - Fn - ( "bvsub", - [ - lzcnt (lbits_size ctx) (Fn ("contents", [smt])); - Fn - ( "bvsub", - [ - bvint ret_sz (Big_int.of_int (lbits_size ctx)); - Fn ("concat", [bvzero (ret_sz - ctx.lbits_index); Fn ("len", [smt])]); - ] - ); - ] + (* For a basic block (contained in a control-flow node / cfnode), we + turn the instructions into a sequence of define-const and + declare-const expressions. Because we are working with a SSA graph, + each variable is guaranteed to only be declared once. + *) + let smt_instr state ctx (I_aux (aux, (_, l)) as instr) = + let open Type_check in + match aux with + | I_funcall (CR_one (CL_id (id, ret_ctyp)), extern, (function_id, _), args) -> + if ctx_is_extern function_id ctx then ( + let name = ctx_get_extern function_id ctx in + if name = "sail_assert" then ( + match args with + | [assertion; _] -> + let* smt = Smt.smt_cval assertion in + let* _ = add_event state Assertion (Fn ("not", [smt])) in + return [] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for assertion" + ) + else if name = "sail_assume" then ( + match args with + | [assumption] -> + let* smt = Smt.smt_cval assumption in + let* _ = add_event state Assumption smt in + return [] + | _ -> Reporting.unreachable l __POS__ "Bad arguments for assertion" + ) + else if name = "sqrt_real" then ( + match args with + | [v] -> builtin_sqrt_real id v + | _ -> Reporting.unreachable l __POS__ "Bad arguments for sqrt_real" + ) + else ( + match Smt.builtin ~allow_io:false name with + | Some generator -> + let* value = generator args ret_ctyp in + singleton (define_const id ret_ctyp value) + | None -> failwith ("No generator " ^ string_of_id function_id) + ) ) - | _ -> builtin_type_error ctx "count_leading_zeros" [v] (Some ret_ctyp) - -let builtin_set_slice_bits ctx v1 v2 v3 v4 v5 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2, cval_ctyp v3, cval_ctyp v4, cval_ctyp v5, ret_ctyp) with - | CT_constant n', CT_constant m', CT_fbits n, CT_constant pos, CT_fbits m, CT_fbits n'' - when Big_int.to_int m' = m && Big_int.to_int n' = n && n'' = n && Big_int.less_equal (Big_int.add pos m') n' -> - let pos = Big_int.to_int pos in - if pos = 0 then ( - let mask = Fn ("concat", [bvones (n - m); bvzero m]) in - let smt5 = Fn ("concat", [bvzero (n - m); smt_cval ctx v5]) in - Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) - ) - else if n - m - pos = 0 then ( - let mask = Fn ("concat", [bvzero m; bvones pos]) in - let smt5 = Fn ("concat", [smt_cval ctx v5; bvzero pos]) in - Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) + else if extern && string_of_id function_id = "internal_vector_init" then singleton (declare_const id ret_ctyp) + else if extern && string_of_id function_id = "internal_vector_update" then begin + match args with + | [vec; i; x] -> + let sz = required_width (Big_int.of_int (Smt.generic_vector_length (cval_ctyp vec) - 1)) in + let* vec = Smt.smt_cval vec in + let* i = + Smt_gen.bind (Smt.smt_cval i) (Smt_gen.signed_size ~into:sz ~from:(Smt.int_size (cval_ctyp i))) + in + let* x = Smt.smt_cval x in + singleton (define_const id ret_ctyp (Fn ("store", [vec; i; x]))) + | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" + end + else if not extern then + let* smt_args = mapM Smt.smt_cval args in + singleton (define_const id ret_ctyp (Fn (zencode_id function_id, smt_args))) + else failwith ("Unrecognised function " ^ string_of_id function_id) + | I_init (ctyp, id, cval) | I_copy (CL_id (id, ctyp), cval) -> + let* cval_smt = Smt.smt_cval cval in + let* converted_smt = Smt.smt_conversion ~into:ctyp ~from:(cval_ctyp cval) cval_smt in + singleton (define_const id ctyp converted_smt) + | I_copy (clexp, cval) -> + let* smt = Smt.smt_cval cval in + let write, ctyp = rmw_write clexp in + singleton (define_const write ctyp (rmw_modify smt clexp)) + | I_decl (ctyp, id) -> begin + begin + match l with Unique (n, _) -> Stack.push (n, zencode_name id) state.arg_stack | _ -> () + end; + let* ty = smt_ctyp ctyp in + let wf_pred = wf_smt_ctyp ctyp in + match wf_pred with + | Some p -> return [Declare_const (id, ty); Assert (p (Var id))] + | None -> return [Declare_const (id, ty)] + end + | I_clear _ -> return [] + (* Should only appear as terminators for basic blocks. *) + | I_jump _ | I_goto _ | I_end _ | I_exit _ | I_undefined _ -> + Reporting.unreachable l __POS__ "SMT: Instruction should only appear as block terminator" + | _ -> Reporting.unreachable l __POS__ (string_of_instr instr) + + let generate_reg_decs inits cdefs = + let rec go acc = function + | CDEF_aux (CDEF_register (id, ctyp, _), _) :: cdefs when not (NameMap.mem (Name (id, 0)) inits) -> + let* smt_typ = smt_ctyp ctyp in + go (Declare_const (Name (id, 0), smt_typ) :: acc) cdefs + | _ :: cdefs -> go acc cdefs + | [] -> return (List.rev acc) + in + go [] cdefs + + let smt_terminator ctx state = + let open Jib_ssa in + function + | T_end id -> add_event state Return (Var id) + | T_exit _ -> add_pathcond_event state Match + | T_undefined _ | T_goto _ | T_jump _ | T_label _ | T_none -> return () + + let smt_cfnode all_cdefs ctx state = + let open Jib_ssa in + function + | CF_start inits -> + let* smt_reg_decs = generate_reg_decs inits all_cdefs in + let smt_start (id, ctyp) = + match id with Have_exception _ -> define_const id ctyp (Bool_lit false) | _ -> declare_const id ctyp + in + let* smt_inits = mapM smt_start (NameMap.bindings inits) in + return (smt_reg_decs @ smt_inits) + | CF_block (instrs, terminator) -> + let* smt_instrs = mapM (smt_instr state ctx) instrs in + let* _ = smt_terminator ctx state terminator in + return (List.concat smt_instrs) + (* We can ignore any non basic-block/start control-flow nodes *) + | _ -> return [] + + let smt_ctype_def = function + | CTD_enum (id, elems) -> return (declare_datatypes (mk_enum (zencode_upper_id id) (List.map zencode_id elems))) + | CTD_struct (id, fields) -> + let* fields = + mapM + (fun (field, ctyp) -> + let* smt_typ = smt_ctyp ctyp in + return (zencode_upper_id id ^ "_" ^ zencode_id field, smt_typ) + ) + fields + in + return (declare_datatypes (mk_record (zencode_upper_id id) fields)) + | CTD_variant (id, ctors) -> + let* ctors = + mapM + (fun (ctor, ctyp) -> + let* smt_typ = smt_ctyp ctyp in + return (zencode_id ctor, smt_typ) + ) + ctors + in + return (declare_datatypes (mk_variant (zencode_upper_id id) ctors)) + + let rec generate_ctype_defs acc = function + | CDEF_aux (CDEF_type ctd, _) :: cdefs -> + let* smt_type_def = smt_ctype_def ctd in + generate_ctype_defs (smt_type_def :: acc) cdefs + | _ :: cdefs -> generate_ctype_defs acc cdefs + | [] -> return (List.rev acc) + + (* [smt_header ctx cdefs] produces a list of smt definitions for all + the datatypes in a specification *) + let smt_header cdefs = + let* smt_type_defs = generate_ctype_defs [] cdefs in + return + ([declare_datatypes (mk_enum "Unit" ["unit"])] + @ [ + declare_datatypes + (mk_record "Bits" + [("len", Bitvec lbits_index_width); ("contents", Bitvec Config.max_unknown_bitvector_width)] + ); + ] + @ smt_type_defs ) - else ( - let mask = Fn ("concat", [bvones (n - m - pos); Fn ("concat", [bvzero m; bvones pos])]) in - let smt5 = Fn ("concat", [bvzero (n - m - pos); Fn ("concat", [smt_cval ctx v5; bvzero pos])]) in - Fn ("bvor", [Fn ("bvand", [smt_cval ctx v3; mask]); smt5]) + + let dump_graph name cfg = + let gv_file = name ^ ".gv" in + prerr_endline Util.("Dumping graph: " ^ gv_file |> bold |> yellow |> clear); + let out_chan = open_out gv_file in + Jib_ssa.make_dot out_chan cfg; + close_out out_chan + + let push_smt_defs stack smt_defs = List.iter (fun def -> Stack.push def stack) smt_defs + + let smt_instr_list debug_attr name ctx all_cdefs instrs = + let stack = Stack.create () in + + let open Jib_ssa in + let start, cfg = ssa ?debug_prefix:(Option.map (fun _ -> name) debug_attr) instrs in + let visit_order = + try topsort cfg + with Not_a_DAG n -> + dump_graph name cfg; + raise + (Reporting.err_general Parse_ast.Unknown + (Printf.sprintf "%s: control flow graph is not acyclic (node %d is in cycle)\nWrote graph to %s.gv" name n + name + ) + ) + in + if Option.is_some debug_attr then dump_graph name cfg; + + let state = { events = ref EventMap.empty; cfg; node = -1; arg_stack = Stack.create () } in + + List.iter + (fun n -> + match get_vertex cfg n with + | None -> () + | Some ((ssa_elems, cfnode), preds, succs) -> + let pathcond, checks = + Smt_gen.run + (let* muxers = Smt_gen.fmap List.concat (mapM (smt_ssanode cfg preds) ssa_elems) in + let state = { state with node = n } in + let* basic_block = smt_cfnode all_cdefs ctx state cfnode in + push_smt_defs stack muxers; + push_smt_defs stack basic_block; + get_pathcond state.node state.cfg + ) + Parse_ast.Unknown + in + if not Config.ignore_overflow then ( + let overflow_stack = event_stack state Overflow in + List.iter + (fun overflow_smt -> Stack.push (Fn ("and", [pathcond; overflow_smt])) overflow_stack) + (Smt_gen.get_overflows checks) + ) ) - (* set_slice_bits(len, slen, x, pos, y) = - let mask = slice_mask(len, pos, slen) in - (x AND NOT(mask)) OR ((unsigned_size(len, y) << pos) AND mask) *) - | CT_constant n', _, CT_fbits n, _, CT_lbits, CT_fbits n'' when Big_int.to_int n' = n && n'' = n -> - let pos = bvzeint ctx (lbits_size ctx) v4 in - let slen = bvzeint ctx ctx.lbits_index v2 in - let mask = Fn ("bvshl", [bvmask ctx slen; pos]) in - let smt3 = unsigned_size ctx (lbits_size ctx) n (smt_cval ctx v3) in - let smt3' = Fn ("bvand", [smt3; Fn ("bvnot", [mask])]) in - let smt5 = Fn ("contents", [smt_cval ctx v5]) in - let smt5' = Fn ("bvand", [Fn ("bvshl", [smt5; pos]); mask]) in - Extract (n - 1, 0, Fn ("bvor", [smt3'; smt5'])) - | _ -> builtin_type_error ctx "set_slice" [v1; v2; v3; v4; v5] (Some ret_ctyp) - -let builtin_compare_bits fn ctx v1 v2 ret_ctyp = - match (cval_ctyp v1, cval_ctyp v2) with - | CT_fbits n, CT_fbits m when n = m -> Fn (fn, [smt_cval ctx v1; smt_cval ctx v2]) - | _ -> builtin_type_error ctx fn [v1; v2] (Some ret_ctyp) - -(* ***** String operations: lib/real.sail ***** *) - -let builtin_decimal_string_of_bits ctx v = - begin - match cval_ctyp v with - | CT_fbits n -> Fn ("int.to.str", [Fn ("bv2nat", [smt_cval ctx v])]) - | _ -> builtin_type_error ctx "decimal_string_of_bits" [v] None - end + visit_order; -(* ***** Real number operations: lib/real.sail ***** *) - -let builtin_sqrt_real ctx root v = - ctx.use_real := true; - let smt = smt_cval ctx v in - [ - Declare_const (root, Real); - Assert (Fn ("and", [Fn ("=", [smt; Fn ("*", [Var root; Var root])]); Fn (">=", [Var root; Real_lit "0.0"])])); - ] - -let smt_builtin ctx name args ret_ctyp = - match (name, args, ret_ctyp) with - | "eq_anything", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - (* lib/flow.sail *) - | "eq_bit", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "eq_bool", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "eq_unit", [v1; v2], CT_bool -> Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "eq_int", [v1; v2], CT_bool -> builtin_eq_int ctx v1 v2 - | "not", [v], _ -> Fn ("not", [smt_cval ctx v]) - | "lt", [v1; v2], _ -> builtin_lt ctx v1 v2 - | "lteq", [v1; v2], _ -> builtin_lteq ctx v1 v2 - | "gt", [v1; v2], _ -> builtin_gt ctx v1 v2 - | "gteq", [v1; v2], _ -> builtin_gteq ctx v1 v2 - (* lib/arith.sail *) - | "add_int", [v1; v2], _ -> builtin_add_int ctx v1 v2 ret_ctyp - | "sub_int", [v1; v2], _ -> builtin_sub_int ctx v1 v2 ret_ctyp - | "sub_nat", [v1; v2], _ -> builtin_sub_nat ctx v1 v2 ret_ctyp - | "mult_int", [v1; v2], _ -> builtin_mult_int ctx v1 v2 ret_ctyp - | "neg_int", [v], _ -> builtin_negate_int ctx v ret_ctyp - | "shl_int", [v1; v2], _ -> builtin_shl_int ctx v1 v2 ret_ctyp - | "shr_int", [v1; v2], _ -> builtin_shr_int ctx v1 v2 ret_ctyp - | "shl_mach_int", [v1; v2], _ -> builtin_shl_int ctx v1 v2 ret_ctyp - | "shr_mach_int", [v1; v2], _ -> builtin_shr_int ctx v1 v2 ret_ctyp - | "abs_int", [v], _ -> builtin_abs_int ctx v ret_ctyp - | "pow2", [v], _ -> builtin_pow2 ctx v ret_ctyp - | "max_int", [v1; v2], _ -> builtin_max_int ctx v1 v2 ret_ctyp - | "min_int", [v1; v2], _ -> builtin_min_int ctx v1 v2 ret_ctyp - | "ediv_int", [v1; v2], _ -> builtin_tdiv_int ctx v1 v2 ret_ctyp - (* All signed and unsigned bitvector comparisons *) - | "slt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvslt" ctx v1 v2 ret_ctyp - | "ult_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvult" ctx v1 v2 ret_ctyp - | "sgt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsgt" ctx v1 v2 ret_ctyp - | "ugt_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvugt" ctx v1 v2 ret_ctyp - | "slteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsle" ctx v1 v2 ret_ctyp - | "ulteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvule" ctx v1 v2 ret_ctyp - | "sgteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvsge" ctx v1 v2 ret_ctyp - | "ugteq_bits", [v1; v2], CT_bool -> builtin_compare_bits "bvuge" ctx v1 v2 ret_ctyp - (* lib/vector_dec.sail *) - | "eq_bits", [v1; v2], CT_bool -> builtin_eq_bits ctx v1 v2 - | "zeros", [v], _ -> builtin_zeros ctx v ret_ctyp - | "sail_zeros", [v], _ -> builtin_zeros ctx v ret_ctyp - | "ones", [v], _ -> builtin_ones ctx v ret_ctyp - | "sail_ones", [v], _ -> builtin_ones ctx v ret_ctyp - | "zero_extend", [v1; v2], _ -> builtin_zero_extend ctx v1 v2 ret_ctyp - | "sign_extend", [v1; v2], _ -> builtin_sign_extend ctx v1 v2 ret_ctyp - | "sail_truncate", [v1; v2], _ -> builtin_sail_truncate ctx v1 v2 ret_ctyp - | "sail_truncateLSB", [v1; v2], _ -> builtin_sail_truncateLSB ctx v1 v2 ret_ctyp - | "shiftl", [v1; v2], _ -> builtin_shift "bvshl" ctx v1 v2 ret_ctyp - | "shiftr", [v1; v2], _ -> builtin_shift "bvlshr" ctx v1 v2 ret_ctyp - | "arith_shiftr", [v1; v2], _ -> builtin_shift "bvashr" ctx v1 v2 ret_ctyp - | "and_bits", [v1; v2], _ -> builtin_and_bits ctx v1 v2 ret_ctyp - | "or_bits", [v1; v2], _ -> builtin_or_bits ctx v1 v2 ret_ctyp - | "xor_bits", [v1; v2], _ -> builtin_xor_bits ctx v1 v2 ret_ctyp - | "not_bits", [v], _ -> builtin_not_bits ctx v ret_ctyp - | "add_bits", [v1; v2], _ -> builtin_add_bits ctx v1 v2 ret_ctyp - | "add_bits_int", [v1; v2], _ -> builtin_add_bits_int ctx v1 v2 ret_ctyp - | "sub_bits", [v1; v2], _ -> builtin_sub_bits ctx v1 v2 ret_ctyp - | "sub_bits_int", [v1; v2], _ -> builtin_sub_bits_int ctx v1 v2 ret_ctyp - | "append", [v1; v2], _ -> builtin_append ctx v1 v2 ret_ctyp - | "length", [v], ret_ctyp -> builtin_length ctx v ret_ctyp - | "vector_access", [v1; v2], ret_ctyp -> builtin_vector_access ctx v1 v2 ret_ctyp - | "vector_subrange", [v1; v2; v3], ret_ctyp -> builtin_vector_subrange ctx v1 v2 v3 ret_ctyp - | "vector_update", [v1; v2; v3], ret_ctyp -> builtin_vector_update ctx v1 v2 v3 ret_ctyp - | "vector_update_subrange", [v1; v2; v3; v4], ret_ctyp -> builtin_vector_update_subrange ctx v1 v2 v3 v4 ret_ctyp - | "sail_unsigned", [v], ret_ctyp -> builtin_unsigned ctx v ret_ctyp - | "sail_signed", [v], ret_ctyp -> builtin_signed ctx v ret_ctyp - | "replicate_bits", [v1; v2], ret_ctyp -> builtin_replicate_bits ctx v1 v2 ret_ctyp - | "count_leading_zeros", [v], ret_ctyp -> builtin_count_leading_zeros ctx v ret_ctyp - | "slice", [v1; v2; v3], ret_ctyp -> builtin_slice ctx v1 v2 v3 ret_ctyp - | "get_slice_int", [v1; v2; v3], ret_ctyp -> builtin_get_slice_int ctx v1 v2 v3 ret_ctyp - | "set_slice", [v1; v2; v3; v4; v5], ret_ctyp -> builtin_set_slice_bits ctx v1 v2 v3 v4 v5 ret_ctyp - (* string builtins *) - | "concat_str", [v1; v2], CT_string -> - ctx.use_string := true; - Fn ("str.++", [smt_cval ctx v1; smt_cval ctx v2]) - | "eq_string", [v1; v2], CT_bool -> - ctx.use_string := true; - Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "decimal_string_of_bits", [v], CT_string -> - ctx.use_string := true; - builtin_decimal_string_of_bits ctx v - (* lib/real.sail *) - (* Note that sqrt_real is special and is handled by smt_instr. *) - | "eq_real", [v1; v2], CT_bool -> - ctx.use_real := true; - Fn ("=", [smt_cval ctx v1; smt_cval ctx v2]) - | "neg_real", [v], CT_real -> - ctx.use_real := true; - Fn ("-", [smt_cval ctx v]) - | "add_real", [v1; v2], CT_real -> - ctx.use_real := true; - Fn ("+", [smt_cval ctx v1; smt_cval ctx v2]) - | "sub_real", [v1; v2], CT_real -> - ctx.use_real := true; - Fn ("-", [smt_cval ctx v1; smt_cval ctx v2]) - | "mult_real", [v1; v2], CT_real -> - ctx.use_real := true; - Fn ("*", [smt_cval ctx v1; smt_cval ctx v2]) - | "div_real", [v1; v2], CT_real -> - ctx.use_real := true; - Fn ("/", [smt_cval ctx v1; smt_cval ctx v2]) - | "lt_real", [v1; v2], CT_bool -> - ctx.use_real := true; - Fn ("<", [smt_cval ctx v1; smt_cval ctx v2]) - | "gt_real", [v1; v2], CT_bool -> - ctx.use_real := true; - Fn (">", [smt_cval ctx v1; smt_cval ctx v2]) - | "lteq_real", [v1; v2], CT_bool -> - ctx.use_real := true; - Fn ("<=", [smt_cval ctx v1; smt_cval ctx v2]) - | "gteq_real", [v1; v2], CT_bool -> - ctx.use_real := true; - Fn (">=", [smt_cval ctx v1; smt_cval ctx v2]) - | _ -> - Reporting.unreachable ctx.pragma_l __POS__ - ("Unknown builtin " ^ name ^ " " - ^ Util.string_of_list ", " string_of_ctyp (List.map cval_ctyp args) - ^ " -> " ^ string_of_ctyp ret_ctyp - ) + return (stack, state) -let loc_doc _ = "UNKNOWN" - -(* Memory reads and writes as defined in lib/regfp.sail *) -let writes = ref (-1) - -let builtin_write_mem l ctx wk addr_size addr data_size data = - incr writes; - let name = "W" ^ string_of_int !writes in - ( [ - Write_mem - { - name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - kind = smt_cval ctx wk; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - data = smt_cval ctx data; - data_type = smt_ctyp ctx (cval_ctyp data); - doc = loc_doc l; - }; - ], - Var (name ^ "_ret") - ) - -let ea_writes = ref (-1) - -let builtin_write_mem_ea ctx wk addr_size addr data_size = - incr ea_writes; - let name = "A" ^ string_of_int !ea_writes in - ( [ - Write_mem_ea - ( name, - ctx.node, - Lazy.force ctx.pathcond, - smt_cval ctx wk, - smt_cval ctx addr, - smt_ctyp ctx (cval_ctyp addr), - smt_cval ctx data_size, - smt_ctyp ctx (cval_ctyp data_size) - ); - ], - Enum "unit" - ) - -let reads = ref (-1) - -let builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp = - incr reads; - let name = "R" ^ string_of_int !reads in - ( [ - Read_mem - { - name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - ret_type = smt_ctyp ctx ret_ctyp; - kind = smt_cval ctx rk; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - doc = loc_doc l; - }; - ], - Read_res name - ) - -let excl_results = ref (-1) - -let builtin_excl_res ctx = - incr excl_results; - let name = "E" ^ string_of_int !excl_results in - ([Excl_res (name, ctx.node, Lazy.force ctx.pathcond)], Var (name ^ "_ret")) - -let barriers = ref (-1) - -let builtin_barrier l ctx bk = - incr barriers; - let name = "B" ^ string_of_int !barriers in - ( [Barrier { name; node = ctx.node; active = Lazy.force ctx.pathcond; kind = smt_cval ctx bk; doc = loc_doc l }], - Enum "unit" - ) - -let cache_maintenances = ref (-1) - -let builtin_cache_maintenance l ctx cmk addr_size addr = - incr cache_maintenances; - let name = "M" ^ string_of_int !cache_maintenances in - ( [ - Cache_maintenance - { - name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - kind = smt_cval ctx cmk; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - doc = loc_doc l; - }; - ], - Enum "unit" - ) - -let branch_announces = ref (-1) - -let builtin_branch_announce l ctx addr_size addr = - incr branch_announces; - let name = "C" ^ string_of_int !branch_announces in - ( [ - Branch_announce - { - name; - node = ctx.node; - active = Lazy.force ctx.pathcond; - addr = smt_cval ctx addr; - addr_type = smt_ctyp ctx (cval_ctyp addr); - doc = loc_doc l; - }; - ], - Enum "unit" - ) - -let define_const ctx id ctyp exp = Define_const (zencode_name id, smt_ctyp ctx ctyp, exp) -let preserve_const ctx id ctyp exp = Preserve_const (string_of_id id, smt_ctyp ctx ctyp, exp) -let declare_const ctx id ctyp = Declare_const (zencode_name id, smt_ctyp ctx ctyp) - -let smt_ctype_def ctx = function - | CTD_enum (id, elems) -> [declare_datatypes (mk_enum (zencode_upper_id id) (List.map zencode_id elems))] - | CTD_struct (id, fields) -> - [ - declare_datatypes - (mk_record (zencode_upper_id id) - (List.map (fun (field, ctyp) -> (zencode_upper_id id ^ "_" ^ zencode_id field, smt_ctyp ctx ctyp)) fields) - ); - ] - | CTD_variant (id, ctors) -> - [ - declare_datatypes - (mk_variant (zencode_upper_id id) (List.map (fun (ctor, ctyp) -> (zencode_id ctor, smt_ctyp ctx ctyp)) ctors)); - ] + (** When we generate a property for a CDEF_val, we find it's + associated function body in a CDEF_fundef node. However, we must + keep track of any global letbindings between the spec and the + fundef, so they can appear in the generated SMT. *) + let rec find_function lets id = function + | CDEF_aux (CDEF_fundef (id', heap_return, args, body), def_annot) :: _ when Id.compare id id' = 0 -> + (lets, Some (heap_return, args, body, def_annot)) + | CDEF_aux (CDEF_let (_, vars, setup), _) :: cdefs -> + let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (name id)) vars in + find_function (lets @ vars @ setup) id cdefs + | _ :: cdefs -> find_function lets id cdefs + | [] -> (lets, None) + + let rec smt_query state = function + | Q_all ev -> + let stack = event_stack state ev in + smt_conj (Stack.fold (fun xs x -> x :: xs) [] stack) + | Q_exist ev -> + let stack = event_stack state ev in + smt_disj (Stack.fold (fun xs x -> x :: xs) [] stack) + | Q_not q -> Fn ("not", [smt_query state q]) + | Q_and qs -> smt_conj (List.map (smt_query state) qs) + | Q_or qs -> smt_disj (List.map (smt_query state) qs) + + type generated_smt_info = { + file_name : string; + function_id : id; + args : id list; + arg_ctyps : ctyp list; + arg_smt_names : (id * string option) list; + } -let rec generate_ctype_defs ctx = function - | CDEF_aux (CDEF_type ctd, _) :: cdefs -> smt_ctype_def ctx ctd :: generate_ctype_defs ctx cdefs - | _ :: cdefs -> generate_ctype_defs ctx cdefs - | [] -> [] + let smt_cdef props lets name_file ctx all_cdefs smt_includes (CDEF_aux (aux, def_annot)) = + match aux with + | CDEF_val (function_id, _, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> begin + match find_function [] function_id all_cdefs with + | intervening_lets, Some (None, args, instrs, function_def_annot) -> + let debug_attr = get_def_attribute "jib_debug" function_def_annot in + let prop_type, prop_args, pragma_l, vs = Bindings.find function_id props in + + let pragma = Property.parse_pragma pragma_l prop_args in + + (* When we create each argument declaration, give it a unique + location from the $property pragma, so we can identify it later. *) + let arg_decls = + List.map2 + (fun id ctyp -> + let l = unique pragma_l in + idecl l ctyp (name id) + ) + args arg_ctyps + in + let instrs = + let open Jib_optimize in + lets @ intervening_lets @ arg_decls @ instrs + |> inline all_cdefs (fun _ -> true) + (* |> List.map (map_instr (expand_reg_deref ctx.tc_env Config.register_map)) *) + |> flatten_instrs + |> remove_unused_labels |> remove_pointless_goto + in -let rec generate_reg_decs ctx inits = function - | CDEF_aux (CDEF_register (id, ctyp, _), _) :: cdefs when not (NameMap.mem (Name (id, 0)) inits) -> - Declare_const (zencode_name (Name (id, 0)), smt_ctyp ctx ctyp) :: generate_reg_decs ctx inits cdefs - | _ :: cdefs -> generate_reg_decs ctx inits cdefs - | [] -> [] + if Option.is_some debug_attr then ( + prerr_endline Util.("Pre-SMT IR for " ^ string_of_id function_id ^ ":" |> yellow |> bold |> clear); + List.iter (fun instr -> prerr_endline (string_of_instr instr)) instrs + ); -(**************************************************************************) -(* 2. Converting sail types to Jib types for SMT *) -(**************************************************************************) + let (stack, state), _ = + Smt_gen.run (smt_instr_list debug_attr (string_of_id function_id) ctx all_cdefs instrs) pragma_l + in -let max_int n = Big_int.pred (Big_int.pow_int_positive 2 (n - 1)) -let min_int n = Big_int.negate (Big_int.pow_int_positive 2 (n - 1)) + let query = smt_query state pragma.query in + push_smt_defs stack [Assert (Fn ("not", [query]))]; + + let fname = name_file (string_of_id function_id) in + let out_chan = open_out fname in + if prop_type = "counterexample" then output_string out_chan "(set-option :produce-models true)\n"; + + let header, _ = Smt_gen.run (smt_header all_cdefs) pragma_l in + List.iter + (fun def -> + output_string out_chan (string_of_smt_def def); + output_string out_chan "\n" + ) + header; + + (* Include custom SMT definitions. *) + List.iter (fun include_file -> output_string out_chan (Util.read_whole_file include_file)) smt_includes; + + let queue = Queue_optimizer.optimize stack in + Queue.iter + (fun def -> + output_string out_chan (string_of_smt_def def); + output_string out_chan "\n" + ) + queue; + + (* (Queue.of_seq (List.to_seq (List.rev (List.of_seq (Stack.to_seq stack))))); *) + output_string out_chan "(check-sat)\n"; + if prop_type = "counterexample" then output_string out_chan "(get-model)\n"; + + close_out out_chan; + let arg_names = Stack.fold (fun m (k, v) -> (k, v) :: m) [] state.arg_stack in + let arg_smt_names = + List.map + (function + | I_aux (I_decl (_, Name (id, _)), (_, Unique (n, _))) -> (id, List.assoc_opt n arg_names) + | _ -> assert false + ) + arg_decls + in + Some { file_name = fname; function_id; args; arg_ctyps; arg_smt_names } + | _ -> + let _, _, pragma_l, _ = Bindings.find function_id props in + raise (Reporting.err_general pragma_l "No function body found") + end + | _ -> None + + let rec smt_cdefs acc props lets name_file ctx all_cdefs smt_includes = function + | CDEF_aux (CDEF_let (_, vars, setup), _) :: cdefs -> + let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (name id)) vars in + smt_cdefs acc props (lets @ vars @ setup) name_file ctx all_cdefs smt_includes cdefs + | cdef :: cdefs -> begin + match smt_cdef props lets name_file ctx all_cdefs smt_includes cdef with + | Some generation_info -> + smt_cdefs (generation_info :: acc) props lets name_file ctx all_cdefs smt_includes cdefs + | None -> smt_cdefs acc props lets name_file ctx all_cdefs smt_includes cdefs + end + | [] -> acc + + (* For generating SMT when we have a reg_deref(r : register(t)) + function, we have to expand it into a if-then-else cascade that + checks if r is any one of the registers with type t, and reads that + register if it is. We also do a similar thing for *r = x + *) + class expand_reg_deref_visitor env : jib_visitor = + object + inherit empty_jib_visitor + + method! vcval _ = SkipChildren + method! vctyp _ = SkipChildren + method! vclexp _ = SkipChildren + + method! vinstr = + function + | I_aux (I_funcall (CR_one (CL_addr (CL_id (id, ctyp))), false, function_id, args), (_, l)) -> begin + match ctyp with + | CT_ref reg_ctyp -> begin + match CTMap.find_opt reg_ctyp Config.register_map with + | Some regs -> + let end_label = label "end_reg_write_" in + let try_reg r = + let next_label = label "next_reg_write_" in + [ + ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; + ifuncall l (CL_id (name r, reg_ctyp)) function_id args; + igoto end_label; + ilabel next_label; + ] + in + ChangeTo (iblock (List.concat (List.map try_reg regs) @ [ilabel end_label])) + | None -> + raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) + end + | _ -> + raise + (Reporting.err_general l "Register reference assignment must take a register reference as an argument") + end + | I_aux (I_funcall (CR_one clexp, false, function_id, [reg_ref]), (_, l)) as instr -> + let open Type_check in + begin + match + if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") + else None + with + | Some "reg_deref" -> begin + match cval_ctyp reg_ref with + | CT_ref reg_ctyp -> begin + (* Not find all the registers with this ctyp *) + match CTMap.find_opt reg_ctyp Config.register_map with + | Some regs -> + let end_label = label "end_reg_deref_" in + let try_reg r = + let next_label = label "next_reg_deref_" in + [ + ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); reg_ref])) next_label; + icopy l clexp (V_id (name r, reg_ctyp)); + igoto end_label; + ilabel next_label; + ] + in + ChangeTo (iblock (List.concat (List.map try_reg regs) @ [ilabel end_label])) + | None -> + raise + (Reporting.err_general l + ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp) + ) + end + | _ -> + raise + (Reporting.err_general l "Register dereference must have a register reference as an argument") + end + | _ -> SkipChildren + end + | I_aux (I_copy (CL_addr (CL_id (id, ctyp)), cval), (_, l)) -> begin + match ctyp with + | CT_ref reg_ctyp -> begin + match CTMap.find_opt reg_ctyp Config.register_map with + | Some regs -> + let end_label = label "end_reg_write_" in + let try_reg r = + let next_label = label "next_reg_write_" in + [ + ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; + icopy l (CL_id (name r, reg_ctyp)) cval; + igoto end_label; + ilabel next_label; + ] + in + ChangeTo (iblock (List.concat (List.map try_reg regs) @ [ilabel end_label])) + | None -> + raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) + end + | _ -> + raise + (Reporting.err_general l "Register reference assignment must take a register reference as an argument") + end + | _ -> DoChildren + end + + let generate_smt ~properties ~name_file ~smt_includes ctx cdefs = + let cdefs = visit_cdefs (new expand_reg_deref_visitor ctx.tc_env) cdefs in + smt_cdefs [] properties [] name_file ctx cdefs smt_includes cdefs +end -module SMT_config (Opts : sig +module CompileConfig (Opts : sig val unroll_limit : int end) : Jib_compile.CONFIG = struct open Jib_compile @@ -1422,8 +1032,11 @@ end) : Jib_compile.CONFIG = struct | Nexp_aux (Nexp_constant n, _) -> CT_fbits (Big_int.to_int n) | _ -> CT_lbits end - | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_typ typ, _)]) when string_of_id id = "vector" -> - CT_vector (convert_typ ctx typ) + | Typ_app (id, [A_aux (A_nexp n, _); A_aux (A_typ typ, _)]) when string_of_id id = "vector" -> begin + match nexp_simp n with + | Nexp_aux (Nexp_constant c, _) -> CT_fvector (Big_int.to_int c, convert_typ ctx typ) + | _ -> CT_vector (convert_typ ctx typ) + end | Typ_app (id, [A_aux (A_typ typ, _)]) when string_of_id id = "register" -> CT_ref (convert_typ ctx typ) | Typ_id id when Bindings.mem id ctx.records -> CT_struct (id, Bindings.find id ctx.records |> snd |> Bindings.bindings) @@ -1557,766 +1170,9 @@ end) : Jib_compile.CONFIG = struct let use_real = true let branch_coverage = None let track_throw = false + let use_void = false end -(**************************************************************************) -(* 3. Generating SMT *) -(**************************************************************************) - -let push_smt_defs stack smt_defs = List.iter (fun def -> Stack.push def stack) smt_defs - -(* When generating SMT when we encounter joins between two or more - blocks such as in the example below, we have to generate a muxer - that chooses the correct value of v_n or v_m to assign to v_o. We - use the pi nodes that contain the path condition for each - block to generate an if-then-else for each phi function. The order - of the arguments to each phi function is based on the graph node - index for the predecessor nodes. - - +---------------+ +---------------+ - | pi(cond_1) | | pi(cond_2) | - | ... | | ... | - | Basic block 1 | | Basic block 2 | - +---------------+ +---------------+ - \ / - \ / - +---------------------+ - | v/o = phi(v/n, v/m) | - | ... | - +---------------------+ - - would generate: - - (define-const v/o (ite cond_1 v/n v/m_)) -*) -let smt_ssanode ctx cfg preds = - let open Jib_ssa in - function - | Pi _ -> [] - | Phi (id, ctyp, ids) -> ( - let get_pi n = - match get_vertex cfg n with - | Some ((ssa_elems, _), _, _) -> List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems) - | None -> failwith "Predecessor node does not exist" - in - let pis = List.map get_pi (IntSet.elements preds) in - let mux = - List.fold_right2 - (fun pi id chain -> - let pathcond = smt_conj (List.map (smt_cval ctx) pi) in - match chain with - | Some smt -> Some (Ite (pathcond, Var (zencode_name id), smt)) - | None -> Some (Var (zencode_name id)) - ) - pis ids None - in - match mux with None -> assert false | Some mux -> [Define_const (zencode_name id, smt_ctyp ctx ctyp, mux)] - ) - -(* The pi condition are computed by traversing the dominator tree, - with each node having a pi condition defined as the conjunction of - all guards between it and the start node in the dominator - tree. This is imprecise because we have situations like: - - 1 - / \ - 2 3 - | | - | 4 - | |\ - 5 6 9 - \ / | - 7 10 - | - 8 - - where 8 = match_failure, 1 = start and 10 = return. - 2, 3, 6 and 9 are guards as they come directly after a control flow - split, which always follows a conditional jump. - - Here the path through the dominator tree for the match_failure is - 1->7->8 which contains no guards so the pi condition would be empty. - What we do now is walk backwards (CFG must be acyclic at this point) - until we hit the join point prior to where we require a path - condition. We then take the disjunction of the pi conditions for the - join point's predecessors, so 5 and 6 in this case. Which gives us a - path condition of 2 | (3 & 6) as the dominator chains are 1->2->5 and - 1->3->4->6. - - This should work as any split in the control flow must have been - caused by a conditional jump followed by distinct guards, so each of - the nodes immediately prior to a join point must be dominated by at - least one unique guard. It also explains why the pi conditions are - sufficient to choose outcomes of phi functions above. - - If we hit a guard before a join (such as 9 for return's path - conditional) we just return the pi condition for that guard, i.e. - (3 & 9) for 10. If we reach start then the path condition is simply - true. -*) -let rec get_pathcond n cfg ctx = - let open Jib_ssa in - let get_pi m = - match get_vertex cfg m with - | Some ((ssa_elems, _), _, _) -> - V_call (Band, List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems)) - | None -> failwith "Node does not exist" - in - match get_vertex cfg n with - | Some ((_, CF_guard cond), _, _) -> smt_cval ctx (get_pi n) - | Some (_, preds, succs) -> - if IntSet.cardinal preds = 0 then Bool_lit true - else if IntSet.cardinal preds = 1 then get_pathcond (IntSet.min_elt preds) cfg ctx - else ( - let pis = List.map get_pi (IntSet.elements preds) in - smt_cval ctx (V_call (Bor, pis)) - ) - | None -> assert false (* Should never be called for a non-existent node *) - -(* For any complex l-expression we need to turn it into a - read-modify-write in the SMT solver. The SSA transform turns CL_id - nodes into CL_rmw (read, write, ctyp) nodes when CL_id is wrapped - in any other l-expression. The read and write must have the same - name but different SSA numbers. -*) -let rec rmw_write = function - | CL_rmw (_, write, ctyp) -> (write, ctyp) - | CL_id _ -> assert false - | CL_tuple (clexp, _) -> rmw_write clexp - | CL_field (clexp, _) -> rmw_write clexp - | clexp -> failwith "Could not understand l-expression" - -let rmw_read = function CL_rmw (read, _, _) -> zencode_name read | _ -> assert false - -let rmw_modify smt = function - | CL_tuple (clexp, n) -> - let ctyp = clexp_ctyp clexp in - begin - match ctyp with - | CT_tup ctyps -> - let len = List.length ctyps in - let set_tup i = if i == n then smt else Fn (Printf.sprintf "tup_%d_%d" len i, [Var (rmw_read clexp)]) in - Fn ("tup" ^ string_of_int len, List.init len set_tup) - | _ -> failwith "Tuple modify does not have tuple type" - end - | CL_field (clexp, field) -> - let ctyp = clexp_ctyp clexp in - begin - match ctyp with - | CT_struct (struct_id, fields) -> - let set_field (field', _) = - if Id.compare field field' = 0 then smt - else Field (zencode_upper_id struct_id ^ "_" ^ zencode_id field', Var (rmw_read clexp)) - in - Fn (zencode_upper_id struct_id, List.map set_field fields) - | _ -> failwith "Struct modify does not have struct type" - end - | _ -> assert false - -let smt_terminator ctx = - let open Jib_ssa in - function - | T_end id -> - add_event ctx Return (Var (zencode_name id)); - [] - | T_exit _ -> - add_pathcond_event ctx Match; - [] - | T_undefined _ | T_goto _ | T_jump _ | T_label _ | T_none -> [] - -(* For a basic block (contained in a control-flow node / cfnode), we - turn the instructions into a sequence of define-const and - declare-const expressions. Because we are working with a SSA graph, - each variable is guaranteed to only be declared once. -*) -let smt_instr ctx = - let open Type_check in - function - | I_aux (I_funcall (CL_id (id, ret_ctyp), extern, function_id, args), (_, l)) -> - if Env.is_extern (fst function_id) ctx.tc_env "c" && not extern then ( - let name = Env.get_extern (fst function_id) ctx.tc_env "c" in - if name = "sqrt_real" then begin - match args with - | [v] -> builtin_sqrt_real ctx (zencode_name id) v - | _ -> Reporting.unreachable l __POS__ "Bad arguments for sqrt_real" - (* See lib/regfp.sail *) - end - else if name = "platform_write_mem" then begin - match args with - | [wk; addr_size; addr; data_size; data] -> - let mem_event, var = builtin_write_mem l ctx wk addr_size addr data_size data in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __write_mem" - end - else if name = "platform_write_mem_ea" then begin - match args with - | [wk; addr_size; addr; data_size] -> - let mem_event, var = builtin_write_mem_ea ctx wk addr_size addr data_size in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __write_mem_ea" - end - else if name = "platform_read_mem" then begin - match args with - | [rk; addr_size; addr; data_size] -> - let mem_event, var = builtin_read_mem l ctx rk addr_size addr data_size ret_ctyp in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __read_mem" - end - else if name = "platform_barrier" then begin - match args with - | [bk] -> - let mem_event, var = builtin_barrier l ctx bk in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" - end - else if name = "platform_cache_maintenance" then begin - match args with - | [cmk; addr_size; addr] -> - let mem_event, var = builtin_cache_maintenance l ctx cmk addr_size addr in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" - end - else if name = "platform_branch_announce" then begin - match args with - | [addr_size; addr] -> - let mem_event, var = builtin_branch_announce l ctx addr_size addr in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __barrier" - end - else if name = "platform_excl_res" then begin - match args with - | [_] -> - let mem_event, var = builtin_excl_res ctx in - mem_event @ [define_const ctx id ret_ctyp var] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for __excl_res" - end - else if name = "sail_exit" then ( - add_event ctx Assertion (Bool_lit false); - [] - ) - else if name = "sail_assert" then begin - match args with - | [assertion; _] -> - let smt = smt_cval ctx assertion in - add_event ctx Assertion (Fn ("not", [smt])); - [] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for assertion" - end - else ( - let value = smt_builtin ctx name args ret_ctyp in - [define_const ctx id ret_ctyp (Syntactic (value, List.map (smt_cval ctx) args))] - ) - ) - else if extern && string_of_id (fst function_id) = "internal_vector_init" then [declare_const ctx id ret_ctyp] - else if extern && string_of_id (fst function_id) = "internal_vector_update" then begin - match args with - | [vec; i; x] -> - let sz = int_size ctx (cval_ctyp i) in - [ - define_const ctx id ret_ctyp - (Fn - ( "store", - [ - smt_cval ctx vec; - force_size ~checked:false ctx ctx.vector_index sz (smt_cval ctx i); - smt_cval ctx x; - ] - ) - ); - ] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for internal_vector_update" - end - else if - (string_of_id (fst function_id) = "update_fbits" || string_of_id (fst function_id) = "update_lbits") && extern - then begin - match args with - | [vec; i; x] -> [define_const ctx id ret_ctyp (builtin_vector_update ctx vec i x ret_ctyp)] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for update_{f,l}bits" - end - else if string_of_id (fst function_id) = "sail_assume" then begin - match args with - | [assumption] -> - let smt = smt_cval ctx assumption in - add_event ctx Assumption smt; - [] - | _ -> Reporting.unreachable l __POS__ "Bad arguments for assumption" - end - else if not extern then ( - let smt_args = List.map (smt_cval ctx) args in - [define_const ctx id ret_ctyp (Ctor (zencode_uid function_id, smt_args))] - ) - else failwith ("Unrecognised function " ^ string_of_uid function_id) - | I_aux (I_copy (CL_addr (CL_id (_, _)), _), (_, l)) -> - Reporting.unreachable l __POS__ "Register reference write should be re-written by now" - | I_aux (I_init (ctyp, id, cval), _) | I_aux (I_copy (CL_id (id, ctyp), cval), _) -> begin - match (id, cval) with - | Name (id, _), _ when IdSet.mem id ctx.preserved -> - [preserve_const ctx id ctyp (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] - | _, V_lit (VL_undefined, _) -> - (* Declare undefined variables as arbitrary but fixed *) - [declare_const ctx id ctyp] - | _, _ -> [define_const ctx id ctyp (smt_conversion ctx (cval_ctyp cval) ctyp (smt_cval ctx cval))] - end - | I_aux (I_copy (clexp, cval), _) -> - let smt = smt_cval ctx cval in - let write, ctyp = rmw_write clexp in - [define_const ctx write ctyp (rmw_modify smt clexp)] - | I_aux (I_decl (ctyp, id), (_, l)) -> - (* Function arguments have unique locations defined from the - $property pragma. We record how they will appear in the - generated SMT so we can check models. *) - begin - match l with Unique (n, l') when l' = ctx.pragma_l -> Stack.push (n, zencode_name id) ctx.arg_stack | _ -> () - end; - [declare_const ctx id ctyp] - | I_aux (I_clear _, _) -> [] - (* Should only appear as terminators for basic blocks. *) - | I_aux ((I_jump _ | I_goto _ | I_end _ | I_exit _ | I_undefined _), (_, l)) -> - Reporting.unreachable l __POS__ "SMT: Instruction should only appear as block terminator" - | I_aux (_, (_, l)) -> Reporting.unreachable l __POS__ "Cannot translate instruction" - -let smt_cfnode all_cdefs ctx ssa_elems = - let open Jib_ssa in - function - | CF_start inits -> - let smt_reg_decs = generate_reg_decs ctx inits all_cdefs in - let smt_start (id, ctyp) = - match id with Have_exception _ -> define_const ctx id ctyp (Bool_lit false) | _ -> declare_const ctx id ctyp - in - smt_reg_decs @ List.map smt_start (NameMap.bindings inits) - | CF_block (instrs, terminator) -> - let smt_instrs = List.map (smt_instr ctx) instrs in - let smt_term = smt_terminator ctx terminator in - List.concat (smt_instrs @ [smt_term]) - (* We can ignore any non basic-block/start control-flow nodes *) - | _ -> [] - -(** When we generate a property for a CDEF_val, we find it's - associated function body in a CDEF_fundef node. However, we must - keep track of any global letbindings between the spec and the - fundef, so they can appear in the generated SMT. *) -let rec find_function lets id = function - | CDEF_aux (CDEF_fundef (id', heap_return, args, body), _) :: _ when Id.compare id id' = 0 -> - (lets, Some (heap_return, args, body)) - | CDEF_aux (CDEF_let (_, vars, setup), _) :: cdefs -> - let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (name id)) vars in - find_function (lets @ vars @ setup) id cdefs - | _ :: cdefs -> find_function lets id cdefs - | [] -> (lets, None) - -module type Sequence = sig - type 'a t - val create : unit -> 'a t - val add : 'a -> 'a t -> unit -end - -module Make_optimizer (S : Sequence) = struct - let optimize stack = - let stack' = Stack.create () in - let uses = Hashtbl.create (Stack.length stack) in - - let rec uses_in_exp = function - | Var var -> begin - match Hashtbl.find_opt uses var with - | Some n -> Hashtbl.replace uses var (n + 1) - | None -> Hashtbl.add uses var 1 - end - | Syntactic (exp, _) -> uses_in_exp exp - | Shared _ | Enum _ | Read_res _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _ -> () - | Fn (_, exps) | Ctor (_, exps) -> List.iter uses_in_exp exps - | Field (_, exp) -> uses_in_exp exp - | Struct (_, fields) -> List.iter (fun (_, exp) -> uses_in_exp exp) fields - | Ite (cond, t, e) -> - uses_in_exp cond; - uses_in_exp t; - uses_in_exp e - | Extract (_, _, exp) | Tester (_, exp) | SignExtend (_, exp) -> uses_in_exp exp - | Forall _ -> assert false - in - - let remove_unused () = function - | Declare_const (var, _) as def -> begin - match Hashtbl.find_opt uses var with None -> () | Some _ -> Stack.push def stack' - end - | Declare_fun _ as def -> Stack.push def stack' - | Preserve_const (_, _, exp) as def -> - uses_in_exp exp; - Stack.push def stack' - | Define_const (var, _, exp) as def -> begin - match Hashtbl.find_opt uses var with - | None -> () - | Some _ -> - uses_in_exp exp; - Stack.push def stack' - end - | (Declare_datatypes _ | Declare_tuple _) as def -> Stack.push def stack' - | Write_mem w as def -> - uses_in_exp w.active; - uses_in_exp w.kind; - uses_in_exp w.addr; - uses_in_exp w.data; - Stack.push def stack' - | Write_mem_ea (_, _, active, wk, addr, _, data_size, _) as def -> - uses_in_exp active; - uses_in_exp wk; - uses_in_exp addr; - uses_in_exp data_size; - Stack.push def stack' - | Read_mem r as def -> - uses_in_exp r.active; - uses_in_exp r.kind; - uses_in_exp r.addr; - Stack.push def stack' - | Barrier b as def -> - uses_in_exp b.active; - uses_in_exp b.kind; - Stack.push def stack' - | Cache_maintenance m as def -> - uses_in_exp m.active; - uses_in_exp m.kind; - uses_in_exp m.addr; - Stack.push def stack' - | Branch_announce c as def -> - uses_in_exp c.active; - uses_in_exp c.addr; - Stack.push def stack' - | Excl_res (_, _, active) as def -> - uses_in_exp active; - Stack.push def stack' - | Assert exp as def -> - uses_in_exp exp; - Stack.push def stack' - | Define_fun _ -> assert false - in - Stack.fold remove_unused () stack; - - let vars = Hashtbl.create (Stack.length stack') in - let kinds = Hashtbl.create (Stack.length stack') in - let seq = S.create () in - - let constant_propagate = function - | Declare_const _ as def -> S.add def seq - | Declare_fun _ as def -> S.add def seq - | Preserve_const (var, typ, exp) -> S.add (Preserve_const (var, typ, simp_smt_exp vars kinds exp)) seq - | Define_const (var, typ, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin - match (Hashtbl.find_opt uses var, simp_smt_exp vars kinds exp) with - | _, (Bitvec_lit _ | Bool_lit _) -> Hashtbl.add vars var exp - | _, Var _ when !opt_propagate_vars -> Hashtbl.add vars var exp - | _, Ctor (str, _) -> - Hashtbl.add kinds var str; - S.add (Define_const (var, typ, exp)) seq - | Some 1, _ -> Hashtbl.add vars var exp - | Some _, exp -> S.add (Define_const (var, typ, exp)) seq - | None, _ -> assert false - end - | Write_mem w -> - S.add - (Write_mem - { - w with - active = simp_smt_exp vars kinds w.active; - kind = simp_smt_exp vars kinds w.kind; - addr = simp_smt_exp vars kinds w.addr; - data = simp_smt_exp vars kinds w.data; - } - ) - seq - | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - S.add - (Write_mem_ea - ( name, - node, - simp_smt_exp vars kinds active, - simp_smt_exp vars kinds wk, - simp_smt_exp vars kinds addr, - addr_ty, - simp_smt_exp vars kinds data_size, - data_size_ty - ) - ) - seq - | Read_mem r -> - S.add - (Read_mem - { - r with - active = simp_smt_exp vars kinds r.active; - kind = simp_smt_exp vars kinds r.kind; - addr = simp_smt_exp vars kinds r.addr; - } - ) - seq - | Barrier b -> - S.add - (Barrier { b with active = simp_smt_exp vars kinds b.active; kind = simp_smt_exp vars kinds b.kind }) - seq - | Cache_maintenance m -> - S.add - (Cache_maintenance - { - m with - active = simp_smt_exp vars kinds m.active; - kind = simp_smt_exp vars kinds m.kind; - addr = simp_smt_exp vars kinds m.addr; - } - ) - seq - | Branch_announce c -> - S.add - (Branch_announce { c with active = simp_smt_exp vars kinds c.active; addr = simp_smt_exp vars kinds c.addr }) - seq - | Excl_res (name, node, active) -> S.add (Excl_res (name, node, simp_smt_exp vars kinds active)) seq - | Assert exp -> S.add (Assert (simp_smt_exp vars kinds exp)) seq - | (Declare_datatypes _ | Declare_tuple _) as def -> S.add def seq - | Define_fun _ -> assert false - in - Stack.iter constant_propagate stack'; - seq -end - -module Queue_optimizer = Make_optimizer (struct - type 'a t = 'a Queue.t - let create = Queue.create - let add = Queue.add - let iter = Queue.iter -end) - -(** [smt_header ctx cdefs] produces a list of smt definitions for all the datatypes in a specification *) -let smt_header ctx cdefs = - let smt_ctype_defs = List.concat (generate_ctype_defs ctx cdefs) in - [declare_datatypes (mk_enum "Unit" ["unit"])] - @ (IntSet.elements !(ctx.tuple_sizes) |> List.map (fun n -> Declare_tuple n)) - @ [declare_datatypes (mk_record "Bits" [("len", Bitvec ctx.lbits_index); ("contents", Bitvec (lbits_size ctx))])] - @ smt_ctype_defs - -(* For generating SMT when we have a reg_deref(r : register(t)) - function, we have to expand it into a if-then-else cascade that - checks if r is any one of the registers with type t, and reads that - register if it is. We also do a similar thing for *r = x -*) -let expand_reg_deref env register_map = function - | I_aux (I_funcall (CL_addr (CL_id (id, ctyp)), false, function_id, args), (_, l)) -> begin - match ctyp with - | CT_ref reg_ctyp -> begin - match CTMap.find_opt reg_ctyp register_map with - | Some regs -> - let end_label = label "end_reg_write_" in - let try_reg r = - let next_label = label "next_reg_write_" in - [ - ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; - ifuncall l (CL_id (name r, reg_ctyp)) function_id args; - igoto end_label; - ilabel next_label; - ] - in - iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) - | None -> raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) - end - | _ -> - raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") - end - | I_aux (I_funcall (clexp, false, function_id, [reg_ref]), (_, l)) as instr -> - let open Type_check in - begin - match - if Env.is_extern (fst function_id) env "smt" then Some (Env.get_extern (fst function_id) env "smt") else None - with - | Some "reg_deref" -> begin - match cval_ctyp reg_ref with - | CT_ref reg_ctyp -> begin - (* Not find all the registers with this ctyp *) - match CTMap.find_opt reg_ctyp register_map with - | Some regs -> - let end_label = label "end_reg_deref_" in - let try_reg r = - let next_label = label "next_reg_deref_" in - [ - ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); reg_ref])) next_label; - icopy l clexp (V_id (name r, reg_ctyp)); - igoto end_label; - ilabel next_label; - ] - in - iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) - | None -> - raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) - end - | _ -> raise (Reporting.err_general l "Register dereference must have a register reference as an argument") - end - | _ -> instr - end - | I_aux (I_copy (CL_addr (CL_id (id, ctyp)), cval), (_, l)) -> begin - match ctyp with - | CT_ref reg_ctyp -> begin - match CTMap.find_opt reg_ctyp register_map with - | Some regs -> - let end_label = label "end_reg_write_" in - let try_reg r = - let next_label = label "next_reg_write_" in - [ - ijump l (V_call (Neq, [V_lit (VL_ref (string_of_id r), reg_ctyp); V_id (id, ctyp)])) next_label; - icopy l (CL_id (name r, reg_ctyp)) cval; - igoto end_label; - ilabel next_label; - ] - in - iblock (List.concat (List.map try_reg regs) @ [ilabel end_label]) - | None -> raise (Reporting.err_general l ("Could not find any registers with type " ^ string_of_ctyp reg_ctyp)) - end - | _ -> - raise (Reporting.err_general l "Register reference assignment must take a register reference as an argument") - end - | instr -> instr - -let rec smt_query ctx = function - | Q_all ev -> - let stack = event_stack ctx ev in - smt_conj (Stack.fold (fun xs x -> x :: xs) [] stack) - | Q_exist ev -> - let stack = event_stack ctx ev in - smt_disj (Stack.fold (fun xs x -> x :: xs) [] stack) - | Q_not q -> Fn ("not", [smt_query ctx q]) - | Q_and qs -> Fn ("and", List.map (smt_query ctx) qs) - | Q_or qs -> Fn ("or", List.map (smt_query ctx) qs) - -let dump_graph name cfg = - let gv_file = name ^ ".gv" in - let out_chan = open_out gv_file in - Jib_ssa.make_dot out_chan cfg; - close_out out_chan - -let smt_instr_list name ctx all_cdefs instrs = - let stack = Stack.create () in - - let open Jib_ssa in - let start, cfg = ssa instrs in - let visit_order = - try topsort cfg - with Not_a_DAG n -> - dump_graph name cfg; - raise - (Reporting.err_general ctx.pragma_l - (Printf.sprintf "%s: control flow graph is not acyclic (node %d is in cycle)\nWrote graph to %s.gv" name n - name - ) - ) - in - if !opt_debug_graphs then dump_graph name cfg; - - List.iter - (fun n -> - match get_vertex cfg n with - | None -> () - | Some ((ssa_elems, cfnode), preds, succs) -> - let muxers = ssa_elems |> List.map (smt_ssanode ctx cfg preds) |> List.concat in - let ctx = { ctx with node = n; pathcond = lazy (get_pathcond n cfg ctx) } in - let basic_block = smt_cfnode all_cdefs ctx ssa_elems cfnode in - push_smt_defs stack muxers; - push_smt_defs stack basic_block - ) - visit_order; - - (stack, start, cfg) - -let smt_cdef props lets name_file ctx all_cdefs smt_includes (CDEF_aux (aux, _)) = - match aux with - | CDEF_val (function_id, _, arg_ctyps, ret_ctyp) when Bindings.mem function_id props -> begin - match find_function [] function_id all_cdefs with - | intervening_lets, Some (None, args, instrs) -> - let prop_type, prop_args, pragma_l, vs = Bindings.find function_id props in - - let pragma = parse_pragma pragma_l prop_args in - - let ctx = { ctx with events = ref EventMap.empty; pragma_l; arg_stack = Stack.create () } in - - (* When we create each argument declaration, give it a unique - location from the $property pragma, so we can identify it later. *) - let arg_decls = - List.map2 - (fun id ctyp -> - let l = unique pragma_l in - idecl l ctyp (name id) - ) - args arg_ctyps - in - let instrs = - let open Jib_optimize in - lets @ intervening_lets @ arg_decls @ instrs - |> inline all_cdefs (fun _ -> true) - |> List.map (map_instr (expand_reg_deref ctx.tc_env ctx.register_map)) - |> flatten_instrs |> remove_unused_labels |> remove_pointless_goto - in - - let stack, _, _ = smt_instr_list (string_of_id function_id) ctx all_cdefs instrs in - - let query = smt_query ctx pragma.query in - push_smt_defs stack [Assert (Fn ("not", [query]))]; - - let fname = name_file (string_of_id function_id) in - let out_chan = open_out fname in - if prop_type = "counterexample" then output_string out_chan "(set-option :produce-models true)\n"; - - let header = smt_header ctx all_cdefs in - - (* If the solver is Z3, don't output a logic as Z3 will infer it. *) - begin - match !opt_auto_solver with - | Z3 -> () - | _ -> - if !(ctx.use_string) || !(ctx.use_real) then output_string out_chan "(set-logic ALL)\n" - else output_string out_chan "(set-logic QF_AUFBVFPDT)\n" - end; - - List.iter - (fun def -> - output_string out_chan (string_of_smt_def def); - output_string out_chan "\n" - ) - header; - - (* Include custom SMT definitions. *) - List.iter (fun include_file -> output_string out_chan (Util.read_whole_file include_file)) smt_includes; - - let queue = Queue_optimizer.optimize stack in - Queue.iter - (fun def -> - output_string out_chan (string_of_smt_def def); - output_string out_chan "\n" - ) - queue; - - output_string out_chan "(check-sat)\n"; - if prop_type = "counterexample" then output_string out_chan "(get-model)\n"; - - close_out out_chan; - if prop_type = "counterexample" && !opt_auto then ( - let arg_names = Stack.fold (fun m (k, v) -> (k, v) :: m) [] ctx.arg_stack in - let arg_smt_names = - List.map - (function - | I_aux (I_decl (_, Name (id, _)), (_, Unique (n, _))) -> (id, List.assoc_opt n arg_names) - | _ -> assert false - ) - arg_decls - in - check_counterexample ctx.ast ctx.tc_env fname function_id args arg_ctyps arg_smt_names - ) - | _ -> failwith "Bad function body" - end - | _ -> () - -let rec smt_cdefs props lets name_file ctx ast smt_includes = function - | CDEF_aux (CDEF_let (_, vars, setup), _) :: cdefs -> - let vars = List.map (fun (id, ctyp) -> idecl (id_loc id) ctyp (name id)) vars in - smt_cdefs props (lets @ vars @ setup) name_file ctx ast smt_includes cdefs - | cdef :: cdefs -> - smt_cdef props lets name_file ctx ast smt_includes cdef; - smt_cdefs props lets name_file ctx ast smt_includes cdefs - | [] -> () - (* In order to support register references, we need to build a map from each ctyp to a list of registers with that ctyp, then when we see a type like register(bits(32)) we can use the map to figure out @@ -2333,40 +1189,19 @@ let rec build_register_map rmap = function | _ :: cdefs -> build_register_map rmap cdefs | [] -> rmap -let compile env effect_info ast = +let compile ~unroll_limit env effect_info ast = let cdefs, jib_ctx = - let module Jibc = Jib_compile.Make (SMT_config (struct - let unroll_limit = !opt_unroll_limit + let module Jibc = Jib_compile.Make (CompileConfig (struct + let unroll_limit = unroll_limit end)) in let env, effect_info = Jib_compile.add_special_functions env effect_info in let ctx = Jib_compile.initial_ctx env effect_info in let t = Profile.start () in let cdefs, ctx = Jibc.compile_ast ctx ast in + let cdefs, ctx = Jib_optimize.remove_tuples cdefs ctx in + let cdefs = Jib_optimize.unique_per_function_ids cdefs in Profile.finish "Compiling to Jib IR" t; (cdefs, ctx) in - let cdefs = Jib_optimize.unique_per_function_ids cdefs in - let rmap = build_register_map CTMap.empty cdefs in - (cdefs, jib_ctx, { (initial_ctx ()) with tc_env = jib_ctx.tc_env; register_map = rmap; ast }) - -let serialize_smt_model file env effect_info ast = - let cdefs, _, ctx = compile env effect_info ast in - let out_chan = open_out file in - Marshal.to_channel out_chan cdefs []; - Marshal.to_channel out_chan (Type_check.Env.set_prover None ctx.tc_env) []; - Marshal.to_channel out_chan ctx.register_map []; - close_out out_chan - -let deserialize_smt_model file = - let in_chan = open_in file in - let cdefs = (Marshal.from_channel in_chan : cdef list) in - let env = (Marshal.from_channel in_chan : Type_check.env) in - let rmap = (Marshal.from_channel in_chan : id list CTMap.t) in - close_in in_chan; - (cdefs, { (initial_ctx ()) with tc_env = env; register_map = rmap }) - -let generate_smt props name_file env effect_info smt_includes ast = - try - let cdefs, _, ctx = compile env effect_info ast in - smt_cdefs props [] name_file ctx cdefs smt_includes cdefs - with Type_error.Type_error (l, err) -> raise (Type_error.to_reporting_exn l err) + let register_map = build_register_map CTMap.empty cdefs in + (cdefs, jib_ctx, register_map) diff --git a/src/sail_smt_backend/jib_smt.mli b/src/sail_smt_backend/jib_smt.mli index 626e46b3a..3fac7aaa6 100644 --- a/src/sail_smt_backend/jib_smt.mli +++ b/src/sail_smt_backend/jib_smt.mli @@ -72,113 +72,41 @@ open Ast_defs open Ast_util open Jib open Jib_util -open Jib_ssa -open Smtlib -val opt_ignore_overflow : bool ref -val opt_auto : bool ref val opt_debug_graphs : bool ref -val opt_propagate_vars : bool ref -val zencode_name : name -> string - -module IntSet : Set.S with type elt = int -module EventMap : Map.S with type key = Property.event - -(** These give the default bounds for various SMT types, stored in the - initial_ctx. *) - -val opt_default_lint_size : int ref -val opt_default_lbits_index : int ref -val opt_default_vector_index : int ref - -type ctx = { - lbits_index : int; - (** Arbitrary-precision bitvectors are represented as a (BitVec lbits_index, BitVec (2 ^ lbits_index)) pair. *) - lint_size : int; (** The size we use for integers where we don't know how large they are statically. *) - vector_index : int; - (** A generic vector, vector('a) becomes Array (BitVec vector_index) 'a. - We need to take care that vector_index is large enough for all generic vectors. *) - register_map : id list CTMap.t; (** A map from each ctyp to a list of registers of that ctyp *) - tuple_sizes : IntSet.t ref; (** A set to keep track of all the tuple sizes we need to generate types for *) - tc_env : Type_check.Env.t; (** tc_env is the global type-checking environment *) - pragma_l : Ast.l; - (** A location, usually the $counterexample or $property we are - generating the SMT for. Used for error messages. *) - arg_stack : (int * string) Stack.t; (** Used internally to keep track of function argument names *) - ast : Type_check.typed_ast; (** The fully type-checked ast *) - shared : ctyp Bindings.t; - (** Shared variables. These variables do not get renamed by - Smtlib.suffix_variables_def, and their SSA number is - omitted. They should therefore only ever be read and never - written. Used by sail-axiomatic for symbolic values in the - initial litmus state. *) - preserved : IdSet.t; - (** icopy instructions to an id in preserved will generated a - define-const (by using Smtlib.Preserved_const) that will not be - simplified away or renamed. It will also not get a SSA - number. Such variables can therefore only ever be written to - once, and never read. They are used by sail-axiomatic to - extract information from the generated SMT. *) - events : smt_exp Stack.t EventMap.t ref; - (** For every event type we have a stack of boolean SMT - expressions for each occurance of that event. See - src/property.ml for the event types *) - node : int; - pathcond : smt_exp Lazy.t; - (** When generating SMT for an instruction pathcond will contain - the global path conditional of the containing block/node in the - control flow graph *) - use_string : bool ref; - use_real : bool ref; - (** Set if we need to use strings or real numbers in the generated - SMT, which then requires set-logic ALL or similar depending on - the solver *) -} - -(** Compile an AST into Jib suitable for SMT generation, and initialise a context. *) -val compile : Type_check.Env.t -> Effects.side_effect_info -> Type_check.typed_ast -> cdef list * Jib_compile.ctx * ctx - -(* TODO: Currently we internally use mutable stacks and queues to - avoid any issues with stack overflows caused by some non - tail-recursive list functions, as the generated SMT can be very - long, especially without any optimization. Not clear that this is - really better than just using lists. *) - -val smt_header : ctx -> cdef list -> smt_def list - -val smt_query : ctx -> Property.query -> smt_exp - -val smt_instr_list : - string -> ctx -> cdef list -> instr list -> smt_def Stack.t * int * (ssa_elem list * cf_node) Jib_ssa.array_graph - -module type Sequence = sig - type 'a t - val create : unit -> 'a t - val add : 'a -> 'a t -> unit +module type CONFIG = sig + val max_unknown_integer_width : int + val max_unknown_bitvector_width : int + val max_unknown_generic_vector_length : int + val register_map : id list CTMap.t + val ignore_overflow : bool end -(** Optimize SMT generated by smt_instr_list. SMT definitions are - added to the result sequence in the order they should appear in the - final SMTLIB file. Depending on the order in which we want to - process the results we can either use a FIFO queue or a LIFO - stack, or any other structure. *) -module Make_optimizer (S : Sequence) : sig - val optimize : smt_def Stack.t -> smt_def S.t +module Make (Config : CONFIG) : sig + type generated_smt_info = { + file_name : string; + function_id : id; + args : id list; + arg_ctyps : ctyp list; + arg_smt_names : (id * string option) list; + } + + (** Generate SMT for all the $property and $counterexample pragmas + provided, and write the generated SMT to appropriately named + files. *) + val generate_smt : + properties:(string * string * l * 'a val_spec) Bindings.t (** See Property.find_properties *) -> + name_file:(string -> string) (** Applied to each function name to generate the file name for the smtlib file *) -> + smt_includes:string list (** Extra files to include in each generated SMT problem *) -> + Jib_compile.ctx -> + cdef list -> + generated_smt_info list end -val serialize_smt_model : string -> Type_check.Env.t -> Effects.side_effect_info -> Type_check.typed_ast -> unit - -val deserialize_smt_model : string -> cdef list * ctx - -(** Generate SMT for all the $property and $counterexample pragmas in - an AST, and write it to appropriately named files. *) -val generate_smt : - (string * string * l * 'a val_spec) Bindings.t (* See Property.find_properties *) -> - (string -> string) -> - (* Applied to each function name to generate the file name for the smtlib file *) +val compile : + unroll_limit:int -> Type_check.Env.t -> Effects.side_effect_info -> - string list -> Type_check.typed_ast -> - unit + cdef list * Jib_compile.ctx * id list CTMap.t diff --git a/src/sail_smt_backend/sail_plugin_smt.ml b/src/sail_smt_backend/sail_plugin_smt.ml index 200e9f3c3..5161fed49 100644 --- a/src/sail_smt_backend/sail_plugin_smt.ml +++ b/src/sail_smt_backend/sail_plugin_smt.ml @@ -67,35 +67,43 @@ open Libsail -let opt_includes_smt : string list ref = ref [] +open Jib_smt -let set_auto_solver arg = - let open Smtlib in - match counterexample_solver_from_name arg with Some solver -> opt_auto_solver := solver | None -> () +let opt_smt_auto = ref false +let opt_smt_auto_solver = ref Smt_exp.Cvc5 +let opt_smt_includes : string list ref = ref [] +let opt_smt_ignore_overflow = ref false +let opt_smt_unknown_integer_width = ref 128 +let opt_smt_unknown_bitvector_width = ref 64 +let opt_smt_unknown_generic_vector_width = ref 32 + +let set_smt_auto_solver arg = + let open Smt_exp in + match counterexample_solver_from_name arg with Some solver -> opt_smt_auto_solver := solver | None -> () let smt_options = [ - ("-smt_auto", Arg.Tuple [Arg.Set Jib_smt.opt_auto], " automatically call the smt solver on generated SMT"); + ("-smt_auto", Arg.Tuple [Arg.Set opt_smt_auto], " automatically call the smt solver on generated SMT"); ( "-smt_auto_solver", - Arg.String set_auto_solver, + Arg.Tuple [Arg.Set opt_smt_auto; Arg.String set_smt_auto_solver], " set the solver to use for counterexample checks (default cvc5)" ); - ("-smt_ignore_overflow", Arg.Set Jib_smt.opt_ignore_overflow, " ignore integer overflow in generated SMT"); - ("-smt_propagate_vars", Arg.Set Jib_smt.opt_propagate_vars, " propgate variables through generated SMT"); + ("-smt_ignore_overflow", Arg.Set opt_smt_ignore_overflow, " ignore integer overflow in generated SMT"); ( "-smt_int_size", - Arg.String (fun n -> Jib_smt.opt_default_lint_size := int_of_string n), + Arg.String (fun n -> opt_smt_unknown_integer_width := int_of_string n), " set a bound of n on the maximum integer bitwidth for generated SMT (default 128)" ); + ("-smt_propagate_vars", Arg.Unit (fun () -> ()), " (deprecated) propgate variables through generated SMT"); ( "-smt_bits_size", - Arg.String (fun n -> Jib_smt.opt_default_lbits_index := int_of_string n), - " set a bound of 2 ^ n for bitvector bitwidth in generated SMT (default 8)" + Arg.String (fun n -> opt_smt_unknown_bitvector_width := int_of_string n), + " set a size bound of n for unknown-length bitvectors in generated SMT (default 64)" ); ( "-smt_vector_size", - Arg.String (fun n -> Jib_smt.opt_default_vector_index := int_of_string n), + Arg.String (fun n -> opt_smt_unknown_generic_vector_width := int_of_string n), " set a bound of 2 ^ n for generic vectors in generated SMT (default 5)" ); ( "-smt_include", - Arg.String (fun i -> opt_includes_smt := i :: !opt_includes_smt), + Arg.String (fun i -> opt_smt_includes := i :: !opt_smt_includes), " insert additional file in SMT output" ); ] @@ -134,8 +142,8 @@ let smt_rewrites = let smt_target _ _ out_file ast effect_info env = let open Ast_util in - let props = Property.find_properties ast in - let prop_ids = Bindings.bindings props |> List.map fst |> IdSet.of_list in + let properties = Property.find_properties ast in + let prop_ids = Bindings.bindings properties |> List.map fst |> IdSet.of_list in let ast = Callgraph.filter_ast_ids prop_ids IdSet.empty ast in Specialize.add_initial_calls prop_ids; let ast_smt, env, effect_info = Specialize.(specialize typ_specialization env ast effect_info) in @@ -146,6 +154,25 @@ let smt_target _ _ out_file ast effect_info env = match out_file with Some f -> fun str -> f ^ "_" ^ str ^ ".smt2" | None -> fun str -> str ^ ".smt2" in Reporting.opt_warnings := true; - Jib_smt.generate_smt props name_file env effect_info !opt_includes_smt ast_smt + let cdefs, ctx, register_map = Jib_smt.compile ~unroll_limit:10 env effect_info ast_smt in + let module SMTGen = Jib_smt.Make (struct + let max_unknown_integer_width = !opt_smt_unknown_integer_width + let max_unknown_bitvector_width = !opt_smt_unknown_bitvector_width + let max_unknown_generic_vector_length = !opt_smt_unknown_generic_vector_width + let register_map = register_map + let ignore_overflow = !opt_smt_ignore_overflow + end) in + let module Counterexample = Smt_exp.Counterexample (struct + let max_unknown_integer_width = !opt_smt_unknown_integer_width + end) in + let generated_smt = SMTGen.generate_smt ~properties ~name_file ~smt_includes:!opt_smt_includes ctx cdefs in + if !opt_smt_auto then + List.iter + (fun ({ file_name; function_id; args; arg_ctyps; arg_smt_names } : SMTGen.generated_smt_info) -> + Counterexample.check ~env:ctx.tc_env ~ast ~solver:!opt_smt_auto_solver ~file_name ~function_id ~args ~arg_ctyps + ~arg_smt_names + ) + generated_smt; + () let _ = Target.register ~name:"smt" ~options:smt_options ~rewrites:smt_rewrites smt_target diff --git a/src/sail_smt_backend/smtlib.ml b/src/sail_smt_backend/smtlib.ml deleted file mode 100644 index 33a0c5601..000000000 --- a/src/sail_smt_backend/smtlib.ml +++ /dev/null @@ -1,750 +0,0 @@ -(****************************************************************************) -(* Sail *) -(* *) -(* Sail and the Sail architecture models here, comprising all files and *) -(* directories except the ASL-derived Sail code in the aarch64 directory, *) -(* are subject to the BSD two-clause licence below. *) -(* *) -(* The ASL derived parts of the ARMv8.3 specification in *) -(* aarch64/no_vector and aarch64/full are copyright ARM Ltd. *) -(* *) -(* Copyright (c) 2013-2021 *) -(* Kathyrn Gray *) -(* Shaked Flur *) -(* Stephen Kell *) -(* Gabriel Kerneis *) -(* Robert Norton-Wright *) -(* Christopher Pulte *) -(* Peter Sewell *) -(* Alasdair Armstrong *) -(* Brian Campbell *) -(* Thomas Bauereiss *) -(* Anthony Fox *) -(* Jon French *) -(* Dominic Mulligan *) -(* Stephen Kell *) -(* Mark Wassell *) -(* Alastair Reid (Arm Ltd) *) -(* *) -(* All rights reserved. *) -(* *) -(* This work was partially supported by EPSRC grant EP/K008528/1 REMS: Rigorous *) -(* Engineering for Mainstream Systems, an ARM iCASE award, EPSRC IAA *) -(* KTF funding, and donations from Arm. This project has received *) -(* funding from the European Research Council (ERC) under the European *) -(* Union’s Horizon 2020 research and innovation programme (grant *) -(* agreement No 789108, ELVER). *) -(* *) -(* This software was developed by SRI International and the University of *) -(* Cambridge Computer Laboratory (Department of Computer Science and *) -(* Technology) under DARPA/AFRL contracts FA8650-18-C-7809 ("CIFV") *) -(* and FA8750-10-C-0237 ("CTSRD"). *) -(* *) -(* Redistribution and use in source and binary forms, with or without *) -(* modification, are permitted provided that the following conditions *) -(* are met: *) -(* 1. Redistributions of source code must retain the above copyright *) -(* notice, this list of conditions and the following disclaimer. *) -(* 2. Redistributions in binary form must reproduce the above copyright *) -(* notice, this list of conditions and the following disclaimer in *) -(* the documentation and/or other materials provided with the *) -(* distribution. *) -(* *) -(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) -(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) -(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) -(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) -(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) -(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) -(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) -(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) -(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) -(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) -(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) -(* SUCH DAMAGE. *) -(****************************************************************************) - -open Libsail - -open Ast -open Ast_util - -type counterexample_solver = Cvc5 | Cvc4 | Z3 - -let counterexample_command = function Cvc5 -> "cvc5 --lang=smt2.6" | Cvc4 -> "cvc4 --lang=smt2.6" | Z3 -> "z3" - -let counterexample_solver_from_name name = - match String.lowercase_ascii name with "cvc4" -> Some Cvc4 | "cvc5" -> Some Cvc5 | "z3" -> Some Z3 | _ -> None - -let opt_auto_solver = ref Cvc5 - -type smt_typ = - | Bitvec of int - | Bool - | String - | Real - | Datatype of string * (string * (string * smt_typ) list) list - | Tuple of smt_typ list - | Array of smt_typ * smt_typ - -let rec smt_typ_compare t1 t2 = - match (t1, t2) with - | Bitvec n, Bitvec m -> compare n m - | Bool, Bool -> 0 - | String, String -> 0 - | Real, Real -> 0 - | Datatype (name1, _), Datatype (name2, _) -> String.compare name1 name2 - | Tuple ts1, Tuple ts2 -> Util.lex_ord_list smt_typ_compare ts1 ts2 - | Array (t11, t12), Array (t21, t22) -> - let c = smt_typ_compare t11 t21 in - if c = 0 then smt_typ_compare t12 t22 else c - | Bitvec _, _ -> 1 - | _, Bitvec _ -> -1 - | Bool, _ -> 1 - | _, Bool -> -1 - | String, _ -> 1 - | _, String -> -1 - | Real, _ -> 1 - | _, Real -> -1 - | Datatype _, _ -> 1 - | _, Datatype _ -> -1 - | Tuple _, _ -> 1 - | _, Tuple _ -> -1 - -let rec smt_typ_equal t1 t2 = - match (t1, t2) with - | Bitvec n, Bitvec m -> n = m - | Bool, Bool -> true - | Datatype (name1, ctors1), Datatype (name2, ctors2) -> - let field_equal (field_name1, typ1) (field_name2, typ2) = field_name1 = field_name2 && smt_typ_equal typ1 typ2 in - let ctor_equal (ctor_name1, fields1) (ctor_name2, fields2) = - ctor_name1 = ctor_name2 - && List.length fields1 = List.length fields2 - && List.for_all2 field_equal fields1 fields2 - in - name1 = name2 && List.length ctors1 = List.length ctors2 && List.for_all2 ctor_equal ctors1 ctors2 - | _, _ -> false - -let mk_enum name elems = Datatype (name, List.map (fun elem -> (elem, [])) elems) - -let mk_record name fields = Datatype (name, [(name, fields)]) - -let mk_variant name ctors = Datatype (name, List.map (fun (ctor, ty) -> (ctor, [("un" ^ ctor, ty)])) ctors) - -type smt_exp = - | Bool_lit of bool - | Bitvec_lit of Sail2_values.bitU list - | Real_lit of string - | String_lit of string - | Var of string - | Shared of string - | Read_res of string - | Enum of string - | Fn of string * smt_exp list - | Ctor of string * smt_exp list - | Ite of smt_exp * smt_exp * smt_exp - | SignExtend of int * smt_exp - | Extract of int * int * smt_exp - | Tester of string * smt_exp - | Syntactic of smt_exp * smt_exp list - | Struct of string * (string * smt_exp) list - | Field of string * smt_exp - (* Used by sail-axiomatic, should never be generated by sail -smt! *) - | Forall of (string * smt_typ) list * smt_exp - -let rec fold_smt_exp f = function - | Fn (name, args) -> f (Fn (name, List.map (fold_smt_exp f) args)) - | Ctor (name, args) -> f (Ctor (name, List.map (fold_smt_exp f) args)) - | Ite (cond, t, e) -> f (Ite (fold_smt_exp f cond, fold_smt_exp f t, fold_smt_exp f e)) - | SignExtend (n, exp) -> f (SignExtend (n, fold_smt_exp f exp)) - | Extract (n, m, exp) -> f (Extract (n, m, fold_smt_exp f exp)) - | Tester (ctor, exp) -> f (Tester (ctor, fold_smt_exp f exp)) - | Forall (binders, exp) -> f (Forall (binders, fold_smt_exp f exp)) - | Syntactic (exp, exps) -> f (Syntactic (fold_smt_exp f exp, List.map (fold_smt_exp f) exps)) - | Field (name, exp) -> f (Field (name, fold_smt_exp f exp)) - | Struct (name, fields) -> f (Struct (name, List.map (fun (field, exp) -> (field, fold_smt_exp f exp)) fields)) - | (Bool_lit _ | Bitvec_lit _ | Real_lit _ | String_lit _ | Var _ | Shared _ | Read_res _ | Enum _) as exp -> f exp - -let smt_conj = function [] -> Bool_lit true | [x] -> x | xs -> Fn ("and", xs) - -let smt_disj = function [] -> Bool_lit false | [x] -> x | xs -> Fn ("or", xs) - -let extract i j x = Extract (i, j, x) - -let bvnot x = Fn ("bvnot", [x]) -let bvand x y = Fn ("bvand", [x; y]) -let bvor x y = Fn ("bvor", [x; y]) -let bvneg x = Fn ("bvneg", [x]) -let bvadd x y = Fn ("bvadd", [x; y]) -let bvmul x y = Fn ("bvmul", [x; y]) -let bvudiv x y = Fn ("bvudiv", [x; y]) -let bvurem x y = Fn ("bvurem", [x; y]) -let bvshl x y = Fn ("bvshl", [x; y]) -let bvlshr x y = Fn ("bvlshr", [x; y]) -let bvult x y = Fn ("bvult", [x; y]) - -let bvzero n = Bitvec_lit (Sail2_operators_bitlists.zeros (Big_int.of_int n)) - -let bvones n = Bitvec_lit (Sail2_operators_bitlists.ones (Big_int.of_int n)) - -let simp_equal x y = - match (x, y) with Bitvec_lit bv1, Bitvec_lit bv2 -> Some (Sail2_operators_bitlists.eq_vec bv1 bv2) | _, _ -> None - -let simp_and xs = - let xs = List.filter (function Bool_lit true -> false | _ -> true) xs in - match xs with - | [] -> Bool_lit true - | [x] -> x - | _ -> if List.exists (function Bool_lit false -> true | _ -> false) xs then Bool_lit false else Fn ("and", xs) - -let simp_or xs = - let xs = List.filter (function Bool_lit false -> false | _ -> true) xs in - match xs with - | [] -> Bool_lit false - | [x] -> x - | _ -> if List.exists (function Bool_lit true -> true | _ -> false) xs then Bool_lit true else Fn ("or", xs) - -let rec all_bitvec_lit = function Bitvec_lit _ :: rest -> all_bitvec_lit rest | [] -> true | _ :: _ -> false - -let rec merge_bitvec_lit = function - | Bitvec_lit b :: rest -> b @ merge_bitvec_lit rest - | [] -> [] - | _ :: _ -> assert false - -let simp_fn = function - | Fn ("not", [Fn ("not", [exp])]) -> exp - | Fn ("not", [Bool_lit b]) -> Bool_lit (not b) - | Fn ("contents", [Fn ("Bits", [_; contents])]) -> contents - | Fn ("len", [Fn ("Bits", [len; _])]) -> len - | Fn ("or", xs) -> simp_or xs - | Fn ("and", xs) -> simp_and xs - | Fn ("=>", [Bool_lit true; y]) -> y - | Fn ("=>", [Bool_lit false; y]) -> Bool_lit true - | Fn ("bvsub", [Bitvec_lit bv1; Bitvec_lit bv2]) -> Bitvec_lit (Sail2_operators_bitlists.sub_vec bv1 bv2) - | Fn ("bvadd", [Bitvec_lit bv1; Bitvec_lit bv2]) -> Bitvec_lit (Sail2_operators_bitlists.add_vec bv1 bv2) - | Fn ("concat", xs) when all_bitvec_lit xs -> Bitvec_lit (merge_bitvec_lit xs) - | Fn ("=", [x; y]) as exp -> begin match simp_equal x y with Some b -> Bool_lit b | None -> exp end - | exp -> exp - -let simp_ite = function - | Ite (cond, Bool_lit true, Bool_lit false) -> cond - | Ite (cond, Bool_lit x, Bool_lit y) when x = y -> Bool_lit x - | Ite (_, Var v, Var v') when v = v' -> Var v - | Ite (Bool_lit true, then_exp, _) -> then_exp - | Ite (Bool_lit false, _, else_exp) -> else_exp - | exp -> exp - -let rec simp_smt_exp vars kinds = function - | Var v -> begin match Hashtbl.find_opt vars v with Some exp -> simp_smt_exp vars kinds exp | None -> Var v end - | (Read_res _ | Shared _ | Enum _ | Bitvec_lit _ | Bool_lit _ | String_lit _ | Real_lit _) as exp -> exp - | Field (field, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin - match exp with Struct (_, fields) -> List.assoc field fields | _ -> Field (field, exp) - end - | Struct (name, fields) -> Struct (name, List.map (fun (field, exp) -> (field, simp_smt_exp vars kinds exp)) fields) - | Fn (f, exps) -> - let exps = List.map (simp_smt_exp vars kinds) exps in - simp_fn (Fn (f, exps)) - | Ctor (f, exps) -> - let exps = List.map (simp_smt_exp vars kinds) exps in - simp_fn (Ctor (f, exps)) - | Ite (cond, t, e) -> - let cond = simp_smt_exp vars kinds cond in - let t = simp_smt_exp vars kinds t in - let e = simp_smt_exp vars kinds e in - simp_ite (Ite (cond, t, e)) - | Extract (i, j, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin - match exp with - | Bitvec_lit bv -> - Bitvec_lit (Sail2_operators_bitlists.subrange_vec_dec bv (Big_int.of_int i) (Big_int.of_int j)) - | _ -> Extract (i, j, exp) - end - | Tester (str, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin - match exp with - | Var v -> begin - match Hashtbl.find_opt kinds v with - | Some str' when str = str' -> Bool_lit true - | Some str' -> Bool_lit false - | None -> Tester (str, exp) - end - | _ -> Tester (str, exp) - end - | Syntactic (exp, _) -> exp - | SignExtend (i, exp) -> - let exp = simp_smt_exp vars kinds exp in - begin - match exp with - | Bitvec_lit bv -> Bitvec_lit (Sail2_operators_bitlists.sign_extend bv (Big_int.of_int (i + List.length bv))) - | _ -> SignExtend (i, exp) - end - | Forall (binders, exp) -> Forall (binders, exp) - -type read_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - ret_type : smt_typ; - doc : string; -} - -type write_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - data_type : smt_typ; - data : smt_exp; - doc : string; -} - -type barrier_info = { name : string; node : int; active : smt_exp; kind : smt_exp; doc : string } - -type branch_info = { name : string; node : int; active : smt_exp; addr_type : smt_typ; addr : smt_exp; doc : string } - -type cache_op_info = { - name : string; - node : int; - active : smt_exp; - kind : smt_exp; - addr_type : smt_typ; - addr : smt_exp; - doc : string; -} - -type smt_def = - | Define_fun of string * (string * smt_typ) list * smt_typ * smt_exp - | Declare_fun of string * smt_typ list * smt_typ - | Declare_const of string * smt_typ - | Define_const of string * smt_typ * smt_exp - (* Same as Define_const, but it'll never be removed by simplification *) - | Preserve_const of string * smt_typ * smt_exp - | Write_mem of write_info - | Write_mem_ea of string * int * smt_exp * smt_exp * smt_exp * smt_typ * smt_exp * smt_typ - | Read_mem of read_info - | Barrier of barrier_info - | Branch_announce of branch_info - | Cache_maintenance of cache_op_info - | Excl_res of string * int * smt_exp - | Declare_datatypes of string * (string * (string * smt_typ) list) list - | Declare_tuple of int - | Assert of smt_exp - -let smt_def_map_exp f = function - | Define_fun (name, args, ty, exp) -> Define_fun (name, args, ty, f exp) - | Declare_fun (name, args, ty) -> Declare_fun (name, args, ty) - | Declare_const (name, ty) -> Declare_const (name, ty) - | Define_const (name, ty, exp) -> Define_const (name, ty, f exp) - | Preserve_const (name, ty, exp) -> Preserve_const (name, ty, f exp) - | Write_mem w -> Write_mem { w with active = f w.active; kind = f w.kind; addr = f w.addr; data = f w.data } - | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - Write_mem_ea (name, node, f active, f wk, f addr, addr_ty, f data_size, data_size_ty) - | Read_mem r -> Read_mem { r with active = f r.active; kind = f r.kind; addr = f r.addr } - | Barrier b -> Barrier { b with active = f b.active; kind = f b.kind } - | Cache_maintenance m -> Cache_maintenance { m with active = f m.active; kind = f m.kind; addr = f m.addr } - | Branch_announce c -> Branch_announce { c with active = f c.active; addr = f c.addr } - | Excl_res (name, node, active) -> Excl_res (name, node, f active) - | Declare_datatypes (name, ctors) -> Declare_datatypes (name, ctors) - | Declare_tuple n -> Declare_tuple n - | Assert exp -> Assert (f exp) - -let smt_def_iter_exp f = function - | Define_fun (name, args, ty, exp) -> f exp - | Define_const (name, ty, exp) -> f exp - | Preserve_const (name, ty, exp) -> f exp - | Write_mem w -> - f w.active; - f w.kind; - f w.addr; - f w.data - | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - f active; - f wk; - f addr; - f data_size - | Read_mem r -> - f r.active; - f r.kind; - f r.addr - | Barrier b -> - f b.active; - f b.kind - | Cache_maintenance m -> - f m.active; - f m.kind; - f m.addr - | Branch_announce c -> - f c.active; - f c.addr - | Excl_res (name, node, active) -> f active - | Assert exp -> f exp - | Declare_fun _ | Declare_const _ | Declare_tuple _ | Declare_datatypes _ -> () - -let declare_datatypes = function Datatype (name, ctors) -> Declare_datatypes (name, ctors) | _ -> assert false - -(** For generating SMT with multiple threads (i.e. for litmus tests), - we suffix all the variables in the generated SMT with a thread - identifier to avoid any name clashes between the two threads. *) - -let suffix_variables_exp sfx = - fold_smt_exp (function Var v -> Var (v ^ sfx) | Read_res v -> Read_res (v ^ sfx) | exp -> exp) - -let suffix_variables_read_info sfx (r : read_info) = - let suffix exp = suffix_variables_exp sfx exp in - { r with name = r.name ^ sfx; active = suffix r.active; kind = suffix r.kind; addr = suffix r.addr } - -let suffix_variables_write_info sfx (w : write_info) = - let suffix exp = suffix_variables_exp sfx exp in - { - w with - name = w.name ^ sfx; - active = suffix w.active; - kind = suffix w.kind; - addr = suffix w.addr; - data = suffix w.data; - } - -let suffix_variables_barrier_info sfx (b : barrier_info) = - let suffix exp = suffix_variables_exp sfx exp in - { b with name = b.name ^ sfx; active = suffix b.active; kind = suffix b.kind } - -let suffix_variables_branch_info sfx (c : branch_info) = - let suffix exp = suffix_variables_exp sfx exp in - { c with name = c.name ^ sfx; active = suffix c.active; addr = suffix c.addr } - -let suffix_variables_cache_op_info sfx (m : cache_op_info) = - let suffix exp = suffix_variables_exp sfx exp in - { m with name = m.name ^ sfx; kind = suffix m.kind; active = suffix m.active; addr = suffix m.addr } - -let suffix_variables_def sfx = function - | Define_fun (name, args, ty, exp) -> - Define_fun (name ^ sfx, List.map (fun (arg, ty) -> (sfx ^ arg, ty)) args, ty, suffix_variables_exp sfx exp) - | Declare_fun (name, tys, ty) -> Declare_fun (name ^ sfx, tys, ty) - | Declare_const (name, ty) -> Declare_const (name ^ sfx, ty) - | Define_const (name, ty, exp) -> Define_const (name ^ sfx, ty, suffix_variables_exp sfx exp) - | Preserve_const (name, ty, exp) -> Preserve_const (name, ty, suffix_variables_exp sfx exp) - | Write_mem w -> Write_mem (suffix_variables_write_info sfx w) - | Write_mem_ea (name, node, active, wk, addr, addr_ty, data_size, data_size_ty) -> - Write_mem_ea - ( name ^ sfx, - node, - suffix_variables_exp sfx active, - suffix_variables_exp sfx wk, - suffix_variables_exp sfx addr, - addr_ty, - suffix_variables_exp sfx data_size, - data_size_ty - ) - | Read_mem r -> Read_mem (suffix_variables_read_info sfx r) - | Barrier b -> Barrier (suffix_variables_barrier_info sfx b) - | Cache_maintenance m -> Cache_maintenance (suffix_variables_cache_op_info sfx m) - | Branch_announce c -> Branch_announce (suffix_variables_branch_info sfx c) - | Excl_res (name, node, active) -> Excl_res (name ^ sfx, node, suffix_variables_exp sfx active) - | Declare_datatypes (name, ctors) -> Declare_datatypes (name, ctors) - | Declare_tuple n -> Declare_tuple n - | Assert exp -> Assert (suffix_variables_exp sfx exp) - -let pp_sfun str docs = - let open PPrint in - parens (separate space (string str :: docs)) - -let rec pp_smt_typ = - let open PPrint in - function - | Bool -> string "Bool" - | String -> string "String" - | Real -> string "Real" - | Bitvec n -> string (Printf.sprintf "(_ BitVec %d)" n) - | Datatype (name, _) -> string name - | Tuple tys -> pp_sfun ("Tup" ^ string_of_int (List.length tys)) (List.map pp_smt_typ tys) - | Array (ty1, ty2) -> pp_sfun "Array" [pp_smt_typ ty1; pp_smt_typ ty2] - -let pp_str_smt_typ (str, ty) = - let open PPrint in - parens (string str ^^ space ^^ pp_smt_typ ty) - -let rec pp_smt_exp = - let open PPrint in - function - | Bool_lit b -> string (string_of_bool b) - | Real_lit str -> string str - | String_lit str -> string ("\"" ^ str ^ "\"") - | Bitvec_lit bv -> string (Sail2_values.show_bitlist_prefix '#' bv) - | Var str -> string str - | Shared str -> string str - | Read_res str -> string (str ^ "_ret") - | Enum str -> string str - | Fn (str, exps) -> parens (string str ^^ space ^^ separate_map space pp_smt_exp exps) - | Field (str, exp) -> parens (string str ^^ space ^^ pp_smt_exp exp) - | Struct (str, fields) -> parens (string str ^^ space ^^ separate_map space (fun (_, exp) -> pp_smt_exp exp) fields) - | Ctor (str, exps) -> parens (string str ^^ space ^^ separate_map space pp_smt_exp exps) - | Ite (cond, then_exp, else_exp) -> - parens (separate space [string "ite"; pp_smt_exp cond; pp_smt_exp then_exp; pp_smt_exp else_exp]) - | Extract (i, j, exp) -> parens (string (Printf.sprintf "(_ extract %d %d)" i j) ^^ space ^^ pp_smt_exp exp) - | Tester (kind, exp) -> parens (string (Printf.sprintf "(_ is %s)" kind) ^^ space ^^ pp_smt_exp exp) - | SignExtend (i, exp) -> parens (string (Printf.sprintf "(_ sign_extend %d)" i) ^^ space ^^ pp_smt_exp exp) - | Syntactic (exp, _) -> pp_smt_exp exp - | Forall (binders, exp) -> - parens (string "forall" ^^ space ^^ parens (separate_map space pp_str_smt_typ binders) ^^ space ^^ pp_smt_exp exp) - -let pp_smt_def = - let open PPrint in - let open Printf in - function - | Define_fun (name, args, ty, exp) -> - parens - (string "define-fun" ^^ space ^^ string name ^^ space - ^^ parens (separate_map space pp_str_smt_typ args) - ^^ space ^^ pp_smt_typ ty ^//^ pp_smt_exp exp - ) - | Declare_fun (name, args, ty) -> - parens - (string "declare-fun" ^^ space ^^ string name ^^ space - ^^ parens (separate_map space pp_smt_typ args) - ^^ space ^^ pp_smt_typ ty - ) - | Declare_const (name, ty) -> pp_sfun "declare-const" [string name; pp_smt_typ ty] - | Define_const (name, ty, exp) | Preserve_const (name, ty, exp) -> - pp_sfun "define-const" [string name; pp_smt_typ ty; pp_smt_exp exp] - | Write_mem w -> - pp_sfun "define-const" [string (w.name ^ "_kind"); string "Zwrite_kind"; pp_smt_exp w.kind] - ^^ hardline - ^^ pp_sfun "define-const" [string (w.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp w.active] - ^^ hardline - ^^ pp_sfun "define-const" [string (w.name ^ "_data"); pp_smt_typ w.data_type; pp_smt_exp w.data] - ^^ hardline - ^^ pp_sfun "define-const" [string (w.name ^ "_addr"); pp_smt_typ w.addr_type; pp_smt_exp w.addr] - ^^ hardline - ^^ pp_sfun "declare-const" [string (w.name ^ "_ret"); pp_smt_typ Bool] - | Write_mem_ea (name, _, active, wk, addr, addr_ty, data_size, data_size_ty) -> - pp_sfun "define-const" [string (name ^ "_kind"); string "Zwrite_kind"; pp_smt_exp wk] - ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_active"); pp_smt_typ Bool; pp_smt_exp active] - ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_size"); pp_smt_typ data_size_ty; pp_smt_exp data_size] - ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_addr"); pp_smt_typ addr_ty; pp_smt_exp addr] - | Read_mem r -> - pp_sfun "define-const" [string (r.name ^ "_kind"); string "Zread_kind"; pp_smt_exp r.kind] - ^^ hardline - ^^ pp_sfun "define-const" [string (r.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp r.active] - ^^ hardline - ^^ pp_sfun "define-const" [string (r.name ^ "_addr"); pp_smt_typ r.addr_type; pp_smt_exp r.addr] - ^^ hardline - ^^ pp_sfun "declare-const" [string (r.name ^ "_ret"); pp_smt_typ r.ret_type] - | Barrier b -> - pp_sfun "define-const" [string (b.name ^ "_kind"); string "Zbarrier_kind"; pp_smt_exp b.kind] - ^^ hardline - ^^ pp_sfun "define-const" [string (b.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp b.active] - | Cache_maintenance m -> - pp_sfun "define-const" [string (m.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp m.active] - ^^ hardline - ^^ pp_sfun "define-const" [string (m.name ^ "_kind"); string "Zcache_op_kind"; pp_smt_exp m.kind] - ^^ hardline - ^^ pp_sfun "define-const" [string (m.name ^ "_addr"); pp_smt_typ m.addr_type; pp_smt_exp m.addr] - | Branch_announce c -> - pp_sfun "define-const" [string (c.name ^ "_active"); pp_smt_typ Bool; pp_smt_exp c.active] - ^^ hardline - ^^ pp_sfun "define-const" [string (c.name ^ "_addr"); pp_smt_typ c.addr_type; pp_smt_exp c.addr] - | Excl_res (name, _, active) -> - pp_sfun "declare-const" [string (name ^ "_res"); pp_smt_typ Bool] - ^^ hardline - ^^ pp_sfun "define-const" [string (name ^ "_active"); pp_smt_typ Bool; pp_smt_exp active] - | Declare_datatypes (name, ctors) -> - let pp_ctor (ctor_name, fields) = - match fields with [] -> parens (string ctor_name) | _ -> pp_sfun ctor_name (List.map pp_str_smt_typ fields) - in - pp_sfun "declare-datatypes" - [Printf.ksprintf string "((%s 0))" name; parens (parens (separate_map space pp_ctor ctors))] - | Declare_tuple n -> - let par = separate_map space string (Util.list_init n (fun i -> "T" ^ string_of_int i)) in - let fields = separate space (Util.list_init n (fun i -> Printf.ksprintf string "(tup_%d_%d T%d)" n i i)) in - pp_sfun "declare-datatypes" - [ - Printf.ksprintf string "((Tup%d %d))" n n; - parens - (parens - (separate space - [string "par"; parens par; parens (parens (ksprintf string "tup%d" n ^^ space ^^ fields))] - ) - ); - ] - | Assert exp -> pp_sfun "assert" [pp_smt_exp exp] - -let string_of_smt_def def = Pretty_print_sail.Document.to_string (pp_smt_def def) - -let output_smt_defs out_chan smt = List.iter (fun def -> output_string out_chan (string_of_smt_def def ^ "\n")) smt - -(**************************************************************************) -(* 2. Parser for SMT solver output *) -(**************************************************************************) - -(* Output format from each SMT solver does not seem to be completely - standardised, unlike the SMTLIB input format, but usually is some - form of s-expression based representation. Therefore we define a - simple parser for s-expressions using monadic parser combinators. *) - -type sexpr = List of sexpr list | Atom of string - -let rec string_of_sexpr = function - | List sexprs -> "(" ^ Util.string_of_list " " string_of_sexpr sexprs ^ ")" - | Atom str -> str - -open Parser_combinators - -let lparen = token (function Str.Delim "(" -> Some () | _ -> None) -let rparen = token (function Str.Delim ")" -> Some () | _ -> None) -let atom = token (function Str.Text str -> Some str | _ -> None) - -let rec sexp toks = - let parse = - pchoose - (atom >>= fun str -> preturn (Atom str)) - ( lparen >>= fun _ -> - plist sexp >>= fun xs -> - rparen >>= fun _ -> preturn (List xs) - ) - in - parse toks - -let parse_sexps input = - let delim = Str.regexp "[ \n\t]+\\|(\\|)" in - let tokens = Str.full_split delim input in - let non_whitespace = function Str.Delim d when String.trim d = "" -> false | _ -> true in - let tokens = List.filter non_whitespace tokens in - match plist sexp tokens with Ok (result, _) -> result | Fail -> failwith "Parse failure" - -let parse_sexpr_int width = function - | List [Atom "_"; Atom v; Atom m] when int_of_string m = width && String.length v > 2 && String.sub v 0 2 = "bv" -> - let v = String.sub v 2 (String.length v - 2) in - Some (Big_int.of_string v) - | Atom v when String.length v > 2 && String.sub v 0 2 = "#b" -> - let v = String.sub v 2 (String.length v - 2) in - Some (Big_int.of_string ("0b" ^ v)) - | Atom v when String.length v > 2 && String.sub v 0 2 = "#x" -> - let v = String.sub v 2 (String.length v - 2) in - Some (Big_int.of_string ("0x" ^ v)) - | _ -> None - -let rec value_of_sexpr sexpr = - let open Jib in - let open Value in - function - | CT_fbits width -> begin - match parse_sexpr_int width sexpr with - | Some value -> mk_vector (Sail_lib.get_slice_int' (width, value, 0)) - | None -> failwith ("Cannot parse sexpr as bitvector: " ^ string_of_sexpr sexpr) - end - | CT_struct (_, fields) -> begin - match sexpr with - | List (Atom name :: smt_fields) -> - V_record - (List.fold_left2 - (fun m (field_id, ctyp) sexpr -> StringMap.add (string_of_id field_id) (value_of_sexpr sexpr ctyp) m) - StringMap.empty fields smt_fields - ) - | _ -> failwith ("Cannot parse sexpr as struct " ^ string_of_sexpr sexpr) - end - | CT_enum (_, members) -> begin - match sexpr with - | Atom name -> begin - match List.find_opt (fun member -> Util.zencode_string (string_of_id member) = name) members with - | Some member -> V_member (string_of_id member) - | None -> - failwith - ("Could not find enum member for " ^ name ^ " in " ^ Util.string_of_list ", " string_of_id members) - end - | _ -> failwith ("Cannot parse sexpr as enum " ^ string_of_sexpr sexpr) - end - | CT_bool -> begin - match sexpr with - | Atom "true" -> V_bool true - | Atom "false" -> V_bool false - | _ -> failwith ("Cannot parse sexpr as bool " ^ string_of_sexpr sexpr) - end - | CT_fint width -> begin - match parse_sexpr_int width sexpr with - | Some value -> V_int value - | None -> failwith ("Cannot parse sexpr as fixed-width integer: " ^ string_of_sexpr sexpr) - end - | ctyp -> failwith ("Unsupported type in sexpr: " ^ Jib_util.string_of_ctyp ctyp) - -let rec find_arg id ctyp arg_smt_names = function - | List [Atom "define-fun"; Atom str; List []; _; value] :: _ - when Util.assoc_compare_opt Id.compare id arg_smt_names = Some (Some str) -> - (id, value_of_sexpr value ctyp) - | _ :: sexps -> find_arg id ctyp arg_smt_names sexps - | [] -> (id, V_unit) - -let build_counterexample args arg_ctyps arg_smt_names model = - List.map2 (fun id ctyp -> find_arg id ctyp arg_smt_names model) args arg_ctyps - -let rec run frame = - match frame with - | Interpreter.Done (state, v) -> Some v - | Interpreter.Step (lazy_str, _, _, _) -> run (Interpreter.eval_frame frame) - | Interpreter.Break frame -> run (Interpreter.eval_frame frame) - | Interpreter.Fail (_, _, _, _, msg) -> None - | Interpreter.Effect_request (out, state, stack, eff) -> run (Interpreter.default_effect_interp state eff) - -let check_counterexample ast env fname function_id args arg_ctyps arg_smt_names = - let open Printf in - print_endline ("Checking counterexample: " ^ fname); - let in_chan = ksprintf Unix.open_process_in "%s %s" (counterexample_command !opt_auto_solver) fname in - let lines = ref [] in - begin - try - while true do - lines := input_line in_chan :: !lines - done - with End_of_file -> () - end; - let solver_output = List.rev !lines |> String.concat "\n" in - begin - match solver_output |> parse_sexps with - | Atom "sat" :: (List (Atom "model" :: model) | List model) :: _ -> - let open Value in - let open Interpreter in - print_endline (sprintf "Solver found counterexample: %s" Util.("ok" |> green |> clear)); - let counterexample = build_counterexample args arg_ctyps arg_smt_names model in - List.iter (fun (id, v) -> print_endline (" " ^ string_of_id id ^ " -> " ^ string_of_value v)) counterexample; - let istate = initial_state ast env !primops in - let annot = (Parse_ast.Unknown, Type_check.mk_tannot env bool_typ) in - let call = - E_aux - ( E_app - ( function_id, - List.map - (fun (_, v) -> E_aux (E_internal_value v, (Parse_ast.Unknown, Type_check.empty_tannot))) - counterexample - ), - annot - ) - in - let result = run (Step (lazy "", istate, return call, [])) in - begin - match result with - | Some (V_bool false) | None -> - ksprintf print_endline "Replaying counterexample: %s" Util.("ok" |> green |> clear) - | _ -> () - end - | Atom "unsat" :: _ -> - print_endline "Solver could not find counterexample"; - print_endline "Solver output:"; - print_endline solver_output - | _ -> - print_endline "Unexpected solver output:"; - print_endline solver_output - end; - let status = Unix.close_process_in in_chan in - () diff --git a/src/sail_sv_backend/generate_primop2.ml b/src/sail_sv_backend/generate_primop2.ml new file mode 100644 index 000000000..c08c193de --- /dev/null +++ b/src/sail_sv_backend/generate_primop2.ml @@ -0,0 +1,128 @@ +(**************************************************************************) +(* Sail to verilog *) +(* *) +(* Copyright (c) 2023 *) +(* Alasdair Armstrong *) +(* *) +(* All rights reserved. *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(**************************************************************************) + +open Libsail + +open Ast_util +open Jib +open Jib_util +open Printf +open Sv_ir + +module StringSet = Set.Make (String) + +let generated_library_defs = ref (StringSet.empty, []) + +let register_library_def name def = + let names, _ = !generated_library_defs in + if StringSet.mem name names then name + else ( + let source = def () in + let names, defs = !generated_library_defs in + generated_library_defs := (StringSet.add name names, source :: defs); + name + ) + +let get_generated_library_defs () = List.rev (snd !generated_library_defs) + +let primop_name s = Jib_util.name (mk_id s) + +let print_fbits width = + let name = sprintf "sail_print_fixed_bits_%d" width in + register_library_def name (fun () -> + let b = primop_name "b" in + let s = primop_name "s" in + let in_str = primop_name "in_str" in + let out_str = primop_name "out_str" in + let always_comb = + (* If the width is a multiple of four, format as hexadecimal. + We take care to ensure the formatting is identical to other + Sail backends. *) + if width mod 4 = 0 then ( + let zeros_init = String.make (width / 4) '0' in + let zeros = Jib_util.name (mk_id "zeros") in + let bstr = Jib_util.name (mk_id "bstr") in + [ + SVS_var (zeros, CT_string, None); + SVS_var (bstr, CT_string, None); + svs_raw "bstr.hextoa(b)" ~inputs:[b] ~outputs:[bstr]; + svs_raw (sprintf "zeros = \"%s\"" zeros_init) ~inputs:[zeros]; + svs_raw + (sprintf + "out_str = {in_str, s, $sformatf(\"0x%%s\", zeros.substr(0, %d - bstr.len()), bstr.toupper()), \ + \"\\n\"}" + ((width / 4) - 1) + ) + ~inputs:[in_str; s; zeros; bstr] ~outputs:[out_str]; + SVS_assign (SVP_id Jib_util.return, Enum "unit"); + ] + |> List.map mk_statement + ) + else + [ + svs_raw "out_str = {in_str, s, $sformatf(\"0b%b\", b), \"\\n\"}" ~inputs:[in_str; s; b] ~outputs:[out_str] + |> mk_statement; + ] + in + SVD_module + { + name = SVN_string name; + input_ports = [mk_port s CT_string; mk_port b (CT_fbits width); mk_port in_str CT_string]; + output_ports = [mk_port Jib_util.return CT_unit; mk_port out_str CT_string]; + defs = [SVD_always_comb (mk_statement (SVS_block always_comb))]; + } + ) + +let binary_module l gen = + Some + (fun args ret_ctyp -> + match (args, ret_ctyp) with + | [v1; v2], ret_ctyp -> gen v1 v2 ret_ctyp + | _ -> Reporting.unreachable l __POS__ "Incorrect arity given to binary module generator" + ) + +let ternary_module l gen = + Some + (fun args ret_ctyp -> + match (args, ret_ctyp) with + | [v1; v2; v3], ret_ctyp -> gen v1 v2 v3 ret_ctyp + | _ -> Reporting.unreachable l __POS__ "Incorrect arity given to binary module generator" + ) + +let generate_module ~at:l = function + | "print_bits" -> + ternary_module l (fun _ v2 _ _ -> + match cval_ctyp v2 with + | CT_fbits width -> print_fbits width + | _ -> Reporting.unreachable l __POS__ "Invalid types given to print_bits generator" + ) + | _ -> None diff --git a/src/sail_sv_backend/jib_sv.ml b/src/sail_sv_backend/jib_sv.ml index f7d69bbfb..c508bf94b 100644 --- a/src/sail_sv_backend/jib_sv.ml +++ b/src/sail_sv_backend/jib_sv.ml @@ -72,11 +72,153 @@ open Ast_util open Jib open Jib_compile open Jib_util +open Jib_visitor open PPrint open Printf open Smt_exp open Generate_primop +open Sv_ir + +module IntSet = Util.IntSet + +class footprint_visitor ctx registers reads writes need_stdout need_stderr : jib_visitor = + object + inherit empty_jib_visitor + + method! vctyp _ = SkipChildren + + method! vcval = + function + | V_id (Name (id, _), local_ctyp) -> + begin + match Bindings.find_opt id registers with + | Some ctyp -> + assert (ctyp_equal local_ctyp ctyp); + prerr_endline Util.(string_of_id id |> green |> clear); + reads := IdSet.add id !reads + | None -> () + end; + SkipChildren + | _ -> DoChildren + + method! vinstr = + function + | I_aux (I_funcall (_, _, (id, _), _), _) -> + if ctx_is_extern id ctx then ( + let name = ctx_get_extern id ctx in + prerr_endline ("NAME: " ^ name); + if name = "print" || name = "print_endline" || name = "print_bits" then need_stdout := true + else if name = "prerr" || name = "prerr_endline" || name = "prerr_bits" then need_stderr := true + ); + DoChildren + | _ -> DoChildren + + method! vclexp = + function + | CL_id (Name (id, _), local_ctyp) -> + begin + match Bindings.find_opt id registers with + | Some ctyp -> + assert (ctyp_equal local_ctyp ctyp); + prerr_endline Util.(string_of_id id |> yellow |> clear); + writes := IdSet.add id !writes + | None -> () + end; + SkipChildren + | _ -> DoChildren + end + +type footprint = { + direct_reads : IdSet.t; + direct_writes : IdSet.t; + all_reads : IdSet.t; + all_writes : IdSet.t; + need_stdout : bool; + need_stderr : bool; +} + +type spec_info = { + (* A map from register types to all the registers with that type *) + register_ctyp_map : IdSet.t CTMap.t; + (* A map from register names to types *) + registers : ctyp Bindings.t; + (* Function footprint information *) + footprints : footprint Bindings.t; + (* Specification callgraph *) + callgraph : IdGraph.graph; +} + +let collect_spec_info ctx cdefs = + let register_ctyp_map, registers = + List.fold_left + (fun (ctyp_map, regs) cdef -> + match cdef with + | CDEF_aux (CDEF_register (id, ctyp, _), _) -> + ( CTMap.update ctyp + (function Some ids -> Some (IdSet.add id ids) | None -> Some (IdSet.singleton id)) + ctyp_map, + Bindings.add id ctyp regs + ) + | _ -> (ctyp_map, regs) + ) + (CTMap.empty, Bindings.empty) cdefs + in + let footprints = + List.fold_left + (fun footprints cdef -> + match cdef with + | CDEF_aux (CDEF_fundef (f, _, _, body), _) -> + let reads = ref IdSet.empty in + let writes = ref IdSet.empty in + let need_stdout = ref false in + let need_stderr = ref false in + let _ = visit_cdef (new footprint_visitor ctx registers reads writes need_stdout need_stderr) cdef in + Bindings.add f + { + direct_reads = !reads; + direct_writes = !writes; + all_reads = IdSet.empty; + all_writes = IdSet.empty; + need_stdout = !need_stdout; + need_stderr = !need_stderr; + } + footprints + | _ -> footprints + ) + Bindings.empty cdefs + in + let cfg = callgraph cdefs in + let footprints = + List.fold_left + (fun footprints cdef -> + match cdef with + | CDEF_aux (CDEF_fundef (f, _, _, body), _) -> + let footprint = Bindings.find f footprints in + let callees = cfg |> IdGraph.reachable (IdSet.singleton f) IdSet.empty |> IdSet.remove f in + let all_reads, all_writes, need_stdout, need_stderr = + List.fold_left + (fun (all_reads, all_writes, need_stdout, need_stderr) callee -> + match Bindings.find_opt callee footprints with + | Some footprint -> + ( IdSet.union all_reads footprint.direct_reads, + IdSet.union all_writes footprint.direct_writes, + need_stdout || footprint.need_stdout, + need_stderr || footprint.need_stderr + ) + | _ -> (all_reads, all_writes, need_stdout, need_stderr) + ) + (footprint.direct_reads, footprint.direct_writes, footprint.need_stdout, footprint.need_stderr) + (IdSet.elements callees) + in + Bindings.update f + (fun _ -> Some { footprint with all_reads; all_writes; need_stdout; need_stderr }) + footprints + | _ -> footprints + ) + footprints cdefs + in + { register_ctyp_map; registers; footprints; callgraph = cfg } module type CONFIG = sig val max_unknown_integer_width : int @@ -112,7 +254,7 @@ module Make (Config : CONFIG) = struct in Str.string_match regexp s 0 - let sv_id_string id = + let pp_id_string id = let s = string_of_id id in if valid_sv_identifier s @@ -122,9 +264,11 @@ module Make (Config : CONFIG) = struct then s else Util.zencode_string s - let sv_id id = string (sv_id_string id) + let pp_id id = string (pp_id_string id) + + let pp_sv_name = function SVN_id id -> pp_id id | SVN_string s -> string s - let sv_type_id_string id = "t_" ^ sv_id_string id + let sv_type_id_string id = "t_" ^ pp_id_string id let sv_type_id id = string (sv_type_id_string id) @@ -185,7 +329,9 @@ module Make (Config : CONFIG) = struct (struct let max_unknown_integer_width = Config.max_unknown_integer_width let max_unknown_bitvector_width = Config.max_unknown_bitvector_width + let max_unknown_generic_vector_length = 32 let union_ctyp_classify = is_packed + let register_ref reg_name = Fn ("reg_ref", [String_lit reg_name]) end) (struct let print_bits l = function @@ -239,13 +385,18 @@ module Make (Config : CONFIG) = struct let ( let* ) = Smt_gen.bind let return = Smt_gen.return let mapM = Smt_gen.mapM + let fmap = Smt_gen.fmap - let sv_name = function - | Name (id, _) -> sv_id id - | Have_exception _ -> string "sail_have_exception" - | Current_exception _ -> string "sail_current_exception" - | Throw_location _ -> string "sail_throw_location" - | Return _ -> string "sail_return" + let pp_name = + let ssa_num n = if n = -1 then empty else string ("_" ^ string_of_int n) in + function + | Name (id, n) -> pp_id id ^^ ssa_num n + | Have_exception n -> string "sail_have_exception" ^^ ssa_num n + | Current_exception n -> string "sail_current_exception" ^^ ssa_num n + | Throw_location n -> string "sail_throw_location" ^^ ssa_num n + | Channel (Chan_stdout, n) -> string "sail_stdout" ^^ ssa_num n + | Channel (Chan_stderr, n) -> string "sail_stderr" ^^ ssa_num n + | Return n -> string "sail_return" ^^ ssa_num n let wrap_type ctyp doc = match sv_ctyp ctyp with @@ -269,10 +420,10 @@ module Make (Config : CONFIG) = struct let sv_type_def = function | CTD_enum (id, ids) -> string "typedef" ^^ space ^^ string "enum" ^^ space - ^^ group (lbrace ^^ nest 4 (hardline ^^ separate_map (comma ^^ hardline) sv_id ids) ^^ hardline ^^ rbrace) + ^^ group (lbrace ^^ nest 4 (hardline ^^ separate_map (comma ^^ hardline) pp_id ids) ^^ hardline ^^ rbrace) ^^ space ^^ sv_type_id id ^^ semi | CTD_struct (id, fields) -> - let sv_field (id, ctyp) = wrap_type ctyp (sv_id id) in + let sv_field (id, ctyp) = wrap_type ctyp (pp_id id) in let can_be_packed = List.for_all (fun (_, ctyp) -> is_packed ctyp) fields in string "typedef" ^^ space ^^ string "struct" ^^ (if can_be_packed then space ^^ string "packed" else empty) @@ -290,9 +441,9 @@ module Make (Config : CONFIG) = struct else empty in let kind_id (id, _) = string_of_id id |> Util.zencode_string |> String.uppercase_ascii |> string in - let sv_ctor (id, ctyp) = wrap_type ctyp (sv_id id) in - let tag_type = string ("sailtag_" ^ sv_id_string id) in - let value_type = string ("sailunion_" ^ sv_id_string id) in + let sv_ctor (id, ctyp) = wrap_type ctyp (pp_id id) in + let tag_type = string ("sailtag_" ^ pp_id_string id) in + let value_type = string ("sailunion_" ^ pp_id_string id) in let kind_enum = separate space [ @@ -313,7 +464,7 @@ module Make (Config : CONFIG) = struct let padding_structs = List.map (fun (ctor_id, ctyp) -> - let padding_type = string ("sailpadding_" ^ sv_id_string ctor_id) in + let padding_type = string ("sailpadding_" ^ pp_id_string ctor_id) in let required_padding = max_width - Option.get (bit_width ctyp) in let padded = separate space @@ -345,12 +496,12 @@ module Make (Config : CONFIG) = struct if Config.union_padding then List.map (fun (_, (ctor_id, ctyp, padding_type, required_padding)) -> - separate space [string "function"; string "automatic"; sv_type_id id; sv_id ctor_id] + separate space [string "function"; string "automatic"; sv_type_id id; pp_id ctor_id] ^^ parens (wrap_type ctyp (char 'v')) ^^ semi ^^ nest 4 (hardline ^^ sv_type_id id ^^ space ^^ char 'r' ^^ semi ^^ hardline - ^^ string ("sailunion_" ^ sv_id_string id) + ^^ string ("sailunion_" ^ pp_id_string id) ^^ space ^^ char 'u' ^^ semi ^^ hardline ^^ padding_type ^^ space ^^ char 'p' ^^ semi ^^ hardline ^^ separate space [ @@ -359,7 +510,7 @@ module Make (Config : CONFIG) = struct string_of_id ctor_id |> Util.zencode_string |> String.uppercase_ascii |> string; ] ^^ semi ^^ hardline - ^^ separate space [char 'p' ^^ dot ^^ sv_id ctor_id; equals; char 'v'] + ^^ separate space [char 'p' ^^ dot ^^ pp_id ctor_id; equals; char 'v'] ^^ semi ^^ hardline ^^ ( if required_padding > 0 then separate space @@ -371,7 +522,7 @@ module Make (Config : CONFIG) = struct ^^ semi ^^ hardline else empty ) - ^^ separate space [char 'u' ^^ dot ^^ sv_id ctor_id; equals; char 'p'] + ^^ separate space [char 'u' ^^ dot ^^ pp_id ctor_id; equals; char 'p'] ^^ semi ^^ hardline ^^ separate space [string "r.value"; equals; char 'u'] ^^ semi ^^ hardline ^^ string "return" ^^ space ^^ char 'r' ^^ semi @@ -382,12 +533,12 @@ module Make (Config : CONFIG) = struct else List.map (fun (ctor_id, ctyp) -> - separate space [string "function"; string "automatic"; sv_type_id id; sv_id ctor_id] + separate space [string "function"; string "automatic"; sv_type_id id; pp_id ctor_id] ^^ parens (wrap_type ctyp (char 'v')) ^^ semi ^^ nest 4 (hardline ^^ sv_type_id id ^^ space ^^ char 'r' ^^ semi ^^ hardline - ^^ string ("sailunion_" ^ sv_id_string id) + ^^ string ("sailunion_" ^ pp_id_string id) ^^ space ^^ char 'u' ^^ semi ^^ hardline ^^ separate space [ @@ -396,7 +547,7 @@ module Make (Config : CONFIG) = struct string_of_id ctor_id |> Util.zencode_string |> String.uppercase_ascii |> string; ] ^^ semi ^^ hardline - ^^ separate space [char 'u' ^^ dot ^^ sv_id ctor_id; equals; char 'v'] + ^^ separate space [char 'u' ^^ dot ^^ pp_id ctor_id; equals; char 'v'] ^^ semi ^^ hardline ^^ separate space [string "r.value"; equals; char 'u'] ^^ semi ^^ hardline ^^ string "return" ^^ space ^^ char 'r' ^^ semi @@ -405,7 +556,7 @@ module Make (Config : CONFIG) = struct ) ctors in - let sv_padded_ctor (_, (ctor_id, _, padding_type, _)) = padding_type ^^ space ^^ sv_id ctor_id in + let sv_padded_ctor (_, (ctor_id, _, padding_type, _)) = padding_type ^^ space ^^ pp_id ctor_id in (if Config.union_padding then separate_map (twice hardline) fst padding_structs ^^ twice hardline else empty) ^^ separate space [ @@ -448,7 +599,7 @@ module Make (Config : CONFIG) = struct let constructors = List.map (fun (ctor_id, ctyp) -> - separate space [string "function"; string "automatic"; sv_type_id id; sv_id ctor_id] + separate space [string "function"; string "automatic"; sv_type_id id; pp_id ctor_id] ^^ parens (wrap_type ctyp (char 'v')) ^^ semi ^^ nest 4 @@ -460,7 +611,7 @@ module Make (Config : CONFIG) = struct string_of_id ctor_id |> Util.zencode_string |> String.uppercase_ascii |> string; ] ^^ semi ^^ hardline - ^^ separate space [char 'r' ^^ dot ^^ sv_id ctor_id; equals; char 'v'] + ^^ separate space [char 'r' ^^ dot ^^ pp_id ctor_id; equals; char 'v'] ^^ semi ^^ hardline ^^ string "return" ^^ space ^^ char 'r' ^^ semi ) ^^ hardline ^^ string "endfunction" @@ -519,8 +670,8 @@ module Make (Config : CONFIG) = struct | _ -> None (* Convert a SMTLIB expression into SystemVerilog *) - let rec sv_smt ?(need_parens = false) = - let sv_smt_parens exp = sv_smt ~need_parens:true exp in + let rec pp_smt ?(need_parens = false) = + let pp_smt_parens exp = pp_smt ~need_parens:true exp in let opt_parens doc = if need_parens then parens doc else doc in function | Bitvec_lit bits -> @@ -532,90 +683,91 @@ module Make (Config : CONFIG) = struct | String_lit s -> if Config.nostrings then string "SAIL_UNIT" else ksprintf string "\"%s\"" s | Enum "unit" -> string "SAIL_UNIT" | Fn ("reg_ref", [String_lit r]) -> ksprintf string "SAIL_REG_%s" (Util.zencode_upper_string r) - | Fn ("Bits", [size; bv]) -> squote ^^ lbrace ^^ sv_smt size ^^ comma ^^ space ^^ sv_smt bv ^^ rbrace - | Fn ("concat", xs) -> lbrace ^^ separate_map (comma ^^ space) sv_smt xs ^^ rbrace - | Fn ("not", [Fn ("not", [x])]) -> sv_smt ~need_parens x - | Fn ("not", [Fn ("=", [x; y])]) -> opt_parens (separate space [sv_smt_parens x; string "!="; sv_smt_parens y]) - | Fn ("not", [x]) -> opt_parens (char '!' ^^ sv_smt_parens x) - | Fn ("=", [x; y]) -> opt_parens (separate space [sv_smt_parens x; string "=="; sv_smt_parens y]) - | Fn ("and", xs) -> opt_parens (separate_map (space ^^ string "&&" ^^ space) sv_smt_parens xs) - | Fn ("or", xs) -> opt_parens (separate_map (space ^^ string "||" ^^ space) sv_smt_parens xs) - | Fn ("bvnot", [x]) -> opt_parens (char '~' ^^ sv_smt_parens x) - | Fn ("bvneg", [x]) -> opt_parens (char '-' ^^ sv_smt_parens x) - | Fn ("bvand", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '&'; sv_smt_parens y]) + | Fn ("Bits", [size; bv]) -> squote ^^ lbrace ^^ pp_smt size ^^ comma ^^ space ^^ pp_smt bv ^^ rbrace + | Fn ("concat", xs) -> lbrace ^^ separate_map (comma ^^ space) pp_smt xs ^^ rbrace + | Fn ("not", [Fn ("not", [x])]) -> pp_smt ~need_parens x + | Fn ("not", [Fn ("=", [x; y])]) -> opt_parens (separate space [pp_smt_parens x; string "!="; pp_smt_parens y]) + | Fn ("not", [x]) -> opt_parens (char '!' ^^ pp_smt_parens x) + | Fn ("=", [x; y]) -> opt_parens (separate space [pp_smt_parens x; string "=="; pp_smt_parens y]) + | Fn ("and", xs) -> opt_parens (separate_map (space ^^ string "&&" ^^ space) pp_smt_parens xs) + | Fn ("or", xs) -> opt_parens (separate_map (space ^^ string "||" ^^ space) pp_smt_parens xs) + | Fn ("bvnot", [x]) -> opt_parens (char '~' ^^ pp_smt_parens x) + | Fn ("bvneg", [x]) -> opt_parens (char '-' ^^ pp_smt_parens x) + | Fn ("bvand", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '&'; pp_smt_parens y]) | Fn ("bvnand", [x; y]) -> - opt_parens (char '~' ^^ parens (separate space [sv_smt_parens x; char '&'; sv_smt_parens y])) - | Fn ("bvor", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '|'; sv_smt_parens y]) + opt_parens (char '~' ^^ parens (separate space [pp_smt_parens x; char '&'; pp_smt_parens y])) + | Fn ("bvor", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '|'; pp_smt_parens y]) | Fn ("bvnor", [x; y]) -> - opt_parens (char '~' ^^ parens (separate space [sv_smt_parens x; char '|'; sv_smt_parens y])) - | Fn ("bvxor", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '^'; sv_smt_parens y]) + opt_parens (char '~' ^^ parens (separate space [pp_smt_parens x; char '|'; pp_smt_parens y])) + | Fn ("bvxor", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '^'; pp_smt_parens y]) | Fn ("bvxnor", [x; y]) -> - opt_parens (char '~' ^^ parens (separate space [sv_smt_parens x; char '^'; sv_smt_parens y])) - | Fn ("bvadd", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '+'; sv_smt_parens y]) - | Fn ("bvsub", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '-'; sv_smt_parens y]) - | Fn ("bvmul", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '*'; sv_smt_parens y]) - | Fn ("bvult", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '<'; sv_smt_parens y]) - | Fn ("bvule", [x; y]) -> opt_parens (separate space [sv_smt_parens x; string "<="; sv_smt_parens y]) - | Fn ("bvugt", [x; y]) -> opt_parens (separate space [sv_smt_parens x; char '>'; sv_smt_parens y]) - | Fn ("bvuge", [x; y]) -> opt_parens (separate space [sv_smt_parens x; string ">="; sv_smt_parens y]) - | Fn ("bvslt", [x; y]) -> opt_parens (separate space [sv_signed (sv_smt x); char '<'; sv_signed (sv_smt y)]) - | Fn ("bvsle", [x; y]) -> opt_parens (separate space [sv_signed (sv_smt x); string "<="; sv_signed (sv_smt y)]) - | Fn ("bvsgt", [x; y]) -> opt_parens (separate space [sv_signed (sv_smt x); char '>'; sv_signed (sv_smt y)]) - | Fn ("bvsge", [x; y]) -> opt_parens (separate space [sv_signed (sv_smt x); string ">="; sv_signed (sv_smt y)]) - | Fn ("bvshl", [x; y]) -> opt_parens (separate space [sv_smt_parens x; string "<<"; sv_signed (sv_smt y)]) - | Fn ("bvlshr", [x; y]) -> opt_parens (separate space [sv_smt_parens x; string ">>"; sv_signed (sv_smt y)]) - | Fn ("bvashr", [x; y]) -> opt_parens (separate space [sv_smt_parens x; string ">>>"; sv_signed (sv_smt y)]) - | Fn ("select", [x; i]) -> sv_smt_parens x ^^ lbracket ^^ sv_smt i ^^ rbracket - | Fn ("contents", [Var v]) -> sv_name v ^^ dot ^^ string "bits" - | Fn ("contents", [x]) -> string "sail_bits_value" ^^ parens (sv_smt x) - | Fn ("len", [Var v]) -> sv_name v ^^ dot ^^ string "size" - | Fn ("len", [x]) -> string "sail_bits_size" ^^ parens (sv_smt x) - | Fn ("cons", [x; xs]) -> lbrace ^^ sv_smt x ^^ comma ^^ space ^^ sv_smt xs ^^ rbrace - | Fn (f, args) -> string f ^^ parens (separate_map (comma ^^ space) sv_smt args) - | Store (_, store_fn, arr, i, x) -> string store_fn ^^ parens (separate_map (comma ^^ space) sv_smt [arr; i; x]) - | SignExtend (len, _, x) -> ksprintf string "unsigned'(%d'(signed'({" len ^^ sv_smt x ^^ string "})))" - | ZeroExtend (len, _, x) -> ksprintf string "%d'({" len ^^ sv_smt x ^^ string "})" + opt_parens (char '~' ^^ parens (separate space [pp_smt_parens x; char '^'; pp_smt_parens y])) + | Fn ("bvadd", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '+'; pp_smt_parens y]) + | Fn ("bvsub", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '-'; pp_smt_parens y]) + | Fn ("bvmul", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '*'; pp_smt_parens y]) + | Fn ("bvult", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '<'; pp_smt_parens y]) + | Fn ("bvule", [x; y]) -> opt_parens (separate space [pp_smt_parens x; string "<="; pp_smt_parens y]) + | Fn ("bvugt", [x; y]) -> opt_parens (separate space [pp_smt_parens x; char '>'; pp_smt_parens y]) + | Fn ("bvuge", [x; y]) -> opt_parens (separate space [pp_smt_parens x; string ">="; pp_smt_parens y]) + | Fn ("bvslt", [x; y]) -> opt_parens (separate space [sv_signed (pp_smt x); char '<'; sv_signed (pp_smt y)]) + | Fn ("bvsle", [x; y]) -> opt_parens (separate space [sv_signed (pp_smt x); string "<="; sv_signed (pp_smt y)]) + | Fn ("bvsgt", [x; y]) -> opt_parens (separate space [sv_signed (pp_smt x); char '>'; sv_signed (pp_smt y)]) + | Fn ("bvsge", [x; y]) -> opt_parens (separate space [sv_signed (pp_smt x); string ">="; sv_signed (pp_smt y)]) + | Fn ("bvshl", [x; y]) -> opt_parens (separate space [pp_smt_parens x; string "<<"; sv_signed (pp_smt y)]) + | Fn ("bvlshr", [x; y]) -> opt_parens (separate space [pp_smt_parens x; string ">>"; sv_signed (pp_smt y)]) + | Fn ("bvashr", [x; y]) -> opt_parens (separate space [pp_smt_parens x; string ">>>"; sv_signed (pp_smt y)]) + | Fn ("select", [x; i]) -> pp_smt_parens x ^^ lbracket ^^ pp_smt i ^^ rbracket + | Fn ("contents", [Var v]) -> pp_name v ^^ dot ^^ string "bits" + | Fn ("contents", [x]) -> string "sail_bits_value" ^^ parens (pp_smt x) + | Fn ("len", [Var v]) -> pp_name v ^^ dot ^^ string "size" + | Fn ("len", [x]) -> string "sail_bits_size" ^^ parens (pp_smt x) + | Fn ("cons", [x; xs]) -> lbrace ^^ pp_smt x ^^ comma ^^ space ^^ pp_smt xs ^^ rbrace + | Fn ("str.++", xs) -> lbrace ^^ separate_map (comma ^^ space) pp_smt xs ^^ rbrace + | Fn (f, args) -> string f ^^ parens (separate_map (comma ^^ space) pp_smt args) + | Store (_, store_fn, arr, i, x) -> string store_fn ^^ parens (separate_map (comma ^^ space) pp_smt [arr; i; x]) + | SignExtend (len, _, x) -> ksprintf string "unsigned'(%d'(signed'({" len ^^ pp_smt x ^^ string "})))" + | ZeroExtend (len, _, x) -> ksprintf string "%d'({" len ^^ pp_smt x ^^ string "})" | Extract (n, m, Bitvec_lit bits) -> - sv_smt (Bitvec_lit (Sail2_operators_bitlists.subrange_vec_dec bits (Big_int.of_int n) (Big_int.of_int m))) + pp_smt (Bitvec_lit (Sail2_operators_bitlists.subrange_vec_dec bits (Big_int.of_int n) (Big_int.of_int m))) | Extract (n, m, Var v) -> - if n = m then sv_name v ^^ lbracket ^^ string (string_of_int n) ^^ rbracket - else sv_name v ^^ lbracket ^^ string (string_of_int n) ^^ colon ^^ string (string_of_int m) ^^ rbracket + if n = m then pp_name v ^^ lbracket ^^ string (string_of_int n) ^^ rbracket + else pp_name v ^^ lbracket ^^ string (string_of_int n) ^^ colon ^^ string (string_of_int m) ^^ rbracket | Extract (n, m, x) -> - if n = m then braces (sv_smt x) ^^ lbracket ^^ string (string_of_int n) ^^ rbracket - else braces (sv_smt x) ^^ lbracket ^^ string (string_of_int n) ^^ colon ^^ string (string_of_int m) ^^ rbracket - | Var v -> sv_name v + if n = m then braces (pp_smt x) ^^ lbracket ^^ string (string_of_int n) ^^ rbracket + else braces (pp_smt x) ^^ lbracket ^^ string (string_of_int n) ^^ colon ^^ string (string_of_int m) ^^ rbracket + | Var v -> pp_name v | Tester (ctor, v) -> opt_parens - (separate space [sv_smt v ^^ dot ^^ string "tag"; string "=="; string (ctor |> String.uppercase_ascii)]) + (separate space [pp_smt v ^^ dot ^^ string "tag"; string "=="; string (ctor |> String.uppercase_ascii)]) | Unwrap (ctor, packed, v) -> - let packed_ctor = if Config.union_padding then sv_id ctor ^^ dot ^^ sv_id ctor else sv_id ctor in - if packed then sv_smt v ^^ dot ^^ string "value" ^^ dot ^^ packed_ctor else sv_smt v ^^ dot ^^ sv_id ctor - | Field (_, field, v) -> sv_smt v ^^ dot ^^ sv_id field + let packed_ctor = if Config.union_padding then pp_id ctor ^^ dot ^^ pp_id ctor else pp_id ctor in + if packed then pp_smt v ^^ dot ^^ string "value" ^^ dot ^^ packed_ctor else pp_smt v ^^ dot ^^ pp_id ctor + | Field (_, field, v) -> pp_smt v ^^ dot ^^ pp_id field | Ite (cond, then_exp, else_exp) -> - separate space [sv_smt_parens cond; char '?'; sv_smt_parens then_exp; char ':'; sv_smt_parens else_exp] + separate space [pp_smt_parens cond; char '?'; pp_smt_parens then_exp; char ':'; pp_smt_parens else_exp] | Enum e -> failwith "Unknown enum" | Empty_list -> string "{}" | Hd (op, arg) -> begin match tails arg with - | Some (index, v) -> sv_name v ^^ brackets (string (string_of_int index)) - | None -> string op ^^ parens (sv_smt arg) + | Some (index, v) -> pp_name v ^^ brackets (string (string_of_int index)) + | None -> string op ^^ parens (pp_smt arg) end - | Tl (op, arg) -> string op ^^ parens (sv_smt arg) + | Tl (op, arg) -> string op ^^ parens (pp_smt arg) | _ -> empty let sv_cval cval = let* smt = Smt.smt_cval cval in - return (sv_smt smt) + return (pp_smt smt) let rec sv_clexp = function - | CL_id (id, _) -> sv_name id - | CL_field (clexp, field) -> sv_clexp clexp ^^ dot ^^ sv_id field + | CL_id (id, _) -> pp_name id + | CL_field (clexp, field) -> sv_clexp clexp ^^ dot ^^ pp_id field | clexp -> string ("// CLEXP " ^ Jib_util.string_of_clexp clexp) - let sv_update_fbits = function + let svir_update_fbits = function | [bv; index; bit] -> begin match (cval_ctyp bv, cval_ctyp index) with - | CT_fbits 1, _ -> Smt_gen.fmap sv_smt (Smt.smt_cval bit) + | CT_fbits 1, _ -> Smt.smt_cval bit | CT_fbits sz, CT_constant c -> let c = Big_int.to_int c in let* bv_smt = Smt.smt_cval bv in @@ -627,7 +779,7 @@ module Make (Config : CONFIG) = struct else if c = sz - 1 then Fn ("concat", [bit_smt; bv_smt_2]) else Fn ("concat", [bv_smt_1; bit_smt; bv_smt_2]) in - return (sv_smt smt) + return smt | _, _ -> failwith "update_fbits 1" end | _ -> failwith "update_fbits 2" @@ -648,53 +800,73 @@ module Make (Config : CONFIG) = struct match clexp with | CL_addr (CL_id (id, CT_ref reg_ctyp)) -> let encoded = Util.zencode_string (string_of_ctyp reg_ctyp) in - ksprintf string "sail_reg_assign_%s" encoded ^^ parens (sv_name id ^^ comma ^^ space ^^ value) ^^ semi + ksprintf string "sail_reg_assign_%s" encoded ^^ parens (pp_name id ^^ comma ^^ space ^^ value) ^^ semi | _ -> sv_clexp clexp ^^ space ^^ equals ^^ space ^^ value ^^ semi - let rec sv_instr ctx (I_aux (aux, (_, l))) = - let ld = sv_line_directive l in + let rec svir_clexp = function + | CL_id (id, _) -> SVP_id id + | CL_field (clexp, field) -> SVP_field (svir_clexp clexp, field) + | CL_void -> SVP_void + | CL_rmw _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "RMW" + | CL_addr _ | CL_tuple _ -> Reporting.unreachable Parse_ast.Unknown __POS__ "addr/tuple" + + let svir_creturn = function + | CR_one clexp -> svir_clexp clexp + | CR_multi clexps -> SVP_multi (List.map svir_clexp clexps) + + let rec svir_instr ctx (I_aux (aux, (_, l))) = + let wrap aux = return (Some (SVS_aux (aux, l))) in match aux with - | I_comment str -> return (concat_map string ["/* "; str; " */"]) - | I_decl (ctyp, id) -> return (ld ^^ wrap_type ctyp (sv_name id) ^^ semi) + | I_comment str -> wrap (SVS_comment str) + | I_decl (ctyp, id) -> wrap (SVS_var (id, ctyp, None)) | I_init (ctyp, id, cval) -> - let* value = sv_cval cval in - return (ld ^^ separate space [wrap_type ctyp (sv_name id); equals; value] ^^ semi) + let* value = Smt.smt_cval cval in + wrap (SVS_var (id, ctyp, Some value)) | I_return cval -> - let* value = sv_cval cval in - return (string "return" ^^ space ^^ value ^^ semi) - | I_end id -> return (string "return" ^^ space ^^ sv_name id ^^ semi) - | I_exit _ -> return (if Config.comb then string "sail_reached_unreachable = 1;" else string "$finish" ^^ semi) + let* value = Smt.smt_cval cval in + wrap (SVS_return value) + | I_end id -> wrap (SVS_return (Var id)) + | I_exit _ -> wrap (svs_raw "$finish") + | I_copy (CL_void, cval) -> return None | I_copy (clexp, cval) -> let* value = Smt_gen.bind (Smt.smt_cval cval) (Smt.smt_conversion ~into:(clexp_ctyp clexp) ~from:(cval_ctyp cval)) in - return (sv_assign clexp (sv_smt value)) - | I_funcall (clexp, _, (id, _), args) -> + wrap (SVS_assign (svir_clexp clexp, value)) + | I_funcall (creturn, _, (id, _), args) -> if ctx_is_extern id ctx then ( let name = ctx_get_extern id ctx in - match Smt.builtin name with + match Smt.builtin ~allow_io:false name with | Some generator -> - let* value = Smt_gen.fmap Smt_exp.simp (generator args (clexp_ctyp clexp)) in + let clexp = + match creturn with + | CR_one clexp -> clexp + | CR_multi _ -> Reporting.unreachable l __POS__ "Multiple return generator primitive found" + in + let* value = Smt_gen.fmap (Smt_exp.simp (fun _ -> None)) (generator args (clexp_ctyp clexp)) in begin (* We can optimize R = store(R, i x) into R[i] = x *) match (clexp, value) with | CL_id (v, _), Store (_, _, Var v', i, x) when Name.compare v v' = 0 -> - return - (ld - ^^ separate space [sv_clexp clexp ^^ lbracket ^^ sv_smt i ^^ rbracket; equals; sv_smt x] - ^^ semi - ) - | _, _ -> return (ld ^^ sv_assign clexp (sv_smt value)) + wrap (SVS_assign (SVP_index (svir_clexp clexp, i), x)) + | _, _ -> wrap (SVS_assign (svir_clexp clexp, value)) end - | None -> - let* args = mapM Smt.smt_cval args in - let value = Fn ("sail_" ^ name, args) in - return (ld ^^ sv_assign clexp (sv_smt value)) + | None -> ( + match Generate_primop2.generate_module ~at:l name with + | Some generator -> + let generated_name = generator args (creturn_ctyp creturn) in + let* args = mapM Smt.smt_cval args in + wrap (SVS_call (svir_creturn creturn, SVN_string generated_name, args)) + | None -> + let* args = mapM Smt.smt_cval args in + let value = Fn ("sail_" ^ name, args) in + wrap (SVS_call (svir_creturn creturn, SVN_id id, args)) + ) ) else if Id.compare id (mk_id "update_fbits") = 0 then - let* rhs = sv_update_fbits args in - return (ld ^^ sv_clexp clexp ^^ space ^^ equals ^^ space ^^ rhs ^^ semi) - else if Id.compare id (mk_id "internal_vector_init") = 0 then return empty + let* rhs = svir_update_fbits args in + wrap (SVS_assign (svir_creturn creturn, rhs)) + else if Id.compare id (mk_id "internal_vector_init") = 0 then return None else if Id.compare id (mk_id "internal_vector_update") = 0 then ( match args with | [arr; i; x] -> begin @@ -708,46 +880,29 @@ module Make (Config : CONFIG) = struct ) in let* x = Smt.smt_cval x in - return - (sv_clexp clexp ^^ lbracket ^^ sv_smt i ^^ rbracket ^^ space ^^ equals ^^ space ^^ sv_smt x ^^ semi) + wrap (SVS_assign (SVP_index (svir_creturn creturn, i), x)) | _ -> Reporting.unreachable l __POS__ "Invalid vector type for internal vector update" end | _ -> Reporting.unreachable l __POS__ "Invalid number of arguments to internal vector update" ) else - let* args = mapM sv_cval args in - let call = sv_id id ^^ parens (separate (comma ^^ space) args) in - return (ld ^^ sv_assign clexp call) - | I_if (cond, [], else_instrs, _) -> - let* cond = sv_cval (V_call (Bnot, [cond])) in - return - (string "if" ^^ space ^^ parens cond ^^ space ^^ string "begin" - ^^ nest 4 (hardline ^^ separate_map hardline (sv_checked_instr ctx) else_instrs) - ^^ hardline ^^ string "end" ^^ semi - ) - | I_if (cond, then_instrs, [], _) -> - let* cond = sv_cval cond in - return - (string "if" ^^ space ^^ parens cond ^^ space ^^ string "begin" - ^^ nest 4 (hardline ^^ separate_map hardline (sv_checked_instr ctx) then_instrs) - ^^ hardline ^^ string "end" ^^ semi - ) - | I_if (cond, then_instrs, else_instrs, _) -> - let* cond = sv_cval cond in - return - (string "if" ^^ space ^^ parens cond ^^ space ^^ string "begin" - ^^ nest 4 (hardline ^^ separate_map hardline (sv_checked_instr ctx) then_instrs) - ^^ hardline ^^ string "end" ^^ space ^^ string "else" ^^ space ^^ string "begin" - ^^ nest 4 (hardline ^^ separate_map hardline (sv_checked_instr ctx) else_instrs) - ^^ hardline ^^ string "end" ^^ semi - ) + let* args = mapM Smt.smt_cval args in + wrap (SVS_call (svir_creturn creturn, SVN_id id, args)) | I_block instrs -> - return - (string "begin" - ^^ nest 4 (hardline ^^ separate_map hardline (sv_checked_instr ctx) instrs) - ^^ hardline ^^ string "end" ^^ semi - ) - | I_raw s -> return (string s ^^ semi) + let* statements = fmap Util.option_these (mapM (svir_instr ctx) instrs) in + wrap (SVS_block statements) + | I_if (cond, then_instrs, else_instrs, _) -> + let* cond = Smt.smt_cval cond in + let to_block statements = + match Util.option_these statements with + | [] -> None + | [statement] -> Some statement + | statements -> Some (SVS_aux (SVS_block statements, Parse_ast.Unknown)) + in + let* then_block = fmap to_block (mapM (svir_instr ctx) then_instrs) in + let* else_block = fmap to_block (mapM (svir_instr ctx) else_instrs) in + wrap (SVS_if (cond, then_block, else_block)) + | I_raw str -> wrap (svs_raw str) | I_undefined ctyp -> Reporting.unreachable l __POS__ "Unreachable instruction should not reach SystemVerilog backend" | I_jump _ | I_goto _ | I_label _ -> @@ -757,10 +912,503 @@ module Make (Config : CONFIG) = struct | I_clear _ | I_reset _ | I_reinit _ -> Reporting.unreachable l __POS__ "Cleanup commands should not appear in SystemVerilog backend" - and sv_checked_instr ctx (I_aux (_, (_, l)) as instr) = + let rec pp_place = function + | SVP_id id -> pp_name id + | SVP_index (place, i) -> pp_place place ^^ lbracket ^^ pp_smt i ^^ rbracket + | SVP_field (place, field) -> pp_place place ^^ dot ^^ pp_id field + | SVP_multi places -> parens (separate_map (comma ^^ space) pp_place places) + | SVP_void -> string "void" + + let pp_sv_name = function SVN_id id -> pp_id id | SVN_string s -> string s + + let rec pp_statement ?(terminator = semi ^^ hardline) (SVS_aux (aux, l)) = + let ld = sv_line_directive l in + match aux with + | SVS_comment str -> concat_map string ["/* "; str; " */"] + | SVS_var (id, ctyp, init_opt) -> begin + match init_opt with + | Some init -> ld ^^ separate space [wrap_type ctyp (pp_name id); equals; pp_smt init] ^^ terminator + | None -> ld ^^ wrap_type ctyp (pp_name id) ^^ terminator + end + | SVS_return smt -> string "return" ^^ space ^^ pp_smt smt ^^ terminator + | SVS_assign (place, value) -> ld ^^ separate space [pp_place place; equals; pp_smt value] ^^ terminator + | SVS_call (place, ctor, args) -> + ld + ^^ separate space [pp_place place; equals; pp_sv_name ctor] + ^^ parens (separate_map (comma ^^ space) pp_smt args) + ^^ terminator + | SVS_if (_, None, None) -> empty + | SVS_if (cond, None, Some else_block) -> + let cond = pp_smt (Fn ("not", [cond])) in + string "if" ^^ space ^^ parens cond ^^ space ^^ pp_statement else_block + | SVS_if (cond, Some then_block, None) -> + string "if" ^^ space ^^ parens (pp_smt cond) ^^ space ^^ pp_statement then_block + | SVS_if (cond, Some then_block, Some else_block) -> empty + | SVS_case { head_exp; cases; fallthrough } -> + let pp_case (ids, statement) = + separate space [separate_map (comma ^^ space) pp_id ids; colon; pp_statement statement] + in + let pp_fallthrough = function + | None -> empty + | Some statement -> hardline ^^ separate space [string "default"; colon; pp_statement statement] + in + string "case" ^^ space + ^^ parens (pp_smt head_exp) + ^^ nest 4 (hardline ^^ separate_map hardline pp_case cases ^^ pp_fallthrough fallthrough) + ^^ hardline ^^ string "endcase" ^^ terminator + | SVS_block statements -> + let block_terminator last = if last then semi else semi ^^ hardline in + string "begin" + ^^ nest 4 + (hardline + ^^ concat (Util.map_last (fun last -> pp_statement ~terminator:(block_terminator last)) statements) + ) + ^^ hardline ^^ string "end" ^^ terminator + | SVS_raw (s, _, _) -> string s ^^ terminator + | SVS_skip -> empty + + let sv_instr ctx instr = + let* statement_opt = svir_instr ctx instr in + match statement_opt with Some statement -> return (pp_statement statement) | None -> return empty + + let sv_checked_instr ctx (I_aux (_, (_, l)) as instr) = let v, _ = Smt_gen.run (sv_instr ctx instr) l in v + let smt_ssanode cfg preds = + let open Jib_ssa in + function + | Pi _ -> return None + | Phi (id, ctyp, ids) -> ( + let get_pi n = + match get_vertex cfg n with + | Some ((ssa_elems, _), _, _) -> List.concat (List.map (function Pi guards -> guards | _ -> []) ssa_elems) + | None -> failwith "Predecessor node does not exist" + in + let pis = List.map get_pi (IntSet.elements preds) in + let* mux = + List.fold_right2 + (fun pi id chain -> + let* chain = chain in + let* pi = mapM Smt.smt_cval pi in + let pathcond = smt_conj pi in + match chain with Some smt -> return (Some (Ite (pathcond, Var id, smt))) | None -> return (Some (Var id)) + ) + pis ids (return None) + in + match mux with None -> assert false | Some mux -> return (Some (id, ctyp, mux)) + ) + + let svir_cfnode spec_info ctx = + let open Jib_ssa in + function + | CF_start inits -> + let svir_start (id, ctyp) = + prerr_endline (string_of_name id); + SVS_aux (SVS_var (id, ctyp, None), Parse_ast.Unknown) + in + let svir_inits = List.map svir_start (NameMap.bindings inits) in + return svir_inits + | CF_block (instrs, _) -> + let* statements = fmap Util.option_these (mapM (svir_instr ctx) instrs) in + return statements + | _ -> return [] + + class register_fix_visitor spec_info ssa_nums : svir_visitor = + object + inherit empty_svir_visitor + + method! vctyp _ = SkipChildren + + method! vname name = + let name, n = Jib_ssa.unssa_name name in + ssa_nums := + NameMap.update name + (function None -> Some (IntSet.singleton n) | Some ns -> Some (IntSet.add n ns)) + !ssa_nums; + None + end + + class thread_registers ctx spec_info : jib_visitor = + object + inherit empty_jib_visitor + + method! vctyp _ = SkipChildren + + method! vinstr (I_aux (aux, iannot) as no_change) = + match aux with + | I_funcall (CR_one clexp, ext, (f, []), args) -> begin + match Bindings.find_opt f spec_info.footprints with + | Some footprint -> + prerr_endline ("Threading " ^ string_of_id f); + let reads = + List.map + (fun id -> V_id (Name (id, -1), Bindings.find id spec_info.registers)) + (IdSet.elements footprint.all_reads) + in + let writes = + List.map + (fun id -> CL_id (Name (id, -1), Bindings.find id spec_info.registers)) + (IdSet.elements footprint.all_writes) + in + let channels = + (if footprint.need_stdout then [Channel (Chan_stdout, -1)] else []) + @ if footprint.need_stderr then [Channel (Chan_stderr, -1)] else [] + in + let input_channels = List.map (fun c -> V_id (c, CT_string)) channels in + let output_channels = List.map (fun c -> CL_id (c, CT_string)) channels in + ChangeTo + (I_aux + ( I_funcall + (CR_multi ((clexp :: writes) @ output_channels), ext, (f, []), args @ reads @ input_channels), + iannot + ) + ) + | None -> + if ctx_is_extern f ctx then ( + let name = ctx_get_extern f ctx in + if name = "print" || name = "print_endline" || name = "print_bits" then + ChangeTo + (I_aux + ( I_funcall + ( CR_multi (clexp :: [CL_id (Channel (Chan_stdout, -1), CT_string)]), + ext, + (f, []), + args @ [V_id (Channel (Chan_stdout, -1), CT_string)] + ), + iannot + ) + ) + else SkipChildren + ) + else ( + prerr_endline ("No footprint: " ^ string_of_id f); + SkipChildren + ) + end + | _ -> DoChildren + end + + class find_final_names ssa_nums final_names : svir_visitor = + object + inherit empty_svir_visitor + + method! vctyp _ = SkipChildren + + method! vname ssa_name = + let name, n = Jib_ssa.unssa_name ssa_name in + match NameMap.find_opt name ssa_nums with + | Some ns when n = IntSet.max_elt ns -> + final_names := NameMap.add name ssa_name !final_names; + None + | _ -> None + end + + (* This rewrite lifts statements out of an always_comb block in a + module, that need to appear in the toplevel of the module as + definitions. *) + class hoist_module_statements decls instantiations : svir_visitor = + object + inherit empty_svir_visitor + + method! vctyp _ = SkipChildren + + method! vstatement (SVS_aux (aux, l)) = + match aux with + | SVS_var (Name (id, n), ctyp, exp_opt) -> + decls := Bindings.add id ctyp !decls; + begin + match exp_opt with + | Some exp -> ChangeTo (SVS_aux (SVS_assign (SVP_id (Name (id, n)), exp), l)) + | None -> ChangeTo (SVS_aux (SVS_skip, l)) + end + | SVS_call (place, f, args) -> + Queue.add (place, f, args) instantiations; + ChangeTo (SVS_aux (SVS_skip, l)) + | _ -> DoChildren + end + + let svir_module spec_info ctx f params param_ctyps ret_ctyp body = + prerr_endline Util.(string_of_id f |> red |> clear); + let footprint = Bindings.find f spec_info.footprints in + let always_comb = Queue.create () in + let declvars = ref Bindings.empty in + let ssa_vars = ref NameMap.empty in + + (* Add a statment to the always_comb block *) + let add_comb_statement stmt = + let stmt = visit_sv_statement (new register_fix_visitor spec_info ssa_vars) stmt in + Queue.add stmt always_comb + in + + List.iter prerr_endline + (List.map (fun (I_aux (_, (_, l)) as instr) -> string_of_instr instr ^ " " ^ Reporting.short_loc_to_string l) body); + + let open Jib_ssa in + let start, cfg = ssa (visit_instrs (new thread_registers ctx spec_info) body) in + let visit_order = + try topsort cfg + with Not_a_DAG n -> + raise + (Reporting.err_general Parse_ast.Unknown + (Printf.sprintf "%s: control flow graph is not acyclic (node %d is in cycle)" (string_of_id f) n) + ) + in + + (* Generate the contents of the always_comb block *) + let _ = + Smt_gen.iterM + (fun n -> + match get_vertex cfg n with + | None -> return () + | Some ((ssa_elems, cfnode), preds, _) -> + let* muxers = fmap Util.option_these (mapM (smt_ssanode cfg preds) ssa_elems) in + List.iter + (fun (id, ctyp, mux) -> + add_comb_statement + (SVS_aux (SVS_assign (SVP_id id, Smt_exp.simp (fun _ -> None) mux), Parse_ast.Unknown)) + ) + muxers; + let* block = svir_cfnode spec_info ctx cfnode in + List.iter add_comb_statement block; + return () + ) + visit_order + |> fun m -> Smt_gen.run m (id_loc f) + in + + (* Create the always_comb definition, lifting/hoisting the module instantations out of the block *) + let final_names = ref NameMap.empty in + let module_instantiations = Queue.create () in + let always_comb_def = + let fix s = + s + |> visit_sv_statement (new find_final_names !ssa_vars final_names) + |> visit_sv_statement (new hoist_module_statements declvars module_instantiations) + in + match List.of_seq (Queue.to_seq always_comb) with + | [] -> [] + | [statement] -> [SVD_always_comb (fix statement)] + | statements -> [SVD_always_comb (fix (SVS_aux (SVS_block statements, Parse_ast.Unknown)))] + in + let module_instantiation_defs, _ = + Queue.fold + (fun (defs, numbers) (place, module_name, inputs) -> + let number = match SVNameMap.find_opt module_name numbers with None -> 0 | Some n -> n in + let instance_name = + string_of_sv_name (modify_sv_name ~prefix:("inst_" ^ string_of_int number ^ "_") module_name) + in + let outputs = match place with SVP_multi places -> places | place -> [place] in + ( SVD_instantiate { module_name; instance_name; input_connections = inputs; output_connections = outputs } + :: defs, + SVNameMap.add module_name (number + 1) numbers + ) + ) + ([], SVNameMap.empty) module_instantiations + in + + (* Create the input and output ports *) + let input_ports : sv_module_port list = + List.map2 (fun id ctyp -> { name = Name (id, 0); external_name = string_of_id id; typ = ctyp }) params param_ctyps + @ List.map + (fun id -> + { + name = Name (id, 0); + external_name = string_of_id (prepend_id "in_" id); + typ = Bindings.find id spec_info.registers; + } + ) + (IdSet.elements footprint.all_reads) + @ ( if footprint.need_stdout then + [{ name = Channel (Chan_stdout, 0); external_name = "in_stdout"; typ = CT_string }] + else [] + ) + @ + if footprint.need_stderr then [{ name = Channel (Chan_stderr, 0); external_name = "in_stderr"; typ = CT_string }] + else [] + in + + let output_ports : sv_module_port list = + [{ name = NameMap.find Jib_util.return !final_names; external_name = "sail_return"; typ = ret_ctyp }] + @ List.map + (fun id -> + { + name = NameMap.find (Name (id, -1)) !final_names; + external_name = string_of_id (prepend_id "out_" id); + typ = Bindings.find id spec_info.registers; + } + ) + (IdSet.elements footprint.all_writes) + @ ( if footprint.need_stdout then + [ + { + name = NameMap.find (Channel (Chan_stdout, -1)) !final_names; + external_name = "out_stdout"; + typ = CT_string; + }; + ] + else [] + ) + @ + if footprint.need_stderr then + [ + { name = NameMap.find (Channel (Chan_stderr, -1)) !final_names; external_name = "out_stderr"; typ = CT_string }; + ] + else [] + in + + (* Create toplevel variables for all things in the always_comb + block that aren't ports. We can push variables that aren't used + in the block back down later if we want. *) + let module_vars = Queue.create () in + NameMap.iter + (fun name nums -> + let get_ctyp = function + | Return _ -> ret_ctyp + | Name (id, _) -> begin + match Bindings.find_opt id spec_info.registers with + | Some ctyp -> ctyp + | None -> ( + match Bindings.find_opt id !declvars with + | Some ctyp -> ctyp + | None -> ( + match Util.list_index (fun p -> Id.compare p id = 0) params with + | Some i -> List.nth param_ctyps i + | None -> failwith (string_of_id id) + ) + ) + end + | Channel _ -> CT_string + | Have_exception _ -> CT_bool + | Throw_location _ -> CT_string + | Current_exception _ -> failwith "current_exception" + in + let ctyp = get_ctyp name in + IntSet.iter + (fun n -> + let name = Jib_ssa.ssa_name n name in + if + List.for_all (fun (port : sv_module_port) -> Name.compare name port.name <> 0) (input_ports @ output_ports) + then Queue.add (SVD_var (name, ctyp)) module_vars + ) + nums + ) + !ssa_vars; + + let defs = List.of_seq (Queue.to_seq module_vars) @ List.rev module_instantiation_defs @ always_comb_def in + { name = SVN_id f; input_ports; output_ports; defs } + + let toplevel_module spec_info = + match Bindings.find_opt (mk_id "main") spec_info.footprints with + | None -> None + | Some footprint -> + let register_inputs, register_outputs = + Bindings.fold + (fun reg ctyp (ins, outs) -> + ( SVD_var (Name (prepend_id "in_" reg, -1), ctyp) :: ins, + SVD_var (Name (prepend_id "out_" reg, -1), ctyp) :: outs + ) + ) + spec_info.registers ([], []) + in + let channel_outputs = + (if footprint.need_stdout then [SVD_var (Name (mk_id "out_stdout", -1), CT_string)] else []) + @ if footprint.need_stderr then [SVD_var (Name (mk_id "out_stderr", -1), CT_string)] else [] + in + let instantiate_main = + SVD_instantiate + { + module_name = SVN_id (mk_id "main"); + instance_name = "inst_main"; + input_connections = + ([Enum "unit"] + @ List.map (fun reg -> Var (Name (prepend_id "in_" reg, -1))) (IdSet.elements footprint.all_reads) + @ (if footprint.need_stdout then [String_lit ""] else []) + @ if footprint.need_stderr then [String_lit ""] else [] + ); + output_connections = + ([SVP_id Jib_util.return] + @ List.map (fun reg -> SVP_id (Name (prepend_id "out_" reg, -1))) (IdSet.elements footprint.all_writes) + @ (if footprint.need_stdout then [SVP_id (Name (mk_id "out_stdout", -1))] else []) + @ if footprint.need_stderr then [SVP_id (Name (mk_id "out_stderr", -1))] else [] + ); + } + in + let always_comb = + let unchanged_registers = + Bindings.fold + (fun reg _ unchanged -> + if not (IdSet.mem reg footprint.all_writes) then + mk_statement + (SVS_assign (SVP_id (Name (prepend_id "out_" reg, -1)), Var (Name (prepend_id "in_" reg, -1)))) + :: unchanged + else unchanged + ) + spec_info.registers [] + in + let channel_writes = + ( if footprint.need_stdout then + [mk_statement (svs_raw "$write(out_stdout)" ~inputs:[Name (mk_id "out_stdout", -1)])] + else [] + ) + @ + if footprint.need_stderr then + [mk_statement (svs_raw "$write(out_stderr)" ~inputs:[Name (mk_id "out_stderr", -1)])] + else [] + in + SVD_always_comb + (mk_statement (SVS_block (unchanged_registers @ channel_writes @ [mk_statement (svs_raw "$finish")]))) + in + Some + { + name = SVN_string "sail_toplevel"; + input_ports = []; + output_ports = []; + defs = + register_inputs @ register_outputs @ channel_outputs + @ [SVD_var (Jib_util.return, CT_unit)] + @ [instantiate_main; always_comb]; + } + + let rec pp_module m = + let ports = + match (m.input_ports, m.output_ports) with + | [], [] -> semi + | inputs, outputs -> + let ports = List.map (fun port -> ("input", port)) inputs @ List.map (fun port -> ("output", port)) outputs in + let pp_port (dir, { name; external_name; typ }) = + let external_name_hint = + if external_name = "" then empty else space ^^ string (Printf.sprintf "/* %s */" external_name) + in + string dir ^^ space ^^ wrap_type typ (pp_name name) ^^ external_name_hint + in + nest 4 (char '(' ^^ hardline ^^ separate_map (comma ^^ hardline) pp_port ports) + ^^ hardline ^^ char ')' ^^ semi + in + string "module" ^^ space ^^ pp_sv_name m.name ^^ ports + ^^ nest 4 (hardline ^^ separate_map (semi ^^ hardline) pp_def m.defs) + ^^ hardline ^^ string "endmodule" + + and pp_fundef f = string "function" + + and pp_def = function + | SVD_var (id, ctyp) -> wrap_type ctyp (pp_name id) + | SVD_always_comb statement -> string "always_comb" ^^ space ^^ pp_statement statement + | SVD_instantiate { module_name; instance_name; input_connections; output_connections } -> + let inputs = List.map (fun exp -> pp_smt exp) input_connections in + let outputs = List.map pp_place output_connections in + let connections = + match inputs @ outputs with [] -> empty | connections -> parens (separate (comma ^^ space) connections) + in + pp_sv_name module_name ^^ space ^^ string instance_name ^^ connections + | SVD_fundef f -> pp_fundef f + | SVD_module m -> pp_module m + | _ -> string "def" + + (* + let svir_fundef f params param_ctyps ret_ctyp +*) + let sv_fundef_with ctx f params param_ctyps ret_ctyp fun_body = let sv_ret_ty, index_ty = sv_ctyp ret_ctyp in let sv_ret_ty, typedef = @@ -772,7 +1420,7 @@ module Make (Config : CONFIG) = struct | None -> (string sv_ret_ty, empty) in let param_docs = - try List.map2 (fun param ctyp -> wrap_type ctyp (sv_id param)) params param_ctyps + try List.map2 (fun param ctyp -> wrap_type ctyp (pp_id param)) params param_ctyps with Invalid_argument _ -> Reporting.unreachable Unknown __POS__ "Function arity mismatch" in typedef @@ -782,15 +1430,8 @@ module Make (Config : CONFIG) = struct ^^ nest 4 (hardline ^^ fun_body) ^^ hardline ^^ string "endfunction" - let sv_fundef ctx f params param_ctyps ret_ctyp body = - let fun_body = - if List.exists (fun unrf -> unrf = string_of_id f) Config.unreachable then string "sail_reached_unreachable = 1;" - else - wrap_type ret_ctyp (sv_name Jib_util.return) - ^^ semi ^^ hardline - ^^ separate_map hardline (sv_checked_instr ctx) body - in - sv_fundef_with ctx (sv_id_string f) params param_ctyps ret_ctyp fun_body + let sv_fundef spec_info ctx f params param_ctyps ret_ctyp body = + pp_module (svir_module spec_info ctx f params param_ctyps ret_ctyp body) let filter_clear = filter_instrs (function I_aux (I_clear _, _) -> false | _ -> true) @@ -873,7 +1514,7 @@ module Make (Config : CONFIG) = struct parens (separate space [char 'r'; string "=="; string (reg_ref reg)]); string "begin"; ] - ^^ nest 4 (hardline ^^ string "return" ^^ space ^^ sv_id reg ^^ semi) + ^^ nest 4 (hardline ^^ string "return" ^^ space ^^ pp_id reg ^^ semi) ^^ hardline ^^ string "end" ^^ semi ) regs @@ -894,7 +1535,7 @@ module Make (Config : CONFIG) = struct parens (separate space [char 'r'; string "=="; string (reg_ref reg)]); string "begin"; ] - ^^ nest 4 (hardline ^^ sv_id reg ^^ space ^^ equals ^^ space ^^ char 'v' ^^ semi) + ^^ nest 4 (hardline ^^ pp_id reg ^^ space ^^ equals ^^ space ^^ char 'v' ^^ semi) ^^ hardline ^^ string "end" ^^ semi ) regs @@ -910,11 +1551,30 @@ module Make (Config : CONFIG) = struct let empty_cdef_doc = { outside_module = empty; inside_module_prefix = empty; inside_module = empty } - let sv_cdef ctx fn_ctyps setup_calls (CDEF_aux (aux, _)) = + let svir_cdef spec_info ctx fn_ctyps (CDEF_aux (aux, _)) = + match aux with + | CDEF_val (f, _, param_ctyps, ret_ctyp) -> ([], Bindings.add f (param_ctyps, ret_ctyp) fn_ctyps) + | CDEF_fundef (f, _, params, body) -> + if List.mem (string_of_id f) Config.ignore then ([], fn_ctyps) + else ( + let body = + Jib_optimize.( + body |> flatten_instrs |> remove_dead_code |> variable_decls_to_top (* |> structure_control_flow_block *) + |> remove_undefined |> filter_clear + ) + in + match Bindings.find_opt f fn_ctyps with + | Some (param_ctyps, ret_ctyp) -> + ([SVD_module (svir_module spec_info ctx f params param_ctyps ret_ctyp body)], fn_ctyps) + | None -> Reporting.unreachable (id_loc f) __POS__ ("No function type found for " ^ string_of_id f) + ) + | _ -> ([], fn_ctyps) + + let sv_cdef spec_info ctx fn_ctyps setup_calls (CDEF_aux (aux, _)) = match aux with | CDEF_register (id, ctyp, setup) -> - let binding_doc = wrap_type ctyp (sv_id id) ^^ semi ^^ twice hardline in - let name = sprintf "sail_setup_reg_%s" (sv_id_string id) in + let binding_doc = wrap_type ctyp (pp_id id) ^^ semi ^^ twice hardline in + let name = sprintf "sail_setup_reg_%s" (pp_id_string id) in ( { empty_cdef_doc with inside_module_prefix = binding_doc; inside_module = sv_setup_function ctx name setup }, fn_ctyps, name :: setup_calls @@ -927,7 +1587,7 @@ module Make (Config : CONFIG) = struct else ( let body = Jib_optimize.( - body |> flatten_instrs |> remove_dead_code |> variable_decls_to_top |> structure_control_flow_block + body |> flatten_instrs |> remove_dead_code |> variable_decls_to_top (* |> structure_control_flow_block *) |> remove_undefined |> filter_clear ) in @@ -936,7 +1596,7 @@ module Make (Config : CONFIG) = struct | Some (param_ctyps, ret_ctyp) -> ( { empty_cdef_doc with - inside_module = sv_fundef ctx f params param_ctyps ret_ctyp body ^^ twice hardline; + inside_module = sv_fundef spec_info ctx f params param_ctyps ret_ctyp body ^^ twice hardline; }, fn_ctyps, setup_calls @@ -946,7 +1606,7 @@ module Make (Config : CONFIG) = struct ) | CDEF_let (n, bindings, setup) -> let bindings_doc = - separate_map (semi ^^ hardline) (fun (id, ctyp) -> wrap_type ctyp (sv_id id)) bindings + separate_map (semi ^^ hardline) (fun (id, ctyp) -> wrap_type ctyp (pp_id id)) bindings ^^ semi ^^ twice hardline in let name = sprintf "sail_setup_let_%d" n in @@ -963,7 +1623,7 @@ module Make (Config : CONFIG) = struct | Some (param_ctyps, ret_ctyp) -> begin let main_args = List.map2 - (fun param param_ctyp -> match param_ctyp with CT_unit -> string "SAIL_UNIT" | _ -> sv_id param) + (fun param param_ctyp -> match param_ctyp with CT_unit -> string "SAIL_UNIT" | _ -> pp_id param) params param_ctyps in let non_unit = @@ -976,7 +1636,7 @@ module Make (Config : CONFIG) = struct in let module_main_in = List.map - (fun (param, param_ctyp) -> string "input" ^^ space ^^ wrap_type param_ctyp (sv_id param)) + (fun (param, param_ctyp) -> string "input" ^^ space ^^ wrap_type param_ctyp (pp_id param)) non_unit in match ret_ctyp with diff --git a/src/sail_sv_backend/jib_sv.mli b/src/sail_sv_backend/jib_sv.mli new file mode 100644 index 000000000..a26af9180 --- /dev/null +++ b/src/sail_sv_backend/jib_sv.mli @@ -0,0 +1,140 @@ +(****************************************************************************) +(* Sail *) +(* *) +(* Sail and the Sail architecture models here, comprising all files and *) +(* directories except the ASL-derived Sail code in the aarch64 directory, *) +(* are subject to the BSD two-clause licence below. *) +(* *) +(* The ASL derived parts of the ARMv8.3 specification in *) +(* aarch64/no_vector and aarch64/full are copyright ARM Ltd. *) +(* *) +(* Copyright (c) 2013-2021 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* Alastair Reid (Arm Ltd) *) +(* Louis-Emile Ploix *) +(* *) +(* All rights reserved. *) +(* *) +(* This work was partially supported by EPSRC grant EP/K008528/1 REMS: Rigorous *) +(* Engineering for Mainstream Systems, an ARM iCASE award, EPSRC IAA *) +(* KTF funding, and donations from Arm. This project has received *) +(* funding from the European Research Council (ERC) under the European *) +(* Union’s Horizon 2020 research and innovation programme (grant *) +(* agreement No 789108, ELVER). *) +(* *) +(* This software was developed by SRI International and the University of *) +(* Cambridge Computer Laboratory (Department of Computer Science and *) +(* Technology) under DARPA/AFRL contracts FA8650-18-C-7809 ("CIFV") *) +(* and FA8750-10-C-0237 ("CTSRD"). *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(****************************************************************************) + +open Libsail + +open Ast_util + +type spec_info + +val collect_spec_info : Jib_compile.ctx -> Jib.cdef list -> spec_info + +module type CONFIG = sig + val max_unknown_integer_width : int + val max_unknown_bitvector_width : int + + (** Output SystemVerilog line directives where possible *) + val line_directives : bool + + (** If true, treat all strings as if they were the unit type. + Obviously this is only sound when the semantics does not depend + on strings, and they are only used for output. *) + val nostrings : bool + + val nopacked : bool + val union_padding : bool + val unreachable : string list + val comb : bool + val ignore : string list +end + +module Make (Config : CONFIG) : sig + type cdef_doc = { + outside_module : PPrint.document; + inside_module_prefix : PPrint.document; + inside_module : PPrint.document; + } + + val svir_cdef : + spec_info -> + Jib_compile.ctx -> + (Jib.ctyp list * Libsail.Jib.ctyp) Bindings.t -> + Jib.cdef -> + Sv_ir.sv_def list * (Jib.ctyp list * Jib.ctyp) Bindings.t + + val pp_def : Sv_ir.sv_def -> PPrint.document + + val toplevel_module : spec_info -> Sv_ir.sv_module option + + val sv_cdef : + spec_info -> + Jib_compile.ctx -> + (Jib.ctyp list * Libsail.Jib.ctyp) Bindings.t -> + string list -> + Jib.cdef -> + cdef_doc * (Jib.ctyp list * Jib.ctyp) Bindings.t * string list + + val sv_register_references : Jib.cdef list -> PPrint.document * PPrint.document + + val sv_fundef_with : + Jib_compile.ctx -> string -> Ast.id list -> Jib.ctyp list -> Jib.ctyp -> PPrint.document -> PPrint.document + + val sv_ctyp : Jib.ctyp -> string * string option + + val wrap_type : Jib.ctyp -> PPrint.document -> PPrint.document + + val pp_id_string : Ast.id -> string + + val pp_id : Ast.id -> PPrint.document + val main_args : + Jib.cdef option -> + (Jib.ctyp list * Jib.ctyp) Bindings.t -> + PPrint.document list * PPrint.document option * PPrint.document list + + val make_call_precise : Jib_compile.ctx -> Ast.id -> bool +end diff --git a/src/sail_sv_backend/sail_plugin_sv.ml b/src/sail_sv_backend/sail_plugin_sv.ml index e2f5cbad3..0047afe20 100644 --- a/src/sail_sv_backend/sail_plugin_sv.ml +++ b/src/sail_sv_backend/sail_plugin_sv.ml @@ -317,6 +317,7 @@ module Verilog_config (C : JIB_CONFIG) : Jib_compile.CONFIG = struct let track_throw = true let branch_coverage = None let use_real = false + let use_void = false end let register_types cdefs = @@ -424,21 +425,25 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = ^^ space ^^ string "sail_throw_location;" ^^ twice hardline in - let in_doc, out_doc, reg_doc, fn_ctyps, setup_calls = + let spec_info = Jib_sv.collect_spec_info ctx cdefs in + + let doc, fn_ctyps = List.fold_left - (fun (doc_in, doc_out, doc_reg, fn_ctyps, setup_calls) cdef -> - let cdef_doc, fn_ctyps, setup_calls = sv_cdef ctx fn_ctyps setup_calls cdef in - ( doc_in ^^ cdef_doc.inside_module, - doc_out ^^ cdef_doc.outside_module, - doc_reg ^^ cdef_doc.inside_module_prefix, - fn_ctyps, - setup_calls - ) + (fun (doc, fn_ctyps) cdef -> + let svir_defs, fn_ctyps = svir_cdef spec_info ctx fn_ctyps cdef in + (separate_map (twice hardline) pp_def svir_defs ^^ twice hardline ^^ doc, fn_ctyps) ) - (exception_vars, include_doc, empty, Bindings.empty, []) - cdefs + (empty, Bindings.empty) cdefs + in + let doc = + let library_defs = Generate_primop2.get_generated_library_defs () in + let top_doc = + Option.fold ~none:empty ~some:(fun m -> twice hardline ^^ pp_def (SVD_module m)) (SV.toplevel_module spec_info) + in + separate_map (twice hardline) pp_def library_defs ^^ twice hardline ^^ doc ^^ top_doc in + (* let reg_ref_enums, reg_ref_functions = sv_register_references cdefs in let out_doc = out_doc ^^ reg_ref_enums in let in_doc = reg_doc ^^ reg_ref_functions ^^ in_doc in @@ -496,7 +501,7 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = (List.filter_map (function | CDEF_aux (CDEF_register (id, ctyp, _), _) -> - Some (SV.sv_id id ^^ space ^^ equals ^^ space ^^ sv_id id ^^ string "_in" ^^ semi ^^ hardline) + Some (pp_id id ^^ space ^^ equals ^^ space ^^ pp_id id ^^ string "_in" ^^ semi ^^ hardline) | _ -> None ) cdefs @@ -510,7 +515,7 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = (List.filter_map (function | CDEF_aux (CDEF_register (id, ctyp, _), _) -> - Some (sv_id id ^^ string "_out" ^^ space ^^ equals ^^ space ^^ sv_id id ^^ semi ^^ hardline) + Some (pp_id id ^^ string "_out" ^^ space ^^ equals ^^ space ^^ pp_id id ^^ semi ^^ hardline) | _ -> None ) cdefs @@ -519,7 +524,7 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = in let main = - List.find_opt (function CDEF_aux (CDEF_fundef (id, _, _, _), _) -> sv_id_string id = "main" | _ -> false) cdefs + List.find_opt (function CDEF_aux (CDEF_fundef (id, _, _, _), _) -> pp_id_string id = "main" | _ -> false) cdefs in let main_args, main_result, module_main_in_out = main_args main fn_ctyps in @@ -545,7 +550,7 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = List.filter_map (function | CDEF_aux (CDEF_register (id, ctyp, _), _) -> - Some (string "input" ^^ space ^^ wrap_type ctyp (sv_id id ^^ string "_in")) + Some (string "input" ^^ space ^^ wrap_type ctyp (pp_id id ^^ string "_in")) | _ -> None ) cdefs @@ -557,13 +562,12 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = List.filter_map (function | CDEF_aux (CDEF_register (id, ctyp, _), _) -> - Some (string "output" ^^ space ^^ wrap_type ctyp (sv_id id ^^ string "_out")) + Some (string "output" ^^ space ^^ wrap_type ctyp (pp_id id ^^ string "_out")) | _ -> None ) cdefs else [] in - let sv_output = Pretty_print_sail.Document.to_string (wrap_module out_doc ("sail_" ^ out) @@ -571,6 +575,8 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = (in_doc ^^ wire_funs ^^ setup_function ^^ invoke_main) ) in + *) + let sv_output = Pretty_print_sail.Document.to_string doc in make_genlib_file (sprintf "sail_genlib_%s.sv" out); let ((out_chan, _, _, _) as file_info) = Util.open_output_with_check_unformatted !opt_output_dir (out ^ ".sv") in @@ -591,10 +597,15 @@ let verilog_target _ default_sail_dir out_opt ast effect_info env = (verilator_cpp_wrapper out); Util.close_output_with_check file_info; - Reporting.system_checked - (sprintf "verilator --cc --exe --build -j 0 -I%s --Mdir %s_obj_dir sim_%s.cpp %s.sv" sail_sv_libdir out out - out - ); + (* Verilator sometimes just spuriously returns non-zero exit + codes even when it suceeds, so we don't use system_checked + here, and just hope for the best. *) + let _ = + Unix.system + (sprintf "verilator --cc --exe --build -j 0 -I%s --Mdir %s_obj_dir sim_%s.cpp %s.sv" sail_sv_libdir out out + out + ) + in begin match !opt_verilate with | Verilator_run -> Reporting.system_checked (sprintf "%s_obj_dir/V%s" out out) diff --git a/src/sail_sv_backend/sv_ir.ml b/src/sail_sv_backend/sv_ir.ml new file mode 100644 index 000000000..e054f794c --- /dev/null +++ b/src/sail_sv_backend/sv_ir.ml @@ -0,0 +1,341 @@ +(****************************************************************************) +(* Sail *) +(* *) +(* Sail and the Sail architecture models here, comprising all files and *) +(* directories except the ASL-derived Sail code in the aarch64 directory, *) +(* are subject to the BSD two-clause licence below. *) +(* *) +(* The ASL derived parts of the ARMv8.3 specification in *) +(* aarch64/no_vector and aarch64/full are copyright ARM Ltd. *) +(* *) +(* Copyright (c) 2013-2021 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* Alastair Reid (Arm Ltd) *) +(* Louis-Emile Ploix *) +(* *) +(* All rights reserved. *) +(* *) +(* This work was partially supported by EPSRC grant EP/K008528/1 REMS: Rigorous *) +(* Engineering for Mainstream Systems, an ARM iCASE award, EPSRC IAA *) +(* KTF funding, and donations from Arm. This project has received *) +(* funding from the European Research Council (ERC) under the European *) +(* Union’s Horizon 2020 research and innovation programme (grant *) +(* agreement No 789108, ELVER). *) +(* *) +(* This software was developed by SRI International and the University of *) +(* Cambridge Computer Laboratory (Department of Computer Science and *) +(* Technology) under DARPA/AFRL contracts FA8650-18-C-7809 ("CIFV") *) +(* and FA8750-10-C-0237 ("CTSRD"). *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(****************************************************************************) + +open Libsail + +open Ast_util +open Jib_util +open Jib_visitor +open PPrint +open Smt_exp + +open Generate_primop + +type sv_name = SVN_id of Ast.id | SVN_string of string + +module SVName = struct + type t = sv_name + let compare n1 n2 = + match (n1, n2) with + | SVN_id id1, SVN_id id2 -> Id.compare id1 id2 + | SVN_string s1, SVN_string s2 -> String.compare s1 s2 + | SVN_id _, _ -> 1 + | _, SVN_id _ -> -1 +end + +let modify_sv_name ?(prefix = "") ?(suffix = "") = function + | SVN_id id -> SVN_id (append_id (prepend_id prefix id) suffix) + | SVN_string str -> SVN_string (prefix ^ str ^ suffix) + +let string_of_sv_name = function SVN_id id -> string_of_id id | SVN_string str -> str + +module SVNameMap = Map.Make (SVName) + +type sv_module_port = { name : Jib.name; external_name : string; typ : Jib.ctyp } + +let mk_port name ctyp = { name; external_name = ""; typ = ctyp } + +type sv_module = { + name : sv_name; + input_ports : sv_module_port list; + output_ports : sv_module_port list; + defs : sv_def list; +} + +and sv_function = { + function_name : sv_name; + return_type : Jib.ctyp option; + params : (Ast.id * Jib.ctyp) list; + body : sv_statement; +} + +and sv_def = + | SVD_type of Jib.ctype_def + | SVD_module of sv_module + | SVD_var of Jib.name * Jib.ctyp + | SVD_fundef of sv_function + | SVD_instantiate of { + module_name : sv_name; + instance_name : string; + input_connections : smt_exp list; + output_connections : sv_place list; + } + | SVD_always_comb of sv_statement + +and sv_place = + | SVP_id of Jib.name + | SVP_index of sv_place * smt_exp + | SVP_field of sv_place * Ast.id + | SVP_multi of sv_place list + | SVP_void + +and sv_statement = SVS_aux of sv_statement_aux * Ast.l + +and sv_statement_aux = + | SVS_comment of string + | SVS_skip + | SVS_var of Jib.name * Jib.ctyp * smt_exp option + | SVS_return of smt_exp + | SVS_assign of sv_place * smt_exp + | SVS_call of sv_place * sv_name * smt_exp list + | SVS_case of { head_exp : smt_exp; cases : (Ast.id list * sv_statement) list; fallthrough : sv_statement option } + | SVS_if of smt_exp * sv_statement option * sv_statement option + | SVS_block of sv_statement list + | SVS_raw of string * Jib.name list * Jib.name list + +let svs_raw ?(inputs = []) ?(outputs = []) s = SVS_raw (s, inputs, outputs) + +let mk_statement ?(loc = Parse_ast.Unknown) aux = SVS_aux (aux, loc) + +class type svir_visitor = object + inherit common_visitor + method vsmt_exp : smt_exp -> smt_exp visit_action + method vplace : sv_place -> sv_place visit_action + method vstatement : sv_statement -> sv_statement visit_action + method vdef : sv_def -> sv_def visit_action +end + +let rec visit_smt_exp (vis : svir_visitor) outer_smt_exp = + let aux (vis : svir_visitor) no_change = + match no_change with + | Var name -> + let name' = visit_name (vis :> common_visitor) name in + if name == name' then no_change else Var name' + | ZeroExtend (n, m, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else ZeroExtend (n, m, exp') + | SignExtend (n, m, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else SignExtend (n, m, exp') + | Extract (n, m, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else Extract (n, m, exp') + | Hd (hd_op, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else Hd (hd_op, exp') + | Tl (tl_op, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else Tl (tl_op, exp') + | Tester (test, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else Tester (test, exp') + | Ite (i, t, e) -> + let i' = visit_smt_exp vis i in + let t' = visit_smt_exp vis t in + let e' = visit_smt_exp vis e in + if i == i' && t == t' && e == e' then no_change else Ite (i', t', e') + | Store (info, store_fn, arr, index, x) -> + let arr' = visit_smt_exp vis arr in + let index' = visit_smt_exp vis index in + let x' = visit_smt_exp vis x in + if arr == arr' && index == index' && x == x' then no_change else Store (info, store_fn, arr', index', x') + | Field (struct_id, field_id, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else Field (struct_id, field_id, exp') + | Fn (f, args) -> + let args' = map_no_copy (visit_smt_exp vis) args in + if args == args' then no_change else Fn (f, args') + | Unwrap (ctor, b, exp) -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else Unwrap (ctor, b, exp') + | Bool_lit _ | Bitvec_lit _ | Real_lit _ | String_lit _ | Enum _ | Empty_list -> no_change + in + do_visit vis (vis#vsmt_exp outer_smt_exp) aux outer_smt_exp + +let rec visit_sv_place (vis : svir_visitor) outer_place = + let aux (vis : svir_visitor) no_change = + match no_change with + | SVP_id name -> + let name' = visit_name (vis :> common_visitor) name in + if name == name' then no_change else SVP_id name' + | SVP_index (place, exp) -> + let place' = visit_sv_place vis place in + let exp' = visit_smt_exp vis exp in + if place == place' && exp == exp' then no_change else SVP_index (place', exp') + | SVP_field (place, field_id) -> + let place' = visit_sv_place vis place in + if place == place' then no_change else SVP_field (place', field_id) + | SVP_multi places -> + let places' = map_no_copy (visit_sv_place vis) places in + if places == places' then no_change else SVP_multi places' + | SVP_void -> no_change + in + do_visit vis (vis#vplace outer_place) aux outer_place + +let rec visit_sv_statement (vis : svir_visitor) outer_statement = + let aux (vis : svir_visitor) (SVS_aux (stmt, l) as no_change) = + match stmt with + | SVS_var (name, ctyp, None) -> + let name' = visit_name (vis :> common_visitor) name in + let ctyp' = visit_ctyp (vis :> common_visitor) ctyp in + if name == name' && ctyp == ctyp' then no_change else SVS_aux (SVS_var (name', ctyp', None), l) + | SVS_var (name, ctyp, Some exp) -> + let name' = visit_name (vis :> common_visitor) name in + let ctyp' = visit_ctyp (vis :> common_visitor) ctyp in + let exp' = visit_smt_exp vis exp in + if name == name' && ctyp == ctyp' && exp == exp' then no_change + else SVS_aux (SVS_var (name', ctyp', Some exp'), l) + | SVS_assign (place, exp) -> + let place' = visit_sv_place vis place in + let exp' = visit_smt_exp vis exp in + if place == place' && exp == exp' then no_change else SVS_aux (SVS_assign (place', exp'), l) + | SVS_block statements -> + let statements' = map_no_copy (visit_sv_statement vis) statements in + if statements == statements' then no_change else SVS_aux (SVS_block statements', l) + | SVS_return exp -> + let exp' = visit_smt_exp vis exp in + if exp == exp' then no_change else SVS_aux (SVS_return exp', l) + | SVS_call (place, f, args) -> + let place' = visit_sv_place vis place in + let args' = map_no_copy (visit_smt_exp vis) args in + if place == place' && args == args' then no_change else SVS_aux (SVS_call (place', f, args'), l) + | SVS_case { head_exp; cases; fallthrough } -> + let head_exp' = visit_smt_exp vis head_exp in + let cases' = + map_no_copy + (function + | (ids, stmt) as no_change -> + let stmt' = visit_sv_statement vis stmt in + if stmt == stmt' then no_change else (ids, stmt') + ) + cases + in + let fallthrough' = map_no_copy_opt (visit_sv_statement vis) fallthrough in + if head_exp == head_exp' && cases == cases' && fallthrough == fallthrough' then no_change + else SVS_aux (SVS_case { head_exp = head_exp'; cases = cases'; fallthrough = fallthrough' }, l) + | SVS_if (exp, then_stmt_opt, else_stmt_opt) -> + let exp' = visit_smt_exp vis exp in + let then_stmt_opt' = map_no_copy_opt (visit_sv_statement vis) then_stmt_opt in + let else_stmt_opt' = map_no_copy_opt (visit_sv_statement vis) else_stmt_opt in + if exp == exp' && then_stmt_opt == then_stmt_opt' && else_stmt_opt == else_stmt_opt' then no_change + else SVS_aux (SVS_if (exp', then_stmt_opt', else_stmt_opt'), l) + | SVS_raw _ | SVS_comment _ | SVS_skip -> no_change + in + do_visit vis (vis#vstatement outer_statement) aux outer_statement + +let rec visit_sv_def (vis : svir_visitor) outer_def = + let aux (vis : svir_visitor) no_change = + match no_change with + | SVD_type _ -> no_change + | SVD_module { name; input_ports; output_ports; defs } -> + let visit_port ({ name; external_name; typ } as no_change) = + let name' = visit_name (vis :> common_visitor) name in + let typ' = visit_ctyp (vis :> common_visitor) typ in + if name == name' && typ == typ' then no_change else { name = name'; external_name; typ = typ' } + in + let input_ports' = map_no_copy visit_port input_ports in + let output_ports' = map_no_copy visit_port output_ports in + let defs' = map_no_copy (visit_sv_def vis) defs in + if input_ports == input_ports' && output_ports == output_ports' && defs == defs' then no_change + else SVD_module { name; input_ports = input_ports'; output_ports = output_ports'; defs = defs' } + | SVD_var (name, ctyp) -> + let name' = visit_name (vis :> common_visitor) name in + let ctyp' = visit_ctyp (vis :> common_visitor) ctyp in + if name == name' && ctyp == ctyp' then no_change else SVD_var (name', ctyp') + | SVD_instantiate { module_name; instance_name; input_connections; output_connections } -> + let input_connections' = map_no_copy (visit_smt_exp vis) input_connections in + let output_connections' = map_no_copy (visit_sv_place vis) output_connections in + if input_connections == input_connections' && output_connections == output_connections' then no_change + else + SVD_instantiate + { + module_name; + instance_name; + input_connections = input_connections'; + output_connections = output_connections'; + } + | SVD_fundef { function_name; return_type; params; body } -> + let return_type' = map_no_copy_opt (visit_ctyp (vis :> common_visitor)) return_type in + let params' = + map_no_copy + (function + | (id, ctyp) as no_change -> + let ctyp' = visit_ctyp (vis :> common_visitor) ctyp in + if ctyp == ctyp' then no_change else (id, ctyp') + ) + params + in + let body' = visit_sv_statement vis body in + if return_type == return_type' && params == params' && body == body' then no_change + else SVD_fundef { function_name; return_type = return_type'; params = params'; body = body' } + | SVD_always_comb statement -> + let statement' = visit_sv_statement vis statement in + if statement == statement' then no_change else SVD_always_comb statement' + in + do_visit vis (vis#vdef outer_def) aux outer_def + +class empty_svir_visitor : svir_visitor = + object + method vid _ = None + method vname _ = None + method vctyp _ = DoChildren + method vsmt_exp _ = DoChildren + method vplace _ = DoChildren + method vstatement _ = DoChildren + method vdef _ = DoChildren + end diff --git a/src/sail_sv_backend/sv_ir.mli b/src/sail_sv_backend/sv_ir.mli new file mode 100644 index 000000000..e89dffff9 --- /dev/null +++ b/src/sail_sv_backend/sv_ir.mli @@ -0,0 +1,171 @@ +(****************************************************************************) +(* Sail *) +(* *) +(* Sail and the Sail architecture models here, comprising all files and *) +(* directories except the ASL-derived Sail code in the aarch64 directory, *) +(* are subject to the BSD two-clause licence below. *) +(* *) +(* The ASL derived parts of the ARMv8.3 specification in *) +(* aarch64/no_vector and aarch64/full are copyright ARM Ltd. *) +(* *) +(* Copyright (c) 2013-2021 *) +(* Kathyrn Gray *) +(* Shaked Flur *) +(* Stephen Kell *) +(* Gabriel Kerneis *) +(* Robert Norton-Wright *) +(* Christopher Pulte *) +(* Peter Sewell *) +(* Alasdair Armstrong *) +(* Brian Campbell *) +(* Thomas Bauereiss *) +(* Anthony Fox *) +(* Jon French *) +(* Dominic Mulligan *) +(* Stephen Kell *) +(* Mark Wassell *) +(* Alastair Reid (Arm Ltd) *) +(* Louis-Emile Ploix *) +(* *) +(* All rights reserved. *) +(* *) +(* This work was partially supported by EPSRC grant EP/K008528/1 REMS: Rigorous *) +(* Engineering for Mainstream Systems, an ARM iCASE award, EPSRC IAA *) +(* KTF funding, and donations from Arm. This project has received *) +(* funding from the European Research Council (ERC) under the European *) +(* Union’s Horizon 2020 research and innovation programme (grant *) +(* agreement No 789108, ELVER). *) +(* *) +(* This software was developed by SRI International and the University of *) +(* Cambridge Computer Laboratory (Department of Computer Science and *) +(* Technology) under DARPA/AFRL contracts FA8650-18-C-7809 ("CIFV") *) +(* and FA8750-10-C-0237 ("CTSRD"). *) +(* *) +(* Redistribution and use in source and binary forms, with or without *) +(* modification, are permitted provided that the following conditions *) +(* are met: *) +(* 1. Redistributions of source code must retain the above copyright *) +(* notice, this list of conditions and the following disclaimer. *) +(* 2. Redistributions in binary form must reproduce the above copyright *) +(* notice, this list of conditions and the following disclaimer in *) +(* the documentation and/or other materials provided with the *) +(* distribution. *) +(* *) +(* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' *) +(* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED *) +(* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A *) +(* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR *) +(* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, *) +(* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT *) +(* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF *) +(* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND *) +(* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, *) +(* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT *) +(* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF *) +(* SUCH DAMAGE. *) +(****************************************************************************) + +(** This file defines an intermediate representation that is roughly + equivalent to the subset of SystemVerilog that we target. This + enables us to perform SystemVerilog to SystemVerilog rewrites - + for this purpose we also define a vistor-pattern rewriter + [svir_visitor], much like it's [jib_visitor] equivalent. *) + +open Libsail + +open Jib_visitor +open Smt_exp + +type sv_name = SVN_id of Ast.id | SVN_string of string + +module SVName : sig + type t = sv_name + val compare : sv_name -> sv_name -> int +end + +module SVNameMap : sig + include Map.S with type key = sv_name +end + +val modify_sv_name : ?prefix:string -> ?suffix:string -> sv_name -> sv_name + +val string_of_sv_name : sv_name -> string + +type sv_module_port = { name : Jib.name; external_name : string; typ : Jib.ctyp } + +val mk_port : Jib.name -> Jib.ctyp -> sv_module_port + +type sv_module = { + name : sv_name; + input_ports : sv_module_port list; + output_ports : sv_module_port list; + defs : sv_def list; +} + +and sv_function = { + function_name : sv_name; + return_type : Jib.ctyp option; + params : (Ast.id * Jib.ctyp) list; + body : sv_statement; +} + +and sv_def = + | SVD_type of Jib.ctype_def + | SVD_module of sv_module + | SVD_var of Jib.name * Jib.ctyp + | SVD_fundef of sv_function + | SVD_instantiate of { + module_name : sv_name; + instance_name : string; + input_connections : smt_exp list; + output_connections : sv_place list; + } + | SVD_always_comb of sv_statement + +and sv_place = + | SVP_id of Jib.name + | SVP_index of sv_place * smt_exp + | SVP_field of sv_place * Ast.id + | SVP_multi of sv_place list + | SVP_void + +and sv_statement = SVS_aux of sv_statement_aux * Ast.l + +and sv_statement_aux = + | SVS_comment of string + | SVS_skip + | SVS_var of Jib.name * Jib.ctyp * smt_exp option + | SVS_return of smt_exp + | SVS_assign of sv_place * smt_exp + | SVS_call of sv_place * sv_name * smt_exp list + | SVS_case of { head_exp : smt_exp; cases : (Ast.id list * sv_statement) list; fallthrough : sv_statement option } + | SVS_if of smt_exp * sv_statement option * sv_statement option + | SVS_block of sv_statement list + | SVS_raw of string * Jib.name list * Jib.name list + +val svs_raw : ?inputs:Jib.name list -> ?outputs:Jib.name list -> string -> sv_statement_aux + +val mk_statement : ?loc:Parse_ast.l -> sv_statement_aux -> sv_statement + +class type svir_visitor = object + (** Note that despite inheriting from common_visitor, we don't use + [vid]. Instead specific types of identifiers should be + re-written by matching on their containing node. *) + inherit common_visitor + + method vsmt_exp : smt_exp -> smt_exp visit_action + method vplace : sv_place -> sv_place visit_action + method vstatement : sv_statement -> sv_statement visit_action + method vdef : sv_def -> sv_def visit_action +end + +class empty_svir_visitor : svir_visitor + +val visit_smt_exp : svir_visitor -> smt_exp -> smt_exp + +val visit_sv_place : svir_visitor -> sv_place -> sv_place + +val visit_sv_statement : svir_visitor -> sv_statement -> sv_statement + +val visit_sv_def : svir_visitor -> sv_def -> sv_def diff --git a/test/smt/issue573_1.sat.sail b/test/smt/issue573_1.sat.sail new file mode 100644 index 000000000..78864512d --- /dev/null +++ b/test/smt/issue573_1.sat.sail @@ -0,0 +1,10 @@ +default Order dec + +$include + +$counterexample +function prop(x : int) -> bool = { + if x == 0 then return true; + if x == 100 then return true; + false +} diff --git a/test/smt/issue573_2.sat.sail b/test/smt/issue573_2.sat.sail new file mode 100644 index 000000000..571fe0f7e --- /dev/null +++ b/test/smt/issue573_2.sat.sail @@ -0,0 +1,10 @@ +default Order dec + +$include + +$counterexample +function prop(x : int) -> bool = { + if x == 0 then return true; + if x == 100 then return true; + return false +} diff --git a/test/smt/linked_int.unsat.sail b/test/smt/linked_int.unsat.sail new file mode 100644 index 000000000..93ba0900e --- /dev/null +++ b/test/smt/linked_int.unsat.sail @@ -0,0 +1,6 @@ +default Order dec + +$include + +$property +function prop forall 'n. (x: int('n), y: int('n)) -> bool = x == y diff --git a/test/smt/linked_int2.unsat.sail b/test/smt/linked_int2.unsat.sail new file mode 100644 index 000000000..2694a72d6 --- /dev/null +++ b/test/smt/linked_int2.unsat.sail @@ -0,0 +1,6 @@ +default Order dec + +$include + +$property +function prop forall 'n 'm, 'n == 'm. (x: int('n), y: int('m)) -> bool = x == y diff --git a/test/smt/lzcnt.unsat.sail b/test/smt/lzcnt.unsat.sail index b627c9e54..822b85863 100644 --- a/test/smt/lzcnt.unsat.sail +++ b/test/smt/lzcnt.unsat.sail @@ -2,7 +2,7 @@ default Order dec $include -val lzcnt = "count_leading_zeros" : forall 'w. bits('w) -> range(0, 'w) +val lzcnt = pure "count_leading_zeros" : forall 'w. bits('w) -> range(0, 'w) $property function prop() -> bool = { diff --git a/test/smt/revrev_endianness2.unsat.sail b/test/smt/revrev_endianness2.unsat.sail index 33ba93a22..812f3cb4d 100644 --- a/test/smt/revrev_endianness2.unsat.sail +++ b/test/smt/revrev_endianness2.unsat.sail @@ -1,5 +1,7 @@ default Order dec +$option -smt_bits_size 128 + $include $property diff --git a/test/smt/revrev_endianness3.unsat.sail b/test/smt/revrev_endianness3.unsat.sail new file mode 100644 index 000000000..a75049596 --- /dev/null +++ b/test/smt/revrev_endianness3.unsat.sail @@ -0,0 +1,22 @@ +default Order dec + +$option -smt_bits_size 128 + +$include +$include + +$property +function prop forall 'n, 'n in {8, 16, 32, 64, 128}. (n: int('n), xs: bits(128)) -> bool = { + let xs = xs[n - 1 .. 0] in + if length(xs) == 8 then { + reverse_endianness(reverse_endianness(xs)) == xs + } else if length(xs) == 16 then { + reverse_endianness(reverse_endianness(xs)) == xs + } else if length(xs) == 32 then { + reverse_endianness(reverse_endianness(xs)) == xs + } else if length(xs) == 64 then { + reverse_endianness(reverse_endianness(xs)) == xs + } else { + reverse_endianness(reverse_endianness(xs)) == xs + } +} diff --git a/test/smt/run_tests.py b/test/smt/run_tests.py index 77f157967..ead1993d6 100755 --- a/test/smt/run_tests.py +++ b/test/smt/run_tests.py @@ -20,7 +20,14 @@ 'assembly_mapping_sat': { 'z3', 'cvc4' }, # This test using unsupported CVC4 features 'arith_unsat': { 'z3', 'cvc4' }, 'arith_LFL_unsat' : { 'z3', 'cvc4' }, - 'revrev_endianness2_unsat' : { 'z3', 'cvc4' }, # There is some bug in this test + 'store_load_sat' : { 'z3', 'cvc4' }, + 'load_store_dep_sat' : { 'z3', 'cvc4' }, + 'store_load_scattered_sat' : { 'z3', 'cvc4' }, + 'mem_builtins_unsat' : { 'z3', 'cvc4' }, + 'rv_add_1_unsat' : { 'z3', 'cvc4' }, + 'rv_add_0_unsat' : { 'z3', 'cvc4' }, + 'rv_add_1_sat' : { 'z3', 'cvc4' }, + 'rv_add_0_sat' : { 'z3', 'cvc4' }, } print("Sail is {}".format(sail)) diff --git a/test/smt/rv_add_1.unsat.sail b/test/smt/rv_add_1.unsat.sail index d923d1d56..ca7904731 100644 --- a/test/smt/rv_add_1.unsat.sail +++ b/test/smt/rv_add_1.unsat.sail @@ -50,6 +50,7 @@ register R31 : xlenbits /* Getters and setters for X registers (special case for zeros register, x0) */ val rX : forall 'n, 0 <= 'n < 32. regno('n) -> xlenbits effect {rreg} +$[jib_debug] function rX(r) = { if r == 0 then sail_zero_extend(0x0, sizeof(xlen)) else if r == 1 then R1 @@ -149,6 +150,7 @@ function clause execute (ITYPE (imm, rs1, rd, RISCV_ADDI)) = function clause decode _ = None() $property +$[jib_debug] function prop(imm: bits(12), rs1: regbits, rd: regbits) -> bool = { let v = X(rs1); match decode(imm @ rs1 @ 0b000 @ rd @ 0b0010011) { diff --git a/test/smt/string.unsat.sail b/test/smt/string.unsat.sail index b91abfadd..13bc259e5 100644 --- a/test/smt/string.unsat.sail +++ b/test/smt/string.unsat.sail @@ -2,10 +2,6 @@ default Order dec $include -val "concat_str" : (string, string) -> string - -val "eq_string" : (string, string) -> bool - overload operator == = {eq_string} $property diff --git a/test/sv/.gitignore b/test/sv/.gitignore new file mode 100644 index 000000000..8df2fcca4 --- /dev/null +++ b/test/sv/.gitignore @@ -0,0 +1,2 @@ +# Ignore generated directories from SV tests +**/* diff --git a/test/sv/run_tests.py b/test/sv/run_tests.py index cf7e0c8d1..c544bbc28 100755 --- a/test/sv/run_tests.py +++ b/test/sv/run_tests.py @@ -67,6 +67,7 @@ def test_sv(name, opts, skip_list): continue tests[filename] = os.fork() if tests[filename] == 0: + step('rm -rf {}_obj_dir'.format(basename)); if basename.startswith('fail'): step('{} -no_warn -sv ../c/{} -o {} -sv_verilate compile{} > {}.out'.format(sail, filename, basename, opts, basename)) else: