Skip to content

Commit

Permalink
chore: update to nightly-testing-2024-12-21 (#907)
Browse files Browse the repository at this point in the history
Co-authored-by: luisacicolini <[email protected]>
  • Loading branch information
tobiasgrosser and luisacicolini authored Dec 25, 2024
1 parent 19a92c4 commit f2b36bf
Show file tree
Hide file tree
Showing 13 changed files with 117 additions and 97 deletions.
4 changes: 3 additions & 1 deletion SSA/Core/HVector.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ theorem eq_of_type_eq_nil {A : α → Type*} {l : List α}
syntax "[" withoutPosition(term,*) "]ₕ" : term

@[simp]
theorem cons_nil_get : (HVector.cons x .nil).get (0 : Fin 1) = x := rfl
theorem cons_get_zero {A : α → Type*} {a: α} {as : List α} {e : A a} {vec : HVector A as} :
(HVector.cons e vec).get (@OfNat.ofNat (Fin (as.length + 1)) 0 Fin.instOfNat) = e := by
rfl

-- Copied from core for List
macro_rules
Expand Down
17 changes: 12 additions & 5 deletions SSA/Experimental/Bits/AutoStructs/FiniteStateMachine.lean
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def or : FSM Bool :=
(Circuit.var true (inr false))) Empty.elim }

@[simp] lemma eval_or (x : Bool → BitStream) : or.eval x = (x true) ||| (x false) := by
ext n; cases n <;> simp [and, eval, nextBit]
ext n; cases n <;> simp [or, eval, nextBit]

def xor : FSM Bool :=
{ α := Empty,
Expand All @@ -264,7 +264,7 @@ def xor : FSM Bool :=
(Circuit.var true (inr false))) Empty.elim }

@[simp] lemma eval_xor (x : Bool → BitStream) : xor.eval x = (x true) ^^^ (x false) := by
ext n; cases n <;> simp [and, eval, nextBit]
ext n; cases n <;> simp [xor, eval, nextBit]

def add : FSM Bool :=
{ α := Unit,
Expand All @@ -279,6 +279,13 @@ def add : FSM Bool :=
Circuit.var true (inr false) ^^^
Circuit.var true (inl ()) }

private theorem add_nextBitCirc_some_eval :
(add.nextBitCirc (some ())).eval =
fun x => x (inr true) && x (inr false) || x (inr true)
&& x (inl ()) || x (inr false) && x (inl ()) := by
ext x
simp [add]

/-- The internal carry state of the `add` FSM agrees with
the carry bit of addition as implemented on bitstreams -/
theorem carry_add_succ (x : Bool → BitStream) (n : ℕ) :
Expand All @@ -290,7 +297,7 @@ theorem carry_add_succ (x : Bool → BitStream) (n : ℕ) :
simp [carry, BitStream.addAux, nextBit, add, BitVec.adcb]
| succ n ih =>
unfold carry
simp [nextBit, ih, Circuit.eval, BitStream.addAux, BitVec.adcb]
simp [add_nextBitCirc_some_eval, nextBit, ih, Circuit.eval, BitStream.addAux, BitVec.adcb, nextBitCirc, Sum.elim]

@[simp] theorem carry_zero (x : ar → BitStream) : carry p x 0 = p.initCarry := rfl
@[simp] theorem initCarry_add : add.initCarry = (fun _ => false) := rfl
Expand Down Expand Up @@ -395,7 +402,7 @@ def one : FSM (Fin 0) :=
ext n
cases n
· rfl
· simp [eval, carry_one, nextBit]
· simp! [eval, carry_one, nextBit, one, mk]

def negOne : FSM (Fin 0) :=
{ α := Empty,
Expand Down Expand Up @@ -427,7 +434,7 @@ theorem carry_ls (b : Bool) (x : Unit → BitStream) : ∀ (n : ℕ),
ext n
cases n
· rfl
· simp [carry_ls, eval, nextBit, BitStream.concat]
· simp [ls, carry, carry_ls, eval, nextBit, BitStream.concat]

def var (n : ℕ) : FSM (Fin (n+1)) :=
{ α := Empty,
Expand Down
28 changes: 16 additions & 12 deletions SSA/Experimental/Bits/Fast/Circuit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,12 @@ instance : AndOp (Circuit α) := ⟨Circuit.simplifyAnd⟩
@[simp] lemma eval_and : ∀ (c₁ c₂ : Circuit α) (f : α → Bool),
(eval (c₁ &&& c₂) f) = ((eval c₁ f) && (eval c₂ f)) := by
intros c₁ c₂ f
cases c₁ <;> cases c₂ <;> simp [eval, simplifyAnd]
cases c₁ <;> cases c₂ <;> simp [simplifyAnd, AndOp.and, HAnd.hAnd]

theorem varsFinset_and [DecidableEq α] (c₁ c₂ : Circuit α) :
(varsFinset (c₁ &&& c₂)) ⊆ (varsFinset c₁ ∪ varsFinset c₂) := by
cases c₁ <;> cases c₂ <;> simp [vars, simplifyAnd, varsFinset, Finset.subset_iff]
cases c₁ <;> cases c₂ <;> simp [vars, simplifyAnd, varsFinset, Finset.subset_iff,
AndOp.and, HAnd.hAnd]

def simplifyOr : Circuit α → Circuit α → Circuit α
| tru, _ => tru
Expand All @@ -138,15 +139,18 @@ instance : OrOp (Circuit α) := ⟨Circuit.simplifyOr⟩
@[simp] lemma eval_or : ∀ (c₁ c₂ : Circuit α) (f : α → Bool),
(eval (c₁ ||| c₂) f) = ((eval c₁ f) || (eval c₂ f)) := by
intros c₁ c₂ f
cases c₁ <;> cases c₂ <;> simp [Circuit.simplifyOr, eval]
cases c₁ <;> cases c₂ <;> simp [Circuit.simplifyOr, eval,
OrOp.or, HOr.hOr]

theorem vars_or [DecidableEq α] (c₁ c₂ : Circuit α) :
(vars (c₁ ||| c₂)) ⊆ (vars c₁ ++ vars c₂).dedup := by
cases c₁ <;> cases c₂ <;> simp [vars, simplifyOr]
cases c₁ <;> cases c₂ <;> simp [vars, simplifyOr,
OrOp.or, HOr.hOr]

theorem varsFinset_or [DecidableEq α] (c₁ c₂ : Circuit α) :
(varsFinset (c₁ ||| c₂)) ⊆ (varsFinset c₁ ∪ varsFinset c₂) := by
cases c₁ <;> cases c₂ <;> simp [vars, simplifyOr, varsFinset, Finset.subset_iff]
cases c₁ <;> cases c₂ <;> simp [vars, simplifyOr, varsFinset, Finset.subset_iff,
OrOp.or, HOr.hOr]

def simplifyNot : Circuit α → Circuit α
| tru => fals
Expand All @@ -173,13 +177,14 @@ theorem simplifyNot_eq_complement (c : Circuit α) :
erw [eval, eval, eval_complement a, eval_complement b, Bool.not_and]
| or a b, f => by
erw [eval, eval, eval_complement a, eval_complement b, Bool.not_or]
| var true a, f => by simp [eval]
| var false a, f => by simp [eval]
| var true a, f => by simp [eval, ←simplifyNot_eq_complement, simplifyNot]
| var false a, f => by simp [eval, ←simplifyNot_eq_complement, simplifyNot]

theorem varsFinset_complement [DecidableEq α] (c : Circuit α) :
(varsFinset (~~~ c)) ⊆ varsFinset c := by
intro x
induction c <;> simp [simplifyNot, vars, mem_varsFinset] <;> aesop
induction c <;> simp [simplifyNot, ←simplifyNot_eq_complement, vars, mem_varsFinset]
<;> aesop

@[simp]
def simplifyXor : Circuit α → Circuit α → Circuit α
Expand All @@ -198,16 +203,15 @@ instance : Xor (Circuit α) := ⟨Circuit.simplifyXor⟩
@[simp] lemma eval_xor : ∀ (c₁ c₂ : Circuit α) (f : α → Bool),
eval (c₁ ^^^ c₂) f = Bool.xor (eval c₁ f) (eval c₂ f) := by
intros c₁ c₂ f
cases c₁ <;> cases c₂ <;> simp [simplifyXor, Bool.xor_not_left'] <;>
split_ifs <;> simp [*] at *
cases c₁ <;> cases c₂ <;> simp [simplifyXor, Bool.xor_not_left', HXor.hXor, Xor.xor]

set_option maxHeartbeats 1000000
theorem vars_simplifyXor [DecidableEq α] (c₁ c₂ : Circuit α) :
(vars (simplifyXor c₁ c₂)) ⊆ (vars c₁ ++ vars c₂).dedup := by
intro x
simp only [List.mem_dedup, List.mem_append]
simp only [List.mem_dedup, List.mem_append, ←simplifyNot_eq_complement]
induction c₁ <;> induction c₂ <;> simp only [simplifyXor, vars,
← simplifyNot_eq_complement] at * <;> aesop
← simplifyNot_eq_complement, simplifyNot] at * <;> aesop

theorem varsFinset_simplifyXor [DecidableEq α] (c₁ c₂ : Circuit α) :
(varsFinset (simplifyXor c₁ c₂)) ⊆ (varsFinset c₁ ∪ varsFinset c₂) := by
Expand Down
17 changes: 12 additions & 5 deletions SSA/Experimental/Bits/Fast/FiniteStateMachine.lean
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def or : FSM Bool :=
(Circuit.var true (inr false))) Empty.elim }

@[simp] lemma eval_or (x : Bool → BitStream) : or.eval x = (x true) ||| (x false) := by
ext n; cases n <;> simp [and, eval, nextBit]
ext n; cases n <;> simp [or, eval, nextBit]

def xor : FSM Bool :=
{ α := Empty,
Expand Down Expand Up @@ -334,7 +334,7 @@ lemma eval_scanAnd_succ (x : Unit → BitStream) (n : Nat) :
exact h j (by omega)

@[simp] lemma eval_xor (x : Bool → BitStream) : xor.eval x = (x true) ^^^ (x false) := by
ext n; cases n <;> simp [and, eval, nextBit]
ext n; cases n <;> simp [xor, eval, nextBit]

@[simp] lemma eval_nxor (x : Bool → BitStream) : nxor.eval x = ((x true).nxor (x false)) := by
ext n; cases n
Expand Down Expand Up @@ -407,6 +407,13 @@ def add : FSM Bool :=
Circuit.var true (inr false) ^^^
Circuit.var true (inl ()) }

theorem add_nextBitCirc_some_eval :
(add.nextBitCirc (some ())).eval =
fun x => x (inr true) && x (inr false) || x (inr true)
&& x (inl ()) || x (inr false) && x (inl ()) := by
ext x
simp +ground [eval, add, Circuit.simplifyAnd, Circuit.simplifyOr]

/-- The internal carry state of the `add` FSM agrees with
the carry bit of addition as implemented on bitstreams -/
theorem carry_add_succ (x : Bool → BitStream) (n : ℕ) :
Expand All @@ -418,7 +425,7 @@ theorem carry_add_succ (x : Bool → BitStream) (n : ℕ) :
simp [carry, BitStream.addAux, nextBit, add, BitVec.adcb]
| succ n ih =>
unfold carry
simp [nextBit, ih, Circuit.eval, BitStream.addAux, BitVec.adcb]
simp [add_nextBitCirc_some_eval, nextBit, ih, Circuit.eval, BitStream.addAux, BitVec.adcb]

@[simp] theorem carry_zero (x : arity → BitStream) : carry p x 0 = p.initCarry := rfl
@[simp] theorem initCarry_add : add.initCarry = (fun _ => false) := rfl
Expand Down Expand Up @@ -558,7 +565,7 @@ def one : FSM (Fin 0) :=
ext n
cases n
· rfl
· simp [eval, carry_one, nextBit]
· simp! [eval, carry_one, nextBit, one]

def negOne : FSM (Fin 0) :=
{ α := Empty,
Expand Down Expand Up @@ -593,7 +600,7 @@ theorem carry_ls (b : Bool) (x : Unit → BitStream) : ∀ (n : ℕ),
ext n
cases n
· rfl
· simp [carry_ls, eval, nextBit, BitStream.concat]
· simp [ls, carry_ls, eval, nextBit, BitStream.concat, carry]

def var (n : ℕ) : FSM (Fin (n+1)) :=
{ α := Empty,
Expand Down
26 changes: 13 additions & 13 deletions SSA/Projects/CIRCT/DC/DC.lean
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,19 @@ instance : DialectSignature DC := ⟨Op.signature⟩
@[simp]
instance : DialectDenote (DC) where
denote
| .fst, arg, _ => (arg.getN 0).fst
| .snd, arg, _ => (arg.getN 0).snd
| .fstVal _, arg, _ => (arg.getN 0).fst
| .sndVal _, arg, _ => (arg.getN 0).snd
| .pair _, arg, _ => (arg.getN 0, arg.getN 1)
| .unpack _, arg, _ => CIRCTStream.DC.unpack (arg.getN 0)
| .pack _, arg, _ => CIRCTStream.DC.pack (arg.getN 0) (arg.getN 1)
| .branch, arg, _ => CIRCTStream.DC.branch (arg.getN 0)
| .fork, arg, _ => CIRCTStream.DC.fork (arg.getN 0)
| .join, arg, _ => CIRCTStream.DC.join (arg.getN 0) (arg.getN 1)
| .merge, arg, _ => CIRCTStream.DC.merge (arg.getN 0) (arg.getN 1)
| .select, arg, _ => CIRCTStream.DC.select (arg.getN 0) (arg.getN 1) (arg.getN 2)
| .sink, arg, _ => CIRCTStream.DC.sink (arg.getN 0)
| .fst, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature])).fst
| .snd, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature])).snd
| .fstVal _, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature])).fst
| .sndVal _, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature]) ).snd
| .pair _, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature]), arg.getN 1 (by simp [DialectSignature.sig, signature]))
| .unpack _, arg, _ => CIRCTStream.DC.unpack (arg.getN 0 (by simp [DialectSignature.sig, signature]))
| .pack _, arg, _ => CIRCTStream.DC.pack (arg.getN 0 (by simp [DialectSignature.sig, signature])) (arg.getN 1 (by simp [DialectSignature.sig, signature]))
| .branch, arg, _ => CIRCTStream.DC.branch (arg.getN 0 (by simp [DialectSignature.sig, signature]))
| .fork, arg, _ => CIRCTStream.DC.fork (arg.getN 0 (by simp [DialectSignature.sig, signature]))
| .join, arg, _ => CIRCTStream.DC.join (arg.getN 0 (by simp [DialectSignature.sig, signature])) (arg.getN 1 (by simp [DialectSignature.sig, signature]))
| .merge, arg, _ => CIRCTStream.DC.merge (arg.getN 0 (by simp [DialectSignature.sig, signature])) (arg.getN 1 (by simp [DialectSignature.sig, signature]))
| .select, arg, _ => CIRCTStream.DC.select (arg.getN 0 (by simp [DialectSignature.sig, signature])) (arg.getN 1 (by simp [DialectSignature.sig, signature])) (arg.getN 2 (by simp [DialectSignature.sig, signature]))
| .sink, arg, _ => CIRCTStream.DC.sink (arg.getN 0 (by simp [DialectSignature.sig, signature]))
| .source, _, _ => CIRCTStream.DC.source

end Dialect
Expand Down
8 changes: 4 additions & 4 deletions SSA/Projects/CIRCT/Handshake/Handshake.lean
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ instance : DialectSignature Handshake := ⟨Op.signature⟩
@[simp]
instance : DialectDenote (Handshake) where
denote
| .branch _, arg, _ => CIRCTStream.Handshake.branch (arg.getN 0) (arg.getN 1)
| .merge _, arg, _ => CIRCTStream.Handshake.merge (arg.getN 0) (arg.getN 1)
| .fst _, arg, _ => (arg.getN 0).fst
| .snd _, arg, _ => (arg.getN 0).snd
| .branch _, arg, _ => CIRCTStream.Handshake.branch (arg.getN 0 (by simp [DialectSignature.sig, signature])) (arg.getN 1 (by simp [DialectSignature.sig, signature]))
| .merge _, arg, _ => CIRCTStream.Handshake.merge (arg.getN 0 (by simp [DialectSignature.sig, signature])) (arg.getN 1 (by simp [DialectSignature.sig, signature]))
| .fst _, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature])).fst
| .snd _, arg, _ => (arg.getN 0 (by simp [DialectSignature.sig, signature])).snd

end Dialect

Expand Down
42 changes: 21 additions & 21 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -381,27 +381,27 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
(TyDenote.toType <| DialectSignature.outTy o) :=
match o with
| Op.const _ val => const? _ val
| Op.copy _ => (op.getN 0)
| Op.not _ => LLVM.not (op.getN 0)
| Op.neg _ => LLVM.neg (op.getN 0)
| Op.trunc w w' flags => LLVM.trunc w' (op.getN 0) flags
| Op.zext w w' flag => LLVM.zext w' (op.getN 0) flag
| Op.sext w w' => LLVM.sext w' (op.getN 0)
| Op.and _ => LLVM.and (op.getN 0) (op.getN 1)
| Op.or _ flag => LLVM.or (op.getN 0) (op.getN 1) flag
| Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1)
| Op.shl _ flags => LLVM.shl (op.getN 0) (op.getN 1) flags
| Op.lshr _ flag => LLVM.lshr (op.getN 0) (op.getN 1) flag
| Op.ashr _ flag => LLVM.ashr (op.getN 0) (op.getN 1) flag
| Op.sub _ flags => LLVM.sub (op.getN 0) (op.getN 1) flags
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ flags => LLVM.mul (op.getN 0) (op.getN 1) flags
| Op.sdiv _ flag => LLVM.sdiv (op.getN 0) (op.getN 1) flag
| Op.udiv _ flag => LLVM.udiv (op.getN 0) (op.getN 1) flag
| Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1)
| Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1)
| Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1)
| Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2)
| Op.copy _ => (op.getN 0 (by simp [DialectSignature.sig, signature]))
| Op.not _ => LLVM.not (op.getN 0 (by simp [DialectSignature.sig, signature]))
| Op.neg _ => LLVM.neg (op.getN 0 (by simp [DialectSignature.sig, signature]))
| Op.trunc w w' flags => LLVM.trunc w' (op.getN 0 (by simp [DialectSignature.sig, signature])) flags
| Op.zext w w' flag => LLVM.zext w' (op.getN 0 (by simp [DialectSignature.sig, signature])) flag
| Op.sext w w' => LLVM.sext w' (op.getN 0 (by simp [DialectSignature.sig, signature]))
| Op.and _ => LLVM.and (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature]))
| Op.or _ flag => LLVM.or (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flag
| Op.xor _ => LLVM.xor (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature]))
| Op.shl _ flags => LLVM.shl (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flags
| Op.lshr _ flag => LLVM.lshr (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flag
| Op.ashr _ flag => LLVM.ashr (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flag
| Op.sub _ flags => LLVM.sub (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flags
| Op.add _ flags => LLVM.add (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flags
| Op.mul _ flags => LLVM.mul (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flags
| Op.sdiv _ flag => LLVM.sdiv (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flag
| Op.udiv _ flag => LLVM.udiv (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) flag
| Op.urem _ => LLVM.urem (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature]))
| Op.srem _ => LLVM.srem (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature]))
| Op.icmp c _ => LLVM.icmp c (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature]))
| Op.select _ => LLVM.select (op.getN 0 (by simp [DialectSignature.sig, signature])) (op.getN 1 (by simp [DialectSignature.sig, signature])) (op.getN 2 (by simp [DialectSignature.sig, signature]))

instance : DialectDenote LLVM := ⟨
fun o args _ => Op.denote o args
Expand Down
2 changes: 1 addition & 1 deletion SSA/Projects/InstCombine/LLVM/PrettyEDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ private def pretty_test_exact :=
example : pretty_test = prettier_test_generic 32 := by
unfold pretty_test prettier_test_generic
simp_alive_meta
simp
rfl

example : pretty_test_generic = prettier_test_generic := rfl

Expand Down
Loading

0 comments on commit f2b36bf

Please sign in to comment.