Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add Overflow Semantics to Addition in LLVM #480

Merged
merged 24 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions SSA/Core/MLIRSyntax/GenericParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,38 @@ declare_syntax_cat mlir_attr_val_symbol
syntax "@" ident : mlir_attr_val_symbol
syntax "@" str : mlir_attr_val_symbol
syntax "#" ident : mlir_attr_val -- alias
syntax "#" strLit : mlir_attr_val -- aliass
syntax "#" strLit : mlir_attr_val -- alias

syntax "#" ident "<" strLit ">" : mlir_attr_val -- opaqueAttr
declare_syntax_cat dialect_attribute_contents
syntax mlir_attr_val : dialect_attribute_contents
/--
Following https://mlir.llvm.org/docs/LangRef/, we define a `dialect-attribute`,
which is a particular case of an `mlir-attr-val` that is namespaced to a particular dialect

```bnf
dialect-namespace ::= bare-id

dialect-attribute ::= `#` (opaque-dialect-attribute | pretty-dialect-attribute)
opaque-dialect-attribute ::= dialect-namespace dialect-attribute-body
pretty-dialect-attribute ::= dialect-namespace `.` pretty-dialect-attribute-lead-ident
dialect-attribute-body?
pretty-dialect-attribute-lead-ident ::= `[A-Za-z][A-Za-z0-9._]*`

dialect-attribute-body ::= `<` dialect-attribute-contents+ `>`
dialect-attribute-contents ::= dialect-attribute-body
| `(` dialect-attribute-contents+ `)`
| `[` dialect-attribute-contents+ `]`
| `{` dialect-attribute-contents+ `}`
| [^\[<({\]>)}\0]+
```
-/
syntax "(" dialect_attribute_contents + ")" : dialect_attribute_contents
syntax "[" dialect_attribute_contents + "]": dialect_attribute_contents
syntax "{" dialect_attribute_contents + "}": dialect_attribute_contents
syntax "#" ident "<" mlir_attr_val,* ">" : mlir_attr_val
-- If I un-comment this line, it causes an error. I don't know why. Oh well.
-- syntax "#" ident "<" ident ">" : mlir_attr_val
-- syntax "#" ident "<" strLit ">" : mlir_attr_val
syntax "#opaque<" ident "," strLit ">" ":" mlir_type : mlir_attr_val -- opaqueElementsAttr
syntax mlir_attr_val_symbol "::" mlir_attr_val_symbol : mlir_attr_val_symbol

Expand Down Expand Up @@ -596,11 +625,21 @@ macro_rules
| `([mlir_attr_val| false ]) => `(AttrValue.bool False)


macro_rules
| `([mlir_attr_val| # $dialect:ident <$xs ,* > ]) => do
let initList : TSyntax `term <- `([])
let vals : TSyntax `term <- xs.getElems.foldlM (init := initList) fun (xs : TSyntax `term) (x : TSyntax `mlir_attr_val) =>
`([mlir_attr_val| #$dialect<$x>] :: $xs)
`(AttrValue.list $vals)

macro_rules
| `([mlir_attr_val| # $dialect:ident < $opaqueData:str > ]) => do
let dialect := Lean.quote dialect.getId.toString
`(AttrValue.opaque_ $dialect $opaqueData)

| `([mlir_attr_val| # $dialect:ident < $opaqueData:ident > ]) => do
let d := Lean.quote dialect.getId.toString
let g : TSyntax `str := Lean.Syntax.mkStrLit (toString opaqueData.getId)
`(AttrValue.opaque_ $d $g)
macro_rules
| `([mlir_attr_val| #opaque< $dialect:ident, $opaqueData:str> : $t:mlir_type ]) => do
let dialect := Lean.quote dialect.getId.toString
Expand Down
4 changes: 2 additions & 2 deletions SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def alive_simplifyMulDivRem805 (w : Nat) :
simp_alive_ssa
simp_alive_undef
simp_alive_case_bash
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq]
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq, false_and, false_or, ite_false, Option.some_bind]
cases w
case zero =>
intros x
Expand Down Expand Up @@ -231,7 +231,7 @@ def alive_simplifyMulDivRem805' (w : Nat) :
simp_alive_ssa
simp_alive_undef
simp_alive_case_bash
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq]
simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq, false_and, false_or, ite_false, Option.some_bind]
intros a
simp_alive_ops
simp only [ofNat_eq_ofNat, Bool.or_eq_true, beq_iff_eq, Bool.and_eq_true, bne_iff_ne, ne_eq,
Expand Down
54 changes: 46 additions & 8 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,41 @@ inductive MOp.BinaryOp : Type
| ashr
| urem
| srem
| add
| add (nswnuw : AdditionFlags := {nsw := false, nuw := false} )
| mul
| sub
| sdiv
| udiv
deriving Repr, DecidableEq, Inhabited
deriving DecidableEq, Inhabited

open Std (Format) in
/--
If both the nuw and nsw flags are the default value (false,false),
then we should not print them. This should be the default
behavior in Lean, but it isn't
-/
def reprWithoutFlags (op : MOp.BinaryOp) (prec : Nat) : Format :=
let op : String := match op with
| .and => "and"
| .or => "or"
| .xor => "xor"
| .shl => "shl"
| .lshr => "lshr"
| .ashr => "ashr"
| .urem => "urem"
| .srem => "srem"
| .add ⟨false, false⟩ => "add"
| .add ⟨nsw, nuw⟩ => toString f!"add {nsw} {nuw}"
| .mul => "mul"
| .sub => "sub"
| .sdiv => "sdiv"
| .udiv => "udiv"
Repr.addAppParen (Format.group (Format.nest
(if prec >= max_prec then 1 else 2) f!"InstCombine.MOp.BinaryOp.{op}"))
prec

instance : Repr (MOp.BinaryOp) where
reprPrec := reprWithoutFlags

-- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations
inductive MOp (φ : Nat) : Type
Expand All @@ -127,12 +156,17 @@ namespace MOp
@[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr
@[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem
@[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem
@[match_pattern] def add (w : Width φ) : MOp φ := .binary w .add

@[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul
@[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub
@[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv
@[match_pattern] def udiv (w : Width φ) : MOp φ := .binary w .udiv

/- This definition is off by itself because it is different-/
@[match_pattern] def add (w : Width φ)
(additionFlags: AdditionFlags := {nsw := false , nuw := false}) : MOp φ
:= .binary w (.add additionFlags )

/-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/
def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(neg : ∀ {φ} {w : Width φ}, motive (neg w))
Expand All @@ -146,7 +180,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(ashr : ∀ {φ} {w : Width φ}, motive (ashr w))
(urem : ∀ {φ} {w : Width φ}, motive (urem w))
(srem : ∀ {φ} {w : Width φ}, motive (srem w))
(add : ∀ {φ} {w : Width φ}, motive (add w))
(add : ∀ {φ additionFlags} {w : Width φ}, motive (add w additionFlags))
(mul : ∀ {φ} {w : Width φ}, motive (mul w))
(sub : ∀ {φ} {w : Width φ}, motive (sub w))
(sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w))
Expand All @@ -166,7 +200,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
| _, .ashr _ => ashr
| _, .urem _ => urem
| _, .srem _ => srem
| _, .add _ => add
| _, .add _ _ => add
| _, .mul _ => mul
| _, .sub _ => sub
| _, .sdiv _ => sdiv
Expand All @@ -189,7 +223,7 @@ instance : ToString (MOp φ) where
| .urem _ => "urem"
| .srem _ => "srem"
| .select _ => "select"
| .add _ => "add"
| .add _ _ => "add"
| .mul _ => "mul"
| .sub _ => "sub"
| .neg _ => "neg"
Expand All @@ -216,7 +250,6 @@ namespace Op
@[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete
@[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete
@[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete
@[match_pattern] abbrev add : Nat → Op := MOp.add ∘ .concrete
@[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete
@[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete
@[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete
Expand All @@ -227,6 +260,11 @@ namespace Op
@[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete
@[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val


/- Add is separate from the other operations because it takes in 2 flags: nuw and nsw.-/
@[match_pattern] abbrev add (w : Nat) (flags: AdditionFlags :=
{nsw := false , nuw := false}) : Op:= MOp.add (.concrete w) flags

end Op

instance : ToString Op where
Expand Down Expand Up @@ -275,7 +313,7 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
| Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1)
| Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1)
| Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1)
| Op.add _ => LLVM.add (op.getN 0) (op.getN 1)
| Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags
| Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1)
| Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1)
| Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1)
Expand Down
12 changes: 11 additions & 1 deletion SSA/Projects/InstCombine/LLVM/EDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,17 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
| "llvm.ashr" => pure <| Sum.inl .ashr
| "llvm.urem" => pure <| Sum.inl .urem
| "llvm.srem" => pure <| Sum.inl .srem
| "llvm.add" => pure <| Sum.inl .add
| "llvm.add" => do
let attr? := opStx.attrs.getAttr "overflowFlags"
match attr? with
| .none => pure <| Sum.inl (MOp.BinaryOp.add)
| .some y => match y with
| .opaque_ "llvm.overflow" "nsw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, false⟩)
| .opaque_ "llvm.overflow" "nuw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨false, true⟩)
| .list [.opaque_ "llvm.overflow" "nuw", .opaque_ "llvm.overflow" "nsw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, true⟩)
| .list [.opaque_ "llvm.overflow" "nsw", .opaque_ "llvm.overflow" "nuw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, true⟩)
| .opaque_ "llvm.overflow" s => throw <| .generic s!"The overflow flag {s} not allowed. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| _ => throw <| .generic s!"Unrecognised overflow flag found: {MLIR.AST.docAttrVal y}. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)"
| "llvm.mul" => pure <| Sum.inl .mul
| "llvm.sub" => pure <| Sum.inl .sub
| "llvm.sdiv" => pure <| Sum.inl .sdiv
Expand Down
17 changes: 15 additions & 2 deletions SSA/Projects/InstCombine/LLVM/Semantics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,24 @@ def add? {w : Nat} (x y : BitVec w) : IntW w :=
@[simp_llvm_option]
theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl

structure AdditionFlags where
nsw : Bool := false
nuw : Bool := false
deriving Repr, DecidableEq

@[simp_llvm_option]
def add {w : Nat} (x y : IntW w) : IntW w := do
def add {w : Nat} (x y : IntW w) (flags : AdditionFlags := {nsw := false , nuw := false}) : IntW w := do
let x' ← x
let y' ← y
add? x' y'
let nsw := flags.nsw
let nuw := flags.nuw
let AddSignedWraps? : Prop := nsw ∧
((x'.toInt + y'.toInt) < -(2^(w-1)) ∨ (x'.toInt + y'.toInt) ≥ 2^w)
let AddUnsignedWraps? : Prop := nuw ∧ ((x'.toNat + y'.toNat) ≥ 2^w)
if (AddSignedWraps? ∨ AddUnsignedWraps?) then
none
else
add? x' y'

/--
The value produced is the integer difference of the two operands.
Expand Down
Loading