Skip to content

Commit

Permalink
Chore: Change Automata Tactic from Equality to Congruence on w Bits (#…
Browse files Browse the repository at this point in the history
…507)

In a previous version of the Automata tactic, I had relied on statements
for re-writing that were not true, for example

```lean
@[simp] theorem ofBitVec_add : ofBitVec (x + y) = (ofBitVec x) + (ofBitVec y) := sorry
```

But in this patch, I use a Simproc to only require a weaker statement:

```lean
@[simp] theorem ofBitVec_add : EqualUpTo w (ofBitVec (x + y))  ((ofBitVec x) + (ofBitVec y)) := sorry
```

So the code has become uglier and more complicated, but at least all the
sorries are provable now.

The main new point of complication is that I introduced a new Simproc
(because I no longer rely on the congruence properties of equality), and
this makes the code more complicated.

The reason why is that for equality, congruence comes for free, i.e. if
a=b, then f(a) = f(b). For other equivalence relations, congruence is
not automatic, so we have to prove and apply congruence manually.

There are still three sorries in the tactic, namely

```lean

@[simp] theorem ofBitVec_sub : ofBitVec (x - y) ≈ʷ (ofBitVec x) - (ofBitVec y)  := by
  sorry

@[simp] theorem ofBitVec_add : ofBitVec (x + y) ≈ʷ (ofBitVec x) + (ofBitVec y)  := by
  sorry

@[simp] theorem ofBitVec_neg : ofBitVec (- x) ≈ʷ  - (ofBitVec x) := by
  sorry
```

But as I find these statements difficult to prove, I will prove them in
a later PR.

cc: @Equilibris

---------

Co-authored-by: Atticus Kuhn <[email protected]>
Co-authored-by: Atticus Kuhn <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent 6bcf937 commit cf6d655
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 68 deletions.
90 changes: 85 additions & 5 deletions SSA/Experimental/Bits/Fast/BitStream.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import Mathlib.Tactic.NormNum

import Mathlib.Logic.Function.Iterate

-- TODO: upstream the following section
section UpStream

Expand Down Expand Up @@ -264,6 +263,12 @@ variable (x y : BitVec (w+1))
simp only [ofBitVec, BitVec.getLsb_xor, xor_eq]
split <;> simp_all

@[simp] theorem ofBitVec_not : ofBitVec (~~~ x) = ~~~ (ofBitVec x) := by
funext i
simp only [ofBitVec, BitVec.getLsb_not, BitVec.msb_not, lt_add_iff_pos_left, add_pos_iff,
zero_lt_one, or_true, decide_True, Bool.true_and, not_eq]
split <;> simp_all

end Lemmas

end BitwiseOps
Expand Down Expand Up @@ -363,10 +368,85 @@ Crucially, our decision procedure works by considering which equalities hold for
-- (∀ w, (x w + y w) = z w) ↔ (∀ w, (ofBitVec (x w)) + (ofBitVec (y w)) ) := by
-- have ⟨h₁, h₂⟩ : True ∧ True := sorry
-- sorry
@[simp] theorem ofBitVec_sub : ofBitVec (x - y) = (ofBitVec x) - (ofBitVec y) := sorry
@[simp] theorem ofBitVec_add : ofBitVec (x + y) = (ofBitVec x) + (ofBitVec y) := sorry
@[simp] theorem ofBitVec_neg : ofBitVec (-x) = -(ofBitVec x) := sorry
@[simp] theorem ofBitVec_not : ofBitVec (~~~ x) = ~~~ (ofBitVec x) := sorry

variable {w : Nat} {x y : BitVec w} {a b a' b' : BitStream}

local infix:20 " ≈ʷ " => EqualUpTo w

-- TODO: These sorries are difficult, and will be proven in a later Pull Request.
@[simp] theorem ofBitVec_sub : ofBitVec (x - y) ≈ʷ (ofBitVec x) - (ofBitVec y) := by
sorry

@[simp] theorem ofBitVec_add : ofBitVec (x + y) ≈ʷ (ofBitVec x) + (ofBitVec y) := by
sorry

@[simp] theorem ofBitVec_neg : ofBitVec (- x) ≈ʷ - (ofBitVec x) := by
sorry

theorem equal_up_to_refl : a ≈ʷ a := by
intros j _
rfl

theorem equal_up_to_symm (e : a ≈ʷ b) : b ≈ʷ a := by
intros j h
symm
exact e j h

theorem equal_up_to_trans (e1 : a ≈ʷ b) (e2 : b ≈ʷ c) : a ≈ʷ c := by
intros j h
trans b j
exact e1 j h
exact e2 j h

instance congr_equiv : Equivalence (EqualUpTo w) := {
refl := fun _ => equal_up_to_refl,
symm := equal_up_to_symm,
trans := equal_up_to_trans
}

theorem sub_congr (e1 : a ≈ʷ b) (e2 : c ≈ʷ d) : (a - c) ≈ʷ (b - d) := by
intros n h
have sub_congr_lemma : a.subAux c n = b.subAux d n := by
induction n
<;> simp only [subAux, Prod.mk.injEq, e1 _ h, e2 _ h, and_self]
rename_i _ ih
simp only [ih (by omega), and_self]
simp only [HSub.hSub, Sub.sub, BitStream.sub, sub_congr_lemma]

theorem add_congr (e1 : a ≈ʷ b) (e2 : c ≈ʷ d) : (a + c) ≈ʷ (b + d) := by
intros n h
have add_congr_lemma : a.addAux c n = b.addAux d n := by
induction n
<;> simp only [addAux, Prod.mk.injEq, e1 _ h, e2 _ h]
rename_i _ ih
simp only [ih (by omega), Bool.bne_right_inj]
simp only [HAdd.hAdd, Add.add, BitStream.add, add_congr_lemma]

theorem neg_congr (e1 : a ≈ʷ b) : (-a) ≈ʷ -b := by
intros n h
have neg_congr_lemma : a.negAux n = b.negAux n := by
induction n
<;> simp only [negAux, Prod.mk.injEq, (e1 _ h)]
rename_i _ ih
simp only [ih (by omega), Bool.bne_right_inj, and_self]
simp only [Neg.neg, BitStream.neg, neg_congr_lemma]

theorem not_congr (e1 : a ≈ʷ b) : (~~~a) ≈ʷ ~~~b := by
intros g h
simp only [not_eq, e1 g h]

theorem equal_trans (e1 : a ≈ʷ b) (e2 : c ≈ʷ d) : (a ≈ʷ c) = (b ≈ʷ d) := by
apply propext
constructor
<;> intros h
· apply equal_up_to_trans _ e2
apply equal_up_to_trans _ h
apply equal_up_to_symm
assumption
· apply equal_up_to_trans _
apply (equal_up_to_symm e2)
apply equal_up_to_trans _ h
assumption

end Lemmas

Expand Down
175 changes: 112 additions & 63 deletions SSA/Experimental/Bits/Fast/Tactic.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import Lean.Meta.Tactic.Simp.BuiltinSimprocs
import SSA.Experimental.Bits.Fast.FiniteStateMachine
import SSA.Experimental.Bits.Fast.BitStream
import SSA.Experimental.Bits.Fast.Decide
import SSA.Experimental.Bits.Lemmas
Expand All @@ -9,6 +8,7 @@ open Lean Elab Tactic
open Lean Meta
open scoped Qq


/-!
# BitVec Automata Tactic
There are two ways of expressing BitVec expressions. One is:
Expand All @@ -30,32 +30,33 @@ we have a decision procedure that decides equality on expressions of the second

section EvalLemmas
variable {x y : _root_.Term} {vars : Nat → BitStream}
def sub_eval :
(Term.sub x y).eval vars = x.eval vars - y.eval vars := by

lemma eval_sub :
(x.sub y).eval vars = x.eval vars - y.eval vars := by
simp only [Term.eval]

def add_eval :
(Term.add x y).eval vars = x.eval vars + y.eval vars := by
lemma eval_add :
(x.add y).eval vars = x.eval vars + y.eval vars := by
simp only [Term.eval]

def neg_eval :
(Term.neg x).eval vars = - x.eval vars := by
lemma eval_neg :
(x.neg).eval vars = - x.eval vars := by
simp only [Term.eval]

def not_eval :
(Term.not x).eval vars = ~~~ x.eval vars := by
lemma eval_not :
(x.not).eval vars = ~~~ x.eval vars := by
simp only [Term.eval]

def and_eval :
(Term.and x y).eval vars = x.eval vars &&& y.eval vars := by
lemma eval_and :
(x.and y).eval vars = x.eval vars &&& y.eval vars := by
simp only [Term.eval]

def xor_eval :
(Term.xor x y).eval vars = x.eval vars ^^^ y.eval vars := by
lemma eval_xor :
(x.xor y).eval vars = x.eval vars ^^^ y.eval vars := by
simp only [Term.eval]

def or_eval :
(Term.or x y).eval vars = x.eval vars ||| y.eval vars := by
lemma eval_or :
(x.or y).eval vars = x.eval vars ||| y.eval vars := by
simp only [Term.eval]
end EvalLemmas

Expand All @@ -78,7 +79,7 @@ def termNatCorrect (f : Nat → BitStream) (w n : Nat) : BitStream.ofBitVec (Bi
exact incrBit w n


def quoteThm (qMapIndexToFVar : Q(Nat → BitStream)) (w : Q(Nat)) (nat: Nat) : Q(@Eq (BitStream) (BitStream.ofBitVec (@BitVec.ofNat $w $nat)) (@Term.eval (termNat $(nat)) $qMapIndexToFVar)) := q(termNatCorrect $qMapIndexToFVar $w $nat)
def quoteThm (qMapIndexToFVar : Q(Nat → BitStream)) (w : Q(Nat)) (nat: Nat) : Q(@Eq BitStream (BitStream.ofBitVec (@BitVec.ofNat $w $nat)) (@Term.eval (termNat $nat) $qMapIndexToFVar)) := q(termNatCorrect $qMapIndexToFVar $w $nat)

/--
Simplify BitStream.ofBitVec x where x is an FVar.
Expand All @@ -102,7 +103,7 @@ simproc reduce_bitvec (BitStream.ofBitVec _) := fun e => do
| .fvar x => do
let p : Q(Nat) := quoteFVar x
return .done { expr := q(Term.eval (Term.var $p) $qMapIndexToFVar)}
| x =>
| x =>
match_expr x with
| BitVec.ofNat a b =>
let nat := b.nat?
Expand All @@ -116,36 +117,72 @@ simproc reduce_bitvec (BitStream.ofBitVec _) := fun e => do
}
| _ => throwError m!"reduce_bitvec: Expression {x} is not a nat literal"

/-!
# Helper function to construct Exprs
-/


/--
Helper function to construct an equality expression
Given an Expr e, return a pair e', p where e' is an expression and p is a proof that e and e' are equal on the fist w bits
-/
def eqE (left : Q(Nat)) (right : Q(Nat)) : Q(Prop) :=
q($left = $right)
partial def first_rep (w : Q(Nat)) (e : Q( BitStream)) : SimpM (Σ (x : Q(BitStream)) , Q(@BitStream.EqualUpTo $w $e $x)) :=
match e with
| ~q(@HSub.hSub BitStream BitStream BitStream _ $a $b) => do
let ⟨ anext, aproof ⟩ ← first_rep w a
let ⟨ bnext, bproof ⟩ ← first_rep w b
return
q(@HSub.hSub BitStream BitStream BitStream _ $anext $bnext),
q(@BitStream.sub_congr $w $a $anext $b $bnext $aproof $bproof)
| ~q(@BitStream.ofBitVec $w ($a - $b)) =>
return
q((@BitStream.ofBitVec $w $a) - (@BitStream.ofBitVec $w $b)),
.app (.app (.app (.const ``BitStream.ofBitVec_sub []) w) a ) b
| ~q(@HAdd.hAdd BitStream BitStream BitStream _ $a $b) => do
let ⟨ anext, aproof ⟩ ← first_rep w a
let ⟨ bnext, bproof ⟩ ← first_rep w b
return
q($anext + $bnext),
.app (.app (.app (.app (.app (.app (.app (.const ``BitStream.add_congr []) w) a) anext) b) bnext) aproof) bproof
| ~q(@BitStream.ofBitVec $w ($a + $b)) =>
return
q(@HAdd.hAdd BitStream BitStream BitStream _ (@BitStream.ofBitVec $w $a) (@BitStream.ofBitVec $w $b)),
.app (.app (.app (.const ``BitStream.ofBitVec_add []) w) a ) b
| ~q(@Neg.neg BitStream _ $a)=> do
let ⟨ anext, aproof ⟩ ← first_rep w a
return
q(-$anext),
(.app (.app (.app (.app (.const ``BitStream.neg_congr []) w) a) anext) aproof)
| ~q(@BitStream.ofBitVec $w (@Neg.neg (BitVec $w) _ $a)) => do
return
q(@Neg.neg BitStream _ (@BitStream.ofBitVec $w $a)),
.app (.app (.const ``BitStream.ofBitVec_neg []) w) a
| ~q(@Complement.complement BitStream _ $a) => do
let ⟨ anext, aproof ⟩ ← first_rep w a
return
q(~~~ $anext),
(.app (.app (.app (.app (.const ``BitStream.not_congr []) w) a) anext) aproof)
| e =>
return
e,
.app (.app (.const ``BitStream.equal_up_to_refl []) w) e

/--
Helper function to construct an if then else expression
Push all ofBitVecs down to the lowest level
-/
def iteE (length : Q(Nat)) (left : Q(Nat)) (right : Q(Nat)) (ifTrue : Expr) (ifFalse : Expr) : Expr :=
((((((Expr.const `ite [Level.zero.succ]).app (.app (.const ``BitVec []) length)).app
(eqE left right)).app
(((Expr.const `instDecidableEqNat []).app left).app right)).app
ifTrue).app
ifFalse)

/--
Helper function to construct a function expression
-/
def funE (length : Q(Nat)) (body : Q(BitStream)) : Q(Nat → BitStream):=
(Expr.lam `n (Expr.const `Nat [])
(((Expr.const `BitStream.ofBitVec []).app
length).app
body)
BinderInfo.default)
simproc reduce_bitvec2 (BitStream.EqualUpTo (_ : Nat) _ _) := fun e => do
match (e : Q(Prop)) with
| .app (.app (.app (.const ``BitStream.EqualUpTo []) w) l ) r => do
let ⟨ lterm, lproof ⟩ ← first_rep w l
let ⟨ rterm, rproof ⟩ ← first_rep w r
return .done {
expr := .app (.app (.app (.const ``BitStream.EqualUpTo []) w) lterm) rterm
proof? :=
some (.app (.app (.app (.app (.app (.app (.app (.const ``BitStream.equal_trans []) w) l) lterm) r) rterm) lproof) rproof)
}
| _ => throwError m!"Expression {e} is not of the expected form. Expected something of the form BitStream.EqualUpTo (w : Nat) (lhs : BitStream) (rhs : BitStream) : Prop"

/--
Introduce vars which maps variable ids to the variable values.
Expand Down Expand Up @@ -174,10 +211,20 @@ def introduceMapIndexToFVar : TacticM Unit := do withMainContext <| do
let length : Expr := a
let hypValue : Expr := fVars.foldl (fun (accumulator : Expr) (currentFVar : FVarId) =>
let quotedCurrentFVar : Expr := .fvar currentFVar
let fVarId : Expr := quoteFVar currentFVar
iteE length lastBVar fVarId quotedCurrentFVar accumulator
let fVarId : Q(Nat) := quoteFVar currentFVar
let eqE : Q(Prop) := q($lastBVar = $fVarId);
((((((Expr.const `ite [Level.zero.succ]).app (.app (.const ``BitVec []) length)).app
eqE).app
(((Expr.const ``instDecidableEqNat []).app lastBVar).app fVarId)).app
quotedCurrentFVar).app
accumulator)
) (.fvar last)
let mapIndexToFVar : Expr := funE length hypValue
let mapIndexToFVar : Q(Nat → BitStream):=
(Expr.lam `n (Expr.const `Nat [])
(((Expr.const `BitStream.ofBitVec []).app
length).app
hypValue)
BinderInfo.default)
let newGoal : MVarId ← goal.define `vars mapIndexToFVarType mapIndexToFVar
replaceMainGoal [newGoal]
| _ => throwError "Goal is not of the expected form"
Expand All @@ -190,42 +237,45 @@ Create bv_automata tactic which solves equalities on bitvectors.
macro "bv_automata" : tactic =>
`(tactic| (
apply BitStream.eq_of_ofBitVec_eq
simp only [
BitStream.ofBitVec_sub,
BitStream.ofBitVec_or,
repeat simp only [
reduce_bitvec2,
BitStream.ofBitVec_not,
BitStream.ofBitVec_xor,
BitStream.ofBitVec_and,
BitStream.ofBitVec_not,
BitStream.ofBitVec_add,
BitStream.ofBitVec_neg
BitStream.ofBitVec_or,
]
introduceMapIndexToFVar
intro mapIndexToFVar
simp only [
sub_eval,
add_eval,
neg_eval,
and_eval,
xor_eval,
or_eval,
not_eval,
eval_sub,
eval_add,
eval_neg,
eval_and,
eval_xor,
eval_or,
eval_not,
reduce_bitvec,
Nat.reduceAdd,
BitVec.ofNat_eq_ofNat
]
intros _ _
repeat (apply congrFun)
apply congrFun
apply congrFun
native_decide
))


/-!
# Test Cases
-/


def test0 {w : Nat} (x y : BitVec (w + 1)) : x + 0 = x := by
bv_automata

def test_simple2 {w : Nat} (x y : BitVec (w + 1)) : x = x := by
bv_automata

def test1 {w : Nat} (x y : BitVec (w + 1)) : (x ||| y) - (x ^^^ y) = x &&& y := by
bv_automata

Expand All @@ -236,9 +286,9 @@ def test3 (x y : BitVec 300) : ((x ||| y) - (x ^^^ y)) = (x &&& y) := by
bv_automata

def test4 (x y : BitVec 2) : (x + -y) = (x - y) := by
bv_automata
bv_automata

def test5 (x y : BitVec 2) : (x + y) = (y + x) := by
def test5 (x y z : BitVec 2) : (x + y + z) = (z + y + x) := by
bv_automata

def test6 (x y z : BitVec 2) : (x + (y + z)) = (x + y + z) := by
Expand Down Expand Up @@ -280,7 +330,6 @@ def test27 (x y : BitVec 5) : 2 + x = 1 + x + 1 := by
def test28 {w : Nat} (x y : BitVec (w + 1)) : x &&& x &&& x &&& x &&& x &&& x = x := by
bv_automata


-- This test is commented out because it takes over a minute to run
-- def broken_test (x y : BitVec 5) : 2 + x + 2 = x + 4 := by
-- bv_automata

0 comments on commit cf6d655

Please sign in to comment.