Skip to content

Commit

Permalink
Feat: Add Overflow Semantics to Addition in LLVM (#480)
Browse files Browse the repository at this point in the history
This is a smaller, scaled back version of 
#471

```
I am wondering if we can scope it a bit smaller and just add overflow flags for addition. I feel this would allow us to iterate quickly on the right implementation and then expand it to all ops.
```



https://grosser.zulipchat.com/#narrow/stream/446584-Project---Lean4---BitVectors/topic/Overflow.20Flags.20in.20LLVM/near/453334239

---------

Co-authored-by: Atticus Kuhn <[email protected]>
Co-authored-by: Tobias Grosser <[email protected]>
Co-authored-by: Alex Keizer <[email protected]>
Co-authored-by: Atticus Kuhn <[email protected]>
  • Loading branch information
5 people authored Aug 8, 2024
1 parent deb5982 commit 4b13fba
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 16 deletions.
45 changes: 42 additions & 3 deletions SSA/Core/MLIRSyntax/GenericParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,38 @@ declare_syntax_cat mlir_attr_val_symbol
syntax "@" ident : mlir_attr_val_symbol
syntax "@" str : mlir_attr_val_symbol
syntax "#" ident : mlir_attr_val -- alias
syntax "#" strLit : mlir_attr_val -- aliass
syntax "#" strLit : mlir_attr_val -- alias

syntax "#" ident "<" strLit ">" : mlir_attr_val -- opaqueAttr
declare_syntax_cat dialect_attribute_contents
syntax mlir_attr_val : dialect_attribute_contents
/--
Following https://mlir.llvm.org/docs/LangRef/, we define a `dialect-attribute`,
which is a particular case of an `mlir-attr-val` that is namespaced to a particular dialect
```bnf
dialect-namespace ::= bare-id
dialect-attribute ::= `#` (opaque-dialect-attribute | pretty-dialect-attribute)
opaque-dialect-attribute ::= dialect-namespace dialect-attribute-body
pretty-dialect-attribute ::= dialect-namespace `.` pretty-dialect-attribute-lead-ident
dialect-attribute-body?
pretty-dialect-attribute-lead-ident ::= `[A-Za-z][A-Za-z0-9._]*`
dialect-attribute-body ::= `<` dialect-attribute-contents+ `>`
dialect-attribute-contents ::= dialect-attribute-body
| `(` dialect-attribute-contents+ `)`
| `[` dialect-attribute-contents+ `]`
| `{` dialect-attribute-contents+ `}`
| [^\[<({\]>)}\0]+
```
-/
syntax "(" dialect_attribute_contents + ")" : dialect_attribute_contents
syntax "[" dialect_attribute_contents + "]": dialect_attribute_contents
syntax "{" dialect_attribute_contents + "}": dialect_attribute_contents
syntax "#" ident "<" mlir_attr_val,* ">" : mlir_attr_val
-- If I un-comment this line, it causes an error. I don't know why. Oh well.
-- syntax "#" ident "<" ident ">" : mlir_attr_val
-- syntax "#" ident "<" strLit ">" : mlir_attr_val
syntax "#opaque<" ident "," strLit ">" ":" mlir_type : mlir_attr_val -- opaqueElementsAttr
syntax mlir_attr_val_symbol "::" mlir_attr_val_symbol : mlir_attr_val_symbol

Expand Down Expand Up @@ -596,11 +625,21 @@ macro_rules
| `([mlir_attr_val| false ]) => `(AttrValue.bool False)


macro_rules
| `([mlir_attr_val| # $dialect:ident <$xs ,* > ]) => do
let initList : TSyntax `term <- `([])
let vals : TSyntax `term <- xs.getElems.foldlM (init := initList) fun (xs : TSyntax `term) (x : TSyntax `mlir_attr_val) =>
`([mlir_attr_val| #$dialect<$x>] :: $xs)
`(AttrValue.list $vals)

macro_rules
| `([mlir_attr_val| # $dialect:ident < $opaqueData:str > ]) => do
let dialect := Lean.quote dialect.getId.toString
`(AttrValue.opaque_ $dialect $opaqueData)

| `([mlir_attr_val| # $dialect:ident < $opaqueData:ident > ]) => do
let d := Lean.quote dialect.getId.toString
let g : TSyntax `str := Lean.Syntax.mkStrLit (toString opaqueData.getId)
`(AttrValue.opaque_ $d $g)
macro_rules
| `([mlir_attr_val| #opaque< $dialect:ident, $opaqueData:str> : $t:mlir_type ]) => do
let dialect := Lean.quote dialect.getId.toString
Expand Down
4 changes: 2 additions & 2 deletions SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def alive_simplifyMulDivRem805 (w : Nat) :
simp_alive_ssa
simp_alive_undef
simp_alive_case_bash
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq]
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq, false_and, false_or, ite_false, Option.some_bind]
cases w
case zero =>
intros x
Expand Down Expand Up @@ -231,7 +231,7 @@ def alive_simplifyMulDivRem805' (w : Nat) :
simp_alive_ssa
simp_alive_undef
simp_alive_case_bash
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq]
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq, false_and, false_or, ite_false, Option.some_bind]
intros a
simp_alive_ops
simp only [ofNat_eq_ofNat, Bool.or_eq_true, beq_iff_eq, Bool.and_eq_true, bne_iff_ne, ne_eq,
Expand Down
54 changes: 46 additions & 8 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,41 @@ inductive MOp.BinaryOp : Type
| ashr
| urem
| srem
| add
| add (nswnuw : AdditionFlags := {nsw := false, nuw := false} )
| mul
| sub
| sdiv
| udiv
deriving Repr, DecidableEq, Inhabited
deriving DecidableEq, Inhabited

open Std (Format) in
/--
If both the nuw and nsw flags are the default value (false,false),
then we should not print them. This should be the default
behavior in Lean, but it isn't
-/
def reprWithoutFlags (op : MOp.BinaryOp) (prec : Nat) : Format :=
let op : String := match op with
| .and => "and"
| .or => "or"
| .xor => "xor"
| .shl => "shl"
| .lshr => "lshr"
| .ashr => "ashr"
| .urem => "urem"
| .srem => "srem"
| .add ⟨false, false⟩ => "add"
| .add ⟨nsw, nuw⟩ => toString f!"add {nsw} {nuw}"
| .mul => "mul"
| .sub => "sub"
| .sdiv => "sdiv"
| .udiv => "udiv"
Repr.addAppParen (Format.group (Format.nest
(if prec >= max_prec then 1 else 2) f!"InstCombine.MOp.BinaryOp.{op}"))
prec

instance : Repr (MOp.BinaryOp) where
reprPrec := reprWithoutFlags

-- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations
inductive MOp (φ : Nat) : Type
Expand All @@ -127,12 +156,17 @@ namespace MOp
@[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr
@[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem
@[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem
@[match_pattern] def add (w : Width φ) : MOp φ := .binary w .add

@[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul
@[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub
@[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv
@[match_pattern] def udiv (w : Width φ) : MOp φ := .binary w .udiv

/- This definition is off by itself because it is different-/
@[match_pattern] def add (w : Width φ)
(additionFlags: AdditionFlags := {nsw := false , nuw := false}) : MOp φ
:= .binary w (.add additionFlags )

/-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/
def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(neg : ∀ {φ} {w : Width φ}, motive (neg w))
Expand All @@ -146,7 +180,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(ashr : ∀ {φ} {w : Width φ}, motive (ashr w))
(urem : ∀ {φ} {w : Width φ}, motive (urem w))
(srem : ∀ {φ} {w : Width φ}, motive (srem w))
(add : ∀ {φ} {w : Width φ}, motive (add w))
(add : ∀ {φ additionFlags} {w : Width φ}, motive (add w additionFlags))
(mul : ∀ {φ} {w : Width φ}, motive (mul w))
(sub : ∀ {φ} {w : Width φ}, motive (sub w))
(sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w))
Expand All @@ -166,7 +200,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
| _, .ashr _ => ashr
| _, .urem _ => urem
| _, .srem _ => srem
| _, .add _ => add
| _, .add _ _ => add
| _, .mul _ => mul
| _, .sub _ => sub
| _, .sdiv _ => sdiv
Expand All @@ -189,7 +223,7 @@ instance : ToString (MOp φ) where
| .urem _ => "urem"
| .srem _ => "srem"
| .select _ => "select"
| .add _ => "add"
| .add _ _ => "add"
| .mul _ => "mul"
| .sub _ => "sub"
| .neg _ => "neg"
Expand All @@ -216,7 +250,6 @@ namespace Op
@[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete
@[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete
@[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete
@[match_pattern] abbrev add : Nat → Op := MOp.add ∘ .concrete
@[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete
@[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete
@[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete
Expand All @@ -227,6 +260,11 @@ namespace Op
@[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete
@[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val


/- Add is separate from the other operations because it takes in 2 flags: nuw and nsw.-/
@[match_pattern] abbrev add (w : Nat) (flags: AdditionFlags :=
{nsw := false , nuw := false}) : Op:= MOp.add (.concrete w) flags

end Op

instance : ToString Op where
Expand Down Expand Up @@ -275,7 +313,7 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
| Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1)
| Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1)
| Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1)
| Op.add _ => LLVM.add (op.getN 0) (op.getN 1)
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1)
| Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1)
| Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1)
Expand Down
12 changes: 11 additions & 1 deletion SSA/Projects/InstCombine/LLVM/EDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,17 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
| "llvm.ashr" => pure <| Sum.inl .ashr
| "llvm.urem" => pure <| Sum.inl .urem
| "llvm.srem" => pure <| Sum.inl .srem
| "llvm.add" => pure <| Sum.inl .add
| "llvm.add" => do
let attr? := opStx.attrs.getAttr "overflowFlags"
match attr? with
| .none => pure <| Sum.inl (MOp.BinaryOp.add)
| .some y => match y with
| .opaque_ "llvm.overflow" "nsw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, false⟩)
| .opaque_ "llvm.overflow" "nuw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨false, true⟩)
| .list [.opaque_ "llvm.overflow" "nuw", .opaque_ "llvm.overflow" "nsw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, true⟩)
| .list [.opaque_ "llvm.overflow" "nsw", .opaque_ "llvm.overflow" "nuw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, true⟩)
| .opaque_ "llvm.overflow" s => throw <| .generic s!"The overflow flag {s} not allowed. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| _ => throw <| .generic s!"Unrecognised overflow flag found: {MLIR.AST.docAttrVal y}. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| "llvm.mul" => pure <| Sum.inl .mul
| "llvm.sub" => pure <| Sum.inl .sub
| "llvm.sdiv" => pure <| Sum.inl .sdiv
Expand Down
17 changes: 15 additions & 2 deletions SSA/Projects/InstCombine/LLVM/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,24 @@ def add? {w : Nat} (x y : BitVec w) : IntW w :=
@[simp_llvm_option]
theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl

structure AdditionFlags where
nsw : Bool := false
nuw : Bool := false
deriving Repr, DecidableEq

@[simp_llvm_option]
def add {w : Nat} (x y : IntW w) : IntW w := do
def add {w : Nat} (x y : IntW w) (flags : AdditionFlags := {nsw := false , nuw := false}) : IntW w := do
let x' ← x
let y' ← y
add? x' y'
let nsw := flags.nsw
let nuw := flags.nuw
let AddSignedWraps? : Prop := nsw ∧
((x'.toInt + y'.toInt) < -(2^(w-1)) ∨ (x'.toInt + y'.toInt) ≥ 2^w)
let AddUnsignedWraps? : Prop := nuw ∧ ((x'.toNat + y'.toNat) ≥ 2^w)
if (AddSignedWraps? ∨ AddUnsignedWraps?) then
none
else
add? x' y'

/--
The value produced is the integer difference of the two operands.
Expand Down

0 comments on commit 4b13fba

Please sign in to comment.