diff --git a/tests/filecheck/transforms/convert_memref_to_ptr.mlir b/tests/filecheck/transforms/convert_memref_to_ptr.mlir new file mode 100644 index 0000000000..b7691da29a --- /dev/null +++ b/tests/filecheck/transforms/convert_memref_to_ptr.mlir @@ -0,0 +1,67 @@ +// RUN: xdsl-opt -p convert-memref-to-ptr --split-input-file --verify-diagnostics %s | filecheck %s + +%v, %idx, %arr = "test.op"() : () -> (i32, index, memref<10xi32>) +memref.store %v, %arr[%idx] {"nontemporal" = false} : memref<10xi32> + +// CHECK: %bytes_per_element = ptr_xdsl.type_offset i32 : index +// CHECK-NEXT: %scaled_pointer_offset = arith.muli %idx, %bytes_per_element : index +// CHECK-NEXT: %0 = ptr_xdsl.to_ptr %arr : memref<10xi32> -> !ptr_xdsl.ptr +// CHECK-NEXT: %offset_pointer = ptr_xdsl.ptradd %0, %scaled_pointer_offset : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr +// CHECK-NEXT: ptr_xdsl.store %v, %offset_pointer : i32, !ptr_xdsl.ptr + +%idx1, %idx2, %arr2 = "test.op"() : () -> (index, index, memref<10x10xi32>) +memref.store %v, %arr2[%idx1, %idx2] {"nontemporal" = false} : memref<10x10xi32> + +// CHECK-NEXT: %idx1, %idx2, %arr2 = "test.op"() : () -> (index, index, memref<10x10xi32>) +// CHECK-NEXT: %pointer_dim_stride = arith.constant 10 : index +// CHECK-NEXT: %pointer_dim_offset = arith.muli %idx1, %pointer_dim_stride : index +// CHECK-NEXT: %pointer_dim_stride_1 = arith.addi %pointer_dim_offset, %idx2 : index +// CHECK-NEXT: %bytes_per_element_1 = ptr_xdsl.type_offset i32 : index +// CHECK-NEXT: %scaled_pointer_offset_1 = arith.muli %pointer_dim_stride_1, %bytes_per_element_1 : index +// CHECK-NEXT: %1 = ptr_xdsl.to_ptr %arr2 : memref<10x10xi32> -> !ptr_xdsl.ptr +// CHECK-NEXT: %offset_pointer_1 = ptr_xdsl.ptradd %1, %scaled_pointer_offset_1 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr +// CHECK-NEXT: ptr_xdsl.store %v, %offset_pointer_1 : i32, !ptr_xdsl.ptr + +%lv = memref.load %arr[%idx] {"nontemporal" = false} : memref<10xi32> + +// CHECK-NEXT: %bytes_per_element_2 = ptr_xdsl.type_offset i32 : index +// CHECK-NEXT: %scaled_pointer_offset_2 = arith.muli %idx, %bytes_per_element_2 : index +// CHECK-NEXT: %lv = ptr_xdsl.to_ptr %arr : memref<10xi32> -> !ptr_xdsl.ptr +// CHECK-NEXT: %offset_pointer_2 = ptr_xdsl.ptradd %lv, %scaled_pointer_offset_2 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr +// CHECK-NEXT: %lv_1 = ptr_xdsl.load %offset_pointer_2 : !ptr_xdsl.ptr -> i32 + +%lv2 = memref.load %arr2[%idx1, %idx2] {"nontemporal" = false} : memref<10x10xi32> + +// CHECK-NEXT: %pointer_dim_stride_2 = arith.constant 10 : index +// CHECK-NEXT: %pointer_dim_offset_1 = arith.muli %idx1, %pointer_dim_stride_2 : index +// CHECK-NEXT: %pointer_dim_stride_3 = arith.addi %pointer_dim_offset_1, %idx2 : index +// CHECK-NEXT: %bytes_per_element_3 = ptr_xdsl.type_offset i32 : index +// CHECK-NEXT: %scaled_pointer_offset_3 = arith.muli %pointer_dim_stride_3, %bytes_per_element_3 : index +// CHECK-NEXT: %lv2 = ptr_xdsl.to_ptr %arr2 : memref<10x10xi32> -> !ptr_xdsl.ptr +// CHECK-NEXT: %offset_pointer_3 = ptr_xdsl.ptradd %lv2, %scaled_pointer_offset_3 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr +// CHECK-NEXT: %lv2_1 = ptr_xdsl.load %offset_pointer_3 : !ptr_xdsl.ptr -> i32 + +%fv, %farr = "test.op"() : () -> (f64, memref<10xf64>) +memref.store %fv, %farr[%idx] {"nontemporal" = false} : memref<10xf64> + +// CHECK-NEXT: %fv, %farr = "test.op"() : () -> (f64, memref<10xf64>) +// CHECK-NEXT: %bytes_per_element_4 = ptr_xdsl.type_offset f64 : index +// CHECK-NEXT: %scaled_pointer_offset_4 = arith.muli %idx, %bytes_per_element_4 : index +// CHECK-NEXT: %2 = ptr_xdsl.to_ptr %farr : memref<10xf64> -> !ptr_xdsl.ptr +// CHECK-NEXT: %offset_pointer_4 = ptr_xdsl.ptradd %2, %scaled_pointer_offset_4 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr +// CHECK-NEXT: ptr_xdsl.store %fv, %offset_pointer_4 : f64, !ptr_xdsl.ptr + +%flv = memref.load %farr[%idx] {"nontemporal" = false} : memref<10xf64> + +// CHECK-NEXT: %bytes_per_element_5 = ptr_xdsl.type_offset f64 : index +// CHECK-NEXT: %scaled_pointer_offset_5 = arith.muli %idx, %bytes_per_element_5 : index +// CHECK-NEXT: %flv = ptr_xdsl.to_ptr %farr : memref<10xf64> -> !ptr_xdsl.ptr +// CHECK-NEXT: %offset_pointer_5 = ptr_xdsl.ptradd %flv, %scaled_pointer_offset_5 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr +// CHECK-NEXT: %flv_1 = ptr_xdsl.load %offset_pointer_5 : !ptr_xdsl.ptr -> f64 + +%fmem = "test.op"() : () -> (memref) +%flv2 = memref.load %fmem[] {"nontemporal" = false} : memref + +// CHECK-NEXT: %fmem = "test.op"() : () -> memref +// CHECK-NEXT: %flv2 = ptr_xdsl.to_ptr %fmem : memref -> !ptr_xdsl.ptr +// CHECK-NEXT: %flv2_1 = ptr_xdsl.load %flv2 : !ptr_xdsl.ptr -> f64 diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py index a687df0d0d..a4f5d5b6fb 100644 --- a/xdsl/transforms/__init__.py +++ b/xdsl/transforms/__init__.py @@ -236,6 +236,11 @@ def get_memref_to_dsd(): return memref_to_dsd.MemrefToDsdPass + def get_memref_to_ptr(): + from xdsl.transforms import convert_memref_to_ptr + + return convert_memref_to_ptr.ConvertMemrefToPtr + def get_mlir_opt(): from xdsl.transforms import mlir_opt @@ -501,6 +506,7 @@ def get_varith_fuse_repeated_operands(): "memref-stream-tile-outer-loops": get_memref_stream_tile_outer_loops, "memref-stream-legalize": get_memref_stream_legalize, "memref-to-dsd": get_memref_to_dsd, + "convert-memref-to-ptr": get_memref_to_ptr, "mlir-opt": get_mlir_opt, "printf-to-llvm": get_printf_to_llvm, "printf-to-putchar": get_printf_to_putchar, diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py new file mode 100644 index 0000000000..239e0bdcd9 --- /dev/null +++ b/xdsl/transforms/convert_memref_to_ptr.py @@ -0,0 +1,164 @@ +from collections.abc import Iterable +from dataclasses import dataclass +from typing import cast + +from xdsl.context import MLContext +from xdsl.dialects import arith, builtin, memref, ptr +from xdsl.ir import Operation, SSAValue +from xdsl.irdl import Any +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.utils.exceptions import DiagnosticException + + +def offset_calculations( + memref_type: memref.MemRefType[Any], indices: Iterable[SSAValue] +) -> tuple[list[Operation], SSAValue]: + """Get operations calculating an offset which needs to be added to memref's base pointer to access an element referenced by indices.""" + + assert isinstance(memref_type.element_type, builtin.FixedBitwidthType) + + match memref_type.layout: + case builtin.NoneAttr(): + strides = builtin.ShapedType.strides_for_shape(memref_type.get_shape()) + case builtin.StridedLayoutAttr(): + strides = memref_type.layout.get_strides() + case _: + raise DiagnosticException(f"Unsupported layout type {memref_type.layout}") + + ops: list[Operation] = [] + + head: SSAValue | None = None + + for index, stride in zip(indices, strides, strict=True): + # Calculate the offset that needs to be added through the index of the current + # dimension. + increment = index + match stride: + case None: + raise DiagnosticException( + f"MemRef {memref_type} with dynamic stride is not yet implemented" + ) + case 1: + # Stride 1 is a noop making the index equal to the offset. + pass + case _: + # Otherwise, multiply the stride (which by definition is the number of + # elements required to be skipped when incrementing that dimension). + ops.extend( + ( + stride_op := arith.Constant.from_int_and_width( + stride, builtin.IndexType() + ), + offset_op := arith.Muli(increment, stride_op), + ) + ) + stride_op.result.name_hint = "pointer_dim_stride" + offset_op.result.name_hint = "pointer_dim_offset" + + increment = offset_op.result + + if head is None: + # First iteration. + head = increment + continue + + # Otherwise sum up the products. + add_op = arith.Addi(head, increment) + add_op.result.name_hint = "pointer_dim_stride" + ops.append(add_op) + head = add_op.result + + if head is None: + raise DiagnosticException("Got empty indices for offset calculations.") + + ops.extend( + [ + bytes_per_element_op := ptr.TypeOffsetOp( + operands=[], + result_types=[builtin.IndexType()], + properties={"elem_type": memref_type.element_type}, + ), + final_offset := arith.Muli(head, bytes_per_element_op), + ] + ) + + bytes_per_element_op.offset.name_hint = "bytes_per_element" + final_offset.result.name_hint = "scaled_pointer_offset" + + return ops, final_offset.result + + +def get_target_ptr( + target_memref: SSAValue, + memref_type: memref.MemRefType[Any], + indices: Iterable[SSAValue], +) -> tuple[list[Operation], SSAValue]: + """Get operations returning a pointer to an element of a memref referenced by indices.""" + + ops: list[Operation] = [ + memref_ptr := ptr.ToPtrOp( + operands=[target_memref], result_types=[ptr.PtrType()] + ) + ] + + if not indices: + return ops, memref_ptr.res + + offset_ops, offset = offset_calculations(memref_type, indices) + ops = offset_ops + ops + ops.append( + target_ptr := ptr.PtrAddOp( + operands=[memref_ptr, offset], result_types=[ptr.PtrType()] + ) + ) + + target_ptr.result.name_hint = "offset_pointer" + return ops, target_ptr.result + + +@dataclass +class ConvertStoreOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: memref.Store, rewriter: PatternRewriter, /): + assert isinstance(op_memref_type := op.memref.type, memref.MemRefType) + memref_type = cast(memref.MemRefType[Any], op_memref_type) + + ops, target_ptr = get_target_ptr(op.memref, memref_type, op.indices) + ops.append(ptr.StoreOp(operands=[target_ptr, op.value])) + + rewriter.replace_matched_op(ops) + + +@dataclass +class ConvertLoadOp(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: memref.Load, rewriter: PatternRewriter, /): + assert isinstance(op_memref_type := op.memref.type, memref.MemRefType) + memref_type = cast(memref.MemRefType[Any], op_memref_type) + + ops, target_ptr = get_target_ptr(op.memref, memref_type, op.indices) + ops.append( + load_result := ptr.LoadOp( + operands=[target_ptr], result_types=[memref_type.element_type] + ) + ) + + rewriter.replace_matched_op(ops, new_results=[load_result.res]) + + +@dataclass(frozen=True) +class ConvertMemrefToPtr(ModulePass): + name = "convert-memref-to-ptr" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + the_one_pass = PatternRewriteWalker( + GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()]) + ) + the_one_pass.rewrite_module(op)