Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add Overflow Semantics to Addition in LLVM #480

Merged
merged 24 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions SSA/Core/MLIRSyntax/GenericParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,37 @@ declare_syntax_cat mlir_attr_val_symbol
syntax "@" ident : mlir_attr_val_symbol
syntax "@" str : mlir_attr_val_symbol
syntax "#" ident : mlir_attr_val -- alias
syntax "#" strLit : mlir_attr_val -- aliass
syntax "#" strLit : mlir_attr_val -- alias

syntax "#" ident "<" strLit ">" : mlir_attr_val -- opaqueAttr
declare_syntax_cat dialect_attribute_contents
syntax mlir_attr_val : dialect_attribute_contents
/--
I got this from https://mlir.llvm.org/docs/LangRef/
alexkeizer marked this conversation as resolved.
Show resolved Hide resolved

```bnf
dialect-namespace ::= bare-id

dialect-attribute ::= `#` (opaque-dialect-attribute | pretty-dialect-attribute)
opaque-dialect-attribute ::= dialect-namespace dialect-attribute-body
pretty-dialect-attribute ::= dialect-namespace `.` pretty-dialect-attribute-lead-ident
dialect-attribute-body?
pretty-dialect-attribute-lead-ident ::= `[A-Za-z][A-Za-z0-9._]*`

dialect-attribute-body ::= `<` dialect-attribute-contents+ `>`
dialect-attribute-contents ::= dialect-attribute-body
| `(` dialect-attribute-contents+ `)`
| `[` dialect-attribute-contents+ `]`
| `{` dialect-attribute-contents+ `}`
| [^\[<({\]>)}\0]+
```
-/
syntax "(" dialect_attribute_contents + ")" : dialect_attribute_contents
syntax "[" dialect_attribute_contents + "]": dialect_attribute_contents
syntax "{" dialect_attribute_contents + "}": dialect_attribute_contents
syntax "#" ident "<" sepBy(mlir_attr_val, ",") ">" : mlir_attr_val
AtticusKuhn marked this conversation as resolved.
Show resolved Hide resolved
-- If I un-comment this line, it causes an error. I don't know why. Oh well.
-- syntax "#" ident "<" ident ">" : mlir_attr_val
-- syntax "#" ident "<" strLit ">" : mlir_attr_val
syntax "#opaque<" ident "," strLit ">" ":" mlir_type : mlir_attr_val -- opaqueElementsAttr
syntax mlir_attr_val_symbol "::" mlir_attr_val_symbol : mlir_attr_val_symbol

Expand Down Expand Up @@ -596,11 +624,21 @@ macro_rules
| `([mlir_attr_val| false ]) => `(AttrValue.bool False)


macro_rules
| `([mlir_attr_val| # $dialect:ident <$xs ,* > ]) => do
let initList : TSyntax `term <- `([])
let vals : TSyntax `term <- xs.getElems.foldlM (init := initList) fun (xs : TSyntax `term) (x : TSyntax `mlir_attr_val) =>
`([mlir_attr_val| #$dialect<$x>] :: $xs)
`(AttrValue.list $vals)

macro_rules
| `([mlir_attr_val| # $dialect:ident < $opaqueData:str > ]) => do
let dialect := Lean.quote dialect.getId.toString
`(AttrValue.opaque_ $dialect $opaqueData)

| `([mlir_attr_val| # $dialect:ident < $opaqueData:ident > ]) => do
let d := Lean.quote dialect.getId.toString
let g : TSyntax `str := Lean.Syntax.mkStrLit (toString opaqueData.getId)
`(AttrValue.opaque_ $d $g)
macro_rules
| `([mlir_attr_val| #opaque< $dialect:ident, $opaqueData:str> : $t:mlir_type ]) => do
let dialect := Lean.quote dialect.getId.toString
Expand Down
54 changes: 48 additions & 6 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ deriving Repr, DecidableEq, Inhabited

/-- Homogeneous, binary operations -/
inductive MOp.BinaryOp : Type
| and
| or
| xor
| shl
| lshr
| ashr
| urem
| srem
| add (nswnuw : AdditionFlags := {nsw := false, nuw := false} )
| mul
| sub
| sdiv
| udiv
deriving DecidableEq, Inhabited
/-- Homogeneous, binary operations without flags -/
inductive BinaryOpWithoutFlags : Type
| and
| or
| xor
Expand All @@ -103,6 +119,28 @@ inductive MOp.BinaryOp : Type
| udiv
deriving Repr, DecidableEq, Inhabited

/--
The reason that I am using the admittedly hacky and ad-hoc method is that I want to preserve the guard_msgs statements, otherwise the build will fail.
But the default Repr instance has some fancy behavior where depending on the indentation it will sometimes wrap in parentheses.
I think the only way to replicate this behavior is to have another class and piggy-back off its default Repr class
-/
def BinaryOpRemoveFlags : MOp.BinaryOp → BinaryOpWithoutFlags
| .and => BinaryOpWithoutFlags.and
| .or => BinaryOpWithoutFlags.or
| .xor => BinaryOpWithoutFlags.xor
| .shl => BinaryOpWithoutFlags.shl
| .lshr => BinaryOpWithoutFlags.lshr
| .ashr => BinaryOpWithoutFlags.ashr
| .urem => BinaryOpWithoutFlags.urem
| .srem => BinaryOpWithoutFlags.srem
| .add _ => BinaryOpWithoutFlags.add
| .mul => BinaryOpWithoutFlags.mul
| .sub => BinaryOpWithoutFlags.sub
| .sdiv => BinaryOpWithoutFlags.sdiv
| .udiv => BinaryOpWithoutFlags.udiv

instance : Repr (MOp.BinaryOp) where
reprPrec op w := ((toString (reprPrec (BinaryOpRemoveFlags op) w)).replace "InstCombine.BinaryOpWithoutFlags" "InstCombine.MOp.BinaryOp").replace "false" ""
-- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations
inductive MOp (φ : Nat) : Type
| unary (w : Width φ) (op : MOp.UnaryOp) : MOp φ
Expand All @@ -127,7 +165,7 @@ namespace MOp
@[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr
@[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem
@[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem
@[match_pattern] def add (w : Width φ) : MOp φ := .binary w .add
@[match_pattern] def add (w : Width φ) (additionFlags: AdditionFlags := {nsw := false , nuw := false}) : MOp φ := .binary w (.add additionFlags )
AtticusKuhn marked this conversation as resolved.
Show resolved Hide resolved
@[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul
@[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub
@[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv
Expand All @@ -146,7 +184,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(ashr : ∀ {φ} {w : Width φ}, motive (ashr w))
(urem : ∀ {φ} {w : Width φ}, motive (urem w))
(srem : ∀ {φ} {w : Width φ}, motive (srem w))
(add : ∀ {φ} {w : Width φ}, motive (add w))
(add : ∀ {φ additionFlags} {w : Width φ}, motive (add w additionFlags))
(mul : ∀ {φ} {w : Width φ}, motive (mul w))
(sub : ∀ {φ} {w : Width φ}, motive (sub w))
(sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w))
Expand All @@ -166,7 +204,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
| _, .ashr _ => ashr
| _, .urem _ => urem
| _, .srem _ => srem
| _, .add _ => add
| _, .add _ _ => add
| _, .mul _ => mul
| _, .sub _ => sub
| _, .sdiv _ => sdiv
Expand All @@ -189,7 +227,7 @@ instance : ToString (MOp φ) where
| .urem _ => "urem"
| .srem _ => "srem"
| .select _ => "select"
| .add _ => "add"
| .add _ _ => "add"
| .mul _ => "mul"
| .sub _ => "sub"
| .neg _ => "neg"
Expand All @@ -216,7 +254,6 @@ namespace Op
@[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete
@[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete
@[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete
@[match_pattern] abbrev add : Nat → Op := MOp.add ∘ .concrete
@[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete
@[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete
@[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete
Expand All @@ -227,6 +264,11 @@ namespace Op
@[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete
@[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val


/-- Add is separate from the other operations because it takes in 2 flags: nuw and nsw.-/
@[match_pattern] abbrev add (w : Nat) (flags: AdditionFlags :=
{nsw := false , nuw := false}) : Op:= MOp.add (.concrete w) flags

end Op

instance : ToString Op where
Expand Down Expand Up @@ -275,7 +317,7 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
| Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1)
| Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1)
| Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1)
| Op.add _ => LLVM.add (op.getN 0) (op.getN 1)
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1)
| Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1)
| Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1)
Expand Down
12 changes: 11 additions & 1 deletion SSA/Projects/InstCombine/LLVM/EDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,17 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
| "llvm.ashr" => pure <| Sum.inl .ashr
| "llvm.urem" => pure <| Sum.inl .urem
| "llvm.srem" => pure <| Sum.inl .srem
| "llvm.add" => pure <| Sum.inl .add
| "llvm.add" => do
let attr? := opStx.attrs.getAttr "overflowFlags"
match attr? with
| .none => pure <| Sum.inl (MOp.BinaryOp.add ⟨ false , false ⟩ )
| .some y => match y with
| .opaque_ "llvm.overflow" "nsw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨ true ,false⟩ )
| .opaque_ "llvm.overflow" "nuw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨ false, true ⟩ )
| .list [.opaque_ "llvm.overflow" "nuw", .opaque_ "llvm.overflow" "nsw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨ true, true ⟩ )
| .list [.opaque_ "llvm.overflow" "nsw", .opaque_ "llvm.overflow" "nuw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨ true, true ⟩ )
| .opaque_ "llvm.overflow" s => throw <| .generic s!"The overflow flag {s} not allowed. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| _ => throw <| .generic s!"Unrecognised overflow flag found: {MLIR.AST.docAttrVal y}. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| "llvm.mul" => pure <| Sum.inl .mul
| "llvm.sub" => pure <| Sum.inl .sub
| "llvm.sdiv" => pure <| Sum.inl .sdiv
Expand Down
26 changes: 24 additions & 2 deletions SSA/Projects/InstCombine/LLVM/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,33 @@ def add? {w : Nat} (x y : BitVec w) : IntW w :=
@[simp_llvm_option]
theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl

structure AdditionFlags where
nsw : Bool := false
nuw : Bool := false
deriving DecidableEq

/-- Does the signed addition x + y overflow? -/
def AddSignedWraps? (x y : BitVec w) : Bool := (x.toInt + y.toInt) < -(2^(w-1)) ∨ (x.toInt + y.toInt) ≥ 2^w
AtticusKuhn marked this conversation as resolved.
Show resolved Hide resolved

/-- Does the unsigned addition x + y overflow? -/
def AddUnSignedWraps? (x y : BitVec w) : Bool := (x.toNat + y.toNat) ≥ 2^w

@[simp_llvm_option]
def add {w : Nat} (x y : IntW w) : IntW w := do
def add {w : Nat} (x y : IntW w) (flags : AdditionFlags := {nsw := false , nuw := false}) : IntW w := do
let x' ← x
let y' ← y
add? x' y'
let nsw := flags.nsw
let nuw := flags.nuw
if (nsw ∧ AddSignedWraps? x' y') ∨ (nuw ∧ AddUnSignedWraps? x' y') then
AtticusKuhn marked this conversation as resolved.
Show resolved Hide resolved
none
else
add? x' y'

@[simp]
theorem add_reduce (x y : IntW w) : add x y = match x, y with
| none , _ => none
| _ , none => none
| some a , some b => .some (a + b) := by cases x <;> cases y <;> (simp ; rfl)

/--
The value produced is the integer difference of the two operands.
Expand Down
5 changes: 5 additions & 0 deletions SSA/Projects/InstCombine/TacticAuto.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import SSA.Projects.InstCombine.ForLean

import SSA.Projects.InstCombine.LLVM.EDSL
import Batteries.Data.BitVec
import SSA.Projects.InstCombine.LLVM.Semantics

attribute [simp_llvm_case_bash]
BitVec.Refinement.refl BitVec.Refinement.some_some BitVec.Refinement.none_left
Expand Down Expand Up @@ -59,6 +60,10 @@ macro_rules
macro "simp_alive_undef" : tactic =>
`(tactic|
(
-- Since I introduced the NSW and NUW flags to add, I need to remove them in the case where they are both
-- the default value of false.
-- I don't know why Lean isn't smart enough to see that False ∧ x = False
try simp only [LLVM.add_reduce]
AtticusKuhn marked this conversation as resolved.
Show resolved Hide resolved
simp (config := {failIfUnchanged := false}) only [
simp_llvm_option,
BitVec.Refinement, bind_assoc,
Expand Down
Loading