From 1fab125846e50490b0c94205d5f7677db18a6f6e Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 15 Nov 2024 14:33:48 +0000 Subject: [PATCH] core: new inference system for constraints --- tests/dialects/test_builtin.py | 12 +- .../filecheck/dialects/onnx/onnx_invalid.mlir | 6 +- tests/irdl/test_attribute_definition.py | 8 +- .../irdl/test_declarative_assembly_format.py | 10 +- tests/test_pyrdl.py | 2 +- xdsl/dialects/bufferization.py | 113 ++++----- xdsl/dialects/builtin.py | 54 ++++- xdsl/dialects/linalg.py | 9 +- xdsl/dialects/onnx.py | 11 +- xdsl/dialects/stencil.py | 5 +- xdsl/irdl/constraints.py | 223 ++++++++++-------- xdsl/irdl/declarative_assembly_format.py | 51 ++-- .../declarative_assembly_format_parser.py | 61 +++-- xdsl/irdl/operations.py | 61 ++--- 14 files changed, 311 insertions(+), 315 deletions(-) diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index b49c52d746..3b461b6a61 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -163,7 +163,7 @@ def test_vector_rank_constraint_verify(): vector_type = VectorType(i32, [1, 2]) constraint = VectorRankConstraint(2) - constraint.verify(vector_type) + constraint.verify(vector_type, ConstraintContext()) def test_vector_rank_constraint_rank_mismatch(): @@ -171,7 +171,7 @@ def test_vector_rank_constraint_rank_mismatch(): constraint = VectorRankConstraint(3) with pytest.raises(VerifyException) as e: - constraint.verify(vector_type) + constraint.verify(vector_type, ConstraintContext()) assert e.value.args[0] == "Expected vector rank to be 3, got 2." @@ -180,7 +180,7 @@ def test_vector_rank_constraint_attr_mismatch(): constraint = VectorRankConstraint(3) with pytest.raises(VerifyException) as e: - constraint.verify(memref_type) + constraint.verify(memref_type, ConstraintContext()) assert e.value.args[0] == "memref<1x2xi32> should be of type VectorType." @@ -188,7 +188,7 @@ def test_vector_base_type_constraint_verify(): vector_type = VectorType(i32, [1, 2]) constraint = VectorBaseTypeConstraint(i32) - constraint.verify(vector_type) + constraint.verify(vector_type, ConstraintContext()) def test_vector_base_type_constraint_type_mismatch(): @@ -196,7 +196,7 @@ def test_vector_base_type_constraint_type_mismatch(): constraint = VectorBaseTypeConstraint(i64) with pytest.raises(VerifyException) as e: - constraint.verify(vector_type) + constraint.verify(vector_type, ConstraintContext()) assert e.value.args[0] == "Expected vector type to be i64, got i32." @@ -205,7 +205,7 @@ def test_vector_base_type_constraint_attr_mismatch(): constraint = VectorBaseTypeConstraint(i32) with pytest.raises(VerifyException) as e: - constraint.verify(memref_type) + constraint.verify(memref_type, ConstraintContext()) assert e.value.args[0] == "memref<1x2xi32> should be of type VectorType." diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir index e05725c26d..b15ba25761 100644 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ b/tests/filecheck/dialects/onnx/onnx_invalid.mlir @@ -369,7 +369,7 @@ builtin.module { %t0 = "test.op"(): () -> (f32) // CHECK: operand at position 0 does not verify: - // CHECK: Unexpected attribute f32 + // CHECK: Expected tensor or memref type, got f32 %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {onnx_node_name = "/MaxPoolSingleOut"} : (f32) -> tensor<5x5x32x32xf32> } @@ -379,7 +379,7 @@ builtin.module { %t0= "test.op"(): () -> (tensor<5x5x32x32xf32>) // CHECK: result at position 0 does not verify: - // CHECK: Unexpected attribute tensor<5x5x32x32xi32> + // CHECK: attribute f32 expected from variable 'T', but got i32 %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t0) {"onnx_node_name" = "/MaxPoolSingleOut"} : (tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xi32> } @@ -565,7 +565,7 @@ builtin.module { builtin.module { %t0 = "test.op"() : () -> (tensor<3x4xf32>) - + // CHECK: Operation does not verify: tensor input shape (3, 4) is not equal to tensor output shape (7, 3) %res_sigmoid = "onnx.Sigmoid"(%t0) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<7x3xf32> } diff --git a/tests/irdl/test_attribute_definition.py b/tests/irdl/test_attribute_definition.py index 518064c4bc..f13f9d9b88 100644 --- a/tests/irdl/test_attribute_definition.py +++ b/tests/irdl/test_attribute_definition.py @@ -381,11 +381,7 @@ def test_union_constraint_fail(): class PositiveIntConstr(AttrConstraint): - def verify( - self, - attr: Attribute, - constraint_context: ConstraintContext | None = None, - ) -> None: + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: if not isinstance(attr, IntData): raise VerifyException( f"Expected {IntData.name} attribute, but got {attr.name}." @@ -602,7 +598,7 @@ def test_informative_constraint(): match="User-enlightening message.\nUnderlying verification failure: Expected attribute #none but got #builtin.int<1>", ): constr.verify(IntAttr(1), ConstraintContext()) - assert constr.get_resolved_variables() == set() + assert constr.can_infer(set()) assert constr.get_unique_base() == NoneAttr diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 7371f1601c..9faaf31eae 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -62,7 +62,7 @@ ) from xdsl.parser import Parser from xdsl.printer import Printer -from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError +from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError, VerifyException ################################################################################ # Utils for this test file # @@ -1748,10 +1748,12 @@ class OneOperandOneResultNestedOp(IRDLOperation): %1 = test.one_operand_one_result_nested %0 : i32""" ) with pytest.raises( - ParseError, - match="Verification error while inferring operation type: ", + VerifyException, + match="i32 should be of base attribute test.param_one", ): - check_roundtrip(program, ctx) + parser = Parser(ctx, program) + while (op := parser.parse_optional_operation()) is not None: + op.verify() def test_variadic_length_inference(): diff --git a/tests/test_pyrdl.py b/tests/test_pyrdl.py index aecdc24e0f..e448b9d141 100644 --- a/tests/test_pyrdl.py +++ b/tests/test_pyrdl.py @@ -135,7 +135,7 @@ class LessThan(AttrConstraint): def verify( self, attr: Attribute, - constraint_context: ConstraintContext | None = None, + constraint_context: ConstraintContext, ) -> None: if not isinstance(attr, IntData): raise VerifyException(f"{attr} should be of base attribute {IntData.name}") diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py index 7b4f1dc076..a2af19b12f 100644 --- a/xdsl/dialects/bufferization.py +++ b/xdsl/dialects/bufferization.py @@ -1,4 +1,6 @@ -from typing import Any +from collections.abc import Set +from dataclasses import dataclass +from typing import Any, ClassVar from xdsl.dialects.builtin import ( AnyMemRefTypeConstr, @@ -19,7 +21,9 @@ AnyOf, AttrSizedOperandSegments, ConstraintContext, + GenericAttrConstraint, IRDLOperation, + ResolveType, VarConstraint, irdl_op_definition, operand_def, @@ -32,7 +36,10 @@ from xdsl.utils.hints import isa -class TensorMemrefInferenceConstraint(VarConstraint[Attribute]): +@dataclass(frozen=True) +class TensorFromMemrefConstraint( + GenericAttrConstraint[TensorType[Attribute] | UnrankedTensorType[Attribute]] +): """ Constraint to infer tensor shapes from memref shapes, inferring ranked tensor from ranked memref (and unranked from unranked, respectively). @@ -41,42 +48,29 @@ class TensorMemrefInferenceConstraint(VarConstraint[Attribute]): and checks for matching element type, shape (ranked only), as well as verifying sub constraints. """ - def infer(self, constraint_context: ConstraintContext) -> Attribute: - if self.name in constraint_context.variables: - m_type = constraint_context.get_variable(self.name) - if isa(m_type, MemRefType[Attribute]): - return TensorType(m_type.get_element_type(), m_type.get_shape()) - if isa(m_type, UnrankedMemrefType[Attribute]): - return UnrankedTensorType(m_type.element_type) - raise ValueError(f"Unexpected {self.name} - cannot infer attribute") + memref_var_constr: VarConstraint[ + MemRefType[Attribute] | UnrankedMemrefType[Attribute] + ] + + def can_infer(self, variables: Set[str]) -> bool: + return self.memref_var_constr.can_infer(variables) + + def infer( + self, variables: dict[str, ResolveType] + ) -> TensorType[Attribute] | UnrankedTensorType[Attribute]: + memref_type = self.memref_var_constr.infer(variables) + if isinstance(memref_type, MemRefType): + return TensorType(memref_type.element_type, memref_type.shape) + return UnrankedTensorType(memref_type.element_type) def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: - if self.name in constraint_context.variables: - seen = constraint_context.get_variable(self.name) - if not ( - isinstance(attr, ContainerType) - and isinstance(seen, ContainerType) - and attr.get_element_type() == seen.get_element_type() - ): - raise VerifyException( - f"Unexpected {self.name} - cannot verify element type of attribute {attr}" - ) - if ( - isinstance(attr, ShapedType) != isinstance(seen, ShapedType) - or isinstance(attr, ShapedType) - and isinstance(seen, ShapedType) - and attr.get_shape() != seen.get_shape() - ): - raise VerifyException( - f"Unexpected {self.name} - cannot verify shape of attribute {attr}" - ) - elif isinstance(attr, ContainerType): - self.constraint.verify(attr, constraint_context) - constraint_context.set_variable(self.name, attr) - else: - raise VerifyException( - f"Unexpected {self.name} - attribute must be ContainerType" - ) + if isa(attr, TensorType[Attribute]): + memref_type = MemRefType(attr.element_type, attr.shape) + return self.memref_var_constr.verify(memref_type, constraint_context) + if isa(attr, UnrankedTensorType[Attribute]): + memref_type = UnrankedMemrefType.from_type(attr.element_type) + + raise VerifyException(f"Expected TensorType or UnrankedTensorType, got {attr}") @irdl_op_definition @@ -108,16 +102,10 @@ def __init__( class ToTensorOp(IRDLOperation): name = "bufferization.to_tensor" - memref = operand_def( - TensorMemrefInferenceConstraint( - "T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr]) - ) - ) - tensor = result_def( - TensorMemrefInferenceConstraint( - "T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr]) - ) - ) + T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr) + + memref = operand_def(T) + tensor = result_def(TensorFromMemrefConstraint(T)) writable = opt_prop_def(UnitAttr) restrict = opt_prop_def(UnitAttr) @@ -153,16 +141,10 @@ def __init__( class ToMemrefOp(IRDLOperation): name = "bufferization.to_memref" - tensor = operand_def( - TensorMemrefInferenceConstraint( - "T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr]) - ) - ) - memref = result_def( - TensorMemrefInferenceConstraint( - "T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr]) - ) - ) + T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr) + + tensor = operand_def(TensorFromMemrefConstraint(T)) + memref = result_def(T) read_only = opt_prop_def(UnitAttr) assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)" @@ -172,21 +154,10 @@ class ToMemrefOp(IRDLOperation): class MaterializeInDestination(IRDLOperation): name = "bufferization.materialize_in_destination" - source = operand_def( - TensorMemrefInferenceConstraint( - "T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr - ) - ) - dest = operand_def( - TensorMemrefInferenceConstraint( - "T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr - ) - ) - result = result_def( - TensorMemrefInferenceConstraint( - "T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr - ) - ) + T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr) + source = operand_def(T) + dest = operand_def(T) + result = result_def(T) restrict = opt_prop_def(UnitAttr) writable = opt_prop_def(UnitAttr) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 220da9dad7..9bdf8b6909 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -2,7 +2,7 @@ import math from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from enum import Enum from math import prod @@ -52,6 +52,7 @@ MessageConstraint, ParamAttrConstraint, ParameterDef, + ResolveType, attr_constr_coercion, base, irdl_attr_definition, @@ -83,7 +84,7 @@ """ -class ShapedType(ABC): +class ShapedType(Attribute, ABC): @abstractmethod def get_num_dims(self) -> int: ... @@ -101,14 +102,6 @@ def strides_for_shape(shape: Sequence[int], factor: int = 1) -> tuple[int, ...]: return tuple(accumulate(reversed(shape), operator.mul, initial=factor))[-2::-1] -class AnyShapedType(AttrConstraint): - def verify( - self, attr: Attribute, constraint_context: ConstraintContext | None = None - ) -> None: - if not isinstance(attr, ShapedType): - raise Exception(f"expected type ShapedType but got {attr}") - - _ContainerElementTypeT = TypeVar( "_ContainerElementTypeT", bound=Attribute | None, covariant=True ) @@ -1637,6 +1630,47 @@ def constr( AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType) +@dataclass(frozen=True, init=False) +class TensorOrMemrefOf( + GenericAttrConstraint[TensorType[AttributeCovT] | MemRefType[AttributeCovT]] +): + """A type constraint that can be nested once in a vector or a tensor.""" + + elem_constr: GenericAttrConstraint[AttributeCovT] + + def __init__( + self, + elem_constr: AttributeCovT + | type[AttributeCovT] + | GenericAttrConstraint[AttributeCovT], + ) -> None: + object.__setattr__(self, "elem_constr", attr_constr_coercion(elem_constr)) + + @staticmethod + def _wrap_resolver( + resolver: Callable[[AttributeCovT], ResolveType], + ) -> Callable[[TensorType[AttributeCovT] | MemRefType[AttributeCovT]], ResolveType]: + return lambda attr: resolver(attr.element_type) + + def get_resolvers( + self, + ) -> dict[ + str, + Callable[[TensorType[AttributeCovT] | MemRefType[AttributeCovT]], ResolveType], + ]: + return { + v: self._wrap_resolver(r) + for v, r in self.elem_constr.get_resolvers().items() + } + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + if isinstance(attr, VectorType) or isinstance(attr, TensorType): + attr = cast(VectorType[Attribute] | TensorType[Attribute], attr) + self.elem_constr.verify(attr.element_type, constraint_context) + else: + raise VerifyException(f"Expected tensor or memref type, got {attr}") + + @irdl_attr_definition class UnrankedMemrefType( Generic[_UnrankedMemrefTypeElems], diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 466e6a3b41..854ed50c86 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -13,7 +13,6 @@ AffineMapAttr, AnyFloat, AnyMemRefType, - AnyShapedType, AnyTensorType, ArrayAttr, DenseArrayBase, @@ -100,7 +99,7 @@ class Generic(IRDLOperation): name = "linalg.generic" inputs = var_operand_def() - outputs = var_operand_def(AnyShapedType()) + outputs = var_operand_def(base(ShapedType)) res = var_result_def(AnyTensorType) @@ -396,7 +395,7 @@ class NamedOpBase(IRDLOperation, ABC): """ inputs = var_operand_def() - outputs = var_operand_def(AnyShapedType()) + outputs = var_operand_def(base(ShapedType)) res = var_result_def(AnyTensorType) @@ -951,7 +950,7 @@ class PoolingOpsBase(IRDLOperation, ABC): """Base class for linalg pooling operations.""" inputs = var_operand_def() - outputs = var_operand_def(AnyShapedType()) + outputs = var_operand_def(base(ShapedType)) res = var_result_def(AnyTensorType) @@ -1002,7 +1001,7 @@ class ConvOpsBase(IRDLOperation, ABC): """Base class for linalg convolution operations.""" inputs = var_operand_def() - outputs = var_operand_def(AnyShapedType()) + outputs = var_operand_def(base(ShapedType)) res = var_result_def(AnyTensorType) diff --git a/xdsl/dialects/onnx.py b/xdsl/dialects/onnx.py index 2123f72ba6..077dae99b7 100644 --- a/xdsl/dialects/onnx.py +++ b/xdsl/dialects/onnx.py @@ -2,13 +2,14 @@ import math from abc import ABC -from typing import Annotated, cast +from typing import Annotated, ClassVar, cast from typing_extensions import Self from xdsl.dialects.builtin import ( Any, AnyFloat, + AnyFloatConstr, AnyIntegerAttr, AnyTensorType, ArrayAttr, @@ -22,6 +23,7 @@ SSAValue, StringAttr, SymbolRefAttr, + TensorOrMemrefOf, TensorType, ) from xdsl.ir import ( @@ -31,6 +33,7 @@ from xdsl.irdl import ( ConstraintVar, IRDLOperation, + VarConstraint, attr_def, base, irdl_op_definition, @@ -703,9 +706,9 @@ class MaxPoolSingleOut(IRDLOperation): name = "onnx.MaxPoolSingleOut" - T = Annotated[AnyFloat | IntegerType, ConstraintVar("T")] - data = operand_def(base(TensorType[T]) | base(MemRefType[T])) - output = result_def(base(TensorType[T]) | base(MemRefType[T])) + T: ClassVar = VarConstraint("T", AnyFloatConstr | base(IntegerType)) + data = operand_def(TensorOrMemrefOf(T)) + output = result_def(TensorOrMemrefOf(T)) auto_pad = attr_def(StringAttr) ceil_mode = attr_def(AnyIntegerAttr) diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index fd157496cc..44d6787a75 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -1203,10 +1203,7 @@ def matches(attr: TensorType[Attribute], other: Attribute) -> bool: and attr.get_element_type() == other.get_element_type() ) - def verify( - self, attr: Attribute, constraint_context: ConstraintContext | None = None - ) -> None: - constraint_context = constraint_context or ConstraintContext() + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: if self.name in constraint_context.variables: if isa(attr, TensorType[Attribute]) and TensorIgnoreSizeConstraint.matches( attr, constraint_context.get_variable(self.name) diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index 15956c3c78..7d01a93d49 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -1,10 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence, Set from dataclasses import dataclass, field from inspect import isclass -from typing import Generic, TypeAlias, TypeVar +from typing import Generic, TypeAlias, TypeVar, cast from typing_extensions import assert_never @@ -12,6 +12,18 @@ from xdsl.utils.exceptions import VerifyException from xdsl.utils.runtime_final import is_runtime_final +_AttributeCovT = TypeVar("_AttributeCovT", bound=Attribute, covariant=True) + +ResolveType: TypeAlias = Attribute | Sequence[Attribute] + + +def print_resolve_type(r: ResolveType | None) -> str: + if isinstance(r, Attribute): + return str(r) + elif isinstance(r, Sequence): + return str(tuple(str(x) for x in r)) + return "None" + @dataclass class ConstraintContext: @@ -53,9 +65,6 @@ def update(self, other: ConstraintContext): self._range_variables.update(other._range_variables) -_AttributeCovT = TypeVar("_AttributeCovT", bound=Attribute, covariant=True) - - @dataclass(frozen=True) class GenericAttrConstraint(Generic[AttributeCovT], ABC): """Constrain an attribute to a certain value.""" @@ -72,14 +81,13 @@ def verify( """ ... - def get_resolved_variables(self) -> set[str]: + def get_resolvers(self) -> dict[str, Callable[[AttributeCovT], ResolveType]]: """ - Get the set of type variables that are always resolved when verifying - the constraint. + Get a dictionary of variables that can be solved from this attribute. """ - return set() + return dict() - def can_infer(self, constraint_names: set[str]) -> bool: + def can_infer(self, variables: Set[str]) -> bool: """ Check if there is enough information to infer the attribute given the constraint variables that are already set. @@ -87,9 +95,9 @@ def can_infer(self, constraint_names: set[str]) -> bool: # By default, we cannot infer anything. return False - def infer(self, constraint_context: ConstraintContext) -> Attribute: + def infer(self, variables: dict[str, ResolveType]) -> AttributeCovT: """ - Infer the attribute given the constraint variables that are already set. + Infer the attribute given the the values for all variables. Raises an exception if the attribute cannot be inferred. If `can_infer` returns `True` with the given constraint variables, this method should @@ -143,17 +151,15 @@ def verify( self.constraint.verify(attr, constraint_context) constraint_context.set_variable(self.name, attr) - def get_resolved_variables(self) -> set[str]: - return {self.name, *self.constraint.get_resolved_variables()} + def get_resolvers(self) -> dict[str, Callable[[AttributeCovT], ResolveType]]: + return {self.name: lambda attr: attr} - def can_infer(self, constraint_names: set[str]) -> bool: - return self.name in constraint_names + def infer(self, variables: dict[str, ResolveType]) -> AttributeCovT: + v = variables[self.name] + return cast(AttributeCovT, v) - def infer(self, constraint_context: ConstraintContext) -> Attribute: - constraint_context = constraint_context or ConstraintContext() - if self.name not in constraint_context.variables: - raise ValueError(f"Cannot infer attribute from constraint {self}") - return constraint_context.get_variable(self.name) + def can_infer(self, variables: Set[str]) -> bool: + return self.name in variables def get_unique_base(self) -> type[Attribute] | None: return self.constraint.get_unique_base() @@ -189,10 +195,10 @@ def verify( if attr != self.attr: raise VerifyException(f"Expected attribute {self.attr} but got {attr}") - def can_infer(self, constraint_names: set[str]) -> bool: + def can_infer(self, variables: Set[str]) -> bool: return True - def infer(self, constraint_context: ConstraintContext) -> Attribute: + def infer(self, variables: dict[str, ResolveType]) -> AttributeCovT: return self.attr def get_unique_base(self) -> type[Attribute] | None: @@ -296,12 +302,10 @@ def __or__( ) -> AnyOf[AttributeCovT | _AttributeCovT]: return AnyOf((*self.attr_constrs, value)) - def get_resolved_variables(self) -> set[str]: - if len(self.attr_constrs) == 0: - return set() - return set[str].intersection( - *(constr.get_resolved_variables() for constr in self.attr_constrs) - ) + def get_resolvers(self) -> dict[str, Callable[[AttributeCovT], ResolveType]]: + if len(self.attr_constrs) == 1: + return self.attr_constrs[0].get_resolvers() + return dict() def get_unique_base(self) -> type[Attribute] | None: bases = [constr.get_unique_base() for constr in self.attr_constrs] @@ -339,21 +343,19 @@ def verify( exc_msg += "\n".join([str(e) for e in exc_bucket]) raise VerifyException(exc_msg) - def get_resolved_variables(self) -> set[str]: - if len(self.attr_constrs) == 0: - return set() - return set[str].union( - *[constr.get_resolved_variables() for constr in self.attr_constrs] - ) + def get_resolvers(self) -> dict[str, Callable[[AttributeCovT], ResolveType]]: + d: dict[str, Callable[[AttributeCovT], ResolveType]] = dict() + for constr in self.attr_constrs: + d |= constr.get_resolvers() + return d - def can_infer(self, constraint_names: set[str]) -> bool: - return any(constr.can_infer(constraint_names) for constr in self.attr_constrs) + def can_infer(self, variables: Set[str]) -> bool: + return any(constr.can_infer(variables) for constr in self.attr_constrs) - def infer(self, constraint_context: ConstraintContext | None = None) -> Attribute: - constraint_context = constraint_context or ConstraintContext() + def infer(self, variables: dict[str, ResolveType]) -> AttributeCovT: for constr in self.attr_constrs: - if constr.can_infer(set(constraint_context.variables)): - return constr.infer(constraint_context) + if constr.can_infer(variables.keys()): + return constr.infer(variables) raise ValueError("Cannot infer attribute from constraint") def get_unique_base(self) -> type[Attribute] | None: @@ -388,15 +390,13 @@ class ParamAttrConstraint( base_attr: type[ParametrizedAttributeCovT] """The base attribute type.""" - param_constrs: tuple[GenericAttrConstraint[Attribute], ...] + param_constrs: tuple[AttrConstraint, ...] """The attribute parameter constraints""" def __init__( self, base_attr: type[ParametrizedAttributeCovT], - param_constrs: Sequence[ - (Attribute | type[Attribute] | GenericAttrConstraint[Attribute] | None) - ], + param_constrs: Sequence[(Attribute | type[Attribute] | AttrConstraint | None)], ): constrs = tuple( attr_constr_coercion(constr) if constr is not None else AnyAttr() @@ -422,14 +422,22 @@ def verify( for idx, param_constr in enumerate(self.param_constrs): param_constr.verify(attr.parameters[idx], constraint_context) - def get_resolved_variables(self) -> set[str]: - if not self.param_constrs: - return set() - return { - var - for constr in self.param_constrs - for var in constr.get_resolved_variables() - } + @staticmethod + def _wrap_resolver( + i: int, resolver: Callable[[Attribute], ResolveType] + ) -> Callable[[ParametrizedAttributeCovT], ResolveType]: + return lambda a: resolver(a.parameters[i]) + + def get_resolvers( + self, + ) -> dict[str, Callable[[ParametrizedAttributeCovT], ResolveType]]: + resolvers: dict[str, Callable[[ParametrizedAttributeCovT], ResolveType]] = ( + dict() + ) + for i, param_constr in enumerate(self.param_constrs): + for v, r in param_constr.get_resolvers().items(): + resolvers[v] = self._wrap_resolver(i, r) + return resolvers def get_unique_base(self) -> type[Attribute] | None: if is_runtime_final(self.base_attr): @@ -470,23 +478,22 @@ def verify( *e.args[1:], ) - def get_resolved_variables(self) -> set[str]: - return self.constr.get_resolved_variables() + def get_resolvers(self) -> dict[str, Callable[[AttributeCovT], ResolveType]]: + return self.constr.get_resolvers() def get_unique_base(self) -> type[Attribute] | None: return self.constr.get_unique_base() - def can_infer(self, constraint_names: set[str]) -> bool: - return self.constr.can_infer(constraint_names) + def can_infer(self, variables: Set[str]) -> bool: + return self.constr.can_infer(variables) - def infer(self, constraint_context: ConstraintContext) -> Attribute: - return self.constr.infer(constraint_context) + def infer(self, variables: dict[str, ResolveType]) -> AttributeCovT: + return self.constr.infer(variables) +@dataclass(frozen=True) class GenericRangeConstraint(Generic[AttributeCovT], ABC): - """ - Constrain a range of attributes to a certain value. - """ + """Constrain an range of attributes to certain values.""" @abstractmethod def verify( @@ -500,32 +507,37 @@ def verify( """ ... - def get_resolved_variables(self) -> set[str]: + def get_resolvers( + self, + ) -> dict[str, Callable[[Sequence[AttributeCovT]], ResolveType]]: """ - Get the set of type variables that are always resolved when verifying - the constraint. + Get a dictionary of variables that can be solved from this attribute. """ - return set() + return dict() - def can_infer(self, constraint_names: set[str]) -> bool: + def can_infer(self, variables: Set[str]) -> bool: """ - Check if there is enough information to infer the range given the + Check if there is enough information to infer the attribute given the constraint variables that are already set. """ # By default, we cannot infer anything. return False def infer( - self, length: int, constraint_context: ConstraintContext - ) -> Sequence[Attribute]: + self, length: int, variables: dict[str, ResolveType] + ) -> Sequence[AttributeCovT]: """ - Infer the range given the constraint variables that are already set. + Infer the attribute given the the values for all variables. - Raises an exception if the range cannot be inferred. If `can_infer` + Raises an exception if the attribute cannot be inferred. If `can_infer` returns `True` with the given constraint variables, this method should not raise an exception. """ - raise ValueError("Cannot infer range from constraint") + raise ValueError("Cannot infer attribute from constraint") + + def get_unique_base(self) -> type[Attribute] | None: + """Get the unique base type that can satisfy the constraint, if any.""" + return None RangeConstraint: TypeAlias = GenericRangeConstraint[Attribute] @@ -547,9 +559,8 @@ class RangeVarConstraint(GenericRangeConstraint[AttributeCovT]): def verify( self, attrs: Sequence[Attribute], - constraint_context: ConstraintContext | None = None, + constraint_context: ConstraintContext, ) -> None: - constraint_context = constraint_context or ConstraintContext() if self.name in constraint_context.range_variables: if tuple(attrs) != constraint_context.get_range_variable(self.name): raise VerifyException( @@ -560,22 +571,24 @@ def verify( self.constraint.verify(attrs, constraint_context) constraint_context.set_range_variable(self.name, tuple(attrs)) - def get_resolved_variables(self) -> set[str]: - return {self.name, *self.constraint.get_resolved_variables()} + def get_resolvers( + self, + ) -> dict[str, Callable[[Sequence[AttributeCovT]], ResolveType]]: + return {self.name: lambda attr_tuple: attr_tuple} - def can_infer(self, constraint_names: set[str]) -> bool: - return self.name in constraint_names + def can_infer(self, variables: Set[str]) -> bool: + return self.name in variables def infer( - self, length: int, constraint_context: ConstraintContext - ) -> Sequence[Attribute]: - constraint_context = constraint_context or ConstraintContext() - if self.name not in constraint_context.range_variables: - raise ValueError(f"Cannot infer attribute from constraint {self}") - return constraint_context.get_range_variable(self.name) + self, + length: int, + variables: dict[str, ResolveType], + ) -> Sequence[AttributeCovT]: + v = variables[self.name] + return cast(Sequence[AttributeCovT], v) -@dataclass +@dataclass(frozen=True) class RangeOf(GenericRangeConstraint[AttributeCovT]): """ Constrain each element in a range to satisfy a given constraint. @@ -591,19 +604,19 @@ def verify( for a in attrs: self.constr.verify(a, constraint_context) - def get_resolved_variables(self) -> set[str]: - return self.constr.get_resolved_variables() - - def can_infer(self, constraint_names: set[str]) -> bool: - return self.constr.can_infer(constraint_names) + def can_infer(self, variables: Set[str]) -> bool: + return self.constr.can_infer(variables) def infer( - self, length: int, constraint_context: ConstraintContext - ) -> list[Attribute]: - return [self.constr.infer(constraint_context)] * length + self, + length: int, + variables: dict[str, ResolveType], + ) -> Sequence[AttributeCovT]: + attr = self.constr.infer(variables) + return tuple(attr for _ in range(length)) -@dataclass +@dataclass(frozen=True) class SingleOf(GenericRangeConstraint[AttributeCovT]): """ Constrain a range to only contain a single element, which should satisfy a given constraint. @@ -620,16 +633,22 @@ def verify( raise VerifyException(f"Expected a single attribute, got {len(attrs)}") self.constr.verify(attrs[0], constraint_context) - def get_resolved_variables(self) -> set[str]: - return self.constr.get_resolved_variables() + def get_resolvers( + self, + ) -> dict[str, Callable[[Sequence[AttributeCovT]], ResolveType]]: + return { + v: lambda attrs: r(attrs[0]) for v, r in self.constr.get_resolvers().items() + } - def can_infer(self, constraint_names: set[str]) -> bool: - return self.constr.can_infer(constraint_names) + def can_infer(self, variables: Set[str]) -> bool: + return self.constr.can_infer(variables) def infer( - self, length: int, constraint_context: ConstraintContext - ) -> list[Attribute]: - return [self.constr.infer(constraint_context)] + self, + length: int, + variables: dict[str, ResolveType], + ) -> Sequence[AttributeCovT]: + return (self.constr.infer(variables),) def range_constr_coercion( diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 99006f5ee0..c5bdae7c3e 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -7,7 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass, field from typing import Literal @@ -21,18 +21,17 @@ TypedAttribute, ) from xdsl.irdl import ( - ConstraintContext, IRDLOperation, IRDLOperationInvT, OpDef, OptionalDef, + ResolveType, Successor, VariadicDef, VarIRConstruct, ) from xdsl.parser import Parser, UnresolvedOperand from xdsl.printer import Printer -from xdsl.utils.exceptions import VerifyException from xdsl.utils.lexer import PunctuationSpelling OperandOrResult = Literal[VarIRConstruct.OPERAND, VarIRConstruct.RESULT] @@ -53,7 +52,7 @@ class ParsingState: successors: list[Successor | None | Sequence[Successor]] attributes: dict[str, Attribute] properties: dict[str, Attribute] - constraint_context: ConstraintContext + variables: dict[str, ResolveType] def __init__(self, op_def: OpDef): self.operands = [None] * len(op_def.operands) @@ -63,7 +62,7 @@ def __init__(self, op_def: OpDef): self.successors = [None] * len(op_def.successors) self.attributes = {} self.properties = {} - self.constraint_context = ConstraintContext() + self.variables = {} @dataclass @@ -94,6 +93,9 @@ class FormatProgram: stmts: tuple[FormatDirective, ...] """The statements composing the program. They are executed in order.""" + resolvers: dict[str, Callable[[ParsingState], ResolveType]] + """Resolvers for all type variables.""" + @staticmethod def from_str(input: str, op_def: OpDef) -> FormatProgram: """ @@ -120,7 +122,7 @@ def parse( stmt.parse(parser, state) # Get constraint variables from the parsed operand and result types - self.assign_constraint_variables(parser, state, op_def) + self.resolve_constraint_variables(state) # Infer operand types that should be inferred unresolved_operands = state.operands @@ -167,35 +169,8 @@ def parse( successors=state.successors, ) - def assign_constraint_variables( - self, parser: Parser, state: ParsingState, op_def: OpDef - ): - """ - Assign constraint variables with values got from the - parsed operand and result types. - """ - if any(type is None for type in (*state.operand_types, *state.result_types)): - try: - for (_, operand_def), operand_type in zip( - op_def.operands, state.operand_types, strict=True - ): - if operand_type is None: - continue - if isinstance(operand_type, Attribute): - operand_type = (operand_type,) - operand_def.constr.verify(operand_type, state.constraint_context) - for (_, result_def), result_type in zip( - op_def.results, state.result_types, strict=True - ): - if result_type is None: - continue - if isinstance(result_type, Attribute): - result_type = (result_type,) - result_def.constr.verify(result_type, state.constraint_context) - except VerifyException as e: - parser.raise_error( - "Verification error while inferring operation type: " + str(e) - ) + def resolve_constraint_variables(self, state: ParsingState): + state.variables = {v: r(state) for v, r in self.resolvers.items()} def resolve_operand_types(self, state: ParsingState, op_def: OpDef) -> None: """ @@ -209,7 +184,8 @@ def resolve_operand_types(self, state: ParsingState, op_def: OpDef) -> None: operand = state.operands[i] range_length = len(operand) if isinstance(operand, Sequence) else 1 operand_type = operand_def.constr.infer( - range_length, state.constraint_context + range_length, + state.variables, ) resolved_operand_type: Attribute | Sequence[Attribute] if isinstance(operand_def, OptionalDef): @@ -242,7 +218,8 @@ def resolve_result_types(self, state: ParsingState, op_def: OpDef) -> None: ) range_length = 1 inferred_result_types = result_def.constr.infer( - range_length, state.constraint_context + range_length, + state.variables, ) resolved_result_type = inferred_result_types[0] state.result_types[i] = resolved_result_type diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index d9c617905f..54cc5b97dd 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -6,7 +6,7 @@ from __future__ import annotations import re -from collections.abc import Callable +from collections.abc import Callable, Sequence, Set from dataclasses import dataclass, field from enum import Enum, auto from itertools import pairwise @@ -18,7 +18,6 @@ AttrOrPropDef, AttrSizedOperandSegments, AttrSizedSegments, - ConstraintContext, OpDef, OptionalDef, OptOperandDef, @@ -28,6 +27,7 @@ OptSuccessorDef, ParamAttrConstraint, ParsePropInAttrDict, + ResolveType, VariadicDef, VarOperandDef, VarRegionDef, @@ -56,6 +56,7 @@ OptionalResultVariable, OptionalSuccessorVariable, OptionalUnitAttrVariable, + ParsingState, PunctuationDirective, RegionDirective, RegionVariable, @@ -176,15 +177,15 @@ def parse_format(self) -> FormatProgram: elements.append(self.parse_directive()) self.add_reserved_attrs_to_directive(elements) - seen_variables = self.resolve_types() + resolvers = self.resolve_types() self.verify_directives(elements) self.verify_attr_dict() self.verify_properties() - self.verify_operands(seen_variables) - self.verify_results(seen_variables) + self.verify_operands(resolvers.keys()) + self.verify_results(resolvers.keys()) self.verify_regions() self.verify_successors() - return FormatProgram(tuple(elements)) + return FormatProgram(tuple(elements), resolvers) def verify_directives(self, elements: list[FormatDirective]): """ @@ -241,20 +242,48 @@ def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]): ) return - def resolve_types(self) -> set[str]: + @staticmethod + def wrap_resolver( + i: int, + is_operands: bool, + resolver: Callable[[Sequence[Attribute]], ResolveType], + ) -> Callable[[ParsingState], ResolveType]: + def wrapped(state: ParsingState) -> ResolveType: + if is_operands: + types = state.operand_types[i] + else: + types = state.result_types[i] + assert types is not None + if isinstance(types, Attribute): + types = (types,) + return resolver(types) + + return wrapped + + def resolve_types(self) -> dict[str, Callable[[ParsingState], ResolveType]]: """ Find out which constraint variables can be inferred from the parsed attributes. """ - seen_variables = set[str]() + resolvers = dict[str, Callable[[ParsingState], ResolveType]]() for i, (_, operand_def) in enumerate(self.op_def.operands): if self.seen_operand_types[i]: - seen_variables |= operand_def.constr.get_resolved_variables() + resolvers.update( + { + v: self.wrap_resolver(i, True, r) + for v, r in operand_def.constr.get_resolvers().items() + } + ) for i, (_, result_def) in enumerate(self.op_def.results): if self.seen_result_types[i]: - seen_variables |= result_def.constr.get_resolved_variables() - return seen_variables + resolvers.update( + { + v: self.wrap_resolver(i, False, r) + for v, r in result_def.constr.get_resolvers().items() + } + ) + return resolvers - def verify_operands(self, seen_variables: set[str]): + def verify_operands(self, variables: Set[str]): """ Check that all operands and operand types are refered at least once, or inferred from another construct. @@ -276,14 +305,14 @@ def verify_operands(self, seen_variables: set[str]): "directive to the custom assembly format" ) if not seen_operand_type: - if not operand_def.constr.can_infer(seen_variables): + if not operand_def.constr.can_infer(variables): self.raise_error( f"type of operand '{operand_name}' cannot be inferred, " f"consider adding a 'type(${operand_name})' directive to the " "custom assembly format" ) - def verify_results(self, seen_variables: set[str]): + def verify_results(self, variables: Set[str]): """Check that all result types are refered at least once, or inferred from another construct.""" @@ -291,7 +320,7 @@ def verify_results(self, seen_variables: set[str]): self.seen_result_types, self.op_def.results, strict=True ): if not result_type: - if not result_def.constr.can_infer(seen_variables): + if not result_def.constr.can_infer(variables): self.raise_error( f"type of result '{result_name}' cannot be inferred, " f"consider adding a 'type(${result_name})' directive to the " @@ -499,7 +528,7 @@ def parse_optional_variable( unique_base.get_type_index() ] if type_constraint.can_infer(set()): - unique_type = type_constraint.infer(ConstraintContext()) + unique_type = type_constraint.infer(dict()) if ( unique_base is not None and unique_base in Builtin.attributes diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index 33da6a9298..8e71a21ba8 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -356,10 +356,7 @@ class VarOperandDef(OperandDef, VariadicDef): def __init__( self, - attr: Attribute - | type[Attribute] - | AttrConstraint - | GenericRangeConstraint[Attribute], + attr: Attribute | type[Attribute] | AttrConstraint | RangeConstraint, ): self.constr = range_constr_coercion(attr) @@ -386,13 +383,10 @@ class ResultDef(OperandOrResultDef): """The result constraint.""" def __init__( - self, - attr: Attribute - | type[Attribute] - | AttrConstraint - | GenericRangeConstraint[Attribute], + self, attr: Attribute | type[Attribute] | AttrConstraint | RangeConstraint ): - self.constr = range_constr_coercion(attr) + assert not isinstance(attr, GenericRangeConstraint) + self.constr = single_range_constr_coercion(attr) @dataclass(init=False) @@ -400,11 +394,7 @@ class VarResultDef(ResultDef, VariadicDef): """An IRDL variadic result definition.""" def __init__( - self, - attr: Attribute - | type[Attribute] - | AttrConstraint - | GenericRangeConstraint[Attribute], + self, attr: Attribute | type[Attribute] | AttrConstraint | RangeConstraint ): self.constr = range_constr_coercion(attr) @@ -530,36 +520,24 @@ def __init__(self, cls: type[_ClsT]): class _RangeConstrainedOpDefField(Generic[_ClsT], _OpDefField[_ClsT]): - param: ( - RangeConstraint - | AttrConstraint - | Attribute - | type[Attribute] - | TypeVar - | ConstraintVar - ) + param: RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar def __init__( self, cls: type[_ClsT], - param: RangeConstraint - | AttrConstraint - | Attribute - | type[Attribute] - | TypeVar - | ConstraintVar, + param: RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar, ): super().__init__(cls) self.param = param class _ConstrainedOpDefField(Generic[_ClsT], _OpDefField[_ClsT]): - param: AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar + param: AttrConstraint | Attribute | type[Attribute] | TypeVar def __init__( self, cls: type[_ClsT], - param: AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar, + param: AttrConstraint | Attribute | type[Attribute] | TypeVar, ): super().__init__(cls) self.param = param @@ -589,7 +567,7 @@ class _AttrOrPropFieldDef( def __init__( self, cls: type[AttrOrPropInvT], - param: AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar, + param: AttrConstraint | Attribute | type[Attribute] | TypeVar, ir_name: str | None = None, default_value: Attribute | None = None, ): @@ -637,9 +615,7 @@ class _SuccessorFieldDef(_OpDefField[SuccessorDef]): def result_def( - constraint: ( - AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar - ) = Attribute, + constraint: (AttrConstraint | Attribute | type[Attribute] | TypeVar) = Attribute, *, default: None = None, resolver: None = None, @@ -738,12 +714,7 @@ def opt_prop_def( def attr_def( - constraint: ( - type[AttributeInvT] - | TypeVar - | GenericAttrConstraint[AttributeInvT] - | ConstraintVar - ), + constraint: (type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT]), default_value: Attribute | None = None, *, attr_name: str | None = None, @@ -803,9 +774,7 @@ def opt_attr_def( def operand_def( - constraint: ( - AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar - ) = Attribute, + constraint: (AttrConstraint | Attribute | type[Attribute] | TypeVar) = Attribute, *, default: None = None, resolver: None = None, @@ -1160,14 +1129,14 @@ def get_constraint( # Get attribute constraints from a list of pyrdl constraints def get_range_constraint( pyrdl_constr: ( - GenericRangeConstraint[Attribute] + RangeConstraint | AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar ), - ) -> GenericRangeConstraint[Attribute]: + ) -> RangeConstraint: if isinstance(pyrdl_constr, GenericRangeConstraint): return pyrdl_constr return RangeOf(get_constraint(pyrdl_constr))