diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 9335d48c2a..f97bd909df 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 3c2d5ea2d6..8486b303cd 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Generic, TypeVar, cast, overload +from typing import ClassVar, Generic, TypeVar, cast from typing_extensions import Self @@ -46,30 +46,10 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type - @overload @staticmethod def constr( - *, - element_type: None = None, - ) -> BaseAttr[StreamType[Attribute]]: ... - - @overload - @staticmethod - def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[StreamType[Attribute]] - | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[StreamType[Attribute]](StreamType) + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -79,68 +59,38 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" - @overload @staticmethod def constr( - *, - element_type: None = None, - ) -> BaseAttr[ReadableStreamType[Attribute]]: ... - - @overload - @staticmethod - def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[ReadableStreamType[Attribute]] - | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( ReadableStreamType, (element_type,) ) +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) + + @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[WritableStreamType[Attribute]]: ... - - @overload @staticmethod def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[WritableStreamType[Attribute]] - | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( WritableStreamType, (element_type,) ) +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) + + class ReadOperation(IRDLOperation, abc.ABC): """ Abstract base class for operations that read from a stream. @@ -148,7 +98,7 @@ class ReadOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType.constr(element_type=T)) + stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -182,7 +132,7 @@ class WriteOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType.constr(element_type=T)) + stream = operand_def(WritableStreamType.constr(T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream])