Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (llvm) Registering llvm.icmp #3356

Merged
merged 12 commits into from
Nov 1, 2024
30 changes: 30 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,33 @@

%cst32 = llvm.mlir.constant(25 : i32) : i32
// CHECK: %cst32 = llvm.mlir.constant(25 : i32) : i32

%icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32
// CHECK: %icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32
tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved

%icmp_ne = llvm.icmp "ne" %arg0, %arg1 : i32
// CHECK: %icmp_ne = llvm.icmp "ne" %arg0, %arg1 : i32

%icmp_slt = llvm.icmp "slt" %arg0, %arg1 : i32
// CHECK: %icmp_slt = llvm.icmp "slt" %arg0, %arg1 : i32

%icmp_sle = llvm.icmp "sle" %arg0, %arg1 : i32
// CHECK: %icmp_sle = llvm.icmp "sle" %arg0, %arg1 : i32

%icmp_sgt = llvm.icmp "sgt" %arg0, %arg1 : i32
// CHECK: %icmp_sgt = llvm.icmp "sgt" %arg0, %arg1 : i32

%icmp_sge = llvm.icmp "sge" %arg0, %arg1 : i32
// CHECK: %icmp_sge = llvm.icmp "sge" %arg0, %arg1 : i32

%icmp_ult = llvm.icmp "ult" %arg0, %arg1 : i32
// CHECK: %icmp_ult = llvm.icmp "ult" %arg0, %arg1 : i32

%icmp_ule = llvm.icmp "ule" %arg0, %arg1 : i32
// CHECK: %icmp_ule = llvm.icmp "ule" %arg0, %arg1 : i32

%icmp_ugt = llvm.icmp "ugt" %arg0, %arg1 : i32
// CHECK: %icmp_ugt = llvm.icmp "ugt" %arg0, %arg1 : i32

%icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
// CHECK: %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
80 changes: 80 additions & 0 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Operation,
ParametrizedAttribute,
Region,
SpacedOpaqueSyntaxAttribute,
SSAValue,
TypeAttribute,
)
Expand Down Expand Up @@ -610,6 +611,83 @@ def verify(self, verify_nested_ops: bool = True):
super().verify(verify_nested_ops)


class ICmpPredicateFlag(StrEnum):
EQ = "eq"
NE = "ne"
SLT = "slt"
SLE = "sle"
SGT = "sgt"
SGE = "sge"
ULT = "ult"
ULE = "ule"
UGT = "ugt"
UGE = "uge"


@irdl_attr_definition
class ICmpPredicateAttr(SpacedOpaqueSyntaxAttribute, EnumAttribute[ICmpPredicateFlag]):
name = "llvm.predicate"
ALL_PREDICATES = tuple(ICmpPredicateFlag)

@classmethod
def from_int(cls, i: int):
return ICmpPredicateAttr(cls.ALL_PREDICATES[i])


@irdl_op_definition
class ICmpOp(IRDLOperation):
name = "llvm.icmp"
T: ClassVar = VarConstraint("T", BaseAttr(IntegerType))

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)
predicate = prop_def(ICmpPredicateAttr)

traits = frozenset([NoMemoryEffect()])

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
predicate: ICmpPredicateAttr,
attributes: dict[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
attributes=attributes,
result_types=[lhs.type],
properties={
"predicate": predicate,
},
)

@classmethod
def parse(cls, parser: Parser):
predicate_literal = parser.parse_str_literal()
predicate_value = ICmpPredicateFlag[predicate_literal.upper()]
predicate = ICmpPredicateAttr(predicate_value)
lhs = parser.parse_unresolved_operand()
parser.parse_characters(",")
rhs = parser.parse_unresolved_operand()
attributes = parser.parse_optional_attr_dict()
parser.parse_characters(":")
type = parser.parse_type()
operands = parser.resolve_operands([lhs, rhs], [type, type], parser.pos)
return cls(operands[0], operands[1], predicate, attributes)

def print_predicate(self, printer: Printer):
self.predicate.print_parameter(printer)

def print(self, printer: Printer):
printer.print_string(' "')
self.print_predicate(printer)
printer.print('" ', self.lhs, ", ", self.rhs)
printer.print_op_attributes(self.attributes)
printer.print_string(" : ")
printer.print(self.lhs.type)


@irdl_op_definition
class GEPOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1500,6 +1578,7 @@ class ZeroOp(IRDLOperation):
TruncOp,
ZExtOp,
SExtOp,
ICmpOp,
ExtractValueOp,
InsertValueOp,
InlineAsmOp,
Expand Down Expand Up @@ -1530,5 +1609,6 @@ class ZeroOp(IRDLOperation):
TailCallKindAttr,
FastMathAttr,
OverflowAttr,
ICmpPredicateAttr,
],
)
Loading