Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predicate support #656

Merged
merged 19 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions SSA/Experimental/Bits/Defs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ inductive Term : Type
| incr : Term → Term
/-- Decrement (i.e., subtract one) -/
| decr : Term → Term
/-- `repeatBit` is an operation that will repeat the infinitely repeat the
least significant `true` bit of the input.

That is `repeatBit t` is all-zeroes iff `t` is all-zeroes.
Otherwise, there is some number `k` s.t. `repeatBit t` is all-ones after
dropping the least significant `k` bits -/
| repeatBit : Term → Term

open Term

Expand All @@ -53,20 +60,21 @@ so eval requires us to give a value for each possible variable.
-/
def Term.eval (t : Term) (vars : Nat → BitStream) : BitStream :=
match t with
| var n => vars n
| zero => BitStream.zero
| one => BitStream.one
| negOne => BitStream.negOne
| and t₁ t₂ => (t₁.eval vars) &&& (t₂.eval vars)
| or t₁ t₂ => (t₁.eval vars) ||| (t₂.eval vars)
| xor t₁ t₂ => (t₁.eval vars) ^^^ (t₂.eval vars)
| not t => ~~~(t.eval vars)
| ls b t => (Term.eval t vars).concat b
| add t₁ t₂ => (Term.eval t₁ vars) + (Term.eval t₂ vars)
| sub t₁ t₂ => (Term.eval t₁ vars) - (Term.eval t₂ vars)
| neg t => -(Term.eval t vars)
| incr t => BitStream.incr (Term.eval t vars)
| decr t => BitStream.decr (Term.eval t vars)
| var n => vars n
| zero => BitStream.zero
| one => BitStream.one
| negOne => BitStream.negOne
| and t₁ t₂ => (t₁.eval vars) &&& (t₂.eval vars)
| or t₁ t₂ => (t₁.eval vars) ||| (t₂.eval vars)
| xor t₁ t₂ => (t₁.eval vars) ^^^ (t₂.eval vars)
| not t => ~~~(t.eval vars)
| ls b t => (Term.eval t vars).concat b
| add t₁ t₂ => (Term.eval t₁ vars) + (Term.eval t₂ vars)
| sub t₁ t₂ => (Term.eval t₁ vars) - (Term.eval t₂ vars)
| neg t => -(Term.eval t vars)
| incr t => BitStream.incr (Term.eval t vars)
| decr t => BitStream.decr (Term.eval t vars)
| repeatBit t => BitStream.repeatBit (Term.eval t vars)

instance : Add Term := ⟨add⟩
instance : Sub Term := ⟨sub⟩
Expand Down Expand Up @@ -94,6 +102,7 @@ a term like `var 10` only has a single free variable, but its arity will be `11`
| neg t => arity t
| incr t => arity t
| decr t => arity t
| repeatBit t => arity t

/--
Evaluate a term `t` to the BitStream it represents.
Expand Down Expand Up @@ -130,6 +139,7 @@ and only require that many bitstream values to be given in `vars`.
let x₁ := t₁.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
let x₂ := t₂.evalFin (fun i => vars (Fin.castLE (by simp [arity]) i))
x₁ - x₂
| neg t => -(Term.evalFin t vars)
| incr t => BitStream.incr (Term.evalFin t vars)
| decr t => BitStream.decr (Term.evalFin t vars)
| neg t => -(Term.evalFin t vars)
| incr t => BitStream.incr (Term.evalFin t vars)
| decr t => BitStream.decr (Term.evalFin t vars)
| repeatBit t => BitStream.repeatBit (Term.evalFin t vars)
19 changes: 19 additions & 0 deletions SSA/Experimental/Bits/Fast/BitStream.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ section Lemmas
theorem ext {x y : BitStream} (h : ∀ i, x i = y i) : x = y := by
funext i; exact h i

theorem corec_eq_corec {a : α} {b : β} {f g}
(R : α → β → Prop)
(h : ∀ a b, R a b →
let x := f a
let y := g b
R x.fst y.fst ∧ x.snd = y.snd) :
corec f a = corec g b := by
sorry
bollu marked this conversation as resolved.
Show resolved Hide resolved

end Lemmas

end Basic
Expand Down Expand Up @@ -276,6 +285,16 @@ instance : Add BitStream := ⟨add⟩
instance : Neg BitStream := ⟨neg⟩
instance : Sub BitStream := ⟨sub⟩

/-- `repeatBit xs` will repeat the first bit of `xs` which is `true`.
That is, it will be all-zeros iff `xs` is all-zeroes,
otherwise, there's some number `k` so that after dropping the `k` least
significant bits, `repeatBit xs` is all-ones. -/
def repeatBit (xs : BitStream) : BitStream :=
corec (b := (false, xs)) fun (carry, xs) =>
Comment on lines +318 to +323
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we wanted to change this name? I like foldOr more than repeatBit

let carry := carry || xs 0
let xs := xs.tail
((carry, xs), carry)

/-!
TODO: We should define addition and `carry` in terms of `mapAccum`.
For example:
Expand Down
146 changes: 146 additions & 0 deletions SSA/Experimental/Bits/Fast/FiniteStateMachine.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import SSA.Experimental.Bits.Fast.Circuit

open Sum

section FSM
variable {α β α' β' : Type} {γ : β → Type}

/-- `FSM n` represents a function `BitStream → ⋯ → BitStream → BitStream`,
Expand Down Expand Up @@ -103,6 +104,15 @@ theorem eval_eq_carry (x : arity → BitStream) (n : ℕ) :
p.eval x n = (p.nextBit (p.carry x n) (fun i => x i n)).2 :=
rfl

theorem eval_eq_eval' :
p.eval x = p.eval' x := by
funext i
simp only [eval, eval']
induction i generalizing p x
case zero => rfl
case succ i ih =>
sorry

/-- `p.changeVars f` changes the arity of an `FSM`.
The function `f` determines how the new input bits map to the input expected by `p` -/
def changeVars {arity2 : Type} (changeVars : arity → arity2) : FSM arity2 :=
Expand Down Expand Up @@ -475,6 +485,22 @@ theorem eval_eq_zero_of_set {arity : Type _} (p : FSM arity)
rw [eval]
exact (evalAux_eq_zero_of_set p R hR hi hr1 x n).1

def repeatBit : FSM Unit where
α := Unit
initCarry := fun () => false
nextBitCirc := fun _ =>
.or (.var true <| .inl ()) (.var true <| .inr ())

@[simp] theorem eval_repeatBit :
repeatBit.eval x = BitStream.repeatBit (x ()) := by
unfold BitStream.repeatBit
rw [eval_eq_eval', eval']
apply BitStream.corec_eq_corec
(R := fun a b => a.1 () = b.2 ∧ (a.2 ()) = b.1)
intro ⟨y, a⟩ ⟨b, x⟩ h
simp only at h
simp [h, nextBit, BitStream.head]

end FSM

structure FSMSolution (t : Term) extends FSM (Fin t.arity) :=
Expand Down Expand Up @@ -504,6 +530,19 @@ def composeBinary
| true => q₁.toFSM
| false => q₂.toFSM)

def composeBinary'
(p : FSM Bool)
{n m : Nat}
(q₁ : FSM (Fin n))
(q₂ : FSM (Fin m)) :
FSM (Fin (max n m)) :=
p.compose (Fin (max n m))
(λ b => Fin (cond b n m))
(λ b i => Fin.castLE (by cases b <;> simp) i)
(λ b => match b with
| true => q₁
| false => q₂)

@[simp] lemma composeUnary_eval
(p : FSM Unit)
{t : Term}
Expand Down Expand Up @@ -589,6 +628,69 @@ def termEvalEqFSM : ∀ (t : Term), FSMSolution t
let q := termEvalEqFSM t
{ toFSM := by dsimp [arity]; exact composeUnary FSM.decr q,
good := by ext; simp }
| repeatBit t =>
let p := termEvalEqFSM t
{ toFSM := by dsimp [arity]; exact composeUnary FSM.repeatBit p,
good := by ext; simp }

/-!
FSM that implement bitwise-and. Since we use `0` as the good state,
we keep the invariant that if both inputs are good and our state is `0`, then we produce a `0`.
If not, we produce an infinite sequence of `1`.
-/
def and : FSM Bool :=
{ α := Unit,
initCarry := fun _ => false,
nextBitCirc := fun a =>
match a with
| some () =>
-- Only if both are `0` we produce a `0`.
(Circuit.var true (inr false) |||
((Circuit.var false (inr true) |||
-- But if we have failed and have value `1`, then we produce a `1` from our state.
(Circuit.var true (inl ())))))
| none => -- must succeed in both arguments, so we are `0` if both are `0`.
Circuit.var true (inr true) |||
Circuit.var true (inr false)
}

/-!
FSM that implement bitwise-or. Since we use `0` as the good state,
we keep the invariant that if either inputs is `0` then our state is `0`.
If not, we produce a `1`.
-/
def or : FSM Bool :=
{ α := Unit,
initCarry := fun _ => false,
nextBitCirc := fun a =>
match a with
| some () =>
-- If either succeeds, then the full thing succeeds
((Circuit.var true (inr false) &&&
((Circuit.var false (inr true)) |||
-- On the other hand, if we have failed, then propagate failure.
(Circuit.var true (inl ())))))
| none => -- can succeed in either argument, so we are `0` if either is `0`.
Circuit.var true (inr true) &&&
Circuit.var true (inr false)
}

/-!
FSM that implement logical not.
we keep the invariant that if the input ever fails and becomes a `1`, then we produce a `0`.
IF not, we produce an infinite sequence of `1`.

EDIT: Aha, this doesn't work!
We need NFA to DFA here (as the presburger book does),
where we must produce an infinite sequence of`0` iff the input can *ever* become a `1`.
But here, since we phrase things directly in terms of producing sequences, it's a bit less clear
what we should do :)

- Alternatively, we need to be able to decide `eventually always zero`.
- Alternatively, we push negations inside, and decide `⬝ ≠ ⬝` and `⬝ ≰ ⬝`.
-/
def lnot : FSM Unit := sorry


inductive Result : Type
| falseAfter (n : ℕ) : Result
Expand Down Expand Up @@ -679,3 +781,47 @@ theorem decideIfZeros_correct {arity : Type _} [DecidableEq arity]
intro x s h
use x
exact h

end FSM

/--
The fragment of predicate logic that we support in `bv_automata`.
Currently, we support equality, conjunction, disjunction, and negation.
This can be expanded to also support arithmetic constraints such as unsigned-less-than.
-/
inductive Predicate : Nat → Type _ where
| eq (t1 t2 : Term) : Predicate ((max t1.arity t2.arity))
| and (p : Predicate n) (q : Predicate m) : Predicate (max n m)
| or (p : Predicate n) (q : Predicate m) : Predicate (max n m)
-- For now, we can't prove `not`, because it needs NFA → DFA conversion
-- the way Sid knows how to build it, or negation normal form,
-- both of which is machinery we lack.
-- | not (p : Predicate n) : Predicate n



/--
denote a reflected `predicate` into a `prop.
-/
def Predicate.denote : Predicate α → Prop
| eq t1 t2 => t1.eval = t2.eval
| and p q => p.denote ∧ q.denote
| or p q => p.denote ∨ q.denote
Comment on lines +807 to +810
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not quite right! We need to be taking in the arguments at the top-level, and pass those around.
Currently, we're denoting eq t u as equality of the functions that are denoted by t and u, meaning that we're universally quantifying each con-/disjunct separately.

def Predicate.denote (xs : α → BitStream) : Predicate α → Prop
| eq t1 t2 => t1.eval xs = t2.eval xs
| and p q => p.denote xs ∧ q.denote xs
| or p q => p.denote xs ∨ q.denote xs

-- | not p => ¬ p.denote

/--
Convert a predicate into a proposition
-/
def Predicate.toFSM : Predicate k → FSM (Fin k)
| .eq t1 t2 => (termEvalEqFSM (Term.repeatBit <| Term.xor t1 t2)).toFSM
| .and p q =>
let p := toFSM p
let q := toFSM q
composeBinary' FSM.and p q
| .or p q =>
let p := toFSM p
let q := toFSM q
composeBinary' FSM.or p q

theorem Predicate.toFsm_correct {k : Nat} (p : Predicate k) :
decideIfZeros p.toFSM = true ↔ p.denote := by sorry
Comment on lines +827 to +828
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this needs to be:

theorem Predicate.toFsm_correct {k : Nat} (p : Predicate k) :
  decideIfZeros p.toFSM = true ↔ (\all xs, p.denote xs) := by sorry

Loading