From d66a7b86292739ab2c6e1cb6a7224689b386c56b Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Tue, 16 Jul 2024 16:56:56 +0100 Subject: [PATCH 1/8] add flags --- SSA/Projects/InstCombine/LLVM/Semantics.lean | 121 ++++++++++++++----- 1 file changed, 93 insertions(+), 28 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index 86bb109e3..04f6f025c 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -34,6 +34,9 @@ def and {w : Nat} (x y : IntW w) : IntW w := do The ‘or’ instruction returns the bitwise logical inclusive or of its two operands. -/ + +structure OrParams where + disjoint : Bool := false @[simp_llvm] def or? {w : Nat} (x y : BitVec w) : IntW w := pure <| x ||| y @@ -42,9 +45,11 @@ def or? {w : Nat} (x y : BitVec w) : IntW w := theorem or?_eq : LLVM.or? a b = .some (BitVec.or a b) := rfl @[simp_llvm_option] -def or {w : Nat} (x y : IntW w) : IntW w := do +def or {w : Nat} (x y : IntW w) (params : OrParams := {}) : IntW w := do let x' ← x let y' ← y + let g ← and x y + guard (¬params.disjoint ∨ g = 0) or? x' y' /-- @@ -70,18 +75,25 @@ The value produced is the integer sum of the two operands. If the sum has unsigned overflow, the result returned is the mathematical result modulo 2n, where n is the bit width of the result. Because LLVM integers use a two’s complement representation, this instruction is appropriate for both signed and unsigned integers. -/ + +structure AddParams where + nuw : Bool := false + nsw : Bool := false + @[simp_llvm] -def add? {w : Nat} (x y : BitVec w) : IntW w := +def add? {w : Nat} (x y : BitVec w) : IntW w := pure <| x + y @[simp_llvm_option] theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl @[simp_llvm_option] -def add {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - add? x' y' +def add {w : Nat} (x y : IntW w) (params : AddParams := {}) : IntW w := do + let x ← x + let y ← y + guard (¬ params.nsw ∨ ¬ ((x.toInt + y.toInt) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w)) + guard (¬ params.nuw ∨ ¬ ((x.toNat + y.toNat) ≥ 2^w)) + add? x y /-- The value produced is the integer difference of the two operands. @@ -91,15 +103,26 @@ Because LLVM integers use a two’s complement representation, this instruction @[simp_llvm] def sub? {w : Nat} (x y : BitVec w) : IntW w := pure <| x - y +structure SubParams where + nuw : Bool := false + nsw : Bool := false @[simp_llvm_option] theorem sub?_eq : LLVM.sub? a b = .some (BitVec.sub a b) := rfl @[simp_llvm_option] -def sub {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - sub? x' y' +def sub {w : Nat} (x y : IntW w) (params : SubParams := {}) : IntW w := do + let x ← x + let y ← y + + -- Check for unsigned overflow if nuw is set + guard (¬params.nuw ∨ x.toNat ≥ y.toNat) + + -- Check for signed overflow if nsw is set + guard (¬params.nsw ∨ + ((x.toInt - y.toInt) ≥ -(2^(w-1)) ∧ (x.toInt - y.toInt) < 2^(w-1))) + + sub? x y /-- The value produced is the integer product of the two operands. @@ -112,6 +135,10 @@ result for both signed and unsigned integers. If a full product (e.g. i32 * i32 -> i64) is needed, the operands should be sign-extended or zero-extended as appropriate to the width of the full product. -/ + +structure MulParams where + nuw : Bool := false + nsw : Bool := false @[simp_llvm] def mul? {w : Nat} (x y : BitVec w) : IntW w := pure <| x * y @@ -120,16 +147,30 @@ def mul? {w : Nat} (x y : BitVec w) : IntW w := theorem mul?_eq : LLVM.mul? a b = .some (BitVec.mul a b) := rfl @[simp_llvm_option] -def mul {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - mul? x' y' +def mul {w : Nat} (x y : IntW w) (params : MulParams := {}) : IntW w := do + let x ← x + let y ← y + + -- Perform the multiplication + let result := x * y + + -- Check for unsigned overflow if nuw is set + guard (¬ params.nuw ∨ ¬((x * y).toNat ≤ (2^w))) + -- Check for signed overflow if nsw is set + guard (¬ params.nsw ∨ ¬((x * y).toNat < (2^(w - 1)) ∧ (x * y).toNat ≥ (2^(w - 1)) )) + + -- Return the result + result /-- The value produced is the unsigned integer quotient of the two operands. Note that unsigned integer division and signed integer division are distinct operations; for signed integer division, use ‘sdiv’. Division by zero is undefined behavior. -/ + +structure UdivParams where + exact : Bool := false + @[simp_llvm] def udiv? {w : Nat} (x y : BitVec w) : IntW w := match y.toNat with @@ -137,9 +178,11 @@ def udiv? {w : Nat} (x y : BitVec w) : IntW w := | _ => pure <| BitVec.ofInt w (x.toNat / y.toNat) @[simp_llvm_option] -def udiv {w : Nat} (x y : IntW w) : IntW w := do +def udiv {w : Nat} (x y : IntW w) (params : UdivParams := {}) : IntW w := do let x' ← x let y' ← y + --- If the exact keyword is present, the result value of the udiv is a poison value if %op1 is not a multiple of %op2 + guard (¬params.exact ∨ (x'.toNat ∣ y'.toNat)) udiv? x' y' def intMin (w : Nat) : BitVec w := @@ -166,25 +209,36 @@ at width 2, -4 / -1 is considered overflow! -- only way overflow can happen is (INT_MIN / -1). -- but we do not consider overflow when `w=1`, because `w=1` only has a sign bit, so there -- is no magniture to overflow. +structure SdivParams where + exact : Bool := false @[simp_llvm] -def sdiv? {w : Nat} (x y : BitVec w) : IntW w := - if y == 0 || (w != 1 && x == (intMin w) && y == -1) - then .none - else pure (BitVec.sdiv x y) +def sdiv? {w : Nat} (x y : BitVec w) : IntW w := do + guard ((y ≠ 0) ∧ (w = 1 ∨ x ≠ (intMin w) ∨ y ≠ -1) ) + BitVec.sdiv x y theorem sdiv?_denom_zero_eq_none {w : Nat} (x : BitVec w) : LLVM.sdiv? x 0 = none := by - simp [LLVM.sdiv?, BitVec.sdiv] + simp + simp [LLVM.sdiv?] + simp [BitVec.sdiv] + simp [guard] + simp [failure] theorem sdiv?_eq_pure_of_neq_allOnes {x y : BitVec w} (hy : y ≠ 0) (hx : LLVM.intMin w ≠ x) : LLVM.sdiv? x y = pure (BitVec.sdiv x y) := by - simp [LLVM.sdiv?] - tauto + simp only [LLVM.sdiv?] + have t : (y ≠ 0 ∧ (w = 1 ∨ ¬x = intMin w ∨ ¬y = -1#w)) := by + tauto + -- simp only [, ne_eq, Option.bind_eq_bind, Option.pure_def] + simp only [ne_eq, t, and_true, Option.bind_eq_bind, Option.pure_def] + simp [t] + simp [guard] @[simp_llvm_option] -def sdiv {w : Nat} (x y : IntW w) : IntW w := do +def sdiv {w : Nat} (x y : IntW w) (params : SdivParams := {}) : IntW w := do let x' ← x let y' ← y + guard (¬ params.exact ∨ (x'.toNat ∣ y'.toNat)) sdiv? x' y' -- Probably not a Mathlib worthy name, not sure how you'd mathlibify the precondition @@ -195,7 +249,13 @@ theorem sdiv?_eq_div_if {w : Nat} {x y : BitVec w} : then none else pure <| BitVec.sdiv x y := by - simp [sdiv?]; split_ifs <;> try tauto + simp [sdiv?, -BitVec.ofNat_eq_ofNat, -ne_eq, guard, Option.bind] + simp + split_ifs + tauto + tauto + tauto + tauto /-- This instruction returns the unsigned integer remainder of a division. This instruction always performs an unsigned division to get the remainder. @@ -282,6 +342,9 @@ The value produced is op1 * 2^op2 mod 2n, where n is the width of the result. If op2 is (statically or dynamically) equal to or larger than the number of bits in op1, this instruction returns a poison value. -/ +structure ShlParams where + nuw : Bool := false + nsw : Bool := false @[simp_llvm] def shl? {n} (op1 : BitVec n) (op2 : BitVec n) : IntW n := let bits := op2.toNat -- should this be toInt? @@ -290,10 +353,12 @@ def shl? {n} (op1 : BitVec n) (op2 : BitVec n) : IntW n := @[simp_llvm_option] -def shl {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - shl? x' y' +def shl {w : Nat} (x y : IntW w) (params : ShlParams := {}): IntW w := do + let x ← x + let y ← y + guard (¬ params.nsw ∨ ¬ ((x.toInt * 2 ^ y.toNat) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w)) + guard (¬ params.nuw ∨ ¬ ((x.toNat * 2 ^ y.toNat) ≥ 2^w)) + shl? x y /-- This instruction always performs a logical shift right operation. From ce2e0d9fa759acb7526e84bddf8f9b5f6af3b2f5 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Wed, 17 Jul 2024 11:00:58 +0100 Subject: [PATCH 2/8] Make changes --- SSA/Projects/InstCombine/LLVM/Semantics.lean | 95 +++++++++----------- 1 file changed, 41 insertions(+), 54 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index 04f6f025c..1cd03328a 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -34,9 +34,6 @@ def and {w : Nat} (x y : IntW w) : IntW w := do The ‘or’ instruction returns the bitwise logical inclusive or of its two operands. -/ - -structure OrParams where - disjoint : Bool := false @[simp_llvm] def or? {w : Nat} (x y : BitVec w) : IntW w := pure <| x ||| y @@ -45,12 +42,13 @@ def or? {w : Nat} (x y : BitVec w) : IntW w := theorem or?_eq : LLVM.or? a b = .some (BitVec.or a b) := rfl @[simp_llvm_option] -def or {w : Nat} (x y : IntW w) (params : OrParams := {}) : IntW w := do +def or {w : Nat} (x y : IntW w) (disjoint : Bool := false) : IntW w := do let x' ← x let y' ← y - let g ← and x y - guard (¬params.disjoint ∨ g = 0) - or? x' y' + if disjoint ∧ BitVec.toNat ( x' &&& y') = 0 then + .none + else + or? x' y' /-- The ‘xor’ instruction returns the bitwise logical exclusive or of its two @@ -81,7 +79,7 @@ structure AddParams where nsw : Bool := false @[simp_llvm] -def add? {w : Nat} (x y : BitVec w) : IntW w := +def add? {w : Nat} (x y : BitVec w) : IntW w := pure <| x + y @[simp_llvm_option] @@ -91,9 +89,10 @@ theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl def add {w : Nat} (x y : IntW w) (params : AddParams := {}) : IntW w := do let x ← x let y ← y - guard (¬ params.nsw ∨ ¬ ((x.toInt + y.toInt) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w)) - guard (¬ params.nuw ∨ ¬ ((x.toNat + y.toNat) ≥ 2^w)) - add? x y + if (params.nsw ∧ (x.toInt + y.toInt) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w) ∨ ( params.nuw ∧ (x.toNat + y.toNat) ≥ 2^w) then + .none + else + add? x y /-- The value produced is the integer difference of the two operands. @@ -114,15 +113,13 @@ theorem sub?_eq : LLVM.sub? a b = .some (BitVec.sub a b) := rfl def sub {w : Nat} (x y : IntW w) (params : SubParams := {}) : IntW w := do let x ← x let y ← y - -- Check for unsigned overflow if nuw is set - guard (¬params.nuw ∨ x.toNat ≥ y.toNat) - -- Check for signed overflow if nsw is set - guard (¬params.nsw ∨ - ((x.toInt - y.toInt) ≥ -(2^(w-1)) ∧ (x.toInt - y.toInt) < 2^(w-1))) - - sub? x y + if (params.nuw ∧ x.toNat ≥ y.toNat) ∨ (params.nsw ∧ + ((x.toInt - y.toInt) ≥ -(2^(w-1)) ∧ (x.toInt - y.toInt) < 2^(w-1))) then + .none + else + sub? x y /-- The value produced is the integer product of the two operands. @@ -150,17 +147,13 @@ theorem mul?_eq : LLVM.mul? a b = .some (BitVec.mul a b) := rfl def mul {w : Nat} (x y : IntW w) (params : MulParams := {}) : IntW w := do let x ← x let y ← y - - -- Perform the multiplication - let result := x * y - -- Check for unsigned overflow if nuw is set - guard (¬ params.nuw ∨ ¬((x * y).toNat ≤ (2^w))) -- Check for signed overflow if nsw is set - guard (¬ params.nsw ∨ ¬((x * y).toNat < (2^(w - 1)) ∧ (x * y).toNat ≥ (2^(w - 1)) )) - - -- Return the result - result + if (params.nuw ∧ (x * y).toNat ≤ (2^w)) ∨ (params.nsw ∧ (x * y).toNat < (2^(w - 1)) ∧ (x * y).toNat ≥ (2^(w - 1))) then + .none + else + -- Return the result + mul? x y /-- The value produced is the unsigned integer quotient of the two operands. @@ -182,8 +175,10 @@ def udiv {w : Nat} (x y : IntW w) (params : UdivParams := {}) : IntW w := do let x' ← x let y' ← y --- If the exact keyword is present, the result value of the udiv is a poison value if %op1 is not a multiple of %op2 - guard (¬params.exact ∨ (x'.toNat ∣ y'.toNat)) - udiv? x' y' + if params.exact ∧ ¬(x'.toNat ∣ y'.toNat) then + .none + else + udiv? x' y' def intMin (w : Nat) : BitVec w := - BitVec.ofNat w (2^(w - 1)) @@ -212,34 +207,30 @@ at width 2, -4 / -1 is considered overflow! structure SdivParams where exact : Bool := false @[simp_llvm] -def sdiv? {w : Nat} (x y : BitVec w) : IntW w := do - guard ((y ≠ 0) ∧ (w = 1 ∨ x ≠ (intMin w) ∨ y ≠ -1) ) - BitVec.sdiv x y +def sdiv? {w : Nat} (x y : BitVec w) : IntW w := + if y = 0 ∨ (w ≠ 1 ∧ x = (intMin w) ∧ y = -1) + then .none + else pure (BitVec.sdiv x y) theorem sdiv?_denom_zero_eq_none {w : Nat} (x : BitVec w) : LLVM.sdiv? x 0 = none := by simp simp [LLVM.sdiv?] - simp [BitVec.sdiv] - simp [guard] - simp [failure] theorem sdiv?_eq_pure_of_neq_allOnes {x y : BitVec w} (hy : y ≠ 0) (hx : LLVM.intMin w ≠ x) : LLVM.sdiv? x y = pure (BitVec.sdiv x y) := by - simp only [LLVM.sdiv?] - have t : (y ≠ 0 ∧ (w = 1 ∨ ¬x = intMin w ∨ ¬y = -1#w)) := by - tauto - -- simp only [, ne_eq, Option.bind_eq_bind, Option.pure_def] - simp only [ne_eq, t, and_true, Option.bind_eq_bind, Option.pure_def] - simp [t] - simp [guard] + simp [LLVM.sdiv?] + tauto + @[simp_llvm_option] def sdiv {w : Nat} (x y : IntW w) (params : SdivParams := {}) : IntW w := do let x' ← x let y' ← y - guard (¬ params.exact ∨ (x'.toNat ∣ y'.toNat)) - sdiv? x' y' + if (params.exact ∧ ¬ (x'.toNat ∣ y'.toNat)) then + .none + else + sdiv? x' y' -- Probably not a Mathlib worthy name, not sure how you'd mathlibify the precondition @[simp_llvm] @@ -249,13 +240,8 @@ theorem sdiv?_eq_div_if {w : Nat} {x y : BitVec w} : then none else pure <| BitVec.sdiv x y := by - simp [sdiv?, -BitVec.ofNat_eq_ofNat, -ne_eq, guard, Option.bind] - simp - split_ifs - tauto - tauto - tauto - tauto + simp [sdiv?] + /-- This instruction returns the unsigned integer remainder of a division. This instruction always performs an unsigned division to get the remainder. @@ -356,9 +342,10 @@ def shl? {n} (op1 : BitVec n) (op2 : BitVec n) : IntW n := def shl {w : Nat} (x y : IntW w) (params : ShlParams := {}): IntW w := do let x ← x let y ← y - guard (¬ params.nsw ∨ ¬ ((x.toInt * 2 ^ y.toNat) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w)) - guard (¬ params.nuw ∨ ¬ ((x.toNat * 2 ^ y.toNat) ≥ 2^w)) - shl? x y + if (params.nsw ∧ ¬ ((x.toInt * 2 ^ y.toNat) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w)) ∨ (params.nuw ∧ ¬ ((x.toNat * 2 ^ y.toNat) ≥ 2^w)) then + .none + else + shl? x y /-- This instruction always performs a logical shift right operation. From 5730a608dfd419c0875ffcbbc3f391a3203eaf9c Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Wed, 17 Jul 2024 15:45:29 +0100 Subject: [PATCH 3/8] change syntax parser --- SSA/Core/MLIRSyntax/GenericParser.lean | 14 +++++++++++--- SSA/Projects/InstCombine/Base.lean | 18 +++++++++++------- SSA/Projects/InstCombine/LLVM/EDSL.lean | 15 +++++++++++++-- SSA/Projects/InstCombine/LLVM/Semantics.lean | 2 +- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/SSA/Core/MLIRSyntax/GenericParser.lean b/SSA/Core/MLIRSyntax/GenericParser.lean index 3d2a712f5..682c23ac5 100644 --- a/SSA/Core/MLIRSyntax/GenericParser.lean +++ b/SSA/Core/MLIRSyntax/GenericParser.lean @@ -535,8 +535,13 @@ syntax "@" ident : mlir_attr_val_symbol syntax "@" str : mlir_attr_val_symbol syntax "#" ident : mlir_attr_val -- alias syntax "#" strLit : mlir_attr_val -- aliass - -syntax "#" ident "<" strLit ">" : mlir_attr_val -- opaqueAttr +declare_syntax_cat dialect_attribute_contents +syntax mlir_attr_val : dialect_attribute_contents +syntax "(" dialect_attribute_contents + ")" : dialect_attribute_contents +syntax "[" dialect_attribute_contents + "]": dialect_attribute_contents +syntax "{" dialect_attribute_contents + "}": dialect_attribute_contents +-- syntax [^\[<({\]>)}\0]+ : dialect_attribute_contents +syntax "#" ident "<" dialect_attribute_contents ">" : mlir_attr_val -- opaqueAttr syntax "#opaque<" ident "," strLit ">" ":" mlir_type : mlir_attr_val -- opaqueElementsAttr syntax mlir_attr_val_symbol "::" mlir_attr_val_symbol : mlir_attr_val_symbol @@ -582,7 +587,10 @@ macro_rules | `([mlir_attr_val| # $dialect:ident < $opaqueData:str > ]) => do let dialect := Lean.quote dialect.getId.toString `(AttrValue.opaque_ $dialect $opaqueData) - +| `([mlir_attr_val| # $dialect:ident < $opaqueData:ident > ]) => do + let d := Lean.quote dialect.getId.toString + let g : TSyntax `str := Lean.Syntax.mkStrLit (toString opaqueData.getId) + `(AttrValue.opaque_ $d $g) macro_rules | `([mlir_attr_val| #opaque< $dialect:ident, $opaqueData:str> : $t:mlir_type ]) => do let dialect := Lean.quote dialect.getId.toString diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index 7e26aca86..c3e3245b3 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -96,7 +96,7 @@ inductive MOp.BinaryOp : Type | ashr | urem | srem - | add + | add (nsw : Bool) (nuw : Bool) | mul | sub | sdiv @@ -127,7 +127,7 @@ namespace MOp @[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr @[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem @[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem -@[match_pattern] def add (w : Width φ) : MOp φ := .binary w .add +@[match_pattern] def add (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.add nsw nuw) @[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul @[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub @[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv @@ -146,7 +146,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} (ashr : ∀ {φ} {w : Width φ}, motive (ashr w)) (urem : ∀ {φ} {w : Width φ}, motive (urem w)) (srem : ∀ {φ} {w : Width φ}, motive (srem w)) - (add : ∀ {φ} {w : Width φ}, motive (add w)) + (add : ∀ {φ nsw nuw} {w : Width φ}, motive (add nsw nuw w)) (mul : ∀ {φ} {w : Width φ}, motive (mul w)) (sub : ∀ {φ} {w : Width φ}, motive (sub w)) (sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w)) @@ -166,7 +166,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} | _, .ashr _ => ashr | _, .urem _ => urem | _, .srem _ => srem - | _, .add _ => add + | n, .add nsw nuw w => @add n nsw nuw w | _, .mul _ => mul | _, .sub _ => sub | _, .sdiv _ => sdiv @@ -189,7 +189,7 @@ instance : ToString (MOp φ) where | .urem _ => "urem" | .srem _ => "srem" | .select _ => "select" - | .add _ => "add" + | .add _ _ _ => "add" | .mul _ => "mul" | .sub _ => "sub" | .neg _ => "neg" @@ -216,7 +216,7 @@ namespace Op @[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete @[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete @[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete -@[match_pattern] abbrev add : Nat → Op := MOp.add ∘ .concrete +@[match_pattern] abbrev add (nuw : Bool := false) (nsw : Bool := false) : Nat → Op := (MOp.add nsw nuw) ∘ .concrete @[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete @[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete @[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete @@ -275,7 +275,10 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig | Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1) | Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1) | Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1) - | Op.add _ => LLVM.add (op.getN 0) (op.getN 1) + -- | Op.add _ => LLVM.add (op.getN 0) (op.getN 1) + | Op.add a b _ => LLVM.add (op.getN 0) (op.getN 1) {nsw := a, nuw := b} + + -- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true true)), _ => sorry | Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1) | Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1) | Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1) @@ -283,6 +286,7 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig | Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1) | Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1) | Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2) + -- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true false)), _ => sorry instance : DialectDenote LLVM := ⟨ fun o args _ => Op.denote o args diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 7303e242e..9dbcded19 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -97,7 +97,7 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : else let v₂ : Γ.Var (.bitvec w) := (by simpa using hty) ▸ v₂ - let (op : MOp.BinaryOp ⊕ LLVM.IntPredicate) ← match opStx.name with + let (op : (MOp.BinaryOp) ⊕ LLVM.IntPredicate) ← match opStx.name with | "llvm.and" => pure <| Sum.inl .and | "llvm.or" => pure <| Sum.inl .or | "llvm.xor" => pure <| Sum.inl .xor @@ -106,7 +106,18 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : | "llvm.ashr" => pure <| Sum.inl .ashr | "llvm.urem" => pure <| Sum.inl .urem | "llvm.srem" => pure <| Sum.inl .srem - | "llvm.add" => pure <| Sum.inl .add + | "llvm.add" => do + -- sorry + let att := opStx.attrs.getAttr "overflowFlags" + match att with + | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) + | .some y => match y with + | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.add true false) + | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.add false true) + | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" + | _ => throw <| .generic s!"flag not allowed" + -- sorry + | "llvm.mul" => pure <| Sum.inl .mul | "llvm.sub" => pure <| Sum.inl .sub | "llvm.sdiv" => pure <| Sum.inl .sdiv diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index 1cd03328a..ceb202165 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -45,7 +45,7 @@ theorem or?_eq : LLVM.or? a b = .some (BitVec.or a b) := rfl def or {w : Nat} (x y : IntW w) (disjoint : Bool := false) : IntW w := do let x' ← x let y' ← y - if disjoint ∧ BitVec.toNat ( x' &&& y') = 0 then + if disjoint ∧ x' &&& y' ≠ 0 then .none else or? x' y' From 3ad683aea68f765b763ec1ae2dc517126c719d44 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Wed, 17 Jul 2024 17:33:11 +0100 Subject: [PATCH 4/8] add all overflow flags --- .gitattributes | 1 + SSA/Projects/InstCombine/Base.lean | 113 ++++++++++++------------ SSA/Projects/InstCombine/LLVM/EDSL.lean | 36 +++++--- 3 files changed, 82 insertions(+), 68 deletions(-) diff --git a/.gitattributes b/.gitattributes index c20182ffa..290fd67bd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ SSA/Projects/InstCombine/tests/LLVM/** linguist-generated=true SSA/Projects/InstCombine/all.lean linguist-generated=true +test/LLVMDialect/InstCombine/** linguist-generated=true \ No newline at end of file diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index c3e3245b3..f968bf5d8 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -89,18 +89,18 @@ deriving Repr, DecidableEq, Inhabited /-- Homogeneous, binary operations -/ inductive MOp.BinaryOp : Type | and - | or + | or (disjoint : Bool) | xor - | shl - | lshr - | ashr + | shl (nsw : Bool) (nuw : Bool) + | lshr (exact : Bool) + | ashr (exact : Bool) | urem | srem | add (nsw : Bool) (nuw : Bool) - | mul - | sub - | sdiv - | udiv + | mul (nsw : Bool) (nuw : Bool) + | sub (nsw : Bool) (nuw : Bool) + | sdiv (exact : Bool) + | udiv (exact : Bool) deriving Repr, DecidableEq, Inhabited -- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations @@ -120,18 +120,18 @@ namespace MOp @[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy @[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and -@[match_pattern] def or (w : Width φ) : MOp φ := .binary w .or +@[match_pattern] def or (disjoint : Bool) (w : Width φ) : MOp φ := .binary w (.or disjoint) @[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor -@[match_pattern] def shl (w : Width φ) : MOp φ := .binary w .shl -@[match_pattern] def lshr (w : Width φ) : MOp φ := .binary w .lshr -@[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr +@[match_pattern] def shl (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.shl nsw nuw) +@[match_pattern] def lshr (exact : Bool) (w : Width φ) : MOp φ := .binary w (.lshr exact) +@[match_pattern] def ashr (exact : Bool) (w : Width φ) : MOp φ := .binary w (.ashr exact) @[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem @[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem @[match_pattern] def add (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.add nsw nuw) -@[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul -@[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub -@[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv -@[match_pattern] def udiv (w : Width φ) : MOp φ := .binary w .udiv +@[match_pattern] def mul (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.mul nsw nuw) +@[match_pattern] def sub (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.sub nsw nuw) +@[match_pattern] def sdiv (exact : Bool) (w : Width φ) : MOp φ := .binary w (.sdiv exact) +@[match_pattern] def udiv (exact : Bool) (w : Width φ) : MOp φ := .binary w (.udiv exact) /-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} @@ -139,18 +139,18 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} (not : ∀ {φ} {w : Width φ}, motive (not w)) (copy : ∀ {φ} {w : Width φ}, motive (copy w)) (and : ∀ {φ} {w : Width φ}, motive (and w)) - (or : ∀ {φ} {w : Width φ}, motive (or w)) + (or : ∀ {φ disjoint} {w : Width φ}, motive (or disjoint w)) (xor : ∀ {φ} {w : Width φ}, motive (xor w)) - (shl : ∀ {φ} {w : Width φ}, motive (shl w)) - (lshr : ∀ {φ} {w : Width φ}, motive (lshr w)) - (ashr : ∀ {φ} {w : Width φ}, motive (ashr w)) + (shl : ∀ {φ nsw nuw} {w : Width φ}, motive (shl nsw nuw w)) + (lshr : ∀ {φ exact } {w : Width φ}, motive (lshr exact w)) + (ashr : ∀ {φ exact } {w : Width φ}, motive (ashr exact w)) (urem : ∀ {φ} {w : Width φ}, motive (urem w)) (srem : ∀ {φ} {w : Width φ}, motive (srem w)) (add : ∀ {φ nsw nuw} {w : Width φ}, motive (add nsw nuw w)) - (mul : ∀ {φ} {w : Width φ}, motive (mul w)) - (sub : ∀ {φ} {w : Width φ}, motive (sub w)) - (sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w)) - (udiv : ∀ {φ} {w : Width φ}, motive (udiv w)) + (mul : ∀ {φ nsw nuw} {w : Width φ}, motive (mul nsw nuw w)) + (sub : ∀ {φ nsw nuw} {w : Width φ}, motive (sub nsw nuw w)) + (sdiv : ∀ {φ exact } {w : Width φ}, motive (sdiv exact w)) + (udiv : ∀ {φ exact } {w : Width φ}, motive (udiv exact w)) (select : ∀ {φ} {w : Width φ}, motive (select w)) (icmp : ∀ {φ c} {w : Width φ}, motive (icmp c w)) (const : ∀ {φ v} {w : Width φ}, motive (const w v)) : @@ -159,18 +159,18 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} | _, .not _ => not | _, .copy _ => copy | _, .and _ => and - | _, .or _ => or + | _, .or _ _ => or | _, .xor _ => xor - | _, .shl _ => shl - | _, .lshr _ => lshr - | _, .ashr _ => ashr + | _, .shl _ _ _ => shl + | _, .lshr _ _ => lshr + | _, .ashr _ _ => ashr | _, .urem _ => urem | _, .srem _ => srem | n, .add nsw nuw w => @add n nsw nuw w - | _, .mul _ => mul - | _, .sub _ => sub - | _, .sdiv _ => sdiv - | _, .udiv _ => udiv + | _, .mul _ _ _ => mul + | _, .sub _ _ _ => sub + | _, .sdiv _ _ => sdiv + | _, .udiv _ _ => udiv | _, .select _ => select | _, .icmp .. => icmp | _, .const .. => const @@ -180,22 +180,22 @@ end MOp instance : ToString (MOp φ) where toString | .and _ => "and" - | .or _ => "or" + | .or _ _ => "or" | .not _ => "not" | .xor _ => "xor" - | .shl _ => "shl" - | .lshr _ => "lshr" - | .ashr _ => "ashr" + | .shl _ _ _ => "shl" + | .lshr _ _ => "lshr" + | .ashr _ _ => "ashr" | .urem _ => "urem" | .srem _ => "srem" | .select _ => "select" | .add _ _ _ => "add" - | .mul _ => "mul" - | .sub _ => "sub" + | .mul _ _ _ => "mul" + | .sub _ _ _ => "sub" | .neg _ => "neg" | .copy _ => "copy" - | .sdiv _ => "sdiv" - | .udiv _ => "udiv" + | .sdiv _ _ => "sdiv" + | .udiv _ _ => "udiv" | .icmp ty _ => s!"icmp {ty}" | .const _ v => s!"const {v}" @@ -207,22 +207,22 @@ namespace Op @[match_pattern] abbrev binary (w : Nat) (op : MOp.BinaryOp) : Op := MOp.binary (.concrete w) op @[match_pattern] abbrev and : Nat → Op := MOp.and ∘ .concrete -@[match_pattern] abbrev or : Nat → Op := MOp.or ∘ .concrete +@[match_pattern] abbrev or (disjoint : Bool) : Nat → Op := MOp.or disjoint ∘ .concrete @[match_pattern] abbrev not : Nat → Op := MOp.not ∘ .concrete @[match_pattern] abbrev xor : Nat → Op := MOp.xor ∘ .concrete -@[match_pattern] abbrev shl : Nat → Op := MOp.shl ∘ .concrete -@[match_pattern] abbrev lshr : Nat → Op := MOp.lshr ∘ .concrete -@[match_pattern] abbrev ashr : Nat → Op := MOp.ashr ∘ .concrete +@[match_pattern] abbrev shl (nsw nuw : Bool) : Nat → Op := MOp.shl nsw nuw ∘ .concrete +@[match_pattern] abbrev lshr (exact : Bool) : Nat → Op := MOp.lshr exact ∘ .concrete +@[match_pattern] abbrev ashr (exact : Bool) : Nat → Op := MOp.ashr exact ∘ .concrete @[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete @[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete @[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete @[match_pattern] abbrev add (nuw : Bool := false) (nsw : Bool := false) : Nat → Op := (MOp.add nsw nuw) ∘ .concrete -@[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete -@[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete +@[match_pattern] abbrev mul (nsw nuw : Bool) : Nat → Op := MOp.mul nsw nuw ∘ .concrete +@[match_pattern] abbrev sub (nsw nuw : Bool) : Nat → Op := MOp.sub nsw nuw ∘ .concrete @[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete @[match_pattern] abbrev copy : Nat → Op := MOp.copy ∘ .concrete -@[match_pattern] abbrev sdiv : Nat → Op := MOp.sdiv ∘ .concrete -@[match_pattern] abbrev udiv : Nat → Op := MOp.udiv ∘ .concrete +@[match_pattern] abbrev sdiv (exact : Bool) : Nat → Op := MOp.sdiv exact ∘ .concrete +@[match_pattern] abbrev udiv (exact : Bool) : Nat → Op := MOp.udiv exact ∘ .concrete @[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete @[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val @@ -269,19 +269,18 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig | Op.not _ => LLVM.not (op.getN 0) | Op.neg _ => LLVM.neg (op.getN 0) | Op.and _ => LLVM.and (op.getN 0) (op.getN 1) - | Op.or _ => LLVM.or (op.getN 0) (op.getN 1) + | Op.or d _ => LLVM.or (op.getN 0) (op.getN 1) d | Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1) - | Op.shl _ => LLVM.shl (op.getN 0) (op.getN 1) - | Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1) - | Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1) - | Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1) - -- | Op.add _ => LLVM.add (op.getN 0) (op.getN 1) + | Op.shl nsw nuw _ => LLVM.shl (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} + | Op.lshr e _ => LLVM.lshr (op.getN 0) (op.getN 1) + | Op.ashr e _ => LLVM.ashr (op.getN 0) (op.getN 1) + | Op.sub nsw nuw _ => LLVM.sub (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} | Op.add a b _ => LLVM.add (op.getN 0) (op.getN 1) {nsw := a, nuw := b} -- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true true)), _ => sorry - | Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1) - | Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1) - | Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1) + | Op.mul nsw nuw _ => LLVM.mul (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} + | Op.sdiv e _ => LLVM.sdiv (op.getN 0) (op.getN 1) + | Op.udiv e _ => LLVM.udiv (op.getN 0) (op.getN 1) | Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1) | Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1) | Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 9dbcded19..89d76ad59 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -99,15 +99,15 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : let (op : (MOp.BinaryOp) ⊕ LLVM.IntPredicate) ← match opStx.name with | "llvm.and" => pure <| Sum.inl .and - | "llvm.or" => pure <| Sum.inl .or + -- we do nothing, MLIR does not support this syntax, I don't think + | "llvm.or" => pure <| Sum.inl (.or false) | "llvm.xor" => pure <| Sum.inl .xor - | "llvm.shl" => pure <| Sum.inl .shl - | "llvm.lshr" => pure <| Sum.inl .lshr - | "llvm.ashr" => pure <| Sum.inl .ashr + | "llvm.shl" => pure <| Sum.inl (.shl false false) + | "llvm.lshr" => pure <| Sum.inl (.lshr false) + | "llvm.ashr" => pure <| Sum.inl (.ashr false) | "llvm.urem" => pure <| Sum.inl .urem | "llvm.srem" => pure <| Sum.inl .srem | "llvm.add" => do - -- sorry let att := opStx.attrs.getAttr "overflowFlags" match att with | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) @@ -116,12 +116,26 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.add false true) | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" | _ => throw <| .generic s!"flag not allowed" - -- sorry - - | "llvm.mul" => pure <| Sum.inl .mul - | "llvm.sub" => pure <| Sum.inl .sub - | "llvm.sdiv" => pure <| Sum.inl .sdiv - | "llvm.udiv" => pure <| Sum.inl .udiv + | "llvm.mul" =>do + let att := opStx.attrs.getAttr "overflowFlags" + match att with + | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) + | .some y => match y with + | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.mul true false) + | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.mul false true) + | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" + | _ => throw <| .generic s!"flag not allowed" + | "llvm.sub" =>do + let att := opStx.attrs.getAttr "overflowFlags" + match att with + | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) + | .some y => match y with + | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.sub true false) + | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.sub false true) + | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" + | _ => throw <| .generic s!"flag not allowed" + | "llvm.sdiv" => pure <| Sum.inl (.sdiv false) + | "llvm.udiv" => pure <| Sum.inl (.udiv false) | "llvm.icmp.eq" => pure <| Sum.inr LLVM.IntPredicate.eq | "llvm.icmp.ne" => pure <| Sum.inr LLVM.IntPredicate.ne | "llvm.icmp.ugt" => pure <| Sum.inr LLVM.IntPredicate.ugt From 7492a71b8b6eb95555726ed4627a2808e5629623 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Thu, 18 Jul 2024 11:14:39 +0100 Subject: [PATCH 5/8] fix a bug --- .../AliveHandwrittenLargeExamples.lean | 3 +++ SSA/Projects/InstCombine/Base.lean | 20 +++++++++---------- SSA/Projects/InstCombine/ComWrappers.lean | 18 ++++++++--------- SSA/Projects/InstCombine/LLVM/EDSL.lean | 4 ++-- SSA/Projects/InstCombine/LLVM/Semantics.lean | 19 +++++++++++++----- 5 files changed, 38 insertions(+), 26 deletions(-) diff --git a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean index 0cf7a1f99..074108b1e 100644 --- a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean +++ b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean @@ -4,6 +4,7 @@ import SSA.Projects.InstCombine.LLVM.PrettyEDSL import SSA.Projects.InstCombine.Tactic import SSA.Projects.InstCombine.TacticAuto import SSA.Projects.InstCombine.ComWrappers +import SSA.Projects.InstCombine.LLVM.Semantics import Mathlib.Tactic open BitVec @@ -134,6 +135,7 @@ def alive_simplifyMulDivRem805 (w : Nat) : rw [LLVM.sdiv?_denom_zero_eq_none] apply Refinement.none_left case neg => + simp rw [BitVec.ult_toNat] rw [BitVec.toNat_ofNat] cases w' @@ -227,6 +229,7 @@ def alive_simplifyMulDivRem805' (w : Nat) : unfold MulDivRem805_lhs MulDivRem805_rhs simp only [simp_llvm_wrap] simp_alive_ssa + simp only [LLVM.add_reduce] simp_alive_undef simp_alive_case_bash simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq] diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index f968bf5d8..3b4bcfee5 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -120,18 +120,18 @@ namespace MOp @[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy @[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and -@[match_pattern] def or (disjoint : Bool) (w : Width φ) : MOp φ := .binary w (.or disjoint) +@[match_pattern] def or (disjoint : Bool := false) (w : Width φ) : MOp φ := .binary w (.or disjoint) @[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor @[match_pattern] def shl (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.shl nsw nuw) -@[match_pattern] def lshr (exact : Bool) (w : Width φ) : MOp φ := .binary w (.lshr exact) -@[match_pattern] def ashr (exact : Bool) (w : Width φ) : MOp φ := .binary w (.ashr exact) +@[match_pattern] def lshr (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.lshr exact) +@[match_pattern] def ashr (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.ashr exact) @[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem @[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem -@[match_pattern] def add (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.add nsw nuw) -@[match_pattern] def mul (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.mul nsw nuw) -@[match_pattern] def sub (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.sub nsw nuw) -@[match_pattern] def sdiv (exact : Bool) (w : Width φ) : MOp φ := .binary w (.sdiv exact) -@[match_pattern] def udiv (exact : Bool) (w : Width φ) : MOp φ := .binary w (.udiv exact) +@[match_pattern] def add (nsw : Bool:= false) (nuw : Bool:= false) (w : Width φ) : MOp φ := .binary w (.add nsw nuw) +@[match_pattern] def mul (nsw : Bool:= false) (nuw : Bool:= false) (w : Width φ) : MOp φ := .binary w (.mul nsw nuw) +@[match_pattern] def sub (nsw : Bool:= false) (nuw : Bool:= false) (w : Width φ) : MOp φ := .binary w (.sub nsw nuw) +@[match_pattern] def sdiv (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.sdiv exact) +@[match_pattern] def udiv (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.udiv exact) /-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} @@ -279,8 +279,8 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig -- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true true)), _ => sorry | Op.mul nsw nuw _ => LLVM.mul (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} - | Op.sdiv e _ => LLVM.sdiv (op.getN 0) (op.getN 1) - | Op.udiv e _ => LLVM.udiv (op.getN 0) (op.getN 1) + | Op.sdiv e _ => LLVM.sdiv (op.getN 0) (op.getN 1) {exact := e} + | Op.udiv e _ => LLVM.udiv (op.getN 0) (op.getN 1) {exact := e} | Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1) | Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1) | Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1) diff --git a/SSA/Projects/InstCombine/ComWrappers.lean b/SSA/Projects/InstCombine/ComWrappers.lean index 2f6dc4932..f7e5d0eef 100644 --- a/SSA/Projects/InstCombine/ComWrappers.lean +++ b/SSA/Projects/InstCombine/ComWrappers.lean @@ -65,7 +65,7 @@ def or {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.or w) + (op := InstCombine.MOp.or false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -93,7 +93,7 @@ def shl {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.shl w) + (op := InstCombine.MOp.shl false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -107,7 +107,7 @@ def lshr {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.lshr w) + (op := InstCombine.MOp.lshr false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -121,7 +121,7 @@ def ashr {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.ashr w) + (op := InstCombine.MOp.ashr false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -135,7 +135,7 @@ def sub {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.sub w) + (op := InstCombine.MOp.sub false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -149,7 +149,7 @@ def add {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.add w) + (op := InstCombine.MOp.add false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -163,7 +163,7 @@ def mul {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.mul w) + (op := InstCombine.MOp.mul false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -177,7 +177,7 @@ def sdiv {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.sdiv w) + (op := InstCombine.MOp.sdiv false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -191,7 +191,7 @@ def udiv {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.udiv w) + (op := InstCombine.MOp.udiv false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 89d76ad59..e303e65b0 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -119,7 +119,7 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : | "llvm.mul" =>do let att := opStx.attrs.getAttr "overflowFlags" match att with - | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) + | .none => pure <| Sum.inl (MOp.BinaryOp.mul false false) | .some y => match y with | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.mul true false) | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.mul false true) @@ -128,7 +128,7 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : | "llvm.sub" =>do let att := opStx.attrs.getAttr "overflowFlags" match att with - | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) + | .none => pure <| Sum.inl (MOp.BinaryOp.sub false false) | .some y => match y with | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.sub true false) | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.sub false true) diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index ceb202165..6575c369e 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -89,11 +89,20 @@ theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl def add {w : Nat} (x y : IntW w) (params : AddParams := {}) : IntW w := do let x ← x let y ← y - if (params.nsw ∧ (x.toInt + y.toInt) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w) ∨ ( params.nuw ∧ (x.toNat + y.toNat) ≥ 2^w) then + if (params.nsw ∧ (x.toInt + y.toInt) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w) ∨ ( params.nuw ∧ (x.toNat + y.toNat) ≥ 2^w) then .none else add? x y +set_option allowUnsafeReducibility true +@[simp, reducible] +theorem add_reduce (x y : IntW w) : add x y = match x , y with + | .none , _ => .none + | _ , .none => none + | .some a , .some b => .some (a + b) := by + rcases x + all_goals (cases y) + all_goals (try simp ; try rfl) /-- The value produced is the integer difference of the two operands. If the difference has unsigned overflow, the result returned is the mathematical result modulo 2n, where n is the bit width of the result. @@ -227,10 +236,10 @@ theorem sdiv?_eq_pure_of_neq_allOnes {x y : BitVec w} (hy : y ≠ 0) def sdiv {w : Nat} (x y : IntW w) (params : SdivParams := {}) : IntW w := do let x' ← x let y' ← y - if (params.exact ∧ ¬ (x'.toNat ∣ y'.toNat)) then - .none - else - sdiv? x' y' + -- if (params.exact ∧ ¬ (x'.toNat ∣ y'.toNat)) then + -- .none + -- else + sdiv? x' y' -- Probably not a Mathlib worthy name, not sure how you'd mathlibify the precondition @[simp_llvm] From 18b38fcba85208baf6f8ee9bacbd99d12a5dae27 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Thu, 18 Jul 2024 11:30:27 +0100 Subject: [PATCH 6/8] fix autogenerated messages --- SSA/Projects/InstCombine/Test.lean | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/SSA/Projects/InstCombine/Test.lean b/SSA/Projects/InstCombine/Test.lean index 3bd8c7a53..7c8a52dc2 100644 --- a/SSA/Projects/InstCombine/Test.lean +++ b/SSA/Projects/InstCombine/Test.lean @@ -102,10 +102,9 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.const (ConcreteOrMVa /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.ashr) (%0, %2) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.ashr false) (%0, %2) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 3) op2 ["1", "0", "arg0"] - /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) @@ -113,10 +112,11 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary -/ #guard_msgs in #eval mkExpr (Γn 4) op3 ["2", "1", "0", "arg0"] + /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.add) (%4, %3) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.add false false) (%4, %3) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 5) op4 ["3", "2", "1", "0", "arg0"] @@ -151,7 +151,7 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.const (ConcreteOrMVa /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.ashr) (%0, %2) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.ashr false) (%0, %2) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 3) (ops.get! 2) ["1", "0", "arg0"] @@ -165,7 +165,7 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.add) (%4, %3) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.add false false) (%4, %3) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 5) (ops.get! 4) ["3", "2", "1", "0", "arg0"] @@ -186,13 +186,13 @@ info: ⟨[MTy.bitvec (ConcreteOrMVar.concrete 32)], Com.var (Expr.mk (MOp.const (ConcreteOrMVar.concrete 32) (Int.ofNat 8)) ⋯ ⋯ HVector.nil HVector.nil) (Com.var (Expr.mk (MOp.const (ConcreteOrMVar.concrete 32) (Int.ofNat 31)) ⋯ ⋯ HVector.nil HVector.nil) (Com.var - (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) MOp.BinaryOp.ashr) ⋯ ⋯ (⟨2, ⋯⟩::ₕ(⟨0, ⋯⟩::ₕHVector.nil)) - HVector.nil) + (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) (MOp.BinaryOp.ashr false)) ⋯ ⋯ + (⟨2, ⋯⟩::ₕ(⟨0, ⋯⟩::ₕHVector.nil)) HVector.nil) (Com.var (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) MOp.BinaryOp.and) ⋯ ⋯ (⟨0, ⋯⟩::ₕ(⟨2, ⋯⟩::ₕHVector.nil)) HVector.nil) (Com.var - (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) MOp.BinaryOp.add) ⋯ ⋯ + (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) (MOp.BinaryOp.add false false)) ⋯ ⋯ (⟨0, ⋯⟩::ₕ(⟨1, ⋯⟩::ₕHVector.nil)) HVector.nil) (Com.ret ⟨0, ⋯⟩)))))⟩⟩⟩ -/ From a203c33c56dfdc09fa66057fa8db339db95af494 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Thu, 18 Jul 2024 11:44:53 +0100 Subject: [PATCH 7/8] please pass the github actions --- .../AliveHandwrittenLargeExamples.lean | 1 + SSA/Projects/InstCombine/LLVM/Semantics.lean | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean index 074108b1e..5b84a8d8f 100644 --- a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean +++ b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean @@ -230,6 +230,7 @@ def alive_simplifyMulDivRem805' (w : Nat) : simp only [simp_llvm_wrap] simp_alive_ssa simp only [LLVM.add_reduce] + simp only [LLVM.sdiv_reduce] simp_alive_undef simp_alive_case_bash simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq] diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index 6575c369e..4fff86b42 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -236,11 +236,18 @@ theorem sdiv?_eq_pure_of_neq_allOnes {x y : BitVec w} (hy : y ≠ 0) def sdiv {w : Nat} (x y : IntW w) (params : SdivParams := {}) : IntW w := do let x' ← x let y' ← y - -- if (params.exact ∧ ¬ (x'.toNat ∣ y'.toNat)) then - -- .none - -- else - sdiv? x' y' - + if (params.exact ∧ ¬ (x'.toNat ∣ y'.toNat)) then + .none + else + sdiv? x' y' +@[simp, reducible] +theorem sdiv_reduce (x y : IntW w) : sdiv x y = match x , y with + | .none , _ => .none + | _ , .none => none + | .some a , .some b => sdiv? a b := by + cases x + all_goals (cases y) + all_goals (try simp ; try rfl) -- Probably not a Mathlib worthy name, not sure how you'd mathlibify the precondition @[simp_llvm] theorem sdiv?_eq_div_if {w : Nat} {x y : BitVec w} : From 65fd1d50437edd263422f65d9559815713c73e87 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn Date: Thu, 18 Jul 2024 12:43:25 +0100 Subject: [PATCH 8/8] please work please work --- SSA/Projects/InstCombine/LLVM/Enumerator.lean | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/SSA/Projects/InstCombine/LLVM/Enumerator.lean b/SSA/Projects/InstCombine/LLVM/Enumerator.lean index 74afaf8bb..d28c3bd1e 100644 --- a/SSA/Projects/InstCombine/LLVM/Enumerator.lean +++ b/SSA/Projects/InstCombine/LLVM/Enumerator.lean @@ -182,18 +182,18 @@ def generateRawSemantics : IO Unit := do rows := rows.append (icmpRows pred) -- rows := rows.append (binopRows "and" (fun w a b => InstCombine.Op.denote (.and w) [a,b]ₕ)) - rows := rows.append (binopRows "or" (fun w a b => InstCombine.Op.denote (.or w) [a,b]ₕ)) + rows := rows.append (binopRows "or" (fun w a b => InstCombine.Op.denote (.or false w) [a,b]ₕ)) rows := rows.append (binopRows "xor" (fun w a b => InstCombine.Op.denote (.xor w) [a,b]ₕ)) - rows := rows.append (binopRows "add" (fun w a b => InstCombine.Op.denote (.add w) [a,b]ₕ)) - rows := rows.append (binopRows "sub" (fun w a b => InstCombine.Op.denote (.sub w) [a,b]ₕ)) - rows := rows.append (binopRows "mul" (fun w a b => InstCombine.Op.denote (.mul w) [a,b]ₕ)) - rows := rows.append (binopRows "udiv" (fun w a b => InstCombine.Op.denote (.udiv w) [a,b]ₕ)) - rows := rows.append (binopRows "sdiv" (fun w a b => InstCombine.Op.denote (.sdiv w) [a,b]ₕ)) + rows := rows.append (binopRows "add" (fun w a b => InstCombine.Op.denote (.add false false w) [a,b]ₕ)) + rows := rows.append (binopRows "sub" (fun w a b => InstCombine.Op.denote (.sub false false w) [a,b]ₕ)) + rows := rows.append (binopRows "mul" (fun w a b => InstCombine.Op.denote (.mul false false w) [a,b]ₕ)) + rows := rows.append (binopRows "udiv" (fun w a b => InstCombine.Op.denote (.udiv false w) [a,b]ₕ)) + rows := rows.append (binopRows "sdiv" (fun w a b => InstCombine.Op.denote (.sdiv false w) [a,b]ₕ)) rows := rows.append (binopRows "urem" (fun w a b => InstCombine.Op.denote (.urem w) [a,b]ₕ)) rows := rows.append (binopRows "srem" (fun w a b => InstCombine.Op.denote (.srem w) [a,b]ₕ)) - rows := rows.append (binopRows "shl" (fun w a b => InstCombine.Op.denote (.shl w) [a,b]ₕ)) - rows := rows.append (binopRows "lshr" (fun w a b => InstCombine.Op.denote (.lshr w) [a,b]ₕ)) - rows := rows.append (binopRows "ashr" (fun w a b => InstCombine.Op.denote (.ashr w) [a,b]ₕ)) + rows := rows.append (binopRows "shl" (fun w a b => InstCombine.Op.denote (.shl false false w) [a,b]ₕ)) + rows := rows.append (binopRows "lshr" (fun w a b => InstCombine.Op.denote (.lshr false w) [a,b]ₕ)) + rows := rows.append (binopRows "ashr" (fun w a b => InstCombine.Op.denote (.ashr false w) [a,b]ₕ)) rows.toList |>.map toString |> "\n".intercalate |> stream.putStr return ()