Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (llvm) Registering llvm.icmp #3356

Merged
merged 12 commits into from
Nov 1, 2024
30 changes: 30 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,33 @@

%cst32 = llvm.mlir.constant(25 : i32) : i32
// CHECK: %cst32 = llvm.mlir.constant(25 : i32) : i32

%icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32
// CHECK: %icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32
tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved

%icmp_ne = llvm.icmp "ne" %arg0, %arg1 : i32
// CHECK: %icmp_ne = llvm.icmp "ne" %arg0, %arg1 : i32

%icmp_slt = llvm.icmp "slt" %arg0, %arg1 : i32
// CHECK: %icmp_slt = llvm.icmp "slt" %arg0, %arg1 : i32

%icmp_sle = llvm.icmp "sle" %arg0, %arg1 : i32
// CHECK: %icmp_sle = llvm.icmp "sle" %arg0, %arg1 : i32

%icmp_sgt = llvm.icmp "sgt" %arg0, %arg1 : i32
// CHECK: %icmp_sgt = llvm.icmp "sgt" %arg0, %arg1 : i32

%icmp_sge = llvm.icmp "sge" %arg0, %arg1 : i32
// CHECK: %icmp_sge = llvm.icmp "sge" %arg0, %arg1 : i32

%icmp_ult = llvm.icmp "ult" %arg0, %arg1 : i32
// CHECK: %icmp_ult = llvm.icmp "ult" %arg0, %arg1 : i32

%icmp_ule = llvm.icmp "ule" %arg0, %arg1 : i32
// CHECK: %icmp_ule = llvm.icmp "ule" %arg0, %arg1 : i32

%icmp_ugt = llvm.icmp "ugt" %arg0, %arg1 : i32
// CHECK: %icmp_ugt = llvm.icmp "ugt" %arg0, %arg1 : i32

%icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
// CHECK: %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
101 changes: 101 additions & 0 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,106 @@ def verify(self, verify_nested_ops: bool = True):
super().verify(verify_nested_ops)


class ICmpPredicateFlag(StrEnum):
EQ = 0
NE = 1
SLT = 2
SLE = 3
SGT = 4
SGE = 5
ULT = 6
ULE = 7
UGT = 8
UGE = 9

@classmethod
def from_int(cls, i: int) -> ICmpPredicateFlag:
if i == 0:
return ICmpPredicateFlag.EQ
elif i == 1:
return ICmpPredicateFlag.NE
elif i == 2:
return ICmpPredicateFlag.SLT
elif i == 3:
return ICmpPredicateFlag.SLE
elif i == 4:
return ICmpPredicateFlag.SGT
elif i == 5:
return ICmpPredicateFlag.SGE
elif i == 6:
return ICmpPredicateFlag.ULT
elif i == 7:
return ICmpPredicateFlag.ULE
elif i == 8:
return ICmpPredicateFlag.UGT
elif i == 9:
return ICmpPredicateFlag.UGE
raise VerifyException(f"invalide predicate value {i}")


@irdl_op_definition
class ICmpOp(IRDLOperation, ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this marked as an abstract class?

name = "llvm.icmp"
T: ClassVar = VarConstraint("T", BaseAttr(IntegerType))

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)
predicate = prop_def(Attribute)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change? it was better before as a strenum, no?


traits = frozenset([NoMemoryEffect()])

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
predicate: Attribute,
attributes: dict[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
attributes=attributes,
result_types=[lhs.type],
properties={
"predicate": predicate,
},
)

@classmethod
def parse(cls, parser: Parser):
predicate_literal = parser.parse_str_literal()
predicate_value = ICmpPredicateFlag[predicate_literal.upper()].value
predicate = IntAttr(predicate_value)
lhs = parser.parse_unresolved_operand()
parser.parse_characters(",")
rhs = parser.parse_unresolved_operand()
attributes = parser.parse_optional_attr_dict()
parser.parse_characters(":")
type = parser.parse_type()
operands = parser.resolve_operands([lhs, rhs], [type, type], parser.pos)
return cls(operands[0], operands[1], predicate, attributes)

def print_predicate(self, printer: Printer):
i = None
if isattr(self.predicate, IntAttr):
i = self.predicate.data
elif isattr(self.predicate, AnyIntegerAttr):
i = self.predicate.value.data
if i is None:
raise VerifyException(
f"Predicate is a {type(self.predicate)} when it should be an IntAttr or IntegerAttr"
)
printer.print(ICmpPredicateFlag.from_int(i).name.lower())

def print(self, printer: Printer):
printer.print(' "')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general, print_string is preferred to print but this is a super nit

self.print_predicate(printer)
printer.print('" ', self.lhs, ", ", self.rhs)
printer.print_op_attributes(self.attributes)
printer.print(" : ")
printer.print(self.lhs.type)


@irdl_op_definition
class GEPOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1500,6 +1600,7 @@ class ZeroOp(IRDLOperation):
TruncOp,
ZExtOp,
SExtOp,
ICmpOp,
ExtractValueOp,
InsertValueOp,
InlineAsmOp,
Expand Down
Loading