Skip to content

Commit

Permalink
misc: add more .constr helpers and constraints (#3273)
Browse files Browse the repository at this point in the history
Pylance errors: 84 -> 74.

Part of #3264
  • Loading branch information
superlopuh authored Oct 9, 2024
1 parent b205332 commit 170ce50
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 34 deletions.
27 changes: 26 additions & 1 deletion tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
AttrSizedOperandSegments,
AttrSizedRegionSegments,
AttrSizedResultSegments,
BaseAttr,
ConstraintVar,
EqAttrConstraint,
GenericAttrConstraint,
IRDLOperation,
ParamAttrConstraint,
ParameterDef,
ParsePropInAttrDict,
VarOperand,
Expand Down Expand Up @@ -1559,7 +1562,7 @@ class OptSuccessorOp(IRDLOperation):
# Inference #
################################################################################

_T = TypeVar("_T", bound=Attribute)
_T = TypeVar("_T", bound=Attribute, covariant=True)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1661,6 +1664,18 @@ class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
p: ParameterDef[_T]
q: ParameterDef[Attribute]

@classmethod
def constr(
cls,
*,
n: GenericAttrConstraint[Attribute] | None = None,
p: GenericAttrConstraint[_T] | None = None,
q: GenericAttrConstraint[Attribute] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if n is None and p is None and q is None:
return BaseAttr(cls)
return ParamAttrConstraint(cls, (n, p, q))

@irdl_op_definition
class TwoOperandsNestedVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
Expand Down Expand Up @@ -1695,6 +1710,16 @@ class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
name = "test.param_one"
p: ParameterDef[_T]

@classmethod
def constr(
cls,
*,
p: GenericAttrConstraint[_T] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if p is None:
return BaseAttr(cls)
return ParamAttrConstraint(cls, (p,))

@irdl_op_definition
class OneOperandOneResultNestedOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
Expand Down
8 changes: 1 addition & 7 deletions tests/tblgen_to_py/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,7 @@ class Test_AttributesOp(IRDLOperation):
name = "test.attributes"

int_attr = prop_def(
ParamAttrConstraint(
IntegerAttr,
(
AnyAttr(),
EqAttrConstraint(IntegerType(16)),
),
)
IntegerAttr[IntegerType].constr(type=EqAttrConstraint(IntegerType(16)))
)

in_ = prop_def(BaseAttr(Test_TestAttr), prop_name="in")
Expand Down
44 changes: 43 additions & 1 deletion xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AttrConstraint,
BaseAttr,
ConstraintContext,
GenericAttrConstraint,
GenericData,
IRDLOperation,
MessageConstraint,
Expand Down Expand Up @@ -439,6 +440,7 @@ class IndexType(ParametrizedAttribute):
_IntegerAttrType = TypeVar(
"_IntegerAttrType", bound=IntegerType | IndexType, covariant=True
)
IntegerAttrTypeConstr = IndexTypeConstr | BaseAttr(IntegerType)
AnySignlessIntegerOrIndexType: TypeAlias = Annotated[
Attribute, AnyOf([IndexType, SignlessIntegerConstraint])
]
Expand Down Expand Up @@ -511,6 +513,24 @@ def parse_with_type(
def print_without_type(self, printer: Printer):
return printer.print(self.value.data)

@classmethod
def constr(
cls,
*,
# pyright needs updating, with the new one it works fine
value: AttrConstraint | None = None,
type: GenericAttrConstraint[_IntegerAttrType] = IntegerAttrTypeConstr, # pyright: ignore[reportGeneralTypeIssues]
) -> GenericAttrConstraint[IntegerAttr[_IntegerAttrType]]:
if value is None and type == AnyAttr():
return BaseAttr[IntegerAttr[_IntegerAttrType]](IntegerAttr)
return ParamAttrConstraint[IntegerAttr[_IntegerAttrType]](
IntegerAttr,
(
value,
type,
),
)


AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
AnyIntegerAttrConstr: BaseAttr[AnyIntegerAttr] = BaseAttr(IntegerAttr)
Expand Down Expand Up @@ -644,6 +664,7 @@ def __init__(


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)


@irdl_attr_definition
Expand Down Expand Up @@ -1436,7 +1457,7 @@ def print(self, printer: Printer) -> None:
f128 = Float128Type()


_MemRefTypeElement = TypeVar("_MemRefTypeElement", bound=Attribute)
_MemRefTypeElement = TypeVar("_MemRefTypeElement", bound=Attribute, covariant=True)
_UnrankedMemrefTypeElems = TypeVar(
"_UnrankedMemrefTypeElems", bound=Attribute, covariant=True
)
Expand Down Expand Up @@ -1565,6 +1586,27 @@ def get_strides(self) -> Sequence[int | None] | None:
case _:
return self.layout.get_strides()

@classmethod
def constr(
cls,
*,
shape: GenericAttrConstraint[Attribute] | None = None,
# pyright needs updating, with the new one it works fine
element_type: GenericAttrConstraint[_MemRefTypeElement] = AnyAttr(), # pyright: ignore[reportGeneralTypeIssues]
layout: GenericAttrConstraint[Attribute] | None = None,
memory_space: GenericAttrConstraint[Attribute] | None = None,
) -> GenericAttrConstraint[MemRefType[_MemRefTypeElement]]:
if (
shape is None
and element_type == AnyAttr()
and layout is None
and memory_space is None
):
return BaseAttr[MemRefType[_MemRefTypeElement]](MemRefType)
return ParamAttrConstraint[MemRefType[_MemRefTypeElement]](
MemRefType, (shape, element_type, layout, memory_space)
)


AnyMemRefType: TypeAlias = MemRefType[Attribute]
AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType)
Expand Down
34 changes: 33 additions & 1 deletion xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AnyFloatAttr,
AnyFloatAttrConstr,
AnyIntegerAttr,
AnyIntegerAttrConstr,
AnyMemRefType,
ArrayAttr,
BoolAttr,
Expand Down Expand Up @@ -48,6 +50,8 @@
TypeAttribute,
)
from xdsl.irdl import (
AnyOf,
BaseAttr,
ConstraintVar,
IRDLOperation,
ParameterDef,
Expand Down Expand Up @@ -380,6 +384,7 @@ def get_element_type(self) -> TypeAttribute:
QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]

ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr
ParamAttrConstr = AnyFloatAttrConstr | AnyIntegerAttrConstr


@irdl_op_definition
Expand All @@ -395,7 +400,7 @@ class VariableOp(IRDLOperation):

name = "csl.variable"

default = opt_prop_def(ParamAttr)
default = opt_prop_def(ParamAttrConstr)
res = result_def(VarType)

def get_element_type(self):
Expand Down Expand Up @@ -578,6 +583,12 @@ def verify_(self) -> None:
)


ZerosOpAttr: TypeAlias = IntegerType | Float32Type | Float16Type
ZerosOpAttrConstr: AnyOf[ZerosOpAttr] = (
BaseAttr(IntegerType) | BaseAttr(Float32Type) | BaseAttr(Float16Type)
)


@irdl_op_definition
class ZerosOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1898,6 +1909,27 @@ class RpcOp(IRDLOperation):
id = operand_def(ColorType)


ParamOpAttr: TypeAlias = (
Float16Type
| Float32Type
| IntegerType
| ColorType
| FunctionType
| ImportedModuleType
| ComptimeStructType
)

ParamOpAttrConstr = (
BaseAttr(Float16Type)
| BaseAttr(Float32Type)
| BaseAttr(IntegerType)
| BaseAttr(ColorType)
| BaseAttr(FunctionType)
| BaseAttr(ImportedModuleType)
| BaseAttr(ComptimeStructType)
)


@irdl_op_definition
class ParamOp(IRDLOperation):
"""
Expand Down
17 changes: 17 additions & 0 deletions xdsl/dialects/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
TypeAttribute,
)
from xdsl.irdl import (
AnyAttr,
BaseAttr,
ConstraintVar,
GenericAttrConstraint,
IRDLOperation,
ParamAttrConstraint,
ParameterDef,
irdl_attr_definition,
irdl_op_definition,
Expand All @@ -44,6 +48,19 @@ def __init__(self, element_type: _StreamTypeElement):
def get_element_type(self) -> _StreamTypeElement:
return self.element_type

@classmethod
def constr(
cls,
*,
# pyright needs updating, with the new one it works fine
element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(), # pyright: ignore[reportGeneralTypeIssues]
) -> GenericAttrConstraint[StreamType[_StreamTypeElement]]:
if element_type == AnyAttr():
return BaseAttr[StreamType[_StreamTypeElement]](StreamType)
return ParamAttrConstraint[StreamType[_StreamTypeElement]](
StreamType, (element_type,)
)


@irdl_attr_definition
class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]):
Expand Down
16 changes: 3 additions & 13 deletions xdsl/tools/tblgen_to_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,12 @@ def _resolve_prop_constraint(self, rec: TblgenRecord | str) -> str:
return "BaseAttr(BoolAttr)"
case "IndexAttr":
return textwrap.dedent("""
ParamAttrConstraint(
IntegerAttr, (AnyAttr(), EqAttrConstraint(IndexType()))
)
IntegerAttr[IndexType].constr(type=IndexTypeConstr)
""")

case "APIntAttr":
return textwrap.dedent("""
ParamAttrConstraint(
IntegerAttr, (AnyAttr(), AnyAttr())
)
IntegerAttr[Attribute].constr()
""") # TODO can't represent APInt properly

case "StrAttr":
Expand All @@ -355,13 +351,7 @@ def _resolve_prop_constraint(self, rec: TblgenRecord | str) -> str:
or "UnsignedIntegerAttrBase" in rec.superclasses
):
return textwrap.dedent(f"""
ParamAttrConstraint(
IntegerAttr,
(
AnyAttr(),
{self._resolve_type_constraint(rec["valueType"]["def"])},
),
)
IntegerAttr[IntegerType].constr(type={self._resolve_type_constraint(rec["valueType"]["def"])})
""")

if "FloatAttrBase" in rec.superclasses:
Expand Down
3 changes: 2 additions & 1 deletion xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, bufferization, func, linalg, memref, stencil, tensor
from xdsl.dialects.builtin import (
AnyTensorTypeConstr,
DenseArrayBase,
DenseIntOrFPElementsAttr,
FunctionType,
Expand Down Expand Up @@ -163,7 +164,7 @@ def _get_empty_bufferized_region(args: Sequence[BlockArgument]) -> Region:
arg_types=[
(
tensor_to_memref_type(arg.type)
if isattr(arg.type, TensorType)
if isattr(arg.type, AnyTensorTypeConstr)
else arg.type
)
for arg in args
Expand Down
5 changes: 2 additions & 3 deletions xdsl/transforms/lower_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from xdsl.dialects import arith, builtin, scf
from xdsl.dialects.csl import csl, csl_wrapper
from xdsl.ir import Block, Operation, Region, SSAValue
from xdsl.irdl import base
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -43,7 +42,7 @@ def _collect_params(op: csl_wrapper.ModuleOp) -> list[SSAValue]:
"""
params = list[SSAValue]()
for param in op.params:
if isa(param.value, builtin.IntegerAttr):
if isattr(param.value, builtin.AnyIntegerAttrConstr):
value = arith.Constant(param.value)
else:
value = None
Expand Down Expand Up @@ -177,7 +176,7 @@ def lower_layout_module(
def _collect_yield_args(yield_op: csl_wrapper.YieldOp) -> list[csl.ParamOp]:
params = list[csl.ParamOp]()
for s, v in yield_op.items():
assert isattr(ty := v.type, base(csl.ParamOp.T))
assert isattr(ty := v.type, csl.ParamOpAttrConstr)
params.append(csl.ParamOp(s, ty))
return params

Expand Down
18 changes: 11 additions & 7 deletions xdsl/transforms/memref_to_dsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,32 @@
op_type_rewrite_pattern,
)
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr


class LowerAllocOpPass(RewritePattern):
"""Lowers `memref.alloc` to `csl.zeros`."""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.Alloc, rewriter: PatternRewriter, /):
assert isa(op.memref.type, MemRefType[csl.ZerosOp.T])
zeros_op = csl.ZerosOp(op.memref.type)
assert isattr(
memref_type := op.memref.type,
MemRefType[csl.ZerosOpAttr].constr(element_type=csl.ZerosOpAttrConstr),
)
zeros_op = csl.ZerosOp(memref_type)

dsd_t = csl.DsdType(
csl.DsdKind.mem1d_dsd
if len(op.memref.type.shape) == 1
if len(memref_type.shape) == 1
else csl.DsdKind.mem4d_dsd
)
offsets = None
if isinstance(op.memref.type.layout, StridedLayoutAttr) and isinstance(
op.memref.type.layout.offset, IntAttr
if isinstance(memref_type.layout, StridedLayoutAttr) and isinstance(
memref_type.layout.offset, IntAttr
):
offsets = ArrayAttr([IntegerAttr(op.memref.type.layout.offset, 16)])
offsets = ArrayAttr([IntegerAttr(memref_type.layout.offset, 16)])

shape = [arith.Constant(IntegerAttr(d, 16)) for d in op.memref.type.shape]
shape = [arith.Constant(IntegerAttr(d, 16)) for d in memref_type.shape]
dsd_op = csl.GetMemDsdOp.build(
operands=[zeros_op, shape],
result_types=[dsd_t],
Expand Down

0 comments on commit 170ce50

Please sign in to comment.