From 187eb3178383497125e345ac75ac24d73bdcc1d0 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Thu, 10 Oct 2024 22:28:01 +0100 Subject: [PATCH] misc: use VarConstraint instead of Annotated[ConstraintVar] in operation definitions (#3264) Using Annotated in the way that we do is not well defined by the Python type system, this fixes the current misuse to get the number of Pylance errors to 0. The current proposal is to keep both implementations working, and to deprecate the `Annotated` API at some point in the future. --- docs/irdl.ipynb | 5 +- .../irdl/test_declarative_assembly_format.py | 14 +-- tests/irdl/test_operation_definition.py | 115 ++++++++++++++++-- xdsl/dialects/affine.py | 15 +-- xdsl/dialects/arith.py | 12 +- xdsl/dialects/comb.py | 23 ++-- xdsl/dialects/csl/csl.py | 34 +++--- xdsl/dialects/eqsat.py | 7 +- xdsl/dialects/llvm.py | 7 +- xdsl/dialects/ltl.py | 14 +-- xdsl/dialects/memref.py | 26 ++-- xdsl/dialects/memref_stream.py | 9 +- xdsl/dialects/mod_arith.py | 9 +- xdsl/dialects/riscv_snitch.py | 24 ++-- xdsl/dialects/scf.py | 10 +- xdsl/dialects/seq.py | 15 +-- xdsl/dialects/stablehlo.py | 12 +- xdsl/dialects/stencil.py | 2 - xdsl/dialects/stream.py | 12 +- xdsl/dialects/transform.py | 3 - xdsl/dialects/varith.py | 6 +- 21 files changed, 247 insertions(+), 127 deletions(-) diff --git a/docs/irdl.ipynb b/docs/irdl.ipynb index f92f71739e..144601f39a 100644 --- a/docs/irdl.ipynb +++ b/docs/irdl.ipynb @@ -1058,14 +1058,15 @@ } ], "source": [ - "from xdsl.irdl import ConstraintVar\n", + "from typing import ClassVar\n", + "from xdsl.irdl import base, VarConstraint\n", "\n", "\n", "@irdl_op_definition\n", "class BinaryOp(IRDLOperation):\n", " name = \"binary_op\"\n", "\n", - " T = Annotated[IntegerType, ConstraintVar(\"T\")]\n", + " T: ClassVar[VarConstraint[IntegerType]] = VarConstraint(\"T\", base(IntegerType))\n", "\n", " lhs = operand_def(T)\n", " rhs = operand_def(T)\n", diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index d22853edae..c49f99cac2 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -3,7 +3,7 @@ import textwrap from collections.abc import Callable from io import StringIO -from typing import Annotated, Generic, TypeVar +from typing import ClassVar, Generic, TypeVar import pytest @@ -23,13 +23,13 @@ AttrSizedRegionSegments, AttrSizedResultSegments, BaseAttr, - ConstraintVar, EqAttrConstraint, GenericAttrConstraint, IRDLOperation, ParamAttrConstraint, ParameterDef, ParsePropInAttrDict, + VarConstraint, VarOperand, VarOpResult, attr_def, @@ -1578,7 +1578,7 @@ def test_basic_inference(format: str): @irdl_op_definition class TwoOperandsOneResultWithVarOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "test.two_operands_one_result_with_var" res = result_def(T) @@ -1678,11 +1678,11 @@ def constr( @irdl_op_definition class TwoOperandsNestedVarOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "test.two_operands_one_result_with_var" res = result_def(T) - lhs = operand_def(ParamOne[T]) + lhs = operand_def(ParamOne[Attribute].constr(p=T)) rhs = operand_def(T) assembly_format = "$lhs $rhs attr-dict `:` type($lhs)" @@ -1722,11 +1722,11 @@ def constr( @irdl_op_definition class OneOperandOneResultNestedOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "test.one_operand_one_result_nested" res = result_def(T) - lhs = operand_def(ParamOne[T]) + lhs = operand_def(ParamOne[Attribute].constr(p=T)) assembly_format = "$lhs attr-dict `:` type($lhs)" diff --git a/tests/irdl/test_operation_definition.py b/tests/irdl/test_operation_definition.py index 30ee723f13..b55208e00f 100644 --- a/tests/irdl/test_operation_definition.py +++ b/tests/irdl/test_operation_definition.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar import pytest @@ -31,7 +31,9 @@ RangeOf, RegionDef, ResultDef, + VarConstraint, attr_def, + base, irdl_op_definition, operand_def, opt_attr_def, @@ -157,15 +159,16 @@ def test_attr_verify(): op.verify() +# TODO: remove this test once the Annotated API is deprecated @irdl_op_definition class ConstraintVarOp(IRDLOperation): name = "test.constraint_var_op" T = Annotated[IntegerType | IndexType, ConstraintVar("T")] - operand = operand_def(T) - result = result_def(T) - attribute = attr_def(T) + operand = operand_def(T) # pyright: ignore[reportArgumentType] + result = result_def(T) # pyright: ignore[reportArgumentType] + attribute = attr_def(T) # pyright: ignore[reportArgumentType, reportUnknownVariableType] def test_constraint_var(): @@ -193,7 +196,10 @@ def test_constraint_var_fail_non_equal(): op = ConstraintVarOp.create( operands=[index_operand], result_types=[i32], attributes={"attribute": i32} ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): op.verify() # Fail because of result @@ -202,7 +208,10 @@ def test_constraint_var_fail_non_equal(): result_types=[IndexType()], attributes={"attribute": i32}, ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): op2.verify() # Fail because of attribute @@ -211,7 +220,10 @@ def test_constraint_var_fail_non_equal(): result_types=[i32], attributes={"attribute": IndexType()}, ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: attribute i32 expected from variable 'T', but got index", + ): op3.verify() @@ -223,7 +235,94 @@ def test_constraint_var_fail_not_satisfy_constraint(): result_types=[TestType("foo")], attributes={"attribute": TestType("foo")}, ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: operand at position 0 does not verify", + ): + op.verify() + + +@irdl_op_definition +class GenericConstraintVarOp(IRDLOperation): + name = "test.constraint_var_op" + + T: ClassVar[VarConstraint[IntegerType | IndexType]] = VarConstraint( + "T", base(IntegerType) | base(IndexType) + ) + + operand = operand_def(T) + result = result_def(T) + attribute = attr_def(T) + + +def test_generic_constraint_var(): + i32_operand = TestSSAValue(i32) + index_operand = TestSSAValue(IndexType()) + op = GenericConstraintVarOp.create( + operands=[i32_operand], result_types=[i32], attributes={"attribute": i32} + ) + op.verify() + + op2 = GenericConstraintVarOp.create( + operands=[index_operand], + result_types=[IndexType()], + attributes={"attribute": IndexType()}, + ) + op2.verify() + + +def test_generic_constraint_var_fail_non_equal(): + """Check that all uses of a constraint variable are of the same attribute.""" + i32_operand = TestSSAValue(i32) + index_operand = TestSSAValue(IndexType()) + + # Fail because of operand + op = GenericConstraintVarOp.create( + operands=[index_operand], result_types=[i32], attributes={"attribute": i32} + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): + op.verify() + + # Fail because of result + op2 = GenericConstraintVarOp.create( + operands=[i32_operand], + result_types=[IndexType()], + attributes={"attribute": i32}, + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): + op2.verify() + + # Fail because of attribute + op3 = GenericConstraintVarOp.create( + operands=[i32_operand], + result_types=[i32], + attributes={"attribute": IndexType()}, + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: attribute i32 expected from variable 'T', but got index", + ): + op3.verify() + + +def test_generic_constraint_var_fail_not_satisfy_constraint(): + """Check that all uses of a constraint variable are satisfying the constraint.""" + test_operand = TestSSAValue(TestType("foo")) + op = GenericConstraintVarOp.create( + operands=[test_operand], + result_types=[TestType("foo")], + attributes={"attribute": TestType("foo")}, + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: operand at position 0 does not verify", + ): op.verify() diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 752cfabae6..951409fbfa 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Annotated, Any, cast +from typing import Any, ClassVar, cast from xdsl.dialects.builtin import ( AffineMapAttr, @@ -20,9 +20,10 @@ from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, + VarConstraint, attr_def, irdl_op_definition, operand_def, @@ -258,10 +259,10 @@ def verify_(self) -> None: class Store(IRDLOperation): name = "affine.store" - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) value = operand_def(T) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType) map = opt_prop_def(AffineMapAttr) @@ -291,9 +292,9 @@ def __init__( class Load(IRDLOperation): name = "affine.load" - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType) result = result_def(T) @@ -305,7 +306,7 @@ def __init__( memref: SSAValue, indices: Sequence[SSAValue], map: AffineMapAttr | None = None, - result_type: T | None = None, + result_type: Attribute | None = None, ): if map is None: # Create identity map for memrefs with at least one dimension or () -> () diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index e196e09b9c..b5281171aa 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -2,7 +2,7 @@ import abc from collections.abc import Mapping, Sequence -from typing import Annotated, Literal, TypeVar, cast, overload +from typing import ClassVar, Literal, TypeVar, cast, overload from xdsl.dialects.builtin import ( AnyFloat, @@ -26,8 +26,8 @@ from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, - ConstraintVar, IRDLOperation, + VarConstraint, base, irdl_attr_definition, irdl_op_definition, @@ -175,7 +175,7 @@ def parse(cls: type[Constant], parser: Parser) -> Constant: class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC): """A generic base class for arith's binary operations on signless integers.""" - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) lhs = operand_def(T) rhs = operand_def(T) @@ -216,7 +216,7 @@ def __hash__(self) -> int: class FloatingPointLikeBinaryOperation(IRDLOperation, abc.ABC): """A generic base class for arith's binary operations on floats.""" - T = Annotated[Attribute, ConstraintVar("T"), floatingPointLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", floatingPointLike) lhs = operand_def(T) rhs = operand_def(T) @@ -290,7 +290,7 @@ class AddUIExtended(IRDLOperation): traits = frozenset([Pure()]) - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) lhs = operand_def(T) rhs = operand_def(T) @@ -353,7 +353,7 @@ class Muli(SignlessIntegerBinaryOperation): class MulExtendedBase(IRDLOperation): """Base class for extended multiplication operations.""" - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) lhs = operand_def(T) rhs = operand_def(T) diff --git a/xdsl/dialects/comb.py b/xdsl/dialects/comb.py index 4655f09667..a170642213 100644 --- a/xdsl/dialects/comb.py +++ b/xdsl/dialects/comb.py @@ -11,14 +11,21 @@ from abc import ABC from collections.abc import Sequence -from typing import Annotated - -from xdsl.dialects.builtin import I32, I64, IntegerAttr, IntegerType, UnitAttr +from typing import ClassVar + +from xdsl.dialects.builtin import ( + I32, + I64, + IntegerAttr, + IntegerType, + UnitAttr, +) from xdsl.ir import Attribute, Dialect, Operation, SSAValue, TypeAttribute from xdsl.irdl import ( - ConstraintVar, IRDLOperation, + VarConstraint, attr_def, + base, irdl_op_definition, operand_def, opt_attr_def, @@ -49,7 +56,7 @@ class BinCombOperation(IRDLOperation, ABC): result, all of the same integer type. """ - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", base(IntegerType)) lhs = operand_def(T) rhs = operand_def(T) @@ -96,7 +103,7 @@ class VariadicCombOperation(IRDLOperation, ABC): result, all of the same integer type. """ - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", base(IntegerType)) inputs = var_operand_def(T) result = result_def(T) @@ -253,7 +260,7 @@ class ICmpOp(IRDLOperation, ABC): name = "comb.icmp" - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", base(IntegerType)) predicate = attr_def(IntegerAttr[I64]) lhs = operand_def(T) @@ -555,7 +562,7 @@ class MuxOp(IRDLOperation): name = "comb.mux" - T = Annotated[TypeAttribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[TypeAttribute]] = VarConstraint("T", base(TypeAttribute)) cond = operand_def(IntegerType(1)) true_value = operand_def(T) diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index de3dc3686e..a1f0bf30f4 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -52,10 +52,10 @@ from xdsl.irdl import ( AnyOf, BaseAttr, - ConstraintVar, IRDLOperation, ParameterDef, ParametrizedAttribute, + VarConstraint, attr_def, base, eq, @@ -597,17 +597,19 @@ class ZerosOp(IRDLOperation): name = "csl.zeros" - T = Annotated[IntegerType | Float32Type | Float16Type, ConstraintVar("T")] + T: ClassVar[VarConstraint[ZerosOpAttr]] = VarConstraint("T", ZerosOpAttrConstr) size = opt_operand_def(T) - result = result_def(MemRefType[T]) + result = result_def( + MemRefType[IntegerType | Float32Type | Float16Type].constr(element_type=T) + ) is_const = opt_prop_def(builtin.UnitAttr) def __init__( self, - memref: MemRefType[T], + memref: MemRefType[IntegerType | Float32Type | Float16Type], dynamic_size: SSAValue | Operation | None = None, is_const: builtin.UnitAttr | None = None, ): @@ -630,13 +632,17 @@ class ConstantsOp(IRDLOperation): name = "csl.constants" - T = Annotated[IntegerType | Float32Type | Float16Type, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType | Float32Type | Float16Type]] = VarConstraint( + "T", BaseAttr(IntegerType) | BaseAttr(Float32Type) | BaseAttr(Float16Type) + ) size = operand_def(IntegerType) value = operand_def(T) - result = result_def(MemRefType[T]) + result = result_def( + MemRefType[IntegerType | Float32Type | Float16Type].constr(element_type=T) + ) is_const = opt_prop_def(builtin.UnitAttr) @@ -1942,16 +1948,7 @@ class ParamOp(IRDLOperation): command line by passing params to the compiler. """ - T = Annotated[ - Float16Type - | Float32Type - | IntegerType - | ColorType - | FunctionType - | ImportedModuleType - | ComptimeStructType, - ConstraintVar("T"), - ] + T: ClassVar[VarConstraint[ParamOpAttr]] = VarConstraint("T", ParamOpAttrConstr) name = "csl.param" @@ -1963,7 +1960,10 @@ class ParamOp(IRDLOperation): res = result_def(T) def __init__( - self, name: str, result_type: T, init_value: SSAValue | Operation | None = None + self, + name: str, + result_type: ParamOpAttr, + init_value: SSAValue | Operation | None = None, ): super().__init__( operands=[init_value], diff --git a/xdsl/dialects/eqsat.py b/xdsl/dialects/eqsat.py index 637a2d9d79..a92727f437 100644 --- a/xdsl/dialects/eqsat.py +++ b/xdsl/dialects/eqsat.py @@ -11,12 +11,13 @@ from __future__ import annotations -from typing import Annotated +from typing import ClassVar from xdsl.ir import Attribute, Dialect, SSAValue from xdsl.irdl import ( - ConstraintVar, + AnyAttr, IRDLOperation, + VarConstraint, irdl_op_definition, result_def, var_operand_def, @@ -26,7 +27,7 @@ @irdl_op_definition class EClassOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "eqsat.eclass" arguments = var_operand_def(T) diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index d1b316dd9d..82914929bb 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from types import EllipsisType -from typing import Annotated +from typing import ClassVar from xdsl.dialects.builtin import ( I64, @@ -34,9 +34,10 @@ TypeAttribute, ) from xdsl.irdl import ( - ConstraintVar, + BaseAttr, IRDLOperation, ParameterDef, + VarConstraint, base, irdl_attr_definition, irdl_op_definition, @@ -355,7 +356,7 @@ def verify(self): class ArithmeticBinOperation(IRDLOperation, ABC): """Class for arithmetic binary operations.""" - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", BaseAttr(IntegerType)) lhs = operand_def(T) rhs = operand_def(T) diff --git a/xdsl/dialects/ltl.py b/xdsl/dialects/ltl.py index 8e03e8246c..0a9ef5b4eb 100644 --- a/xdsl/dialects/ltl.py +++ b/xdsl/dialects/ltl.py @@ -6,14 +6,14 @@ from __future__ import annotations -from typing import Annotated +from typing import ClassVar -from xdsl.dialects.builtin import IntegerType, Signedness +from xdsl.dialects.builtin import IntegerType from xdsl.ir import Attribute, Dialect, ParametrizedAttribute, SSAValue, TypeAttribute from xdsl.irdl import ( AnyOf, - ConstraintVar, IRDLOperation, + VarConstraint, irdl_attr_definition, irdl_op_definition, result_def, @@ -51,11 +51,9 @@ class AndOp(IRDLOperation): name = "ltl.and" - T = Annotated[ - Attribute, - AnyOf([Sequence, Property, IntegerType(1, signedness=Signedness.SIGNLESS)]), - ConstraintVar("T"), - ] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint( + "T", AnyOf([Sequence, Property, IntegerType(1)]) + ) input = var_operand_def(T) diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 0b6acb460d..90e773cbae 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -8,6 +8,7 @@ from xdsl.dialects.builtin import ( I64, AnyFloat, + AnyFloatConstr, AnyIntegerAttr, AnySignlessIntegerType, ArrayAttr, @@ -21,6 +22,7 @@ MemrefLayoutAttr, MemRefType, NoneAttr, + SignlessIntegerConstraint, StridedLayoutAttr, StringAttr, SymbolRefAttr, @@ -35,11 +37,12 @@ ) from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParsePropInAttrDict, SameVariadicResultSize, + VarConstraint, base, irdl_op_definition, operand_def, @@ -67,13 +70,13 @@ @irdl_op_definition class Load(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] - name = "memref.load" + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) + nontemporal = opt_prop_def(BoolAttr) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType()) res = result_def(T) @@ -106,14 +109,14 @@ def get( @irdl_op_definition class Store(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "memref.store" nontemporal = opt_prop_def(BoolAttr) value = operand_def(T) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType()) irdl_options = [ParsePropInAttrDict()] @@ -357,13 +360,14 @@ def verify_(self) -> None: class AtomicRMWOp(IRDLOperation): name = "memref.atomic_rmw" - T = Annotated[ - AnyFloat | AnySignlessIntegerType, - ConstraintVar("T"), - ] + T: ClassVar[VarConstraint[AnyFloat | AnySignlessIntegerType]] = VarConstraint( + "T", AnyFloatConstr | SignlessIntegerConstraint + ) value = operand_def(T) - memref = operand_def(MemRefType[T]) + memref = operand_def( + MemRefType[AnyFloat | AnySignlessIntegerType].constr(element_type=T) + ) indices = var_operand_def(IndexType) kind = prop_def(IntegerAttr[I64]) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 56a2ba107b..6b14be9bd5 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from enum import auto from itertools import product -from typing import Annotated, Any, cast +from typing import Any, ClassVar, cast from typing_extensions import Self @@ -35,10 +35,11 @@ SSAValue, ) from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParameterDef, + VarConstraint, base, irdl_attr_definition, irdl_op_definition, @@ -848,9 +849,9 @@ class YieldOp(AbstractYieldOperation[Attribute]): class FillOp(IRDLOperation): name = "memref_stream.fill" - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) - memref = operand_def(memref.MemRefType[T]) + memref = operand_def(memref.MemRefType[Attribute].constr(element_type=T)) value = operand_def(T) assembly_format = "$memref `with` $value attr-dict `:` type($memref)" diff --git a/xdsl/dialects/mod_arith.py b/xdsl/dialects/mod_arith.py index 8c2cd11f24..3c8765192c 100644 --- a/xdsl/dialects/mod_arith.py +++ b/xdsl/dialects/mod_arith.py @@ -4,15 +4,15 @@ """ from abc import ABC -from typing import Annotated +from typing import ClassVar from xdsl.dialects.arith import signlessIntegerLike from xdsl.dialects.builtin import AnyIntegerAttr from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( - ConstraintVar, IRDLOperation, ParsePropInAttrDict, + VarConstraint, irdl_op_definition, operand_def, prop_def, @@ -26,11 +26,12 @@ class BinaryOp(IRDLOperation, ABC): Simple binary operation """ - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] - modulus = prop_def(AnyIntegerAttr) + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) + lhs = operand_def(T) rhs = operand_def(T) output = result_def(T) + modulus = prop_def(AnyIntegerAttr) irdl_options = [ParsePropInAttrDict()] diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 56ac11fc39..0c58b4f2b9 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -2,7 +2,7 @@ from abc import ABC from collections.abc import Sequence -from typing import Annotated, TypeAlias, TypeVar, cast +from typing import ClassVar, cast from typing_extensions import Self @@ -39,8 +39,9 @@ ) from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( - ConstraintVar, + VarConstraint, attr_def, + base, irdl_op_definition, operand_def, opt_attr_def, @@ -801,9 +802,6 @@ class VFMaxSOp(riscv.RdRsRsFloatOperationWithFastMath): traits = frozenset((Pure(),)) -RdRsFloatInvT = TypeVar("RdRsFloatInvT", bound=FloatRegisterType) - - class RdRsRsAccumulatingFloatOperationWithFastMath(RISCVInstruction, ABC): """ A base class for RISC-V operations that have one destination floating-point register, @@ -811,10 +809,12 @@ class RdRsRsAccumulatingFloatOperationWithFastMath(RISCVInstruction, ABC): be annotated with fastmath flags. """ - SameFloatRegisterType: TypeAlias = Annotated[RdRsFloatInvT, ConstraintVar("RdRs")] + SAME_FLOAT_REGISTER_TYPE: ClassVar[VarConstraint[FloatRegisterType]] = ( + VarConstraint("SAME_FLOAT_REGISTER_TYPE", base(FloatRegisterType)) + ) - rd_out = result_def(SameFloatRegisterType) - rd_in = operand_def(SameFloatRegisterType) + rd_out = result_def(SAME_FLOAT_REGISTER_TYPE) + rd_in = operand_def(SAME_FLOAT_REGISTER_TYPE) rs1 = operand_def(FloatRegisterType) rs2 = operand_def(FloatRegisterType) @@ -873,10 +873,12 @@ class RdRsAccumulatingFloatOperation(RISCVInstruction, ABC): that also acts as a source register, and a source floating-point register. """ - SameFloatRegisterType: TypeAlias = Annotated[RdRsFloatInvT, ConstraintVar("RdRs")] + SAME_FLOAT_REGISTER_TYPE: ClassVar[VarConstraint[FloatRegisterType]] = ( + VarConstraint("SAME_FLOAT_REGISTER_TYPE", base(FloatRegisterType)) + ) - rd_out = result_def(SameFloatRegisterType) - rd_in = operand_def(SameFloatRegisterType) + rd_out = result_def(SAME_FLOAT_REGISTER_TYPE) + rd_in = operand_def(SAME_FLOAT_REGISTER_TYPE) rs = operand_def(FloatRegisterType) def __init__( diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index e9b80d99fe..edef7369cb 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Annotated +from typing import ClassVar from typing_extensions import Self @@ -10,6 +10,7 @@ DenseArrayBase, IndexType, IntegerType, + SignlessIntegerConstraint, i64, ) from xdsl.dialects.utils import ( @@ -20,8 +21,9 @@ from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, + VarConstraint, + base, irdl_op_definition, operand_def, prop_def, @@ -284,7 +286,9 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: class For(IRDLOperation): name = "scf.for" - T = Annotated[AnySignlessIntegerOrIndexType, ConstraintVar("T")] + T: ClassVar[VarConstraint[AnySignlessIntegerOrIndexType]] = VarConstraint( + "T", base(IndexType) | SignlessIntegerConstraint + ) lb = operand_def(T) ub = operand_def(T) diff --git a/xdsl/dialects/seq.py b/xdsl/dialects/seq.py index a861221be9..29f74304bc 100644 --- a/xdsl/dialects/seq.py +++ b/xdsl/dialects/seq.py @@ -5,7 +5,7 @@ """ from enum import Enum -from typing import Annotated +from typing import ClassVar from xdsl.dialects.builtin import ( AnyIntegerAttr, @@ -17,10 +17,11 @@ from xdsl.dialects.hw import InnerSymAttr from xdsl.ir import Attribute, Data, Dialect, Operation, SSAValue from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParametrizedAttribute, + VarConstraint, attr_def, irdl_attr_definition, irdl_op_definition, @@ -93,15 +94,15 @@ class CompRegOp(IRDLOperation): name = "seq.compreg" - DataType = Annotated[Attribute, ConstraintVar("DataType")] + DATA_TYPE: ClassVar[VarConstraint[Attribute]] = VarConstraint("DataType", AnyAttr()) inner_sym = opt_attr_def(InnerSymAttr) - input = operand_def(DataType) + input = operand_def(DATA_TYPE) clk = operand_def(clock) reset = opt_operand_def(i1) - reset_value = opt_operand_def(DataType) - power_on_value = opt_operand_def(DataType) - data = result_def(DataType) + reset_value = opt_operand_def(DATA_TYPE) + power_on_value = opt_operand_def(DATA_TYPE) + data = result_def(DATA_TYPE) irdl_options = [AttrSizedOperandSegments()] diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index 33d1fbe472..1cffaa38ea 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -8,7 +8,7 @@ import abc from collections.abc import Sequence -from typing import Annotated, TypeAlias, cast +from typing import Annotated, ClassVar, TypeAlias, cast from xdsl.dialects.builtin import ( I32, @@ -38,7 +38,9 @@ ConstraintVar, IRDLOperation, ParameterDef, + VarConstraint, attr_def, + base, irdl_attr_definition, irdl_op_definition, operand_def, @@ -57,7 +59,7 @@ class ElementwiseBinaryOperation(IRDLOperation, abc.ABC): # TODO: Remove this constraint for complex types. - T = Annotated[AnyTensorType, ConstraintVar("T")] + T: ClassVar[VarConstraint[AnyTensorType]] = VarConstraint("T", base(AnyTensorType)) lhs = operand_def(T) rhs = operand_def(T) @@ -221,7 +223,7 @@ class AbsOp(IRDLOperation): name = "stablehlo.abs" # TODO: Remove this constraint for complex types. - T = Annotated[AnyTensorType, ConstraintVar("T")] + T: ClassVar[VarConstraint[AnyTensorType]] = VarConstraint("T", base(AnyTensorType)) operand = operand_def(T) result = result_def(T) @@ -284,7 +286,9 @@ class AndOp(IRDLOperation): name = "stablehlo.and" - T = Annotated[IntegerTensorType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerTensorType]] = VarConstraint( + "T", base(IntegerTensorType) + ) lhs = operand_def(T) rhs = operand_def(T) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index c6e6deffad..f0cef4d084 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -479,8 +479,6 @@ class ApplyOp(IRDLOperation): name = "stencil.apply" - B = Annotated[Attribute, ConstraintVar("B")] - args = var_operand_def(Attribute) dest = var_operand_def(FieldType) region = region_def() diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 794eaddb6a..2a70ca5c37 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import Annotated, Generic, TypeAlias, TypeVar, cast +from typing import ClassVar, Generic, TypeAlias, TypeVar, cast from typing_extensions import Self @@ -18,11 +18,11 @@ from xdsl.irdl import ( AnyAttr, BaseAttr, - ConstraintVar, GenericAttrConstraint, IRDLOperation, ParamAttrConstraint, ParameterDef, + VarConstraint, irdl_attr_definition, irdl_op_definition, operand_def, @@ -80,9 +80,9 @@ class ReadOperation(IRDLOperation, abc.ABC): Abstract base class for operations that read from a stream. """ - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType[T]) + stream = operand_def(ReadableStreamType[Attribute].constr(element_type=T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -113,10 +113,10 @@ class WriteOperation(IRDLOperation, abc.ABC): Abstract base class for operations that write to a stream. """ - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType[T]) + stream = operand_def(WritableStreamType[Attribute].constr(element_type=T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) diff --git a/xdsl/dialects/transform.py b/xdsl/dialects/transform.py index 43801a6aa3..4dea325fc8 100644 --- a/xdsl/dialects/transform.py +++ b/xdsl/dialects/transform.py @@ -28,7 +28,6 @@ from xdsl.irdl import ( AnyOf, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParameterDef, attr_def, @@ -525,8 +524,6 @@ class SequenceOp(IRDLOperation): name = "transform.sequence" - T = Annotated[AnyIntegerOrFailurePropagationModeAttr, ConstraintVar("T")] - body = region_def("single_block") failure_propagation_mode = prop_def(Attribute) root = var_operand_def(AnyOpType) diff --git a/xdsl/dialects/varith.py b/xdsl/dialects/varith.py index a10dc28996..47ed605c12 100644 --- a/xdsl/dialects/varith.py +++ b/xdsl/dialects/varith.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import ClassVar from xdsl.dialects.builtin import ( BFloat16Type, @@ -14,8 +14,8 @@ from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, - ConstraintVar, IRDLOperation, + VarConstraint, irdl_op_definition, result_def, var_operand_def, @@ -43,7 +43,7 @@ class VarithOp(IRDLOperation): Variadic arithmetic operation """ - T = Annotated[Attribute, ConstraintVar("T"), integerOrFloatLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", integerOrFloatLike) args = var_operand_def(T) res = result_def(T)