From 4b13fbab524ca7ed3b5772727ac4f4eccef980e9 Mon Sep 17 00:00:00 2001 From: Atticus Kuhn <52258164+AtticusKuhn@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:37:24 +0100 Subject: [PATCH] Feat: Add Overflow Semantics to Addition in LLVM (#480) This is a smaller, scaled back version of https://github.com/opencompl/lean-mlir/pull/471 ``` I am wondering if we can scope it a bit smaller and just add overflow flags for addition. I feel this would allow us to iterate quickly on the right implementation and then expand it to all ops. ``` https://grosser.zulipchat.com/#narrow/stream/446584-Project---Lean4---BitVectors/topic/Overflow.20Flags.20in.20LLVM/near/453334239 --------- Co-authored-by: Atticus Kuhn Co-authored-by: Tobias Grosser Co-authored-by: Alex Keizer Co-authored-by: Atticus Kuhn --- SSA/Core/MLIRSyntax/GenericParser.lean | 45 ++++++++++++++-- .../AliveHandwrittenLargeExamples.lean | 4 +- SSA/Projects/InstCombine/Base.lean | 54 ++++++++++++++++--- SSA/Projects/InstCombine/LLVM/EDSL.lean | 12 ++++- SSA/Projects/InstCombine/LLVM/Semantics.lean | 17 +++++- 5 files changed, 116 insertions(+), 16 deletions(-) diff --git a/SSA/Core/MLIRSyntax/GenericParser.lean b/SSA/Core/MLIRSyntax/GenericParser.lean index 523ed15c5..980e64d31 100644 --- a/SSA/Core/MLIRSyntax/GenericParser.lean +++ b/SSA/Core/MLIRSyntax/GenericParser.lean @@ -551,9 +551,38 @@ declare_syntax_cat mlir_attr_val_symbol syntax "@" ident : mlir_attr_val_symbol syntax "@" str : mlir_attr_val_symbol syntax "#" ident : mlir_attr_val -- alias -syntax "#" strLit : mlir_attr_val -- aliass +syntax "#" strLit : mlir_attr_val -- alias -syntax "#" ident "<" strLit ">" : mlir_attr_val -- opaqueAttr +declare_syntax_cat dialect_attribute_contents +syntax mlir_attr_val : dialect_attribute_contents +/-- +Following https://mlir.llvm.org/docs/LangRef/, we define a `dialect-attribute`, +which is a particular case of an `mlir-attr-val` that is namespaced to a particular dialect + +```bnf +dialect-namespace ::= bare-id + +dialect-attribute ::= `#` (opaque-dialect-attribute | pretty-dialect-attribute) +opaque-dialect-attribute ::= dialect-namespace dialect-attribute-body +pretty-dialect-attribute ::= dialect-namespace `.` pretty-dialect-attribute-lead-ident + dialect-attribute-body? +pretty-dialect-attribute-lead-ident ::= `[A-Za-z][A-Za-z0-9._]*` + +dialect-attribute-body ::= `<` dialect-attribute-contents+ `>` +dialect-attribute-contents ::= dialect-attribute-body + | `(` dialect-attribute-contents+ `)` + | `[` dialect-attribute-contents+ `]` + | `{` dialect-attribute-contents+ `}` + | [^\[<({\]>)}\0]+ +``` +-/ +syntax "(" dialect_attribute_contents + ")" : dialect_attribute_contents +syntax "[" dialect_attribute_contents + "]": dialect_attribute_contents +syntax "{" dialect_attribute_contents + "}": dialect_attribute_contents +syntax "#" ident "<" mlir_attr_val,* ">" : mlir_attr_val +-- If I un-comment this line, it causes an error. I don't know why. Oh well. +-- syntax "#" ident "<" ident ">" : mlir_attr_val +-- syntax "#" ident "<" strLit ">" : mlir_attr_val syntax "#opaque<" ident "," strLit ">" ":" mlir_type : mlir_attr_val -- opaqueElementsAttr syntax mlir_attr_val_symbol "::" mlir_attr_val_symbol : mlir_attr_val_symbol @@ -596,11 +625,21 @@ macro_rules | `([mlir_attr_val| false ]) => `(AttrValue.bool False) +macro_rules +| `([mlir_attr_val| # $dialect:ident <$xs ,* > ]) => do + let initList : TSyntax `term <- `([]) + let vals : TSyntax `term <- xs.getElems.foldlM (init := initList) fun (xs : TSyntax `term) (x : TSyntax `mlir_attr_val) => + `([mlir_attr_val| #$dialect<$x>] :: $xs) + `(AttrValue.list $vals) + macro_rules | `([mlir_attr_val| # $dialect:ident < $opaqueData:str > ]) => do let dialect := Lean.quote dialect.getId.toString `(AttrValue.opaque_ $dialect $opaqueData) - +| `([mlir_attr_val| # $dialect:ident < $opaqueData:ident > ]) => do + let d := Lean.quote dialect.getId.toString + let g : TSyntax `str := Lean.Syntax.mkStrLit (toString opaqueData.getId) + `(AttrValue.opaque_ $d $g) macro_rules | `([mlir_attr_val| #opaque< $dialect:ident, $opaqueData:str> : $t:mlir_type ]) => do let dialect := Lean.quote dialect.getId.toString diff --git a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean index 48dfab288..1c122e58b 100644 --- a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean +++ b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean @@ -122,7 +122,7 @@ def alive_simplifyMulDivRem805 (w : Nat) : simp_alive_ssa simp_alive_undef simp_alive_case_bash - simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq] + simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq, false_and, false_or, ite_false, Option.some_bind] cases w case zero => intros x @@ -231,7 +231,7 @@ def alive_simplifyMulDivRem805' (w : Nat) : simp_alive_ssa simp_alive_undef simp_alive_case_bash - simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq] + simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq, false_and, false_or, ite_false, Option.some_bind] intros a simp_alive_ops simp only [ofNat_eq_ofNat, Bool.or_eq_true, beq_iff_eq, Bool.and_eq_true, bne_iff_ne, ne_eq, diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index 7e26aca86..c7111806f 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -96,12 +96,41 @@ inductive MOp.BinaryOp : Type | ashr | urem | srem - | add + | add (nswnuw : AdditionFlags := {nsw := false, nuw := false} ) | mul | sub | sdiv | udiv -deriving Repr, DecidableEq, Inhabited +deriving DecidableEq, Inhabited + +open Std (Format) in +/-- +If both the nuw and nsw flags are the default value (false,false), +then we should not print them. This should be the default +behavior in Lean, but it isn't +-/ +def reprWithoutFlags (op : MOp.BinaryOp) (prec : Nat) : Format := + let op : String := match op with + | .and => "and" + | .or => "or" + | .xor => "xor" + | .shl => "shl" + | .lshr => "lshr" + | .ashr => "ashr" + | .urem => "urem" + | .srem => "srem" + | .add ⟨false, false⟩ => "add" + | .add ⟨nsw, nuw⟩ => toString f!"add {nsw} {nuw}" + | .mul => "mul" + | .sub => "sub" + | .sdiv => "sdiv" + | .udiv => "udiv" + Repr.addAppParen (Format.group (Format.nest + (if prec >= max_prec then 1 else 2) f!"InstCombine.MOp.BinaryOp.{op}")) + prec + +instance : Repr (MOp.BinaryOp) where + reprPrec := reprWithoutFlags -- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations inductive MOp (φ : Nat) : Type @@ -127,12 +156,17 @@ namespace MOp @[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr @[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem @[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem -@[match_pattern] def add (w : Width φ) : MOp φ := .binary w .add + @[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul @[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub @[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv @[match_pattern] def udiv (w : Width φ) : MOp φ := .binary w .udiv +/- This definition is off by itself because it is different-/ +@[match_pattern] def add (w : Width φ) + (additionFlags: AdditionFlags := {nsw := false , nuw := false}) : MOp φ + := .binary w (.add additionFlags ) + /-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} (neg : ∀ {φ} {w : Width φ}, motive (neg w)) @@ -146,7 +180,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} (ashr : ∀ {φ} {w : Width φ}, motive (ashr w)) (urem : ∀ {φ} {w : Width φ}, motive (urem w)) (srem : ∀ {φ} {w : Width φ}, motive (srem w)) - (add : ∀ {φ} {w : Width φ}, motive (add w)) + (add : ∀ {φ additionFlags} {w : Width φ}, motive (add w additionFlags)) (mul : ∀ {φ} {w : Width φ}, motive (mul w)) (sub : ∀ {φ} {w : Width φ}, motive (sub w)) (sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w)) @@ -166,7 +200,7 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} | _, .ashr _ => ashr | _, .urem _ => urem | _, .srem _ => srem - | _, .add _ => add + | _, .add _ _ => add | _, .mul _ => mul | _, .sub _ => sub | _, .sdiv _ => sdiv @@ -189,7 +223,7 @@ instance : ToString (MOp φ) where | .urem _ => "urem" | .srem _ => "srem" | .select _ => "select" - | .add _ => "add" + | .add _ _ => "add" | .mul _ => "mul" | .sub _ => "sub" | .neg _ => "neg" @@ -216,7 +250,6 @@ namespace Op @[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete @[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete @[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete -@[match_pattern] abbrev add : Nat → Op := MOp.add ∘ .concrete @[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete @[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete @[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete @@ -227,6 +260,11 @@ namespace Op @[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete @[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val + +/- Add is separate from the other operations because it takes in 2 flags: nuw and nsw.-/ +@[match_pattern] abbrev add (w : Nat) (flags: AdditionFlags := + {nsw := false , nuw := false}) : Op:= MOp.add (.concrete w) flags + end Op instance : ToString Op where @@ -275,7 +313,7 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig | Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1) | Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1) | Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1) - | Op.add _ => LLVM.add (op.getN 0) (op.getN 1) + | Op.add _ flags => LLVM.add (op.getN 0) (op.getN 1) flags | Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1) | Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1) | Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 00af2e575..3cd90f322 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -107,7 +107,17 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : | "llvm.ashr" => pure <| Sum.inl .ashr | "llvm.urem" => pure <| Sum.inl .urem | "llvm.srem" => pure <| Sum.inl .srem - | "llvm.add" => pure <| Sum.inl .add + | "llvm.add" => do + let attr? := opStx.attrs.getAttr "overflowFlags" + match attr? with + | .none => pure <| Sum.inl (MOp.BinaryOp.add) + | .some y => match y with + | .opaque_ "llvm.overflow" "nsw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, false⟩) + | .opaque_ "llvm.overflow" "nuw" => pure <| Sum.inl (MOp.BinaryOp.add ⟨false, true⟩) + | .list [.opaque_ "llvm.overflow" "nuw", .opaque_ "llvm.overflow" "nsw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, true⟩) + | .list [.opaque_ "llvm.overflow" "nsw", .opaque_ "llvm.overflow" "nuw"] => pure <| Sum.inl (MOp.BinaryOp.add ⟨true, true⟩) + | .opaque_ "llvm.overflow" s => throw <| .generic s!"The overflow flag {s} not allowed. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)" + | _ => throw <| .generic s!"Unrecognised overflow flag found: {MLIR.AST.docAttrVal y}. We currently support nsw (no signed wrap) and nuw (no unsigned wrap)" | "llvm.mul" => pure <| Sum.inl .mul | "llvm.sub" => pure <| Sum.inl .sub | "llvm.sdiv" => pure <| Sum.inl .sdiv diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index 86bb109e3..3fd7ad188 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -77,11 +77,24 @@ def add? {w : Nat} (x y : BitVec w) : IntW w := @[simp_llvm_option] theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl +structure AdditionFlags where + nsw : Bool := false + nuw : Bool := false + deriving Repr, DecidableEq + @[simp_llvm_option] -def add {w : Nat} (x y : IntW w) : IntW w := do +def add {w : Nat} (x y : IntW w) (flags : AdditionFlags := {nsw := false , nuw := false}) : IntW w := do let x' ← x let y' ← y - add? x' y' + let nsw := flags.nsw + let nuw := flags.nuw + let AddSignedWraps? : Prop := nsw ∧ + ((x'.toInt + y'.toInt) < -(2^(w-1)) ∨ (x'.toInt + y'.toInt) ≥ 2^w) + let AddUnsignedWraps? : Prop := nuw ∧ ((x'.toNat + y'.toNat) ≥ 2^w) + if (AddSignedWraps? ∨ AddUnsignedWraps?) then + none + else + add? x' y' /-- The value produced is the integer difference of the two operands.