Skip to content

Commit

Permalink
add all overflow flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Atticus Kuhn committed Jul 17, 2024
1 parent 5730a60 commit 3ad683a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 68 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
SSA/Projects/InstCombine/tests/LLVM/** linguist-generated=true
SSA/Projects/InstCombine/all.lean linguist-generated=true
test/LLVMDialect/InstCombine/** linguist-generated=true
113 changes: 56 additions & 57 deletions SSA/Projects/InstCombine/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ deriving Repr, DecidableEq, Inhabited
/-- Homogeneous, binary operations -/
inductive MOp.BinaryOp : Type
| and
| or
| or (disjoint : Bool)
| xor
| shl
| lshr
| ashr
| shl (nsw : Bool) (nuw : Bool)
| lshr (exact : Bool)
| ashr (exact : Bool)
| urem
| srem
| add (nsw : Bool) (nuw : Bool)
| mul
| sub
| sdiv
| udiv
| mul (nsw : Bool) (nuw : Bool)
| sub (nsw : Bool) (nuw : Bool)
| sdiv (exact : Bool)
| udiv (exact : Bool)
deriving Repr, DecidableEq, Inhabited

-- See: https://releases.llvm.org/14.0.0/docs/LangRef.html#bitwise-binary-operations
Expand All @@ -120,37 +120,37 @@ namespace MOp
@[match_pattern] def copy (w : Width φ) : MOp φ := .unary w .copy

@[match_pattern] def and (w : Width φ) : MOp φ := .binary w .and
@[match_pattern] def or (w : Width φ) : MOp φ := .binary w .or
@[match_pattern] def or (disjoint : Bool) (w : Width φ) : MOp φ := .binary w (.or disjoint)
@[match_pattern] def xor (w : Width φ) : MOp φ := .binary w .xor
@[match_pattern] def shl (w : Width φ) : MOp φ := .binary w .shl
@[match_pattern] def lshr (w : Width φ) : MOp φ := .binary w .lshr
@[match_pattern] def ashr (w : Width φ) : MOp φ := .binary w .ashr
@[match_pattern] def shl (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.shl nsw nuw)
@[match_pattern] def lshr (exact : Bool) (w : Width φ) : MOp φ := .binary w (.lshr exact)
@[match_pattern] def ashr (exact : Bool) (w : Width φ) : MOp φ := .binary w (.ashr exact)
@[match_pattern] def urem (w : Width φ) : MOp φ := .binary w .urem
@[match_pattern] def srem (w : Width φ) : MOp φ := .binary w .srem
@[match_pattern] def add (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.add nsw nuw)
@[match_pattern] def mul (w : Width φ) : MOp φ := .binary w .mul
@[match_pattern] def sub (w : Width φ) : MOp φ := .binary w .sub
@[match_pattern] def sdiv (w : Width φ) : MOp φ := .binary w .sdiv
@[match_pattern] def udiv (w : Width φ) : MOp φ := .binary w .udiv
@[match_pattern] def mul (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.mul nsw nuw)
@[match_pattern] def sub (nsw : Bool) (nuw : Bool) (w : Width φ) : MOp φ := .binary w (.sub nsw nuw)
@[match_pattern] def sdiv (exact : Bool) (w : Width φ) : MOp φ := .binary w (.sdiv exact)
@[match_pattern] def udiv (exact : Bool) (w : Width φ) : MOp φ := .binary w (.udiv exact)

/-- Recursion principle in terms of individual operations, rather than `unary` or `binary` -/
def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
(neg : ∀ {φ} {w : Width φ}, motive (neg w))
(not : ∀ {φ} {w : Width φ}, motive (not w))
(copy : ∀ {φ} {w : Width φ}, motive (copy w))
(and : ∀ {φ} {w : Width φ}, motive (and w))
(or : ∀ {φ} {w : Width φ}, motive (or w))
(or : ∀ {φ disjoint} {w : Width φ}, motive (or disjoint w))
(xor : ∀ {φ} {w : Width φ}, motive (xor w))
(shl : ∀ {φ} {w : Width φ}, motive (shl w))
(lshr : ∀ {φ} {w : Width φ}, motive (lshr w))
(ashr : ∀ {φ} {w : Width φ}, motive (ashr w))
(shl : ∀ {φ nsw nuw} {w : Width φ}, motive (shl nsw nuw w))
(lshr : ∀ {φ exact } {w : Width φ}, motive (lshr exact w))
(ashr : ∀ {φ exact } {w : Width φ}, motive (ashr exact w))
(urem : ∀ {φ} {w : Width φ}, motive (urem w))
(srem : ∀ {φ} {w : Width φ}, motive (srem w))
(add : ∀ {φ nsw nuw} {w : Width φ}, motive (add nsw nuw w))
(mul : ∀ {φ} {w : Width φ}, motive (mul w))
(sub : ∀ {φ} {w : Width φ}, motive (sub w))
(sdiv : ∀ {φ} {w : Width φ}, motive (sdiv w))
(udiv : ∀ {φ} {w : Width φ}, motive (udiv w))
(mul : ∀ {φ nsw nuw} {w : Width φ}, motive (mul nsw nuw w))
(sub : ∀ {φ nsw nuw} {w : Width φ}, motive (sub nsw nuw w))
(sdiv : ∀ {φ exact } {w : Width φ}, motive (sdiv exact w))
(udiv : ∀ {φ exact } {w : Width φ}, motive (udiv exact w))
(select : ∀ {φ} {w : Width φ}, motive (select w))
(icmp : ∀ {φ c} {w : Width φ}, motive (icmp c w))
(const : ∀ {φ v} {w : Width φ}, motive (const w v)) :
Expand All @@ -159,18 +159,18 @@ def deepCasesOn {motive : ∀ {φ}, MOp φ → Sort*}
| _, .not _ => not
| _, .copy _ => copy
| _, .and _ => and
| _, .or _ => or
| _, .or _ _ => or
| _, .xor _ => xor
| _, .shl _ => shl
| _, .lshr _ => lshr
| _, .ashr _ => ashr
| _, .shl _ _ _ => shl
| _, .lshr _ _ => lshr
| _, .ashr _ _ => ashr
| _, .urem _ => urem
| _, .srem _ => srem
| n, .add nsw nuw w => @add n nsw nuw w
| _, .mul _ => mul
| _, .sub _ => sub
| _, .sdiv _ => sdiv
| _, .udiv _ => udiv
| _, .mul _ _ _ => mul
| _, .sub _ _ _ => sub
| _, .sdiv _ _ => sdiv
| _, .udiv _ _ => udiv
| _, .select _ => select
| _, .icmp .. => icmp
| _, .const .. => const
Expand All @@ -180,22 +180,22 @@ end MOp
instance : ToString (MOp φ) where
toString
| .and _ => "and"
| .or _ => "or"
| .or _ _ => "or"
| .not _ => "not"
| .xor _ => "xor"
| .shl _ => "shl"
| .lshr _ => "lshr"
| .ashr _ => "ashr"
| .shl _ _ _ => "shl"
| .lshr _ _ => "lshr"
| .ashr _ _ => "ashr"
| .urem _ => "urem"
| .srem _ => "srem"
| .select _ => "select"
| .add _ _ _ => "add"
| .mul _ => "mul"
| .sub _ => "sub"
| .mul _ _ _ => "mul"
| .sub _ _ _ => "sub"
| .neg _ => "neg"
| .copy _ => "copy"
| .sdiv _ => "sdiv"
| .udiv _ => "udiv"
| .sdiv _ _ => "sdiv"
| .udiv _ _ => "udiv"
| .icmp ty _ => s!"icmp {ty}"
| .const _ v => s!"const {v}"

Expand All @@ -207,22 +207,22 @@ namespace Op
@[match_pattern] abbrev binary (w : Nat) (op : MOp.BinaryOp) : Op := MOp.binary (.concrete w) op

@[match_pattern] abbrev and : Nat → Op := MOp.and ∘ .concrete
@[match_pattern] abbrev or : Nat → Op := MOp.or ∘ .concrete
@[match_pattern] abbrev or (disjoint : Bool) : Nat → Op := MOp.or disjoint ∘ .concrete
@[match_pattern] abbrev not : Nat → Op := MOp.not ∘ .concrete
@[match_pattern] abbrev xor : Nat → Op := MOp.xor ∘ .concrete
@[match_pattern] abbrev shl : Nat → Op := MOp.shl ∘ .concrete
@[match_pattern] abbrev lshr : Nat → Op := MOp.lshr ∘ .concrete
@[match_pattern] abbrev ashr : Nat → Op := MOp.ashr ∘ .concrete
@[match_pattern] abbrev shl (nsw nuw : Bool) : Nat → Op := MOp.shl nsw nuw ∘ .concrete
@[match_pattern] abbrev lshr (exact : Bool) : Nat → Op := MOp.lshr exact ∘ .concrete
@[match_pattern] abbrev ashr (exact : Bool) : Nat → Op := MOp.ashr exact ∘ .concrete
@[match_pattern] abbrev urem : Nat → Op := MOp.urem ∘ .concrete
@[match_pattern] abbrev srem : Nat → Op := MOp.srem ∘ .concrete
@[match_pattern] abbrev select : Nat → Op := MOp.select ∘ .concrete
@[match_pattern] abbrev add (nuw : Bool := false) (nsw : Bool := false) : Nat → Op := (MOp.add nsw nuw) ∘ .concrete
@[match_pattern] abbrev mul : Nat → Op := MOp.mul ∘ .concrete
@[match_pattern] abbrev sub : Nat → Op := MOp.sub ∘ .concrete
@[match_pattern] abbrev mul (nsw nuw : Bool) : Nat → Op := MOp.mul nsw nuw ∘ .concrete
@[match_pattern] abbrev sub (nsw nuw : Bool) : Nat → Op := MOp.sub nsw nuw ∘ .concrete
@[match_pattern] abbrev neg : Nat → Op := MOp.neg ∘ .concrete
@[match_pattern] abbrev copy : Nat → Op := MOp.copy ∘ .concrete
@[match_pattern] abbrev sdiv : Nat → Op := MOp.sdiv ∘ .concrete
@[match_pattern] abbrev udiv : Nat → Op := MOp.udiv ∘ .concrete
@[match_pattern] abbrev sdiv (exact : Bool) : Nat → Op := MOp.sdiv exact ∘ .concrete
@[match_pattern] abbrev udiv (exact : Bool) : Nat → Op := MOp.udiv exact ∘ .concrete

@[match_pattern] abbrev icmp (c : IntPredicate) : Nat → Op := MOp.icmp c ∘ .concrete
@[match_pattern] abbrev const (w : Nat) (val : ℤ) : Op := MOp.const (.concrete w) val
Expand Down Expand Up @@ -269,19 +269,18 @@ def Op.denote (o : LLVM.Op) (op : HVector TyDenote.toType (DialectSignature.sig
| Op.not _ => LLVM.not (op.getN 0)
| Op.neg _ => LLVM.neg (op.getN 0)
| Op.and _ => LLVM.and (op.getN 0) (op.getN 1)
| Op.or _ => LLVM.or (op.getN 0) (op.getN 1)
| Op.or d _ => LLVM.or (op.getN 0) (op.getN 1) d
| Op.xor _ => LLVM.xor (op.getN 0) (op.getN 1)
| Op.shl _ => LLVM.shl (op.getN 0) (op.getN 1)
| Op.lshr _ => LLVM.lshr (op.getN 0) (op.getN 1)
| Op.ashr _ => LLVM.ashr (op.getN 0) (op.getN 1)
| Op.sub _ => LLVM.sub (op.getN 0) (op.getN 1)
-- | Op.add _ => LLVM.add (op.getN 0) (op.getN 1)
| Op.shl nsw nuw _ => LLVM.shl (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw}
| Op.lshr e _ => LLVM.lshr (op.getN 0) (op.getN 1)
| Op.ashr e _ => LLVM.ashr (op.getN 0) (op.getN 1)
| Op.sub nsw nuw _ => LLVM.sub (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw}
| Op.add a b _ => LLVM.add (op.getN 0) (op.getN 1) {nsw := a, nuw := b}

-- | (@MOp.binary (ConcreteOrMVar.concrete _) (@MOp.BinaryOp.add true true)), _ => sorry
| Op.mul _ => LLVM.mul (op.getN 0) (op.getN 1)
| Op.sdiv _ => LLVM.sdiv (op.getN 0) (op.getN 1)
| Op.udiv _ => LLVM.udiv (op.getN 0) (op.getN 1)
| Op.mul nsw nuw _ => LLVM.mul (op.getN 0) (op.getN 1) { nsw := nsw , nuw := nuw}
| Op.sdiv e _ => LLVM.sdiv (op.getN 0) (op.getN 1)
| Op.udiv e _ => LLVM.udiv (op.getN 0) (op.getN 1)
| Op.urem _ => LLVM.urem (op.getN 0) (op.getN 1)
| Op.srem _ => LLVM.srem (op.getN 0) (op.getN 1)
| Op.icmp c _ => LLVM.icmp c (op.getN 0) (op.getN 1)
Expand Down
36 changes: 25 additions & 11 deletions SSA/Projects/InstCombine/LLVM/EDSL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :

let (op : (MOp.BinaryOp) ⊕ LLVM.IntPredicate) ← match opStx.name with
| "llvm.and" => pure <| Sum.inl .and
| "llvm.or" => pure <| Sum.inl .or
-- we do nothing, MLIR does not support this syntax, I don't think
| "llvm.or" => pure <| Sum.inl (.or false)
| "llvm.xor" => pure <| Sum.inl .xor
| "llvm.shl" => pure <| Sum.inl .shl
| "llvm.lshr" => pure <| Sum.inl .lshr
| "llvm.ashr" => pure <| Sum.inl .ashr
| "llvm.shl" => pure <| Sum.inl (.shl false false)
| "llvm.lshr" => pure <| Sum.inl (.lshr false)
| "llvm.ashr" => pure <| Sum.inl (.ashr false)
| "llvm.urem" => pure <| Sum.inl .urem
| "llvm.srem" => pure <| Sum.inl .srem
| "llvm.add" => do
-- sorry
let att := opStx.attrs.getAttr "overflowFlags"
match att with
| .none => pure <| Sum.inl (MOp.BinaryOp.add false false)
Expand All @@ -116,12 +116,26 @@ def mkExpr (Γ : Ctxt (MetaLLVM φ).Ty) (opStx : MLIR.AST.Op φ) :
| (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.add false true)
| (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed"
| _ => throw <| .generic s!"flag not allowed"
-- sorry

| "llvm.mul" => pure <| Sum.inl .mul
| "llvm.sub" => pure <| Sum.inl .sub
| "llvm.sdiv" => pure <| Sum.inl .sdiv
| "llvm.udiv" => pure <| Sum.inl .udiv
| "llvm.mul" =>do
let att := opStx.attrs.getAttr "overflowFlags"
match att with
| .none => pure <| Sum.inl (MOp.BinaryOp.add false false)
| .some y => match y with
| (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.mul true false)
| (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.mul false true)
| (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed"
| _ => throw <| .generic s!"flag not allowed"
| "llvm.sub" =>do
let att := opStx.attrs.getAttr "overflowFlags"
match att with
| .none => pure <| Sum.inl (MOp.BinaryOp.add false false)
| .some y => match y with
| (.opaque_ "llvm.overflow" "nsw") => pure <| Sum.inl (MOp.BinaryOp.sub true false)
| (.opaque_ "llvm.overflow" "nuw") => pure <| Sum.inl (MOp.BinaryOp.sub false true)
| (.opaque_ "llvm.overflow" s ) =>throw <| .generic s!"flag {s} not allowed"
| _ => throw <| .generic s!"flag not allowed"
| "llvm.sdiv" => pure <| Sum.inl (.sdiv false)
| "llvm.udiv" => pure <| Sum.inl (.udiv false)
| "llvm.icmp.eq" => pure <| Sum.inr LLVM.IntPredicate.eq
| "llvm.icmp.ne" => pure <| Sum.inr LLVM.IntPredicate.ne
| "llvm.icmp.ugt" => pure <| Sum.inr LLVM.IntPredicate.ugt
Expand Down

0 comments on commit 3ad683a

Please sign in to comment.