From 0c871d9b1a18737e5734dac2d56d38d5411566b5 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Wed, 20 Nov 2024 14:30:04 +0000 Subject: [PATCH] core: make TypedAttribute printing generic (#3490) This currently doesn't really do anything because `IntegerAttr` is the only `TypedAttribute`. But after future PRs this will unify the way in which `TypedAttribute` is printed. --- xdsl/ir/core.py | 4 ++-- xdsl/irdl/attributes.py | 7 ++++++- xdsl/printer.py | 38 ++++++++++++++++++++------------------ 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index 859019a6be..b4bece71de 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -611,8 +611,8 @@ class TypedAttribute(ParametrizedAttribute, Generic[AttributeCovT], ABC): An attribute with a type. """ - @staticmethod - def get_type_index() -> int: ... + @classmethod + def get_type_index(cls) -> int: ... @staticmethod def parse_with_type( diff --git a/xdsl/irdl/attributes.py b/xdsl/irdl/attributes.py index 7a9969e751..08db888feb 100644 --- a/xdsl/irdl/attributes.py +++ b/xdsl/irdl/attributes.py @@ -235,7 +235,12 @@ def irdl_param_attr_definition(cls: _PAttrTT) -> _PAttrTT: if issubclass(cls, TypedAttribute): parameter_names: tuple[str] = tuple(zip(*attr_def.parameters))[0] type_index = parameter_names.index("type") - new_fields["get_type_index"] = lambda: type_index + + @classmethod + def get_type_index(cls: Any) -> int: + return type_index + + new_fields["get_type_index"] = get_type_index return runtime_final( dataclass(frozen=True, init=False)( diff --git a/xdsl/printer.py b/xdsl/printer.py index 0723e23af9..792e08c3b7 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -68,6 +68,7 @@ SpacedOpaqueSyntaxAttribute, SSAValue, TypeAttribute, + TypedAttribute, ) from xdsl.traits import IsolatedFromAbove, IsTerminator from xdsl.utils.bitwise_casts import ( @@ -525,6 +526,25 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_string("f128") return + if isinstance(attribute, IntegerAttr): + # boolean shorthands + if ( + isinstance( + (ty := attribute.parameters[attribute.get_type_index()]), + IntegerType, + ) + and ty.width.data == 1 + ): + self.print_string("true" if attribute.value.data else "false") + return + # Otherwise we fall through to TypedAttribute case + + if isinstance(attribute, TypedAttribute): + attribute.print_without_type(self) + self.print_string(" : ") + self.print_attribute(attribute.parameters[attribute.get_type_index()]) + return + if isinstance(attribute, StringAttr): self.print_string_literal(attribute.data) return @@ -541,24 +561,6 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_identifier_or_string_literal(ref.data) return - if isinstance(attribute, IntegerAttr): - attribute = cast(AnyIntegerAttr, attribute) - - # boolean shorthands - if ( - isinstance((attr_type := attribute.type), IntegerType) - and attr_type.width.data == 1 - ): - self.print_string("false" if attribute.value.data == 0 else "true") - return - - width = attribute.value - attr_type = attribute.type - assert isinstance(width, IntAttr) - self.print_string(f"{width.data} : ") - self.print_attribute(attr_type) - return - if isinstance(attribute, FloatAttr): attr = cast(AnyFloatAttr, attribute) self.print_float(attr)