diff --git a/docs/irdl.ipynb b/docs/irdl.ipynb index f92f71739e..144601f39a 100644 --- a/docs/irdl.ipynb +++ b/docs/irdl.ipynb @@ -1058,14 +1058,15 @@ } ], "source": [ - "from xdsl.irdl import ConstraintVar\n", + "from typing import ClassVar\n", + "from xdsl.irdl import base, VarConstraint\n", "\n", "\n", "@irdl_op_definition\n", "class BinaryOp(IRDLOperation):\n", " name = \"binary_op\"\n", "\n", - " T = Annotated[IntegerType, ConstraintVar(\"T\")]\n", + " T: ClassVar[VarConstraint[IntegerType]] = VarConstraint(\"T\", base(IntegerType))\n", "\n", " lhs = operand_def(T)\n", " rhs = operand_def(T)\n", diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index d22853edae..c49f99cac2 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -3,7 +3,7 @@ import textwrap from collections.abc import Callable from io import StringIO -from typing import Annotated, Generic, TypeVar +from typing import ClassVar, Generic, TypeVar import pytest @@ -23,13 +23,13 @@ AttrSizedRegionSegments, AttrSizedResultSegments, BaseAttr, - ConstraintVar, EqAttrConstraint, GenericAttrConstraint, IRDLOperation, ParamAttrConstraint, ParameterDef, ParsePropInAttrDict, + VarConstraint, VarOperand, VarOpResult, attr_def, @@ -1578,7 +1578,7 @@ def test_basic_inference(format: str): @irdl_op_definition class TwoOperandsOneResultWithVarOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "test.two_operands_one_result_with_var" res = result_def(T) @@ -1678,11 +1678,11 @@ def constr( @irdl_op_definition class TwoOperandsNestedVarOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "test.two_operands_one_result_with_var" res = result_def(T) - lhs = operand_def(ParamOne[T]) + lhs = operand_def(ParamOne[Attribute].constr(p=T)) rhs = operand_def(T) assembly_format = "$lhs $rhs attr-dict `:` type($lhs)" @@ -1722,11 +1722,11 @@ def constr( @irdl_op_definition class OneOperandOneResultNestedOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "test.one_operand_one_result_nested" res = result_def(T) - lhs = operand_def(ParamOne[T]) + lhs = operand_def(ParamOne[Attribute].constr(p=T)) assembly_format = "$lhs attr-dict `:` type($lhs)" diff --git a/tests/irdl/test_operation_definition.py b/tests/irdl/test_operation_definition.py index 30ee723f13..b55208e00f 100644 --- a/tests/irdl/test_operation_definition.py +++ b/tests/irdl/test_operation_definition.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Annotated, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar import pytest @@ -31,7 +31,9 @@ RangeOf, RegionDef, ResultDef, + VarConstraint, attr_def, + base, irdl_op_definition, operand_def, opt_attr_def, @@ -157,15 +159,16 @@ def test_attr_verify(): op.verify() +# TODO: remove this test once the Annotated API is deprecated @irdl_op_definition class ConstraintVarOp(IRDLOperation): name = "test.constraint_var_op" T = Annotated[IntegerType | IndexType, ConstraintVar("T")] - operand = operand_def(T) - result = result_def(T) - attribute = attr_def(T) + operand = operand_def(T) # pyright: ignore[reportArgumentType] + result = result_def(T) # pyright: ignore[reportArgumentType] + attribute = attr_def(T) # pyright: ignore[reportArgumentType, reportUnknownVariableType] def test_constraint_var(): @@ -193,7 +196,10 @@ def test_constraint_var_fail_non_equal(): op = ConstraintVarOp.create( operands=[index_operand], result_types=[i32], attributes={"attribute": i32} ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): op.verify() # Fail because of result @@ -202,7 +208,10 @@ def test_constraint_var_fail_non_equal(): result_types=[IndexType()], attributes={"attribute": i32}, ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): op2.verify() # Fail because of attribute @@ -211,7 +220,10 @@ def test_constraint_var_fail_non_equal(): result_types=[i32], attributes={"attribute": IndexType()}, ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: attribute i32 expected from variable 'T', but got index", + ): op3.verify() @@ -223,7 +235,94 @@ def test_constraint_var_fail_not_satisfy_constraint(): result_types=[TestType("foo")], attributes={"attribute": TestType("foo")}, ) - with pytest.raises(DiagnosticException): + with pytest.raises( + DiagnosticException, + match="Operation does not verify: operand at position 0 does not verify", + ): + op.verify() + + +@irdl_op_definition +class GenericConstraintVarOp(IRDLOperation): + name = "test.constraint_var_op" + + T: ClassVar[VarConstraint[IntegerType | IndexType]] = VarConstraint( + "T", base(IntegerType) | base(IndexType) + ) + + operand = operand_def(T) + result = result_def(T) + attribute = attr_def(T) + + +def test_generic_constraint_var(): + i32_operand = TestSSAValue(i32) + index_operand = TestSSAValue(IndexType()) + op = GenericConstraintVarOp.create( + operands=[i32_operand], result_types=[i32], attributes={"attribute": i32} + ) + op.verify() + + op2 = GenericConstraintVarOp.create( + operands=[index_operand], + result_types=[IndexType()], + attributes={"attribute": IndexType()}, + ) + op2.verify() + + +def test_generic_constraint_var_fail_non_equal(): + """Check that all uses of a constraint variable are of the same attribute.""" + i32_operand = TestSSAValue(i32) + index_operand = TestSSAValue(IndexType()) + + # Fail because of operand + op = GenericConstraintVarOp.create( + operands=[index_operand], result_types=[i32], attributes={"attribute": i32} + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): + op.verify() + + # Fail because of result + op2 = GenericConstraintVarOp.create( + operands=[i32_operand], + result_types=[IndexType()], + attributes={"attribute": i32}, + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: result at position 0 does not verify", + ): + op2.verify() + + # Fail because of attribute + op3 = GenericConstraintVarOp.create( + operands=[i32_operand], + result_types=[i32], + attributes={"attribute": IndexType()}, + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: attribute i32 expected from variable 'T', but got index", + ): + op3.verify() + + +def test_generic_constraint_var_fail_not_satisfy_constraint(): + """Check that all uses of a constraint variable are satisfying the constraint.""" + test_operand = TestSSAValue(TestType("foo")) + op = GenericConstraintVarOp.create( + operands=[test_operand], + result_types=[TestType("foo")], + attributes={"attribute": TestType("foo")}, + ) + with pytest.raises( + DiagnosticException, + match="Operation does not verify: operand at position 0 does not verify", + ): op.verify() diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 752cfabae6..951409fbfa 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Annotated, Any, cast +from typing import Any, ClassVar, cast from xdsl.dialects.builtin import ( AffineMapAttr, @@ -20,9 +20,10 @@ from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, + VarConstraint, attr_def, irdl_op_definition, operand_def, @@ -258,10 +259,10 @@ def verify_(self) -> None: class Store(IRDLOperation): name = "affine.store" - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) value = operand_def(T) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType) map = opt_prop_def(AffineMapAttr) @@ -291,9 +292,9 @@ def __init__( class Load(IRDLOperation): name = "affine.load" - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType) result = result_def(T) @@ -305,7 +306,7 @@ def __init__( memref: SSAValue, indices: Sequence[SSAValue], map: AffineMapAttr | None = None, - result_type: T | None = None, + result_type: Attribute | None = None, ): if map is None: # Create identity map for memrefs with at least one dimension or () -> () diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index e196e09b9c..b5281171aa 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -2,7 +2,7 @@ import abc from collections.abc import Mapping, Sequence -from typing import Annotated, Literal, TypeVar, cast, overload +from typing import ClassVar, Literal, TypeVar, cast, overload from xdsl.dialects.builtin import ( AnyFloat, @@ -26,8 +26,8 @@ from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, - ConstraintVar, IRDLOperation, + VarConstraint, base, irdl_attr_definition, irdl_op_definition, @@ -175,7 +175,7 @@ def parse(cls: type[Constant], parser: Parser) -> Constant: class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC): """A generic base class for arith's binary operations on signless integers.""" - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) lhs = operand_def(T) rhs = operand_def(T) @@ -216,7 +216,7 @@ def __hash__(self) -> int: class FloatingPointLikeBinaryOperation(IRDLOperation, abc.ABC): """A generic base class for arith's binary operations on floats.""" - T = Annotated[Attribute, ConstraintVar("T"), floatingPointLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", floatingPointLike) lhs = operand_def(T) rhs = operand_def(T) @@ -290,7 +290,7 @@ class AddUIExtended(IRDLOperation): traits = frozenset([Pure()]) - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) lhs = operand_def(T) rhs = operand_def(T) @@ -353,7 +353,7 @@ class Muli(SignlessIntegerBinaryOperation): class MulExtendedBase(IRDLOperation): """Base class for extended multiplication operations.""" - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) lhs = operand_def(T) rhs = operand_def(T) diff --git a/xdsl/dialects/comb.py b/xdsl/dialects/comb.py index 4655f09667..a170642213 100644 --- a/xdsl/dialects/comb.py +++ b/xdsl/dialects/comb.py @@ -11,14 +11,21 @@ from abc import ABC from collections.abc import Sequence -from typing import Annotated - -from xdsl.dialects.builtin import I32, I64, IntegerAttr, IntegerType, UnitAttr +from typing import ClassVar + +from xdsl.dialects.builtin import ( + I32, + I64, + IntegerAttr, + IntegerType, + UnitAttr, +) from xdsl.ir import Attribute, Dialect, Operation, SSAValue, TypeAttribute from xdsl.irdl import ( - ConstraintVar, IRDLOperation, + VarConstraint, attr_def, + base, irdl_op_definition, operand_def, opt_attr_def, @@ -49,7 +56,7 @@ class BinCombOperation(IRDLOperation, ABC): result, all of the same integer type. """ - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", base(IntegerType)) lhs = operand_def(T) rhs = operand_def(T) @@ -96,7 +103,7 @@ class VariadicCombOperation(IRDLOperation, ABC): result, all of the same integer type. """ - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", base(IntegerType)) inputs = var_operand_def(T) result = result_def(T) @@ -253,7 +260,7 @@ class ICmpOp(IRDLOperation, ABC): name = "comb.icmp" - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", base(IntegerType)) predicate = attr_def(IntegerAttr[I64]) lhs = operand_def(T) @@ -555,7 +562,7 @@ class MuxOp(IRDLOperation): name = "comb.mux" - T = Annotated[TypeAttribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[TypeAttribute]] = VarConstraint("T", base(TypeAttribute)) cond = operand_def(IntegerType(1)) true_value = operand_def(T) diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index de3dc3686e..a1f0bf30f4 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -52,10 +52,10 @@ from xdsl.irdl import ( AnyOf, BaseAttr, - ConstraintVar, IRDLOperation, ParameterDef, ParametrizedAttribute, + VarConstraint, attr_def, base, eq, @@ -597,17 +597,19 @@ class ZerosOp(IRDLOperation): name = "csl.zeros" - T = Annotated[IntegerType | Float32Type | Float16Type, ConstraintVar("T")] + T: ClassVar[VarConstraint[ZerosOpAttr]] = VarConstraint("T", ZerosOpAttrConstr) size = opt_operand_def(T) - result = result_def(MemRefType[T]) + result = result_def( + MemRefType[IntegerType | Float32Type | Float16Type].constr(element_type=T) + ) is_const = opt_prop_def(builtin.UnitAttr) def __init__( self, - memref: MemRefType[T], + memref: MemRefType[IntegerType | Float32Type | Float16Type], dynamic_size: SSAValue | Operation | None = None, is_const: builtin.UnitAttr | None = None, ): @@ -630,13 +632,17 @@ class ConstantsOp(IRDLOperation): name = "csl.constants" - T = Annotated[IntegerType | Float32Type | Float16Type, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType | Float32Type | Float16Type]] = VarConstraint( + "T", BaseAttr(IntegerType) | BaseAttr(Float32Type) | BaseAttr(Float16Type) + ) size = operand_def(IntegerType) value = operand_def(T) - result = result_def(MemRefType[T]) + result = result_def( + MemRefType[IntegerType | Float32Type | Float16Type].constr(element_type=T) + ) is_const = opt_prop_def(builtin.UnitAttr) @@ -1942,16 +1948,7 @@ class ParamOp(IRDLOperation): command line by passing params to the compiler. """ - T = Annotated[ - Float16Type - | Float32Type - | IntegerType - | ColorType - | FunctionType - | ImportedModuleType - | ComptimeStructType, - ConstraintVar("T"), - ] + T: ClassVar[VarConstraint[ParamOpAttr]] = VarConstraint("T", ParamOpAttrConstr) name = "csl.param" @@ -1963,7 +1960,10 @@ class ParamOp(IRDLOperation): res = result_def(T) def __init__( - self, name: str, result_type: T, init_value: SSAValue | Operation | None = None + self, + name: str, + result_type: ParamOpAttr, + init_value: SSAValue | Operation | None = None, ): super().__init__( operands=[init_value], diff --git a/xdsl/dialects/eqsat.py b/xdsl/dialects/eqsat.py index 637a2d9d79..a92727f437 100644 --- a/xdsl/dialects/eqsat.py +++ b/xdsl/dialects/eqsat.py @@ -11,12 +11,13 @@ from __future__ import annotations -from typing import Annotated +from typing import ClassVar from xdsl.ir import Attribute, Dialect, SSAValue from xdsl.irdl import ( - ConstraintVar, + AnyAttr, IRDLOperation, + VarConstraint, irdl_op_definition, result_def, var_operand_def, @@ -26,7 +27,7 @@ @irdl_op_definition class EClassOp(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "eqsat.eclass" arguments = var_operand_def(T) diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index d1b316dd9d..82914929bb 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from types import EllipsisType -from typing import Annotated +from typing import ClassVar from xdsl.dialects.builtin import ( I64, @@ -34,9 +34,10 @@ TypeAttribute, ) from xdsl.irdl import ( - ConstraintVar, + BaseAttr, IRDLOperation, ParameterDef, + VarConstraint, base, irdl_attr_definition, irdl_op_definition, @@ -355,7 +356,7 @@ def verify(self): class ArithmeticBinOperation(IRDLOperation, ABC): """Class for arithmetic binary operations.""" - T = Annotated[IntegerType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerType]] = VarConstraint("T", BaseAttr(IntegerType)) lhs = operand_def(T) rhs = operand_def(T) diff --git a/xdsl/dialects/ltl.py b/xdsl/dialects/ltl.py index 8e03e8246c..0a9ef5b4eb 100644 --- a/xdsl/dialects/ltl.py +++ b/xdsl/dialects/ltl.py @@ -6,14 +6,14 @@ from __future__ import annotations -from typing import Annotated +from typing import ClassVar -from xdsl.dialects.builtin import IntegerType, Signedness +from xdsl.dialects.builtin import IntegerType from xdsl.ir import Attribute, Dialect, ParametrizedAttribute, SSAValue, TypeAttribute from xdsl.irdl import ( AnyOf, - ConstraintVar, IRDLOperation, + VarConstraint, irdl_attr_definition, irdl_op_definition, result_def, @@ -51,11 +51,9 @@ class AndOp(IRDLOperation): name = "ltl.and" - T = Annotated[ - Attribute, - AnyOf([Sequence, Property, IntegerType(1, signedness=Signedness.SIGNLESS)]), - ConstraintVar("T"), - ] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint( + "T", AnyOf([Sequence, Property, IntegerType(1)]) + ) input = var_operand_def(T) diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 0b6acb460d..90e773cbae 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -8,6 +8,7 @@ from xdsl.dialects.builtin import ( I64, AnyFloat, + AnyFloatConstr, AnyIntegerAttr, AnySignlessIntegerType, ArrayAttr, @@ -21,6 +22,7 @@ MemrefLayoutAttr, MemRefType, NoneAttr, + SignlessIntegerConstraint, StridedLayoutAttr, StringAttr, SymbolRefAttr, @@ -35,11 +37,12 @@ ) from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParsePropInAttrDict, SameVariadicResultSize, + VarConstraint, base, irdl_op_definition, operand_def, @@ -67,13 +70,13 @@ @irdl_op_definition class Load(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] - name = "memref.load" + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) + nontemporal = opt_prop_def(BoolAttr) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType()) res = result_def(T) @@ -106,14 +109,14 @@ def get( @irdl_op_definition class Store(IRDLOperation): - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) name = "memref.store" nontemporal = opt_prop_def(BoolAttr) value = operand_def(T) - memref = operand_def(MemRefType[T]) + memref = operand_def(MemRefType[Attribute].constr(element_type=T)) indices = var_operand_def(IndexType()) irdl_options = [ParsePropInAttrDict()] @@ -357,13 +360,14 @@ def verify_(self) -> None: class AtomicRMWOp(IRDLOperation): name = "memref.atomic_rmw" - T = Annotated[ - AnyFloat | AnySignlessIntegerType, - ConstraintVar("T"), - ] + T: ClassVar[VarConstraint[AnyFloat | AnySignlessIntegerType]] = VarConstraint( + "T", AnyFloatConstr | SignlessIntegerConstraint + ) value = operand_def(T) - memref = operand_def(MemRefType[T]) + memref = operand_def( + MemRefType[AnyFloat | AnySignlessIntegerType].constr(element_type=T) + ) indices = var_operand_def(IndexType) kind = prop_def(IntegerAttr[I64]) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 56a2ba107b..6b14be9bd5 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from enum import auto from itertools import product -from typing import Annotated, Any, cast +from typing import Any, ClassVar, cast from typing_extensions import Self @@ -35,10 +35,11 @@ SSAValue, ) from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParameterDef, + VarConstraint, base, irdl_attr_definition, irdl_op_definition, @@ -848,9 +849,9 @@ class YieldOp(AbstractYieldOperation[Attribute]): class FillOp(IRDLOperation): name = "memref_stream.fill" - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) - memref = operand_def(memref.MemRefType[T]) + memref = operand_def(memref.MemRefType[Attribute].constr(element_type=T)) value = operand_def(T) assembly_format = "$memref `with` $value attr-dict `:` type($memref)" diff --git a/xdsl/dialects/mod_arith.py b/xdsl/dialects/mod_arith.py index 8c2cd11f24..3c8765192c 100644 --- a/xdsl/dialects/mod_arith.py +++ b/xdsl/dialects/mod_arith.py @@ -4,15 +4,15 @@ """ from abc import ABC -from typing import Annotated +from typing import ClassVar from xdsl.dialects.arith import signlessIntegerLike from xdsl.dialects.builtin import AnyIntegerAttr from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( - ConstraintVar, IRDLOperation, ParsePropInAttrDict, + VarConstraint, irdl_op_definition, operand_def, prop_def, @@ -26,11 +26,12 @@ class BinaryOp(IRDLOperation, ABC): Simple binary operation """ - T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike] - modulus = prop_def(AnyIntegerAttr) + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike) + lhs = operand_def(T) rhs = operand_def(T) output = result_def(T) + modulus = prop_def(AnyIntegerAttr) irdl_options = [ParsePropInAttrDict()] diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 56ac11fc39..0c58b4f2b9 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -2,7 +2,7 @@ from abc import ABC from collections.abc import Sequence -from typing import Annotated, TypeAlias, TypeVar, cast +from typing import ClassVar, cast from typing_extensions import Self @@ -39,8 +39,9 @@ ) from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( - ConstraintVar, + VarConstraint, attr_def, + base, irdl_op_definition, operand_def, opt_attr_def, @@ -801,9 +802,6 @@ class VFMaxSOp(riscv.RdRsRsFloatOperationWithFastMath): traits = frozenset((Pure(),)) -RdRsFloatInvT = TypeVar("RdRsFloatInvT", bound=FloatRegisterType) - - class RdRsRsAccumulatingFloatOperationWithFastMath(RISCVInstruction, ABC): """ A base class for RISC-V operations that have one destination floating-point register, @@ -811,10 +809,12 @@ class RdRsRsAccumulatingFloatOperationWithFastMath(RISCVInstruction, ABC): be annotated with fastmath flags. """ - SameFloatRegisterType: TypeAlias = Annotated[RdRsFloatInvT, ConstraintVar("RdRs")] + SAME_FLOAT_REGISTER_TYPE: ClassVar[VarConstraint[FloatRegisterType]] = ( + VarConstraint("SAME_FLOAT_REGISTER_TYPE", base(FloatRegisterType)) + ) - rd_out = result_def(SameFloatRegisterType) - rd_in = operand_def(SameFloatRegisterType) + rd_out = result_def(SAME_FLOAT_REGISTER_TYPE) + rd_in = operand_def(SAME_FLOAT_REGISTER_TYPE) rs1 = operand_def(FloatRegisterType) rs2 = operand_def(FloatRegisterType) @@ -873,10 +873,12 @@ class RdRsAccumulatingFloatOperation(RISCVInstruction, ABC): that also acts as a source register, and a source floating-point register. """ - SameFloatRegisterType: TypeAlias = Annotated[RdRsFloatInvT, ConstraintVar("RdRs")] + SAME_FLOAT_REGISTER_TYPE: ClassVar[VarConstraint[FloatRegisterType]] = ( + VarConstraint("SAME_FLOAT_REGISTER_TYPE", base(FloatRegisterType)) + ) - rd_out = result_def(SameFloatRegisterType) - rd_in = operand_def(SameFloatRegisterType) + rd_out = result_def(SAME_FLOAT_REGISTER_TYPE) + rd_in = operand_def(SAME_FLOAT_REGISTER_TYPE) rs = operand_def(FloatRegisterType) def __init__( diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index e9b80d99fe..edef7369cb 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Annotated +from typing import ClassVar from typing_extensions import Self @@ -10,6 +10,7 @@ DenseArrayBase, IndexType, IntegerType, + SignlessIntegerConstraint, i64, ) from xdsl.dialects.utils import ( @@ -20,8 +21,9 @@ from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, + VarConstraint, + base, irdl_op_definition, operand_def, prop_def, @@ -284,7 +286,9 @@ def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: class For(IRDLOperation): name = "scf.for" - T = Annotated[AnySignlessIntegerOrIndexType, ConstraintVar("T")] + T: ClassVar[VarConstraint[AnySignlessIntegerOrIndexType]] = VarConstraint( + "T", base(IndexType) | SignlessIntegerConstraint + ) lb = operand_def(T) ub = operand_def(T) diff --git a/xdsl/dialects/seq.py b/xdsl/dialects/seq.py index a861221be9..29f74304bc 100644 --- a/xdsl/dialects/seq.py +++ b/xdsl/dialects/seq.py @@ -5,7 +5,7 @@ """ from enum import Enum -from typing import Annotated +from typing import ClassVar from xdsl.dialects.builtin import ( AnyIntegerAttr, @@ -17,10 +17,11 @@ from xdsl.dialects.hw import InnerSymAttr from xdsl.ir import Attribute, Data, Dialect, Operation, SSAValue from xdsl.irdl import ( + AnyAttr, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParametrizedAttribute, + VarConstraint, attr_def, irdl_attr_definition, irdl_op_definition, @@ -93,15 +94,15 @@ class CompRegOp(IRDLOperation): name = "seq.compreg" - DataType = Annotated[Attribute, ConstraintVar("DataType")] + DATA_TYPE: ClassVar[VarConstraint[Attribute]] = VarConstraint("DataType", AnyAttr()) inner_sym = opt_attr_def(InnerSymAttr) - input = operand_def(DataType) + input = operand_def(DATA_TYPE) clk = operand_def(clock) reset = opt_operand_def(i1) - reset_value = opt_operand_def(DataType) - power_on_value = opt_operand_def(DataType) - data = result_def(DataType) + reset_value = opt_operand_def(DATA_TYPE) + power_on_value = opt_operand_def(DATA_TYPE) + data = result_def(DATA_TYPE) irdl_options = [AttrSizedOperandSegments()] diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index 33d1fbe472..1cffaa38ea 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -8,7 +8,7 @@ import abc from collections.abc import Sequence -from typing import Annotated, TypeAlias, cast +from typing import Annotated, ClassVar, TypeAlias, cast from xdsl.dialects.builtin import ( I32, @@ -38,7 +38,9 @@ ConstraintVar, IRDLOperation, ParameterDef, + VarConstraint, attr_def, + base, irdl_attr_definition, irdl_op_definition, operand_def, @@ -57,7 +59,7 @@ class ElementwiseBinaryOperation(IRDLOperation, abc.ABC): # TODO: Remove this constraint for complex types. - T = Annotated[AnyTensorType, ConstraintVar("T")] + T: ClassVar[VarConstraint[AnyTensorType]] = VarConstraint("T", base(AnyTensorType)) lhs = operand_def(T) rhs = operand_def(T) @@ -221,7 +223,7 @@ class AbsOp(IRDLOperation): name = "stablehlo.abs" # TODO: Remove this constraint for complex types. - T = Annotated[AnyTensorType, ConstraintVar("T")] + T: ClassVar[VarConstraint[AnyTensorType]] = VarConstraint("T", base(AnyTensorType)) operand = operand_def(T) result = result_def(T) @@ -284,7 +286,9 @@ class AndOp(IRDLOperation): name = "stablehlo.and" - T = Annotated[IntegerTensorType, ConstraintVar("T")] + T: ClassVar[VarConstraint[IntegerTensorType]] = VarConstraint( + "T", base(IntegerTensorType) + ) lhs = operand_def(T) rhs = operand_def(T) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index c6e6deffad..f0cef4d084 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -479,8 +479,6 @@ class ApplyOp(IRDLOperation): name = "stencil.apply" - B = Annotated[Attribute, ConstraintVar("B")] - args = var_operand_def(Attribute) dest = var_operand_def(FieldType) region = region_def() diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 794eaddb6a..2a70ca5c37 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import Annotated, Generic, TypeAlias, TypeVar, cast +from typing import ClassVar, Generic, TypeAlias, TypeVar, cast from typing_extensions import Self @@ -18,11 +18,11 @@ from xdsl.irdl import ( AnyAttr, BaseAttr, - ConstraintVar, GenericAttrConstraint, IRDLOperation, ParamAttrConstraint, ParameterDef, + VarConstraint, irdl_attr_definition, irdl_op_definition, operand_def, @@ -80,9 +80,9 @@ class ReadOperation(IRDLOperation, abc.ABC): Abstract base class for operations that read from a stream. """ - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType[T]) + stream = operand_def(ReadableStreamType[Attribute].constr(element_type=T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -113,10 +113,10 @@ class WriteOperation(IRDLOperation, abc.ABC): Abstract base class for operations that write to a stream. """ - T = Annotated[Attribute, ConstraintVar("T")] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType[T]) + stream = operand_def(WritableStreamType[Attribute].constr(element_type=T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) diff --git a/xdsl/dialects/transform.py b/xdsl/dialects/transform.py index 43801a6aa3..4dea325fc8 100644 --- a/xdsl/dialects/transform.py +++ b/xdsl/dialects/transform.py @@ -28,7 +28,6 @@ from xdsl.irdl import ( AnyOf, AttrSizedOperandSegments, - ConstraintVar, IRDLOperation, ParameterDef, attr_def, @@ -525,8 +524,6 @@ class SequenceOp(IRDLOperation): name = "transform.sequence" - T = Annotated[AnyIntegerOrFailurePropagationModeAttr, ConstraintVar("T")] - body = region_def("single_block") failure_propagation_mode = prop_def(Attribute) root = var_operand_def(AnyOpType) diff --git a/xdsl/dialects/varith.py b/xdsl/dialects/varith.py index a10dc28996..47ed605c12 100644 --- a/xdsl/dialects/varith.py +++ b/xdsl/dialects/varith.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import ClassVar from xdsl.dialects.builtin import ( BFloat16Type, @@ -14,8 +14,8 @@ from xdsl.ir import Attribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, - ConstraintVar, IRDLOperation, + VarConstraint, irdl_op_definition, result_def, var_operand_def, @@ -43,7 +43,7 @@ class VarithOp(IRDLOperation): Variadic arithmetic operation """ - T = Annotated[Attribute, ConstraintVar("T"), integerOrFloatLike] + T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", integerOrFloatLike) args = var_operand_def(T) res = result_def(T)