From e0224a8457b249ee5660cbdf19a273fc5c1e08bc Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 22 Nov 2024 17:13:24 +0000 Subject: [PATCH] core: make TypedAttribute not generic (#3505) The bit that used the generic wasn't type-sound anyway. --- pyproject.toml | 18 +- tests/dialects/test_bufferization.py | 5 +- .../with-mlir/dialects/linalg/ops.mlir | 13 +- .../apply-pdl/apply_pdl_build_type.mlir | 27 + tests/irdl/test_attr_constraint.py | 7 +- tests/irdl/test_attribute_definition.py | 2 +- .../irdl/test_declarative_assembly_format.py | 15 +- tests/utils/test_scoped_dict.py | 13 + xdsl/dialects/bufferization.py | 6 +- xdsl/dialects/builtin.py | 19 +- xdsl/dialects/linalg.py | 2 + xdsl/dialects/memref_stream.py | 37 +- xdsl/dialects/riscv_snitch.py | 29 +- xdsl/dialects/stream.py | 153 +----- xdsl/dialects/utils/__init__.py | 4 + xdsl/dialects/utils/fast_math.py | 29 + xdsl/dialects/{utils.py => utils/format.py} | 34 -- xdsl/interpreters/pdl.py | 7 + xdsl/ir/core.py | 6 +- xdsl/irdl/constraints.py | 59 +-- xdsl/irdl/declarative_assembly_format.py | 498 +++++++++--------- .../declarative_assembly_format_parser.py | 210 ++++---- xdsl/parser/base_parser.py | 33 +- xdsl/printer.py | 7 - xdsl/utils/scoped_dict.py | 19 +- 25 files changed, 628 insertions(+), 624 deletions(-) create mode 100644 tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir create mode 100644 xdsl/dialects/utils/__init__.py create mode 100644 xdsl/dialects/utils/fast_math.py rename xdsl/dialects/{utils.py => utils/format.py} (93%) diff --git a/pyproject.toml b/pyproject.toml index 85f01a6fda..f79ac3006b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dev = [ "nbval<0.12", "filecheck==1.0.1", "lit<19.0.0", - "marimo==0.9.20", + "marimo==0.9.21", "pre-commit==4.0.1", "ruff==0.7.4", "asv<0.7", @@ -109,17 +109,19 @@ ignore = [ max-line-length = 300 [tool.ruff.lint.flake8-tidy-imports.banned-api] -"xdsl.parser.core".msg = "Use xdsl.parser instead." -"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead." -"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead." +"xdsl.dialects.utils.fast_math".msg = "Use xdsl.dialects.utils instead" +"xdsl.dialects.utils.format".msg = "Use xdsl.dialects.utils instead" +"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead" +"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead" +"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead" "xdsl.ir.core".msg = "Use xdsl.ir instead." +"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead" "xdsl.irdl.common".msg = "Use xdsl.irdl instead" "xdsl.irdl.constraints".msg = "Use xdsl.irdl instead" -"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead" "xdsl.irdl.operations".msg = "Use xdsl.irdl instead" -"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead" -"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead" -"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead" +"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead." +"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead." +"xdsl.parser.core".msg = "Use xdsl.parser instead." [tool.ruff.lint.per-file-ignores] diff --git a/tests/dialects/test_bufferization.py b/tests/dialects/test_bufferization.py index 169612fcce..19490a89e8 100644 --- a/tests/dialects/test_bufferization.py +++ b/tests/dialects/test_bufferization.py @@ -23,6 +23,7 @@ from xdsl.ir import Attribute from xdsl.irdl import ( EqAttrConstraint, + InferenceContext, IRDLOperation, VarConstraint, irdl_op_definition, @@ -39,13 +40,13 @@ def test_tensor_from_memref_inference(): EqAttrConstraint(MemRefType(f64, [10, 20, 30])) ) assert constr2.can_infer(set()) - assert constr2.infer({}) == TensorType(f64, [10, 20, 30]) + assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30]) constr3 = TensorFromMemrefConstraint( EqAttrConstraint(UnrankedMemrefType.from_type(f64)) ) assert constr3.can_infer(set()) - assert constr3.infer({}) == UnrankedTensorType(f64) + assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64) @irdl_op_definition diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir index 3a62737b2c..9d942b95d9 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir @@ -59,6 +59,9 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>) %18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %20 = "test.op"() : () -> (memref<64x4096xf32>) +%zero = arith.constant 0: f32 +linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>) + linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>) @@ -99,17 +102,19 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou // CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>) // CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1] // CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) { -// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32): -// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32 +// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32): +// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_2 : f32 // CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -> tensor<2x3xf32> // CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) // CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32> +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>) // CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>) // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 // CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32> -// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> +// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32> // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir b/tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir new file mode 100644 index 0000000000..15468e3446 --- /dev/null +++ b/tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir @@ -0,0 +1,27 @@ +// RUN: xdsl-opt %s -p apply-pdl | filecheck %s + +%x = "test.op"() : () -> (i32) + +pdl.pattern : benefit(1) { + %in_type = pdl.type: i32 + %root = pdl.operation "test.op" -> (%in_type: !pdl.type) + pdl.rewrite %root { + %out_type = pdl.type: i64 + %new_op = pdl.operation "test.op" -> (%out_type: !pdl.type) + pdl.replace %root with %new_op + } +} + +// CHECK: builtin.module { +// CHECK-NEXT: %x = "test.op"() : () -> i64 +// CHECK-NEXT: pdl.pattern : benefit(1) { +// CHECK-NEXT: %in_type = pdl.type : i32 +// CHECK-NEXT: %root = pdl.operation "test.op" -> (%in_type : !pdl.type) +// CHECK-NEXT: pdl.rewrite %root { +// CHECK-NEXT: %out_type = pdl.type : i64 +// CHECK-NEXT: %new_op = pdl.operation "test.op" -> (%out_type : !pdl.type) +// CHECK-NEXT: pdl.replace %root with %new_op +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: diff --git a/tests/irdl/test_attr_constraint.py b/tests/irdl/test_attr_constraint.py index f37748bc87..e18b6f6773 100644 --- a/tests/irdl/test_attr_constraint.py +++ b/tests/irdl/test_attr_constraint.py @@ -10,6 +10,7 @@ AttrConstraint, BaseAttr, EqAttrConstraint, + InferenceContext, ParamAttrConstraint, ParameterDef, VarConstraint, @@ -77,7 +78,7 @@ class WrapAttr(BaseWrapAttr): ... ) assert constr.can_infer(set()) - assert constr.infer({}) == WrapAttr((StringAttr("Hello"),)) + assert constr.infer(InferenceContext()) == WrapAttr((StringAttr("Hello"),)) var_constr = ParamAttrConstraint( WrapAttr, @@ -92,7 +93,7 @@ class WrapAttr(BaseWrapAttr): ... ) assert var_constr.can_infer({"T"}) - assert var_constr.infer({"T": StringAttr("Hello")}) == WrapAttr( + assert var_constr.infer(InferenceContext({"T": StringAttr("Hello")})) == WrapAttr( (StringAttr("Hello"),) ) @@ -127,7 +128,7 @@ class NoParamAttr(BaseNoParamAttr): ... constr = BaseAttr(NoParamAttr) assert constr.can_infer(set()) - assert constr.infer({}) == NoParamAttr() + assert constr.infer(InferenceContext()) == NoParamAttr() base_constr = BaseAttr(BaseNoParamAttr) assert not base_constr.can_infer(set()) diff --git a/tests/irdl/test_attribute_definition.py b/tests/irdl/test_attribute_definition.py index a96bc462c6..8f403ca4e6 100644 --- a/tests/irdl/test_attribute_definition.py +++ b/tests/irdl/test_attribute_definition.py @@ -251,7 +251,7 @@ def test_typed_attribute(): @irdl_attr_definition class TypedAttr( # pyright: ignore[reportUnusedClass] - TypedAttribute[Attribute] + TypedAttribute ): name = "test.typed" diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 9d3a1484f6..dded57875f 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -3,7 +3,7 @@ import textwrap from collections.abc import Callable from io import StringIO -from typing import ClassVar, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar import pytest @@ -12,6 +12,8 @@ from xdsl.dialects.builtin import ( I32, BoolAttr, + Float64Type, + FloatAttr, IntegerAttr, MemRefType, ModuleOp, @@ -603,20 +605,21 @@ class OptionalAttributeOp(IRDLOperation): "program, generic_program", [ ( - "test.typed_attr 3", - '"test.typed_attr"() {"attr" = 3 : i32} : () -> ()', + "test.typed_attr 3 3.000000e+00", + '"test.typed_attr"() {"attr" = 3 : i32, "float_attr" = 3.000000e+00 : f64} : () -> ()', ), ], ) def test_typed_attribute_variable(program: str, generic_program: str): - """Test the parsing of optional operands""" + """Test the parsing of typed attributes""" @irdl_op_definition class TypedAttributeOp(IRDLOperation): name = "test.typed_attr" attr = attr_def(IntegerAttr[I32]) + float_attr = attr_def(FloatAttr[Annotated[Float64Type, Float64Type()]]) - assembly_format = "$attr attr-dict" + assembly_format = "$attr $float_attr attr-dict" ctx = MLContext() ctx.load_op(TypedAttributeOp) @@ -693,7 +696,7 @@ def test_unknown_variable(): """Test that variables should refer to an element in the operation.""" with pytest.raises( PyRDLOpDefinitionError, - match="expected variable to refer to an operand, attribute, region, result, or successor", + match="expected variable to refer to an operand, attribute, region, or successor", ): @irdl_op_definition diff --git a/tests/utils/test_scoped_dict.py b/tests/utils/test_scoped_dict.py index cc536dd9cc..a92b18ac4b 100644 --- a/tests/utils/test_scoped_dict.py +++ b/tests/utils/test_scoped_dict.py @@ -32,3 +32,16 @@ def test_simple(): assert 3 not in table assert 3 in inner assert 4 not in inner + + +def test_get(): + parent = ScopedDict(local_scope={"a": 1, "b": 2}) + child = ScopedDict(parent, local_scope={"a": 3, "c": 4}) + + assert child.get("a") == 3 + assert child.get("b") == 2 + assert child.get("c") == 4 + assert child.get("d") is None + + assert child.get("a", 5) == 3 + assert child.get("d", 5) == 5 diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py index 3cd932f379..3793b9d7da 100644 --- a/xdsl/dialects/bufferization.py +++ b/xdsl/dialects/bufferization.py @@ -20,8 +20,8 @@ from xdsl.irdl import ( AttrSizedOperandSegments, ConstraintContext, - ConstraintVariableType, GenericAttrConstraint, + InferenceContext, IRDLOperation, VarConstraint, irdl_op_definition, @@ -53,9 +53,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.memref_constraint.can_infer(var_constraint_names) def infer( - self, variables: dict[str, ConstraintVariableType] + self, context: InferenceContext ) -> TensorType[Attribute] | UnrankedTensorType[Attribute]: - memref_type = self.memref_constraint.infer(variables) + memref_type = self.memref_constraint.infer(context) if isinstance(memref_type, MemRefType): return TensorType(memref_type.element_type, memref_type.shape) return UnrankedTensorType(memref_type.element_type) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index df0e3d89cb..fe3c33384e 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -449,7 +449,7 @@ class IndexType(ParametrizedAttribute): @irdl_attr_definition class IntegerAttr( Generic[_IntegerAttrType], - TypedAttribute[_IntegerAttrType], + TypedAttribute, ): name = "integer" value: ParameterDef[IntAttr] @@ -504,8 +504,8 @@ def verify(self) -> None: @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: assert isinstance(type, IntegerType | IndexType) return IntegerAttr(parser.parse_integer(allow_boolean=(type == i1)), type) @@ -634,7 +634,7 @@ def __hash__(self): @irdl_attr_definition -class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute): +class FloatAttr(Generic[_FloatAttrType], TypedAttribute): name = "float" value: ParameterDef[FloatData] @@ -668,6 +668,17 @@ def __init__( raise ValueError(f"Invalid bitwidth: {type}") super().__init__([data_attr, type]) + @staticmethod + def parse_with_type( + parser: AttrParser, + type: Attribute, + ) -> TypedAttribute: + assert isinstance(type, AnyFloat) + return FloatAttr(parser.parse_float(), type) + + def print_without_type(self, printer: Printer): + return printer.print_float(self) + AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat] AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr) diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 854ed50c86..dd1097d67e 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -642,6 +642,8 @@ class FillOp(NamedOpBase): name = "linalg.fill" + PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True + def __init__( self, inputs: Sequence[SSAValue], diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 9335d48c2a..81af928429 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -46,6 +46,7 @@ opt_prop_def, prop_def, region_def, + result_def, traits_def, var_operand_def, ) @@ -187,14 +188,44 @@ def offsets(self) -> tuple[tuple[int, ...], ...]: @irdl_op_definition -class ReadOp(stream.ReadOperation): +class ReadOp(IRDLOperation): name = "memref_stream.read" + T: ClassVar = VarConstraint("T", AnyAttr()) + + stream = operand_def(stream.ReadableStreamType.constr(T)) + res = result_def(T) + + assembly_format = "`from` $stream attr-dict `:` type($res)" + + def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): + if result_type is None: + assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) + stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + result_type = stream_type.element_type + super().__init__(operands=[stream_val], result_types=[result_type]) + + def assembly_line(self) -> str | None: + return None + @irdl_op_definition -class WriteOp(stream.WriteOperation): +class WriteOp(IRDLOperation): name = "memref_stream.write" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + stream = operand_def(stream.WritableStreamType.constr(T)) + + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + + def __init__(self, value: SSAValue, stream: SSAValue): + super().__init__(operands=[value, stream]) + + def assembly_line(self) -> str | None: + return None + @irdl_op_definition class StreamingRegionOp(IRDLOperation): @@ -370,7 +401,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 25d0cafcaa..1bcbaaa670 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -40,6 +40,7 @@ ) from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( + AnyAttr, VarConstraint, attr_def, base, @@ -155,17 +156,41 @@ def assembly_line(self) -> str | None: @irdl_op_definition -class ReadOp(stream.ReadOperation, RISCVAsmOperation): +class ReadOp(RISCVAsmOperation): name = "riscv_snitch.read" + T: ClassVar = VarConstraint("T", AnyAttr()) + + stream = operand_def(stream.ReadableStreamType.constr(T)) + res = result_def(T) + + assembly_format = "`from` $stream attr-dict `:` type($res)" + + def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): + if result_type is None: + assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) + stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + result_type = stream_type.element_type + super().__init__(operands=[stream_val], result_types=[result_type]) + def assembly_line(self) -> str | None: return None @irdl_op_definition -class WriteOp(stream.WriteOperation, RISCVAsmOperation): +class WriteOp(RISCVAsmOperation): name = "riscv_snitch.write" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + stream = operand_def(stream.WritableStreamType.constr(T)) + + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + + def __init__(self, value: SSAValue, stream: SSAValue): + super().__init__(operands=[value, stream]) + def assembly_line(self) -> str | None: return None diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 3c2d5ea2d6..7ed748f393 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,32 +1,21 @@ from __future__ import annotations -import abc -from typing import ClassVar, Generic, TypeVar, cast, overload - -from typing_extensions import Self +from typing import Generic, TypeVar from xdsl.dialects.builtin import ContainerType from xdsl.ir import ( Attribute, Dialect, ParametrizedAttribute, - SSAValue, TypeAttribute, ) from xdsl.irdl import ( - AnyAttr, BaseAttr, GenericAttrConstraint, - IRDLOperation, ParamAttrConstraint, ParameterDef, - VarConstraint, irdl_attr_definition, - operand_def, - result_def, ) -from xdsl.parser import Parser -from xdsl.printer import Printer _StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) _StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) @@ -46,30 +35,10 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[StreamType[Attribute]]: ... - - @overload @staticmethod def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[StreamType[Attribute]] - | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[StreamType[Attribute]](StreamType) + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -79,134 +48,36 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[ReadableStreamType[Attribute]]: ... - - @overload @staticmethod def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[ReadableStreamType[Attribute]] - | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( ReadableStreamType, (element_type,) ) +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) + + @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[WritableStreamType[Attribute]]: ... - - @overload @staticmethod def constr( - *, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[WritableStreamType[Attribute]] - | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( WritableStreamType, (element_type,) ) -class ReadOperation(IRDLOperation, abc.ABC): - """ - Abstract base class for operations that read from a stream. - """ - - T: ClassVar = VarConstraint("T", AnyAttr()) - - stream = operand_def(ReadableStreamType.constr(element_type=T)) - res = result_def(T) - - def __init__(self, stream: SSAValue, result_type: Attribute | None = None): - if result_type is None: - assert isinstance(stream_type := stream.type, ReadableStreamType) - stream_type = cast(ReadableStreamType[Attribute], stream_type) - result_type = stream_type.element_type - super().__init__(operands=[stream], result_types=[result_type]) - - @classmethod - def parse(cls, parser: Parser) -> Self: - parser.parse_characters("from") - unresolved = parser.parse_unresolved_operand() - parser.parse_punctuation(":") - result_type = parser.parse_attribute() - resolved = parser.resolve_operand(unresolved, ReadableStreamType(result_type)) - return cls(resolved, result_type) - - def print(self, printer: Printer): - printer.print_string(" from ") - printer.print(self.stream) - printer.print_string(" : ") - printer.print_attribute(self.res.type) - - -class WriteOperation(IRDLOperation, abc.ABC): - """ - Abstract base class for operations that write to a stream. - """ - - T: ClassVar = VarConstraint("T", AnyAttr()) - - value = operand_def(T) - stream = operand_def(WritableStreamType.constr(element_type=T)) - - def __init__(self, value: SSAValue, stream: SSAValue): - super().__init__(operands=[value, stream]) - - @classmethod - def parse(cls, parser: Parser) -> Self: - unresolved_value = parser.parse_unresolved_operand() - parser.parse_characters("to") - unresolved_stream = parser.parse_unresolved_operand() - parser.parse_punctuation(":") - result_type = parser.parse_attribute() - resolved_value = parser.resolve_operand(unresolved_value, result_type) - resolved_stream = parser.resolve_operand( - unresolved_stream, WritableStreamType(result_type) - ) - return cls(resolved_value, resolved_stream) - - def print(self, printer: Printer): - printer.print_string(" ") - printer.print_ssa_value(self.value) - printer.print_string(" to ") - printer.print_ssa_value(self.stream) - printer.print_string(" : ") - printer.print_attribute(self.value.type) +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) Stream = Dialect( diff --git a/xdsl/dialects/utils/__init__.py b/xdsl/dialects/utils/__init__.py new file mode 100644 index 0000000000..97cc7a2b6f --- /dev/null +++ b/xdsl/dialects/utils/__init__.py @@ -0,0 +1,4 @@ +# TID 251 enforces to not import from those +# We need to skip it here to allow importing from here instead. +from .fast_math import * # noqa: TID251 +from .format import * # noqa: TID251 diff --git a/xdsl/dialects/utils/fast_math.py b/xdsl/dialects/utils/fast_math.py new file mode 100644 index 0000000000..c12b8b5878 --- /dev/null +++ b/xdsl/dialects/utils/fast_math.py @@ -0,0 +1,29 @@ +from abc import ABC +from dataclasses import dataclass + +from xdsl.ir import BitEnumAttribute +from xdsl.utils.str_enum import StrEnum + + +class FastMathFlag(StrEnum): + """ + Values specifying fast math behaviour of an arithmetic operation. + """ + + REASSOC = "reassoc" + NO_NANS = "nnan" + NO_INFS = "ninf" + NO_SIGNED_ZEROS = "nsz" + ALLOW_RECIP = "arcp" + ALLOW_CONTRACT = "contract" + APPROX_FUNC = "afn" + + +@dataclass(frozen=True, init=False) +class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC): + """ + Base class for attributes defining fast math behavior of arithmetic operations. + """ + + none_value = "none" + all_value = "fast" diff --git a/xdsl/dialects/utils.py b/xdsl/dialects/utils/format.py similarity index 93% rename from xdsl/dialects/utils.py rename to xdsl/dialects/utils/format.py index 5b5fb8b0f9..ae7965d213 100644 --- a/xdsl/dialects/utils.py +++ b/xdsl/dialects/utils/format.py @@ -1,6 +1,4 @@ -from abc import ABC from collections.abc import Iterable, Sequence -from dataclasses import dataclass from typing import Generic from xdsl.dialects.builtin import ( @@ -14,7 +12,6 @@ from xdsl.ir import ( Attribute, AttributeInvT, - BitEnumAttribute, BlockArgument, Operation, Region, @@ -23,7 +20,6 @@ from xdsl.irdl import IRDLOperation, var_operand_def from xdsl.parser import Parser, UnresolvedOperand from xdsl.printer import Printer -from xdsl.utils.str_enum import StrEnum def print_call_op_like( @@ -345,33 +341,3 @@ def parse_dynamic_index_list_without_types( values.append(value_or_index) return values, indices - - -# region Fast Math Flags - - -class FastMathFlag(StrEnum): - """ - Values specifying fast math behaviour of an arithmetic operation. - """ - - REASSOC = "reassoc" - NO_NANS = "nnan" - NO_INFS = "ninf" - NO_SIGNED_ZEROS = "nsz" - ALLOW_RECIP = "arcp" - ALLOW_CONTRACT = "contract" - APPROX_FUNC = "afn" - - -@dataclass(frozen=True, init=False) -class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC): - """ - Base class for attributes defining fast math behavior of arithmetic operations. - """ - - none_value = "none" - all_value = "fast" - - -# endregion diff --git a/xdsl/interpreters/pdl.py b/xdsl/interpreters/pdl.py index ff1bb1317e..4fd5d4232a 100644 --- a/xdsl/interpreters/pdl.py +++ b/xdsl/interpreters/pdl.py @@ -386,3 +386,10 @@ def run_erase( (old,) = interpreter.get_values((op.op_value,)) self.rewriter.erase_op(old) return () + + @impl(pdl.TypeOp) + def run_type( + self, interpreter: Interpreter, op: pdl.TypeOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + assert isinstance(op.constantType, Attribute) + return (op.constantType,) diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index b4bece71de..698f52e25d 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -606,7 +606,7 @@ def _verify(self): super()._verify() -class TypedAttribute(ParametrizedAttribute, Generic[AttributeCovT], ABC): +class TypedAttribute(ParametrizedAttribute, ABC): """ An attribute with a type. """ @@ -617,8 +617,8 @@ def get_type_index(cls) -> int: ... @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: """ Parse the attribute with the given type. """ diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index bf21d83992..488961d7df 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -61,6 +61,15 @@ def update(self, other: ConstraintContext): Possible types that a constraint variable can have. """ + +@dataclass +class InferenceContext: + variables: dict[str, ConstraintVariableType] = field(default_factory=dict) + """ + A mapping from variable names to the inferred attribute or attribute sequence. + """ + + _T = TypeVar("_T") @@ -156,7 +165,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: # By default, we cannot infer anything. return False - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: """ Infer the attribute given the the values for all variables. @@ -228,8 +237,8 @@ def verify( def get_variable_extractors(self) -> dict[str, VarExtractor[AttributeCovT]]: return {self.name: IdExtractor()} - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: - v = variables[self.name] + def infer(self, context: InferenceContext) -> AttributeCovT: + v = context.variables[self.name] return cast(AttributeCovT, v) def can_infer(self, var_constraint_names: Set[str]) -> bool: @@ -272,7 +281,7 @@ def verify( def can_infer(self, var_constraint_names: Set[str]) -> bool: return True - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: return self.attr def get_unique_base(self) -> type[Attribute] | None: @@ -303,7 +312,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: and not self.attr.get_irdl_definition().parameters ) - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: assert issubclass(self.attr, ParametrizedAttribute) attr = self.attr.new(()) return attr @@ -439,10 +448,10 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: constr.can_infer(var_constraint_names) for constr in self.attr_constrs ) - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: + def infer(self, context: InferenceContext) -> AttributeCovT: for constr in self.attr_constrs: - if constr.can_infer(variables.keys()): - return constr.infer(variables) + if constr.can_infer(context.variables.keys()): + return constr.infer(context) raise ValueError("Cannot infer attribute from constraint") def get_unique_base(self) -> type[Attribute] | None: @@ -535,10 +544,8 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: constr.can_infer(var_constraint_names) for constr in self.param_constrs ) - def infer( - self, variables: dict[str, ConstraintVariableType] - ) -> ParametrizedAttributeCovT: - params = tuple(constr.infer(variables) for constr in self.param_constrs) + def infer(self, context: InferenceContext) -> ParametrizedAttributeCovT: + params = tuple(constr.infer(context) for constr in self.param_constrs) attr = self.base_attr.new(params) return attr @@ -590,8 +597,8 @@ def get_unique_base(self) -> type[Attribute] | None: def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.constr.can_infer(var_constraint_names) - def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT: - return self.constr.infer(variables) + def infer(self, context: InferenceContext) -> AttributeCovT: + return self.constr.infer(context) @dataclass(frozen=True) @@ -628,9 +635,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: # By default, we cannot infer anything. return False - def infer( - self, length: int, variables: dict[str, ConstraintVariableType] - ) -> Sequence[AttributeCovT]: + def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]: """ Infer the attribute given the the values for all variables. @@ -684,12 +689,8 @@ def get_variable_extractors( def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.name in var_constraint_names - def infer( - self, - length: int, - variables: dict[str, ConstraintVariableType], - ) -> Sequence[AttributeCovT]: - v = variables[self.name] + def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]: + v = context.variables[self.name] return cast(Sequence[AttributeCovT], v) @@ -715,9 +716,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool: def infer( self, length: int, - variables: dict[str, ConstraintVariableType], + context: InferenceContext, ) -> Sequence[AttributeCovT]: - attr = self.constr.infer(variables) + attr = self.constr.infer(context) return (attr,) * length @@ -756,12 +757,8 @@ def get_variable_extractors( def can_infer(self, var_constraint_names: Set[str]) -> bool: return self.constr.can_infer(var_constraint_names) - def infer( - self, - length: int, - variables: dict[str, ConstraintVariableType], - ) -> Sequence[AttributeCovT]: - return (self.constr.infer(variables),) + def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]: + return (self.constr.infer(context),) def range_constr_coercion( diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 40d8a6841d..977f47e7a3 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, TypeAlias from xdsl.dialects.builtin import UnitAttr from xdsl.ir import ( @@ -22,6 +22,7 @@ ) from xdsl.irdl import ( ConstraintVariableType, + InferenceContext, IRDLOperation, IRDLOperationInvT, OpDef, @@ -186,7 +187,7 @@ def resolve_operand_types(self, state: ParsingState, op_def: OpDef) -> None: range_length = len(operand) if isinstance(operand, Sequence) else 1 operand_type = operand_def.constr.infer( range_length, - state.variables, + InferenceContext(state.variables), ) resolved_operand_type: Attribute | Sequence[Attribute] if isinstance(operand_def, OptionalDef): @@ -220,7 +221,7 @@ def resolve_result_types(self, state: ParsingState, op_def: OpDef) -> None: range_length = 1 inferred_result_types = result_def.constr.infer( range_length, - state.variables, + InferenceContext(state.variables), ) resolved_result_type = inferred_result_types[0] state.result_types[i] = resolved_result_type @@ -237,19 +238,11 @@ def print(self, printer: Printer, op: IRDLOperation) -> None: @dataclass(frozen=True) -class FormatDirective(ABC): - """A format directive for operation format.""" - - @abstractmethod - def parse(self, parser: Parser, state: ParsingState) -> None: ... +class Directive(ABC): + """An assembly format directive""" - @abstractmethod - def print( - self, printer: Printer, state: PrintingState, op: IRDLOperation - ) -> None: ... - -class AnchorableDirective(FormatDirective, ABC): +class AnchorableDirective(Directive, ABC): """ Base class for Directive usable as anchors to optional groups. """ @@ -262,6 +255,18 @@ def is_present(self, op: IRDLOperation) -> bool: ... +class FormatDirective(Directive, ABC): + """A format directive for operation format.""" + + @abstractmethod + def parse(self, parser: Parser, state: ParsingState) -> None: ... + + @abstractmethod + def print( + self, printer: Printer, state: PrintingState, op: IRDLOperation + ) -> None: ... + + class OptionallyParsableDirective(FormatDirective, ABC): """ Base class for Directive that can be optionally parsed. @@ -271,7 +276,7 @@ class OptionallyParsableDirective(FormatDirective, ABC): @abstractmethod def parse_optional(self, parser: Parser, state: ParsingState) -> bool: """ - Try parsing the directive and return if it was present. + Try parsing the directive and return True if it was present. """ ... @@ -279,71 +284,175 @@ def parse(self, parser: Parser, state: ParsingState) -> None: self.parse_optional(parser, state) -class VariadicLikeFormatDirective(AnchorableDirective, ABC): +class VariadicLikeFormatDirective( + OptionallyParsableDirective, AnchorableDirective, ABC +): """ - Baseclass to help keep typechecking simple. - VariadicLike is mostly Variadic or Optional: Whatever directive that can accept - having nothing to parse. + A directive which parses/prints multiple objects separated by commas. + Such directives can not be followed by comma literals. """ - pass + def set_empty(self, state: ParsingState): + """ + Set the appropriate field of the parsing state to be empty. + Used when a variable appears in an optional group which is not parsed. + """ + return -@dataclass(frozen=True) -class VariableDirective(FormatDirective, ABC): +class TypeableDirective(Directive, ABC): """ - A variable directive, with the following format: - variable-directive ::= dollar-ident - The directive will request a space to be printed after. + Directives which can be used to set or get types. """ - name: str - """The variable name. This is only used for error message reporting.""" - index: int - """Index of the variable(operand or result) definition.""" + @abstractmethod + def set_type(self, type: Attribute, state: ParsingState) -> None: ... + + @abstractmethod + def get_type(self, op: IRDLOperation) -> Attribute: ... -class TypeDirective(VariableDirective, ABC): +class VariadicTypeableDirective(AnchorableDirective, ABC): """ - Base class for Directive meant to parse types. + Directives which can set or get multiple types. """ - pass + @abstractmethod + def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: ... + @abstractmethod + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: ... -class RegionDirective(OptionallyParsableDirective, ABC): + +class OptionalTypeableDirective(AnchorableDirective, ABC): """ - Baseclass to help keep typechecking simple. - RegionDirective is for any RegionVariable, which are all OptionallyParsable. + Directives which can optionally set or get a single type. + """ + + @abstractmethod + def set_type(self, type: Attribute | None, state: ParsingState) -> None: ... + + @abstractmethod + def get_type(self, op: IRDLOperation) -> Attribute | None: ... + + +AnyTypeableDirective: TypeAlias = ( + TypeableDirective | VariadicTypeableDirective | OptionalTypeableDirective +) + + +@dataclass(frozen=True) +class TypeDirective(FormatDirective): + """ + A directive which parses the type of a typeable directive, with format: + type-directive ::= type(typeable-directive) + """ + + inner: TypeableDirective + + def parse(self, parser: Parser, state: ParsingState) -> None: + ty = parser.parse_type() + self.inner.set_type(ty, state) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_attribute(self.inner.get_type(op)) + state.last_was_punctuation = False + state.should_emit_space = True + + +class VariadicLikeTypeDirective(VariadicLikeFormatDirective): + """ + Base class for type checking. + A variadic-like type directive can not be followed by a variadic-like type directive. """ - pass +@dataclass(frozen=True) +class VariadicTypeDirective(VariadicLikeTypeDirective): + """ + A directive which parses the type of a variadic typeable directive, with format: + type-directive ::= type(typeable-directive) + """ -class VariadicLikeVariable(VariadicLikeFormatDirective, VariableDirective, ABC): - pass + inner: VariadicTypeableDirective + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + types = parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type + ) + if types is None: + types = () + self.inner.set_types(types, state) + return bool(types) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_list(self.inner.get_types(op), printer.print_attribute) + state.last_was_punctuation = False + state.should_emit_space = True -class VariadicVariable(VariadicLikeVariable, ABC): def is_present(self, op: IRDLOperation) -> bool: - return len(getattr(op, self.name)) > 0 + return self.inner.is_present(op) + + def set_empty(self, state: ParsingState): + self.inner.set_types((), state) + + +@dataclass(frozen=True) +class OptionalTypeDirective(VariadicLikeTypeDirective): + """ + A directive which parses the type of a optional typeable directive, with format: + type-directive ::= type(typeable-directive) + """ + + inner: OptionalTypeableDirective + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + type = parser.parse_optional_type() + self.inner.set_type(type, state) + return bool(type) + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + type = self.inner.get_type(op) + if type: + printer.print_attribute(type) + state.last_was_punctuation = False + state.should_emit_space = True -class OptionalVariable(VariadicLikeVariable, ABC): def is_present(self, op: IRDLOperation) -> bool: - return getattr(op, self.name) is not None + return self.inner.is_present(op) + def set_empty(self, state: ParsingState): + self.inner.set_type(None, state) -class VariadicLikeTypeDirective(VariadicLikeFormatDirective, VariableDirective, ABC): - pass + +@dataclass(frozen=True) +class VariableDirective(Directive, ABC): + """ + A variable directive, with the following format: + variable-directive ::= dollar-ident + The directive will request a space to be printed after. + """ + + name: str + """The variable name. This is only used for error message reporting.""" + index: int + """Index of the variable(operand or result) definition.""" -class VariadicTypeDirective(VariadicLikeTypeDirective, VariadicVariable, ABC): - pass +class VariadicVariable(VariableDirective, AnchorableDirective, ABC): + def is_present(self, op: IRDLOperation) -> bool: + return bool(getattr(op, self.name)) -class OptionalTypeDirective(VariadicLikeTypeDirective, OptionalVariable, ABC): - pass +class OptionalVariable(VariableDirective, AnchorableDirective, ABC): + def is_present(self, op: IRDLOperation) -> bool: + return getattr(op, self.name) is not None @dataclass(frozen=True) @@ -422,7 +531,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No @dataclass(frozen=True) -class OperandVariable(VariableDirective): +class OperandVariable(VariableDirective, FormatDirective, TypeableDirective): """ An operand variable, with the following format: operand-directive ::= dollar-ident @@ -433,6 +542,9 @@ def parse(self, parser: Parser, state: ParsingState) -> None: operand = parser.parse_unresolved_operand() state.operands[self.index] = operand + def set_type(self, type: Attribute, state: ParsingState) -> None: + state.operand_types[self.index] = type + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") @@ -440,10 +552,20 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def get_type(self, op: IRDLOperation) -> Attribute: + return getattr(op, self.name).type + + +class VariadicOperandDirective(VariadicLikeFormatDirective, ABC): + """ + Base class for typechecking. + A variadic operand directive cannot follow another variadic operand directive. + """ + @dataclass(frozen=True) class VariadicOperandVariable( - VariadicVariable, VariableDirective, OptionallyParsableDirective + VariadicVariable, VariadicOperandDirective, VariadicTypeableDirective ): """ A variadic operand variable, with the following format: @@ -460,6 +582,9 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: state.operands[self.index] = operands return bool(operands) + def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: + state.operand_types[self.index] = types + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") @@ -469,8 +594,16 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: + return getattr(op, self.name).types + + def set_empty(self, state: ParsingState): + state.operands[self.index] = () -class OptionalOperandVariable(OptionalVariable, OptionallyParsableDirective): + +class OptionalOperandVariable( + OptionalVariable, VariadicOperandDirective, OptionalTypeableDirective +): """ An optional operand variable, with the following format: operand-directive ::= ( percent-ident )? @@ -484,6 +617,9 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: state.operands[self.index] = operand return bool(operand) + def set_type(self, type: Attribute | None, state: ParsingState) -> None: + state.operand_types[self.index] = type or () + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") @@ -493,80 +629,18 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True - -@dataclass(frozen=True) -class OperandTypeDirective(TypeDirective): - """ - An operand variable type directive, with the following format: - operand-type-directive ::= type(dollar-ident) - The directive will request a space to be printed right after. - """ - - def parse(self, parser: Parser, state: ParsingState) -> None: - type = parser.parse_type() - state.operand_types[self.index] = type - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_attribute(getattr(op, self.name).type) - state.last_was_punctuation = False - state.should_emit_space = True - - -@dataclass(frozen=True) -class VariadicOperandTypeDirective( - TypeDirective, VariadicTypeDirective, OptionallyParsableDirective -): - """ - A variadic operand variable, with the following format: - operand-directive ::= ( percent-ident ( `,` percent-id )* )? - The directive will request a space to be printed after. - """ - - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - operand_types = parser.parse_optional_undelimited_comma_separated_list( - parser.parse_optional_type, parser.parse_type - ) - if operand_types is None: - operand_types = () - state.operand_types[self.index] = operand_types - return bool(operand_types) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_list(getattr(op, self.name).types, printer.print_attribute) - state.last_was_punctuation = False - state.should_emit_space = True - - -class OptionalOperandTypeDirective(OptionalTypeDirective, OptionallyParsableDirective): - """ - An optional operand variable type directive, with the following format: - operand-type-directive ::= ( type(dollar-ident) )? - The directive will request a space to be printed after. - """ - - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - type = parser.parse_optional_type() - if type is None: - type = () - state.operand_types[self.index] = type - return bool(type) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") + def get_type(self, op: IRDLOperation) -> Attribute | None: operand = getattr(op, self.name) if operand: - printer.print_attribute(operand.type) - state.last_was_punctuation = False - state.should_emit_space = True + return operand.type + return None + + def set_empty(self, state: ParsingState): + state.operands[self.index] = () @dataclass(frozen=True) -class ResultVariable(VariableDirective): +class ResultVariable(VariableDirective, TypeableDirective): """ An result variable, with the following format: result-directive ::= dollar-ident @@ -574,23 +648,15 @@ class ResultVariable(VariableDirective): parsing is not handled by the custom operation parser. """ - def parse(self, parser: Parser, state: ParsingState) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def set_type(self, type: Attribute, state: ParsingState) -> None: + state.result_types[self.index] = type - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def get_type(self, op: IRDLOperation) -> Attribute: + return getattr(op, self.name).type @dataclass(frozen=True) -class VariadicResultVariable( - ResultVariable, VariadicVariable, OptionallyParsableDirective -): +class VariadicResultVariable(VariadicVariable, VariadicTypeableDirective): """ A variadic result variable, with the following format: result-directive ::= percent-ident (( `,` percent-id )* )? @@ -598,21 +664,14 @@ class VariadicResultVariable( parsing is not handled by the custom operation parser. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) - return False + def set_types(self, types: Sequence[Attribute], state: ParsingState) -> None: + state.result_types[self.index] = types - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: + return getattr(op, self.name).types -class OptionalResultVariable(OptionalVariable, OptionallyParsableDirective): +class OptionalResultVariable(OptionalVariable, OptionalTypeableDirective): """ An optional result variable, with the following format: result-directive ::= ( percent-ident )? @@ -620,92 +679,29 @@ class OptionalResultVariable(OptionalVariable, OptionallyParsableDirective): parsing is not handled by the custom operation parser. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) - return False - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - assert ( - "Result variables cannot be used directly to parse/print in " - "declarative formats." - ) - - -@dataclass(frozen=True) -class ResultTypeDirective(TypeDirective): - """ - A result variable type directive, with the following format: - result-type-directive ::= type(dollar-ident) - The directive will request a space to be printed right after. - """ - - def parse(self, parser: Parser, state: ParsingState) -> None: - type = parser.parse_type() - state.result_types[self.index] = type + def set_type(self, type: Attribute | None, state: ParsingState) -> None: + state.result_types[self.index] = type or () - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_attribute(getattr(op, self.name).type) - state.last_was_punctuation = False - state.should_emit_space = True + def get_type(self, op: IRDLOperation) -> Attribute | None: + res = getattr(op, self.name) + if res: + return res.type + return None -@dataclass(frozen=True) -class VariadicResultTypeDirective( - TypeDirective, VariadicTypeDirective, OptionallyParsableDirective -): +class RegionDirective(OptionallyParsableDirective, ABC): """ - A variadic result variable type directive, with the following format: - variadic-result-type-directive ::= ( percent-ident ( `,` percent-id )* )? - The directive will request a space to be printed after. + Baseclass to help keep typechecking simple. + RegionDirective is for any RegionVariable, which are all OptionallyParsable. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - result_types = parser.parse_optional_undelimited_comma_separated_list( - parser.parse_optional_type, parser.parse_type - ) - if result_types is None: - result_types = () - state.result_types[self.index] = result_types - return bool(result_types) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - printer.print_list(getattr(op, self.name).types, printer.print_attribute) - state.last_was_punctuation = False - state.should_emit_space = True - -class OptionalResultTypeDirective( - TypeDirective, OptionalTypeDirective, OptionallyParsableDirective -): +class VariadicRegionDirective(RegionDirective, VariadicLikeFormatDirective, ABC): """ - An optional result variable type directive, with the following format: - result-type-directive ::= ( type(dollar-ident) )? - The directive will request a space to be printed after. + Base class for typechecking. + A variadic region directive cannot follow another variadic region directive. """ - def parse_optional(self, parser: Parser, state: ParsingState) -> bool: - type = parser.parse_optional_type() - if type is None: - type = () - state.result_types[self.index] = type - return bool(type) - - def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: - if state.should_emit_space or not state.last_was_punctuation: - printer.print(" ") - result = getattr(op, self.name) - if result: - printer.print_attribute(result.type) - state.last_was_punctuation = False - state.should_emit_space = True - @dataclass(frozen=True) class RegionVariable(RegionDirective, VariableDirective): @@ -729,7 +725,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No @dataclass(frozen=True) -class VariadicRegionVariable(RegionDirective, VariadicVariable): +class VariadicRegionVariable(VariadicRegionDirective, VariadicVariable): """ A variadic region variable, with the following format: region-directive ::= dollar-ident @@ -756,8 +752,11 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.regions[self.index] = () -class OptionalRegionVariable(RegionDirective, OptionalVariable): + +class OptionalRegionVariable(VariadicRegionDirective, OptionalVariable): """ An optional region variable, with the following format: region-directive ::= dollar-ident @@ -780,6 +779,16 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.regions[self.index] = () + + +class VariadicSuccessorDirective(VariadicLikeFormatDirective, ABC): + """ + Base class for type checking. + A variadic successor directive cannot follow another variadic successor directive. + """ + class SuccessorVariable(VariableDirective, OptionallyParsableDirective): """ @@ -803,7 +812,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.should_emit_space = True -class VariadicSuccessorVariable(VariadicVariable, OptionallyParsableDirective): +class VariadicSuccessorVariable(VariadicSuccessorDirective, VariadicVariable): """ A variadic successor variable, with the following format: successor-directive ::= dollar-ident @@ -830,8 +839,11 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.successors[self.index] = () -class OptionalSuccessorVariable(OptionalVariable, OptionallyParsableDirective): + +class OptionalSuccessorVariable(VariadicSuccessorDirective, OptionalVariable): """ An optional successor variable, with the following format: successor-directive ::= dollar-ident @@ -854,6 +866,9 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True + def set_empty(self, state: ParsingState): + state.successors[self.index] = () + @dataclass(frozen=True) class AttributeVariable(FormatDirective): @@ -885,7 +900,9 @@ def parse(self, parser: Parser, state: ParsingState) -> None: ): attr = unique_base.new(unique_base.parse_parameters(parser)) elif issubclass(unique_base, Data): - attr = unique_base.new(unique_base.parse_parameter(parser)) # pyright: ignore[reportUnknownVariableType] + attr = unique_base.new( # pyright: ignore[reportUnknownVariableType] + unique_base.parse_parameter(parser) + ) else: raise ValueError("Attributes must be Data or ParameterizedAttribute.") if self.is_property: @@ -1072,33 +1089,8 @@ def parse(self, parser: Parser, state: ParsingState) -> None: # type to empty else: for element in self.then_elements: - match element: - case ( - OperandVariable(_, index) - | VariadicOperandVariable(_, index) - | OptionalOperandVariable(_, index) - ): - state.operands[index] = () - case ( - OperandTypeDirective(_, index) - | VariadicOperandTypeDirective(_, index) - | OptionalOperandTypeDirective(_, index) - ): - state.operand_types[index] = () - case ( - RegionVariable(_, index) - | VariadicRegionVariable(_, index) - | OptionalRegionVariable(_, index) - ): - state.regions[index] = () - case ( - ResultTypeDirective(_, index) - | VariadicResultTypeDirective(_, index) - | OptionalResultTypeDirective(_, index) - ): - state.result_types[index] = () - case _: - pass + if isinstance(element, VariadicLikeFormatDirective): + element.set_empty(state) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if self.anchor.is_present(op): diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 4977026c81..7e62f3f0b6 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -19,6 +19,7 @@ AttrSizedOperandSegments, AttrSizedSegments, ConstraintVariableType, + InferenceContext, OpDef, OptionalDef, OptOperandDef, @@ -39,6 +40,7 @@ ) from xdsl.irdl.declarative_assembly_format import ( AnchorableDirective, + AnyTypeableDirective, AttrDictDirective, AttributeVariable, DefaultValuedAttributeVariable, @@ -46,35 +48,35 @@ FormatProgram, KeywordDirective, OperandOrResult, - OperandTypeDirective, OperandVariable, OptionalAttributeVariable, OptionalGroupDirective, OptionallyParsableDirective, - OptionalOperandTypeDirective, OptionalOperandVariable, OptionalRegionVariable, - OptionalResultTypeDirective, OptionalResultVariable, OptionalSuccessorVariable, + OptionalTypeableDirective, + OptionalTypeDirective, OptionalUnitAttrVariable, ParsingState, PunctuationDirective, RegionDirective, RegionVariable, - ResultTypeDirective, ResultVariable, SuccessorVariable, - VariableDirective, + TypeDirective, VariadicLikeFormatDirective, VariadicLikeTypeDirective, - VariadicLikeVariable, - VariadicOperandTypeDirective, + VariadicOperandDirective, VariadicOperandVariable, + VariadicRegionDirective, VariadicRegionVariable, - VariadicResultTypeDirective, VariadicResultVariable, + VariadicSuccessorDirective, VariadicSuccessorVariable, + VariadicTypeableDirective, + VariadicTypeDirective, WhitespaceDirective, ) from xdsl.parser import BaseParser, ParserState @@ -147,8 +149,6 @@ class FormatParser(BaseParser): """The successor variables that are already parsed.""" has_attr_dict: bool = field(default=False) """True if the attribute dictionary has already been parsed.""" - context: ParsingContext = field(default=ParsingContext.TopLevel) - """Indicates if the parser is nested in a particular directive.""" type_resolutions: dict[ tuple[OperandOrResult, int], tuple[Callable[[Attribute], Attribute], OperandOrResult, int], @@ -176,7 +176,7 @@ def parse_format(self) -> FormatProgram: """ elements: list[FormatDirective] = [] while self._current_token.kind != Token.Kind.EOF: - elements.append(self.parse_directive()) + elements.append(self.parse_format_directive()) self.add_reserved_attrs_to_directive(elements) extractors = self.extractors_by_name() @@ -204,25 +204,19 @@ def verify_directives(self, elements: list[FormatDirective]): self.raise_error( "A variadic type directive cannot be followed by another variadic type directive." ) - case VariadicLikeVariable(), VariadicLikeVariable(): - if not ( - isinstance(a, RegionDirective | VariadicLikeTypeDirective) - or isinstance(b, RegionDirective | VariadicLikeTypeDirective) - ): - self.raise_error( - "A variadic operand variable cannot be followed by another variadic operand variable." - ) - elif isinstance(a, RegionDirective) and isinstance( - b, RegionDirective - ): - self.raise_error( - "A variadic region variable cannot be followed by another variadic region variable." - ) - case AttrDictDirective(), RegionDirective() if not (a.with_keyword): + case VariadicOperandDirective(), VariadicOperandDirective(): self.raise_error( - "An `attr-dict' directive without keyword cannot be directly followed by a region variable as it is ambiguous." + "A variadic operand variable cannot be followed by another variadic operand variable." + ) + case VariadicRegionDirective(), VariadicRegionDirective(): + self.raise_error( + "A variadic region variable cannot be followed by another variadic region variable." + ) + case VariadicSuccessorDirective(), VariadicSuccessorDirective(): + self.raise_error( + "A variadic successor variable cannot be followed by another variadic successor variable." ) - case AttrDictDirective(), RegionVariable() if not (a.with_keyword): + case AttrDictDirective(), RegionDirective() if not (a.with_keyword): self.raise_error( "An `attr-dict' directive without keyword cannot be directly followed by a region variable as it is ambiguous." ) @@ -402,30 +396,22 @@ def verify_successors(self): "directive to the custom assembly format." ) - def parse_optional_variable( - self, - ) -> VariableDirective | AttributeVariable | None: - """ - Parse a variable, if present, with the following format: - variable ::= `$` bare-ident - The variable should refer to an operand, attribute, region, result, - or successor. - """ - if self._current_token.text[0] != "$": - return None - self._consume_token() - variable_name = self.parse_identifier(" after '$'") - - # Check if the variable is an operand + def _parse_optional_operand( + self, variable_name: str, top_level: bool + ) -> OptionalOperandVariable | VariadicOperandVariable | OperandVariable | None: for idx, (operand_name, operand_def) in enumerate(self.op_def.operands): if variable_name != operand_name: continue - if self.context == ParsingContext.TopLevel: + if top_level: if self.seen_operands[idx]: self.raise_error(f"operand '{variable_name}' is already bound") self.seen_operands[idx] = True if isinstance(operand_def, VariadicDef | OptionalDef): self.seen_attributes.add(AttrSizedOperandSegments.attribute_name) + else: + if self.seen_operand_types[idx]: + self.raise_error(f"type of '{variable_name}' is already bound") + self.seen_operand_types[idx] = True match operand_def: case OptOperandDef(): return OptionalOperandVariable(variable_name, idx) @@ -434,15 +420,28 @@ def parse_optional_variable( case _: return OperandVariable(variable_name, idx) + def parse_optional_typeable_variable(self) -> AnyTypeableDirective | None: + """ + Parse a variable, if present, with the following format: + variable ::= `$` bare-ident + The variable should refer to an operand or result. + """ + if self._current_token.text[0] != "$": + return None + self._consume_token() + variable_name = self.parse_identifier(" after '$'") + + # Check if the variable is an operand + if (variable := self._parse_optional_operand(variable_name, False)) is not None: + return variable + # Check if the variable is a result for idx, (result_name, result_def) in enumerate(self.op_def.results): if variable_name != result_name: continue - if self.context == ParsingContext.TopLevel: - self.raise_error( - "result variable cannot be in a toplevel directive. " - f"Consider using 'type({variable_name})' instead." - ) + if self.seen_result_types[idx]: + self.raise_error(f"type of '{variable_name}' is already bound") + self.seen_result_types[idx] = True match result_def: case OptResultDef(): return OptionalResultVariable(variable_name, idx) @@ -450,10 +449,25 @@ def parse_optional_variable( return VariadicResultVariable(variable_name, idx) case _: return ResultVariable(variable_name, idx) - if isinstance(result_def, VariadicDef): - return VariadicResultVariable(variable_name, idx) - else: - return ResultVariable(variable_name, idx) + + self.raise_error("expected typeable variable to refer to an operand or result") + + def parse_optional_variable( + self, + ) -> FormatDirective | None: + """ + Parse a variable, if present, with the following format: + variable ::= `$` bare-ident + The variable should refer to an operand, attribute, region, or successor. + """ + if self._current_token.text[0] != "$": + return None + self._consume_token() + variable_name = self.parse_identifier(" after '$'") + + # Check if the variable is an operand + if (variable := self._parse_optional_operand(variable_name, True)) is not None: + return variable # Check if the variable is a region for idx, (region_name, region_def) in enumerate(self.op_def.regions): @@ -491,17 +505,14 @@ def parse_optional_variable( attr_name = variable_name attr_or_prop = attr_or_prop_by_name[attr_name] is_property = attr_or_prop == "property" - if self.context == ParsingContext.TopLevel: - if is_property: - if attr_name in self.seen_properties: - self.raise_error(f"property '{variable_name}' is already bound") - self.seen_properties.add(attr_name) - else: - if attr_name in self.seen_attributes: - self.raise_error( - f"attribute '{variable_name}' is already bound" - ) - self.seen_attributes.add(attr_name) + if is_property: + if attr_name in self.seen_properties: + self.raise_error(f"property '{variable_name}' is already bound") + self.seen_properties.add(attr_name) + else: + if attr_name in self.seen_attributes: + self.raise_error(f"attribute '{variable_name}' is already bound") + self.seen_attributes.add(attr_name) attr_def = ( self.op_def.properties.get(attr_name) @@ -528,7 +539,7 @@ def parse_optional_variable( unique_base.get_type_index() ] if type_constraint.can_infer(set()): - unique_type = type_constraint.infer({}) + unique_type = type_constraint.infer(InferenceContext()) if ( unique_base is not None and unique_base in Builtin.attributes @@ -559,63 +570,23 @@ def parse_optional_variable( self.raise_error( "expected variable to refer to an operand, " - "attribute, region, result, or successor" + "attribute, region, or successor" ) def parse_type_directive(self) -> FormatDirective: """ Parse a type directive with the following format: - type-directive ::= `type` `(` variable `)` + type-directive ::= `type` `(` typeable-directive `)` `type` is expected to have already been parsed """ self.parse_punctuation("(") - - # Update the current context, since we are now in a type directive - previous_context = self.context - self.context = ParsingContext.TypeDirective - - variable = self.parse_optional_variable() - match variable: - case None: - self.raise_error("'type' directive expects a variable argument") - case OptionalOperandVariable(name, index): - if self.seen_operand_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_operand_types[index] = True - res = OptionalOperandTypeDirective(name, index) - case VariadicOperandVariable(name, index): - if self.seen_operand_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_operand_types[index] = True - res = VariadicOperandTypeDirective(name, index) - case OperandVariable(name, index): - if self.seen_operand_types[index]: - self.raise_error(f"type of '{name}' is already bound") - self.seen_operand_types[index] = True - res = OperandTypeDirective(name, index) - case OptionalResultVariable(name, index): - if self.seen_result_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_result_types[index] = True - res = OptionalResultTypeDirective(name, index) - case VariadicResultVariable(name, index): - if self.seen_result_types[index]: - self.raise_error(f"types of '{name}' is already bound") - self.seen_result_types[index] = True - res = VariadicResultTypeDirective(name, index) - case ResultVariable(name, index): - if self.seen_result_types[index]: - self.raise_error(f"type of '{name}' is already bound") - self.seen_result_types[index] = True - res = ResultTypeDirective(name, index) - case AttributeVariable(): - self.raise_error("can only take the type of an operand or result") - case _: - raise ValueError(f"Unexpected variable type {type(variable)}") - + inner = self.parse_typeable_directive() self.parse_punctuation(")") - self.context = previous_context - return res + if isinstance(inner, VariadicTypeableDirective): + return VariadicTypeDirective(inner) + if isinstance(inner, OptionalTypeableDirective): + return OptionalTypeDirective(inner) + return TypeDirective(inner) def parse_optional_group(self) -> FormatDirective: """ @@ -626,7 +597,7 @@ def parse_optional_group(self) -> FormatDirective: anchor: FormatDirective | None = None while not self.parse_optional_punctuation(")"): - then_elements += (self.parse_directive(),) + then_elements += (self.parse_format_directive(),) if self.parse_optional_keyword("^"): if anchor is not None: self.raise_error("An optional group can only have one anchor.") @@ -709,7 +680,16 @@ def parse_keyword_or_punctuation(self) -> FormatDirective: self.parse_characters("`") return KeywordDirective(ident) - def parse_directive(self) -> FormatDirective: + def parse_typeable_directive(self) -> AnyTypeableDirective: + """ + Parse a typeable directive, with the following format: + directive ::= variable + """ + if variable := self.parse_optional_typeable_variable(): + return variable + self.raise_error(f"unexpected token '{self._current_token.text}'") + + def parse_format_directive(self) -> FormatDirective: """ Parse a format directive, with the following format: directive ::= `attr-dict` diff --git a/xdsl/parser/base_parser.py b/xdsl/parser/base_parser.py index 792f87e9e8..3586c7b6cc 100644 --- a/xdsl/parser/base_parser.py +++ b/xdsl/parser/base_parser.py @@ -359,6 +359,36 @@ def parse_integer( "Expected integer literal" + context_msg, ) + def parse_optional_float( + self, + *, + allow_negative: bool = True, + ) -> float | None: + """ + Parse a (possibly negative) float, if present. + """ + is_negative = False + if allow_negative: + is_negative = self._parse_optional_token(Token.Kind.MINUS) is not None + + if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: + value = value.get_float_value() + return -value if is_negative else value + + def parse_float( + self, + *, + allow_negative: bool = True, + ) -> float: + """ + Parse a (possibly negative) float. + """ + + return self.expect( + lambda: self.parse_optional_float(allow_negative=allow_negative), + "Expected float literal", + ) + def parse_optional_number( self, *, allow_boolean: bool = False ) -> int | float | None: @@ -376,8 +406,7 @@ def parse_optional_number( ) is not None: return -value if is_negative else value - if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: - value = value.get_float_value() + if (value := self.parse_optional_float(allow_negative=False)) is not None: return -value if is_negative else value if is_negative: diff --git a/xdsl/printer.py b/xdsl/printer.py index 792e08c3b7..f9b8bb4ab4 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -561,13 +561,6 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_identifier_or_string_literal(ref.data) return - if isinstance(attribute, FloatAttr): - attr = cast(AnyFloatAttr, attribute) - self.print_float(attr) - self.print_string(" : ") - self.print_attribute(attr.type) - return - # Complex types have MLIR shorthands but XDSL does not. if isinstance(attribute, ComplexType): self.print_string("complex<") diff --git a/xdsl/utils/scoped_dict.py b/xdsl/utils/scoped_dict.py index ab95061fc8..6fab8ba667 100644 --- a/xdsl/utils/scoped_dict.py +++ b/xdsl/utils/scoped_dict.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, TypeVar +from typing import Generic, TypeVar, overload _Key = TypeVar("_Key") _Value = TypeVar("_Value") @@ -28,11 +28,26 @@ def __init__( parent: ScopedDict[_Key, _Value] | None = None, *, name: str | None = None, + local_scope: dict[_Key, _Value] | None = None, ) -> None: - self._local_scope = {} + self._local_scope = {} if local_scope is None else local_scope self.parent = parent self.name = name + @overload + def get(self, key: _Key, default: None = None) -> _Value | None: ... + + @overload + def get(self, key: _Key, default: _Value) -> _Value: ... + + def get(self, key: _Key, default: _Value | None = None) -> _Value | None: + local = self._local_scope.get(key) + if local is not None: + return local + if self.parent is None: + return default + return self.parent.get(key, default) + def __getitem__(self, key: _Key) -> _Value: """ Fetch key from environment. Attempts to first fetch from current scope,