Skip to content

Commit

Permalink
core: new inference system for constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 15, 2024
1 parent b7ca27c commit 1fab125
Show file tree
Hide file tree
Showing 14 changed files with 311 additions and 315 deletions.
12 changes: 6 additions & 6 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ def test_vector_rank_constraint_verify():
vector_type = VectorType(i32, [1, 2])
constraint = VectorRankConstraint(2)

constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())


def test_vector_rank_constraint_rank_mismatch():
vector_type = VectorType(i32, [1, 2])
constraint = VectorRankConstraint(3)

with pytest.raises(VerifyException) as e:
constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())
assert e.value.args[0] == "Expected vector rank to be 3, got 2."


Expand All @@ -180,23 +180,23 @@ def test_vector_rank_constraint_attr_mismatch():
constraint = VectorRankConstraint(3)

with pytest.raises(VerifyException) as e:
constraint.verify(memref_type)
constraint.verify(memref_type, ConstraintContext())
assert e.value.args[0] == "memref<1x2xi32> should be of type VectorType."


def test_vector_base_type_constraint_verify():
vector_type = VectorType(i32, [1, 2])
constraint = VectorBaseTypeConstraint(i32)

constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())


def test_vector_base_type_constraint_type_mismatch():
vector_type = VectorType(i32, [1, 2])
constraint = VectorBaseTypeConstraint(i64)

with pytest.raises(VerifyException) as e:
constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())
assert e.value.args[0] == "Expected vector type to be i64, got i32."


Expand All @@ -205,7 +205,7 @@ def test_vector_base_type_constraint_attr_mismatch():
constraint = VectorBaseTypeConstraint(i32)

with pytest.raises(VerifyException) as e:
constraint.verify(memref_type)
constraint.verify(memref_type, ConstraintContext())
assert e.value.args[0] == "memref<1x2xi32> should be of type VectorType."


Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/onnx/onnx_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ builtin.module {
%t0 = "test.op"(): () -> (f32)

// CHECK: operand at position 0 does not verify:
// CHECK: Unexpected attribute f32
// CHECK: Expected tensor or memref type, got f32
%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {onnx_node_name = "/MaxPoolSingleOut"} : (f32) -> tensor<5x5x32x32xf32>
}

Expand All @@ -379,7 +379,7 @@ builtin.module {
%t0= "test.op"(): () -> (tensor<5x5x32x32xf32>)

// CHECK: result at position 0 does not verify:
// CHECK: Unexpected attribute tensor<5x5x32x32xi32>
// CHECK: attribute f32 expected from variable 'T', but got i32
%res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut"} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xi32>

}
Expand Down Expand Up @@ -565,7 +565,7 @@ builtin.module {

builtin.module {
%t0 = "test.op"() : () -> (tensor<3x4xf32>)

// CHECK: Operation does not verify: tensor input shape (3, 4) is not equal to tensor output shape (7, 3)
%res_sigmoid = "onnx.Sigmoid"(%t0) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<7x3xf32>
}
8 changes: 2 additions & 6 deletions tests/irdl/test_attribute_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,7 @@ def test_union_constraint_fail():


class PositiveIntConstr(AttrConstraint):
def verify(
self,
attr: Attribute,
constraint_context: ConstraintContext | None = None,
) -> None:
def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if not isinstance(attr, IntData):
raise VerifyException(
f"Expected {IntData.name} attribute, but got {attr.name}."
Expand Down Expand Up @@ -602,7 +598,7 @@ def test_informative_constraint():
match="User-enlightening message.\nUnderlying verification failure: Expected attribute #none but got #builtin.int<1>",
):
constr.verify(IntAttr(1), ConstraintContext())
assert constr.get_resolved_variables() == set()
assert constr.can_infer(set())
assert constr.get_unique_base() == NoneAttr


Expand Down
10 changes: 6 additions & 4 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError
from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError, VerifyException

################################################################################
# Utils for this test file #
Expand Down Expand Up @@ -1748,10 +1748,12 @@ class OneOperandOneResultNestedOp(IRDLOperation):
%1 = test.one_operand_one_result_nested %0 : i32"""
)
with pytest.raises(
ParseError,
match="Verification error while inferring operation type: ",
VerifyException,
match="i32 should be of base attribute test.param_one",
):
check_roundtrip(program, ctx)
parser = Parser(ctx, program)
while (op := parser.parse_optional_operation()) is not None:
op.verify()


def test_variadic_length_inference():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pyrdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class LessThan(AttrConstraint):
def verify(
self,
attr: Attribute,
constraint_context: ConstraintContext | None = None,
constraint_context: ConstraintContext,
) -> None:
if not isinstance(attr, IntData):
raise VerifyException(f"{attr} should be of base attribute {IntData.name}")
Expand Down
113 changes: 42 additions & 71 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
from collections.abc import Set
from dataclasses import dataclass
from typing import Any, ClassVar

from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
Expand All @@ -19,7 +21,9 @@
AnyOf,
AttrSizedOperandSegments,
ConstraintContext,
GenericAttrConstraint,
IRDLOperation,
ResolveType,
VarConstraint,
irdl_op_definition,
operand_def,
Expand All @@ -32,7 +36,10 @@
from xdsl.utils.hints import isa


class TensorMemrefInferenceConstraint(VarConstraint[Attribute]):
@dataclass(frozen=True)
class TensorFromMemrefConstraint(
GenericAttrConstraint[TensorType[Attribute] | UnrankedTensorType[Attribute]]
):
"""
Constraint to infer tensor shapes from memref shapes, inferring ranked tensor from ranked memref
(and unranked from unranked, respectively).
Expand All @@ -41,42 +48,29 @@ class TensorMemrefInferenceConstraint(VarConstraint[Attribute]):
and checks for matching element type, shape (ranked only), as well as verifying sub constraints.
"""

def infer(self, constraint_context: ConstraintContext) -> Attribute:
if self.name in constraint_context.variables:
m_type = constraint_context.get_variable(self.name)
if isa(m_type, MemRefType[Attribute]):
return TensorType(m_type.get_element_type(), m_type.get_shape())
if isa(m_type, UnrankedMemrefType[Attribute]):
return UnrankedTensorType(m_type.element_type)
raise ValueError(f"Unexpected {self.name} - cannot infer attribute")
memref_var_constr: VarConstraint[
MemRefType[Attribute] | UnrankedMemrefType[Attribute]
]

def can_infer(self, variables: Set[str]) -> bool:
return self.memref_var_constr.can_infer(variables)

def infer(
self, variables: dict[str, ResolveType]
) -> TensorType[Attribute] | UnrankedTensorType[Attribute]:
memref_type = self.memref_var_constr.infer(variables)
if isinstance(memref_type, MemRefType):
return TensorType(memref_type.element_type, memref_type.shape)
return UnrankedTensorType(memref_type.element_type)

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if self.name in constraint_context.variables:
seen = constraint_context.get_variable(self.name)
if not (
isinstance(attr, ContainerType)
and isinstance(seen, ContainerType)
and attr.get_element_type() == seen.get_element_type()
):
raise VerifyException(
f"Unexpected {self.name} - cannot verify element type of attribute {attr}"
)
if (
isinstance(attr, ShapedType) != isinstance(seen, ShapedType)
or isinstance(attr, ShapedType)
and isinstance(seen, ShapedType)
and attr.get_shape() != seen.get_shape()
):
raise VerifyException(
f"Unexpected {self.name} - cannot verify shape of attribute {attr}"
)
elif isinstance(attr, ContainerType):
self.constraint.verify(attr, constraint_context)
constraint_context.set_variable(self.name, attr)
else:
raise VerifyException(
f"Unexpected {self.name} - attribute must be ContainerType"
)
if isa(attr, TensorType[Attribute]):
memref_type = MemRefType(attr.element_type, attr.shape)
return self.memref_var_constr.verify(memref_type, constraint_context)
if isa(attr, UnrankedTensorType[Attribute]):
memref_type = UnrankedMemrefType.from_type(attr.element_type)

raise VerifyException(f"Expected TensorType or UnrankedTensorType, got {attr}")


@irdl_op_definition
Expand Down Expand Up @@ -108,16 +102,10 @@ def __init__(
class ToTensorOp(IRDLOperation):
name = "bufferization.to_tensor"

memref = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr])
)
)
tensor = result_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr])
)
)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)

memref = operand_def(T)
tensor = result_def(TensorFromMemrefConstraint(T))
writable = opt_prop_def(UnitAttr)
restrict = opt_prop_def(UnitAttr)

Expand Down Expand Up @@ -153,16 +141,10 @@ def __init__(
class ToMemrefOp(IRDLOperation):
name = "bufferization.to_memref"

tensor = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr])
)
)
memref = result_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr])
)
)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)

tensor = operand_def(TensorFromMemrefConstraint(T))
memref = result_def(T)
read_only = opt_prop_def(UnitAttr)

assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)"
Expand All @@ -172,21 +154,10 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestination(IRDLOperation):
name = "bufferization.materialize_in_destination"

source = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
dest = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
result = result_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)
source = operand_def(T)
dest = operand_def(T)
result = result_def(T)
restrict = opt_prop_def(UnitAttr)
writable = opt_prop_def(UnitAttr)

Expand Down
54 changes: 44 additions & 10 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass
from enum import Enum
from math import prod
Expand Down Expand Up @@ -52,6 +52,7 @@
MessageConstraint,
ParamAttrConstraint,
ParameterDef,
ResolveType,
attr_constr_coercion,
base,
irdl_attr_definition,
Expand Down Expand Up @@ -83,7 +84,7 @@
"""


class ShapedType(ABC):
class ShapedType(Attribute, ABC):
@abstractmethod
def get_num_dims(self) -> int: ...

Expand All @@ -101,14 +102,6 @@ def strides_for_shape(shape: Sequence[int], factor: int = 1) -> tuple[int, ...]:
return tuple(accumulate(reversed(shape), operator.mul, initial=factor))[-2::-1]


class AnyShapedType(AttrConstraint):
def verify(
self, attr: Attribute, constraint_context: ConstraintContext | None = None
) -> None:
if not isinstance(attr, ShapedType):
raise Exception(f"expected type ShapedType but got {attr}")


_ContainerElementTypeT = TypeVar(
"_ContainerElementTypeT", bound=Attribute | None, covariant=True
)
Expand Down Expand Up @@ -1637,6 +1630,47 @@ def constr(
AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType)


@dataclass(frozen=True, init=False)
class TensorOrMemrefOf(
GenericAttrConstraint[TensorType[AttributeCovT] | MemRefType[AttributeCovT]]
):
"""A type constraint that can be nested once in a vector or a tensor."""

elem_constr: GenericAttrConstraint[AttributeCovT]

def __init__(
self,
elem_constr: AttributeCovT
| type[AttributeCovT]
| GenericAttrConstraint[AttributeCovT],
) -> None:
object.__setattr__(self, "elem_constr", attr_constr_coercion(elem_constr))

@staticmethod
def _wrap_resolver(
resolver: Callable[[AttributeCovT], ResolveType],
) -> Callable[[TensorType[AttributeCovT] | MemRefType[AttributeCovT]], ResolveType]:
return lambda attr: resolver(attr.element_type)

def get_resolvers(
self,
) -> dict[
str,
Callable[[TensorType[AttributeCovT] | MemRefType[AttributeCovT]], ResolveType],
]:
return {
v: self._wrap_resolver(r)
for v, r in self.elem_constr.get_resolvers().items()
}

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if isinstance(attr, VectorType) or isinstance(attr, TensorType):
attr = cast(VectorType[Attribute] | TensorType[Attribute], attr)
self.elem_constr.verify(attr.element_type, constraint_context)
else:
raise VerifyException(f"Expected tensor or memref type, got {attr}")


@irdl_attr_definition
class UnrankedMemrefType(
Generic[_UnrankedMemrefTypeElems],
Expand Down
Loading

0 comments on commit 1fab125

Please sign in to comment.