From 24b76a6e83ebf7f5839d29715ab1591485bca2d2 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Mon, 30 Sep 2024 01:23:23 -0500 Subject: [PATCH] WIP: FSM abstractions --- SSA/Experimental/Bits/Fast/BitStream.lean | 4 +- SSA/Experimental/Bits/Fast/Circuit.lean | 2 +- .../Bits/Fast/FiniteStateMachine.lean | 717 +++++++++++------- 3 files changed, 434 insertions(+), 289 deletions(-) diff --git a/SSA/Experimental/Bits/Fast/BitStream.lean b/SSA/Experimental/Bits/Fast/BitStream.lean index 5805ad5fc..1132fdaa9 100644 --- a/SSA/Experimental/Bits/Fast/BitStream.lean +++ b/SSA/Experimental/Bits/Fast/BitStream.lean @@ -85,8 +85,10 @@ abbrev map₂ (f : Bool → Bool → Bool) : BitStream → BitStream → BitStre def corec {β} (f : β → β × Bool) (b : β) : BitStream := fun i => f ((Prod.fst ∘ f)^[i] b) |>.snd +@[simp] theorem corec_zero {β} (f : β → β × Bool) (b : β) : + corec f b 0 = (f b).2 := rfl -theorem corec_succ {β} (f : β → β × Bool) (b : β) (i : Nat) : +@[simp] theorem corec_succ {β} (f : β → β × Bool) (b : β) (i : Nat) : corec f b (i + 1) = (corec f (f b).1) i := by induction' i with i ih · simp [corec] diff --git a/SSA/Experimental/Bits/Fast/Circuit.lean b/SSA/Experimental/Bits/Fast/Circuit.lean index a384957a1..7e52da361 100644 --- a/SSA/Experimental/Bits/Fast/Circuit.lean +++ b/SSA/Experimental/Bits/Fast/Circuit.lean @@ -221,7 +221,7 @@ def map : ∀ (_c : Circuit α) (_f : α → β), Circuit β | or c₁ c₂, f => (map c₁ f) ||| (map c₂ f) | xor c₁ c₂, f => (map c₁ f) ^^^ (map c₂ f) -lemma eval_map {c : Circuit α} {f : α → β} {g : β → Bool} : +@[simp] lemma eval_map {c : Circuit α} {f : α → β} {g : β → Bool} : eval (map c f) g = eval c (λ x => g (f x)) := by induction c <;> simp [*, Circuit.map, eval] at * diff --git a/SSA/Experimental/Bits/Fast/FiniteStateMachine.lean b/SSA/Experimental/Bits/Fast/FiniteStateMachine.lean index 5650b124c..9e7fe384f 100644 --- a/SSA/Experimental/Bits/Fast/FiniteStateMachine.lean +++ b/SSA/Experimental/Bits/Fast/FiniteStateMachine.lean @@ -13,278 +13,398 @@ open Sum section FSM variable {σ β σ' β' : Type} {γ : β → Type} -/-- An `arity` indexed product of Booleans. -/ -def BoolProd (arity : Type) : Type := arity → Bool - -def BoolProd.append (x : BoolProd σ) (y : BoolProd β) : BoolProd (Sum σ β) := - fun a => match a with - | inl a => x a - | inr b => y b +namespace BitVec @[simp] -theorem BoolProd.append_inl (x : BoolProd σ) (y : BoolProd β) (a : σ) : - (x.append y) (inl a) = x a := rfl +theorem getLsbD_append_left (x : BitVec v) (y : BitVec w) (i : Nat) + (h : ¬(i < w)) : + getLsbD (x ++ y) i = x.getLsbD (i - w) := by + simp [getLsbD_append, h] @[simp] -theorem BoolProd.append_inr (x : BoolProd σ) (y : BoolProd β) (b : β) : - (x.append y) (inr b) = y b := rfl +theorem getLsbD_append_right (x : BitVec v) (y : BitVec w) (i : Nat) + (h : i < w) : + getLsbD (x ++ y) i = y.getLsbD i := by + simp [getLsbD_append, h] -/-- -Morally, a function from (BoolProd σ) to (BoolProd β), but the function is represented -by mapping each output `(b : β)` to a circuit that computes the output bit `b` from the input bits `(a : σ)`. +/-- `appendVector` appends a family of `n` bitvectors, each of which might have +a different width, together into a bitvector whose length is the sum of lengths -/ -def CircuitProd (σ : Type) (β : Type) : Type := β → Circuit σ +def appendVector {ws : Fin n → Nat} (xs : (i : Fin n) → BitVec (ws i)) : + BitVec (∑ i, ws i) := + match n with + | 0 => 0#_ + | n+1 => + let x := xs 0 + (x ++ (appendVector (fun i => xs i.succ))).cast <| by + simp [Finset.sum] + +/-- Construct a bitvector from a function that maps `i : Fin w` to the +`i`-th least significant bit -/ +def ofFnLsb (f : Fin w → Bool) : BitVec w := + match w with + | 0 => 0#0 + | _+1 => concat (ofFnLsb (f ∘ Fin.succ)) (f 0) + +@[simp] lemma getElem_ofFnLsb (f : Fin w → Bool) (i : Fin w) : + (ofFnLsb f)[i.val] = f i := by + sorry + +end BitVec + +namespace Fin + +def addToSum (i : Fin (x + y)) : Fin x ⊕ Fin y := + if h : i < y then + .inr ⟨i, h⟩ + else + .inl (i.subNat y <| by simp_all) -/-- A circuitProd can be evaluated to map σ product of booleans to a β product of booleans. -/ -def CircuitProd.eval (x : CircuitProd σ β) (y : BoolProd σ) : BoolProd β := - fun a => (x a).eval y +@[simp] abbrev addInl (i : Fin x) : Fin (x + y) := castAdd y i +@[simp] abbrev addInr (i : Fin y) : Fin (x + y) := natAdd x i -/-- A 'arity' indexed product of `Bitstream`s. -/ -def BitStreamProd (arity : Type) : Type := arity → BitStream +def addElim (f : Fin x → α) (g : Fin y → α) : Fin (x + y) → α := + fun i => Sum.elim f g (addToSum i) -def BitStreamProd.nthStream (x : BitStreamProd arity) (i : arity) : BitStream := x i -def BitStreamProd.nthBits (x : BitStreamProd arity) (n : Nat) : BoolProd arity := - fun i => (x i) n +def sumOfSigma {f : α → Nat} [Fintype α] (i : Σ a, Fin (f a)) : Fin (∑ a, f a) := + sorry -def BitStreamProd.head (x : BitStreamProd arity) : BoolProd arity := - fun i => (x i).head +def sumToSigma {f : α → Nat} [Fintype α] (i : Fin (∑ a, f a)) : Σ a, Fin (f a) := + sorry -def BitStreamProd.tail (x : BitStreamProd arity) : BitStreamProd arity := - fun a => (x a).tail +end Fin -/-- `FSM n` represents a function `BitStream → ⋯ → BitStream → BitStream`, -where `n` is the number of `BitStream` arguments, -as a finite state machine. --/ -structure FSM (arity : Type) : Type 1 := - /-- - The arity of the (finite) type `σ` (for `s`tate) determines how many bits the - internal carry state of this FSM has -/ - ( σ : Type ) -- Why is σ also not an index? - [ i : Fintype σ ] - [ dec_eq : DecidableEq σ ] - /-- - `initCarry` is the value of the initial internal carry state. - It maps each `σ` to a bit, thus it is morally a bitvector where the width is the arity of `σ` - -/ - ( initCarry : σ → Bool ) - /-- - `nextBitCirc` is a family of Boolean circuits, - which may refer to the current input bits *and* the current state bits - as free variables in the circuit. +/-- An `n`-ary product of `Bitstream`s. -/ +def BitStreamProd (n : Nat) : Type := Fin n → BitStream - `nextBitCirc none` computes the current output bit. - `nextBitCirc (some a)`, computes the *one* bit of the new state that corresponds to `a : σ`. -/ - (currentOutCircuit : Circuit (σ ⊕ arity) ) -- given the current state, and the arguments, compute the output - (nextStateFnCircuit : CircuitProd (σ ⊕ arity) σ ) -- For the new state, to compute the bit `(a : σ)`, given the circuit that computes this bit from the current state and the arguments. +namespace BitStreamProd -attribute [instance] FSM.i FSM.dec_eq +/-- Return the `i`-th stream of `x` -/ +def nthStream (x : BitStreamProd n) (i : Fin n) : BitStream := x i -namespace FSM +/-- Get the `i`th least significant bit of each constituent stream -/ +def getLsbs (xs : BitStreamProd n) (i : Nat) : BitVec n := + BitVec.ofFnLsb fun j => xs j i +/-- Get the least significant bit of each constituent stream -/ +def heads (xs : BitStreamProd n) : BitVec n := + BitVec.ofFnLsb fun i => (xs i).head -/-- The state of FSM `p` is given by a function from `p.σ` to `Bool`. +/-- Drop the least significant bit from each constituent stream, +returning an n-ary product of each streams tail -/ +def tails (xs : BitStreamProd n) : BitStreamProd n := + fun i => (xs i).tail -Note that `p.σ` is assumed to be a finite type, so `p.State` is morally -a finite bitvector whose width is given by the arity of `p.σ` -/ -abbrev State (p : FSM arity): Type := BoolProd p.σ +def castLE (h : n ≤ m) : BitStreamProd m → BitStreamProd n := + (· ∘ Fin.castLE h) -def State.appendInput {p : FSM arity} (s : p.State) (x : BoolProd arity) : - BoolProd (p.σ ⊕ arity) := - fun a => match a with - | inl a => s a - | inr b => x b +section Lemmas -/-- the current otuput of the FSM-/ -def outBit {p : FSM arity} (state : p.State) (input : BoolProd arity) : Bool := - (p.currentOutCircuit).eval (state.appendInput input) +@[simp] lemma getElem_heads (xs : BitStreamProd n) (i : Fin n) : + xs.heads[i.val] = (xs i).head := by + simp [heads] +@[simp] lemma getElem_getLsbs (xs : BitStreamProd n) (i : Nat) (j : Fin n) : + (xs.getLsbs i)[j.val] = xs j i := by + simp [getLsbs] -end FSM +end Lemmas + +end BitStreamProd + +/-- +`CircuitProd vars n` is a collection of `n` Boolean Circuits, each of which can +refer to at most `vars` variables. + +This morally represents a function from `BitVec vars` +(i.e., an assignment of a single bit per variable), +to a `BitVec n` (where each circuit computes a single bit of the output). +See `CircuitProd.eval`. +-/ +def CircuitProd (vars n : Nat) : Type := Fin n → Circuit (Fin vars) + +namespace CircuitProd + + +/-- Evaluate a `CircuitProd vars n` to the function `BitVec vars → BitVec n` +it represents. + +By convention, we use Little Endian order, which is to say, the `i`th circuit +will compute the `i`-th least significant bit of the output, and the variable +with index `i` derives it's assignment from the `i`-th least signicant bit of +the input. +-/ +def eval {vars n : Nat} + (circuit : CircuitProd vars n) (assignment : BitVec vars) : + BitVec n := + BitVec.ofFnLsb fun i => + (circuit i).eval assignment.getLsb' + +@[simp] lemma getLsbD_eval (c : CircuitProd vars n) (assignment : BitVec vars) + (i : Fin n) : + (c.eval assignment)[i.val] + = (c i).eval assignment.getLsb' := by + simp [eval] + +instance : Subsingleton (CircuitProd n 0) := + inferInstanceAs (Subsingleton (Fin 0 → _)) + +end CircuitProd + +/-- `FSM arity` represents a function `BitStream → ⋯ → BitStream → BitStream`, +where `arity` is the number of `BitStream` arguments, +as a finite state machine. +-/ +structure FSM (arity : Nat) : Type 1 := + /-- + `stateWidth` is the number of bits the state has + -/ + (stateWidth : Nat) + /-- + `initialState` is the initial state. + -/ + (initialState : BitVec stateWidth) + /-- + `outCircuit` is a single Boolean circuit, + which will compute the output bit of the current state, + given the current state and input bits. + -/ + (outCircuit : Circuit (Fin <| stateWidth + arity)) + /-- + `nextStateCircuit` is a uniform family of `stateWidth` Boolean circuits, + where each circuit computes one bit of the next state, + given the current state and input bits. + -/ + (nextStateCircuits : CircuitProd (stateWidth + arity) stateWidth) namespace FSM -variable {arity : Type} (p : FSM arity) +/-- A `State` of FSM `p` is just a bitvector with `p.stateWidth` bits -/ +abbrev State (p : FSM arity) : Type := BitVec p.stateWidth + +@[deprecated BitVec.append] +def appendInput {p : FSM arity} (s : BitVec p.stateWidth) (x : BitVec arity) : + BitVec (p.stateWidth + arity) := + s ++ x + +variable {arity : Nat} (p : FSM arity) -/-- The next state function, whieh evalutes the next state, given the current state and the input. -/ -def nextState (s : p.State) (input : BoolProd arity) : p.State := - p.nextStateFnCircuit.eval (s.appendInput input) +/-- Return the output bit of FSM `p`, given the current state and input bits. -/ +@[simp] +def outBit (state : p.State) (input : BitVec arity) : Bool := + (p.outCircuit).eval (state ++ input).getLsb' + +/-- Return the next state of FSM `p`, given the current state and input bits. -/ +@[simp] +def nextState (s : p.State) (input : BitVec arity) : p.State := + p.nextStateCircuits.eval (s ++ input) /-- `p.next state in` computes both the next state bits and the output bit, where `state` are the *current* state bits, and `in` are the current input bits. -/ -def next (carry : p.State) (inputBits : BoolProd arity) : p.State × Bool := - let newState : p.State := p.nextState carry inputBits - let outBit : Bool := p.outBit carry inputBits +@[simp] +def next (state : p.State) (inputBits : BitVec arity) : p.State × Bool := + let newState := p.nextState state inputBits + let outBit := p.outBit state inputBits (newState, outBit) - +-- TODO: document this def outputStreamAux (s₀ : p.State) (inputStream : BitStreamProd arity) : BitStream := fun n => match n with - | 0 => outBit s₀ inputStream.head - | n+1 => outputStreamAux (nextState p s₀ (inputStream.head)) inputStream.tail n + | 0 => p.outBit s₀ inputStream.heads + | n+1 => outputStreamAux (nextState p s₀ (inputStream.heads)) inputStream.tails n @[simp] theorem outputStreamAux_zero (s₀ : p.State) (inputStream : BitStreamProd arity) : - outputStreamAux p s₀ inputStream 0 = outBit s₀ (inputStream.nthBits 0) := rfl + outputStreamAux p s₀ inputStream 0 = p.outBit s₀ (inputStream.getLsbs 0) := rfl @[simp] theorem outputStreamAux_succ (s₀ : p.State) (inputStream : BitStreamProd arity) (n : ℕ) : outputStreamAux p s₀ inputStream (n+1) = - outputStreamAux p (p.nextState s₀ (inputStream.head)) inputStream.tail n := by rfl - + outputStreamAux p (p.nextState s₀ (inputStream.heads)) inputStream.tails n := by rfl +/-- +A `StateStream` w.r.t. FSM `p` is an infinite stream of `p.State`s +-/ def StateStream (p : FSM arity) := ℕ → p.State -/-- `p.carryStream inputStream` computes the stream of carries, from the stream of inputs -/ -def carryStream (inputStream : BitStreamProd arity) : p.StateStream := fun n => - match n with - | 0 => p.initCarry - | n+1 => (p.nextState (carryStream inputStream n) (inputStream.nthBits n)) +/-- `p.stateStream` is the stream of states of FSM `p`, +for a given product of input streams. -@[simp] -theorem carryStream_zero (inputStream : BitStreamProd arity) : p.carryStream inputStream 0 = p.initCarry := rfl +That is, it is the stream that starts with `p.initialState`, +and evolves according to `p.nextState`. -/ +def stateStream (p : FSM arity) (xs : BitStreamProd arity) : p.StateStream + | 0 => p.initialState + | n+1 => (p.nextState (p.stateStream xs n) (xs.getLsbs n)) @[simp] -theorem carryStream_succ (inputStream : BitStreamProd arity) (n : Nat) : - p.carryStream inputStream (n+1) = - p.nextState (p.carryStream inputStream n) (inputStream.nthBits n) := rfl - +theorem stateStream_zero (xs : BitStreamProd arity) : + p.stateStream xs 0 = p.initialState := rfl -/-- `eval p` morally gives the function `BitStream → ... → BitStream` represented by FSM `p` -/ -def eval (x : BitStreamProd arity) : BitStream := - p.outputStreamAux p.initCarry x - -def eval'Corec (input : BitStreamProd arity × p.State) : - (BitStreamProd arity × p.State) × Bool := -- (fun ⟨x, (carry : p.State)⟩ => - let x := input.1 - let carry := input.2 - let x_head := x.head -- (x · |>.head) - let next := p.next carry x_head - let x_tail := x.tail -- (x · |>.tail) +@[simp] +theorem stateStream_succ (inputStream : BitStreamProd arity) (n : Nat) : + p.stateStream inputStream (n+1) + = p.nextState (p.stateStream inputStream n) (inputStream.getLsbs n) := + rfl + +-- /-- `eval p` morally gives the function `BitStream → ... → BitStream` represented by FSM `p` -/ +-- def eval (xs : BitStreamProd arity) : BitStream := +-- p.outputStreamAux p.initialState xs + +def eval.next (xs : BitStreamProd arity × p.State) : + (BitStreamProd arity × p.State) × Bool := -- (fun ⟨x, (state : p.State)⟩ => + let x := xs.1 + let state := xs.2 + let x_head := x.heads + let next := p.next state x_head + let x_tail := x.tails ((x_tail, next.fst), next.snd) /-- `eval'` is an alternative definition of `eval`, written in terms of corecursion. -/ -def eval' (x : BitStreamProd arity) : BitStream := - BitStream.corec (eval'Corec p) (x, (p.initCarry : p.State)) - -/-- -Generalized hypothesis that shows how the output stream and -its corecursive definition evolve with an arbitrary input state. --/ -theorem eval_eq_eval'_aux (i : Nat) : - (p.outputStreamAux state x) i = (BitStream.corec (eval'Corec p) (x, state)) i := by - induction i generalizing state x - case zero => rfl - case succ i ih => - simp [outputStreamAux, eval'Corec, BitStream.corec_succ] - rw [← ih] - rfl - -/-- Show that the two definitions of evaluation are equivalent. -/ -theorem eval_eq_eval' : p.eval x = p.eval' x := by - funext i - apply eval_eq_eval'_aux - -/-- `p.changeInitCarry c` yields an FSM with `c` as the initial state -/ -def changeInitCarry (p : FSM arity) (c : BoolProd p.σ) : FSM arity := - { p with initCarry := c } - -theorem eval_changeInitCarry_succ - (p : FSM arity) (c : p.σ → Bool) (x : BitStreamProd arity) (n : ℕ) : - (p.changeInitCarry c).eval x (n+1) = - (p.changeInitCarry (p.next c (fun a => x a 0)).1).eval - (fun a i => x a (i+1)) n := by - rw [eval, carry_changeInitCarry_succ] - simp [eval, changeInitCarry, next] - -/-- `p.changeVars f` changes the arity of an `FSM`. -The function `f` determines how the new input bits map to the input expected by `p` -/ -def changeVars {arity2 : Type} (changeVars : arity → arity2) : FSM arity2 := - { p with nextBitCirc := fun a => (p.nextBitCirc a).map (Sum.map id changeVars) } - +def eval (x : BitStreamProd arity) : BitStream := + BitStream.corec (eval.next p) (x, p.initialState) + +-- /-- +-- Generalized hypothesis that shows how the output stream and +-- its corecursive definition evolve with an arbitrary input state. +-- -/ +-- theorem eval_eq_eval'_aux (i : Nat) : +-- (p.outputStreamAux state x) i = (BitStream.corec (eval'Corec p) (x, state)) i := by +-- induction i generalizing state x +-- case zero => rfl +-- case succ i ih => +-- simp [outputStreamAux, eval'Corec, BitStream.corec_succ] +-- rw [← ih] +-- rfl + +-- /-- Show that the two definitions of evaluation are equivalent. -/ +-- theorem eval_eq_eval' : p.eval x = p.eval' x := by +-- funext i +-- apply eval_eq_eval'_aux + +/-- `p.withInitialState s` yields an FSM with `s` as the initial state -/ +def withInitialState (p : FSM arity) (s : p.State) : FSM arity := + { p with initialState := s } + +theorem eval_withInitialState_succ + (p : FSM arity) (c : p.State) (xs : BitStreamProd arity) (n : ℕ) : + (p.withInitialState c).eval xs (n+1) = + (p.withInitialState (p.nextState c xs.heads)).eval (xs.tails) n := by + simp [eval, withInitialState, next]; rfl + + +-- /-- `p.changeVars f` changes the arity of an `FSM`. +-- The function `f` determines how the new input bits map to the input expected by `p` -/ +-- def changeVars {newArity : Nat} (changeVars : Fin arity → Fin newArity) : +-- FSM newArity := +-- let map (x : BitVec newArity) : BitVec arity := +-- BitVec.ofFnLsb (fun j => x[changeVars j]) +-- { p with +-- outCircuit := p.outCircuit.map _ +-- -- nextBitCirc := fun a => (p.nextBitCirc a).map (Sum.map id changeVars) } + +-- open Fin in +-- def composeUnary (p : FSM 1) (q : FSM n) : FSM n where +-- stateWidth := p.stateWidth + q.stateWidth +-- initialState := p.initialState ++ q.initialState +-- outCircuit := p.outCircuit.bind <| addCases +-- (fun i => Circuit.var true (addInl <| addInl i)) +-- (fun _ => q.outCircuit.map fun j => +-- j.natAdd p.stateWidth |>.cast (by ac_rfl) +-- ) +-- nextStateCircuits := +-- addCases +-- (fun i => (p.nextStateCircuits i).bind <| ) +-- _ +-- -- p.nextStateCircuits + +open Fin in /-- -Given an FSM `p` of arity `n`, -a family of `n` FSMs `qᵢ` of posibly different arities `mᵢ`, -and given yet another arity `m` such that `mᵢ ≤ m` for all `i`, -we can compose `p` with `qᵢ` yielding a single FSM of arity `m`, -such that each FSM `qᵢ` computes the `i`th bit that is fed to the FSM `p`. -/ -def compose [Fintype arity] [DecidableEq arity] - (new_arity : Type) -- `new_arity` is the resulting arity - (q_arity : arity → Type) -- `q_arityₐ` is the arity of FSM `qₐ` - (vars : ∀ (a : arity), q_arity a → new_arity) - -- ^^ `vars` is the function that tells us, for each FSM `qₐ`, - -- which bits of the final `new_arity` corresponds to the `q_arityₐ` bits expected by `qₐ` - (q : ∀ (a : arity), FSM (q_arity a)) : -- `q` gives the FSMs to be composed with `p` - FSM new_arity := - { σ := p.σ ⊕ (Σ a, (q a).σ), - i := by letI := p.i; infer_instance, - dec_eq := by - letI := p.dec_eq - letI := fun a => (q a).dec_eq - infer_instance, - initCarry := Sum.elim p.initCarry (λ x => (q x.1).initCarry x.2), - nextBitCirc := λ a => - match a with - | none => (p.nextBitCirc none).bind - (Sum.elim - (fun a => Circuit.var true (inl (inl a))) - (fun a => ((q a).nextBitCirc none).map - (Sum.elim (fun d => (inl (inr ⟨a, d⟩))) (fun q => inr (vars a q))))) - | some (inl a) => - (p.nextBitCirc (some a)).bind - (Sum.elim - (fun a => Circuit.var true (inl (inl a))) - (fun a => ((q a).nextBitCirc none).map - (Sum.elim (fun d => (inl (inr ⟨a, d⟩))) (fun q => inr (vars a q))))) - | some (inr ⟨x, y⟩) => - ((q x).nextBitCirc (some y)).map - (Sum.elim - (fun a => inl (inr ⟨_, a⟩)) - (fun a => inr (vars x a))) } - -lemma carry_compose [Fintype arity] [DecidableEq arity] - (new_arity : Type) - (q_arity : arity → Type) - (vars : ∀ (a : arity), q_arity a → new_arity) - (q : ∀ (a : arity), FSM (q_arity a)) - (x : new_BitStreamProd arity) : ∀ (n : ℕ), - (p.compose new_arity q_arity vars q).carry x n = - let z := p.carry (λ a => (q a).eval (fun i => x (vars _ i))) n - Sum.elim z (fun a => (q a.1).carry (fun t => x (vars _ t)) n a.2) - | 0 => by simp [carry, compose] - | n+1 => by - rw [carry, carry_compose _ _ _ _ _ n] - ext y - cases y - · simp [carry, next, compose, Circuit.eval_bind, eval] +Given an FSM `p` of some `arity`, +and a family of `arity` FSMs `qᵢ`, +whose (possibly differing) arities are bounded by `newArity`, +we can compose `p` with `qᵢ` yielding a single FSM of arity `newArity`. + +The input of the composed FSM is given to the FSMs `qᵢ`, each of which computes +a single bit of the input that is then given to `p`. -/ +def compose {newArity : Nat} {qArity : Fin arity → Nat} + (arityLE : ∀ (a : Fin arity), qArity a ≤ newArity) + (q : (i : Fin arity) → FSM (qArity i)) : + FSM newArity := + let qOutCircuit (i : Fin arity) := + (q i).outCircuit.map <| Fin.addElim + (fun j => (addInl (addInr (sumOfSigma ⟨i, j⟩)))) + (fun j => addInr (j.castLE (arityLE i))) + { stateWidth := p.stateWidth + (∑ i, (q i).stateWidth), + initialState := p.initialState ++ (BitVec.appendVector (q · |>.initialState)) + outCircuit := + open Fin in + p.outCircuit.bind <| addCases + (fun i => Circuit.var true <| (addInl <| addInl i)) + qOutCircuit + nextStateCircuits := + open Fin in + addCases + (fun i => (p.nextStateCircuits i).bind <| addCases + (fun j => Circuit.var true (addInl (addInl j))) + qOutCircuit + ) + (fun i => + let ⟨i, j⟩ := i.sumToSigma + ((q i).nextStateCircuits j).map <| addCases + (fun k => addInl (addInr (sumOfSigma ⟨_, k⟩))) + (fun k => addInr (k.castLE <| arityLE i)) + ) + } + +#check FSM.stateStream + +lemma stateStream_compose {newArity : Nat} {qArity : Fin arity → Nat} + (arityLE : ∀ (i : Fin arity), qArity i ≤ newArity) + (q : ∀ (i : Fin arity), FSM (qArity i)) + (xs : BitStreamProd newArity) + (n : Nat) : + (p.compose arityLE q).stateStream xs n = + let pState := p.stateStream (fun i => + (q i).eval (fun j => xs <| j.castLE (arityLE _))) n + let qState : BitVec (∑ i, (q i).stateWidth) := + BitVec.appendVector fun i => + ((q i).stateStream (xs.castLE <| arityLE _) n) + pState ++ qState := by + induction n with + | zero => simp [stateStream, compose] + | succ n ih => + rw [stateStream, ih] + ext (y : Fin (_ + _)) + cases y using Fin.addCases + · simp [stateStream, next, compose, Circuit.eval_bind, eval] congr - ext z - cases z - · simp - · simp [Circuit.eval_map, carry] + ext (z : Fin (_ + _)) + cases z using Fin.addCases + · simp; + sorry + · simp [Circuit.eval_map, stateStream] congr - ext s - cases s - · simp - · simp - · simp [Circuit.eval_map, carry, compose, eval, carry, next] + ext (s : Fin (_ + _)) + cases s using Fin.addCases + · simp; sorry + · simp; sorry + · simp [Circuit.eval_map, stateStream, compose, eval, next] congr - ext z - cases z - · simp - · simp + ext (z : Fin (_ + _)) + cases z using Fin.addCases + · simp; sorry + · simp; sorry /-- Evaluating a composed fsm is equivalent to composing the evaluations of the constituent FSMs -/ -lemma eval_compose [Fintype arity] [DecidableEq arity] - (new_arity : Type) - (q_arity : arity → Type) - (vars : ∀ (a : arity), q_arity a → new_arity) - (q : ∀ (a : arity), FSM (q_arity a)) - (x : new_BitStreamProd arity) : - (p.compose new_arity q_arity vars q).eval x = - p.eval (λ a => (q a).eval (fun i => x (vars _ i))) := by +lemma eval_compose {newArity : Nat} {qArity : Fin arity → Nat} + (arityLE : ∀ (i : Fin arity), qArity i ≤ newArity) + (q : ∀ (i : Fin arity), FSM (qArity i)) + (x : BitStreamProd newArity) : + (p.compose arityLE q).eval x = + p.eval (λ a => (q a).eval (fun i => x (i.castLE <| arityLE _))) := by ext n - rw [eval, carry_compose, eval] + stop + rw [eval, stateStream_compose, eval] simp [compose, next, Circuit.eval_bind] congr ext a @@ -297,72 +417,95 @@ lemma eval_compose [Fintype arity] [DecidableEq arity] simp simp -def v0 : Fin 2 → BitStream - | ⟨0, _⟩ => fun n => n % 2 = 0 - | ⟨1, _⟩ => fun n => n % 2 = 1 +/-! +## Concrete FSMs +From here on out, we start to implement various operations as concrete FSMs +-/ -def and : FSM Bool := - { σ := Empty, - initCarry := Empty.elim, - nextStateFnCircuit := fun a => a.elim, - currentOutCircuit := - (Circuit.and - (Circuit.var true (inr true)) - (Circuit.var true (inr false))) } - -@[simp] lemma eval_and (x : Bool → BitStream) : and.eval x = (x true) &&& (x false) := by - ext n; cases n <;> simp [and, eval, next] +/-! ### Bitwise operations -/ -def or : FSM Bool := - { σ := Empty, - initCarry := Empty.elim, - nextBitCirc := fun a => a.elim - (Circuit.or - (Circuit.var true (inr true)) - (Circuit.var true (inr false))) Empty.elim } +/-- `mapCircuit` lifts a Boolean circuit into a stateless FSM -/ +def mapCircuit (c : Circuit (Fin n)) : FSM n where + stateWidth := 0 + initialState := 0#0 + outCircuit := c.map (Fin.cast <| by ac_rfl) + nextStateCircuits := Fin.elim0 -@[simp] lemma eval_or (x : Bool → BitStream) : or.eval x = (x true) ||| (x false) := by - ext n; cases n <;> simp [and, eval, next] +@[simp] lemma eval_mapCircuit (c : Circuit (Fin n)) (xs : BitStreamProd n) : + (mapCircuit c).eval xs = (fun n => c.eval fun j => (xs.getLsbs n)[j.val]) := by + funext m + simp only [eval, mapCircuit] + induction m generalizing xs + case zero => + simp [eval.next, next, BitVec.zero_width_append _, BitStream.head] + case succ m ih => -def xor : FSM Bool := - { σ := Empty, - initCarry := Empty.elim, - nextBitCirc := fun a => a.elim - (Circuit.xor - (Circuit.var true (inr true)) - (Circuit.var true (inr false))) Empty.elim } + simp [eval.next] + specialize ih xs.tails -@[simp] lemma eval_xor (x : Bool → BitStream) : xor.eval x = (x true) ^^^ (x false) := by - ext n; cases n <;> simp [and, eval, next] + simp at ih + rw [ih (xs.tails)] -def add : FSM Bool := - { σ := Unit, - initCarry := λ _ => false, - nextBitCirc := fun a => - match a with - | some () => - (Circuit.var true (inr true) &&& Circuit.var true (inr false)) - ||| (Circuit.var true (inr true) &&& Circuit.var true (inl ())) - ||| (Circuit.var true (inr false) &&& Circuit.var true (inl ())) - | none => Circuit.var true (inr true) ^^^ - Circuit.var true (inr false) ^^^ - Circuit.var true (inl ()) } -/-- The internal carry state of the `add` FSM agrees with + +def and : FSM 2 := + mapCircuit (Circuit.and + (Circuit.var true 1) + (Circuit.var true 0)) + +@[simp] lemma eval_and (xs : BitStreamProd 2) : + and.eval x = (xs 1) &&& (xs 0) := by + ext n; + -- stop + cases n <;> simp [mapCircuit, and, eval, next] + +def or : FSM 2 := + mapCircuit (Circuit.or + (Circuit.var true 1) + (Circuit.var true 0)) + +@[simp] lemma eval_or (x : Bool → BitStream) : or.eval x = (x true) ||| (x false) := by + ext n; cases n <;> simp [and, eval, next] + +def xor : FSM 2 := + mapCircuit (Circuit.xor + (Circuit.var true 1) + (Circuit.var true 0)) + +@[simp] lemma eval_xor (xs : BitStreamProd 2) : + xor.eval xs = (xs 1) ^^^ (xs 0) := by + ext n; stop cases n <;> simp [and, eval, next] + +/-! ### Arithmetic -/ + +def add : FSM 2 where + stateWidth := 1 + initialState := 0#1 + outCircuit := + Circuit.var true 0 + ^^^ Circuit.var true 1 + ^^^ Circuit.var true 2 + nextStateCircuits _ := + (Circuit.var true 2 &&& Circuit.var true 1) + ||| (Circuit.var true 2 &&& Circuit.var true 0) + ||| (Circuit.var true 1 &&& Circuit.var true 0) + +/-- The internal state of the `add` FSM agrees with the carry bit of addition as implemented on bitstreams -/ -theorem carry_add_succ (x : Bool → BitStream) (n : ℕ) : - add.carry x (n+1) = - fun _ => (BitStream.addAux (x true) (x false) n).2 := by - ext a; obtain rfl : a = () := rfl +theorem carry_add_succ (xs : BitStreamProd 2) (n : ℕ) : + add.stateStream xs (n+1) + = BitVec.ofBool ((BitStream.addAux (xs 1) (xs 0) n).2) := by + ext (a : Fin 1) + obtain rfl : a = (0 : Fin 1) := Fin.fin_one_eq_zero a induction n with | zero => - simp [carry, BitStream.addAux, next, add, BitVec.adcb] + simp [stateStream, BitStream.addAux, next, add, BitVec.adcb] | succ n ih => unfold carry simp [next, ih, Circuit.eval, BitStream.addAux, BitVec.adcb] -@[simp] theorem carry_zero (x : BitStreamProd arity) : carry p x 0 = p.initCarry := rfl -@[simp] theorem initCarry_add : add.initCarry = (fun _ => false) := rfl +@[simp] theorem carry_zero (x : BitStreamProd arity) : carry p x 0 = p.initialState := rfl +@[simp] theorem initialState_add : add.initialState = (fun _ => false) := rfl @[simp] lemma eval_add (x : Bool → BitStream) : add.eval x = (x true) + (x false) := by ext n @@ -378,7 +521,7 @@ given that we can reduce both those operations to just addition and bitwise comp def sub : FSM Bool := { σ := Unit, - initCarry := fun _ => false, + initialState := fun _ => false, nextBitCirc := fun a => match a with | some () => @@ -409,7 +552,7 @@ theorem eval_sub (x : Bool → BitStream) : sub.eval x = (x true) - (x false) := def neg : FSM Unit := { σ := Unit, i := by infer_instance, - initCarry := λ _ => true, + initialState := λ _ => true, nextBitCirc := fun a => match a with | some () => Circuit.var false (inr ()) &&& Circuit.var true (inl ()) @@ -433,7 +576,7 @@ theorem carry_neg (x : Unit → BitStream) : ∀ (n : ℕ), neg.carry x (n+1) = def not : FSM Unit := { σ := Empty, - initCarry := Empty.elim, + initialState := Empty.elim, nextBitCirc := fun _ => Circuit.var false (inr ()) } @[simp] lemma eval_not (x : Unit → BitStream) : not.eval x = ~~~(x ()) := by @@ -441,7 +584,7 @@ def not : FSM Unit := def zero : FSM (Fin 0) := { σ := Empty, - initCarry := Empty.elim, + initialState := Empty.elim, nextBitCirc := fun _ => Circuit.fals } @[simp] lemma eval_zero (x : Fin 0 → BitStream) : zero.eval x = BitStream.zero := by @@ -450,7 +593,7 @@ def zero : FSM (Fin 0) := def one : FSM (Fin 0) := { σ := Unit, i := by infer_instance, - initCarry := λ _ => true, + initialState := λ _ => true, nextBitCirc := fun a => match a with | some () => Circuit.fals @@ -469,7 +612,7 @@ def one : FSM (Fin 0) := def negOne : FSM (Fin 0) := { σ := Empty, i := by infer_instance, - initCarry := Empty.elim, + initialState := Empty.elim, nextBitCirc := fun _ => Circuit.tru } @[simp] lemma eval_negOne (x : Fin 0 → BitStream) : negOne.eval x = BitStream.negOne := by @@ -477,7 +620,7 @@ def negOne : FSM (Fin 0) := def ls (b : Bool) : FSM Unit := { σ := Unit, - initCarry := fun _ => b, + initialState := fun _ => b, nextBitCirc := fun x => match x with | none => Circuit.var true (inl ()) @@ -501,7 +644,7 @@ theorem carry_ls (b : Bool) (x : Unit → BitStream) : ∀ (n : ℕ), def var (n : ℕ) : FSM (Fin (n+1)) := { σ := Empty, i := by infer_instance, - initCarry := Empty.elim, + initialState := Empty.elim, nextBitCirc := λ _ => Circuit.var true (inr (Fin.last _)) } @[simp] lemma eval_var (n : ℕ) (x : Fin (n+1) → BitStream) : (var n).eval x = x (Fin.last n) := by @@ -509,7 +652,7 @@ def var (n : ℕ) : FSM (Fin (n+1)) := def incr : FSM Unit := { σ := Unit, - initCarry := fun _ => true, + initialState := fun _ => true, nextBitCirc := fun x => match x with | none => (Circuit.var true (inr ())) ^^^ (Circuit.var true (inl ())) @@ -532,7 +675,7 @@ theorem carry_incr (x : Unit → BitStream) : ∀ (n : ℕ), def decr : FSM Unit := { σ := Unit, i := by infer_instance, - initCarry := λ _ => true, + initialState := λ _ => true, nextBitCirc := fun x => match x with | none => (Circuit.var true (inr ())) ^^^ (Circuit.var true (inl ())) @@ -554,7 +697,7 @@ theorem carry_decr (x : Unit → BitStream) : ∀ (n : ℕ), decr.carry x (n+1) theorem evalAux_eq_zero_of_set {arity : Type _} (p : FSM arity) (R : Set (p.σ → Bool)) (hR : ∀ x s, (p.next s x).1 ∈ R → s ∈ R) - (hi : p.initCarry ∉ R) (hr1 : ∀ x s, (p.next s x).2 = true → s ∈ R) + (hi : p.initialState ∉ R) (hr1 : ∀ x s, (p.next s x).2 = true → s ∈ R) (x : BitStreamProd arity) (n : ℕ) : p.eval x n = false ∧ p.carry x n ∉ R := by simp (config := {singlePass := true}) only [← not_imp_not] at hR hr1 simp only [Bool.not_eq_true] at hR hr1 @@ -568,7 +711,7 @@ theorem evalAux_eq_zero_of_set {arity : Type _} (p : FSM arity) theorem eval_eq_zero_of_set {arity : Type _} (p : FSM arity) (R : Set (p.σ → Bool)) (hR : ∀ x s, (p.next s x).1 ∈ R → s ∈ R) - (hi : p.initCarry ∉ R) (hr1 : ∀ x s, (p.next s x).2 = true → s ∈ R) : + (hi : p.initialState ∉ R) (hr1 : ∀ x s, (p.next s x).2 = true → s ∈ R) : p.eval = fun _ _ => false := by ext x n rw [eval] @@ -576,7 +719,7 @@ theorem eval_eq_zero_of_set {arity : Type _} (p : FSM arity) def repeatBit : FSM Unit where σ := Unit - initCarry := fun () => false + initialState := fun () => false nextBitCirc := fun _ => .or (.var true <| .inl ()) (.var true <| .inr ()) @@ -730,7 +873,7 @@ If not, we produce an infinite sequence of `1`. -/ def and : FSM Bool := { σ := Unit, - initCarry := fun _ => false, + initialState := fun _ => false, nextBitCirc := fun a => match a with | some () => @@ -751,7 +894,7 @@ If not, we produce a `1`. -/ def or : FSM Bool := { σ := Unit, - initCarry := fun _ => false, + initialState := fun _ => false, nextBitCirc := fun a => match a with | some () => @@ -800,7 +943,7 @@ theorem decideIfZeroAux_wf {σ : Type _} [Fintype σ] [DecidableEq σ] def decideIfZerosAux {arity : Type _} [DecidableEq arity] (p : FSM arity) (c : Circuit p.σ) : Bool := - if c.eval p.initCarry + if c.eval p.initialState then false else have c' := (c.bind (p.nextBitCirc ∘ some)).fst @@ -818,14 +961,14 @@ def decideIfZeros {arity : Type _} [DecidableEq arity] theorem decideIfZerosAux_correct {arity : Type _} [DecidableEq arity] (p : FSM arity) (c : Circuit p.σ) (hc : ∀ s, c.eval s = true → - ∃ m y, (p.changeInitCarry s).eval y m = true) + ∃ m y, (p.withInitialState s).eval y m = true) (hc₂ : ∀ (x : arity → Bool) (s : p.σ → Bool), (FSM.next p s x).snd = true → Circuit.eval c s = true) : decideIfZerosAux p c = true ↔ ∀ n x, p.eval x n = false := by rw [decideIfZerosAux] split_ifs with h · simp - exact hc p.initCarry h + exact hc p.initialState h · dsimp split_ifs with h' · simp only [true_iff] @@ -847,7 +990,7 @@ theorem decideIfZerosAux_correct {arity : Type _} [DecidableEq arity] · rcases hc _ hx with ⟨m, y, hmy⟩ use (m+1) use fun a i => Nat.casesOn i x (fun i a => y a i) a - rw [FSM.eval_changeInitCarry_succ] + rw [FSM.eval_withInitialState_succ] rw [← hmy] simp only [FSM.next, Nat.rec_zero, Nat.rec_add_one] · exact hc _ h @@ -864,7 +1007,7 @@ theorem decideIfZeros_correct {arity : Type _} [DecidableEq arity] intro s x h use 0 use (fun a _ => x a) - simpa [FSM.eval, FSM.changeInitCarry, FSM.next, FSM.carry] + simpa [FSM.eval, FSM.withInitialState, FSM.next, FSM.carry] · simp only [Circuit.eval_fst] intro x s h use x