Skip to content

Commit

Permalink
dialects: (builtin) make FloatAttr a TypedAttribute (#3488)
Browse files Browse the repository at this point in the history
Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
alexarice and superlopuh authored Nov 20, 2024
1 parent 4b08fcd commit ee4094f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 15 deletions.
13 changes: 8 additions & 5 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import textwrap
from collections.abc import Callable
from io import StringIO
from typing import ClassVar, Generic, TypeVar
from typing import Annotated, ClassVar, Generic, TypeVar

import pytest

Expand All @@ -12,6 +12,8 @@
from xdsl.dialects.builtin import (
I32,
BoolAttr,
Float64Type,
FloatAttr,
IntegerAttr,
MemRefType,
ModuleOp,
Expand Down Expand Up @@ -603,20 +605,21 @@ class OptionalAttributeOp(IRDLOperation):
"program, generic_program",
[
(
"test.typed_attr 3",
'"test.typed_attr"() {"attr" = 3 : i32} : () -> ()',
"test.typed_attr 3 3.000000e+00",
'"test.typed_attr"() {"attr" = 3 : i32, "float_attr" = 3.000000e+00 : f64} : () -> ()',
),
],
)
def test_typed_attribute_variable(program: str, generic_program: str):
"""Test the parsing of optional operands"""
"""Test the parsing of typed attributes"""

@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])
float_attr = attr_def(FloatAttr[Annotated[Float64Type, Float64Type()]])

assembly_format = "$attr attr-dict"
assembly_format = "$attr $float_attr attr-dict"

ctx = MLContext()
ctx.load_op(TypedAttributeOp)
Expand Down
13 changes: 12 additions & 1 deletion xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def __hash__(self):


@irdl_attr_definition
class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute):
class FloatAttr(Generic[_FloatAttrType], TypedAttribute[_FloatAttrType]):
name = "float"

value: ParameterDef[FloatData]
Expand Down Expand Up @@ -668,6 +668,17 @@ def __init__(
raise ValueError(f"Invalid bitwidth: {type}")
super().__init__([data_attr, type])

@staticmethod
def parse_with_type(
parser: AttrParser,
type: AttributeInvT,
) -> TypedAttribute[AttributeInvT]:
assert isinstance(type, AnyFloat)
return FloatAttr(parser.parse_float(), type)

def print_without_type(self, printer: Printer):
return printer.print_float(self)


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)
Expand Down
33 changes: 31 additions & 2 deletions xdsl/parser/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,36 @@ def parse_integer(
"Expected integer literal" + context_msg,
)

def parse_optional_float(
self,
*,
allow_negative: bool = True,
) -> float | None:
"""
Parse a (possibly negative) float, if present.
"""
is_negative = False
if allow_negative:
is_negative = self._parse_optional_token(Token.Kind.MINUS) is not None

if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None:
value = value.get_float_value()
return -value if is_negative else value

def parse_float(
self,
*,
allow_negative: bool = True,
) -> float:
"""
Parse a (possibly negative) float.
"""

return self.expect(
lambda: self.parse_optional_float(allow_negative=allow_negative),
"Expected float literal",
)

def parse_optional_number(
self, *, allow_boolean: bool = False
) -> int | float | None:
Expand All @@ -376,8 +406,7 @@ def parse_optional_number(
) is not None:
return -value if is_negative else value

if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None:
value = value.get_float_value()
if (value := self.parse_optional_float(allow_negative=False)) is not None:
return -value if is_negative else value

if is_negative:
Expand Down
7 changes: 0 additions & 7 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,6 @@ def print_attribute(self, attribute: Attribute) -> None:
self.print_identifier_or_string_literal(ref.data)
return

if isinstance(attribute, FloatAttr):
attr = cast(AnyFloatAttr, attribute)
self.print_float(attr)
self.print_string(" : ")
self.print_attribute(attr.type)
return

# Complex types have MLIR shorthands but XDSL does not.
if isinstance(attribute, ComplexType):
self.print_string("complex<")
Expand Down

0 comments on commit ee4094f

Please sign in to comment.