Skip to content

Commit

Permalink
Merge branch 'main' into sasha/eqsat/init
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 24, 2024
2 parents 8c4fdf4 + d9e2fa1 commit 694dd7c
Show file tree
Hide file tree
Showing 77 changed files with 2,042 additions and 1,296 deletions.
2 changes: 1 addition & 1 deletion docs/Toy/toy/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
impl_terminator,
register_impls,
)
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr

from .dialects import toy as toy

Expand Down
2 changes: 1 addition & 1 deletion docs/marimo/linalg_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __():
from xdsl.dialects import arith, func, linalg
from xdsl.dialects.builtin import AffineMap, AffineMapAttr, MemRefType, ModuleOp, f64
from xdsl.dialects.riscv import riscv_code
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Attribute, Block, Region, SSAValue
from xdsl.passes import PipelinePass
from xdsl.tools.command_line_tool import get_all_dialects
Expand Down
26 changes: 14 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ dev = [
"nbval<0.12",
"filecheck==1.0.1",
"lit<19.0.0",
"marimo==0.9.20",
"marimo==0.9.21",
"pre-commit==4.0.1",
"ruff==0.7.4",
"asv<0.7",
"nbconvert>=7.7.2,<8.0.0",
"textual-dev==1.6.1",
"textual-dev==1.7.0",
"pytest-asyncio==0.24.0",
"pyright==1.1.389",
]
gui = ["textual==0.86.1", "pyclip==0.7"]
gui = ["textual==0.86.3", "pyclip==0.7"]
jax = ["jax==0.4.35", "numpy==2.1.3"]
onnx = ["onnx==1.17.0", "numpy==2.1.3"]
riscv = ["riscemu==2.2.7"]
wgpu = ["wgpu==0.19.0"]
wgpu = ["wgpu==0.19.1"]

[project.urls]
Homepage = "https://xdsl.dev/"
Expand Down Expand Up @@ -77,7 +77,7 @@ typeCheckingMode = "strict"
"tests/test_frontend_python_code_check.py",
"xdsl/frontend/onnx/ir_builder.py",
"xdsl/frontend/onnx/type.py",
"xdsl/interpreters/experimental/wgpu.py",
"xdsl/interpreters/wgpu.py",
]
"ignore" = [
"docs/marimo",
Expand Down Expand Up @@ -109,17 +109,19 @@ ignore = [
max-line-length = 300

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"xdsl.parser.core".msg = "Use xdsl.parser instead."
"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead."
"xdsl.dialects.utils.fast_math".msg = "Use xdsl.dialects.utils instead"
"xdsl.dialects.utils.format".msg = "Use xdsl.dialects.utils instead"
"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.core".msg = "Use xdsl.ir instead."
"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead"
"xdsl.irdl.common".msg = "Use xdsl.irdl instead"
"xdsl.irdl.constraints".msg = "Use xdsl.irdl instead"
"xdsl.irdl.attributes".msg = "Use xdsl.irdl instead"
"xdsl.irdl.operations".msg = "Use xdsl.irdl instead"
"xdsl.ir.affine.affine_expr".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_map".msg = "Use xdsl.ir.affine instead"
"xdsl.ir.affine.affine_set".msg = "Use xdsl.ir.affine instead"
"xdsl.parser.affine_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.attribute_parser".msg = "Use xdsl.parser instead."
"xdsl.parser.core".msg = "Use xdsl.parser instead."


[tool.ruff.lint.per-file-ignores]
Expand Down
137 changes: 135 additions & 2 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,139 @@
from xdsl.dialects.bufferization import AllocTensorOp, ToTensorOp
from xdsl.dialects.builtin import MemRefType, TensorType, UnitAttr, f64
from typing import ClassVar

import pytest

from xdsl.dialects.bufferization import (
AllocTensorOp,
TensorFromMemrefConstraint,
ToTensorOp,
)
from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
AnyUnrankedMemrefTypeConstr,
IndexType,
IntegerType,
MemRefType,
TensorType,
UnitAttr,
UnrankedMemrefType,
UnrankedTensorType,
f64,
)
from xdsl.dialects.test import TestOp
from xdsl.ir import Attribute
from xdsl.irdl import (
EqAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
operand_def,
)
from xdsl.utils.exceptions import VerifyException


def test_tensor_from_memref_inference():
constr = TensorFromMemrefConstraint(AnyMemRefTypeConstr)
assert not constr.can_infer(set())

constr2 = TensorFromMemrefConstraint(
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30])

constr3 = TensorFromMemrefConstraint(
EqAttrConstraint(UnrankedMemrefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64)


@irdl_op_definition
class TensorFromMemref(IRDLOperation):
name = "test.tensor_from_memref"
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)

in_tensor = operand_def(
TensorFromMemrefConstraint(
MemRefType.constr(element_type=EqAttrConstraint(IndexType()))
)
)

in_var_memref = operand_def(T)

in_var_tensor = operand_def(TensorFromMemrefConstraint(T))


def test_tensor_from_memref_constraint():
[v_memref, v_tensor] = TestOp(
result_types=[
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
]
).res
op1 = TensorFromMemref(operands=(v_tensor, v_memref, v_tensor))
op1.verify()

[v_unranked_memref, v_unranked_tensor] = TestOp(
result_types=[
UnrankedMemrefType.from_type(IndexType()),
UnrankedTensorType(IndexType()),
]
).res
op2 = TensorFromMemref(operands=(v_tensor, v_unranked_memref, v_unranked_tensor))
op2.verify()


@pytest.mark.parametrize(
"type1, type2, type3, error",
[
(
IndexType(),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"Expected tensor or unranked tensor type, got index",
),
(
TensorType(IntegerType(32), [10, 10, 10]),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"Expected attribute index but got i32",
),
(
UnrankedTensorType(IndexType()),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"memref<\\*xindex> should be of base attribute memref",
),
(
TensorType(IndexType(), [10, 10, 10]),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 20]),
"attribute memref<10x20x30xindex> expected from variable 'T', but got memref<10x20x20xindex>",
),
(
TensorType(IndexType(), [10, 10, 10]),
MemRefType(IntegerType(32), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"attribute memref<10x20x30xi32> expected from variable 'T', but got memref<10x20x30xindex>",
),
],
)
def test_tensor_from_memref_constraint_failure(
type1: Attribute, type2: Attribute, type3: Attribute, error: str
):
[v1, v2, v3] = TestOp(
result_types=[
type1,
type2,
type3,
]
).res

op1 = TensorFromMemref(operands=(v1, v2, v3))
with pytest.raises(VerifyException, match=error):
op1.verify()


def test_to_tensor():
Expand Down
12 changes: 6 additions & 6 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ def test_vector_rank_constraint_verify():
vector_type = VectorType(i32, [1, 2])
constraint = VectorRankConstraint(2)

constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())


def test_vector_rank_constraint_rank_mismatch():
vector_type = VectorType(i32, [1, 2])
constraint = VectorRankConstraint(3)

with pytest.raises(VerifyException) as e:
constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())
assert e.value.args[0] == "Expected vector rank to be 3, got 2."


Expand All @@ -180,23 +180,23 @@ def test_vector_rank_constraint_attr_mismatch():
constraint = VectorRankConstraint(3)

with pytest.raises(VerifyException) as e:
constraint.verify(memref_type)
constraint.verify(memref_type, ConstraintContext())
assert e.value.args[0] == "memref<1x2xi32> should be of type VectorType."


def test_vector_base_type_constraint_verify():
vector_type = VectorType(i32, [1, 2])
constraint = VectorBaseTypeConstraint(i32)

constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())


def test_vector_base_type_constraint_type_mismatch():
vector_type = VectorType(i32, [1, 2])
constraint = VectorBaseTypeConstraint(i64)

with pytest.raises(VerifyException) as e:
constraint.verify(vector_type)
constraint.verify(vector_type, ConstraintContext())
assert e.value.args[0] == "Expected vector type to be i64, got i32."


Expand All @@ -205,7 +205,7 @@ def test_vector_base_type_constraint_attr_mismatch():
constraint = VectorBaseTypeConstraint(i32)

with pytest.raises(VerifyException) as e:
constraint.verify(memref_type)
constraint.verify(memref_type, ConstraintContext())
assert e.value.args[0] == "memref<1x2xi32> should be of type VectorType."


Expand Down
9 changes: 9 additions & 0 deletions tests/filecheck/dialects/arm/test_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP


// CHECK: %x1 = arm.get_register : !arm.reg<x1>
%x1 = arm.get_register : !arm.reg<x1>


// CHECK-GENERIC: %x1 = "arm.get_register"() : () -> !arm.reg<x1>
34 changes: 17 additions & 17 deletions tests/filecheck/dialects/fsm/fsm_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@
"fsm.transition"() ({
"fsm.return"(%arg0) : (i1) -> ()
}, {

}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()
}) {function_type = (i1) -> (i1), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> ()

Expand All @@ -160,7 +160,7 @@
"fsm.output"() : () -> ()
}, {
"fsm.transition"() ({
^bb1(%arg3: i1):
^bb1(%arg3: i1):
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
"fsm.output"() : () -> ()
}, {
Expand All @@ -179,16 +179,16 @@

"fsm.state"() ({
"fsm.output"() : () -> ()

}, {
"fsm.transition"() ({

}, {
^bb1(%arg3: i1):
^bb1(%arg3: i1):
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()
}) {function_type = () -> (), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> ()

Expand All @@ -203,10 +203,10 @@

}, {
"fsm.transition"() ({

}, {
}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()

}) {function_type = (i16) -> (i16) , initialState = "A", sym_name = "foo"} : () -> ()
Expand Down Expand Up @@ -257,7 +257,7 @@
}, {
}) {nextState = @A} : () -> ()
}) {sym_name = "A"} : () -> ()

}) {function_type = (i16) -> (i1), initialState = "A", sym_name = "foo"} : () -> ()
%arg1 = "arith.constant"() {value = 0 : i16} : () -> i16
%arg2 = "arith.constant"() {value = 0 : i16} : () -> i16
Expand Down Expand Up @@ -328,12 +328,12 @@
}) {nextState = @C} : () -> ()
}) {sym_name = "C"} : () -> ()
}) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> ()

"func.func"() ({
%3 = "arith.constant"() {value = 16: i16} : () -> i16

%4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i16
%1 = "arith.constant"() {value = 0 : i16} : () -> i16
%2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "qux"} : () -> ()
Expand Down Expand Up @@ -371,10 +371,10 @@
}) {nextState = @C} : () -> ()
}) {sym_name = "C"} : () -> ()
}) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> ()

"func.func"() ({
%3 = "arith.constant"() {value = 16: i16} : () -> i16

%4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i1
%2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i16
Expand All @@ -391,8 +391,8 @@
%0 = "fsm.variable"() {initValue = 0 : i16, name = "cnt"} : () -> i16
"fsm.machine"() ({
%4 = "test.op"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i16
%2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1
%1 = "arith.constant"() {value = true} : () -> i1
%2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i1
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "qux"} : () -> ()

Expand Down
Loading

0 comments on commit 694dd7c

Please sign in to comment.