Skip to content

Commit

Permalink
core: make TypedAttribute not generic (#3505)
Browse files Browse the repository at this point in the history
The bit that used the generic wasn't type-sound anyway.
  • Loading branch information
alexarice authored Nov 22, 2024
1 parent ae6383b commit e0224a8
Show file tree
Hide file tree
Showing 25 changed files with 628 additions and 624 deletions.
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dev = [
"nbval<0.12",
"filecheck==1.0.1",
"lit<19.0.0",
"marimo==0.9.20",
"marimo==0.9.21",
"pre-commit==4.0.1",
"ruff==0.7.4",
"asv<0.7",
Expand Down Expand Up @@ -109,17 +109,19 @@ ignore = [
max-line-length = 300

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"xdsl.parser.core".msg = "Use xdsl.parser instead."
"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead."
"xdsl.dialects.utils.fast_math".msg = "Use xdsl.dialects.utils instead"
"xdsl.dialects.utils.format".msg = "Use xdsl.dialects.utils instead"
"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.core".msg = "Use xdsl.ir instead."
"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead"
"xdsl.irdl.common".msg = "Use xdsl.irdl instead"
"xdsl.irdl.constraints".msg = "Use xdsl.irdl instead"
"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead"
"xdsl.irdl.operations".msg = "Use xdsl.irdl instead"
"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead"
"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.core".msg = "Use xdsl.parser instead."


[tool.ruff.lint.per-file-ignores]
Expand Down
5 changes: 3 additions & 2 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from xdsl.ir import Attribute
from xdsl.irdl import (
EqAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand All @@ -39,13 +40,13 @@ def test_tensor_from_memref_inference():
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer({}) == TensorType(f64, [10, 20, 30])
assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30])

constr3 = TensorFromMemrefConstraint(
EqAttrConstraint(UnrankedMemrefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer({}) == UnrankedTensorType(f64)
assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64)


@irdl_op_definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)
%18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%20 = "test.op"() : () -> (memref<64x4096xf32>)

%zero = arith.constant 0: f32
linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>)

linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>)


Expand Down Expand Up @@ -99,17 +102,19 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou
// CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>)
// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1]
// CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) {
// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32):
// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32
// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32):
// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_2 : f32
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: } -> tensor<2x3xf32>
// CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>)
// CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>)
// CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32
// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32
// CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: }
27 changes: 27 additions & 0 deletions tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: xdsl-opt %s -p apply-pdl | filecheck %s

%x = "test.op"() : () -> (i32)

pdl.pattern : benefit(1) {
%in_type = pdl.type: i32
%root = pdl.operation "test.op" -> (%in_type: !pdl.type)
pdl.rewrite %root {
%out_type = pdl.type: i64
%new_op = pdl.operation "test.op" -> (%out_type: !pdl.type)
pdl.replace %root with %new_op
}
}

// CHECK: builtin.module {
// CHECK-NEXT: %x = "test.op"() : () -> i64
// CHECK-NEXT: pdl.pattern : benefit(1) {
// CHECK-NEXT: %in_type = pdl.type : i32
// CHECK-NEXT: %root = pdl.operation "test.op" -> (%in_type : !pdl.type)
// CHECK-NEXT: pdl.rewrite %root {
// CHECK-NEXT: %out_type = pdl.type : i64
// CHECK-NEXT: %new_op = pdl.operation "test.op" -> (%out_type : !pdl.type)
// CHECK-NEXT: pdl.replace %root with %new_op
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:
7 changes: 4 additions & 3 deletions tests/irdl/test_attr_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttrConstraint,
BaseAttr,
EqAttrConstraint,
InferenceContext,
ParamAttrConstraint,
ParameterDef,
VarConstraint,
Expand Down Expand Up @@ -77,7 +78,7 @@ class WrapAttr(BaseWrapAttr): ...
)

assert constr.can_infer(set())
assert constr.infer({}) == WrapAttr((StringAttr("Hello"),))
assert constr.infer(InferenceContext()) == WrapAttr((StringAttr("Hello"),))

var_constr = ParamAttrConstraint(
WrapAttr,
Expand All @@ -92,7 +93,7 @@ class WrapAttr(BaseWrapAttr): ...
)

assert var_constr.can_infer({"T"})
assert var_constr.infer({"T": StringAttr("Hello")}) == WrapAttr(
assert var_constr.infer(InferenceContext({"T": StringAttr("Hello")})) == WrapAttr(
(StringAttr("Hello"),)
)

Expand Down Expand Up @@ -127,7 +128,7 @@ class NoParamAttr(BaseNoParamAttr): ...
constr = BaseAttr(NoParamAttr)

assert constr.can_infer(set())
assert constr.infer({}) == NoParamAttr()
assert constr.infer(InferenceContext()) == NoParamAttr()

base_constr = BaseAttr(BaseNoParamAttr)
assert not base_constr.can_infer(set())
Expand Down
2 changes: 1 addition & 1 deletion tests/irdl/test_attribute_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_typed_attribute():

@irdl_attr_definition
class TypedAttr( # pyright: ignore[reportUnusedClass]
TypedAttribute[Attribute]
TypedAttribute
):
name = "test.typed"

Expand Down
15 changes: 9 additions & 6 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import textwrap
from collections.abc import Callable
from io import StringIO
from typing import ClassVar, Generic, TypeVar
from typing import Annotated, ClassVar, Generic, TypeVar

import pytest

Expand All @@ -12,6 +12,8 @@
from xdsl.dialects.builtin import (
I32,
BoolAttr,
Float64Type,
FloatAttr,
IntegerAttr,
MemRefType,
ModuleOp,
Expand Down Expand Up @@ -603,20 +605,21 @@ class OptionalAttributeOp(IRDLOperation):
"program, generic_program",
[
(
"test.typed_attr 3",
'"test.typed_attr"() {"attr" = 3 : i32} : () -> ()',
"test.typed_attr 3 3.000000e+00",
'"test.typed_attr"() {"attr" = 3 : i32, "float_attr" = 3.000000e+00 : f64} : () -> ()',
),
],
)
def test_typed_attribute_variable(program: str, generic_program: str):
"""Test the parsing of optional operands"""
"""Test the parsing of typed attributes"""

@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])
float_attr = attr_def(FloatAttr[Annotated[Float64Type, Float64Type()]])

assembly_format = "$attr attr-dict"
assembly_format = "$attr $float_attr attr-dict"

ctx = MLContext()
ctx.load_op(TypedAttributeOp)
Expand Down Expand Up @@ -693,7 +696,7 @@ def test_unknown_variable():
"""Test that variables should refer to an element in the operation."""
with pytest.raises(
PyRDLOpDefinitionError,
match="expected variable to refer to an operand, attribute, region, result, or successor",
match="expected variable to refer to an operand, attribute, region, or successor",
):

@irdl_op_definition
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/test_scoped_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,16 @@ def test_simple():
assert 3 not in table
assert 3 in inner
assert 4 not in inner


def test_get():
parent = ScopedDict(local_scope={"a": 1, "b": 2})
child = ScopedDict(parent, local_scope={"a": 3, "c": 4})

assert child.get("a") == 3
assert child.get("b") == 2
assert child.get("c") == 4
assert child.get("d") is None

assert child.get("a", 5) == 3
assert child.get("d", 5) == 5
6 changes: 3 additions & 3 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from xdsl.irdl import (
AttrSizedOperandSegments,
ConstraintContext,
ConstraintVariableType,
GenericAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand Down Expand Up @@ -53,9 +53,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.memref_constraint.can_infer(var_constraint_names)

def infer(
self, variables: dict[str, ConstraintVariableType]
self, context: InferenceContext
) -> TensorType[Attribute] | UnrankedTensorType[Attribute]:
memref_type = self.memref_constraint.infer(variables)
memref_type = self.memref_constraint.infer(context)
if isinstance(memref_type, MemRefType):
return TensorType(memref_type.element_type, memref_type.shape)
return UnrankedTensorType(memref_type.element_type)
Expand Down
19 changes: 15 additions & 4 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class IndexType(ParametrizedAttribute):
@irdl_attr_definition
class IntegerAttr(
Generic[_IntegerAttrType],
TypedAttribute[_IntegerAttrType],
TypedAttribute,
):
name = "integer"
value: ParameterDef[IntAttr]
Expand Down Expand Up @@ -504,8 +504,8 @@ def verify(self) -> None:
@staticmethod
def parse_with_type(
parser: AttrParser,
type: AttributeInvT,
) -> TypedAttribute[AttributeInvT]:
type: Attribute,
) -> TypedAttribute:
assert isinstance(type, IntegerType | IndexType)
return IntegerAttr(parser.parse_integer(allow_boolean=(type == i1)), type)

Expand Down Expand Up @@ -634,7 +634,7 @@ def __hash__(self):


@irdl_attr_definition
class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute):
class FloatAttr(Generic[_FloatAttrType], TypedAttribute):
name = "float"

value: ParameterDef[FloatData]
Expand Down Expand Up @@ -668,6 +668,17 @@ def __init__(
raise ValueError(f"Invalid bitwidth: {type}")
super().__init__([data_attr, type])

@staticmethod
def parse_with_type(
parser: AttrParser,
type: Attribute,
) -> TypedAttribute:
assert isinstance(type, AnyFloat)
return FloatAttr(parser.parse_float(), type)

def print_without_type(self, printer: Printer):
return printer.print_float(self)


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)
Expand Down
2 changes: 2 additions & 0 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ class FillOp(NamedOpBase):

name = "linalg.fill"

PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True

def __init__(
self,
inputs: Sequence[SSAValue],
Expand Down
37 changes: 34 additions & 3 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
opt_prop_def,
prop_def,
region_def,
result_def,
traits_def,
var_operand_def,
)
Expand Down Expand Up @@ -187,14 +188,44 @@ def offsets(self) -> tuple[tuple[int, ...], ...]:


@irdl_op_definition
class ReadOp(stream.ReadOperation):
class ReadOp(IRDLOperation):
name = "memref_stream.read"

T: ClassVar = VarConstraint("T", AnyAttr())

stream = operand_def(stream.ReadableStreamType.constr(T))
res = result_def(T)

assembly_format = "`from` $stream attr-dict `:` type($res)"

def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None):
if result_type is None:
assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType)
stream_type = cast(stream.ReadableStreamType[Attribute], stream_type)
result_type = stream_type.element_type
super().__init__(operands=[stream_val], result_types=[result_type])

def assembly_line(self) -> str | None:
return None


@irdl_op_definition
class WriteOp(stream.WriteOperation):
class WriteOp(IRDLOperation):
name = "memref_stream.write"

T: ClassVar = VarConstraint("T", AnyAttr())

value = operand_def(T)
stream = operand_def(stream.WritableStreamType.constr(T))

assembly_format = "$value `to` $stream attr-dict `:` type($value)"

def __init__(self, value: SSAValue, stream: SSAValue):
super().__init__(operands=[value, stream])

def assembly_line(self) -> str | None:
return None


@irdl_op_definition
class StreamingRegionOp(IRDLOperation):
Expand Down Expand Up @@ -370,7 +401,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr())
outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr)
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
Loading

0 comments on commit e0224a8

Please sign in to comment.