Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (stream) simplify constr helper on stream attributes #3473

Merged
merged 12 commits into from
Nov 20, 2024
Merged
2 changes: 1 addition & 1 deletion docs/Toy/toy/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
impl_terminator,
register_impls,
)
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr

from .dialects import toy as toy

Expand Down
2 changes: 1 addition & 1 deletion docs/marimo/linalg_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __():
from xdsl.dialects import arith, func, linalg
from xdsl.dialects.builtin import AffineMap, AffineMapAttr, MemRefType, ModuleOp, f64
from xdsl.dialects.riscv import riscv_code
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Attribute, Block, Region, SSAValue
from xdsl.passes import PipelinePass
from xdsl.tools.command_line_tool import get_all_dialects
Expand Down
29 changes: 0 additions & 29 deletions tests/filecheck/dialects/stream/ops.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion tests/interpreters/test_affine_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.func import FuncFunctions
from xdsl.interpreters.memref import MemrefFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir.affine import AffineMap
from xdsl.utils.test_value import TestSSAValue

Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_builtin_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
i64,
)
from xdsl.interpreter import Interpreter
from xdsl.interpreters import ptr
from xdsl.interpreters.builtin import BuiltinFunctions
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils import ptr

interpreter = Interpreter(ModuleOp([]))
interpreter.register_implementations(BuiltinFunctions())
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_linalg_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.linalg import LinalgFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Block, Region
from xdsl.ir.affine import AffineExpr, AffineMap
from xdsl.utils.test_value import TestSSAValue
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_memref_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.memref import MemrefFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr

interpreter = Interpreter(ModuleOp([]), index_bitwidth=32)
interpreter.register_implementations(ArithFunctions())
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_memref_stream_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.memref_stream import MemrefStreamFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Block, Region
from xdsl.ir.affine import AffineExpr, AffineMap
from xdsl.utils.test_value import TestSSAValue
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_ml_program_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
)
from xdsl.interpreter import Interpreter
from xdsl.interpreters.ml_program import MLProgramFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr


def test_ml_program_global_load_constant():
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_onnx_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
)
from xdsl.interpreter import Interpreter
from xdsl.interpreters.builtin import BuiltinFunctions
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.utils.exceptions import InterpretationError
from xdsl.utils.test_value import TestSSAValue

Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_ptr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xdsl.interpreters import ptr
from xdsl.interpreters.utils import ptr


def test_ptr():
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_riscv_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
i32,
)
from xdsl.interpreter import Interpreter, PythonValues
from xdsl.interpreters.ptr import RawPtr, TypedPtr
from xdsl.interpreters.riscv import RiscvFunctions
from xdsl.interpreters.utils.ptr import RawPtr, TypedPtr
from xdsl.ir import Block, Region
from xdsl.utils.bitwise_casts import convert_f32_to_u32
from xdsl.utils.exceptions import InterpretationError
Expand Down
3 changes: 1 addition & 2 deletions tests/interpreters/test_riscv_snitch_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from xdsl.interpreters.func import FuncFunctions
from xdsl.interpreters.riscv import RiscvFunctions
from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions
from xdsl.interpreters.utils.stream import Acc, Nats
from xdsl.ir import BlockArgument
from xdsl.utils.test_value import TestSSAValue

from .test_stream_interpreter import Acc, Nats


def test_read_write():
interpreter = Interpreter(ModuleOp([]))
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_shaped_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr


def test_shaped_array_offset():
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_snitch_stream_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from xdsl.dialects import riscv, riscv_snitch, snitch_stream, stream
from xdsl.dialects.builtin import ArrayAttr, ModuleOp
from xdsl.interpreter import Interpreter
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.riscv import RiscvFunctions
from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions
from xdsl.interpreters.snitch_stream import SnitchStreamFunctions
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Block, Region
from xdsl.utils.test_value import TestSSAValue

Expand Down
70 changes: 0 additions & 70 deletions tests/interpreters/test_stream_interpreter.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/interpreters/test_tensor_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from xdsl.dialects import tensor
from xdsl.dialects.builtin import ModuleOp, TensorType, f32, i32
from xdsl.interpreter import Interpreter
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.tensor import TensorFunctions
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.utils.exceptions import InterpretationError
from xdsl.utils.test_value import TestSSAValue

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dialect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from xdsl.context import MLContext
from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_dynamic_index_list_with_types,
parse_dynamic_index_list_without_types,
parse_dynamic_index_with_type,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SymbolRefAttr,
UnrealizedConversionCastOp,
)
from xdsl.interpreters.ptr import TypedPtr
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Attribute, Operation, Region, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
UnrankedTensorType,
VectorType,
)
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag
from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SymbolRefAttr,
TensorType,
)
from xdsl.dialects.utils import parse_func_op_like, print_func_op_like
from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like
from xdsl.ir import (
Attribute,
Block,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
TensorType,
)
from xdsl.dialects.experimental import dmp
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/experimental/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_call_op_like,
parse_func_op_like,
print_call_op_like,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
TensorType,
i64,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
AbstractYieldOperation,
)
from xdsl.ir import (
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
i32,
i64,
)
from xdsl.dialects.utils import FastMathAttrBase
from xdsl.dialects.utils.fast_math import FastMathAttrBase
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
i32,
i64,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_dynamic_index_list_without_types,
print_dynamic_index_list,
)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
IntegerType,
StringAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down Expand Up @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr())
outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr)
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.dialects.utils.format import AbstractYieldOperation
from xdsl.ir import (
Attribute,
Dialect,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
UnitAttr,
i32,
)
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
Block,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
StringAttr,
SymbolRefAttr,
)
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
parse_call_op_like,
parse_func_op_like,
print_call_op_like,
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/riscv_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import Self

from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType
from xdsl.dialects.utils import (
from xdsl.dialects.utils.format import (
AbstractYieldOperation,
parse_assignment,
print_assignment,
Expand Down
Loading
Loading