From 4b08fcdbd53ae5c586934bc40b3b75f7be568126 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 16:50:02 +0000 Subject: [PATCH 01/15] misc: split dialects.utils into multiple files (#3472) It feels like a logical split, and leaves space for some upcoming shared dialects helpers. --- pyproject.toml | 16 +++++----- xdsl/dialects/utils/__init__.py | 4 +++ xdsl/dialects/utils/fast_math.py | 29 ++++++++++++++++++ xdsl/dialects/{utils.py => utils/format.py} | 34 --------------------- 4 files changed, 42 insertions(+), 41 deletions(-) create mode 100644 xdsl/dialects/utils/__init__.py create mode 100644 xdsl/dialects/utils/fast_math.py rename xdsl/dialects/{utils.py => utils/format.py} (93%) diff --git a/pyproject.toml b/pyproject.toml index 85f01a6fda..2a677edd06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,17 +109,19 @@ ignore = [ max-line-length = 300 [tool.ruff.lint.flake8-tidy-imports.banned-api] -"xdsl.parser.core".msg = "Use xdsl.parser instead." -"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead." -"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead." +"xdsl.dialects.utils.fast_math".msg = "Use xdsl.dialects.utils instead" +"xdsl.dialects.utils.format".msg = "Use xdsl.dialects.utils instead" +"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead" +"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead" +"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead" "xdsl.ir.core".msg = "Use xdsl.ir instead." +"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead" "xdsl.irdl.common".msg = "Use xdsl.irdl instead" "xdsl.irdl.constraints".msg = "Use xdsl.irdl instead" -"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead" "xdsl.irdl.operations".msg = "Use xdsl.irdl instead" -"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead" -"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead" -"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead" +"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead." +"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead." +"xdsl.parser.core".msg = "Use xdsl.parser instead." [tool.ruff.lint.per-file-ignores] diff --git a/xdsl/dialects/utils/__init__.py b/xdsl/dialects/utils/__init__.py new file mode 100644 index 0000000000..97cc7a2b6f --- /dev/null +++ b/xdsl/dialects/utils/__init__.py @@ -0,0 +1,4 @@ +# TID 251 enforces to not import from those +# We need to skip it here to allow importing from here instead. +from .fast_math import * # noqa: TID251 +from .format import * # noqa: TID251 diff --git a/xdsl/dialects/utils/fast_math.py b/xdsl/dialects/utils/fast_math.py new file mode 100644 index 0000000000..c12b8b5878 --- /dev/null +++ b/xdsl/dialects/utils/fast_math.py @@ -0,0 +1,29 @@ +from abc import ABC +from dataclasses import dataclass + +from xdsl.ir import BitEnumAttribute +from xdsl.utils.str_enum import StrEnum + + +class FastMathFlag(StrEnum): + """ + Values specifying fast math behaviour of an arithmetic operation. + """ + + REASSOC = "reassoc" + NO_NANS = "nnan" + NO_INFS = "ninf" + NO_SIGNED_ZEROS = "nsz" + ALLOW_RECIP = "arcp" + ALLOW_CONTRACT = "contract" + APPROX_FUNC = "afn" + + +@dataclass(frozen=True, init=False) +class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC): + """ + Base class for attributes defining fast math behavior of arithmetic operations. + """ + + none_value = "none" + all_value = "fast" diff --git a/xdsl/dialects/utils.py b/xdsl/dialects/utils/format.py similarity index 93% rename from xdsl/dialects/utils.py rename to xdsl/dialects/utils/format.py index 5b5fb8b0f9..ae7965d213 100644 --- a/xdsl/dialects/utils.py +++ b/xdsl/dialects/utils/format.py @@ -1,6 +1,4 @@ -from abc import ABC from collections.abc import Iterable, Sequence -from dataclasses import dataclass from typing import Generic from xdsl.dialects.builtin import ( @@ -14,7 +12,6 @@ from xdsl.ir import ( Attribute, AttributeInvT, - BitEnumAttribute, BlockArgument, Operation, Region, @@ -23,7 +20,6 @@ from xdsl.irdl import IRDLOperation, var_operand_def from xdsl.parser import Parser, UnresolvedOperand from xdsl.printer import Printer -from xdsl.utils.str_enum import StrEnum def print_call_op_like( @@ -345,33 +341,3 @@ def parse_dynamic_index_list_without_types( values.append(value_or_index) return values, indices - - -# region Fast Math Flags - - -class FastMathFlag(StrEnum): - """ - Values specifying fast math behaviour of an arithmetic operation. - """ - - REASSOC = "reassoc" - NO_NANS = "nnan" - NO_INFS = "ninf" - NO_SIGNED_ZEROS = "nsz" - ALLOW_RECIP = "arcp" - ALLOW_CONTRACT = "contract" - APPROX_FUNC = "afn" - - -@dataclass(frozen=True, init=False) -class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC): - """ - Base class for attributes defining fast math behavior of arithmetic operations. - """ - - none_value = "none" - all_value = "fast" - - -# endregion From ee4094f7c9d46db50b4a40267584f9d257659c3e Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Wed, 20 Nov 2024 16:50:59 +0000 Subject: [PATCH 02/15] dialects: (builtin) make FloatAttr a TypedAttribute (#3488) Co-authored-by: Sasha Lopoukhine --- .../irdl/test_declarative_assembly_format.py | 13 +++++--- xdsl/dialects/builtin.py | 13 +++++++- xdsl/parser/base_parser.py | 33 +++++++++++++++++-- xdsl/printer.py | 7 ---- 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 9d3a1484f6..a58de17fae 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -3,7 +3,7 @@ import textwrap from collections.abc import Callable from io import StringIO -from typing import ClassVar, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar import pytest @@ -12,6 +12,8 @@ from xdsl.dialects.builtin import ( I32, BoolAttr, + Float64Type, + FloatAttr, IntegerAttr, MemRefType, ModuleOp, @@ -603,20 +605,21 @@ class OptionalAttributeOp(IRDLOperation): "program, generic_program", [ ( - "test.typed_attr 3", - '"test.typed_attr"() {"attr" = 3 : i32} : () -> ()', + "test.typed_attr 3 3.000000e+00", + '"test.typed_attr"() {"attr" = 3 : i32, "float_attr" = 3.000000e+00 : f64} : () -> ()', ), ], ) def test_typed_attribute_variable(program: str, generic_program: str): - """Test the parsing of optional operands""" + """Test the parsing of typed attributes""" @irdl_op_definition class TypedAttributeOp(IRDLOperation): name = "test.typed_attr" attr = attr_def(IntegerAttr[I32]) + float_attr = attr_def(FloatAttr[Annotated[Float64Type, Float64Type()]]) - assembly_format = "$attr attr-dict" + assembly_format = "$attr $float_attr attr-dict" ctx = MLContext() ctx.load_op(TypedAttributeOp) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index df0e3d89cb..3cbccf68bc 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -634,7 +634,7 @@ def __hash__(self): @irdl_attr_definition -class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute): +class FloatAttr(Generic[_FloatAttrType], TypedAttribute[_FloatAttrType]): name = "float" value: ParameterDef[FloatData] @@ -668,6 +668,17 @@ def __init__( raise ValueError(f"Invalid bitwidth: {type}") super().__init__([data_attr, type]) + @staticmethod + def parse_with_type( + parser: AttrParser, + type: AttributeInvT, + ) -> TypedAttribute[AttributeInvT]: + assert isinstance(type, AnyFloat) + return FloatAttr(parser.parse_float(), type) + + def print_without_type(self, printer: Printer): + return printer.print_float(self) + AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat] AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr) diff --git a/xdsl/parser/base_parser.py b/xdsl/parser/base_parser.py index 792f87e9e8..3586c7b6cc 100644 --- a/xdsl/parser/base_parser.py +++ b/xdsl/parser/base_parser.py @@ -359,6 +359,36 @@ def parse_integer( "Expected integer literal" + context_msg, ) + def parse_optional_float( + self, + *, + allow_negative: bool = True, + ) -> float | None: + """ + Parse a (possibly negative) float, if present. + """ + is_negative = False + if allow_negative: + is_negative = self._parse_optional_token(Token.Kind.MINUS) is not None + + if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: + value = value.get_float_value() + return -value if is_negative else value + + def parse_float( + self, + *, + allow_negative: bool = True, + ) -> float: + """ + Parse a (possibly negative) float. + """ + + return self.expect( + lambda: self.parse_optional_float(allow_negative=allow_negative), + "Expected float literal", + ) + def parse_optional_number( self, *, allow_boolean: bool = False ) -> int | float | None: @@ -376,8 +406,7 @@ def parse_optional_number( ) is not None: return -value if is_negative else value - if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: - value = value.get_float_value() + if (value := self.parse_optional_float(allow_negative=False)) is not None: return -value if is_negative else value if is_negative: diff --git a/xdsl/printer.py b/xdsl/printer.py index 792e08c3b7..f9b8bb4ab4 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -561,13 +561,6 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_identifier_or_string_literal(ref.data) return - if isinstance(attribute, FloatAttr): - attr = cast(AnyFloatAttr, attribute) - self.print_float(attr) - self.print_string(" : ") - self.print_attribute(attr.type) - return - # Complex types have MLIR shorthands but XDSL does not. if isinstance(attribute, ComplexType): self.print_string("complex<") From 25adb58c5a13520df76601e9e7fe9deb9d1505a0 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 17:00:11 +0000 Subject: [PATCH 03/15] dialects: (stream) simplify constr helper on stream attributes (#3473) I'm working towards removing the stream dialect, and would like to simplify it as much as possible before copying to both memref_stream and snitch_stream. I don't think the constr overload is pulling its weight, it's much simpler just to provide a shorthand for the BaseAttr constraint, and only return a ParamAttrConstraint from constr. --- tests/test_dialect_utils.py | 2 +- xdsl/dialects/arith.py | 2 +- xdsl/dialects/csl/csl.py | 2 +- xdsl/dialects/csl/csl_stencil.py | 2 +- xdsl/dialects/experimental/air.py | 2 +- xdsl/dialects/func.py | 2 +- xdsl/dialects/linalg.py | 2 +- xdsl/dialects/llvm.py | 2 +- xdsl/dialects/memref.py | 2 +- xdsl/dialects/memref_stream.py | 4 +- xdsl/dialects/omp.py | 2 +- xdsl/dialects/riscv.py | 2 +- xdsl/dialects/riscv_func.py | 2 +- xdsl/dialects/riscv_scf.py | 2 +- xdsl/dialects/riscv_snitch.py | 2 +- xdsl/dialects/scf.py | 2 +- xdsl/dialects/stream.py | 82 ++++--------------- xdsl/transforms/arith_add_fastmath.py | 2 +- .../canonicalization_patterns/riscv.py | 2 +- 19 files changed, 35 insertions(+), 85 deletions(-) diff --git a/tests/test_dialect_utils.py b/tests/test_dialect_utils.py index 171d95373f..3c80c06622 100644 --- a/tests/test_dialect_utils.py +++ b/tests/test_dialect_utils.py @@ -4,7 +4,7 @@ from xdsl.context import MLContext from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32 -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_dynamic_index_list_with_types, parse_dynamic_index_list_without_types, parse_dynamic_index_with_type, diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 34e09a4241..7300fc2f53 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -22,7 +22,7 @@ UnrankedTensorType, VectorType, ) -from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index f17bb8be7c..d5257a6e08 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -37,7 +37,7 @@ SymbolRefAttr, TensorType, ) -from xdsl.dialects.utils import parse_func_op_like, print_func_op_like +from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index eca5bc9cad..c9956ae69b 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -17,7 +17,7 @@ TensorType, ) from xdsl.dialects.experimental import dmp -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/experimental/air.py b/xdsl/dialects/experimental/air.py index f87b2b1795..2e592257b1 100644 --- a/xdsl/dialects/experimental/air.py +++ b/xdsl/dialects/experimental/air.py @@ -20,7 +20,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 8131cdf6df..480676f996 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -10,7 +10,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 854ed50c86..5cdc7e08d3 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -24,7 +24,7 @@ TensorType, i64, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, ) from xdsl.ir import ( diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index dd3752ff43..10d34717a5 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -24,7 +24,7 @@ i32, i64, ) -from xdsl.dialects.utils import FastMathAttrBase +from xdsl.dialects.utils.fast_math import FastMathAttrBase from xdsl.ir import ( Attribute, BitEnumAttribute, diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index bfc9bf8d8c..eb50974e89 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -30,7 +30,7 @@ i32, i64, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_dynamic_index_list_without_types, print_dynamic_index_list, ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 9335d48c2a..13ef7a7d30 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -25,7 +25,7 @@ IntegerType, StringAttr, ) -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/omp.py b/xdsl/dialects/omp.py index c83058e145..2eec70a374 100644 --- a/xdsl/dialects/omp.py +++ b/xdsl/dialects/omp.py @@ -9,7 +9,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index b9cdfb34fd..84e1cc631e 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -24,7 +24,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index d9e8a63704..610ceee50a 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -12,7 +12,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/riscv_scf.py b/xdsl/dialects/riscv_scf.py index cea09406c8..67e16f80e9 100644 --- a/xdsl/dialects/riscv_scf.py +++ b/xdsl/dialects/riscv_scf.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 25d0cafcaa..94ce963e95 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -33,7 +33,7 @@ print_immediate_value, si12, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index b29cfbb42c..0c2ac49230 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -12,7 +12,7 @@ SignlessIntegerConstraint, i64, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 3c2d5ea2d6..8486b303cd 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Generic, TypeVar, cast, overload +from typing import ClassVar, Generic, TypeVar, cast from typing_extensions import Self @@ -46,30 +46,10 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type - @overload @staticmethod def constr( - *, - element_type: None = None, - ) -> BaseAttr[StreamType[Attribute]]: ... - - @overload - @staticmethod - def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[StreamType[Attribute]] - | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[StreamType[Attribute]](StreamType) + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -79,68 +59,38 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" - @overload @staticmethod def constr( - *, - element_type: None = None, - ) -> BaseAttr[ReadableStreamType[Attribute]]: ... - - @overload - @staticmethod - def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[ReadableStreamType[Attribute]] - | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( ReadableStreamType, (element_type,) ) +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) + + @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[WritableStreamType[Attribute]]: ... - - @overload @staticmethod def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[WritableStreamType[Attribute]] - | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( WritableStreamType, (element_type,) ) +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) + + class ReadOperation(IRDLOperation, abc.ABC): """ Abstract base class for operations that read from a stream. @@ -148,7 +98,7 @@ class ReadOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType.constr(element_type=T)) + stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -182,7 +132,7 @@ class WriteOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType.constr(element_type=T)) + stream = operand_def(WritableStreamType.constr(T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) diff --git a/xdsl/transforms/arith_add_fastmath.py b/xdsl/transforms/arith_add_fastmath.py index c60f54bd12..e9221be204 100644 --- a/xdsl/transforms/arith_add_fastmath.py +++ b/xdsl/transforms/arith_add_fastmath.py @@ -4,7 +4,7 @@ from typing import Literal from xdsl.dialects import arith, builtin -from xdsl.dialects.utils import FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathFlag from xdsl.passes import MLContext, ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 519c5b0813..6588f84cf9 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -2,7 +2,7 @@ from xdsl.dialects import riscv, riscv_snitch from xdsl.dialects.builtin import IntegerAttr -from xdsl.dialects.utils import FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathFlag from xdsl.ir import OpResult, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, From afa862db1087fc0625658f0cc62ae8587eff8bca Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 17:10:12 +0000 Subject: [PATCH 04/15] Revert "dialects: (stream) simplify constr helper on stream attributes" (#3493) Reverts xdslproject/xdsl#3473 --- tests/test_dialect_utils.py | 2 +- xdsl/dialects/arith.py | 2 +- xdsl/dialects/csl/csl.py | 2 +- xdsl/dialects/csl/csl_stencil.py | 2 +- xdsl/dialects/experimental/air.py | 2 +- xdsl/dialects/func.py | 2 +- xdsl/dialects/linalg.py | 2 +- xdsl/dialects/llvm.py | 2 +- xdsl/dialects/memref.py | 2 +- xdsl/dialects/memref_stream.py | 4 +- xdsl/dialects/omp.py | 2 +- xdsl/dialects/riscv.py | 2 +- xdsl/dialects/riscv_func.py | 2 +- xdsl/dialects/riscv_scf.py | 2 +- xdsl/dialects/riscv_snitch.py | 2 +- xdsl/dialects/scf.py | 2 +- xdsl/dialects/stream.py | 82 +++++++++++++++---- xdsl/transforms/arith_add_fastmath.py | 2 +- .../canonicalization_patterns/riscv.py | 2 +- 19 files changed, 85 insertions(+), 35 deletions(-) diff --git a/tests/test_dialect_utils.py b/tests/test_dialect_utils.py index 3c80c06622..171d95373f 100644 --- a/tests/test_dialect_utils.py +++ b/tests/test_dialect_utils.py @@ -4,7 +4,7 @@ from xdsl.context import MLContext from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32 -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_dynamic_index_list_with_types, parse_dynamic_index_list_without_types, parse_dynamic_index_with_type, diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 7300fc2f53..34e09a4241 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -22,7 +22,7 @@ UnrankedTensorType, VectorType, ) -from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index d5257a6e08..f17bb8be7c 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -37,7 +37,7 @@ SymbolRefAttr, TensorType, ) -from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like +from xdsl.dialects.utils import parse_func_op_like, print_func_op_like from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index c9956ae69b..eca5bc9cad 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -17,7 +17,7 @@ TensorType, ) from xdsl.dialects.experimental import dmp -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/experimental/air.py b/xdsl/dialects/experimental/air.py index 2e592257b1..f87b2b1795 100644 --- a/xdsl/dialects/experimental/air.py +++ b/xdsl/dialects/experimental/air.py @@ -20,7 +20,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 480676f996..8131cdf6df 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -10,7 +10,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 5cdc7e08d3..854ed50c86 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -24,7 +24,7 @@ TensorType, i64, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, ) from xdsl.ir import ( diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index 10d34717a5..dd3752ff43 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -24,7 +24,7 @@ i32, i64, ) -from xdsl.dialects.utils.fast_math import FastMathAttrBase +from xdsl.dialects.utils import FastMathAttrBase from xdsl.ir import ( Attribute, BitEnumAttribute, diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index eb50974e89..bfc9bf8d8c 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -30,7 +30,7 @@ i32, i64, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_dynamic_index_list_without_types, print_dynamic_index_list, ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 13ef7a7d30..9335d48c2a 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -25,7 +25,7 @@ IntegerType, StringAttr, ) -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/omp.py b/xdsl/dialects/omp.py index 2eec70a374..c83058e145 100644 --- a/xdsl/dialects/omp.py +++ b/xdsl/dialects/omp.py @@ -9,7 +9,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index 84e1cc631e..b9cdfb34fd 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -24,7 +24,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index 610ceee50a..d9e8a63704 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -12,7 +12,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/riscv_scf.py b/xdsl/dialects/riscv_scf.py index 67e16f80e9..cea09406c8 100644 --- a/xdsl/dialects/riscv_scf.py +++ b/xdsl/dialects/riscv_scf.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 94ce963e95..25d0cafcaa 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -33,7 +33,7 @@ print_immediate_value, si12, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index 0c2ac49230..b29cfbb42c 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -12,7 +12,7 @@ SignlessIntegerConstraint, i64, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 8486b303cd..3c2d5ea2d6 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Generic, TypeVar, cast +from typing import ClassVar, Generic, TypeVar, cast, overload from typing_extensions import Self @@ -46,10 +46,30 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type + @overload @staticmethod def constr( + *, + element_type: None = None, + ) -> BaseAttr[StreamType[Attribute]]: ... + + @overload + @staticmethod + def constr( + *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... + + @staticmethod + def constr( + *, + element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, + ) -> ( + BaseAttr[StreamType[Attribute]] + | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] + ): + if element_type is None: + return BaseAttr[StreamType[Attribute]](StreamType) return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -59,38 +79,68 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" + @overload @staticmethod def constr( + *, + element_type: None = None, + ) -> BaseAttr[ReadableStreamType[Attribute]]: ... + + @overload + @staticmethod + def constr( + *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... + + @staticmethod + def constr( + *, + element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, + ) -> ( + BaseAttr[ReadableStreamType[Attribute]] + | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] + ): + if element_type is None: + return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( ReadableStreamType, (element_type,) ) -AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( - ReadableStreamType -) - - @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" + @overload + @staticmethod + def constr( + *, + element_type: None = None, + ) -> BaseAttr[WritableStreamType[Attribute]]: ... + + @overload @staticmethod def constr( + *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... + + @staticmethod + def constr( + *, + element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, + ) -> ( + BaseAttr[WritableStreamType[Attribute]] + | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] + ): + if element_type is None: + return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( WritableStreamType, (element_type,) ) -AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( - WritableStreamType -) - - class ReadOperation(IRDLOperation, abc.ABC): """ Abstract base class for operations that read from a stream. @@ -98,7 +148,7 @@ class ReadOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType.constr(T)) + stream = operand_def(ReadableStreamType.constr(element_type=T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -132,7 +182,7 @@ class WriteOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType.constr(T)) + stream = operand_def(WritableStreamType.constr(element_type=T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) diff --git a/xdsl/transforms/arith_add_fastmath.py b/xdsl/transforms/arith_add_fastmath.py index e9221be204..c60f54bd12 100644 --- a/xdsl/transforms/arith_add_fastmath.py +++ b/xdsl/transforms/arith_add_fastmath.py @@ -4,7 +4,7 @@ from typing import Literal from xdsl.dialects import arith, builtin -from xdsl.dialects.utils.fast_math import FastMathFlag +from xdsl.dialects.utils import FastMathFlag from xdsl.passes import MLContext, ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 6588f84cf9..519c5b0813 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -2,7 +2,7 @@ from xdsl.dialects import riscv, riscv_snitch from xdsl.dialects.builtin import IntegerAttr -from xdsl.dialects.utils.fast_math import FastMathFlag +from xdsl.dialects.utils import FastMathFlag from xdsl.ir import OpResult, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, From 680b5563040942a1138a36787a0aea64aeac70eb Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 17:26:19 +0000 Subject: [PATCH 05/15] dialects: (stream) simplify constr helper on stream attributes (#3494) Trying #3473 again --- xdsl/dialects/memref_stream.py | 2 +- xdsl/dialects/stream.py | 82 +++++++--------------------------- 2 files changed, 17 insertions(+), 67 deletions(-) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 9335d48c2a..f97bd909df 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 3c2d5ea2d6..8486b303cd 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Generic, TypeVar, cast, overload +from typing import ClassVar, Generic, TypeVar, cast from typing_extensions import Self @@ -46,30 +46,10 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type - @overload @staticmethod def constr( - *, - element_type: None = None, - ) -> BaseAttr[StreamType[Attribute]]: ... - - @overload - @staticmethod - def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[StreamType[Attribute]] - | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[StreamType[Attribute]](StreamType) + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -79,68 +59,38 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" - @overload @staticmethod def constr( - *, - element_type: None = None, - ) -> BaseAttr[ReadableStreamType[Attribute]]: ... - - @overload - @staticmethod - def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[ReadableStreamType[Attribute]] - | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( ReadableStreamType, (element_type,) ) +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) + + @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[WritableStreamType[Attribute]]: ... - - @overload @staticmethod def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[WritableStreamType[Attribute]] - | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( WritableStreamType, (element_type,) ) +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) + + class ReadOperation(IRDLOperation, abc.ABC): """ Abstract base class for operations that read from a stream. @@ -148,7 +98,7 @@ class ReadOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType.constr(element_type=T)) + stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -182,7 +132,7 @@ class WriteOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType.constr(element_type=T)) + stream = operand_def(WritableStreamType.constr(T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) From f7088a0e37e1dd3783ef10ba1132113ec3fc84ce Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 17:51:27 +0000 Subject: [PATCH 06/15] dialects: (stream) use assembly format for stream read and write ops (#3484) --- xdsl/dialects/stream.py | 44 ++++------------------------------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 8486b303cd..e95e051fc8 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -3,8 +3,6 @@ import abc from typing import ClassVar, Generic, TypeVar, cast -from typing_extensions import Self - from xdsl.dialects.builtin import ContainerType from xdsl.ir import ( Attribute, @@ -25,8 +23,6 @@ operand_def, result_def, ) -from xdsl.parser import Parser -from xdsl.printer import Printer _StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) _StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) @@ -101,6 +97,8 @@ class ReadOperation(IRDLOperation, abc.ABC): stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) + assembly_format = "`from` $stream attr-dict `:` type($res)" + def __init__(self, stream: SSAValue, result_type: Attribute | None = None): if result_type is None: assert isinstance(stream_type := stream.type, ReadableStreamType) @@ -108,21 +106,6 @@ def __init__(self, stream: SSAValue, result_type: Attribute | None = None): result_type = stream_type.element_type super().__init__(operands=[stream], result_types=[result_type]) - @classmethod - def parse(cls, parser: Parser) -> Self: - parser.parse_characters("from") - unresolved = parser.parse_unresolved_operand() - parser.parse_punctuation(":") - result_type = parser.parse_attribute() - resolved = parser.resolve_operand(unresolved, ReadableStreamType(result_type)) - return cls(resolved, result_type) - - def print(self, printer: Printer): - printer.print_string(" from ") - printer.print(self.stream) - printer.print_string(" : ") - printer.print_attribute(self.res.type) - class WriteOperation(IRDLOperation, abc.ABC): """ @@ -134,30 +117,11 @@ class WriteOperation(IRDLOperation, abc.ABC): value = operand_def(T) stream = operand_def(WritableStreamType.constr(T)) + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) - @classmethod - def parse(cls, parser: Parser) -> Self: - unresolved_value = parser.parse_unresolved_operand() - parser.parse_characters("to") - unresolved_stream = parser.parse_unresolved_operand() - parser.parse_punctuation(":") - result_type = parser.parse_attribute() - resolved_value = parser.resolve_operand(unresolved_value, result_type) - resolved_stream = parser.resolve_operand( - unresolved_stream, WritableStreamType(result_type) - ) - return cls(resolved_value, resolved_stream) - - def print(self, printer: Printer): - printer.print_string(" ") - printer.print_ssa_value(self.value) - printer.print_string(" to ") - printer.print_ssa_value(self.stream) - printer.print_string(" : ") - printer.print_attribute(self.value.type) - Stream = Dialect( "stream", From 560390414f812dd534e6be3ee8c098c18ec05525 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 17:58:18 +0000 Subject: [PATCH 07/15] dialects: (stream) remove abstract stream read and write operations (#3486) Another incremental change towards getting rid of the stream dialect, now that the definitions are small enough it feels like the right move is to duplicate the definitions, as there's no particular reason for them to be coupled. --- xdsl/dialects/memref_stream.py | 35 ++++++++++++++++++++++++-- xdsl/dialects/riscv_snitch.py | 29 ++++++++++++++++++++-- xdsl/dialects/stream.py | 45 +--------------------------------- 3 files changed, 61 insertions(+), 48 deletions(-) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index f97bd909df..81af928429 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -46,6 +46,7 @@ opt_prop_def, prop_def, region_def, + result_def, traits_def, var_operand_def, ) @@ -187,14 +188,44 @@ def offsets(self) -> tuple[tuple[int, ...], ...]: @irdl_op_definition -class ReadOp(stream.ReadOperation): +class ReadOp(IRDLOperation): name = "memref_stream.read" + T: ClassVar = VarConstraint("T", AnyAttr()) + + stream = operand_def(stream.ReadableStreamType.constr(T)) + res = result_def(T) + + assembly_format = "`from` $stream attr-dict `:` type($res)" + + def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): + if result_type is None: + assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) + stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + result_type = stream_type.element_type + super().__init__(operands=[stream_val], result_types=[result_type]) + + def assembly_line(self) -> str | None: + return None + @irdl_op_definition -class WriteOp(stream.WriteOperation): +class WriteOp(IRDLOperation): name = "memref_stream.write" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + stream = operand_def(stream.WritableStreamType.constr(T)) + + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + + def __init__(self, value: SSAValue, stream: SSAValue): + super().__init__(operands=[value, stream]) + + def assembly_line(self) -> str | None: + return None + @irdl_op_definition class StreamingRegionOp(IRDLOperation): diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 25d0cafcaa..1bcbaaa670 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -40,6 +40,7 @@ ) from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( + AnyAttr, VarConstraint, attr_def, base, @@ -155,17 +156,41 @@ def assembly_line(self) -> str | None: @irdl_op_definition -class ReadOp(stream.ReadOperation, RISCVAsmOperation): +class ReadOp(RISCVAsmOperation): name = "riscv_snitch.read" + T: ClassVar = VarConstraint("T", AnyAttr()) + + stream = operand_def(stream.ReadableStreamType.constr(T)) + res = result_def(T) + + assembly_format = "`from` $stream attr-dict `:` type($res)" + + def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): + if result_type is None: + assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) + stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + result_type = stream_type.element_type + super().__init__(operands=[stream_val], result_types=[result_type]) + def assembly_line(self) -> str | None: return None @irdl_op_definition -class WriteOp(stream.WriteOperation, RISCVAsmOperation): +class WriteOp(RISCVAsmOperation): name = "riscv_snitch.write" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + stream = operand_def(stream.WritableStreamType.constr(T)) + + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + + def __init__(self, value: SSAValue, stream: SSAValue): + super().__init__(operands=[value, stream]) + def assembly_line(self) -> str | None: return None diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index e95e051fc8..7ed748f393 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,27 +1,20 @@ from __future__ import annotations -import abc -from typing import ClassVar, Generic, TypeVar, cast +from typing import Generic, TypeVar from xdsl.dialects.builtin import ContainerType from xdsl.ir import ( Attribute, Dialect, ParametrizedAttribute, - SSAValue, TypeAttribute, ) from xdsl.irdl import ( - AnyAttr, BaseAttr, GenericAttrConstraint, - IRDLOperation, ParamAttrConstraint, ParameterDef, - VarConstraint, irdl_attr_definition, - operand_def, - result_def, ) _StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) @@ -87,42 +80,6 @@ def constr( ) -class ReadOperation(IRDLOperation, abc.ABC): - """ - Abstract base class for operations that read from a stream. - """ - - T: ClassVar = VarConstraint("T", AnyAttr()) - - stream = operand_def(ReadableStreamType.constr(T)) - res = result_def(T) - - assembly_format = "`from` $stream attr-dict `:` type($res)" - - def __init__(self, stream: SSAValue, result_type: Attribute | None = None): - if result_type is None: - assert isinstance(stream_type := stream.type, ReadableStreamType) - stream_type = cast(ReadableStreamType[Attribute], stream_type) - result_type = stream_type.element_type - super().__init__(operands=[stream], result_types=[result_type]) - - -class WriteOperation(IRDLOperation, abc.ABC): - """ - Abstract base class for operations that write to a stream. - """ - - T: ClassVar = VarConstraint("T", AnyAttr()) - - value = operand_def(T) - stream = operand_def(WritableStreamType.constr(T)) - - assembly_format = "$value `to` $stream attr-dict `:` type($value)" - - def __init__(self, value: SSAValue, stream: SSAValue): - super().__init__(operands=[value, stream]) - - Stream = Dialect( "stream", [], From cffac36af17d225c32e71fcfbacca19e2172841a Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 22 Nov 2024 08:32:07 +0000 Subject: [PATCH 08/15] misc: add get helper to ScopedDict (#3499) --- tests/utils/test_scoped_dict.py | 13 +++++++++++++ xdsl/utils/scoped_dict.py | 19 +++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_scoped_dict.py b/tests/utils/test_scoped_dict.py index cc536dd9cc..a92b18ac4b 100644 --- a/tests/utils/test_scoped_dict.py +++ b/tests/utils/test_scoped_dict.py @@ -32,3 +32,16 @@ def test_simple(): assert 3 not in table assert 3 in inner assert 4 not in inner + + +def test_get(): + parent = ScopedDict(local_scope={"a": 1, "b": 2}) + child = ScopedDict(parent, local_scope={"a": 3, "c": 4}) + + assert child.get("a") == 3 + assert child.get("b") == 2 + assert child.get("c") == 4 + assert child.get("d") is None + + assert child.get("a", 5) == 3 + assert child.get("d", 5) == 5 diff --git a/xdsl/utils/scoped_dict.py b/xdsl/utils/scoped_dict.py index ab95061fc8..6fab8ba667 100644 --- a/xdsl/utils/scoped_dict.py +++ b/xdsl/utils/scoped_dict.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, TypeVar +from typing import Generic, TypeVar, overload _Key = TypeVar("_Key") _Value = TypeVar("_Value") @@ -28,11 +28,26 @@ def __init__( parent: ScopedDict[_Key, _Value] | None = None, *, name: str | None = None, + local_scope: dict[_Key, _Value] | None = None, ) -> None: - self._local_scope = {} + self._local_scope = {} if local_scope is None else local_scope self.parent = parent self.name = name + @overload + def get(self, key: _Key, default: None = None) -> _Value | None: ... + + @overload + def get(self, key: _Key, default: _Value) -> _Value: ... + + def get(self, key: _Key, default: _Value | None = None) -> _Value | None: + local = self._local_scope.get(key) + if local is not None: + return local + if self.parent is None: + return default + return self.parent.get(key, default) + def __getitem__(self, key: _Key) -> _Value: """ Fetch key from environment. Attempts to first fetch from current scope, From 4e09b2d80dc1e25c99ab0082301bbb6b9bc9aaa1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 08:47:26 +0000 Subject: [PATCH 09/15] pip prod(deps): bump marimo from 0.9.20 to 0.9.21 (#3502) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [marimo](https://github.com/marimo-team/marimo) from 0.9.20 to 0.9.21.
Release notes

Sourced from marimo's releases.

0.9.21

What's Changed

Full Changelog: https://github.com/marimo-team/marimo/compare/0.9.20...0.9.21

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=marimo&package-manager=pip&previous-version=0.9.20&new-version=0.9.21)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2a677edd06..f79ac3006b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dev = [ "nbval<0.12", "filecheck==1.0.1", "lit<19.0.0", - "marimo==0.9.20", + "marimo==0.9.21", "pre-commit==4.0.1", "ruff==0.7.4", "asv<0.7", From 150b84bd6b908af2d6378ad7fb69ff2ea24ed3c2 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 22 Nov 2024 08:48:00 +0000 Subject: [PATCH 10/15] API: add InferenceContext (#3500) A wrapper class to encapsulate the inference context. I have a feeling that we might want to split out the inference context into range and non-range variable assignments, this will hide the change from the signature of the `infer` method. --- tests/dialects/test_bufferization.py | 5 +- tests/irdl/test_attr_constraint.py | 7 ++- xdsl/dialects/bufferization.py | 6 +- xdsl/irdl/constraints.py | 59 +++++++++---------- xdsl/irdl/declarative_assembly_format.py | 9 ++- .../declarative_assembly_format_parser.py | 3 +- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/tests/dialects/test_bufferization.py b/tests/dialects/test_bufferization.py index 169612fcce..19490a89e8 100644 --- a/tests/dialects/test_bufferization.py +++ b/tests/dialects/test_bufferization.py @@ -23,6 +23,7 @@ from xdsl.ir import Attribute from xdsl.irdl import ( EqAttrConstraint, + InferenceContext, IRDLOperation, VarConstraint, irdl_op_definition, @@ -39,13 +40,13 @@ def test_tensor_from_memref_inference(): EqAttrConstraint(MemRefType(f64, [10, 20, 30])) ) assert constr2.can_infer(set()) - assert constr2.infer({}) == TensorType(f64, [10, 20, 30]) + assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30]) constr3 = TensorFromMemrefConstraint( EqAttrConstraint(UnrankedMemrefType.from_type(f64)) ) assert constr3.can_infer(set()) - assert constr3.infer({}) == UnrankedTensorType(f64) + assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64) @irdl_op_definition diff --git a/tests/irdl/test_attr_constraint.py b/tests/irdl/test_attr_constraint.py index f37748bc87..e18b6f6773 100644 --- a/tests/irdl/test_attr_constraint.py +++ b/tests/irdl/test_attr_constraint.py @@ -10,6 +10,7 @@ AttrConstraint, BaseAttr, EqAttrConstraint, + InferenceContext, ParamAttrConstraint, ParameterDef, VarConstraint, @@ -77,7 +78,7 @@ class WrapAttr(BaseWrapAttr): ... ) assert constr.can_infer(set()) - assert constr.infer({}) == WrapAttr((StringAttr("Hello"),)) + assert constr.infer(InferenceContext()) == WrapAttr((StringAttr("Hello"),)) var_constr = ParamAttrConstraint( WrapAttr, @@ -92,7 +93,7 @@ class WrapAttr(BaseWrapAttr): ... ) assert var_constr.can_infer({"T"}) - assert var_constr.infer({"T": StringAttr("Hello")}) == WrapAttr( + assert var_constr.infer(InferenceContext({"T": StringAttr("Hello")})) == WrapAttr( (StringAttr("Hello"),) ) @@ -127,7 +128,7 @@ class NoParamAttr(BaseNoParamAttr): ... constr = BaseAttr(NoParamAttr) assert constr.can_infer(set()) - assert constr.infer({}) == NoParamAttr() + assert constr.infer(InferenceContext()) == NoParamAttr() base_constr = BaseAttr(BaseNoParamAttr) assert not base_constr.can_infer(set()) diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py index 3cd932f379..3793b9d7da 100644 --- a/xdsl/dialects/bufferization.py +++ b/xdsl/dialects/bufferization.py @@ -20,8 +20,8 @@ from xdsl.irdl import ( AttrSizedOperandSegments, ConstraintContext, - ConstraintVariableType, GenericAttrConstraint, + InferenceContext, IRDLOperation, VarConstraint, irdl_op_definition, @@ -53,9 +53,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.memref_constraint.can_infer(var_constraint_names) def infer( - self, variables: dict[str, ConstraintVariableType] + self, context: InferenceContext ) -> TensorType[Attribute] | UnrankedTensorType[Attribute]: - memref_type = self.memref_constraint.infer(variables) + memref_type = self.memref_constraint.infer(context) if isinstance(memref_type, MemRefType): return TensorType(memref_type.element_type, memref_type.shape) return UnrankedTensorType(memref_type.element_type) diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index bf21d83992..488961d7df 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -61,6 +61,15 @@ def update(self, other: ConstraintContext): Possible types that a constraint variable can have. """ + +@dataclass +class InferenceContext: + variables: dict[str, ConstraintVariableType] = field(default_factory=dict) + """ + A mapping from variable names to the inferred attribute or attribute sequence. + """ + + _T = TypeVar("_T") @@ -156,7 +165,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: # By default, we cannot infer anything. return False - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: """ Infer the attribute given the the values for all variables. @@ -228,8 +237,8 @@ def verify( def get_variable_extractors(self) -> dict[str, VarExtractor[AttributeCovT]]: return {self.name: IdExtractor()} - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: - v = variables[self.name] + def infer(self, context: InferenceContext) -> AttributeCovT: + v = context.variables[self.name] return cast(AttributeCovT, v) def can_infer(self, var_constraint_names: Set[str]) -> bool: @@ -272,7 +281,7 @@ def verify( def can_infer(self, var_constraint_names: Set[str]) -> bool: return True - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: return self.attr def get_unique_base(self) -> type[Attribute] | None: @@ -303,7 +312,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: and not self.attr.get_irdl_definition().parameters ) - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: assert issubclass(self.attr, ParametrizedAttribute) attr = self.attr.new(()) return attr @@ -439,10 +448,10 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: constr.can_infer(var_constraint_names) for constr in self.attr_constrs ) - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: for constr in self.attr_constrs: - if constr.can_infer(variables.keys()): - return constr.infer(variables) + if constr.can_infer(context.variables.keys()): + return constr.infer(context) raise ValueError("Cannot infer attribute from constraint") def get_unique_base(self) -> type[Attribute] | None: @@ -535,10 +544,8 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: constr.can_infer(var_constraint_names) for constr in self.param_constrs ) - def infer( - self, variables: dict[str, ConstraintVariableType] - ) -> ParametrizedAttributeCovT: - params = tuple(constr.infer(variables) for constr in self.param_constrs) + def infer(self, context: InferenceContext) -> ParametrizedAttributeCovT: + params = tuple(constr.infer(context) for constr in self.param_constrs) attr = self.base_attr.new(params) return attr @@ -590,8 +597,8 @@ def get_unique_base(self) -> type[Attribute] | None: def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.constr.can_infer(var_constraint_names) - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: - return self.constr.infer(variables) + def infer(self, context: InferenceContext) -> AttributeCovT: + return self.constr.infer(context) @dataclass(frozen=True) @@ -628,9 +635,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: # By default, we cannot infer anything. return False - def infer( - self, length: int, variables: dict[str, ConstraintVariableType] - ) -> Sequence[AttributeCovT]: + def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]: """ Infer the attribute given the the values for all variables. @@ -684,12 +689,8 @@ def get_variable_extractors( def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.name in var_constraint_names - def infer( - self, - length: int, - variables: dict[str, ConstraintVariableType], - ) -> Sequence[AttributeCovT]: - v = variables[self.name] + def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]: + v = context.variables[self.name] return cast(Sequence[AttributeCovT], v) @@ -715,9 +716,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: def infer( self, length: int, - variables: dict[str, ConstraintVariableType], + context: InferenceContext, ) -> Sequence[AttributeCovT]: - attr = self.constr.infer(variables) + attr = self.constr.infer(context) return (attr,) * length @@ -756,12 +757,8 @@ def get_variable_extractors( def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.constr.can_infer(var_constraint_names) - def infer( - self, - length: int, - variables: dict[str, ConstraintVariableType], - ) -> Sequence[AttributeCovT]: - return (self.constr.infer(variables),) + def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]: + return (self.constr.infer(context),) def range_constr_coercion( diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 40d8a6841d..e6e6c1113a 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -22,6 +22,7 @@ ) from xdsl.irdl import ( ConstraintVariableType, + InferenceContext, IRDLOperation, IRDLOperationInvT, OpDef, @@ -186,7 +187,7 @@ def resolve_operand_types(self, state: ParsingState, op_def: OpDef) -> None: range_length = len(operand) if isinstance(operand, Sequence) else 1 operand_type = operand_def.constr.infer( range_length, - state.variables, + InferenceContext(state.variables), ) resolved_operand_type: Attribute | Sequence[Attribute] if isinstance(operand_def, OptionalDef): @@ -220,7 +221,7 @@ def resolve_result_types(self, state: ParsingState, op_def: OpDef) -> None: range_length = 1 inferred_result_types = result_def.constr.infer( range_length, - state.variables, + InferenceContext(state.variables), ) resolved_result_type = inferred_result_types[0] state.result_types[i] = resolved_result_type @@ -885,7 +886,9 @@ def parse(self, parser: Parser, state: ParsingState) -> None: ): attr = unique_base.new(unique_base.parse_parameters(parser)) elif issubclass(unique_base, Data): - attr = unique_base.new(unique_base.parse_parameter(parser)) # pyright: ignore[reportUnknownVariableType] + attr = unique_base.new( # pyright: ignore[reportUnknownVariableType] + unique_base.parse_parameter(parser) + ) else: raise ValueError("Attributes must be Data or ParameterizedAttribute.") if self.is_property: diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 4977026c81..d4e413a2fd 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -19,6 +19,7 @@ AttrSizedOperandSegments, AttrSizedSegments, ConstraintVariableType, + InferenceContext, OpDef, OptionalDef, OptOperandDef, @@ -528,7 +529,7 @@ def parse_optional_variable( unique_base.get_type_index() ] if type_constraint.can_infer(set()): - unique_type = type_constraint.infer({}) + unique_type = type_constraint.infer(InferenceContext()) if ( unique_base is not None and unique_base in Builtin.attributes From 8d8feafb41e7f7adf59a0c5a32bfb7774efc150b Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 22 Nov 2024 11:42:52 +0000 Subject: [PATCH 11/15] core: refactor assembly format directives (#3501) Refactors the assembly format directives to: - Remove some unnecessary classes that made the system rigid - Separate directives that go in a `type` clause from format directives --- .../irdl/test_declarative_assembly_format.py | 2 +- xdsl/irdl/declarative_assembly_format.py | 489 +++++++++--------- .../declarative_assembly_format_parser.py | 207 ++++---- 3 files changed, 333 insertions(+), 365 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index a58de17fae..dded57875f 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -696,7 +696,7 @@ def test_unknown_variable(): """Test that variables should refer to an element in the operation.""" with pytest.raises( PyRDLOpDefinitionError, - match="expected variable to refer to an operand, attribute, region, result, or successor", + match="expected variable to refer to an operand, attribute, region, or successor", ): @irdl_op_definition diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index e6e6c1113a..977f47e7a3 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, TypeAlias from xdsl.dialects.builtin import UnitAttr from xdsl.ir import ( @@ -238,19 +238,11 @@ def print(self, printer: Printer, op: IRDLOperation) -> None: @dataclass(frozen=True) -class FormatDirective(ABC): - """A format directive for operation format.""" - - @abstractmethod - def parse(self, parser: Parser, state: ParsingState) -> None: ... +class Directive(ABC): + """An assembly format directive""" - @abstractmethod - def print( - self, printer: Printer, state: PrintingState, op: IRDLOperation - ) -> None: ... - -class AnchorableDirective(FormatDirective, ABC): +class AnchorableDirective(Directive, ABC): """ Base class for Directive usable as anchors to optional groups. """ @@ -263,6 +255,18 @@ def is_present(self, op: IRDLOperation) -> bool: ... +class FormatDirective(Directive, ABC): + """A format directive for operation format.""" + + @abstractmethod + def parse(self, parser: Parser, state: ParsingState) -> None: ... + + @abstractmethod + def print( + self, printer: Printer, state: PrintingState, op: IRDLOperation + ) -> None: ... + + class OptionallyParsableDirective(FormatDirective, ABC): """ Base class for Directive that can be optionally parsed. @@ -272,7 +276,7 @@ class OptionallyParsableDirective(FormatDirective, ABC): @abstractmethod def parse_optional(self, parser: Parser, state: ParsingState) -> bool: """ - Try parsing the directive and return if it was present. + Try parsing the directive and return True if it was present. """ ... @@ -280,71 +284,175 @@ def parse(self, parser: Parser, state: ParsingState) -> None: self.parse_optional(parser, state) -class VariadicLikeFormatDirective(AnchorableDirective, ABC): +class VariadicLikeFormatDirective( + OptionallyParsableDirective, AnchorableDirective, ABC +): """ - Baseclass to help keep typechecking simple. - VariadicLike is mostly Variadic or Optional: Whatever directive that can accept - having nothing to parse. + A directive which parses/prints multiple objects separated by commas. + Such directives can not be followed by comma literals. """ - pass + def set_empty(self, state: ParsingState): + """ + Set the appropriate field of the parsing state to be empty. + Used when a variable appears in an optional group which is not parsed. + """ + return -@dataclass(frozen=True) -class VariableDirective(FormatDirective, ABC): +class TypeableDirective(Directive, ABC): """ - A variable directive, with the following format: - variable-directive ::= dollar-ident - The directive will request a space to be printed after. + Directives which can be used to set or get types. """ - name: str - """The variable name. This is only used for error message reporting.""" - index: int - """Index of the variable(operand or result) definition.""" + @abstractmethod + def set_type(self, type: Attribute, state: ParsingState) -> None: ... + + @abstractmethod + def get_type(self, op: IRDLOperation) -> Attribute: ... -class TypeDirective(VariableDirective, ABC): +class VariadicTypeableDirective(AnchorableDirective, ABC): """ - Base class for Directive meant to parse types. + Directives which can set or get multiple types. """ - pass + @abstractmethod + def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: ... + @abstractmethod + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: ... -class RegionDirective(OptionallyParsableDirective, ABC): + +class OptionalTypeableDirective(AnchorableDirective, ABC): """ - Baseclass to help keep typechecking simple. - RegionDirective is for any RegionVariable, which are all OptionallyParsable. + Directives which can optionally set or get a single type. + """ + + @abstractmethod + def set_type(self, type: Attribute | None, state: ParsingState) -> None: ... + + @abstractmethod + def get_type(self, op: IRDLOperation) -> Attribute | None: ... + + +AnyTypeableDirective: TypeAlias = ( + TypeableDirective | VariadicTypeableDirective | OptionalTypeableDirective +) + + +@dataclass(frozen=True) +class TypeDirective(FormatDirective): + """ + A directive which parses the type of a typeable directive, with format: + type-directive ::= type(typeable-directive) + """ + + inner: TypeableDirective + + def parse(self, parser: Parser, state: ParsingState) -> None: + ty = parser.parse_type() + self.inner.set_type(ty, state) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_attribute(self.inner.get_type(op)) + state.last_was_punctuation = False + state.should_emit_space = True + + +class VariadicLikeTypeDirective(VariadicLikeFormatDirective): + """ + Base class for type checking. + A variadic-like type directive can not be followed by a variadic-like type directive. """ - pass +@dataclass(frozen=True) +class VariadicTypeDirective(VariadicLikeTypeDirective): + """ + A directive which parses the type of a variadic typeable directive, with format: + type-directive ::= type(typeable-directive) + """ -class VariadicLikeVariable(VariadicLikeFormatDirective, VariableDirective, ABC): - pass + inner: VariadicTypeableDirective + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + types = parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type + ) + if types is None: + types = () + self.inner.set_types(types, state) + return bool(types) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_list(self.inner.get_types(op), printer.print_attribute) + state.last_was_punctuation = False + state.should_emit_space = True -class VariadicVariable(VariadicLikeVariable, ABC): def is_present(self, op: IRDLOperation) -> bool: - return len(getattr(op, self.name)) > 0 + return self.inner.is_present(op) + + def set_empty(self, state: ParsingState): + self.inner.set_types((), state) + + +@dataclass(frozen=True) +class OptionalTypeDirective(VariadicLikeTypeDirective): + """ + A directive which parses the type of a optional typeable directive, with format: + type-directive ::= type(typeable-directive) + """ + + inner: OptionalTypeableDirective + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + type = parser.parse_optional_type() + self.inner.set_type(type, state) + return bool(type) + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + type = self.inner.get_type(op) + if type: + printer.print_attribute(type) + state.last_was_punctuation = False + state.should_emit_space = True -class OptionalVariable(VariadicLikeVariable, ABC): def is_present(self, op: IRDLOperation) -> bool: - return getattr(op, self.name) is not None + return self.inner.is_present(op) + def set_empty(self, state: ParsingState): + self.inner.set_type(None, state) -class VariadicLikeTypeDirective(VariadicLikeFormatDirective, VariableDirective, ABC): - pass + +@dataclass(frozen=True) +class VariableDirective(Directive, ABC): + """ + A variable directive, with the following format: + variable-directive ::= dollar-ident + The directive will request a space to be printed after. + """ + + name: str + """The variable name. This is only used for error message reporting.""" + index: int + """Index of the variable(operand or result) definition.""" -class VariadicTypeDirective(VariadicLikeTypeDirective, VariadicVariable, ABC): - pass +class VariadicVariable(VariableDirective, AnchorableDirective, ABC): + def is_present(self, op: IRDLOperation) -> bool: + return bool(getattr(op, self.name)) -class OptionalTypeDirective(VariadicLikeTypeDirective, OptionalVariable, ABC): - pass +class OptionalVariable(VariableDirective, AnchorableDirective, ABC): + def is_present(self, op: IRDLOperation) -> bool: + return getattr(op, self.name) is not None @dataclass(frozen=True) @@ -423,7 +531,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No @dataclass(frozen=True) -class OperandVariable(VariableDirective): +class OperandVariable(VariableDirective, FormatDirective, TypeableDirective): """ An operand variable, with the following format: operand-directive ::= dollar-ident @@ -434,6 +542,9 @@ def parse(self, parser: Parser, state: ParsingState) -> None: operand = parser.parse_unresolved_operand() state.operands[self.index] = operand + def set_type(self, type: Attribute, state: ParsingState) -> None: + state.operand_types[self.index] = type + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") @@ -441,10 +552,20 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def get_type(self, op: IRDLOperation) -> Attribute: + return getattr(op, self.name).type + + +class VariadicOperandDirective(VariadicLikeFormatDirective, ABC): + """ + Base class for typechecking. + A variadic operand directive cannot follow another variadic operand directive. + """ + @dataclass(frozen=True) class VariadicOperandVariable( - VariadicVariable, VariableDirective, OptionallyParsableDirective + VariadicVariable, VariadicOperandDirective, VariadicTypeableDirective ): """ A variadic operand variable, with the following format: @@ -461,6 +582,9 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: state.operands[self.index] = operands return bool(operands) + def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: + state.operand_types[self.index] = types + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") @@ -470,8 +594,16 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: + return getattr(op, self.name).types + + def set_empty(self, state: ParsingState): + state.operands[self.index] = () -class OptionalOperandVariable(OptionalVariable, OptionallyParsableDirective): + +class OptionalOperandVariable( + OptionalVariable, VariadicOperandDirective, OptionalTypeableDirective +): """ An optional operand variable, with the following format: operand-directive ::= ( percent-ident )? @@ -485,6 +617,9 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: state.operands[self.index] = operand return bool(operand) + def set_type(self, type: Attribute | None, state: ParsingState) -> None: + state.operand_types[self.index] = type or () + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") @@ -494,80 +629,18 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True - -@dataclass(frozen=True) -class OperandTypeDirective(TypeDirective): - """ - An operand variable type directive, with the following format: - operand-type-directive ::= type(dollar-ident) - The directive will request a space to be printed right after. - """ - - def parse(self, parser: Parser, state: ParsingState) -> None: - type = parser.parse_type() - state.operand_types[self.index] = type - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_attribute(getattr(op, self.name).type) - state.last_was_punctuation = False - state.should_emit_space = True - - -@dataclass(frozen=True) -class VariadicOperandTypeDirective( - TypeDirective, VariadicTypeDirective, OptionallyParsableDirective -): - """ - A variadic operand variable, with the following format: - operand-directive ::= ( percent-ident ( `,` percent-id )* )? - The directive will request a space to be printed after. - """ - - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - operand_types = parser.parse_optional_undelimited_comma_separated_list( - parser.parse_optional_type, parser.parse_type - ) - if operand_types is None: - operand_types = () - state.operand_types[self.index] = operand_types - return bool(operand_types) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_list(getattr(op, self.name).types, printer.print_attribute) - state.last_was_punctuation = False - state.should_emit_space = True - - -class OptionalOperandTypeDirective(OptionalTypeDirective, OptionallyParsableDirective): - """ - An optional operand variable type directive, with the following format: - operand-type-directive ::= ( type(dollar-ident) )? - The directive will request a space to be printed after. - """ - - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - type = parser.parse_optional_type() - if type is None: - type = () - state.operand_types[self.index] = type - return bool(type) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") + def get_type(self, op: IRDLOperation) -> Attribute | None: operand = getattr(op, self.name) if operand: - printer.print_attribute(operand.type) - state.last_was_punctuation = False - state.should_emit_space = True + return operand.type + return None + + def set_empty(self, state: ParsingState): + state.operands[self.index] = () @dataclass(frozen=True) -class ResultVariable(VariableDirective): +class ResultVariable(VariableDirective, TypeableDirective): """ An result variable, with the following format: result-directive ::= dollar-ident @@ -575,23 +648,15 @@ class ResultVariable(VariableDirective): parsing is not handled by the custom operation parser. """ - def parse(self, parser: Parser, state: ParsingState) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def set_type(self, type: Attribute, state: ParsingState) -> None: + state.result_types[self.index] = type - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def get_type(self, op: IRDLOperation) -> Attribute: + return getattr(op, self.name).type @dataclass(frozen=True) -class VariadicResultVariable( - ResultVariable, VariadicVariable, OptionallyParsableDirective -): +class VariadicResultVariable(VariadicVariable, VariadicTypeableDirective): """ A variadic result variable, with the following format: result-directive ::= percent-ident (( `,` percent-id )* )? @@ -599,21 +664,14 @@ class VariadicResultVariable( parsing is not handled by the custom operation parser. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) - return False + def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: + state.result_types[self.index] = types - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: + return getattr(op, self.name).types -class OptionalResultVariable(OptionalVariable, OptionallyParsableDirective): +class OptionalResultVariable(OptionalVariable, OptionalTypeableDirective): """ An optional result variable, with the following format: result-directive ::= ( percent-ident )? @@ -621,92 +679,29 @@ class OptionalResultVariable(OptionalVariable, OptionallyParsableDirective): parsing is not handled by the custom operation parser. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) - return False + def set_type(self, type: Attribute | None, state: ParsingState) -> None: + state.result_types[self.index] = type or () - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def get_type(self, op: IRDLOperation) -> Attribute | None: + res = getattr(op, self.name) + if res: + return res.type + return None -@dataclass(frozen=True) -class ResultTypeDirective(TypeDirective): +class RegionDirective(OptionallyParsableDirective, ABC): """ - A result variable type directive, with the following format: - result-type-directive ::= type(dollar-ident) - The directive will request a space to be printed right after. + Baseclass to help keep typechecking simple. + RegionDirective is for any RegionVariable, which are all OptionallyParsable. """ - def parse(self, parser: Parser, state: ParsingState) -> None: - type = parser.parse_type() - state.result_types[self.index] = type - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_attribute(getattr(op, self.name).type) - state.last_was_punctuation = False - state.should_emit_space = True - -@dataclass(frozen=True) -class VariadicResultTypeDirective( - TypeDirective, VariadicTypeDirective, OptionallyParsableDirective -): +class VariadicRegionDirective(RegionDirective, VariadicLikeFormatDirective, ABC): """ - A variadic result variable type directive, with the following format: - variadic-result-type-directive ::= ( percent-ident ( `,` percent-id )* )? - The directive will request a space to be printed after. + Base class for typechecking. + A variadic region directive cannot follow another variadic region directive. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - result_types = parser.parse_optional_undelimited_comma_separated_list( - parser.parse_optional_type, parser.parse_type - ) - if result_types is None: - result_types = () - state.result_types[self.index] = result_types - return bool(result_types) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_list(getattr(op, self.name).types, printer.print_attribute) - state.last_was_punctuation = False - state.should_emit_space = True - - -class OptionalResultTypeDirective( - TypeDirective, OptionalTypeDirective, OptionallyParsableDirective -): - """ - An optional result variable type directive, with the following format: - result-type-directive ::= ( type(dollar-ident) )? - The directive will request a space to be printed after. - """ - - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - type = parser.parse_optional_type() - if type is None: - type = () - state.result_types[self.index] = type - return bool(type) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - result = getattr(op, self.name) - if result: - printer.print_attribute(result.type) - state.last_was_punctuation = False - state.should_emit_space = True - @dataclass(frozen=True) class RegionVariable(RegionDirective, VariableDirective): @@ -730,7 +725,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No @dataclass(frozen=True) -class VariadicRegionVariable(RegionDirective, VariadicVariable): +class VariadicRegionVariable(VariadicRegionDirective, VariadicVariable): """ A variadic region variable, with the following format: region-directive ::= dollar-ident @@ -757,8 +752,11 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.regions[self.index] = () + -class OptionalRegionVariable(RegionDirective, OptionalVariable): +class OptionalRegionVariable(VariadicRegionDirective, OptionalVariable): """ An optional region variable, with the following format: region-directive ::= dollar-ident @@ -781,6 +779,16 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.regions[self.index] = () + + +class VariadicSuccessorDirective(VariadicLikeFormatDirective, ABC): + """ + Base class for type checking. + A variadic successor directive cannot follow another variadic successor directive. + """ + class SuccessorVariable(VariableDirective, OptionallyParsableDirective): """ @@ -804,7 +812,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.should_emit_space = True -class VariadicSuccessorVariable(VariadicVariable, OptionallyParsableDirective): +class VariadicSuccessorVariable(VariadicSuccessorDirective, VariadicVariable): """ A variadic successor variable, with the following format: successor-directive ::= dollar-ident @@ -831,8 +839,11 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.successors[self.index] = () + -class OptionalSuccessorVariable(OptionalVariable, OptionallyParsableDirective): +class OptionalSuccessorVariable(VariadicSuccessorDirective, OptionalVariable): """ An optional successor variable, with the following format: successor-directive ::= dollar-ident @@ -855,6 +866,9 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.successors[self.index] = () + @dataclass(frozen=True) class AttributeVariable(FormatDirective): @@ -1075,33 +1089,8 @@ def parse(self, parser: Parser, state: ParsingState) -> None: # type to empty else: for element in self.then_elements: - match element: - case ( - OperandVariable(_, index) - | VariadicOperandVariable(_, index) - | OptionalOperandVariable(_, index) - ): - state.operands[index] = () - case ( - OperandTypeDirective(_, index) - | VariadicOperandTypeDirective(_, index) - | OptionalOperandTypeDirective(_, index) - ): - state.operand_types[index] = () - case ( - RegionVariable(_, index) - | VariadicRegionVariable(_, index) - | OptionalRegionVariable(_, index) - ): - state.regions[index] = () - case ( - ResultTypeDirective(_, index) - | VariadicResultTypeDirective(_, index) - | OptionalResultTypeDirective(_, index) - ): - state.result_types[index] = () - case _: - pass + if isinstance(element, VariadicLikeFormatDirective): + element.set_empty(state) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if self.anchor.is_present(op): diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index d4e413a2fd..7e62f3f0b6 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -40,6 +40,7 @@ ) from xdsl.irdl.declarative_assembly_format import ( AnchorableDirective, + AnyTypeableDirective, AttrDictDirective, AttributeVariable, DefaultValuedAttributeVariable, @@ -47,35 +48,35 @@ FormatProgram, KeywordDirective, OperandOrResult, - OperandTypeDirective, OperandVariable, OptionalAttributeVariable, OptionalGroupDirective, OptionallyParsableDirective, - OptionalOperandTypeDirective, OptionalOperandVariable, OptionalRegionVariable, - OptionalResultTypeDirective, OptionalResultVariable, OptionalSuccessorVariable, + OptionalTypeableDirective, + OptionalTypeDirective, OptionalUnitAttrVariable, ParsingState, PunctuationDirective, RegionDirective, RegionVariable, - ResultTypeDirective, ResultVariable, SuccessorVariable, - VariableDirective, + TypeDirective, VariadicLikeFormatDirective, VariadicLikeTypeDirective, - VariadicLikeVariable, - VariadicOperandTypeDirective, + VariadicOperandDirective, VariadicOperandVariable, + VariadicRegionDirective, VariadicRegionVariable, - VariadicResultTypeDirective, VariadicResultVariable, + VariadicSuccessorDirective, VariadicSuccessorVariable, + VariadicTypeableDirective, + VariadicTypeDirective, WhitespaceDirective, ) from xdsl.parser import BaseParser, ParserState @@ -148,8 +149,6 @@ class FormatParser(BaseParser): """The successor variables that are already parsed.""" has_attr_dict: bool = field(default=False) """True if the attribute dictionary has already been parsed.""" - context: ParsingContext = field(default=ParsingContext.TopLevel) - """Indicates if the parser is nested in a particular directive.""" type_resolutions: dict[ tuple[OperandOrResult, int], tuple[Callable[[Attribute], Attribute], OperandOrResult, int], @@ -177,7 +176,7 @@ def parse_format(self) -> FormatProgram: """ elements: list[FormatDirective] = [] while self._current_token.kind != Token.Kind.EOF: - elements.append(self.parse_directive()) + elements.append(self.parse_format_directive()) self.add_reserved_attrs_to_directive(elements) extractors = self.extractors_by_name() @@ -205,25 +204,19 @@ def verify_directives(self, elements: list[FormatDirective]): self.raise_error( "A variadic type directive cannot be followed by another variadic type directive." ) - case VariadicLikeVariable(), VariadicLikeVariable(): - if not ( - isinstance(a, RegionDirective | VariadicLikeTypeDirective) - or isinstance(b, RegionDirective | VariadicLikeTypeDirective) - ): - self.raise_error( - "A variadic operand variable cannot be followed by another variadic operand variable." - ) - elif isinstance(a, RegionDirective) and isinstance( - b, RegionDirective - ): - self.raise_error( - "A variadic region variable cannot be followed by another variadic region variable." - ) - case AttrDictDirective(), RegionDirective() if not (a.with_keyword): + case VariadicOperandDirective(), VariadicOperandDirective(): self.raise_error( - "An `attr-dict' directive without keyword cannot be directly followed by a region variable as it is ambiguous." + "A variadic operand variable cannot be followed by another variadic operand variable." + ) + case VariadicRegionDirective(), VariadicRegionDirective(): + self.raise_error( + "A variadic region variable cannot be followed by another variadic region variable." + ) + case VariadicSuccessorDirective(), VariadicSuccessorDirective(): + self.raise_error( + "A variadic successor variable cannot be followed by another variadic successor variable." ) - case AttrDictDirective(), RegionVariable() if not (a.with_keyword): + case AttrDictDirective(), RegionDirective() if not (a.with_keyword): self.raise_error( "An `attr-dict' directive without keyword cannot be directly followed by a region variable as it is ambiguous." ) @@ -403,30 +396,22 @@ def verify_successors(self): "directive to the custom assembly format." ) - def parse_optional_variable( - self, - ) -> VariableDirective | AttributeVariable | None: - """ - Parse a variable, if present, with the following format: - variable ::= `$` bare-ident - The variable should refer to an operand, attribute, region, result, - or successor. - """ - if self._current_token.text[0] != "$": - return None - self._consume_token() - variable_name = self.parse_identifier(" after '$'") - - # Check if the variable is an operand + def _parse_optional_operand( + self, variable_name: str, top_level: bool + ) -> OptionalOperandVariable | VariadicOperandVariable | OperandVariable | None: for idx, (operand_name, operand_def) in enumerate(self.op_def.operands): if variable_name != operand_name: continue - if self.context == ParsingContext.TopLevel: + if top_level: if self.seen_operands[idx]: self.raise_error(f"operand '{variable_name}' is already bound") self.seen_operands[idx] = True if isinstance(operand_def, VariadicDef | OptionalDef): self.seen_attributes.add(AttrSizedOperandSegments.attribute_name) + else: + if self.seen_operand_types[idx]: + self.raise_error(f"type of '{variable_name}' is already bound") + self.seen_operand_types[idx] = True match operand_def: case OptOperandDef(): return OptionalOperandVariable(variable_name, idx) @@ -435,15 +420,28 @@ def parse_optional_variable( case _: return OperandVariable(variable_name, idx) + def parse_optional_typeable_variable(self) -> AnyTypeableDirective | None: + """ + Parse a variable, if present, with the following format: + variable ::= `$` bare-ident + The variable should refer to an operand or result. + """ + if self._current_token.text[0] != "$": + return None + self._consume_token() + variable_name = self.parse_identifier(" after '$'") + + # Check if the variable is an operand + if (variable := self._parse_optional_operand(variable_name, False)) is not None: + return variable + # Check if the variable is a result for idx, (result_name, result_def) in enumerate(self.op_def.results): if variable_name != result_name: continue - if self.context == ParsingContext.TopLevel: - self.raise_error( - "result variable cannot be in a toplevel directive. " - f"Consider using 'type({variable_name})' instead." - ) + if self.seen_result_types[idx]: + self.raise_error(f"type of '{variable_name}' is already bound") + self.seen_result_types[idx] = True match result_def: case OptResultDef(): return OptionalResultVariable(variable_name, idx) @@ -451,10 +449,25 @@ def parse_optional_variable( return VariadicResultVariable(variable_name, idx) case _: return ResultVariable(variable_name, idx) - if isinstance(result_def, VariadicDef): - return VariadicResultVariable(variable_name, idx) - else: - return ResultVariable(variable_name, idx) + + self.raise_error("expected typeable variable to refer to an operand or result") + + def parse_optional_variable( + self, + ) -> FormatDirective | None: + """ + Parse a variable, if present, with the following format: + variable ::= `$` bare-ident + The variable should refer to an operand, attribute, region, or successor. + """ + if self._current_token.text[0] != "$": + return None + self._consume_token() + variable_name = self.parse_identifier(" after '$'") + + # Check if the variable is an operand + if (variable := self._parse_optional_operand(variable_name, True)) is not None: + return variable # Check if the variable is a region for idx, (region_name, region_def) in enumerate(self.op_def.regions): @@ -492,17 +505,14 @@ def parse_optional_variable( attr_name = variable_name attr_or_prop = attr_or_prop_by_name[attr_name] is_property = attr_or_prop == "property" - if self.context == ParsingContext.TopLevel: - if is_property: - if attr_name in self.seen_properties: - self.raise_error(f"property '{variable_name}' is already bound") - self.seen_properties.add(attr_name) - else: - if attr_name in self.seen_attributes: - self.raise_error( - f"attribute '{variable_name}' is already bound" - ) - self.seen_attributes.add(attr_name) + if is_property: + if attr_name in self.seen_properties: + self.raise_error(f"property '{variable_name}' is already bound") + self.seen_properties.add(attr_name) + else: + if attr_name in self.seen_attributes: + self.raise_error(f"attribute '{variable_name}' is already bound") + self.seen_attributes.add(attr_name) attr_def = ( self.op_def.properties.get(attr_name) @@ -560,63 +570,23 @@ def parse_optional_variable( self.raise_error( "expected variable to refer to an operand, " - "attribute, region, result, or successor" + "attribute, region, or successor" ) def parse_type_directive(self) -> FormatDirective: """ Parse a type directive with the following format: - type-directive ::= `type` `(` variable `)` + type-directive ::= `type` `(` typeable-directive `)` `type` is expected to have already been parsed """ self.parse_punctuation("(") - - # Update the current context, since we are now in a type directive - previous_context = self.context - self.context = ParsingContext.TypeDirective - - variable = self.parse_optional_variable() - match variable: - case None: - self.raise_error("'type' directive expects a variable argument") - case OptionalOperandVariable(name, index): - if self.seen_operand_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_operand_types[index] = True - res = OptionalOperandTypeDirective(name, index) - case VariadicOperandVariable(name, index): - if self.seen_operand_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_operand_types[index] = True - res = VariadicOperandTypeDirective(name, index) - case OperandVariable(name, index): - if self.seen_operand_types[index]: - self.raise_error(f"type of '{name}' is already bound") - self.seen_operand_types[index] = True - res = OperandTypeDirective(name, index) - case OptionalResultVariable(name, index): - if self.seen_result_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_result_types[index] = True - res = OptionalResultTypeDirective(name, index) - case VariadicResultVariable(name, index): - if self.seen_result_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_result_types[index] = True - res = VariadicResultTypeDirective(name, index) - case ResultVariable(name, index): - if self.seen_result_types[index]: - self.raise_error(f"type of '{name}' is already bound") - self.seen_result_types[index] = True - res = ResultTypeDirective(name, index) - case AttributeVariable(): - self.raise_error("can only take the type of an operand or result") - case _: - raise ValueError(f"Unexpected variable type {type(variable)}") - + inner = self.parse_typeable_directive() self.parse_punctuation(")") - self.context = previous_context - return res + if isinstance(inner, VariadicTypeableDirective): + return VariadicTypeDirective(inner) + if isinstance(inner, OptionalTypeableDirective): + return OptionalTypeDirective(inner) + return TypeDirective(inner) def parse_optional_group(self) -> FormatDirective: """ @@ -627,7 +597,7 @@ def parse_optional_group(self) -> FormatDirective: anchor: FormatDirective | None = None while not self.parse_optional_punctuation(")"): - then_elements += (self.parse_directive(),) + then_elements += (self.parse_format_directive(),) if self.parse_optional_keyword("^"): if anchor is not None: self.raise_error("An optional group can only have one anchor.") @@ -710,7 +680,16 @@ def parse_keyword_or_punctuation(self) -> FormatDirective: self.parse_characters("`") return KeywordDirective(ident) - def parse_directive(self) -> FormatDirective: + def parse_typeable_directive(self) -> AnyTypeableDirective: + """ + Parse a typeable directive, with the following format: + directive ::= variable + """ + if variable := self.parse_optional_typeable_variable(): + return variable + self.raise_error(f"unexpected token '{self._current_token.text}'") + + def parse_format_directive(self) -> FormatDirective: """ Parse a format directive, with the following format: directive ::= `attr-dict` From 421792ae67be4f9873410898b26275efd9a56b8c Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Fri, 22 Nov 2024 12:33:45 +0000 Subject: [PATCH 12/15] dialects: (linalg) linalg.fill attribute positioning (#3496) --- .../with-mlir/dialects/linalg/ops.mlir | 13 +++++++++---- xdsl/dialects/linalg.py | 2 ++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index 3a62737b2c..9d942b95d9 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -59,6 +59,9 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>) %18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %20 = "test.op"() : () -> (memref<64x4096xf32>) +%zero = arith.constant 0: f32 +linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>) + linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>) @@ -99,17 +102,19 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>) // CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] // CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32 +// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32): +// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_2 : f32 // CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -> tensor<2x3xf32> // CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) // CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32> +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>) // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 // CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32> -// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> // CHECK-NEXT: } diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 854ed50c86..dd1097d67e 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -642,6 +642,8 @@ class FillOp(NamedOpBase): name = "linalg.fill" + PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True + def __init__( self, inputs: Sequence[SSAValue], From 020cae763737d7162f6774504087d773dedeca66 Mon Sep 17 00:00:00 2001 From: Hugo Pompougnac Date: Fri, 22 Nov 2024 12:36:57 +0000 Subject: [PATCH 13/15] interpreter: (pdl) add the implementation of type interpretation (#3498) --- .../apply-pdl/apply_pdl_build_type.mlir | 27 +++++++++++++++++++ xdsl/interpreters/pdl.py | 7 +++++ 2 files changed, 34 insertions(+) create mode 100644 tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir diff --git a/tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir b/tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir new file mode 100644 index 0000000000..15468e3446 --- /dev/null +++ b/tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir @@ -0,0 +1,27 @@ +// RUN: xdsl-opt %s -p apply-pdl | filecheck %s + +%x = "test.op"() : () -> (i32) + +pdl.pattern : benefit(1) { + %in_type = pdl.type: i32 + %root = pdl.operation "test.op" -> (%in_type: !pdl.type) + pdl.rewrite %root { + %out_type = pdl.type: i64 + %new_op = pdl.operation "test.op" -> (%out_type: !pdl.type) + pdl.replace %root with %new_op + } +} + +// CHECK: builtin.module { +// CHECK-NEXT: %x = "test.op"() : () -> i64 +// CHECK-NEXT: pdl.pattern : benefit(1) { +// CHECK-NEXT: %in_type = pdl.type : i32 +// CHECK-NEXT: %root = pdl.operation "test.op" -> (%in_type : !pdl.type) +// CHECK-NEXT: pdl.rewrite %root { +// CHECK-NEXT: %out_type = pdl.type : i64 +// CHECK-NEXT: %new_op = pdl.operation "test.op" -> (%out_type : !pdl.type) +// CHECK-NEXT: pdl.replace %root with %new_op +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: diff --git a/xdsl/interpreters/pdl.py b/xdsl/interpreters/pdl.py index ff1bb1317e..4fd5d4232a 100644 --- a/xdsl/interpreters/pdl.py +++ b/xdsl/interpreters/pdl.py @@ -386,3 +386,10 @@ def run_erase( (old,) = interpreter.get_values((op.op_value,)) self.rewriter.erase_op(old) return () + + @impl(pdl.TypeOp) + def run_type( + self, interpreter: Interpreter, op: pdl.TypeOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + assert isinstance(op.constantType, Attribute) + return (op.constantType,) From f11ca74c31e863bf4e800f1b3b8493b98d0a1b79 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 22 Nov 2024 13:14:05 +0000 Subject: [PATCH 14/15] core: make TypedAttribute not generic --- tests/irdl/test_attribute_definition.py | 2 +- xdsl/dialects/builtin.py | 6 +++--- xdsl/ir/core.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/irdl/test_attribute_definition.py b/tests/irdl/test_attribute_definition.py index a96bc462c6..8f403ca4e6 100644 --- a/tests/irdl/test_attribute_definition.py +++ b/tests/irdl/test_attribute_definition.py @@ -251,7 +251,7 @@ def test_typed_attribute(): @irdl_attr_definition class TypedAttr( # pyright: ignore[reportUnusedClass] - TypedAttribute[Attribute] + TypedAttribute ): name = "test.typed" diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index df0e3d89cb..99931f4507 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -449,7 +449,7 @@ class IndexType(ParametrizedAttribute): @irdl_attr_definition class IntegerAttr( Generic[_IntegerAttrType], - TypedAttribute[_IntegerAttrType], + TypedAttribute, ): name = "integer" value: ParameterDef[IntAttr] @@ -504,8 +504,8 @@ def verify(self) -> None: @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: assert isinstance(type, IntegerType | IndexType) return IntegerAttr(parser.parse_integer(allow_boolean=(type == i1)), type) diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index b4bece71de..698f52e25d 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -606,7 +606,7 @@ def _verify(self): super()._verify() -class TypedAttribute(ParametrizedAttribute, Generic[AttributeCovT], ABC): +class TypedAttribute(ParametrizedAttribute, ABC): """ An attribute with a type. """ @@ -617,8 +617,8 @@ def get_type_index(cls) -> int: ... @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: """ Parse the attribute with the given type. """ From 1dfa2d652dd8defa459b960eaa04f5ffc49ef940 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 22 Nov 2024 17:11:51 +0000 Subject: [PATCH 15/15] fix floatattr --- xdsl/dialects/builtin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index baf87a8984..fe3c33384e 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -634,7 +634,7 @@ def __hash__(self): @irdl_attr_definition -class FloatAttr(Generic[_FloatAttrType], TypedAttribute[_FloatAttrType]): +class FloatAttr(Generic[_FloatAttrType], TypedAttribute): name = "float" value: ParameterDef[FloatData] @@ -671,8 +671,8 @@ def __init__( @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: assert isinstance(type, AnyFloat) return FloatAttr(parser.parse_float(), type)