Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: make TypedAttribute not generic #3505

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dev = [
"nbval<0.12",
"filecheck==1.0.1",
"lit<19.0.0",
"marimo==0.9.20",
"marimo==0.9.21",
"pre-commit==4.0.1",
"ruff==0.7.4",
"asv<0.7",
Expand Down Expand Up @@ -109,17 +109,19 @@ ignore = [
max-line-length = 300

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"xdsl.parser.core".msg = "Use xdsl.parser instead."
"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead."
"xdsl.dialects.utils.fast_math".msg = "Use xdsl.dialects.utils instead"
"xdsl.dialects.utils.format".msg = "Use xdsl.dialects.utils instead"
"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.core".msg = "Use xdsl.ir instead."
"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead"
"xdsl.irdl.common".msg = "Use xdsl.irdl instead"
"xdsl.irdl.constraints".msg = "Use xdsl.irdl instead"
"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead"
"xdsl.irdl.operations".msg = "Use xdsl.irdl instead"
"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead"
"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.core".msg = "Use xdsl.parser instead."


[tool.ruff.lint.per-file-ignores]
Expand Down
5 changes: 3 additions & 2 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from xdsl.ir import Attribute
from xdsl.irdl import (
EqAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand All @@ -39,13 +40,13 @@ def test_tensor_from_memref_inference():
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer({}) == TensorType(f64, [10, 20, 30])
assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30])

constr3 = TensorFromMemrefConstraint(
EqAttrConstraint(UnrankedMemrefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer({}) == UnrankedTensorType(f64)
assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64)


@irdl_op_definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)
%18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%20 = "test.op"() : () -> (memref<64x4096xf32>)

%zero = arith.constant 0: f32
linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>)

linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>)


Expand Down Expand Up @@ -99,17 +102,19 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou
// CHECK-NEXT: %13:2 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>)
// CHECK-NEXT: %broadcasted = linalg.broadcast ins(%13#0 : tensor<16xf32>) outs(%13#1 : tensor<16x64xf32>) dimensions = [1]
// CHECK-NEXT: %{{.*}} = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%1#0, %1#0 : tensor<2x3xf32>, tensor<2x3xf32>) outs(%1#0 : tensor<2x3xf32>) {
// CHECK-NEXT: ^bb0(%in: f32, %in_1: f32, %out: f32):
// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_1 : f32
// CHECK-NEXT: ^bb0(%in: f32, %in_2: f32, %out: f32):
// CHECK-NEXT: %{{.*}} = arith.addf %in, %in_2 : f32
// CHECK-NEXT: linalg.yield %{{.*}} : f32
// CHECK-NEXT: } -> tensor<2x3xf32>
// CHECK-NEXT: %{{.*}} = linalg.sub ins(%{{.*}}, %{{.*}} : tensor<2x3xf32>, tensor<2x3xf32>) outs(%{{.*}} : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %16:2 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
// CHECK-NEXT: %17 = "test.op"() : () -> memref<64x4096xf32>
// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: linalg.fill {id} ins(%cst_0 : f32) outs(%17 : memref<64x4096xf32>)
// CHECK-NEXT: linalg.matmul {id} ins(%16#0, %16#1 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%17 : memref<64x4096xf32>)
// CHECK-NEXT: %18:2 = "test.op"() : () -> (tensor<64x9216xi8>, tensor<9216x4096xi8>)
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32
// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32
// CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_0 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
// CHECK-NEXT: }
27 changes: 27 additions & 0 deletions tests/filecheck/transforms/apply-pdl/apply_pdl_build_type.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: xdsl-opt %s -p apply-pdl | filecheck %s

%x = "test.op"() : () -> (i32)

pdl.pattern : benefit(1) {
%in_type = pdl.type: i32
%root = pdl.operation "test.op" -> (%in_type: !pdl.type)
pdl.rewrite %root {
%out_type = pdl.type: i64
%new_op = pdl.operation "test.op" -> (%out_type: !pdl.type)
pdl.replace %root with %new_op
}
}

// CHECK: builtin.module {
// CHECK-NEXT: %x = "test.op"() : () -> i64
// CHECK-NEXT: pdl.pattern : benefit(1) {
// CHECK-NEXT: %in_type = pdl.type : i32
// CHECK-NEXT: %root = pdl.operation "test.op" -> (%in_type : !pdl.type)
// CHECK-NEXT: pdl.rewrite %root {
// CHECK-NEXT: %out_type = pdl.type : i64
// CHECK-NEXT: %new_op = pdl.operation "test.op" -> (%out_type : !pdl.type)
// CHECK-NEXT: pdl.replace %root with %new_op
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:
7 changes: 4 additions & 3 deletions tests/irdl/test_attr_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttrConstraint,
BaseAttr,
EqAttrConstraint,
InferenceContext,
ParamAttrConstraint,
ParameterDef,
VarConstraint,
Expand Down Expand Up @@ -77,7 +78,7 @@ class WrapAttr(BaseWrapAttr): ...
)

assert constr.can_infer(set())
assert constr.infer({}) == WrapAttr((StringAttr("Hello"),))
assert constr.infer(InferenceContext()) == WrapAttr((StringAttr("Hello"),))

var_constr = ParamAttrConstraint(
WrapAttr,
Expand All @@ -92,7 +93,7 @@ class WrapAttr(BaseWrapAttr): ...
)

assert var_constr.can_infer({"T"})
assert var_constr.infer({"T": StringAttr("Hello")}) == WrapAttr(
assert var_constr.infer(InferenceContext({"T": StringAttr("Hello")})) == WrapAttr(
(StringAttr("Hello"),)
)

Expand Down Expand Up @@ -127,7 +128,7 @@ class NoParamAttr(BaseNoParamAttr): ...
constr = BaseAttr(NoParamAttr)

assert constr.can_infer(set())
assert constr.infer({}) == NoParamAttr()
assert constr.infer(InferenceContext()) == NoParamAttr()

base_constr = BaseAttr(BaseNoParamAttr)
assert not base_constr.can_infer(set())
Expand Down
2 changes: 1 addition & 1 deletion tests/irdl/test_attribute_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_typed_attribute():

@irdl_attr_definition
class TypedAttr( # pyright: ignore[reportUnusedClass]
TypedAttribute[Attribute]
TypedAttribute
):
name = "test.typed"

Expand Down
15 changes: 9 additions & 6 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import textwrap
from collections.abc import Callable
from io import StringIO
from typing import ClassVar, Generic, TypeVar
from typing import Annotated, ClassVar, Generic, TypeVar

import pytest

Expand All @@ -12,6 +12,8 @@
from xdsl.dialects.builtin import (
I32,
BoolAttr,
Float64Type,
FloatAttr,
IntegerAttr,
MemRefType,
ModuleOp,
Expand Down Expand Up @@ -603,20 +605,21 @@ class OptionalAttributeOp(IRDLOperation):
"program, generic_program",
[
(
"test.typed_attr 3",
'"test.typed_attr"() {"attr" = 3 : i32} : () -> ()',
"test.typed_attr 3 3.000000e+00",
'"test.typed_attr"() {"attr" = 3 : i32, "float_attr" = 3.000000e+00 : f64} : () -> ()',
),
],
)
def test_typed_attribute_variable(program: str, generic_program: str):
"""Test the parsing of optional operands"""
"""Test the parsing of typed attributes"""

@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])
float_attr = attr_def(FloatAttr[Annotated[Float64Type, Float64Type()]])

assembly_format = "$attr attr-dict"
assembly_format = "$attr $float_attr attr-dict"

ctx = MLContext()
ctx.load_op(TypedAttributeOp)
Expand Down Expand Up @@ -693,7 +696,7 @@ def test_unknown_variable():
"""Test that variables should refer to an element in the operation."""
with pytest.raises(
PyRDLOpDefinitionError,
match="expected variable to refer to an operand, attribute, region, result, or successor",
match="expected variable to refer to an operand, attribute, region, or successor",
):

@irdl_op_definition
Expand Down
13 changes: 13 additions & 0 deletions tests/utils/test_scoped_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,16 @@ def test_simple():
assert 3 not in table
assert 3 in inner
assert 4 not in inner


def test_get():
parent = ScopedDict(local_scope={"a": 1, "b": 2})
child = ScopedDict(parent, local_scope={"a": 3, "c": 4})

assert child.get("a") == 3
assert child.get("b") == 2
assert child.get("c") == 4
assert child.get("d") is None

assert child.get("a", 5) == 3
assert child.get("d", 5) == 5
6 changes: 3 additions & 3 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from xdsl.irdl import (
AttrSizedOperandSegments,
ConstraintContext,
ConstraintVariableType,
GenericAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand Down Expand Up @@ -53,9 +53,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.memref_constraint.can_infer(var_constraint_names)

def infer(
self, variables: dict[str, ConstraintVariableType]
self, context: InferenceContext
) -> TensorType[Attribute] | UnrankedTensorType[Attribute]:
memref_type = self.memref_constraint.infer(variables)
memref_type = self.memref_constraint.infer(context)
if isinstance(memref_type, MemRefType):
return TensorType(memref_type.element_type, memref_type.shape)
return UnrankedTensorType(memref_type.element_type)
Expand Down
19 changes: 15 additions & 4 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ class IndexType(ParametrizedAttribute):
@irdl_attr_definition
class IntegerAttr(
Generic[_IntegerAttrType],
TypedAttribute[_IntegerAttrType],
TypedAttribute,
):
name = "integer"
value: ParameterDef[IntAttr]
Expand Down Expand Up @@ -504,8 +504,8 @@ def verify(self) -> None:
@staticmethod
def parse_with_type(
parser: AttrParser,
type: AttributeInvT,
) -> TypedAttribute[AttributeInvT]:
type: Attribute,
) -> TypedAttribute:
assert isinstance(type, IntegerType | IndexType)
return IntegerAttr(parser.parse_integer(allow_boolean=(type == i1)), type)

Expand Down Expand Up @@ -634,7 +634,7 @@ def __hash__(self):


@irdl_attr_definition
class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute):
class FloatAttr(Generic[_FloatAttrType], TypedAttribute):
name = "float"

value: ParameterDef[FloatData]
Expand Down Expand Up @@ -668,6 +668,17 @@ def __init__(
raise ValueError(f"Invalid bitwidth: {type}")
super().__init__([data_attr, type])

@staticmethod
def parse_with_type(
parser: AttrParser,
type: Attribute,
) -> TypedAttribute:
assert isinstance(type, AnyFloat)
return FloatAttr(parser.parse_float(), type)

def print_without_type(self, printer: Printer):
return printer.print_float(self)


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)
Expand Down
2 changes: 2 additions & 0 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ class FillOp(NamedOpBase):

name = "linalg.fill"

PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True

def __init__(
self,
inputs: Sequence[SSAValue],
Expand Down
37 changes: 34 additions & 3 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
opt_prop_def,
prop_def,
region_def,
result_def,
traits_def,
var_operand_def,
)
Expand Down Expand Up @@ -187,14 +188,44 @@ def offsets(self) -> tuple[tuple[int, ...], ...]:


@irdl_op_definition
class ReadOp(stream.ReadOperation):
class ReadOp(IRDLOperation):
name = "memref_stream.read"

T: ClassVar = VarConstraint("T", AnyAttr())

stream = operand_def(stream.ReadableStreamType.constr(T))
res = result_def(T)

assembly_format = "`from` $stream attr-dict `:` type($res)"

def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None):
if result_type is None:
assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType)
stream_type = cast(stream.ReadableStreamType[Attribute], stream_type)
result_type = stream_type.element_type
super().__init__(operands=[stream_val], result_types=[result_type])

def assembly_line(self) -> str | None:
return None


@irdl_op_definition
class WriteOp(stream.WriteOperation):
class WriteOp(IRDLOperation):
name = "memref_stream.write"

T: ClassVar = VarConstraint("T", AnyAttr())

value = operand_def(T)
stream = operand_def(stream.WritableStreamType.constr(T))

assembly_format = "$value `to` $stream attr-dict `:` type($value)"

def __init__(self, value: SSAValue, stream: SSAValue):
super().__init__(operands=[value, stream])

def assembly_line(self) -> str | None:
return None


@irdl_op_definition
class StreamingRegionOp(IRDLOperation):
Expand Down Expand Up @@ -370,7 +401,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr())
outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr)
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
Loading
Loading