diff --git a/tests/filecheck/dialects/llvm/icmp.mlir b/tests/filecheck/dialects/llvm/icmp.mlir new file mode 100644 index 0000000000..dc558e3010 --- /dev/null +++ b/tests/filecheck/dialects/llvm/icmp.mlir @@ -0,0 +1,53 @@ +// RUN: xdsl-opt %s | xdsl-opt --print-op-generic | filecheck %s + +%arg0, %arg1 = "test.op"() : () -> (i32, i32) + +%icmp_eq_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 0 : i64}> : (i32, i32) -> i1 +%icmp_eq = llvm.icmp "eq" %arg0, %arg1 : i32 +// CHECK: %icmp_eq_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 0 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_eq = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 0 : i64}> : (i32, i32) -> i1 + +%icmp_ne_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 1 : i64}> : (i32, i32) -> i1 +%icmp_ne = llvm.icmp "ne" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_ne_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 1 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_ne = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 1 : i64}> : (i32, i32) -> i1 + +%icmp_slt_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 2 : i64}> : (i32, i32) -> i1 +%icmp_slt = llvm.icmp "slt" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_slt_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 2 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_slt = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 2 : i64}> : (i32, i32) -> i1 + +%icmp_sle_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 3 : i64}> : (i32, i32) -> i1 +%icmp_sle = llvm.icmp "sle" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_sle_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 3 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_sle = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 3 : i64}> : (i32, i32) -> i1 + +%icmp_sgt_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 4 : i64}> : (i32, i32) -> i1 +%icmp_sgt = llvm.icmp "sgt" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_sgt_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 4 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_sgt = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 4 : i64}> : (i32, i32) -> i1 + +%icmp_sge_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 5 : i64}> : (i32, i32) -> i1 +%icmp_sge = llvm.icmp "sge" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_sge_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 5 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_sge = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 5 : i64}> : (i32, i32) -> i1 + +%icmp_ult_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 6 : i64}> : (i32, i32) -> i1 +%icmp_ult = llvm.icmp "ult" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_ult_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 6 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_ult = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 6 : i64}> : (i32, i32) -> i1 + +%icmp_ule_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 7 : i64}> : (i32, i32) -> i1 +%icmp_ule = llvm.icmp "ule" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_ule_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 7 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_ule = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 7 : i64}> : (i32, i32) -> i1 + +%icmp_ugt_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 8 : i64}> : (i32, i32) -> i1 +%icmp_ugt = llvm.icmp "ugt" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_ugt_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 8 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_ugt = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 8 : i64}> : (i32, i32) -> i1 + +%icmp_uge_p = "llvm.icmp"(%arg0, %arg1) <{predicate = 9 : i64}> : (i32, i32) -> i1 +%icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32 +// CHECK-NEXT: %icmp_uge_p = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 9 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %icmp_uge = "llvm.icmp"(%arg0, %arg1) <{"predicate" = 9 : i64}> : (i32, i32) -> i1 diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/llvm/icmp.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/llvm/icmp.mlir new file mode 100644 index 0000000000..85982700d7 --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/llvm/icmp.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s + +%0, %1 = "test.op"() : () -> (i32, i32) + +%2 = "llvm.icmp"(%0, %1) <{predicate = 0 : i64}> : (i32, i32) -> i1 +%3 = llvm.icmp "eq" %0, %1 : i32 +// CHECK: %2 = "llvm.icmp"(%0, %1) <{"predicate" = 0 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %3 = "llvm.icmp"(%0, %1) <{"predicate" = 0 : i64}> : (i32, i32) -> i1 + +%4 = "llvm.icmp"(%0, %1) <{predicate = 1 : i64}> : (i32, i32) -> i1 +%5 = llvm.icmp "ne" %0, %1 : i32 +// CHECK-NEXT: %4 = "llvm.icmp"(%0, %1) <{"predicate" = 1 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %5 = "llvm.icmp"(%0, %1) <{"predicate" = 1 : i64}> : (i32, i32) -> i1 + +%6 = "llvm.icmp"(%0, %1) <{predicate = 2 : i64}> : (i32, i32) -> i1 +%7 = llvm.icmp "slt" %0, %1 : i32 +// CHECK-NEXT: %6 = "llvm.icmp"(%0, %1) <{"predicate" = 2 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %7 = "llvm.icmp"(%0, %1) <{"predicate" = 2 : i64}> : (i32, i32) -> i1 + +%8 = "llvm.icmp"(%0, %1) <{predicate = 3 : i64}> : (i32, i32) -> i1 +%9 = llvm.icmp "sle" %0, %1 : i32 +// CHECK-NEXT: %8 = "llvm.icmp"(%0, %1) <{"predicate" = 3 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %9 = "llvm.icmp"(%0, %1) <{"predicate" = 3 : i64}> : (i32, i32) -> i1 + +%10 = "llvm.icmp"(%0, %1) <{predicate = 4 : i64}> : (i32, i32) -> i1 +%11 = llvm.icmp "sgt" %0, %1 : i32 +// CHECK-NEXT: %10 = "llvm.icmp"(%0, %1) <{"predicate" = 4 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %11 = "llvm.icmp"(%0, %1) <{"predicate" = 4 : i64}> : (i32, i32) -> i1 + +%12 = "llvm.icmp"(%0, %1) <{predicate = 5 : i64}> : (i32, i32) -> i1 +%13 = llvm.icmp "sge" %0, %1 : i32 +// CHECK-NEXT: %12 = "llvm.icmp"(%0, %1) <{"predicate" = 5 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %13 = "llvm.icmp"(%0, %1) <{"predicate" = 5 : i64}> : (i32, i32) -> i1 + +%14 = "llvm.icmp"(%0, %1) <{predicate = 6 : i64}> : (i32, i32) -> i1 +%15 = llvm.icmp "ult" %0, %1 : i32 +// CHECK-NEXT: %14 = "llvm.icmp"(%0, %1) <{"predicate" = 6 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %15 = "llvm.icmp"(%0, %1) <{"predicate" = 6 : i64}> : (i32, i32) -> i1 + +%16 = "llvm.icmp"(%0, %1) <{predicate = 7 : i64}> : (i32, i32) -> i1 +%17 = llvm.icmp "ule" %0, %1 : i32 +// CHECK-NEXT: %16 = "llvm.icmp"(%0, %1) <{"predicate" = 7 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %17 = "llvm.icmp"(%0, %1) <{"predicate" = 7 : i64}> : (i32, i32) -> i1 + +%18 = "llvm.icmp"(%0, %1) <{predicate = 8 : i64}> : (i32, i32) -> i1 +%19 = llvm.icmp "ugt" %0, %1 : i32 +// CHECK-NEXT: %18 = "llvm.icmp"(%0, %1) <{"predicate" = 8 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %19 = "llvm.icmp"(%0, %1) <{"predicate" = 8 : i64}> : (i32, i32) -> i1 + +%20 = "llvm.icmp"(%0, %1) <{predicate = 9 : i64}> : (i32, i32) -> i1 +%21 = llvm.icmp "uge" %0, %1 : i32 +// CHECK-NEXT: %20 = "llvm.icmp"(%0, %1) <{"predicate" = 9 : i64}> : (i32, i32) -> i1 +// CHECK-NEXT: %21 = "llvm.icmp"(%0, %1) <{"predicate" = 9 : i64}> : (i32, i32) -> i1 diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index fbbd6bc204..e6da96bb99 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -20,6 +20,7 @@ StringAttr, SymbolRefAttr, UnitAttr, + i1, i32, i64, ) @@ -31,7 +32,6 @@ Operation, ParametrizedAttribute, Region, - SpacedOpaqueSyntaxAttribute, SSAValue, TypeAttribute, ) @@ -623,15 +623,16 @@ class ICmpPredicateFlag(StrEnum): UGT = "ugt" UGE = "uge" + @staticmethod + def from_int(index: int) -> ICmpPredicateFlag: + return ALL_ICMP_FLAGS[index] -@irdl_attr_definition -class ICmpPredicateAttr(SpacedOpaqueSyntaxAttribute, EnumAttribute[ICmpPredicateFlag]): - name = "llvm.predicate" - ALL_PREDICATES = tuple(ICmpPredicateFlag) + def to_int(self) -> int: + return ICMP_INDEX_BY_FLAG[self] - @classmethod - def from_int(cls, i: int): - return ICmpPredicateAttr(cls.ALL_PREDICATES[i]) + +ALL_ICMP_FLAGS = tuple(ICmpPredicateFlag) +ICMP_INDEX_BY_FLAG = {f: i for (i, f) in enumerate(ALL_ICMP_FLAGS)} @irdl_op_definition @@ -641,8 +642,8 @@ class ICmpOp(IRDLOperation): lhs = operand_def(T) rhs = operand_def(T) - res = result_def(T) - predicate = prop_def(ICmpPredicateAttr) + res = result_def(i1) + predicate = prop_def(IntegerAttr[i64]) traits = frozenset([NoMemoryEffect()]) @@ -650,13 +651,13 @@ def __init__( self, lhs: SSAValue, rhs: SSAValue, - predicate: ICmpPredicateAttr, + predicate: IntegerAttr[IntegerType], attributes: dict[str, Attribute] = {}, ): super().__init__( operands=[lhs, rhs], attributes=attributes, - result_types=[lhs.type], + result_types=[i1], properties={ "predicate": predicate, }, @@ -666,7 +667,8 @@ def __init__( def parse(cls, parser: Parser): predicate_literal = parser.parse_str_literal() predicate_value = ICmpPredicateFlag[predicate_literal.upper()] - predicate = ICmpPredicateAttr(predicate_value) + predicate_int = predicate_value.to_int() + predicate = IntegerAttr(predicate_int, i64) lhs = parser.parse_unresolved_operand() parser.parse_characters(",") rhs = parser.parse_unresolved_operand() @@ -677,7 +679,8 @@ def parse(cls, parser: Parser): return cls(operands[0], operands[1], predicate, attributes) def print_predicate(self, printer: Printer): - self.predicate.print_parameter(printer) + flag = ICmpPredicateFlag.from_int(self.predicate.value.data) + printer.print_string(f"{flag}") def print(self, printer: Printer): printer.print_string(' "') @@ -1609,6 +1612,5 @@ class ZeroOp(IRDLOperation): TailCallKindAttr, FastMathAttr, OverflowAttr, - ICmpPredicateAttr, ], )