Skip to content

Commit

Permalink
Revert "dialects: (stream) simplify constr helper on stream attribute…
Browse files Browse the repository at this point in the history
…s" (#3493)

Reverts #3473
  • Loading branch information
superlopuh authored Nov 20, 2024
1 parent 25adb58 commit afa862d
Show file tree
Hide file tree
Showing 19 changed files with 85 additions and 35 deletions.
2 changes: 1 addition & 1 deletion tests/test_dialect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xdsl.context import MLContext
from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
parse_dynamic_index_list_with_types,
parse_dynamic_index_list_without_types,
parse_dynamic_index_with_type,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UnrankedTensorType,
VectorType,
)
from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SymbolRefAttr,
TensorType,
)
from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like
from xdsl.dialects.utils import parse_func_op_like, print_func_op_like
from xdsl.ir import (
Attribute,
Block,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TensorType,
)
from xdsl.dialects.experimental import dmp
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/experimental/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
parse_call_op_like,
parse_func_op_like,
print_call_op_like,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TensorType,
i64,
)
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
AbstractYieldOperation,
)
from xdsl.ir import (
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
i32,
i64,
)
from xdsl.dialects.utils.fast_math import FastMathAttrBase
from xdsl.dialects.utils import FastMathAttrBase
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
i32,
i64,
)
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
parse_dynamic_index_list_without_types,
print_dynamic_index_list,
)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
IntegerType,
StringAttr,
)
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down Expand Up @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr)
outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr())
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
Block,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
parse_call_op_like,
parse_func_op_like,
print_call_op_like,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import Self

from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
print_immediate_value,
si12,
)
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SignlessIntegerConstraint,
i64,
)
from xdsl.dialects.utils.format import (
from xdsl.dialects.utils import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
82 changes: 66 additions & 16 deletions xdsl/dialects/stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import abc
from typing import ClassVar, Generic, TypeVar, cast
from typing import ClassVar, Generic, TypeVar, cast, overload

from typing_extensions import Self

Expand Down Expand Up @@ -46,10 +46,30 @@ def __init__(self, element_type: _StreamTypeElement):
def get_element_type(self) -> _StreamTypeElement:
return self.element_type

@overload
@staticmethod
def constr(
*,
element_type: None = None,
) -> BaseAttr[StreamType[Attribute]]: ...

@overload
@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]:
) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ...

@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None,
) -> (
BaseAttr[StreamType[Attribute]]
| ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]
):
if element_type is None:
return BaseAttr[StreamType[Attribute]](StreamType)
return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]](
StreamType, (element_type,)
)
Expand All @@ -59,46 +79,76 @@ def constr(
class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]):
name = "stream.readable"

@overload
@staticmethod
def constr(
*,
element_type: None = None,
) -> BaseAttr[ReadableStreamType[Attribute]]: ...

@overload
@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]:
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ...

@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None,
) -> (
BaseAttr[ReadableStreamType[Attribute]]
| ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]
):
if element_type is None:
return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType)
return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]](
ReadableStreamType, (element_type,)
)


AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]](
ReadableStreamType
)


@irdl_attr_definition
class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]):
name = "stream.writable"

@overload
@staticmethod
def constr(
*,
element_type: None = None,
) -> BaseAttr[WritableStreamType[Attribute]]: ...

@overload
@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]:
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ...

@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None,
) -> (
BaseAttr[WritableStreamType[Attribute]]
| ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]
):
if element_type is None:
return BaseAttr[WritableStreamType[Attribute]](WritableStreamType)
return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]](
WritableStreamType, (element_type,)
)


AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]](
WritableStreamType
)


class ReadOperation(IRDLOperation, abc.ABC):
"""
Abstract base class for operations that read from a stream.
"""

T: ClassVar = VarConstraint("T", AnyAttr())

stream = operand_def(ReadableStreamType.constr(T))
stream = operand_def(ReadableStreamType.constr(element_type=T))
res = result_def(T)

def __init__(self, stream: SSAValue, result_type: Attribute | None = None):
Expand Down Expand Up @@ -132,7 +182,7 @@ class WriteOperation(IRDLOperation, abc.ABC):
T: ClassVar = VarConstraint("T", AnyAttr())

value = operand_def(T)
stream = operand_def(WritableStreamType.constr(T))
stream = operand_def(WritableStreamType.constr(element_type=T))

def __init__(self, value: SSAValue, stream: SSAValue):
super().__init__(operands=[value, stream])
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/arith_add_fastmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal

from xdsl.dialects import arith, builtin
from xdsl.dialects.utils.fast_math import FastMathFlag
from xdsl.dialects.utils import FastMathFlag
from xdsl.passes import MLContext, ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/canonicalization_patterns/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xdsl.dialects import riscv, riscv_snitch
from xdsl.dialects.builtin import IntegerAttr
from xdsl.dialects.utils.fast_math import FastMathFlag
from xdsl.dialects.utils import FastMathFlag
from xdsl.ir import OpResult, SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand Down

0 comments on commit afa862d

Please sign in to comment.