diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 9d3a1484f6..a58de17fae 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -3,7 +3,7 @@ import textwrap from collections.abc import Callable from io import StringIO -from typing import ClassVar, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar import pytest @@ -12,6 +12,8 @@ from xdsl.dialects.builtin import ( I32, BoolAttr, + Float64Type, + FloatAttr, IntegerAttr, MemRefType, ModuleOp, @@ -603,20 +605,21 @@ class OptionalAttributeOp(IRDLOperation): "program, generic_program", [ ( - "test.typed_attr 3", - '"test.typed_attr"() {"attr" = 3 : i32} : () -> ()', + "test.typed_attr 3 3.000000e+00", + '"test.typed_attr"() {"attr" = 3 : i32, "float_attr" = 3.000000e+00 : f64} : () -> ()', ), ], ) def test_typed_attribute_variable(program: str, generic_program: str): - """Test the parsing of optional operands""" + """Test the parsing of typed attributes""" @irdl_op_definition class TypedAttributeOp(IRDLOperation): name = "test.typed_attr" attr = attr_def(IntegerAttr[I32]) + float_attr = attr_def(FloatAttr[Annotated[Float64Type, Float64Type()]]) - assembly_format = "$attr attr-dict" + assembly_format = "$attr $float_attr attr-dict" ctx = MLContext() ctx.load_op(TypedAttributeOp) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index df0e3d89cb..3cbccf68bc 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -634,7 +634,7 @@ def __hash__(self): @irdl_attr_definition -class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute): +class FloatAttr(Generic[_FloatAttrType], TypedAttribute[_FloatAttrType]): name = "float" value: ParameterDef[FloatData] @@ -668,6 +668,17 @@ def __init__( raise ValueError(f"Invalid bitwidth: {type}") super().__init__([data_attr, type]) + @staticmethod + def parse_with_type( + parser: AttrParser, + type: AttributeInvT, + ) -> TypedAttribute[AttributeInvT]: + assert isinstance(type, AnyFloat) + return FloatAttr(parser.parse_float(), type) + + def print_without_type(self, printer: Printer): + return printer.print_float(self) + AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat] AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr) diff --git a/xdsl/parser/base_parser.py b/xdsl/parser/base_parser.py index 792f87e9e8..3586c7b6cc 100644 --- a/xdsl/parser/base_parser.py +++ b/xdsl/parser/base_parser.py @@ -359,6 +359,36 @@ def parse_integer( "Expected integer literal" + context_msg, ) + def parse_optional_float( + self, + *, + allow_negative: bool = True, + ) -> float | None: + """ + Parse a (possibly negative) float, if present. + """ + is_negative = False + if allow_negative: + is_negative = self._parse_optional_token(Token.Kind.MINUS) is not None + + if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: + value = value.get_float_value() + return -value if is_negative else value + + def parse_float( + self, + *, + allow_negative: bool = True, + ) -> float: + """ + Parse a (possibly negative) float. + """ + + return self.expect( + lambda: self.parse_optional_float(allow_negative=allow_negative), + "Expected float literal", + ) + def parse_optional_number( self, *, allow_boolean: bool = False ) -> int | float | None: @@ -376,8 +406,7 @@ def parse_optional_number( ) is not None: return -value if is_negative else value - if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: - value = value.get_float_value() + if (value := self.parse_optional_float(allow_negative=False)) is not None: return -value if is_negative else value if is_negative: diff --git a/xdsl/printer.py b/xdsl/printer.py index 792e08c3b7..f9b8bb4ab4 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -561,13 +561,6 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_identifier_or_string_literal(ref.data) return - if isinstance(attribute, FloatAttr): - attr = cast(AnyFloatAttr, attribute) - self.print_float(attr) - self.print_string(" : ") - self.print_attribute(attr.type) - return - # Complex types have MLIR shorthands but XDSL does not. if isinstance(attribute, ComplexType): self.print_string("complex<")