From cf1b395d7f527d8411d1b9903add00b3fa43508d Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 17:03:26 +0000 Subject: [PATCH] Revert "dialects: (stream) simplify constr helper on stream attributes (#3473)" This reverts commit 25adb58c5a13520df76601e9e7fe9deb9d1505a0. --- tests/test_dialect_utils.py | 2 +- xdsl/dialects/arith.py | 2 +- xdsl/dialects/csl/csl.py | 2 +- xdsl/dialects/csl/csl_stencil.py | 2 +- xdsl/dialects/experimental/air.py | 2 +- xdsl/dialects/func.py | 2 +- xdsl/dialects/linalg.py | 2 +- xdsl/dialects/llvm.py | 2 +- xdsl/dialects/memref.py | 2 +- xdsl/dialects/memref_stream.py | 4 +- xdsl/dialects/omp.py | 2 +- xdsl/dialects/riscv.py | 2 +- xdsl/dialects/riscv_func.py | 2 +- xdsl/dialects/riscv_scf.py | 2 +- xdsl/dialects/riscv_snitch.py | 2 +- xdsl/dialects/scf.py | 2 +- xdsl/dialects/stream.py | 82 +++++++++++++++---- xdsl/transforms/arith_add_fastmath.py | 2 +- .../canonicalization_patterns/riscv.py | 2 +- 19 files changed, 85 insertions(+), 35 deletions(-) diff --git a/tests/test_dialect_utils.py b/tests/test_dialect_utils.py index 3c80c06622..171d95373f 100644 --- a/tests/test_dialect_utils.py +++ b/tests/test_dialect_utils.py @@ -4,7 +4,7 @@ from xdsl.context import MLContext from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32 -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_dynamic_index_list_with_types, parse_dynamic_index_list_without_types, parse_dynamic_index_with_type, diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 7300fc2f53..34e09a4241 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -22,7 +22,7 @@ UnrankedTensorType, VectorType, ) -from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index d5257a6e08..f17bb8be7c 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -37,7 +37,7 @@ SymbolRefAttr, TensorType, ) -from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like +from xdsl.dialects.utils import parse_func_op_like, print_func_op_like from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index c9956ae69b..eca5bc9cad 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -17,7 +17,7 @@ TensorType, ) from xdsl.dialects.experimental import dmp -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/experimental/air.py b/xdsl/dialects/experimental/air.py index 2e592257b1..f87b2b1795 100644 --- a/xdsl/dialects/experimental/air.py +++ b/xdsl/dialects/experimental/air.py @@ -20,7 +20,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 480676f996..8131cdf6df 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -10,7 +10,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 5cdc7e08d3..854ed50c86 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -24,7 +24,7 @@ TensorType, i64, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, ) from xdsl.ir import ( diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index 10d34717a5..dd3752ff43 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -24,7 +24,7 @@ i32, i64, ) -from xdsl.dialects.utils.fast_math import FastMathAttrBase +from xdsl.dialects.utils import FastMathAttrBase from xdsl.ir import ( Attribute, BitEnumAttribute, diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index eb50974e89..bfc9bf8d8c 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -30,7 +30,7 @@ i32, i64, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_dynamic_index_list_without_types, print_dynamic_index_list, ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 13ef7a7d30..9335d48c2a 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -25,7 +25,7 @@ IntegerType, StringAttr, ) -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/omp.py b/xdsl/dialects/omp.py index 2eec70a374..c83058e145 100644 --- a/xdsl/dialects/omp.py +++ b/xdsl/dialects/omp.py @@ -9,7 +9,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils.format import AbstractYieldOperation +from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index 84e1cc631e..b9cdfb34fd 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -24,7 +24,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index 610ceee50a..d9e8a63704 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -12,7 +12,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/riscv_scf.py b/xdsl/dialects/riscv_scf.py index 67e16f80e9..cea09406c8 100644 --- a/xdsl/dialects/riscv_scf.py +++ b/xdsl/dialects/riscv_scf.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 94ce963e95..25d0cafcaa 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -33,7 +33,7 @@ print_immediate_value, si12, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index 0c2ac49230..b29cfbb42c 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -12,7 +12,7 @@ SignlessIntegerConstraint, i64, ) -from xdsl.dialects.utils.format import ( +from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 8486b303cd..3c2d5ea2d6 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Generic, TypeVar, cast +from typing import ClassVar, Generic, TypeVar, cast, overload from typing_extensions import Self @@ -46,10 +46,30 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type + @overload @staticmethod def constr( + *, + element_type: None = None, + ) -> BaseAttr[StreamType[Attribute]]: ... + + @overload + @staticmethod + def constr( + *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... + + @staticmethod + def constr( + *, + element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, + ) -> ( + BaseAttr[StreamType[Attribute]] + | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] + ): + if element_type is None: + return BaseAttr[StreamType[Attribute]](StreamType) return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -59,38 +79,68 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" + @overload @staticmethod def constr( + *, + element_type: None = None, + ) -> BaseAttr[ReadableStreamType[Attribute]]: ... + + @overload + @staticmethod + def constr( + *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... + + @staticmethod + def constr( + *, + element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, + ) -> ( + BaseAttr[ReadableStreamType[Attribute]] + | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] + ): + if element_type is None: + return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( ReadableStreamType, (element_type,) ) -AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( - ReadableStreamType -) - - @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" + @overload + @staticmethod + def constr( + *, + element_type: None = None, + ) -> BaseAttr[WritableStreamType[Attribute]]: ... + + @overload @staticmethod def constr( + *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... + + @staticmethod + def constr( + *, + element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, + ) -> ( + BaseAttr[WritableStreamType[Attribute]] + | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] + ): + if element_type is None: + return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( WritableStreamType, (element_type,) ) -AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( - WritableStreamType -) - - class ReadOperation(IRDLOperation, abc.ABC): """ Abstract base class for operations that read from a stream. @@ -98,7 +148,7 @@ class ReadOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType.constr(T)) + stream = operand_def(ReadableStreamType.constr(element_type=T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -132,7 +182,7 @@ class WriteOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType.constr(T)) + stream = operand_def(WritableStreamType.constr(element_type=T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) diff --git a/xdsl/transforms/arith_add_fastmath.py b/xdsl/transforms/arith_add_fastmath.py index e9221be204..c60f54bd12 100644 --- a/xdsl/transforms/arith_add_fastmath.py +++ b/xdsl/transforms/arith_add_fastmath.py @@ -4,7 +4,7 @@ from typing import Literal from xdsl.dialects import arith, builtin -from xdsl.dialects.utils.fast_math import FastMathFlag +from xdsl.dialects.utils import FastMathFlag from xdsl.passes import MLContext, ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 6588f84cf9..519c5b0813 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -2,7 +2,7 @@ from xdsl.dialects import riscv, riscv_snitch from xdsl.dialects.builtin import IntegerAttr -from xdsl.dialects.utils.fast_math import FastMathFlag +from xdsl.dialects.utils import FastMathFlag from xdsl.ir import OpResult, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter,