Skip to content

Commit

Permalink
misc: use VarConstraint instead of Annotated[ConstraintVar] in operat…
Browse files Browse the repository at this point in the history
…ion definitions (#3264)

Using Annotated in the way that we do is not well defined by the Python
type system, this fixes the current misuse to get the number of Pylance
errors to 0.

The current proposal is to keep both implementations working, and to
deprecate the `Annotated` API at some point in the future.
  • Loading branch information
superlopuh authored Oct 10, 2024
1 parent 7f0f3e6 commit 187eb31
Show file tree
Hide file tree
Showing 21 changed files with 247 additions and 127 deletions.
5 changes: 3 additions & 2 deletions docs/irdl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1058,14 +1058,15 @@
}
],
"source": [
"from xdsl.irdl import ConstraintVar\n",
"from typing import ClassVar\n",
"from xdsl.irdl import base, VarConstraint\n",
"\n",
"\n",
"@irdl_op_definition\n",
"class BinaryOp(IRDLOperation):\n",
" name = \"binary_op\"\n",
"\n",
" T = Annotated[IntegerType, ConstraintVar(\"T\")]\n",
" T: ClassVar[VarConstraint[IntegerType]] = VarConstraint(\"T\", base(IntegerType))\n",
"\n",
" lhs = operand_def(T)\n",
" rhs = operand_def(T)\n",
Expand Down
14 changes: 7 additions & 7 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import textwrap
from collections.abc import Callable
from io import StringIO
from typing import Annotated, Generic, TypeVar
from typing import ClassVar, Generic, TypeVar

import pytest

Expand All @@ -23,13 +23,13 @@
AttrSizedRegionSegments,
AttrSizedResultSegments,
BaseAttr,
ConstraintVar,
EqAttrConstraint,
GenericAttrConstraint,
IRDLOperation,
ParamAttrConstraint,
ParameterDef,
ParsePropInAttrDict,
VarConstraint,
VarOperand,
VarOpResult,
attr_def,
Expand Down Expand Up @@ -1578,7 +1578,7 @@ def test_basic_inference(format: str):

@irdl_op_definition
class TwoOperandsOneResultWithVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

name = "test.two_operands_one_result_with_var"
res = result_def(T)
Expand Down Expand Up @@ -1678,11 +1678,11 @@ def constr(

@irdl_op_definition
class TwoOperandsNestedVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

name = "test.two_operands_one_result_with_var"
res = result_def(T)
lhs = operand_def(ParamOne[T])
lhs = operand_def(ParamOne[Attribute].constr(p=T))
rhs = operand_def(T)

assembly_format = "$lhs $rhs attr-dict `:` type($lhs)"
Expand Down Expand Up @@ -1722,11 +1722,11 @@ def constr(

@irdl_op_definition
class OneOperandOneResultNestedOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

name = "test.one_operand_one_result_nested"
res = result_def(T)
lhs = operand_def(ParamOne[T])
lhs = operand_def(ParamOne[Attribute].constr(p=T))

assembly_format = "$lhs attr-dict `:` type($lhs)"

Expand Down
115 changes: 107 additions & 8 deletions tests/irdl/test_operation_definition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Annotated, Generic, TypeVar
from typing import Annotated, ClassVar, Generic, TypeVar

import pytest

Expand Down Expand Up @@ -31,7 +31,9 @@
RangeOf,
RegionDef,
ResultDef,
VarConstraint,
attr_def,
base,
irdl_op_definition,
operand_def,
opt_attr_def,
Expand Down Expand Up @@ -157,15 +159,16 @@ def test_attr_verify():
op.verify()


# TODO: remove this test once the Annotated API is deprecated
@irdl_op_definition
class ConstraintVarOp(IRDLOperation):
name = "test.constraint_var_op"

T = Annotated[IntegerType | IndexType, ConstraintVar("T")]

operand = operand_def(T)
result = result_def(T)
attribute = attr_def(T)
operand = operand_def(T) # pyright: ignore[reportArgumentType]
result = result_def(T) # pyright: ignore[reportArgumentType]
attribute = attr_def(T) # pyright: ignore[reportArgumentType, reportUnknownVariableType]


def test_constraint_var():
Expand Down Expand Up @@ -193,7 +196,10 @@ def test_constraint_var_fail_non_equal():
op = ConstraintVarOp.create(
operands=[index_operand], result_types=[i32], attributes={"attribute": i32}
)
with pytest.raises(DiagnosticException):
with pytest.raises(
DiagnosticException,
match="Operation does not verify: result at position 0 does not verify",
):
op.verify()

# Fail because of result
Expand All @@ -202,7 +208,10 @@ def test_constraint_var_fail_non_equal():
result_types=[IndexType()],
attributes={"attribute": i32},
)
with pytest.raises(DiagnosticException):
with pytest.raises(
DiagnosticException,
match="Operation does not verify: result at position 0 does not verify",
):
op2.verify()

# Fail because of attribute
Expand All @@ -211,7 +220,10 @@ def test_constraint_var_fail_non_equal():
result_types=[i32],
attributes={"attribute": IndexType()},
)
with pytest.raises(DiagnosticException):
with pytest.raises(
DiagnosticException,
match="Operation does not verify: attribute i32 expected from variable 'T', but got index",
):
op3.verify()


Expand All @@ -223,7 +235,94 @@ def test_constraint_var_fail_not_satisfy_constraint():
result_types=[TestType("foo")],
attributes={"attribute": TestType("foo")},
)
with pytest.raises(DiagnosticException):
with pytest.raises(
DiagnosticException,
match="Operation does not verify: operand at position 0 does not verify",
):
op.verify()


@irdl_op_definition
class GenericConstraintVarOp(IRDLOperation):
name = "test.constraint_var_op"

T: ClassVar[VarConstraint[IntegerType | IndexType]] = VarConstraint(
"T", base(IntegerType) | base(IndexType)
)

operand = operand_def(T)
result = result_def(T)
attribute = attr_def(T)


def test_generic_constraint_var():
i32_operand = TestSSAValue(i32)
index_operand = TestSSAValue(IndexType())
op = GenericConstraintVarOp.create(
operands=[i32_operand], result_types=[i32], attributes={"attribute": i32}
)
op.verify()

op2 = GenericConstraintVarOp.create(
operands=[index_operand],
result_types=[IndexType()],
attributes={"attribute": IndexType()},
)
op2.verify()


def test_generic_constraint_var_fail_non_equal():
"""Check that all uses of a constraint variable are of the same attribute."""
i32_operand = TestSSAValue(i32)
index_operand = TestSSAValue(IndexType())

# Fail because of operand
op = GenericConstraintVarOp.create(
operands=[index_operand], result_types=[i32], attributes={"attribute": i32}
)
with pytest.raises(
DiagnosticException,
match="Operation does not verify: result at position 0 does not verify",
):
op.verify()

# Fail because of result
op2 = GenericConstraintVarOp.create(
operands=[i32_operand],
result_types=[IndexType()],
attributes={"attribute": i32},
)
with pytest.raises(
DiagnosticException,
match="Operation does not verify: result at position 0 does not verify",
):
op2.verify()

# Fail because of attribute
op3 = GenericConstraintVarOp.create(
operands=[i32_operand],
result_types=[i32],
attributes={"attribute": IndexType()},
)
with pytest.raises(
DiagnosticException,
match="Operation does not verify: attribute i32 expected from variable 'T', but got index",
):
op3.verify()


def test_generic_constraint_var_fail_not_satisfy_constraint():
"""Check that all uses of a constraint variable are satisfying the constraint."""
test_operand = TestSSAValue(TestType("foo"))
op = GenericConstraintVarOp.create(
operands=[test_operand],
result_types=[TestType("foo")],
attributes={"attribute": TestType("foo")},
)
with pytest.raises(
DiagnosticException,
match="Operation does not verify: operand at position 0 does not verify",
):
op.verify()


Expand Down
15 changes: 8 additions & 7 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Annotated, Any, cast
from typing import Any, ClassVar, cast

from xdsl.dialects.builtin import (
AffineMapAttr,
Expand All @@ -20,9 +20,10 @@
from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue
from xdsl.ir.affine import AffineExpr, AffineMap
from xdsl.irdl import (
AnyAttr,
AttrSizedOperandSegments,
ConstraintVar,
IRDLOperation,
VarConstraint,
attr_def,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -258,10 +259,10 @@ def verify_(self) -> None:
class Store(IRDLOperation):
name = "affine.store"

T = Annotated[Attribute, ConstraintVar("T")]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

value = operand_def(T)
memref = operand_def(MemRefType[T])
memref = operand_def(MemRefType[Attribute].constr(element_type=T))
indices = var_operand_def(IndexType)
map = opt_prop_def(AffineMapAttr)

Expand Down Expand Up @@ -291,9 +292,9 @@ def __init__(
class Load(IRDLOperation):
name = "affine.load"

T = Annotated[Attribute, ConstraintVar("T")]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

memref = operand_def(MemRefType[T])
memref = operand_def(MemRefType[Attribute].constr(element_type=T))
indices = var_operand_def(IndexType)

result = result_def(T)
Expand All @@ -305,7 +306,7 @@ def __init__(
memref: SSAValue,
indices: Sequence[SSAValue],
map: AffineMapAttr | None = None,
result_type: T | None = None,
result_type: Attribute | None = None,
):
if map is None:
# Create identity map for memrefs with at least one dimension or () -> ()
Expand Down
12 changes: 6 additions & 6 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
from collections.abc import Mapping, Sequence
from typing import Annotated, Literal, TypeVar, cast, overload
from typing import ClassVar, Literal, TypeVar, cast, overload

from xdsl.dialects.builtin import (
AnyFloat,
Expand All @@ -26,8 +26,8 @@
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
ConstraintVar,
IRDLOperation,
VarConstraint,
base,
irdl_attr_definition,
irdl_op_definition,
Expand Down Expand Up @@ -175,7 +175,7 @@ def parse(cls: type[Constant], parser: Parser) -> Constant:
class SignlessIntegerBinaryOperation(IRDLOperation, abc.ABC):
"""A generic base class for arith's binary operations on signless integers."""

T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike)

lhs = operand_def(T)
rhs = operand_def(T)
Expand Down Expand Up @@ -216,7 +216,7 @@ def __hash__(self) -> int:
class FloatingPointLikeBinaryOperation(IRDLOperation, abc.ABC):
"""A generic base class for arith's binary operations on floats."""

T = Annotated[Attribute, ConstraintVar("T"), floatingPointLike]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", floatingPointLike)

lhs = operand_def(T)
rhs = operand_def(T)
Expand Down Expand Up @@ -290,7 +290,7 @@ class AddUIExtended(IRDLOperation):

traits = frozenset([Pure()])

T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike)

lhs = operand_def(T)
rhs = operand_def(T)
Expand Down Expand Up @@ -353,7 +353,7 @@ class Muli(SignlessIntegerBinaryOperation):
class MulExtendedBase(IRDLOperation):
"""Base class for extended multiplication operations."""

T = Annotated[Attribute, ConstraintVar("T"), signlessIntegerLike]
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", signlessIntegerLike)

lhs = operand_def(T)
rhs = operand_def(T)
Expand Down
Loading

0 comments on commit 187eb31

Please sign in to comment.