Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot committed Nov 4, 2024
1 parent 87e1a46 commit 77505b5
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 5 deletions.
22 changes: 21 additions & 1 deletion tests/dialects/test_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from xdsl.dialects import arith, builtin, llvm, test
from xdsl.dialects.builtin import UnitAttr, i32
from xdsl.dialects.builtin import UnitAttr, i1, i32
from xdsl.ir import Attribute
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException
Expand Down Expand Up @@ -57,6 +57,26 @@ def test_llvm_overflow_arithmetic_ops(
)


@pytest.mark.parametrize(
"op_type, attributes, exact",
[
(llvm.UDivOp, {}, llvm.IntegerAttr(0, i1)),
(llvm.SDivOp, {}, llvm.IntegerAttr(0, i1)),
(llvm.LShrOp, {}, llvm.IntegerAttr(0, i1)),
(llvm.AShrOp, {}, llvm.IntegerAttr(0, i1)),
],
)
def test_llvm_exact_arithmetic_ops(
op_type: type[llvm.ArithmeticBinOpExact],
attributes: dict[str, Attribute],
exact: llvm.IntegerAttr[llvm.IntegerType],
):
op1, op2 = test.TestOp(result_types=[i32, i32]).results
assert op_type(op1, op2, attributes, exact).is_structurally_equivalent(
op_type(lhs=op1, rhs=op2, attributes=attributes, exact=exact)
)


def test_llvm_pointer_ops():
module = builtin.ModuleOp(
[
Expand Down
12 changes: 12 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
%sdiv = llvm.sdiv %arg0, %arg1 : i32
// CHECK: %sdiv = llvm.sdiv %arg0, %arg1 : i32

%udiv_exact = llvm.udiv exact %arg0, %arg1 : i32
// CHECK: %udiv_exact = llvm.udiv exact %arg0, %arg1 : i32

%sdiv_exact = llvm.sdiv exact %arg0, %arg1 : i32
// CHECK: %sdiv_exact = llvm.sdiv exact %arg0, %arg1 : i32

%urem = llvm.urem %arg0, %arg1 : i32
// CHECK: %urem = llvm.urem %arg0, %arg1 : i32

Expand Down Expand Up @@ -53,6 +59,12 @@
%ashr = llvm.ashr %arg0, %arg1 : i32
// CHECK: %ashr = llvm.ashr %arg0, %arg1 : i32

%lshr_exact = llvm.lshr exact %arg0, %arg1 : i32
// CHECK: %lshr_exact = llvm.lshr exact %arg0, %arg1 : i32

%ashr_exact = llvm.ashr exact %arg0, %arg1 : i32
// CHECK: %ashr_exact = llvm.ashr exact %arg0, %arg1 : i32

%trunc = llvm.trunc %arg0 : i32 to i16
// CHECK: %trunc = llvm.trunc %arg0 : i32 to i16

Expand Down
67 changes: 63 additions & 4 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
I64,
AnyIntegerAttr,
ArrayAttr,
BoolAttr,
ContainerType,
DenseArrayBase,
IndexType,
Expand Down Expand Up @@ -474,6 +475,64 @@ def print(self, printer: Printer) -> None:
printer.print(self.lhs.type)


class ArithmeticBinOpExact(IRDLOperation, ABC):
"""Class for arithmetic binary operations that use an exact flag."""

T: ClassVar = VarConstraint("T", BaseAttr(IntegerType))

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)
exact = opt_prop_def(BoolAttr)

traits = traits_def(NoMemoryEffect())

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
attributes: dict[str, Attribute] = {},
exact: BoolAttr = IntegerAttr(0, i1),
):
super().__init__(
operands=[lhs, rhs],
attributes=attributes,
result_types=[lhs.type],
properties={
"isExact": exact,
},
)

@classmethod
def parse_exact(cls, parser: Parser):
if parser.parse_optional_keyword("exact") is not None:
return IntegerAttr(1, i1)
return IntegerAttr(0, i1)

def print_exact(self, printer: Printer) -> None:
if self.exact and self.exact.value.data:
printer.print(" exact ")

@classmethod
def parse(cls, parser: Parser):
exact = cls.parse_exact(parser)
lhs = parser.parse_unresolved_operand()
parser.parse_characters(",")
rhs = parser.parse_unresolved_operand()
attributes = parser.parse_optional_attr_dict()
parser.parse_characters(":")
type = parser.parse_type()
operands = parser.resolve_operands([lhs, rhs], [type, type], parser.pos)
return cls(operands[0], operands[1], attributes, exact)

def print(self, printer: Printer) -> None:
self.print_exact(printer)
printer.print(" ", self.lhs, ", ", self.rhs)
printer.print_op_attributes(self.attributes)
printer.print(" : ")
printer.print(self.lhs.type)


class IntegerConversionOp(IRDLOperation, ABC):
arg = operand_def(IntegerType)

Expand Down Expand Up @@ -523,12 +582,12 @@ class MulOp(ArithmeticBinOpOverflow):


@irdl_op_definition
class UDivOp(ArithmeticBinOperation):
class UDivOp(ArithmeticBinOpExact):
name = "llvm.udiv"


@irdl_op_definition
class SDivOp(ArithmeticBinOperation):
class SDivOp(ArithmeticBinOpExact):
name = "llvm.sdiv"


Expand Down Expand Up @@ -563,12 +622,12 @@ class ShlOp(ArithmeticBinOpOverflow):


@irdl_op_definition
class LShrOp(ArithmeticBinOperation):
class LShrOp(ArithmeticBinOpExact):
name = "llvm.lshr"


@irdl_op_definition
class AShrOp(ArithmeticBinOperation):
class AShrOp(ArithmeticBinOpExact):
name = "llvm.ashr"


Expand Down

0 comments on commit 77505b5

Please sign in to comment.