Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into add_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasgrosser committed Dec 14, 2024
2 parents 751d290 + a4d4fe1 commit 5f52c5e
Showing 1 changed file with 69 additions and 21 deletions.
90 changes: 69 additions & 21 deletions SSA/Experimental/Bits/Fast/Reflect.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TODO:
+ I currently add support for BitVec.ofInt, with the knowledge that I can remove it
if I'm unable to prove soundness.
- [x] leftShift
- [ ] Break down numeral multiplication into left shift:
- [x] Break down numeral multiplication into left shift:
10 * z
= z <<< 1 + 5 * z
= z <<< 1 + (z + 4 * z)
Expand Down Expand Up @@ -222,34 +222,42 @@ with BitVec.ofNat -/
x * (BitVec.ofNat w n) = BitVec.ofNat w n * x := by rw [BitVec.mul_comm]


/-! Normal form for shifts
See that `x <<< (n : Nat)` is strictly more expression than `x <<< BitVec.ofNat w n`,
because in the former case, we can shift by arbitrary amounts, while in the latter case,
we can only shift by numbers upto `2^w`. Therefore, we choose `x <<< (n : Nat)` as our simp
and preprocessing normal form for the tactic.
-/

@[simp] theorem BitVec.shiftLeft_ofNat_eq (x : BitVec w) (n : Nat) :
x <<< BitVec.ofNat w n = x <<< (n % 2^w) := by simp

/--
Multiplying by an even number `e` is the same as shifting by `1`,
followed by multiplying by half of `e` (the number `n`).
This is used to simplify multiplications into shifts.
-/
theorem BitVec.even_mul_eq_shiftLeft_mul_of_eq_mul_two (x : BitVec w) (n e : Nat) (he : e = n * 2) :
(BitVec.ofNat w e) * x = (BitVec.ofNat w n) * (x <<< 1) := by
theorem BitVec.even_mul_eq_shiftLeft_mul_of_eq_mul_two (w : Nat) (x : BitVec w) (n e : Nat) (he : e = n * 2) :
(BitVec.ofNat w e) * x = (BitVec.ofNat w n) * (x <<< (1 : Nat)) := by
apply BitVec.eq_of_toNat_eq
simp [Nat.shiftLeft_eq, he]
rcases w with rfl | w
· simp [Nat.mod_one]
· simp
congr 1
· congr 1
rw [Nat.mul_comm x.toNat 2, ← Nat.mul_assoc n]

/--
Multiplying by an odd number `o` is the same as adding `x`, followed by multiplying by `(o - 1) / 2`.
This is used to simplify multiplications into shifts.
-/
theorem BitVec.odd_mul_eq_shiftLeft_mul_of_eq_mul_two_add_one (x : BitVec w) (n o : Nat)
(ho : o = n * 2 + 1) : (BitVec.ofNat w o) * x = x + (BitVec.ofNat w n) * (x <<< 1) := by
theorem BitVec.odd_mul_eq_shiftLeft_mul_of_eq_mul_two_add_one (w : Nat) (x : BitVec w) (n o : Nat)
(ho : o = n * 2 + 1) : (BitVec.ofNat w o) * x = x + (BitVec.ofNat w n) * (x <<< (1 : Nat)) := by
apply BitVec.eq_of_toNat_eq
simp [Nat.shiftLeft_eq, ho]
rcases w with rfl | w
· simp [Nat.mod_one]
· simp only [lt_add_iff_pos_left, add_pos_iff, zero_lt_one, or_true, Nat.one_mod_two_pow,
pow_one]
congr 1
· congr 1
rw [Nat.add_mul]
simp only [one_mul]
rw [Nat.mul_assoc, Nat.mul_comm 2]
Expand All @@ -272,6 +280,47 @@ theorem BitVec.odd_mul_eq_shiftLeft_mul_of_eq_mul_two_add_one (x : BitVec w) (n

@[bv_circuit_preprocess] theorem BitVec.zero_mul (x : BitVec w) : 0#w * x = 0#w := by simp


open Lean Meta Elab in

/--
Given an equality proof with `lhs = rhs`, return the `rhs`,
and bail out if we are unable to determine it precisely (i.e. no loose metavars).
-/
def getEqRhs (eq : Expr) : MetaM Expr := do
check eq
let eq ← whnf <| ← inferType eq
let some (_ty, _lhs, rhs) := eq.eq? | throwError "unable to infer RHS for equality {eq}"
let rhs ← instantiateMVars rhs
rhs.ensureHasNoMVars
return rhs

open Lean Meta Elab in
/--
This needs to be a pre-simproc, because we want to rewrite `k * x`
repeatedly into smaller multiplications:
+ rewrite into `x + ((k/2) * (x <<< 1))` if `k` odd.
+ rewrite into `(k/2) * (x <<< 1) if k even.
Since we get a smaller multiplication with `k/2`, we need it to be a pre-simproc so we recurse
into the RHS expression.
-/
simproc↓ [bv_circuit_preprocess] shiftLeft_break_down ((BitVec.ofNat _ _) * (_ : BitVec _)) := fun x => do
match_expr x with
| HMul.hMul _bv _bv _bv _inst kbv x =>
let_expr BitVec.ofNat _w k := kbv | return .continue
let some kVal ← Meta.getNatValue? k | return .continue
/- base cases, will be taken care of by rewrite theorems -/
if kVal == 0 || kVal == 1 then return .continue
let thmName := if kVal % 2 == 0 then
mkConst ``BitVec.even_mul_eq_shiftLeft_mul_of_eq_mul_two
else
mkConst ``BitVec.odd_mul_eq_shiftLeft_mul_of_eq_mul_two_add_one
let eqProof := mkAppN thmName
#[_w, x, mkNatLit <| Nat.div2 kVal, mkNatLit kVal, (← mkEqRefl k)]
return .visit { proof? := eqProof, expr := ← getEqRhs eqProof }
| _ => return .continue

open Lean Elab Meta
def runPreprocessing (g : MVarId) : MetaM (Option MVarId) := do
let some ext ← (getSimpExtension? `bv_circuit_preprocess)
Expand Down Expand Up @@ -594,9 +643,10 @@ partial def reflectTermUnchecked (map : ReflectMap) (w : Expr) (e : Expr) : Meta
return { b with e := out }
| HShiftLeft.hShiftLeft _bv _nat _bv _inst a n =>
let a ← reflectTermUnchecked map w a
let some natVal ← Lean.Meta.getNatValue? n
| throwError "Only shift left by natural numbers are allowed, but found shift by expression '{n}' at {indentD e}"
return { a with e := Term.shiftL a.e natVal }
let some n ← getNatValue? n
| throwError "expected shiftLeft by natural number, found symbolic shift amount '{n}' at '{indentD e}'"
return { a with e := Term.shiftL a.e n }

| HSub.hSub _bv _bv _bv _inst a b =>
let a ← reflectTermUnchecked map w a
let b ← reflectTermUnchecked a.bvToIxMap w b
Expand Down Expand Up @@ -1155,6 +1205,8 @@ example : ∀ (w : Nat) (x : BitVec w), (BitVec.ofInt w (-1)) &&& x = x := by
example : ∀ (w : Nat) (x : BitVec w), x <<< (0 : Nat) = x := by intros; bv_automata_circuit
example : ∀ (w : Nat) (x : BitVec w), x <<< (1 : Nat) = x + x := by intros; bv_automata_circuit
example : ∀ (w : Nat) (x : BitVec w), x <<< (2 : Nat) = x + x + x + x := by
intros w n
-- rw [BitVec.ofNat_eq_ofNat (n := w) (i := 2)]
intros; bv_automata_circuit


Expand Down Expand Up @@ -1187,16 +1239,12 @@ theorem add_eq_xor_add_mul_and_nt (x y : BitVec w) :
bv_automata_circuit

/-- Check that we correctly process an even numeral multiplication. -/
theorem mul_four (x : BitVec w) :
4 * x = x + x + x + x := by
fail_if_success bv_automata_circuit
sorry
theorem mul_four (x : BitVec w) : 4 * x = x + x + x + x := by
bv_automata_circuit

/-- Check that we correctly process an odd numeral multiplication. -/
theorem mul_five (x : BitVec w) :
5 * x = x + x + x + x + 5 := by
fail_if_success bv_automata_circuit
sorry
theorem mul_five (x : BitVec w) : 5 * x = x + x + x + x + x := by
bv_automata_circuit (config := { circuitSizeThreshold := 150 })

open BitVec in
/-- Check that we support sign extension. -/
Expand Down

0 comments on commit 5f52c5e

Please sign in to comment.