Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (llvm) Registering llvm.icmp #3356

Merged
merged 12 commits into from
Nov 1, 2024
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@

%cst32 = llvm.mlir.constant(25 : i32) : i32
// CHECK: %cst32 = llvm.mlir.constant(25 : i32) : i32

%icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32
// CHECK: %icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32
tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved
101 changes: 101 additions & 0 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,106 @@ def verify(self, verify_nested_ops: bool = True):
super().verify(verify_nested_ops)


class ICmpPredicateFlag(StrEnum):
EQ = "eq"
NE = "ne"
SLT = "slt"
SLE = "sle"
SGT = "sgt"
SGE = "sge"
ULT = "ult"
ULE = "ule"
UGT = "ugt"
UGE = "uge"

@classmethod
def from_int(cls, i: int) -> ICmpPredicateFlag:
if i == 0:
return ICmpPredicateFlag.EQ
elif i == 1:
return ICmpPredicateFlag.NE
elif i == 2:
return ICmpPredicateFlag.SLT
elif i == 3:
return ICmpPredicateFlag.SLE
elif i == 4:
return ICmpPredicateFlag.SGE
elif i == 5:
return ICmpPredicateFlag.SGT
elif i == 6:
return ICmpPredicateFlag.ULT
elif i == 7:
return ICmpPredicateFlag.ULE
elif i == 8:
return ICmpPredicateFlag.UGE
elif i == 9:
return ICmpPredicateFlag.UGT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There has to be a better way of doing this, I'd expect some helper to exist in the Python enum implementation, otherwise we should cache this mapping.

raise VerifyException(f"invalide predicate value {i}")


@irdl_attr_definition
class ICmpPredicateAttr(EnumAttribute[ICmpPredicateFlag]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you want OpaqueSyntaxAttribute or SpacedOpaqueSyntaxAttribute here to get it to print properly

name = "llvm.predicate"


@irdl_op_definition
class ICmpOp(IRDLOperation, ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this marked as an abstract class?

name = "llvm.icmp"
T: ClassVar = VarConstraint("T", BaseAttr(IntegerType))

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)
predicate = prop_def(ICmpPredicateAttr)

traits = frozenset([NoMemoryEffect()])

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
predicate: ICmpPredicateAttr,
attributes: dict[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
attributes=attributes,
result_types=[lhs.type],
properties={
"predicate": predicate,
},
)

@classmethod
def parse(cls, parser: Parser):
predicate_literal = parser.parse_str_literal()
predicate = ICmpPredicateAttr(ICmpPredicateFlag(predicate_literal))
lhs = parser.parse_unresolved_operand()
parser.parse_characters(",")
rhs = parser.parse_unresolved_operand()
attributes = parser.parse_optional_attr_dict()
parser.parse_characters(":")
type = parser.parse_type()
operands = parser.resolve_operands([lhs, rhs], [type, type], parser.pos)
return cls(operands[0], operands[1], predicate, attributes)

def print_predicate(self, printer: Printer):
if isattr(self.predicate, AnyIntegerAttr):
ICmpPredicateAttr(
ICmpPredicateFlag.from_int(self.predicate.value.data)
).print_parameter(printer)
else:
self.predicate.print_parameter(printer)

def print(self, printer: Printer):
printer.print(' "')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general, print_string is preferred to print but this is a super nit

self.print_predicate(printer)
printer.print('" ', self.lhs, ", ", self.rhs)
printer.print_op_attributes(self.attributes)
printer.print(" : ")
printer.print(self.lhs.type)


@irdl_op_definition
class GEPOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1500,6 +1600,7 @@ class ZeroOp(IRDLOperation):
TruncOp,
ZExtOp,
SExtOp,
ICmpOp,
ExtractValueOp,
InsertValueOp,
InlineAsmOp,
Expand Down
Loading