Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (stream) simplify constr helper on stream attributes #3473

Merged
merged 12 commits into from
Nov 20, 2024
Merged
2 changes: 1 addition & 1 deletion tests/test_dialect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xdsl.context import MLContext
from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_dynamic_index_list_with_types,
parse_dynamic_index_list_without_types,
parse_dynamic_index_with_type,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UnrankedTensorType,
VectorType,
)
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag
from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SymbolRefAttr,
TensorType,
)
from xdsl.dialects.utils import parse_func_op_like, print_func_op_like
from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like
from xdsl.ir import (
Attribute,
Block,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TensorType,
)
from xdsl.dialects.experimental import dmp
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/experimental/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_call_op_like,
parse_func_op_like,
print_call_op_like,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
TensorType,
i64,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
AbstractYieldOperation,
)
from xdsl.ir import (
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
i32,
i64,
)
from xdsl.dialects.utils import FastMathAttrBase
from xdsl.dialects.utils.fast_math import FastMathAttrBase
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
i32,
i64,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_dynamic_index_list_without_types,
print_dynamic_index_list,
)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
IntegerType,
StringAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down Expand Up @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr())
outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr)
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
Block,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_call_op_like,
parse_func_op_like,
print_call_op_like,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import Self

from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
print_immediate_value,
si12,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SignlessIntegerConstraint,
i64,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
82 changes: 16 additions & 66 deletions xdsl/dialects/stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import abc
from typing import ClassVar, Generic, TypeVar, cast, overload
from typing import ClassVar, Generic, TypeVar, cast

from typing_extensions import Self

Expand Down Expand Up @@ -46,30 +46,10 @@ def __init__(self, element_type: _StreamTypeElement):
def get_element_type(self) -> _StreamTypeElement:
return self.element_type

@overload
@staticmethod
def constr(
*,
element_type: None = None,
) -> BaseAttr[StreamType[Attribute]]: ...

@overload
@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ...

@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None,
) -> (
BaseAttr[StreamType[Attribute]]
| ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]
):
if element_type is None:
return BaseAttr[StreamType[Attribute]](StreamType)
) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]](
StreamType, (element_type,)
)
Expand All @@ -79,76 +59,46 @@ def constr(
class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]):
name = "stream.readable"

@overload
@staticmethod
def constr(
*,
element_type: None = None,
) -> BaseAttr[ReadableStreamType[Attribute]]: ...

@overload
@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ...

@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None,
) -> (
BaseAttr[ReadableStreamType[Attribute]]
| ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]
):
if element_type is None:
return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType)
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]](
ReadableStreamType, (element_type,)
)


AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]](
ReadableStreamType
)


@irdl_attr_definition
class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]):
name = "stream.writable"

@overload
@staticmethod
def constr(
*,
element_type: None = None,
) -> BaseAttr[WritableStreamType[Attribute]]: ...

@overload
@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ...

@staticmethod
def constr(
*,
element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None,
) -> (
BaseAttr[WritableStreamType[Attribute]]
| ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]
):
if element_type is None:
return BaseAttr[WritableStreamType[Attribute]](WritableStreamType)
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]](
WritableStreamType, (element_type,)
)


AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]](
WritableStreamType
)
Comment on lines +89 to +91
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling that this will be the pattern to use until we have defaults for TypeVars



class ReadOperation(IRDLOperation, abc.ABC):
"""
Abstract base class for operations that read from a stream.
"""

T: ClassVar = VarConstraint("T", AnyAttr())

stream = operand_def(ReadableStreamType.constr(element_type=T))
stream = operand_def(ReadableStreamType.constr(T))
res = result_def(T)

def __init__(self, stream: SSAValue, result_type: Attribute | None = None):
Expand Down Expand Up @@ -182,7 +132,7 @@ class WriteOperation(IRDLOperation, abc.ABC):
T: ClassVar = VarConstraint("T", AnyAttr())

value = operand_def(T)
stream = operand_def(WritableStreamType.constr(element_type=T))
stream = operand_def(WritableStreamType.constr(T))

def __init__(self, value: SSAValue, stream: SSAValue):
super().__init__(operands=[value, stream])
Expand Down
Empty file added xdsl/dialects/utils/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions xdsl/dialects/utils/fast_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC
from dataclasses import dataclass

from xdsl.ir import BitEnumAttribute
from xdsl.utils.str_enum import StrEnum


class FastMathFlag(StrEnum):
"""
Values specifying fast math behaviour of an arithmetic operation.
"""

REASSOC = "reassoc"
NO_NANS = "nnan"
NO_INFS = "ninf"
NO_SIGNED_ZEROS = "nsz"
ALLOW_RECIP = "arcp"
ALLOW_CONTRACT = "contract"
APPROX_FUNC = "afn"


@dataclass(frozen=True, init=False)
class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC):
"""
Base class for attributes defining fast math behavior of arithmetic operations.
"""

none_value = "none"
all_value = "fast"
34 changes: 0 additions & 34 deletions xdsl/dialects/utils.py → xdsl/dialects/utils/format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from abc import ABC
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Generic

from xdsl.dialects.builtin import (
Expand All @@ -14,7 +12,6 @@
from xdsl.ir import (
Attribute,
AttributeInvT,
BitEnumAttribute,
BlockArgument,
Operation,
Region,
Expand All @@ -23,7 +20,6 @@
from xdsl.irdl import IRDLOperation, var_operand_def
from xdsl.parser import Parser, UnresolvedOperand
from xdsl.printer import Printer
from xdsl.utils.str_enum import StrEnum


def print_call_op_like(
Expand Down Expand Up @@ -345,33 +341,3 @@ def parse_dynamic_index_list_without_types(
values.append(value_or_index)

return values, indices


# region Fast Math Flags


class FastMathFlag(StrEnum):
"""
Values specifying fast math behaviour of an arithmetic operation.
"""

REASSOC = "reassoc"
NO_NANS = "nnan"
NO_INFS = "ninf"
NO_SIGNED_ZEROS = "nsz"
ALLOW_RECIP = "arcp"
ALLOW_CONTRACT = "contract"
APPROX_FUNC = "afn"


@dataclass(frozen=True, init=False)
class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC):
"""
Base class for attributes defining fast math behavior of arithmetic operations.
"""

none_value = "none"
all_value = "fast"


# endregion
2 changes: 1 addition & 1 deletion xdsl/transforms/arith_add_fastmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal

from xdsl.dialects import arith, builtin
from xdsl.dialects.utils import FastMathFlag
from xdsl.dialects.utils.fast_math import FastMathFlag
from xdsl.passes import MLContext, ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/canonicalization_patterns/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from xdsl.dialects import riscv, riscv_snitch
from xdsl.dialects.builtin import IntegerAttr
from xdsl.dialects.utils import FastMathFlag
from xdsl.dialects.utils.fast_math import FastMathFlag
from xdsl.ir import OpResult, SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand Down
Loading