diff --git a/.gitattributes b/.gitattributes index c20182ffa..290fd67bd 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ SSA/Projects/InstCombine/tests/LLVM/** linguist-generated=true SSA/Projects/InstCombine/all.lean linguist-generated=true +test/LLVMDialect/InstCombine/** linguist-generated=true \ No newline at end of file diff --git a/SSA/Core/MLIRSyntax/GenericParser.lean b/SSA/Core/MLIRSyntax/GenericParser.lean index 2dbdaffcd..e6edc4e59 100644 --- a/SSA/Core/MLIRSyntax/GenericParser.lean +++ b/SSA/Core/MLIRSyntax/GenericParser.lean @@ -548,8 +548,13 @@ syntax "@" ident : mlir_attr_val_symbol syntax "@" str : mlir_attr_val_symbol syntax "#" ident : mlir_attr_val -- alias syntax "#" strLit : mlir_attr_val -- aliass - -syntax "#" ident "<" strLit ">" : mlir_attr_val -- opaqueAttr +declare_syntax_cat dialect_attribute_contents +syntax mlir_attr_val : dialect_attribute_contents +syntax "(" dialect_attribute_contents + ")" : dialect_attribute_contents +syntax "[" dialect_attribute_contents + "]": dialect_attribute_contents +syntax "{" dialect_attribute_contents + "}": dialect_attribute_contents +-- syntax [^\[<({\]>)}\0]+ : dialect_attribute_contents +syntax "#" ident "<" dialect_attribute_contents ">" : mlir_attr_val -- opaqueAttr syntax "#opaque<" ident "," strLit ">" ":" mlir_type : mlir_attr_val -- opaqueElementsAttr syntax mlir_attr_val_symbol "::" mlir_attr_val_symbol : mlir_attr_val_symbol @@ -595,7 +600,10 @@ macro_rules | `([mlir_attr_val| # $dialect:ident < $opaqueData:str > ]) => do let dialect := Lean.quote dialect.getId.toString `(AttrValue.opaque_ $dialect $opaqueData) - +| `([mlir_attr_val| # $dialect:ident < $opaqueData:ident > ]) => do + let d := Lean.quote dialect.getId.toString + let g : TSyntax `str := Lean.Syntax.mkStrLit (toString opaqueData.getId) + `(AttrValue.opaque_ $d $g) macro_rules | `([mlir_attr_val| #opaque< $dialect:ident, $opaqueData:str> : $t:mlir_type ]) => do let dialect := Lean.quote dialect.getId.toString diff --git a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean index 381a932a3..b6c2fd1a7 100644 --- a/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean +++ b/SSA/Projects/InstCombine/AliveHandwrittenLargeExamples.lean @@ -4,6 +4,7 @@ import SSA.Projects.InstCombine.LLVM.PrettyEDSL import SSA.Projects.InstCombine.Tactic import SSA.Projects.InstCombine.TacticAuto import SSA.Projects.InstCombine.ComWrappers +import SSA.Projects.InstCombine.LLVM.Semantics import Mathlib.Tactic open BitVec @@ -134,6 +135,7 @@ def alive_simplifyMulDivRem805 (w : Nat) : rw [LLVM.sdiv?_denom_zero_eq_none] apply Refinement.none_left case neg => + simp rw [BitVec.ult_toNat] rw [BitVec.toNat_ofNat] cases w' @@ -227,6 +229,8 @@ def alive_simplifyMulDivRem805' (w : Nat) : unfold MulDivRem805_lhs MulDivRem805_rhs simp only [simp_llvm_wrap] simp_alive_ssa + simp only [LLVM.add_reduce] + simp only [LLVM.sdiv_reduce] simp_alive_undef simp_alive_case_bash simp only [ofInt_ofNat, add_eq, LLVM.icmp?_ult_eq] diff --git a/SSA/Projects/InstCombine/Base.lean b/SSA/Projects/InstCombine/Base.lean index 7e26aca86..3b4bcfee5 100644 --- a/SSA/Projects/InstCombine/Base.lean +++ b/SSA/Projects/InstCombine/Base.lean @@ -89,18 +89,18 @@ deriving Repr, DecidableEq, Inhabited /-- Homogeneous, binary operations -/ inductive MOp.BinaryOp : Type | and - | or + | or (disjoint : Bool) | xor - | shl - | lshr - | ashr + | shl (nsw : Bool) (nuw : Bool) + | lshr (exact : Bool) + | ashr (exact : Bool) | urem | srem - | add - | mul - | sub - | sdiv - | udiv + | add (nsw : Bool) (nuw : Bool) + | mul (nsw : Bool) (nuw : Bool) + | sub (nsw : Bool) (nuw : Bool) + | sdiv (exact : Bool) + | udiv (exact : Bool) deriving Repr, DecidableEq, Inhabited -- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations @@ -120,18 +120,18 @@ namespace MOp @[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy @[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and -@[match_pattern] def or (w : Width φ) : MOp φ := .binary w .or +@[match_pattern] def or (disjoint : Bool := false) (w : Width φ) : MOp φ := .binary w (.or disjoint) @[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor -@[match_pattern] def shl (w : Width φ) : MOp φ := .binary w .shl -@[match_pattern] def lshr (w : Width φ) : MOp φ := .binary w .lshr -@[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr +@[match_pattern] def shl (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.shl nsw nuw) +@[match_pattern] def lshr (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.lshr exact) +@[match_pattern] def ashr (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.ashr exact) @[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem @[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem -@[match_pattern] def add (w : Width φ) : MOp φ := .binary w .add -@[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul -@[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub -@[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv -@[match_pattern] def udiv (w : Width φ) : MOp φ := .binary w .udiv +@[match_pattern] def add (nsw : Bool:= false) (nuw : Bool:= false) (w : Width φ) : MOp φ := .binary w (.add nsw nuw) +@[match_pattern] def mul (nsw : Bool:= false) (nuw : Bool:= false) (w : Width φ) : MOp φ := .binary w (.mul nsw nuw) +@[match_pattern] def sub (nsw : Bool:= false) (nuw : Bool:= false) (w : Width φ) : MOp φ := .binary w (.sub nsw nuw) +@[match_pattern] def sdiv (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.sdiv exact) +@[match_pattern] def udiv (exact : Bool:= false) (w : Width φ) : MOp φ := .binary w (.udiv exact) /-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} @@ -139,18 +139,18 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} (not : ∀ {φ} {w : Width φ}, motive (not w)) (copy : ∀ {φ} {w : Width φ}, motive (copy w)) (and : ∀ {φ} {w : Width φ}, motive (and w)) - (or : ∀ {φ} {w : Width φ}, motive (or w)) + (or : ∀ {φ disjoint} {w : Width φ}, motive (or disjoint w)) (xor : ∀ {φ} {w : Width φ}, motive (xor w)) - (shl : ∀ {φ} {w : Width φ}, motive (shl w)) - (lshr : ∀ {φ} {w : Width φ}, motive (lshr w)) - (ashr : ∀ {φ} {w : Width φ}, motive (ashr w)) + (shl : ∀ {φ nsw nuw} {w : Width φ}, motive (shl nsw nuw w)) + (lshr : ∀ {φ exact } {w : Width φ}, motive (lshr exact w)) + (ashr : ∀ {φ exact } {w : Width φ}, motive (ashr exact w)) (urem : ∀ {φ} {w : Width φ}, motive (urem w)) (srem : ∀ {φ} {w : Width φ}, motive (srem w)) - (add : ∀ {φ} {w : Width φ}, motive (add w)) - (mul : ∀ {φ} {w : Width φ}, motive (mul w)) - (sub : ∀ {φ} {w : Width φ}, motive (sub w)) - (sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w)) - (udiv : ∀ {φ} {w : Width φ}, motive (udiv w)) + (add : ∀ {φ nsw nuw} {w : Width φ}, motive (add nsw nuw w)) + (mul : ∀ {φ nsw nuw} {w : Width φ}, motive (mul nsw nuw w)) + (sub : ∀ {φ nsw nuw} {w : Width φ}, motive (sub nsw nuw w)) + (sdiv : ∀ {φ exact } {w : Width φ}, motive (sdiv exact w)) + (udiv : ∀ {φ exact } {w : Width φ}, motive (udiv exact w)) (select : ∀ {φ} {w : Width φ}, motive (select w)) (icmp : ∀ {φ c} {w : Width φ}, motive (icmp c w)) (const : ∀ {φ v} {w : Width φ}, motive (const w v)) : @@ -159,18 +159,18 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*} | _, .not _ => not | _, .copy _ => copy | _, .and _ => and - | _, .or _ => or + | _, .or _ _ => or | _, .xor _ => xor - | _, .shl _ => shl - | _, .lshr _ => lshr - | _, .ashr _ => ashr + | _, .shl _ _ _ => shl + | _, .lshr _ _ => lshr + | _, .ashr _ _ => ashr | _, .urem _ => urem | _, .srem _ => srem - | _, .add _ => add - | _, .mul _ => mul - | _, .sub _ => sub - | _, .sdiv _ => sdiv - | _, .udiv _ => udiv + | n, .add nsw nuw w => @add n nsw nuw w + | _, .mul _ _ _ => mul + | _, .sub _ _ _ => sub + | _, .sdiv _ _ => sdiv + | _, .udiv _ _ => udiv | _, .select _ => select | _, .icmp .. => icmp | _, .const .. => const @@ -180,22 +180,22 @@ end MOp instance : ToString (MOp φ) where toString | .and _ => "and" - | .or _ => "or" + | .or _ _ => "or" | .not _ => "not" | .xor _ => "xor" - | .shl _ => "shl" - | .lshr _ => "lshr" - | .ashr _ => "ashr" + | .shl _ _ _ => "shl" + | .lshr _ _ => "lshr" + | .ashr _ _ => "ashr" | .urem _ => "urem" | .srem _ => "srem" | .select _ => "select" - | .add _ => "add" - | .mul _ => "mul" - | .sub _ => "sub" + | .add _ _ _ => "add" + | .mul _ _ _ => "mul" + | .sub _ _ _ => "sub" | .neg _ => "neg" | .copy _ => "copy" - | .sdiv _ => "sdiv" - | .udiv _ => "udiv" + | .sdiv _ _ => "sdiv" + | .udiv _ _ => "udiv" | .icmp ty _ => s!"icmp {ty}" | .const _ v => s!"const {v}" @@ -207,22 +207,22 @@ namespace Op @[match_pattern] abbrev binary (w : Nat) (op : MOp.BinaryOp) : Op := MOp.binary (.concrete w) op @[match_pattern] abbrev and : Nat → Op := MOp.and ∘ .concrete -@[match_pattern] abbrev or : Nat → Op := MOp.or ∘ .concrete +@[match_pattern] abbrev or (disjoint : Bool) : Nat → Op := MOp.or disjoint ∘ .concrete @[match_pattern] abbrev not : Nat → Op := MOp.not ∘ .concrete @[match_pattern] abbrev xor : Nat → Op := MOp.xor ∘ .concrete -@[match_pattern] abbrev shl : Nat → Op := MOp.shl ∘ .concrete -@[match_pattern] abbrev lshr : Nat → Op := MOp.lshr ∘ .concrete -@[match_pattern] abbrev ashr : Nat → Op := MOp.ashr ∘ .concrete +@[match_pattern] abbrev shl (nsw nuw : Bool) : Nat → Op := MOp.shl nsw nuw ∘ .concrete +@[match_pattern] abbrev lshr (exact : Bool) : Nat → Op := MOp.lshr exact ∘ .concrete +@[match_pattern] abbrev ashr (exact : Bool) : Nat → Op := MOp.ashr exact ∘ .concrete @[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete @[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete @[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete -@[match_pattern] abbrev add : Nat → Op := MOp.add ∘ .concrete -@[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete -@[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete +@[match_pattern] abbrev add (nuw : Bool := false) (nsw : Bool := false) : Nat → Op := (MOp.add nsw nuw) ∘ .concrete +@[match_pattern] abbrev mul (nsw nuw : Bool) : Nat → Op := MOp.mul nsw nuw ∘ .concrete +@[match_pattern] abbrev sub (nsw nuw : Bool) : Nat → Op := MOp.sub nsw nuw ∘ .concrete @[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete @[match_pattern] abbrev copy : Nat → Op := MOp.copy ∘ .concrete -@[match_pattern] abbrev sdiv : Nat → Op := MOp.sdiv ∘ .concrete -@[match_pattern] abbrev udiv : Nat → Op := MOp.udiv ∘ .concrete +@[match_pattern] abbrev sdiv (exact : Bool) : Nat → Op := MOp.sdiv exact ∘ .concrete +@[match_pattern] abbrev udiv (exact : Bool) : Nat → Op := MOp.udiv exact ∘ .concrete @[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete @[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val @@ -269,20 +269,23 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig | Op.not _ => LLVM.not (op.getN 0) | Op.neg _ => LLVM.neg (op.getN 0) | Op.and _ => LLVM.and (op.getN 0) (op.getN 1) - | Op.or _ => LLVM.or (op.getN 0) (op.getN 1) + | Op.or d _ => LLVM.or (op.getN 0) (op.getN 1) d | Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1) - | Op.shl _ => LLVM.shl (op.getN 0) (op.getN 1) - | Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1) - | Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1) - | Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1) - | Op.add _ => LLVM.add (op.getN 0) (op.getN 1) - | Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1) - | Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1) - | Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1) + | Op.shl nsw nuw _ => LLVM.shl (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} + | Op.lshr e _ => LLVM.lshr (op.getN 0) (op.getN 1) + | Op.ashr e _ => LLVM.ashr (op.getN 0) (op.getN 1) + | Op.sub nsw nuw _ => LLVM.sub (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} + | Op.add a b _ => LLVM.add (op.getN 0) (op.getN 1) {nsw := a, nuw := b} + + -- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true true)), _ => sorry + | Op.mul nsw nuw _ => LLVM.mul (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw} + | Op.sdiv e _ => LLVM.sdiv (op.getN 0) (op.getN 1) {exact := e} + | Op.udiv e _ => LLVM.udiv (op.getN 0) (op.getN 1) {exact := e} | Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1) | Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1) | Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1) | Op.select _ => LLVM.select (op.getN 0) (op.getN 1) (op.getN 2) + -- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true false)), _ => sorry instance : DialectDenote LLVM := ⟨ fun o args _ => Op.denote o args diff --git a/SSA/Projects/InstCombine/ComWrappers.lean b/SSA/Projects/InstCombine/ComWrappers.lean index 2f6dc4932..f7e5d0eef 100644 --- a/SSA/Projects/InstCombine/ComWrappers.lean +++ b/SSA/Projects/InstCombine/ComWrappers.lean @@ -65,7 +65,7 @@ def or {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.or w) + (op := InstCombine.MOp.or false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -93,7 +93,7 @@ def shl {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.shl w) + (op := InstCombine.MOp.shl false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -107,7 +107,7 @@ def lshr {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.lshr w) + (op := InstCombine.MOp.lshr false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -121,7 +121,7 @@ def ashr {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.ashr w) + (op := InstCombine.MOp.ashr false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -135,7 +135,7 @@ def sub {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.sub w) + (op := InstCombine.MOp.sub false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -149,7 +149,7 @@ def add {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.add w) + (op := InstCombine.MOp.add false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -163,7 +163,7 @@ def mul {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.mul w) + (op := InstCombine.MOp.mul false false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -177,7 +177,7 @@ def sdiv {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.sdiv w) + (op := InstCombine.MOp.sdiv false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) @@ -191,7 +191,7 @@ def udiv {Γ : Ctxt _} (w : ℕ) (l r : Nat) := by get_elem_tactic) : Expr InstCombine.LLVM Γ .pure (InstCombine.Ty.bitvec w) := Expr.mk - (op := InstCombine.MOp.udiv w) + (op := InstCombine.MOp.udiv false w) (eff_le := by constructor) (ty_eq := rfl) (args := .cons ⟨l, lp⟩ <| .cons ⟨r, rp⟩ .nil) diff --git a/SSA/Projects/InstCombine/LLVM/EDSL.lean b/SSA/Projects/InstCombine/LLVM/EDSL.lean index 7303e242e..e303e65b0 100644 --- a/SSA/Projects/InstCombine/LLVM/EDSL.lean +++ b/SSA/Projects/InstCombine/LLVM/EDSL.lean @@ -97,20 +97,45 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) : else let v₂ : Γ.Var (.bitvec w) := (by simpa using hty) ▸ v₂ - let (op : MOp.BinaryOp ⊕ LLVM.IntPredicate) ← match opStx.name with + let (op : (MOp.BinaryOp) ⊕ LLVM.IntPredicate) ← match opStx.name with | "llvm.and" => pure <| Sum.inl .and - | "llvm.or" => pure <| Sum.inl .or + -- we do nothing, MLIR does not support this syntax, I don't think + | "llvm.or" => pure <| Sum.inl (.or false) | "llvm.xor" => pure <| Sum.inl .xor - | "llvm.shl" => pure <| Sum.inl .shl - | "llvm.lshr" => pure <| Sum.inl .lshr - | "llvm.ashr" => pure <| Sum.inl .ashr + | "llvm.shl" => pure <| Sum.inl (.shl false false) + | "llvm.lshr" => pure <| Sum.inl (.lshr false) + | "llvm.ashr" => pure <| Sum.inl (.ashr false) | "llvm.urem" => pure <| Sum.inl .urem | "llvm.srem" => pure <| Sum.inl .srem - | "llvm.add" => pure <| Sum.inl .add - | "llvm.mul" => pure <| Sum.inl .mul - | "llvm.sub" => pure <| Sum.inl .sub - | "llvm.sdiv" => pure <| Sum.inl .sdiv - | "llvm.udiv" => pure <| Sum.inl .udiv + | "llvm.add" => do + let att := opStx.attrs.getAttr "overflowFlags" + match att with + | .none => pure <| Sum.inl (MOp.BinaryOp.add false false) + | .some y => match y with + | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.add true false) + | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.add false true) + | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" + | _ => throw <| .generic s!"flag not allowed" + | "llvm.mul" =>do + let att := opStx.attrs.getAttr "overflowFlags" + match att with + | .none => pure <| Sum.inl (MOp.BinaryOp.mul false false) + | .some y => match y with + | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.mul true false) + | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.mul false true) + | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" + | _ => throw <| .generic s!"flag not allowed" + | "llvm.sub" =>do + let att := opStx.attrs.getAttr "overflowFlags" + match att with + | .none => pure <| Sum.inl (MOp.BinaryOp.sub false false) + | .some y => match y with + | (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.sub true false) + | (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.sub false true) + | (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed" + | _ => throw <| .generic s!"flag not allowed" + | "llvm.sdiv" => pure <| Sum.inl (.sdiv false) + | "llvm.udiv" => pure <| Sum.inl (.udiv false) | "llvm.icmp.eq" => pure <| Sum.inr LLVM.IntPredicate.eq | "llvm.icmp.ne" => pure <| Sum.inr LLVM.IntPredicate.ne | "llvm.icmp.ugt" => pure <| Sum.inr LLVM.IntPredicate.ugt diff --git a/SSA/Projects/InstCombine/LLVM/Enumerator.lean b/SSA/Projects/InstCombine/LLVM/Enumerator.lean index 74afaf8bb..d28c3bd1e 100644 --- a/SSA/Projects/InstCombine/LLVM/Enumerator.lean +++ b/SSA/Projects/InstCombine/LLVM/Enumerator.lean @@ -182,18 +182,18 @@ def generateRawSemantics : IO Unit := do rows := rows.append (icmpRows pred) -- rows := rows.append (binopRows "and" (fun w a b => InstCombine.Op.denote (.and w) [a,b]ₕ)) - rows := rows.append (binopRows "or" (fun w a b => InstCombine.Op.denote (.or w) [a,b]ₕ)) + rows := rows.append (binopRows "or" (fun w a b => InstCombine.Op.denote (.or false w) [a,b]ₕ)) rows := rows.append (binopRows "xor" (fun w a b => InstCombine.Op.denote (.xor w) [a,b]ₕ)) - rows := rows.append (binopRows "add" (fun w a b => InstCombine.Op.denote (.add w) [a,b]ₕ)) - rows := rows.append (binopRows "sub" (fun w a b => InstCombine.Op.denote (.sub w) [a,b]ₕ)) - rows := rows.append (binopRows "mul" (fun w a b => InstCombine.Op.denote (.mul w) [a,b]ₕ)) - rows := rows.append (binopRows "udiv" (fun w a b => InstCombine.Op.denote (.udiv w) [a,b]ₕ)) - rows := rows.append (binopRows "sdiv" (fun w a b => InstCombine.Op.denote (.sdiv w) [a,b]ₕ)) + rows := rows.append (binopRows "add" (fun w a b => InstCombine.Op.denote (.add false false w) [a,b]ₕ)) + rows := rows.append (binopRows "sub" (fun w a b => InstCombine.Op.denote (.sub false false w) [a,b]ₕ)) + rows := rows.append (binopRows "mul" (fun w a b => InstCombine.Op.denote (.mul false false w) [a,b]ₕ)) + rows := rows.append (binopRows "udiv" (fun w a b => InstCombine.Op.denote (.udiv false w) [a,b]ₕ)) + rows := rows.append (binopRows "sdiv" (fun w a b => InstCombine.Op.denote (.sdiv false w) [a,b]ₕ)) rows := rows.append (binopRows "urem" (fun w a b => InstCombine.Op.denote (.urem w) [a,b]ₕ)) rows := rows.append (binopRows "srem" (fun w a b => InstCombine.Op.denote (.srem w) [a,b]ₕ)) - rows := rows.append (binopRows "shl" (fun w a b => InstCombine.Op.denote (.shl w) [a,b]ₕ)) - rows := rows.append (binopRows "lshr" (fun w a b => InstCombine.Op.denote (.lshr w) [a,b]ₕ)) - rows := rows.append (binopRows "ashr" (fun w a b => InstCombine.Op.denote (.ashr w) [a,b]ₕ)) + rows := rows.append (binopRows "shl" (fun w a b => InstCombine.Op.denote (.shl false false w) [a,b]ₕ)) + rows := rows.append (binopRows "lshr" (fun w a b => InstCombine.Op.denote (.lshr false w) [a,b]ₕ)) + rows := rows.append (binopRows "ashr" (fun w a b => InstCombine.Op.denote (.ashr false w) [a,b]ₕ)) rows.toList |>.map toString |> "\n".intercalate |> stream.putStr return () diff --git a/SSA/Projects/InstCombine/LLVM/Semantics.lean b/SSA/Projects/InstCombine/LLVM/Semantics.lean index 86bb109e3..4fff86b42 100644 --- a/SSA/Projects/InstCombine/LLVM/Semantics.lean +++ b/SSA/Projects/InstCombine/LLVM/Semantics.lean @@ -42,10 +42,13 @@ def or? {w : Nat} (x y : BitVec w) : IntW w := theorem or?_eq : LLVM.or? a b = .some (BitVec.or a b) := rfl @[simp_llvm_option] -def or {w : Nat} (x y : IntW w) : IntW w := do +def or {w : Nat} (x y : IntW w) (disjoint : Bool := false) : IntW w := do let x' ← x let y' ← y - or? x' y' + if disjoint ∧ x' &&& y' ≠ 0 then + .none + else + or? x' y' /-- The ‘xor’ instruction returns the bitwise logical exclusive or of its two @@ -70,6 +73,11 @@ The value produced is the integer sum of the two operands. If the sum has unsigned overflow, the result returned is the mathematical result modulo 2n, where n is the bit width of the result. Because LLVM integers use a two’s complement representation, this instruction is appropriate for both signed and unsigned integers. -/ + +structure AddParams where + nuw : Bool := false + nsw : Bool := false + @[simp_llvm] def add? {w : Nat} (x y : BitVec w) : IntW w := pure <| x + y @@ -78,11 +86,23 @@ def add? {w : Nat} (x y : BitVec w) : IntW w := theorem add?_eq : LLVM.add? a b = .some (BitVec.add a b) := rfl @[simp_llvm_option] -def add {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - add? x' y' - +def add {w : Nat} (x y : IntW w) (params : AddParams := {}) : IntW w := do + let x ← x + let y ← y + if (params.nsw ∧ (x.toInt + y.toInt) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w) ∨ ( params.nuw ∧ (x.toNat + y.toNat) ≥ 2^w) then + .none + else + add? x y + +set_option allowUnsafeReducibility true +@[simp, reducible] +theorem add_reduce (x y : IntW w) : add x y = match x , y with + | .none , _ => .none + | _ , .none => none + | .some a , .some b => .some (a + b) := by + rcases x + all_goals (cases y) + all_goals (try simp ; try rfl) /-- The value produced is the integer difference of the two operands. If the difference has unsigned overflow, the result returned is the mathematical result modulo 2n, where n is the bit width of the result. @@ -91,15 +111,24 @@ Because LLVM integers use a two’s complement representation, this instruction @[simp_llvm] def sub? {w : Nat} (x y : BitVec w) : IntW w := pure <| x - y +structure SubParams where + nuw : Bool := false + nsw : Bool := false @[simp_llvm_option] theorem sub?_eq : LLVM.sub? a b = .some (BitVec.sub a b) := rfl @[simp_llvm_option] -def sub {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - sub? x' y' +def sub {w : Nat} (x y : IntW w) (params : SubParams := {}) : IntW w := do + let x ← x + let y ← y + -- Check for unsigned overflow if nuw is set + -- Check for signed overflow if nsw is set + if (params.nuw ∧ x.toNat ≥ y.toNat) ∨ (params.nsw ∧ + ((x.toInt - y.toInt) ≥ -(2^(w-1)) ∧ (x.toInt - y.toInt) < 2^(w-1))) then + .none + else + sub? x y /-- The value produced is the integer product of the two operands. @@ -112,6 +141,10 @@ result for both signed and unsigned integers. If a full product (e.g. i32 * i32 -> i64) is needed, the operands should be sign-extended or zero-extended as appropriate to the width of the full product. -/ + +structure MulParams where + nuw : Bool := false + nsw : Bool := false @[simp_llvm] def mul? {w : Nat} (x y : BitVec w) : IntW w := pure <| x * y @@ -120,16 +153,26 @@ def mul? {w : Nat} (x y : BitVec w) : IntW w := theorem mul?_eq : LLVM.mul? a b = .some (BitVec.mul a b) := rfl @[simp_llvm_option] -def mul {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - mul? x' y' +def mul {w : Nat} (x y : IntW w) (params : MulParams := {}) : IntW w := do + let x ← x + let y ← y + -- Check for unsigned overflow if nuw is set + -- Check for signed overflow if nsw is set + if (params.nuw ∧ (x * y).toNat ≤ (2^w)) ∨ (params.nsw ∧ (x * y).toNat < (2^(w - 1)) ∧ (x * y).toNat ≥ (2^(w - 1))) then + .none + else + -- Return the result + mul? x y /-- The value produced is the unsigned integer quotient of the two operands. Note that unsigned integer division and signed integer division are distinct operations; for signed integer division, use ‘sdiv’. Division by zero is undefined behavior. -/ + +structure UdivParams where + exact : Bool := false + @[simp_llvm] def udiv? {w : Nat} (x y : BitVec w) : IntW w := match y.toNat with @@ -137,10 +180,14 @@ def udiv? {w : Nat} (x y : BitVec w) : IntW w := | _ => pure <| BitVec.ofInt w (x.toNat / y.toNat) @[simp_llvm_option] -def udiv {w : Nat} (x y : IntW w) : IntW w := do +def udiv {w : Nat} (x y : IntW w) (params : UdivParams := {}) : IntW w := do let x' ← x let y' ← y - udiv? x' y' + --- If the exact keyword is present, the result value of the udiv is a poison value if %op1 is not a multiple of %op2 + if params.exact ∧ ¬(x'.toNat ∣ y'.toNat) then + .none + else + udiv? x' y' def intMin (w : Nat) : BitVec w := - BitVec.ofNat w (2^(w - 1)) @@ -166,27 +213,41 @@ at width 2, -4 / -1 is considered overflow! -- only way overflow can happen is (INT_MIN / -1). -- but we do not consider overflow when `w=1`, because `w=1` only has a sign bit, so there -- is no magniture to overflow. +structure SdivParams where + exact : Bool := false @[simp_llvm] def sdiv? {w : Nat} (x y : BitVec w) : IntW w := - if y == 0 || (w != 1 && x == (intMin w) && y == -1) + if y = 0 ∨ (w ≠ 1 ∧ x = (intMin w) ∧ y = -1) then .none else pure (BitVec.sdiv x y) theorem sdiv?_denom_zero_eq_none {w : Nat} (x : BitVec w) : LLVM.sdiv? x 0 = none := by - simp [LLVM.sdiv?, BitVec.sdiv] + simp + simp [LLVM.sdiv?] theorem sdiv?_eq_pure_of_neq_allOnes {x y : BitVec w} (hy : y ≠ 0) (hx : LLVM.intMin w ≠ x) : LLVM.sdiv? x y = pure (BitVec.sdiv x y) := by simp [LLVM.sdiv?] tauto + @[simp_llvm_option] -def sdiv {w : Nat} (x y : IntW w) : IntW w := do +def sdiv {w : Nat} (x y : IntW w) (params : SdivParams := {}) : IntW w := do let x' ← x let y' ← y - sdiv? x' y' - + if (params.exact ∧ ¬ (x'.toNat ∣ y'.toNat)) then + .none + else + sdiv? x' y' +@[simp, reducible] +theorem sdiv_reduce (x y : IntW w) : sdiv x y = match x , y with + | .none , _ => .none + | _ , .none => none + | .some a , .some b => sdiv? a b := by + cases x + all_goals (cases y) + all_goals (try simp ; try rfl) -- Probably not a Mathlib worthy name, not sure how you'd mathlibify the precondition @[simp_llvm] theorem sdiv?_eq_div_if {w : Nat} {x y : BitVec w} : @@ -195,7 +256,8 @@ theorem sdiv?_eq_div_if {w : Nat} {x y : BitVec w} : then none else pure <| BitVec.sdiv x y := by - simp [sdiv?]; split_ifs <;> try tauto + simp [sdiv?] + /-- This instruction returns the unsigned integer remainder of a division. This instruction always performs an unsigned division to get the remainder. @@ -282,6 +344,9 @@ The value produced is op1 * 2^op2 mod 2n, where n is the width of the result. If op2 is (statically or dynamically) equal to or larger than the number of bits in op1, this instruction returns a poison value. -/ +structure ShlParams where + nuw : Bool := false + nsw : Bool := false @[simp_llvm] def shl? {n} (op1 : BitVec n) (op2 : BitVec n) : IntW n := let bits := op2.toNat -- should this be toInt? @@ -290,10 +355,13 @@ def shl? {n} (op1 : BitVec n) (op2 : BitVec n) : IntW n := @[simp_llvm_option] -def shl {w : Nat} (x y : IntW w) : IntW w := do - let x' ← x - let y' ← y - shl? x' y' +def shl {w : Nat} (x y : IntW w) (params : ShlParams := {}): IntW w := do + let x ← x + let y ← y + if (params.nsw ∧ ¬ ((x.toInt * 2 ^ y.toNat) < -(2^(w-1)) ∧ (x.toInt + y.toInt) ≥ 2^w)) ∨ (params.nuw ∧ ¬ ((x.toNat * 2 ^ y.toNat) ≥ 2^w)) then + .none + else + shl? x y /-- This instruction always performs a logical shift right operation. diff --git a/SSA/Projects/InstCombine/Test.lean b/SSA/Projects/InstCombine/Test.lean index 009889122..ccf4c26d2 100644 --- a/SSA/Projects/InstCombine/Test.lean +++ b/SSA/Projects/InstCombine/Test.lean @@ -102,10 +102,9 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.const (ConcreteOrMVa /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.ashr) (%0, %2) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.ashr false) (%0, %2) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 3) op2 ["1", "0", "arg0"] - /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) @@ -113,10 +112,11 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary -/ #guard_msgs in #eval mkExpr (Γn 4) op3 ["2", "1", "0", "arg0"] + /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.add) (%4, %3) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.add false false) (%4, %3) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 5) op4 ["3", "2", "1", "0", "arg0"] @@ -151,7 +151,7 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.const (ConcreteOrMVa /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.ashr) (%0, %2) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.ashr false) (%0, %2) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 3) (ops.get! 2) ["1", "0", "arg0"] @@ -165,7 +165,7 @@ info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary /-- info: Except.ok ⟨EffectKind.pure, ⟨i32, InstCombine.MOp.binary (ConcreteOrMVar.concrete 32) - (InstCombine.MOp.BinaryOp.add) (%4, %3) : (i32, i32) → (i32)⟩⟩ + (InstCombine.MOp.BinaryOp.add false false) (%4, %3) : (i32, i32) → (i32)⟩⟩ -/ #guard_msgs in #eval mkExpr (Γn 5) (ops.get! 4) ["3", "2", "1", "0", "arg0"] @@ -186,13 +186,13 @@ info: ⟨[MTy.bitvec (ConcreteOrMVar.concrete 32)], Com.var (Expr.mk (MOp.const (ConcreteOrMVar.concrete 32) (Int.ofNat 8)) ⋯ ⋯ HVector.nil HVector.nil) (Com.var (Expr.mk (MOp.const (ConcreteOrMVar.concrete 32) (Int.ofNat 31)) ⋯ ⋯ HVector.nil HVector.nil) (Com.var - (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) MOp.BinaryOp.ashr) ⋯ ⋯ (⟨2, ⋯⟩::ₕ(⟨0, ⋯⟩::ₕHVector.nil)) - HVector.nil) + (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) (MOp.BinaryOp.ashr false)) ⋯ ⋯ + (⟨2, ⋯⟩::ₕ(⟨0, ⋯⟩::ₕHVector.nil)) HVector.nil) (Com.var (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) MOp.BinaryOp.and) ⋯ ⋯ (⟨0, ⋯⟩::ₕ(⟨2, ⋯⟩::ₕHVector.nil)) HVector.nil) (Com.var - (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) MOp.BinaryOp.add) ⋯ ⋯ + (Expr.mk (MOp.binary (ConcreteOrMVar.concrete 32) (MOp.BinaryOp.add false false)) ⋯ ⋯ (⟨0, ⋯⟩::ₕ(⟨1, ⋯⟩::ₕHVector.nil)) HVector.nil) (Com.ret ⟨0, ⋯⟩)))))⟩⟩⟩ -/