From 5b1582493d2c0340a0d9776b43c10e6f66444cc7 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 09:25:36 +0000 Subject: [PATCH 1/9] dialects: (stream) remove unused stream.read and stream.write ops --- tests/filecheck/dialects/stream/ops.mlir | 29 -------- .../test_riscv_snitch_interpreter.py | 3 +- tests/interpreters/test_stream_interpreter.py | 70 ------------------- xdsl/dialects/stream.py | 16 +---- xdsl/interpreters/stream.py | 46 ------------ xdsl/interpreters/{ => utils}/ptr.py | 0 xdsl/interpreters/utils/stream.py | 52 ++++++++++++++ 7 files changed, 54 insertions(+), 162 deletions(-) delete mode 100644 tests/filecheck/dialects/stream/ops.mlir delete mode 100644 tests/interpreters/test_stream_interpreter.py delete mode 100644 xdsl/interpreters/stream.py rename xdsl/interpreters/{ => utils}/ptr.py (100%) create mode 100644 xdsl/interpreters/utils/stream.py diff --git a/tests/filecheck/dialects/stream/ops.mlir b/tests/filecheck/dialects/stream/ops.mlir deleted file mode 100644 index 1eba0c6694..0000000000 --- a/tests/filecheck/dialects/stream/ops.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: XDSL_ROUNDTRIP -// RUN: XDSL_GENERIC_ROUNDTRIP - -// CHECK: builtin.module { - - -%readable_stream = "test.op"() : () -> !stream.readable -// CHECK-NEXT: %readable_stream = "test.op"() : () -> !stream.readable - -%writable_stream = "test.op"() : () -> !stream.writable -// CHECK-NEXT: %writable_stream = "test.op"() : () -> !stream.writable - -%value = stream.read from %readable_stream : index -// CHECK-NEXT: %value = stream.read from %readable_stream : index - -stream.write %value to %writable_stream : index -// CHECK-NEXT: stream.write %value to %writable_stream : index - - -// CHECK-NEXT: } - - - -// CHECK-GENERIC: "builtin.module"() ({ -// CHECK-NEXT-GENERIC: %readable_stream = "test.op"() : () -> !stream.readable -// CHECK-NEXT-GENERIC: %writable_stream = "test.op"() : () -> !stream.writable -// CHECK-NEXT-GENERIC: "stream.read"(%readable_stream) : (!stream.readable) -> index -// CHECK-NEXT-GENERIC: "stream.write"(%value, %writable_stream) : (index, !stream.writable) -> () -// CHECK-NEXT-GENERIC: }) : () -> () diff --git a/tests/interpreters/test_riscv_snitch_interpreter.py b/tests/interpreters/test_riscv_snitch_interpreter.py index 286c2ff2cc..e629178c94 100644 --- a/tests/interpreters/test_riscv_snitch_interpreter.py +++ b/tests/interpreters/test_riscv_snitch_interpreter.py @@ -5,11 +5,10 @@ from xdsl.interpreters.func import FuncFunctions from xdsl.interpreters.riscv import RiscvFunctions from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions +from xdsl.interpreters.stream import Acc, Nats from xdsl.ir import BlockArgument from xdsl.utils.test_value import TestSSAValue -from .test_stream_interpreter import Acc, Nats - def test_read_write(): interpreter = Interpreter(ModuleOp([])) diff --git a/tests/interpreters/test_stream_interpreter.py b/tests/interpreters/test_stream_interpreter.py deleted file mode 100644 index 02449d654a..0000000000 --- a/tests/interpreters/test_stream_interpreter.py +++ /dev/null @@ -1,70 +0,0 @@ -from dataclasses import dataclass, field - -from xdsl.dialects import stream -from xdsl.dialects.builtin import IndexType, ModuleOp -from xdsl.interpreter import Interpreter -from xdsl.interpreters.stream import ( - ReadableStream, - StreamFunctions, - WritableStream, -) -from xdsl.utils.test_value import TestSSAValue - - -@dataclass -class Nats(ReadableStream[int]): - index = 0 - - def read(self) -> int: - self.index += 1 - return self.index - - -@dataclass -class Acc(WritableStream[int]): - values: list[int] = field(default_factory=list) - - def write(self, value: int) -> None: - return self.values.append(value) - - -def test_read_write(): - interpreter = Interpreter(ModuleOp([])) - interpreter.register_implementations(StreamFunctions()) - - input_stream = Nats() - output_stream = Acc() - - index = IndexType() - - (value,) = interpreter.run_op( - stream.ReadOp(TestSSAValue(stream.ReadableStreamType(index))), (input_stream,) - ) - assert value == 1 - - (value,) = interpreter.run_op( - stream.ReadOp(TestSSAValue(stream.ReadableStreamType(index))), (input_stream,) - ) - assert value == 2 - - interpreter.run_op( - stream.WriteOp( - TestSSAValue(index), TestSSAValue(stream.ReadableStreamType(index)) - ), - ( - 1, - output_stream, - ), - ) - assert output_stream.values == [1] - - interpreter.run_op( - stream.WriteOp( - TestSSAValue(index), TestSSAValue(stream.ReadableStreamType(index)) - ), - ( - 2, - output_stream, - ), - ) - assert output_stream.values == [1, 2] diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index a44d9dddb0..3c2d5ea2d6 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -22,7 +22,6 @@ ParameterDef, VarConstraint, irdl_attr_definition, - irdl_op_definition, operand_def, result_def, ) @@ -210,22 +209,9 @@ def print(self, printer: Printer): printer.print_attribute(self.value.type) -@irdl_op_definition -class ReadOp(ReadOperation): - name = "stream.read" - - -@irdl_op_definition -class WriteOp(WriteOperation): - name = "stream.write" - - Stream = Dialect( "stream", - [ - ReadOp, - WriteOp, - ], + [], [ ReadableStreamType, WritableStreamType, diff --git a/xdsl/interpreters/stream.py b/xdsl/interpreters/stream.py deleted file mode 100644 index 8d38e458c0..0000000000 --- a/xdsl/interpreters/stream.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, TypeVar - -from typing_extensions import Protocol - -from xdsl.dialects import stream -from xdsl.interpreter import ( - Interpreter, - InterpreterFunctions, - PythonValues, - impl, - register_impls, -) - -T = TypeVar("T") -TCov = TypeVar("TCov", covariant=True) -TCon = TypeVar("TCon", contravariant=True) - - -class ReadableStream(Protocol[TCov]): - def read(self) -> TCov: - raise NotImplementedError() - - -class WritableStream(Protocol[TCon]): - def write(self, value: TCon) -> None: - raise NotImplementedError() - - -@register_impls -class StreamFunctions(InterpreterFunctions): - @impl(stream.ReadOp) - def run_read( - self, interpreter: Interpreter, op: stream.ReadOp, args: tuple[Any, ...] - ) -> PythonValues: - (stream,) = args - stream: ReadableStream[Any] = stream - return (stream.read(),) - - @impl(stream.WriteOp) - def run_write( - self, interpreter: Interpreter, op: stream.WriteOp, args: tuple[Any, ...] - ) -> PythonValues: - (value, stream) = args - stream: WritableStream[Any] = stream - stream.write(value) - return () diff --git a/xdsl/interpreters/ptr.py b/xdsl/interpreters/utils/ptr.py similarity index 100% rename from xdsl/interpreters/ptr.py rename to xdsl/interpreters/utils/ptr.py diff --git a/xdsl/interpreters/utils/stream.py b/xdsl/interpreters/utils/stream.py new file mode 100644 index 0000000000..b0fdbf766d --- /dev/null +++ b/xdsl/interpreters/utils/stream.py @@ -0,0 +1,52 @@ +import abc +from dataclasses import dataclass, field +from typing import Generic, TypeVar + +T = TypeVar("T") +TCov = TypeVar("TCov", covariant=True) +TCon = TypeVar("TCon", contravariant=True) + + +class ReadableStream(Generic[TCov], abc.ABC): + """ + Abstract base class for readable stream interpreter model objects. + """ + + @abc.abstractmethod + def read(self) -> TCov: + raise NotImplementedError() + + +class WritableStream(Generic[TCon], abc.ABC): + """ + Abstract base class for readable stream interpreter model objects. + """ + + @abc.abstractmethod + def write(self, value: TCon) -> None: + raise NotImplementedError() + + +@dataclass +class Nats(ReadableStream[int]): + """ + A stream designed for testing, outputs the next natural number each time it's read. + """ + + index = 0 + + def read(self) -> int: + self.index += 1 + return self.index + + +@dataclass +class Acc(WritableStream[int]): + """ + A stream designed for testing, appends the next natural number written. + """ + + values: list[int] = field(default_factory=list) + + def write(self, value: int) -> None: + return self.values.append(value) From a55afb1fe166fc2c54fc8e6b6b1526fdd691885d Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 09:32:56 +0000 Subject: [PATCH 2/9] fix imports --- docs/Toy/toy/interpreter.py | 2 +- docs/marimo/linalg_snitch.py | 2 +- tests/interpreters/test_affine_interpreter.py | 2 +- tests/interpreters/test_builtin_interpreter.py | 2 +- tests/interpreters/test_linalg_interpreter.py | 2 +- tests/interpreters/test_memref_interpreter.py | 2 +- tests/interpreters/test_memref_stream_interpreter.py | 2 +- tests/interpreters/test_ml_program_interpreter.py | 2 +- tests/interpreters/test_onnx_interpreter.py | 2 +- tests/interpreters/test_ptr.py | 2 +- tests/interpreters/test_riscv_interpreter.py | 2 +- tests/interpreters/test_riscv_snitch_interpreter.py | 2 +- tests/interpreters/test_shaped_array.py | 2 +- tests/interpreters/test_snitch_stream_interpreter.py | 2 +- tests/interpreters/test_tensor_interpreter.py | 2 +- xdsl/backend/riscv/lowering/convert_memref_to_riscv.py | 2 +- xdsl/interpreters/builtin.py | 2 +- xdsl/interpreters/memref.py | 2 +- xdsl/interpreters/ml_program.py | 2 +- xdsl/interpreters/onnx.py | 2 +- xdsl/interpreters/riscv.py | 2 +- xdsl/interpreters/riscv_libc.py | 2 +- xdsl/interpreters/shaped_array.py | 2 +- xdsl/interpreters/snitch_stream.py | 7 ++----- xdsl/interpreters/tensor.py | 2 +- 25 files changed, 26 insertions(+), 29 deletions(-) diff --git a/docs/Toy/toy/interpreter.py b/docs/Toy/toy/interpreter.py index 9b7c335fd4..f4322f4a01 100644 --- a/docs/Toy/toy/interpreter.py +++ b/docs/Toy/toy/interpreter.py @@ -14,8 +14,8 @@ impl_terminator, register_impls, ) -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from .dialects import toy as toy diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py index efb5efcb86..b13cf85e43 100644 --- a/docs/marimo/linalg_snitch.py +++ b/docs/marimo/linalg_snitch.py @@ -47,7 +47,7 @@ def __(): from xdsl.dialects import arith, func, linalg from xdsl.dialects.builtin import AffineMap, AffineMapAttr, MemRefType, ModuleOp, f64 from xdsl.dialects.riscv import riscv_code - from xdsl.interpreters.ptr import TypedPtr + from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute, Block, Region, SSAValue from xdsl.passes import PipelinePass from xdsl.tools.command_line_tool import get_all_dialects diff --git a/tests/interpreters/test_affine_interpreter.py b/tests/interpreters/test_affine_interpreter.py index 74ab1c6c04..a246a43ddd 100644 --- a/tests/interpreters/test_affine_interpreter.py +++ b/tests/interpreters/test_affine_interpreter.py @@ -8,8 +8,8 @@ from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.func import FuncFunctions from xdsl.interpreters.memref import MemrefFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir.affine import AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_builtin_interpreter.py b/tests/interpreters/test_builtin_interpreter.py index e0109f654a..30acf485af 100644 --- a/tests/interpreters/test_builtin_interpreter.py +++ b/tests/interpreters/test_builtin_interpreter.py @@ -10,9 +10,9 @@ i64, ) from xdsl.interpreter import Interpreter -from xdsl.interpreters import ptr from xdsl.interpreters.builtin import BuiltinFunctions from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils import ptr interpreter = Interpreter(ModuleOp([])) interpreter.register_implementations(BuiltinFunctions()) diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py index ebdeda1089..70376ab67f 100644 --- a/tests/interpreters/test_linalg_interpreter.py +++ b/tests/interpreters/test_linalg_interpreter.py @@ -18,8 +18,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.linalg import LinalgFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_memref_interpreter.py b/tests/interpreters/test_memref_interpreter.py index 72d0120c32..41dd0d5a81 100644 --- a/tests/interpreters/test_memref_interpreter.py +++ b/tests/interpreters/test_memref_interpreter.py @@ -11,8 +11,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.memref import MemrefFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr interpreter = Interpreter(ModuleOp([]), index_bitwidth=32) interpreter.register_implementations(ArithFunctions()) diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py index 462bf6ac25..a19df81e42 100644 --- a/tests/interpreters/test_memref_stream_interpreter.py +++ b/tests/interpreters/test_memref_stream_interpreter.py @@ -16,8 +16,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.memref_stream import MemrefStreamFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_ml_program_interpreter.py b/tests/interpreters/test_ml_program_interpreter.py index c21bcf5a54..f7ec07d991 100644 --- a/tests/interpreters/test_ml_program_interpreter.py +++ b/tests/interpreters/test_ml_program_interpreter.py @@ -11,8 +11,8 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.ml_program import MLProgramFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr def test_ml_program_global_load_constant(): diff --git a/tests/interpreters/test_onnx_interpreter.py b/tests/interpreters/test_onnx_interpreter.py index 4396126b73..1a337a4374 100644 --- a/tests/interpreters/test_onnx_interpreter.py +++ b/tests/interpreters/test_onnx_interpreter.py @@ -15,8 +15,8 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.builtin import BuiltinFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_ptr.py b/tests/interpreters/test_ptr.py index ca04a9faa7..94ae1926e8 100644 --- a/tests/interpreters/test_ptr.py +++ b/tests/interpreters/test_ptr.py @@ -1,4 +1,4 @@ -from xdsl.interpreters import ptr +from xdsl.interpreters.utils import ptr def test_ptr(): diff --git a/tests/interpreters/test_riscv_interpreter.py b/tests/interpreters/test_riscv_interpreter.py index dca1e77180..e3eafe532a 100644 --- a/tests/interpreters/test_riscv_interpreter.py +++ b/tests/interpreters/test_riscv_interpreter.py @@ -14,8 +14,8 @@ i32, ) from xdsl.interpreter import Interpreter, PythonValues -from xdsl.interpreters.ptr import RawPtr, TypedPtr from xdsl.interpreters.riscv import RiscvFunctions +from xdsl.interpreters.utils.ptr import RawPtr, TypedPtr from xdsl.ir import Block, Region from xdsl.utils.bitwise_casts import convert_f32_to_u32 from xdsl.utils.exceptions import InterpretationError diff --git a/tests/interpreters/test_riscv_snitch_interpreter.py b/tests/interpreters/test_riscv_snitch_interpreter.py index e629178c94..9f6166500d 100644 --- a/tests/interpreters/test_riscv_snitch_interpreter.py +++ b/tests/interpreters/test_riscv_snitch_interpreter.py @@ -5,7 +5,7 @@ from xdsl.interpreters.func import FuncFunctions from xdsl.interpreters.riscv import RiscvFunctions from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions -from xdsl.interpreters.stream import Acc, Nats +from xdsl.interpreters.utils.stream import Acc, Nats from xdsl.ir import BlockArgument from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_shaped_array.py b/tests/interpreters/test_shaped_array.py index 755e7b1d82..afbe453191 100644 --- a/tests/interpreters/test_shaped_array.py +++ b/tests/interpreters/test_shaped_array.py @@ -1,5 +1,5 @@ -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr def test_shaped_array_offset(): diff --git a/tests/interpreters/test_snitch_stream_interpreter.py b/tests/interpreters/test_snitch_stream_interpreter.py index de1711eb0b..d1fd1dcd7a 100644 --- a/tests/interpreters/test_snitch_stream_interpreter.py +++ b/tests/interpreters/test_snitch_stream_interpreter.py @@ -4,10 +4,10 @@ from xdsl.dialects import riscv, riscv_snitch, snitch_stream, stream from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.interpreter import Interpreter -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.riscv import RiscvFunctions from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions from xdsl.interpreters.snitch_stream import SnitchStreamFunctions +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_tensor_interpreter.py b/tests/interpreters/test_tensor_interpreter.py index 5b981dd4c1..129bcab325 100644 --- a/tests/interpreters/test_tensor_interpreter.py +++ b/tests/interpreters/test_tensor_interpreter.py @@ -3,9 +3,9 @@ from xdsl.dialects import tensor from xdsl.dialects.builtin import ModuleOp, TensorType, f32, i32 from xdsl.interpreter import Interpreter -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray from xdsl.interpreters.tensor import TensorFunctions +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import TestSSAValue diff --git a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py index 9de6b21a44..ef13fa597c 100644 --- a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py +++ b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py @@ -24,7 +24,7 @@ SymbolRefAttr, UnrealizedConversionCastOp, ) -from xdsl.interpreters.ptr import TypedPtr +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute, Operation, Region, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( diff --git a/xdsl/interpreters/builtin.py b/xdsl/interpreters/builtin.py index da453e61fc..5128205a8d 100644 --- a/xdsl/interpreters/builtin.py +++ b/xdsl/interpreters/builtin.py @@ -16,8 +16,8 @@ impl_attr, register_impls, ) -from xdsl.interpreters import ptr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils import ptr from xdsl.ir import Attribute from xdsl.utils.hints import isa diff --git a/xdsl/interpreters/memref.py b/xdsl/interpreters/memref.py index edacc913d3..34ea29be80 100644 --- a/xdsl/interpreters/memref.py +++ b/xdsl/interpreters/memref.py @@ -10,8 +10,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute from xdsl.traits import SymbolTable diff --git a/xdsl/interpreters/ml_program.py b/xdsl/interpreters/ml_program.py index 4cced0a800..67d02e1375 100644 --- a/xdsl/interpreters/ml_program.py +++ b/xdsl/interpreters/ml_program.py @@ -9,8 +9,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.traits import SymbolTable diff --git a/xdsl/interpreters/onnx.py b/xdsl/interpreters/onnx.py index b85549fa62..eb87701f98 100644 --- a/xdsl/interpreters/onnx.py +++ b/xdsl/interpreters/onnx.py @@ -12,9 +12,9 @@ impl, register_impls, ) -from xdsl.interpreters import ptr from xdsl.interpreters.builtin import xtype_for_el_type from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils import ptr from xdsl.utils.exceptions import InterpretationError diff --git a/xdsl/interpreters/riscv.py b/xdsl/interpreters/riscv.py index 1412a23ea0..13cb1e0a5b 100644 --- a/xdsl/interpreters/riscv.py +++ b/xdsl/interpreters/riscv.py @@ -20,8 +20,8 @@ impl_cast, register_impls, ) -from xdsl.interpreters import ptr from xdsl.interpreters.builtin import xtype_for_el_type +from xdsl.interpreters.utils import ptr from xdsl.ir import Attribute, SSAValue from xdsl.utils.bitwise_casts import convert_u32_to_f32 from xdsl.utils.comparisons import to_signed, to_unsigned diff --git a/xdsl/interpreters/riscv_libc.py b/xdsl/interpreters/riscv_libc.py index 19d6cb5652..5b43994e7a 100644 --- a/xdsl/interpreters/riscv_libc.py +++ b/xdsl/interpreters/riscv_libc.py @@ -7,7 +7,7 @@ impl_external, register_impls, ) -from xdsl.interpreters import ptr +from xdsl.interpreters.utils import ptr from xdsl.ir import Operation diff --git a/xdsl/interpreters/shaped_array.py b/xdsl/interpreters/shaped_array.py index 4b1c65acc5..fbbe8dff51 100644 --- a/xdsl/interpreters/shaped_array.py +++ b/xdsl/interpreters/shaped_array.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.builtin import ShapedType -from xdsl.interpreters.ptr import TypedPtr +from xdsl.interpreters.utils.ptr import TypedPtr _T = TypeVar("_T") diff --git a/xdsl/interpreters/snitch_stream.py b/xdsl/interpreters/snitch_stream.py index f53e79b54f..1edec09e89 100644 --- a/xdsl/interpreters/snitch_stream.py +++ b/xdsl/interpreters/snitch_stream.py @@ -10,11 +10,8 @@ impl, register_impls, ) -from xdsl.interpreters import ptr -from xdsl.interpreters.stream import ( - ReadableStream, - WritableStream, -) +from xdsl.interpreters.utils import ptr +from xdsl.interpreters.utils.stream import ReadableStream, WritableStream @dataclass diff --git a/xdsl/interpreters/tensor.py b/xdsl/interpreters/tensor.py index cd66b484d6..58bbcfd7ca 100644 --- a/xdsl/interpreters/tensor.py +++ b/xdsl/interpreters/tensor.py @@ -10,8 +10,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute from xdsl.utils.exceptions import InterpretationError From 7b48fc86014c9361c9fbdc30ea2d2bc6620b2d75 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 09:37:38 +0000 Subject: [PATCH 3/9] revert ptr move --- docs/Toy/toy/interpreter.py | 2 +- docs/marimo/linalg_snitch.py | 2 +- tests/interpreters/test_affine_interpreter.py | 2 +- tests/interpreters/test_builtin_interpreter.py | 2 +- tests/interpreters/test_linalg_interpreter.py | 2 +- tests/interpreters/test_memref_interpreter.py | 2 +- tests/interpreters/test_memref_stream_interpreter.py | 2 +- tests/interpreters/test_ml_program_interpreter.py | 2 +- tests/interpreters/test_onnx_interpreter.py | 2 +- tests/interpreters/test_ptr.py | 2 +- tests/interpreters/test_riscv_interpreter.py | 2 +- tests/interpreters/test_shaped_array.py | 2 +- tests/interpreters/test_snitch_stream_interpreter.py | 2 +- tests/interpreters/test_tensor_interpreter.py | 2 +- xdsl/backend/riscv/lowering/convert_memref_to_riscv.py | 2 +- xdsl/interpreters/builtin.py | 2 +- xdsl/interpreters/memref.py | 2 +- xdsl/interpreters/ml_program.py | 2 +- xdsl/interpreters/onnx.py | 2 +- xdsl/interpreters/{utils => }/ptr.py | 0 xdsl/interpreters/riscv.py | 2 +- xdsl/interpreters/riscv_libc.py | 2 +- xdsl/interpreters/shaped_array.py | 2 +- xdsl/interpreters/snitch_stream.py | 2 +- xdsl/interpreters/tensor.py | 2 +- 25 files changed, 24 insertions(+), 24 deletions(-) rename xdsl/interpreters/{utils => }/ptr.py (100%) diff --git a/docs/Toy/toy/interpreter.py b/docs/Toy/toy/interpreter.py index f4322f4a01..9b7c335fd4 100644 --- a/docs/Toy/toy/interpreter.py +++ b/docs/Toy/toy/interpreter.py @@ -14,8 +14,8 @@ impl_terminator, register_impls, ) +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from .dialects import toy as toy diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py index b13cf85e43..efb5efcb86 100644 --- a/docs/marimo/linalg_snitch.py +++ b/docs/marimo/linalg_snitch.py @@ -47,7 +47,7 @@ def __(): from xdsl.dialects import arith, func, linalg from xdsl.dialects.builtin import AffineMap, AffineMapAttr, MemRefType, ModuleOp, f64 from xdsl.dialects.riscv import riscv_code - from xdsl.interpreters.utils.ptr import TypedPtr + from xdsl.interpreters.ptr import TypedPtr from xdsl.ir import Attribute, Block, Region, SSAValue from xdsl.passes import PipelinePass from xdsl.tools.command_line_tool import get_all_dialects diff --git a/tests/interpreters/test_affine_interpreter.py b/tests/interpreters/test_affine_interpreter.py index a246a43ddd..74ab1c6c04 100644 --- a/tests/interpreters/test_affine_interpreter.py +++ b/tests/interpreters/test_affine_interpreter.py @@ -8,8 +8,8 @@ from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.func import FuncFunctions from xdsl.interpreters.memref import MemrefFunctions +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir.affine import AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_builtin_interpreter.py b/tests/interpreters/test_builtin_interpreter.py index 30acf485af..e0109f654a 100644 --- a/tests/interpreters/test_builtin_interpreter.py +++ b/tests/interpreters/test_builtin_interpreter.py @@ -10,9 +10,9 @@ i64, ) from xdsl.interpreter import Interpreter +from xdsl.interpreters import ptr from xdsl.interpreters.builtin import BuiltinFunctions from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils import ptr interpreter = Interpreter(ModuleOp([])) interpreter.register_implementations(BuiltinFunctions()) diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py index 70376ab67f..ebdeda1089 100644 --- a/tests/interpreters/test_linalg_interpreter.py +++ b/tests/interpreters/test_linalg_interpreter.py @@ -18,8 +18,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.linalg import LinalgFunctions +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_memref_interpreter.py b/tests/interpreters/test_memref_interpreter.py index 41dd0d5a81..72d0120c32 100644 --- a/tests/interpreters/test_memref_interpreter.py +++ b/tests/interpreters/test_memref_interpreter.py @@ -11,8 +11,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.memref import MemrefFunctions +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr interpreter = Interpreter(ModuleOp([]), index_bitwidth=32) interpreter.register_implementations(ArithFunctions()) diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py index a19df81e42..462bf6ac25 100644 --- a/tests/interpreters/test_memref_stream_interpreter.py +++ b/tests/interpreters/test_memref_stream_interpreter.py @@ -16,8 +16,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.memref_stream import MemrefStreamFunctions +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_ml_program_interpreter.py b/tests/interpreters/test_ml_program_interpreter.py index f7ec07d991..c21bcf5a54 100644 --- a/tests/interpreters/test_ml_program_interpreter.py +++ b/tests/interpreters/test_ml_program_interpreter.py @@ -11,8 +11,8 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.ml_program import MLProgramFunctions +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr def test_ml_program_global_load_constant(): diff --git a/tests/interpreters/test_onnx_interpreter.py b/tests/interpreters/test_onnx_interpreter.py index 1a337a4374..4396126b73 100644 --- a/tests/interpreters/test_onnx_interpreter.py +++ b/tests/interpreters/test_onnx_interpreter.py @@ -15,8 +15,8 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.builtin import BuiltinFunctions +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_ptr.py b/tests/interpreters/test_ptr.py index 94ae1926e8..ca04a9faa7 100644 --- a/tests/interpreters/test_ptr.py +++ b/tests/interpreters/test_ptr.py @@ -1,4 +1,4 @@ -from xdsl.interpreters.utils import ptr +from xdsl.interpreters import ptr def test_ptr(): diff --git a/tests/interpreters/test_riscv_interpreter.py b/tests/interpreters/test_riscv_interpreter.py index e3eafe532a..dca1e77180 100644 --- a/tests/interpreters/test_riscv_interpreter.py +++ b/tests/interpreters/test_riscv_interpreter.py @@ -14,8 +14,8 @@ i32, ) from xdsl.interpreter import Interpreter, PythonValues +from xdsl.interpreters.ptr import RawPtr, TypedPtr from xdsl.interpreters.riscv import RiscvFunctions -from xdsl.interpreters.utils.ptr import RawPtr, TypedPtr from xdsl.ir import Block, Region from xdsl.utils.bitwise_casts import convert_f32_to_u32 from xdsl.utils.exceptions import InterpretationError diff --git a/tests/interpreters/test_shaped_array.py b/tests/interpreters/test_shaped_array.py index afbe453191..755e7b1d82 100644 --- a/tests/interpreters/test_shaped_array.py +++ b/tests/interpreters/test_shaped_array.py @@ -1,5 +1,5 @@ +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr def test_shaped_array_offset(): diff --git a/tests/interpreters/test_snitch_stream_interpreter.py b/tests/interpreters/test_snitch_stream_interpreter.py index d1fd1dcd7a..de1711eb0b 100644 --- a/tests/interpreters/test_snitch_stream_interpreter.py +++ b/tests/interpreters/test_snitch_stream_interpreter.py @@ -4,10 +4,10 @@ from xdsl.dialects import riscv, riscv_snitch, snitch_stream, stream from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.interpreter import Interpreter +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.riscv import RiscvFunctions from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions from xdsl.interpreters.snitch_stream import SnitchStreamFunctions -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_tensor_interpreter.py b/tests/interpreters/test_tensor_interpreter.py index 129bcab325..5b981dd4c1 100644 --- a/tests/interpreters/test_tensor_interpreter.py +++ b/tests/interpreters/test_tensor_interpreter.py @@ -3,9 +3,9 @@ from xdsl.dialects import tensor from xdsl.dialects.builtin import ModuleOp, TensorType, f32, i32 from xdsl.interpreter import Interpreter +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray from xdsl.interpreters.tensor import TensorFunctions -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import TestSSAValue diff --git a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py index ef13fa597c..9de6b21a44 100644 --- a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py +++ b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py @@ -24,7 +24,7 @@ SymbolRefAttr, UnrealizedConversionCastOp, ) -from xdsl.interpreters.utils.ptr import TypedPtr +from xdsl.interpreters.ptr import TypedPtr from xdsl.ir import Attribute, Operation, Region, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( diff --git a/xdsl/interpreters/builtin.py b/xdsl/interpreters/builtin.py index 5128205a8d..da453e61fc 100644 --- a/xdsl/interpreters/builtin.py +++ b/xdsl/interpreters/builtin.py @@ -16,8 +16,8 @@ impl_attr, register_impls, ) +from xdsl.interpreters import ptr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils import ptr from xdsl.ir import Attribute from xdsl.utils.hints import isa diff --git a/xdsl/interpreters/memref.py b/xdsl/interpreters/memref.py index 34ea29be80..edacc913d3 100644 --- a/xdsl/interpreters/memref.py +++ b/xdsl/interpreters/memref.py @@ -10,8 +10,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute from xdsl.traits import SymbolTable diff --git a/xdsl/interpreters/ml_program.py b/xdsl/interpreters/ml_program.py index 67d02e1375..4cced0a800 100644 --- a/xdsl/interpreters/ml_program.py +++ b/xdsl/interpreters/ml_program.py @@ -9,8 +9,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.traits import SymbolTable diff --git a/xdsl/interpreters/onnx.py b/xdsl/interpreters/onnx.py index eb87701f98..b85549fa62 100644 --- a/xdsl/interpreters/onnx.py +++ b/xdsl/interpreters/onnx.py @@ -12,9 +12,9 @@ impl, register_impls, ) +from xdsl.interpreters import ptr from xdsl.interpreters.builtin import xtype_for_el_type from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils import ptr from xdsl.utils.exceptions import InterpretationError diff --git a/xdsl/interpreters/utils/ptr.py b/xdsl/interpreters/ptr.py similarity index 100% rename from xdsl/interpreters/utils/ptr.py rename to xdsl/interpreters/ptr.py diff --git a/xdsl/interpreters/riscv.py b/xdsl/interpreters/riscv.py index 13cb1e0a5b..1412a23ea0 100644 --- a/xdsl/interpreters/riscv.py +++ b/xdsl/interpreters/riscv.py @@ -20,8 +20,8 @@ impl_cast, register_impls, ) +from xdsl.interpreters import ptr from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.utils import ptr from xdsl.ir import Attribute, SSAValue from xdsl.utils.bitwise_casts import convert_u32_to_f32 from xdsl.utils.comparisons import to_signed, to_unsigned diff --git a/xdsl/interpreters/riscv_libc.py b/xdsl/interpreters/riscv_libc.py index 5b43994e7a..19d6cb5652 100644 --- a/xdsl/interpreters/riscv_libc.py +++ b/xdsl/interpreters/riscv_libc.py @@ -7,7 +7,7 @@ impl_external, register_impls, ) -from xdsl.interpreters.utils import ptr +from xdsl.interpreters import ptr from xdsl.ir import Operation diff --git a/xdsl/interpreters/shaped_array.py b/xdsl/interpreters/shaped_array.py index fbbe8dff51..4b1c65acc5 100644 --- a/xdsl/interpreters/shaped_array.py +++ b/xdsl/interpreters/shaped_array.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.builtin import ShapedType -from xdsl.interpreters.utils.ptr import TypedPtr +from xdsl.interpreters.ptr import TypedPtr _T = TypeVar("_T") diff --git a/xdsl/interpreters/snitch_stream.py b/xdsl/interpreters/snitch_stream.py index 1edec09e89..cdefa22738 100644 --- a/xdsl/interpreters/snitch_stream.py +++ b/xdsl/interpreters/snitch_stream.py @@ -10,7 +10,7 @@ impl, register_impls, ) -from xdsl.interpreters.utils import ptr +from xdsl.interpreters import ptr from xdsl.interpreters.utils.stream import ReadableStream, WritableStream diff --git a/xdsl/interpreters/tensor.py b/xdsl/interpreters/tensor.py index 58bbcfd7ca..cd66b484d6 100644 --- a/xdsl/interpreters/tensor.py +++ b/xdsl/interpreters/tensor.py @@ -10,8 +10,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type +from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute from xdsl.utils.exceptions import InterpretationError From 4ed292c7f20a9767f6f7906432330a6d79282ba5 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 09:38:40 +0000 Subject: [PATCH 4/9] interpreter: move ptr.py to utils folder --- docs/Toy/toy/interpreter.py | 2 +- docs/marimo/linalg_snitch.py | 2 +- tests/interpreters/test_affine_interpreter.py | 2 +- tests/interpreters/test_builtin_interpreter.py | 2 +- tests/interpreters/test_linalg_interpreter.py | 2 +- tests/interpreters/test_memref_interpreter.py | 2 +- tests/interpreters/test_memref_stream_interpreter.py | 2 +- tests/interpreters/test_ml_program_interpreter.py | 2 +- tests/interpreters/test_onnx_interpreter.py | 2 +- tests/interpreters/test_ptr.py | 2 +- tests/interpreters/test_riscv_interpreter.py | 2 +- tests/interpreters/test_shaped_array.py | 2 +- tests/interpreters/test_snitch_stream_interpreter.py | 2 +- tests/interpreters/test_tensor_interpreter.py | 2 +- xdsl/backend/riscv/lowering/convert_memref_to_riscv.py | 2 +- xdsl/interpreters/builtin.py | 2 +- xdsl/interpreters/memref.py | 2 +- xdsl/interpreters/ml_program.py | 2 +- xdsl/interpreters/onnx.py | 2 +- xdsl/interpreters/riscv.py | 2 +- xdsl/interpreters/riscv_libc.py | 2 +- xdsl/interpreters/shaped_array.py | 2 +- xdsl/interpreters/snitch_stream.py | 2 +- xdsl/interpreters/tensor.py | 2 +- xdsl/interpreters/{ => utils}/ptr.py | 0 25 files changed, 24 insertions(+), 24 deletions(-) rename xdsl/interpreters/{ => utils}/ptr.py (100%) diff --git a/docs/Toy/toy/interpreter.py b/docs/Toy/toy/interpreter.py index 9b7c335fd4..f4322f4a01 100644 --- a/docs/Toy/toy/interpreter.py +++ b/docs/Toy/toy/interpreter.py @@ -14,8 +14,8 @@ impl_terminator, register_impls, ) -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from .dialects import toy as toy diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py index efb5efcb86..b13cf85e43 100644 --- a/docs/marimo/linalg_snitch.py +++ b/docs/marimo/linalg_snitch.py @@ -47,7 +47,7 @@ def __(): from xdsl.dialects import arith, func, linalg from xdsl.dialects.builtin import AffineMap, AffineMapAttr, MemRefType, ModuleOp, f64 from xdsl.dialects.riscv import riscv_code - from xdsl.interpreters.ptr import TypedPtr + from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute, Block, Region, SSAValue from xdsl.passes import PipelinePass from xdsl.tools.command_line_tool import get_all_dialects diff --git a/tests/interpreters/test_affine_interpreter.py b/tests/interpreters/test_affine_interpreter.py index 74ab1c6c04..a246a43ddd 100644 --- a/tests/interpreters/test_affine_interpreter.py +++ b/tests/interpreters/test_affine_interpreter.py @@ -8,8 +8,8 @@ from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.func import FuncFunctions from xdsl.interpreters.memref import MemrefFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir.affine import AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_builtin_interpreter.py b/tests/interpreters/test_builtin_interpreter.py index e0109f654a..30acf485af 100644 --- a/tests/interpreters/test_builtin_interpreter.py +++ b/tests/interpreters/test_builtin_interpreter.py @@ -10,9 +10,9 @@ i64, ) from xdsl.interpreter import Interpreter -from xdsl.interpreters import ptr from xdsl.interpreters.builtin import BuiltinFunctions from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils import ptr interpreter = Interpreter(ModuleOp([])) interpreter.register_implementations(BuiltinFunctions()) diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py index ebdeda1089..70376ab67f 100644 --- a/tests/interpreters/test_linalg_interpreter.py +++ b/tests/interpreters/test_linalg_interpreter.py @@ -18,8 +18,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.linalg import LinalgFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_memref_interpreter.py b/tests/interpreters/test_memref_interpreter.py index 72d0120c32..41dd0d5a81 100644 --- a/tests/interpreters/test_memref_interpreter.py +++ b/tests/interpreters/test_memref_interpreter.py @@ -11,8 +11,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.memref import MemrefFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr interpreter = Interpreter(ModuleOp([]), index_bitwidth=32) interpreter.register_implementations(ArithFunctions()) diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py index 462bf6ac25..a19df81e42 100644 --- a/tests/interpreters/test_memref_stream_interpreter.py +++ b/tests/interpreters/test_memref_stream_interpreter.py @@ -16,8 +16,8 @@ from xdsl.interpreter import Interpreter from xdsl.interpreters.arith import ArithFunctions from xdsl.interpreters.memref_stream import MemrefStreamFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_ml_program_interpreter.py b/tests/interpreters/test_ml_program_interpreter.py index c21bcf5a54..f7ec07d991 100644 --- a/tests/interpreters/test_ml_program_interpreter.py +++ b/tests/interpreters/test_ml_program_interpreter.py @@ -11,8 +11,8 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.ml_program import MLProgramFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr def test_ml_program_global_load_constant(): diff --git a/tests/interpreters/test_onnx_interpreter.py b/tests/interpreters/test_onnx_interpreter.py index 4396126b73..1a337a4374 100644 --- a/tests/interpreters/test_onnx_interpreter.py +++ b/tests/interpreters/test_onnx_interpreter.py @@ -15,8 +15,8 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.builtin import BuiltinFunctions -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_ptr.py b/tests/interpreters/test_ptr.py index ca04a9faa7..94ae1926e8 100644 --- a/tests/interpreters/test_ptr.py +++ b/tests/interpreters/test_ptr.py @@ -1,4 +1,4 @@ -from xdsl.interpreters import ptr +from xdsl.interpreters.utils import ptr def test_ptr(): diff --git a/tests/interpreters/test_riscv_interpreter.py b/tests/interpreters/test_riscv_interpreter.py index dca1e77180..e3eafe532a 100644 --- a/tests/interpreters/test_riscv_interpreter.py +++ b/tests/interpreters/test_riscv_interpreter.py @@ -14,8 +14,8 @@ i32, ) from xdsl.interpreter import Interpreter, PythonValues -from xdsl.interpreters.ptr import RawPtr, TypedPtr from xdsl.interpreters.riscv import RiscvFunctions +from xdsl.interpreters.utils.ptr import RawPtr, TypedPtr from xdsl.ir import Block, Region from xdsl.utils.bitwise_casts import convert_f32_to_u32 from xdsl.utils.exceptions import InterpretationError diff --git a/tests/interpreters/test_shaped_array.py b/tests/interpreters/test_shaped_array.py index 755e7b1d82..afbe453191 100644 --- a/tests/interpreters/test_shaped_array.py +++ b/tests/interpreters/test_shaped_array.py @@ -1,5 +1,5 @@ -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr def test_shaped_array_offset(): diff --git a/tests/interpreters/test_snitch_stream_interpreter.py b/tests/interpreters/test_snitch_stream_interpreter.py index de1711eb0b..d1fd1dcd7a 100644 --- a/tests/interpreters/test_snitch_stream_interpreter.py +++ b/tests/interpreters/test_snitch_stream_interpreter.py @@ -4,10 +4,10 @@ from xdsl.dialects import riscv, riscv_snitch, snitch_stream, stream from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.interpreter import Interpreter -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.riscv import RiscvFunctions from xdsl.interpreters.riscv_snitch import RiscvSnitchFunctions from xdsl.interpreters.snitch_stream import SnitchStreamFunctions +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Block, Region from xdsl.utils.test_value import TestSSAValue diff --git a/tests/interpreters/test_tensor_interpreter.py b/tests/interpreters/test_tensor_interpreter.py index 5b981dd4c1..129bcab325 100644 --- a/tests/interpreters/test_tensor_interpreter.py +++ b/tests/interpreters/test_tensor_interpreter.py @@ -3,9 +3,9 @@ from xdsl.dialects import tensor from xdsl.dialects.builtin import ModuleOp, TensorType, f32, i32 from xdsl.interpreter import Interpreter -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray from xdsl.interpreters.tensor import TensorFunctions +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import TestSSAValue diff --git a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py index 9de6b21a44..ef13fa597c 100644 --- a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py +++ b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py @@ -24,7 +24,7 @@ SymbolRefAttr, UnrealizedConversionCastOp, ) -from xdsl.interpreters.ptr import TypedPtr +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute, Operation, Region, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( diff --git a/xdsl/interpreters/builtin.py b/xdsl/interpreters/builtin.py index da453e61fc..5128205a8d 100644 --- a/xdsl/interpreters/builtin.py +++ b/xdsl/interpreters/builtin.py @@ -16,8 +16,8 @@ impl_attr, register_impls, ) -from xdsl.interpreters import ptr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils import ptr from xdsl.ir import Attribute from xdsl.utils.hints import isa diff --git a/xdsl/interpreters/memref.py b/xdsl/interpreters/memref.py index edacc913d3..34ea29be80 100644 --- a/xdsl/interpreters/memref.py +++ b/xdsl/interpreters/memref.py @@ -10,8 +10,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute from xdsl.traits import SymbolTable diff --git a/xdsl/interpreters/ml_program.py b/xdsl/interpreters/ml_program.py index 4cced0a800..67d02e1375 100644 --- a/xdsl/interpreters/ml_program.py +++ b/xdsl/interpreters/ml_program.py @@ -9,8 +9,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.traits import SymbolTable diff --git a/xdsl/interpreters/onnx.py b/xdsl/interpreters/onnx.py index b85549fa62..eb87701f98 100644 --- a/xdsl/interpreters/onnx.py +++ b/xdsl/interpreters/onnx.py @@ -12,9 +12,9 @@ impl, register_impls, ) -from xdsl.interpreters import ptr from xdsl.interpreters.builtin import xtype_for_el_type from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils import ptr from xdsl.utils.exceptions import InterpretationError diff --git a/xdsl/interpreters/riscv.py b/xdsl/interpreters/riscv.py index 1412a23ea0..13cb1e0a5b 100644 --- a/xdsl/interpreters/riscv.py +++ b/xdsl/interpreters/riscv.py @@ -20,8 +20,8 @@ impl_cast, register_impls, ) -from xdsl.interpreters import ptr from xdsl.interpreters.builtin import xtype_for_el_type +from xdsl.interpreters.utils import ptr from xdsl.ir import Attribute, SSAValue from xdsl.utils.bitwise_casts import convert_u32_to_f32 from xdsl.utils.comparisons import to_signed, to_unsigned diff --git a/xdsl/interpreters/riscv_libc.py b/xdsl/interpreters/riscv_libc.py index 19d6cb5652..5b43994e7a 100644 --- a/xdsl/interpreters/riscv_libc.py +++ b/xdsl/interpreters/riscv_libc.py @@ -7,7 +7,7 @@ impl_external, register_impls, ) -from xdsl.interpreters import ptr +from xdsl.interpreters.utils import ptr from xdsl.ir import Operation diff --git a/xdsl/interpreters/shaped_array.py b/xdsl/interpreters/shaped_array.py index 4b1c65acc5..fbbe8dff51 100644 --- a/xdsl/interpreters/shaped_array.py +++ b/xdsl/interpreters/shaped_array.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.builtin import ShapedType -from xdsl.interpreters.ptr import TypedPtr +from xdsl.interpreters.utils.ptr import TypedPtr _T = TypeVar("_T") diff --git a/xdsl/interpreters/snitch_stream.py b/xdsl/interpreters/snitch_stream.py index cdefa22738..1edec09e89 100644 --- a/xdsl/interpreters/snitch_stream.py +++ b/xdsl/interpreters/snitch_stream.py @@ -10,7 +10,7 @@ impl, register_impls, ) -from xdsl.interpreters import ptr +from xdsl.interpreters.utils import ptr from xdsl.interpreters.utils.stream import ReadableStream, WritableStream diff --git a/xdsl/interpreters/tensor.py b/xdsl/interpreters/tensor.py index cd66b484d6..58bbcfd7ca 100644 --- a/xdsl/interpreters/tensor.py +++ b/xdsl/interpreters/tensor.py @@ -10,8 +10,8 @@ register_impls, ) from xdsl.interpreters.builtin import xtype_for_el_type -from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.interpreters.utils.ptr import TypedPtr from xdsl.ir import Attribute from xdsl.utils.exceptions import InterpretationError diff --git a/xdsl/interpreters/ptr.py b/xdsl/interpreters/utils/ptr.py similarity index 100% rename from xdsl/interpreters/ptr.py rename to xdsl/interpreters/utils/ptr.py From 8968ff3c01e86dbbe6edc359c71c8dd96c33d013 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 09:53:37 +0000 Subject: [PATCH 5/9] misc: split dialects.utils into multiple files --- tests/test_dialect_utils.py | 2 +- xdsl/dialects/arith.py | 2 +- xdsl/dialects/csl/csl.py | 2 +- xdsl/dialects/csl/csl_stencil.py | 2 +- xdsl/dialects/experimental/air.py | 2 +- xdsl/dialects/func.py | 2 +- xdsl/dialects/linalg.py | 2 +- xdsl/dialects/llvm.py | 2 +- xdsl/dialects/memref.py | 2 +- xdsl/dialects/memref_stream.py | 2 +- xdsl/dialects/omp.py | 2 +- xdsl/dialects/riscv.py | 2 +- xdsl/dialects/riscv_func.py | 2 +- xdsl/dialects/riscv_scf.py | 2 +- xdsl/dialects/riscv_snitch.py | 2 +- xdsl/dialects/scf.py | 2 +- xdsl/dialects/utils/__init__.py | 0 xdsl/dialects/utils/fast_math.py | 29 ++++++++++++++++ xdsl/dialects/{utils.py => utils/format.py} | 34 ------------------- xdsl/transforms/arith_add_fastmath.py | 2 +- .../canonicalization_patterns/riscv.py | 2 +- 21 files changed, 47 insertions(+), 52 deletions(-) create mode 100644 xdsl/dialects/utils/__init__.py create mode 100644 xdsl/dialects/utils/fast_math.py rename xdsl/dialects/{utils.py => utils/format.py} (93%) diff --git a/tests/test_dialect_utils.py b/tests/test_dialect_utils.py index 171d95373f..3c80c06622 100644 --- a/tests/test_dialect_utils.py +++ b/tests/test_dialect_utils.py @@ -4,7 +4,7 @@ from xdsl.context import MLContext from xdsl.dialects.builtin import DYNAMIC_INDEX, IndexType, IntegerType, i32 -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_dynamic_index_list_with_types, parse_dynamic_index_list_without_types, parse_dynamic_index_with_type, diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 34e09a4241..7300fc2f53 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -22,7 +22,7 @@ UnrankedTensorType, VectorType, ) -from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue from xdsl.irdl import ( AnyOf, diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index f17bb8be7c..d5257a6e08 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -37,7 +37,7 @@ SymbolRefAttr, TensorType, ) -from xdsl.dialects.utils import parse_func_op_like, print_func_op_like +from xdsl.dialects.utils.format import parse_func_op_like, print_func_op_like from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index eca5bc9cad..c9956ae69b 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -17,7 +17,7 @@ TensorType, ) from xdsl.dialects.experimental import dmp -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/experimental/air.py b/xdsl/dialects/experimental/air.py index f87b2b1795..2e592257b1 100644 --- a/xdsl/dialects/experimental/air.py +++ b/xdsl/dialects/experimental/air.py @@ -20,7 +20,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 8131cdf6df..480676f996 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -10,7 +10,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 466e6a3b41..484ef20e29 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -25,7 +25,7 @@ TensorType, i64, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, ) from xdsl.ir import ( diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index dd3752ff43..10d34717a5 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -24,7 +24,7 @@ i32, i64, ) -from xdsl.dialects.utils import FastMathAttrBase +from xdsl.dialects.utils.fast_math import FastMathAttrBase from xdsl.ir import ( Attribute, BitEnumAttribute, diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index bfc9bf8d8c..eb50974e89 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -30,7 +30,7 @@ i32, i64, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_dynamic_index_list_without_types, print_dynamic_index_list, ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 9335d48c2a..c94ae99d17 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -25,7 +25,7 @@ IntegerType, StringAttr, ) -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/omp.py b/xdsl/dialects/omp.py index c83058e145..2eec70a374 100644 --- a/xdsl/dialects/omp.py +++ b/xdsl/dialects/omp.py @@ -9,7 +9,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils import AbstractYieldOperation +from xdsl.dialects.utils.format import AbstractYieldOperation from xdsl.ir import ( Attribute, Dialect, diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index 90b80271c4..1b0e280c25 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -24,7 +24,7 @@ UnitAttr, i32, ) -from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathAttrBase, FastMathFlag from xdsl.ir import ( Attribute, Block, diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py index d9e8a63704..610ceee50a 100644 --- a/xdsl/dialects/riscv_func.py +++ b/xdsl/dialects/riscv_func.py @@ -12,7 +12,7 @@ StringAttr, SymbolRefAttr, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( parse_call_op_like, parse_func_op_like, print_call_op_like, diff --git a/xdsl/dialects/riscv_scf.py b/xdsl/dialects/riscv_scf.py index cea09406c8..67e16f80e9 100644 --- a/xdsl/dialects/riscv_scf.py +++ b/xdsl/dialects/riscv_scf.py @@ -10,7 +10,7 @@ from typing_extensions import Self from xdsl.dialects.riscv import IntRegisterType, RISCVRegisterType -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 25d0cafcaa..94ce963e95 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -33,7 +33,7 @@ print_immediate_value, si12, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index b29cfbb42c..0c2ac49230 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -12,7 +12,7 @@ SignlessIntegerConstraint, i64, ) -from xdsl.dialects.utils import ( +from xdsl.dialects.utils.format import ( AbstractYieldOperation, parse_assignment, print_assignment, diff --git a/xdsl/dialects/utils/__init__.py b/xdsl/dialects/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/xdsl/dialects/utils/fast_math.py b/xdsl/dialects/utils/fast_math.py new file mode 100644 index 0000000000..c12b8b5878 --- /dev/null +++ b/xdsl/dialects/utils/fast_math.py @@ -0,0 +1,29 @@ +from abc import ABC +from dataclasses import dataclass + +from xdsl.ir import BitEnumAttribute +from xdsl.utils.str_enum import StrEnum + + +class FastMathFlag(StrEnum): + """ + Values specifying fast math behaviour of an arithmetic operation. + """ + + REASSOC = "reassoc" + NO_NANS = "nnan" + NO_INFS = "ninf" + NO_SIGNED_ZEROS = "nsz" + ALLOW_RECIP = "arcp" + ALLOW_CONTRACT = "contract" + APPROX_FUNC = "afn" + + +@dataclass(frozen=True, init=False) +class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC): + """ + Base class for attributes defining fast math behavior of arithmetic operations. + """ + + none_value = "none" + all_value = "fast" diff --git a/xdsl/dialects/utils.py b/xdsl/dialects/utils/format.py similarity index 93% rename from xdsl/dialects/utils.py rename to xdsl/dialects/utils/format.py index 5b5fb8b0f9..ae7965d213 100644 --- a/xdsl/dialects/utils.py +++ b/xdsl/dialects/utils/format.py @@ -1,6 +1,4 @@ -from abc import ABC from collections.abc import Iterable, Sequence -from dataclasses import dataclass from typing import Generic from xdsl.dialects.builtin import ( @@ -14,7 +12,6 @@ from xdsl.ir import ( Attribute, AttributeInvT, - BitEnumAttribute, BlockArgument, Operation, Region, @@ -23,7 +20,6 @@ from xdsl.irdl import IRDLOperation, var_operand_def from xdsl.parser import Parser, UnresolvedOperand from xdsl.printer import Printer -from xdsl.utils.str_enum import StrEnum def print_call_op_like( @@ -345,33 +341,3 @@ def parse_dynamic_index_list_without_types( values.append(value_or_index) return values, indices - - -# region Fast Math Flags - - -class FastMathFlag(StrEnum): - """ - Values specifying fast math behaviour of an arithmetic operation. - """ - - REASSOC = "reassoc" - NO_NANS = "nnan" - NO_INFS = "ninf" - NO_SIGNED_ZEROS = "nsz" - ALLOW_RECIP = "arcp" - ALLOW_CONTRACT = "contract" - APPROX_FUNC = "afn" - - -@dataclass(frozen=True, init=False) -class FastMathAttrBase(BitEnumAttribute[FastMathFlag], ABC): - """ - Base class for attributes defining fast math behavior of arithmetic operations. - """ - - none_value = "none" - all_value = "fast" - - -# endregion diff --git a/xdsl/transforms/arith_add_fastmath.py b/xdsl/transforms/arith_add_fastmath.py index c60f54bd12..e9221be204 100644 --- a/xdsl/transforms/arith_add_fastmath.py +++ b/xdsl/transforms/arith_add_fastmath.py @@ -4,7 +4,7 @@ from typing import Literal from xdsl.dialects import arith, builtin -from xdsl.dialects.utils import FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathFlag from xdsl.passes import MLContext, ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, diff --git a/xdsl/transforms/canonicalization_patterns/riscv.py b/xdsl/transforms/canonicalization_patterns/riscv.py index 519c5b0813..6588f84cf9 100644 --- a/xdsl/transforms/canonicalization_patterns/riscv.py +++ b/xdsl/transforms/canonicalization_patterns/riscv.py @@ -2,7 +2,7 @@ from xdsl.dialects import riscv, riscv_snitch from xdsl.dialects.builtin import IntegerAttr -from xdsl.dialects.utils import FastMathFlag +from xdsl.dialects.utils.fast_math import FastMathFlag from xdsl.ir import OpResult, SSAValue from xdsl.pattern_rewriter import ( PatternRewriter, From 05f4d5cd478e5386b34a9a9251096347be26ea68 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 10:17:38 +0000 Subject: [PATCH 6/9] dialects: (stream) simplify constr helper on stream attributes --- xdsl/dialects/memref_stream.py | 2 +- xdsl/dialects/stream.py | 89 +++++----------------------------- 2 files changed, 13 insertions(+), 78 deletions(-) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index c94ae99d17..13ef7a7d30 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -370,7 +370,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.WritableStreamType.constr()) + outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 3c2d5ea2d6..2243197121 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Generic, TypeVar, cast, overload +from typing import ClassVar, Generic, TypeVar, cast from typing_extensions import Self @@ -46,30 +46,11 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[StreamType[Attribute]]: ... - - @overload - @staticmethod + @classmethod def constr( - *, + cls, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[StreamType[Attribute]] - | ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[StreamType[Attribute]](StreamType) + ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( StreamType, (element_type,) ) @@ -79,66 +60,20 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[ReadableStreamType[Attribute]]: ... - @overload - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[ReadableStreamType[Attribute]] - | ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[ReadableStreamType[Attribute]](ReadableStreamType) - return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( - ReadableStreamType, (element_type,) - ) +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) @irdl_attr_definition class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" - @overload - @staticmethod - def constr( - *, - element_type: None = None, - ) -> BaseAttr[WritableStreamType[Attribute]]: ... - @overload - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: ... - - @staticmethod - def constr( - *, - element_type: GenericAttrConstraint[_StreamTypeElementConstrT] | None = None, - ) -> ( - BaseAttr[WritableStreamType[Attribute]] - | ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]] - ): - if element_type is None: - return BaseAttr[WritableStreamType[Attribute]](WritableStreamType) - return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( - WritableStreamType, (element_type,) - ) +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) class ReadOperation(IRDLOperation, abc.ABC): @@ -148,7 +83,7 @@ class ReadOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(ReadableStreamType.constr(element_type=T)) + stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) def __init__(self, stream: SSAValue, result_type: Attribute | None = None): @@ -182,7 +117,7 @@ class WriteOperation(IRDLOperation, abc.ABC): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(WritableStreamType.constr(element_type=T)) + stream = operand_def(WritableStreamType.constr(T)) def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) From a1b4afb58b808a8582848e110bf0cbe3fc65f25b Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 14:17:16 +0000 Subject: [PATCH 7/9] back to static --- xdsl/dialects/stream.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 2243197121..8486b303cd 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -46,9 +46,8 @@ def __init__(self, element_type: _StreamTypeElement): def get_element_type(self) -> _StreamTypeElement: return self.element_type - @classmethod + @staticmethod def constr( - cls, element_type: GenericAttrConstraint[_StreamTypeElementConstrT], ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( @@ -60,6 +59,14 @@ def constr( class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.readable" + @staticmethod + def constr( + element_type: GenericAttrConstraint[_StreamTypeElementConstrT], + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: + return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( + ReadableStreamType, (element_type,) + ) + AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( ReadableStreamType @@ -70,6 +77,14 @@ class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElem class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): name = "stream.writable" + @staticmethod + def constr( + element_type: GenericAttrConstraint[_StreamTypeElementConstrT], + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: + return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( + WritableStreamType, (element_type,) + ) + AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( WritableStreamType From 284e642bd7713bd756cc44ae592b22eebd3ee1cb Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 08:31:24 +0000 Subject: [PATCH 8/9] dialects: (stream) use assembly format for stream read and write ops --- xdsl/dialects/stream.py | 44 ++++------------------------------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index 8486b303cd..e95e051fc8 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -3,8 +3,6 @@ import abc from typing import ClassVar, Generic, TypeVar, cast -from typing_extensions import Self - from xdsl.dialects.builtin import ContainerType from xdsl.ir import ( Attribute, @@ -25,8 +23,6 @@ operand_def, result_def, ) -from xdsl.parser import Parser -from xdsl.printer import Printer _StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) _StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) @@ -101,6 +97,8 @@ class ReadOperation(IRDLOperation, abc.ABC): stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) + assembly_format = "`from` $stream attr-dict `:` type($res)" + def __init__(self, stream: SSAValue, result_type: Attribute | None = None): if result_type is None: assert isinstance(stream_type := stream.type, ReadableStreamType) @@ -108,21 +106,6 @@ def __init__(self, stream: SSAValue, result_type: Attribute | None = None): result_type = stream_type.element_type super().__init__(operands=[stream], result_types=[result_type]) - @classmethod - def parse(cls, parser: Parser) -> Self: - parser.parse_characters("from") - unresolved = parser.parse_unresolved_operand() - parser.parse_punctuation(":") - result_type = parser.parse_attribute() - resolved = parser.resolve_operand(unresolved, ReadableStreamType(result_type)) - return cls(resolved, result_type) - - def print(self, printer: Printer): - printer.print_string(" from ") - printer.print(self.stream) - printer.print_string(" : ") - printer.print_attribute(self.res.type) - class WriteOperation(IRDLOperation, abc.ABC): """ @@ -134,30 +117,11 @@ class WriteOperation(IRDLOperation, abc.ABC): value = operand_def(T) stream = operand_def(WritableStreamType.constr(T)) + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + def __init__(self, value: SSAValue, stream: SSAValue): super().__init__(operands=[value, stream]) - @classmethod - def parse(cls, parser: Parser) -> Self: - unresolved_value = parser.parse_unresolved_operand() - parser.parse_characters("to") - unresolved_stream = parser.parse_unresolved_operand() - parser.parse_punctuation(":") - result_type = parser.parse_attribute() - resolved_value = parser.resolve_operand(unresolved_value, result_type) - resolved_stream = parser.resolve_operand( - unresolved_stream, WritableStreamType(result_type) - ) - return cls(resolved_value, resolved_stream) - - def print(self, printer: Printer): - printer.print_string(" ") - printer.print_ssa_value(self.value) - printer.print_string(" to ") - printer.print_ssa_value(self.stream) - printer.print_string(" : ") - printer.print_attribute(self.value.type) - Stream = Dialect( "stream", From 9d6032af7a1bffb090ffbf1a008c77d58523bdcd Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Wed, 20 Nov 2024 10:01:50 +0000 Subject: [PATCH 9/9] dialects: (stream) remove abstract stream read and write operations --- xdsl/dialects/memref_stream.py | 35 ++++++++++++++++++++++++-- xdsl/dialects/riscv_snitch.py | 29 ++++++++++++++++++++-- xdsl/dialects/stream.py | 45 +--------------------------------- 3 files changed, 61 insertions(+), 48 deletions(-) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 13ef7a7d30..c08f9dae2c 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -46,6 +46,7 @@ opt_prop_def, prop_def, region_def, + result_def, traits_def, var_operand_def, ) @@ -187,14 +188,44 @@ def offsets(self) -> tuple[tuple[int, ...], ...]: @irdl_op_definition -class ReadOp(stream.ReadOperation): +class ReadOp(IRDLOperation): name = "memref_stream.read" + T: ClassVar = VarConstraint("T", AnyAttr()) + + stream = operand_def(stream.ReadableStreamType.constr(T)) + res = result_def(T) + + assembly_format = "`from` $stream attr-dict `:` type($res)" + + def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): + if result_type is None: + assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) + stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + result_type = stream_type.element_type + super().__init__(operands=[stream_val], result_types=[result_type]) + + def assembly_line(self) -> str | None: + return None + @irdl_op_definition -class WriteOp(stream.WriteOperation): +class WriteOp(IRDLOperation): name = "memref_stream.write" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + stream = operand_def(stream.WritableStreamType.constr(T)) + + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + + def __init__(self, value: SSAValue, stream: SSAValue): + super().__init__(operands=[value, stream]) + + def assembly_line(self) -> str | None: + return None + @irdl_op_definition class StreamingRegionOp(IRDLOperation): diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 94ce963e95..f1179342fd 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -40,6 +40,7 @@ ) from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( + AnyAttr, VarConstraint, attr_def, base, @@ -155,17 +156,41 @@ def assembly_line(self) -> str | None: @irdl_op_definition -class ReadOp(stream.ReadOperation, RISCVAsmOperation): +class ReadOp(RISCVAsmOperation): name = "riscv_snitch.read" + T: ClassVar = VarConstraint("T", AnyAttr()) + + stream = operand_def(stream.ReadableStreamType.constr(T)) + res = result_def(T) + + assembly_format = "`from` $stream attr-dict `:` type($res)" + + def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): + if result_type is None: + assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) + stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + result_type = stream_type.element_type + super().__init__(operands=[stream_val], result_types=[result_type]) + def assembly_line(self) -> str | None: return None @irdl_op_definition -class WriteOp(stream.WriteOperation, RISCVAsmOperation): +class WriteOp(RISCVAsmOperation): name = "riscv_snitch.write" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = operand_def(T) + stream = operand_def(stream.WritableStreamType.constr(T)) + + assembly_format = "$value `to` $stream attr-dict `:` type($value)" + + def __init__(self, value: SSAValue, stream: SSAValue): + super().__init__(operands=[value, stream]) + def assembly_line(self) -> str | None: return None diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py index e95e051fc8..7ed748f393 100644 --- a/xdsl/dialects/stream.py +++ b/xdsl/dialects/stream.py @@ -1,27 +1,20 @@ from __future__ import annotations -import abc -from typing import ClassVar, Generic, TypeVar, cast +from typing import Generic, TypeVar from xdsl.dialects.builtin import ContainerType from xdsl.ir import ( Attribute, Dialect, ParametrizedAttribute, - SSAValue, TypeAttribute, ) from xdsl.irdl import ( - AnyAttr, BaseAttr, GenericAttrConstraint, - IRDLOperation, ParamAttrConstraint, ParameterDef, - VarConstraint, irdl_attr_definition, - operand_def, - result_def, ) _StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) @@ -87,42 +80,6 @@ def constr( ) -class ReadOperation(IRDLOperation, abc.ABC): - """ - Abstract base class for operations that read from a stream. - """ - - T: ClassVar = VarConstraint("T", AnyAttr()) - - stream = operand_def(ReadableStreamType.constr(T)) - res = result_def(T) - - assembly_format = "`from` $stream attr-dict `:` type($res)" - - def __init__(self, stream: SSAValue, result_type: Attribute | None = None): - if result_type is None: - assert isinstance(stream_type := stream.type, ReadableStreamType) - stream_type = cast(ReadableStreamType[Attribute], stream_type) - result_type = stream_type.element_type - super().__init__(operands=[stream], result_types=[result_type]) - - -class WriteOperation(IRDLOperation, abc.ABC): - """ - Abstract base class for operations that write to a stream. - """ - - T: ClassVar = VarConstraint("T", AnyAttr()) - - value = operand_def(T) - stream = operand_def(WritableStreamType.constr(T)) - - assembly_format = "$value `to` $stream attr-dict `:` type($value)" - - def __init__(self, value: SSAValue, stream: SSAValue): - super().__init__(operands=[value, stream]) - - Stream = Dialect( "stream", [],