From 810d653d3f9d47fa3097f5849f42aed3f202d618 Mon Sep 17 00:00:00 2001 From: Maxim Manainen Date: Tue, 15 Oct 2024 15:11:50 +0100 Subject: [PATCH] dialects: (bufferization) add materialize_in_destination (#3301) Currently we need only the command with tensor arguments but we should extend it to memrefs as well at some point. --- .../with-mlir/dialects/bufferization/ops.mlir | 7 +++++ xdsl/dialects/bufferization.py | 26 +++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/bufferization/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/bufferization/ops.mlir index 72b53f6fb9..6a79d63a1e 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/bufferization/ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/bufferization/ops.mlir @@ -5,6 +5,10 @@ module{ %m = "test.op"() : () -> memref<30x20x10xf32> %m_t = bufferization.to_tensor %m restrict writable : memref<30x20x10xf32> %t_m = bufferization.to_memref %m_t read_only : memref<30x20x10xf32> + + %tensor1 = "test.op"() : () -> tensor<2x2xf64> + %tensor2 = "test.op"() : () -> tensor<2x2xf64> + %b = bufferization.materialize_in_destination %tensor1 in %tensor2 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> } // CHECK: builtin.module { @@ -12,4 +16,7 @@ module{ // CHECK-NEXT: %1 = "test.op"() : () -> memref<30x20x10xf32> // CHECK-NEXT: %2 = bufferization.to_tensor %1 restrict writable : memref<30x20x10xf32> // CHECK-NEXT: %3 = bufferization.to_memref %2 read_only : memref<30x20x10xf32> +// CHECK-NEXT: %4 = "test.op"() : () -> tensor<2x2xf64> +// CHECK-NEXT: %5 = "test.op"() : () -> tensor<2x2xf64> +// CHECK-NEXT: %6 = bufferization.materialize_in_destination %4 in %5 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> // CHECK-NEXT: } diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py index 5496c6028b..7b4f1dc076 100644 --- a/xdsl/dialects/bufferization.py +++ b/xdsl/dialects/bufferization.py @@ -168,12 +168,38 @@ class ToMemrefOp(IRDLOperation): assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)" +@irdl_op_definition +class MaterializeInDestination(IRDLOperation): + name = "bufferization.materialize_in_destination" + + source = operand_def( + TensorMemrefInferenceConstraint( + "T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr + ) + ) + dest = operand_def( + TensorMemrefInferenceConstraint( + "T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr + ) + ) + result = result_def( + TensorMemrefInferenceConstraint( + "T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr + ) + ) + restrict = opt_prop_def(UnitAttr) + writable = opt_prop_def(UnitAttr) + + assembly_format = "$source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest attr-dict `:` `(` type($source) `,` type($dest) `)` `->` type($result)" + + Bufferization = Dialect( "bufferization", [ AllocTensorOp, ToTensorOp, ToMemrefOp, + MaterializeInDestination, ], [], )