Skip to content

Commit

Permalink
Lean: Add bitvector function definitions for the lean backend (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot authored Nov 27, 2024
1 parent 34a62cf commit d395f09
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 5 deletions.
30 changes: 26 additions & 4 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ val eq_bits = pure {
interpreter: "eq_list",
lem: "eq_vec",
coq: "eq_vec",
lean: "Eq",
_: "eq_bits"
} : forall 'n. (bits('n), bits('n)) -> bool

Expand All @@ -68,7 +69,8 @@ overload operator == = {eq_bit, eq_bits}
val neq_bits = pure {
lem: "neq_vec",
coq: "neq_vec",
c: "neq_bits"
c: "neq_bits",
lean: "Ne"
} : forall 'n. (bits('n), bits('n)) -> bool

function neq_bits(x, y) = not_bool(eq_bits(x, y))
Expand Down Expand Up @@ -130,7 +132,14 @@ function sail_mask(len, v) = if len <= length(v) then truncate(v, len) else sail

overload operator ^ = {sail_mask}

val bitvector_concat = pure {ocaml: "append", interpreter: "append", lem: "concat_vec", coq: "concat_vec", _: "append"} : forall 'n 'm.
val bitvector_concat = pure {
ocaml: "append",
interpreter: "append",
lem: "concat_vec",
coq: "concat_vec",
lean: "BitVec.append",
_: "append"
} : forall 'n 'm.
(bits('n), bits('m)) -> bits('n + 'm)

overload append = {bitvector_concat}
Expand Down Expand Up @@ -199,7 +208,7 @@ val add_bits = pure {
interpreter: "add_vec",
lem: "add_vec",
coq: "add_vec",
lean: "Add.add",
lean: "HAdd.hAdd",
_: "add_bits"
} : forall 'n. (bits('n), bits('n)) -> bits('n)

Expand All @@ -218,16 +227,25 @@ val sub_bits = pure {
interpreter: "sub_vec",
lem: "sub_vec",
coq: "sub_vec",
lean: "HSub.hSub",
_: "sub_bits"
} : forall 'n. (bits('n), bits('n)) -> bits('n)

val not_vec = pure {ocaml: "not_vec", lem: "not_vec", coq: "not_vec", interpreter: "not_vec", _: "not_bits"} : forall 'n. bits('n) -> bits('n)
val not_vec = pure {
ocaml: "not_vec",
lem: "not_vec",
coq: "not_vec",
interpreter: "not_vec",
lean: "Complement.complement",
_: "not_bits"
} : forall 'n. bits('n) -> bits('n)

val and_vec = pure {
lem: "and_vec",
coq: "and_vec",
ocaml: "and_vec",
interpreter: "and_vec",
lean: "HAnd.hAnd",
_: "and_bits"
} : forall 'n. (bits('n), bits('n)) -> bits('n)

Expand All @@ -238,6 +256,7 @@ val or_vec = pure {
coq: "or_vec",
ocaml: "or_vec",
interpreter: "or_vec",
lean: "HOr.hOr",
_: "or_bits"
} : forall 'n. (bits('n), bits('n)) -> bits('n)

Expand All @@ -248,6 +267,7 @@ val xor_vec = pure {
coq: "xor_vec",
ocaml: "xor_vec",
interpreter: "xor_vec",
lean: "HXor.hXor",
_: "xor_bits"
} : forall 'n. (bits('n), bits('n)) -> bits('n)

Expand Down Expand Up @@ -346,6 +366,7 @@ val unsigned = pure {
lem: "uint",
interpreter: "uint",
coq: "uint",
lean: "BitVec.toNat",
_: "sail_unsigned"
} : forall 'n. bits('n) -> range(0, 2 ^ 'n - 1)

Expand All @@ -358,6 +379,7 @@ val signed = pure {
lem: "sint",
interpreter: "sint",
coq: "sint",
lean: "BitVec.toInt",
_: "sail_signed"
} : forall 'n, 'n > 0. bits('n) -> range(- (2 ^ ('n - 1)), 2 ^ ('n - 1) - 1)

Expand Down
1 change: 1 addition & 0 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ let rec doc_typ (Typ_aux (t, _) as typ) =
match t with
| Typ_id (Id_aux (Id "unit", _)) -> string "Unit"
| Typ_id (Id_aux (Id "int", _)) -> string "Int"
| Typ_id (Id_aux (Id "bool", _)) -> string "Bool"
| Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _)]) -> string "BitVec " ^^ doc_nexp m
| Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) doc_typ ts)
| Typ_id (Id_aux (Id id, _)) -> string id
Expand Down
30 changes: 30 additions & 0 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
def bitvector_eq (x : BitVec 16) (y : BitVec 16) : Bool :=
(Eq x y)

def bitvector_neq (x : BitVec 16) (y : BitVec 16) : Bool :=
(Ne x y)

def bitvector_append (x : BitVec 16) (y : BitVec 16) : BitVec 32 :=
(BitVec.append x y)

def bitvector_add (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
(HAdd.hAdd x y)

def bitvector_sub (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
(HSub.hSub x y)

def bitvector_not (x : BitVec 16) : BitVec 16 :=
(Complement.complement x)

def bitvector_and (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
(HAnd.hAnd x y)

def bitvector_or (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
(HOr.hOr x y)

def bitvector_xor (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
(HXor.hXor x y)

def initialize_registers : Unit :=
()

49 changes: 49 additions & 0 deletions test/lean/bitvec_operation.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
default Order dec

$include <prelude.sail>

val bitvector_eq : (bits(16), bits(16)) -> bool
function bitvector_eq(x, y) = {
x == y
}

val bitvector_neq : (bits(16), bits(16)) -> bool
function bitvector_neq(x, y) = {
x != y
}

val bitvector_append : (bits(16), bits(16)) -> bits(32)
function bitvector_append(x, y) = {
append (x, y)
}

val bitvector_add : (bits(16), bits(16)) -> bits(16)
function bitvector_add(x, y) = {
x + y
}

val bitvector_sub : (bits(16), bits(16)) -> bits(16)
function bitvector_sub(x, y) = {
sub_bits (x, y)
}

val bitvector_not : bits(16) -> bits(16)
function bitvector_not(x) = {
not_vec (x)
}

val bitvector_and : (bits(16), bits(16)) -> bits(16)
function bitvector_and(x, y) = {
x & y
}

val bitvector_or : (bits(16), bits(16)) -> bits(16)
function bitvector_or(x, y) = {
x | y
}

val bitvector_xor : (bits(16), bits(16)) -> bits(16)
function bitvector_xor(x, y) = {
xor_vec (x, y)
}

2 changes: 1 addition & 1 deletion test/lean/extern_bitvec.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def extern_const : BitVec 64 :=
(0xFFFF000012340000 : BitVec 64)

def extern_add : BitVec 16 :=
(Add.add (0xFFFF : BitVec 16) (0x1234 : BitVec 16))
(HAdd.hAdd (0xFFFF : BitVec 16) (0x1234 : BitVec 16))

def initialize_registers : Unit :=
()
Expand Down

0 comments on commit d395f09

Please sign in to comment.