Skip to content

Commit

Permalink
translate trivial type quantifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
javra authored and Alasdair committed Dec 16, 2024
1 parent 766b31e commit 1b03b22
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 30 deletions.
66 changes: 37 additions & 29 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ open Rewriter
open PPrint
open Pretty_print_common

let implicit_parens x = enclose (string "{") (string "}") x
let doc_id_ctor (Id_aux (i, _)) =
match i with Id i -> string i | Operator x -> string (Util.zencode_string ("op " ^ x))
let doc_kid (Kid_aux (Var x, _)) = string ("k_" ^ String.sub x 1 (String.length x - 1))
(* TODO do a proper renaming and keep track of it *)

let is_enum env id = match Env.lookup_id id env with Enum _ -> true | _ -> false

Expand Down Expand Up @@ -83,6 +86,7 @@ let string_of_nexp_con (Nexp_aux (n, l)) =
let doc_nexp (Nexp_aux (n, l) as nexp) =
match n with
| Nexp_constant i -> string (Big_int.to_string i)
| Nexp_var ki -> doc_kid ki
| _ -> failwith ("NExp " ^ string_of_nexp_con nexp ^ " " ^ string_of_nexp nexp ^ " not translatable yet.")

let string_of_typ_con (Typ_aux (t, _)) =
Expand All @@ -104,18 +108,25 @@ let rec doc_typ (Typ_aux (t, _) as typ) =
| Typ_id (Id_aux (Id "bit", _)) -> parens (string "BitVec 1")
| Typ_id (Id_aux (Id "nat", _)) -> string "Nat"
| Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _)]) -> string "BitVec " ^^ doc_nexp m
| Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) -> string "Int"
| Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) doc_typ ts)
| Typ_id (Id_aux (Id id, _)) -> string id
| _ -> failwith ("Type " ^ string_of_typ_con typ ^ " " ^ string_of_typ typ ^ " not translatable yet.")

let doc_typ_id (typ, fid) = concat [doc_id_ctor fid; space; colon; space; doc_typ typ; hardline]

let doc_typ_quant tq =
match tq with
| TypQ_tq qs -> (
match qs with [] -> string "" | _ -> failwith "Type quantifier not translatable yet."
)
| TypQ_no_forall -> string ""
let doc_kind (K_aux (k, _)) =
match k with
| K_int -> string "Int"
| K_bool -> string "Bool"
| _ -> failwith ("Kind " ^ string_of_kind_aux k ^ " not translatable yet.")

let doc_quant_item (QI_aux (qi, _)) =
match qi with
| QI_id (KOpt_aux (KOpt_kind (k, ki), _)) -> implicit_parens (flow (break 1) [doc_kid ki; colon; doc_kind k])
| QI_constraint _ -> failwith "Constraints not supported yet."

let doc_typ_quant tq = match tq with TypQ_tq qs -> List.map doc_quant_item qs | TypQ_no_forall -> []

let lean_escape_string s = Str.global_replace (Str.regexp "\"") "\"\"" s

Expand Down Expand Up @@ -210,28 +221,24 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
| Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> (arg_typs, ret_typ, no_effect)
| _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type")
in
match tq with
| TypQ_tq [] | TypQ_no_forall ->
();
let pat, _, _, _ = destruct_pexp pexp in
let pats, _ = untuple_args_pat arg_typs pat in
let binders : (id * typ) list =
pats
|> List.map (fun (pat, typ) ->
match pat_is_plain_binder env pat with
| Some (Some id) -> (id, typ)
| Some None -> (Id_aux (Id "x", l), typ) (* TODO fresh name or wildcard instead of x *)
| _ -> failwith "Argument pattern not translatable yet."
)
in
let binders : document list =
binders |> List.map (fun (i, t) -> separate space [string (string_of_id i); colon; doc_typ t] |> parens)
in
separate space ([string "def"; string (string_of_id id)] @ binders @ [colon; doc_typ ret_typ; coloneq])
| TypQ_tq qs ->
let qs = List.map string_of_quant_item qs in
let foo = String.concat ";" qs in
failwith ("Type quantifiers (" ^ foo ^ ") not translatable yet")
let pat, _, _, _ = destruct_pexp pexp in
let pats, _ = untuple_args_pat arg_typs pat in
let binders : (id * typ) list =
pats
|> List.map (fun (pat, typ) ->
match pat_is_plain_binder env pat with
| Some (Some id) -> (id, typ)
| Some None -> (Id_aux (Id "x", l), typ) (* TODO fresh name or wildcard instead of x *)
| _ -> failwith "Argument pattern not translatable yet."
)
in
let binders : document list =
binders |> List.map (fun (i, t) -> separate space [string (string_of_id i); colon; doc_typ t] |> parens)
in
(* let binders = doc_typ_quant tq @ binders in *)
(* Use auto-implicits for type quanitifiers for now and see if this works *)
let doc_ret_typ = doc_typ ret_typ in
separate space ([string "def"; string (string_of_id id)] @ binders @ [colon; doc_ret_typ; coloneq])

let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) =
let _, _, exp, _ = destruct_pexp pexp in
Expand Down Expand Up @@ -270,7 +277,8 @@ let doc_typdef (TD_aux (td, tannot) as full_typdef) =
let fields = List.map doc_typ_id fields in
let enums_doc = concat fields in
let rectyp = doc_typ_quant tq in
nest 2 (flow (break 1) [string "structure"; string id; rectyp; string "where"] ^^ hardline ^^ enums_doc)
(* TODO don't ignore type quantifiers *)
nest 2 (flow (break 1) [string "structure"; string id; string "where"] ^^ hardline ^^ enums_doc)
| _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.")

let doc_def (DEF_aux (aux, def_annot) as def) =
Expand Down
2 changes: 1 addition & 1 deletion test/lean/struct.expected.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Sail.sail

structure My_struct where
structure My_struct where
field1 : Int
field2 : Int

Expand Down
11 changes: 11 additions & 0 deletions test/lean/typquant.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import Sail.sail

def foo (n : Int) : BitVec 4 :=
(0xF : BitVec 4)

def bar (x : BitVec k_n) : BitVec k_n :=
x

def initialize_registers : Unit :=
()

15 changes: 15 additions & 0 deletions test/lean/typquant.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
default Order dec

$include <prelude.sail>

val foo : forall 'n. int('n) -> bits(4)

function foo(n) = {
0xF
}

val bar : forall 'n. bits('n) -> bits('n)

function bar(x) = {
x
}

0 comments on commit 1b03b22

Please sign in to comment.