From 533a526c42e222aee863bb3195102e25f490dd82 Mon Sep 17 00:00:00 2001 From: kirk Date: Thu, 10 Oct 2024 22:37:17 +0000 Subject: [PATCH] dialects (arith) update generic print format for fastmath flag in Cmpf, and assign a default value --- tests/filecheck/dialects/arith/arith_ops.mlir | 2 +- xdsl/dialects/arith.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/filecheck/dialects/arith/arith_ops.mlir b/tests/filecheck/dialects/arith/arith_ops.mlir index 1e629d68b8..1bbe7d4a59 100644 --- a/tests/filecheck/dialects/arith/arith_ops.mlir +++ b/tests/filecheck/dialects/arith/arith_ops.mlir @@ -169,7 +169,7 @@ %cmpf_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 2 : i64, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 - // CHECK-NEXT: %cmpf_fm = arith.cmpf ogt, %lhsf32, %rhsf32 fastmath : f32 + // CHECK-NEXT: %cmpf_fm = arith.cmpf ogt, %lhsf32, %rhsf32 {"fastmath" = #arith.fastmath} : f32 %selecti = "arith.select"(%lhsi1, %lhsi32, %rhsi32) : (i1, i32, i32) -> i32 %selectf = "arith.select"(%lhsi1, %lhsf32, %rhsf32) : (i1, f32, f32) -> f32 diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 52e57ee9c4..890260b921 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -720,7 +720,7 @@ class Cmpf(ComparisonOperation): predicate = prop_def(AnyIntegerAttr) lhs = operand_def(floatingPointLike) rhs = operand_def(floatingPointLike) - fastmath = opt_prop_def(FastMathFlagsAttr) + fastmath = prop_def(FastMathFlagsAttr, default_value=FastMathFlagsAttr("none")) result = result_def(IntegerType(1)) def __init__( @@ -728,7 +728,7 @@ def __init__( operand1: SSAValue | Operation, operand2: SSAValue | Operation, arg: int | str, - fastmath: FastMathFlagsAttr | None = None, + fastmath: FastMathFlagsAttr = FastMathFlagsAttr("none"), ): operand1 = SSAValue.get(operand1) operand2 = SSAValue.get(operand2) @@ -759,8 +759,10 @@ def __init__( super().__init__( operands=[operand1, operand2], result_types=[IntegerType(1)], - properties={"predicate": IntegerAttr.from_int_and_width(arg, 64)}, - attributes={"fastmath": fastmath}, + properties={ + "predicate": IntegerAttr.from_int_and_width(arg, 64), + "fastmath": fastmath, + }, ) @classmethod @@ -770,7 +772,7 @@ def parse(cls, parser: Parser): operand1 = parser.parse_unresolved_operand() parser.parse_punctuation(",") operand2 = parser.parse_unresolved_operand() - fastmath = None + fastmath = FastMathFlagsAttr("none") if parser.parse_optional_keyword("fastmath") is not None: fastmath = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser)) parser.parse_punctuation(":") @@ -788,9 +790,9 @@ def print(self, printer: Printer): printer.print_operand(self.lhs) printer.print(", ") printer.print_operand(self.rhs) - if self.fastmath is not None and self.fastmath != FastMathFlagsAttr("none"): - printer.print(" fastmath") - self.fastmath.print_parameter(printer) + if self.fastmath != FastMathFlagsAttr("none"): + printer.print(" ") + printer.print_attr_dict({"fastmath": self.fastmath}) printer.print(" : ") printer.print_attribute(self.lhs.type)