Skip to content

Commit

Permalink
transformations: Convert memref to ptrdialect (#3383)
Browse files Browse the repository at this point in the history
Introducing lowering of some memref operations to a ptr dialect. This
will be used to translate riscv compilation to ptr dialect in the future
(this is the second pr in the series).
  • Loading branch information
mamanain authored Nov 3, 2024
1 parent 1cf1222 commit 152aa06
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tests/filecheck/transforms/convert_memref_to_ptr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: xdsl-opt -p convert-memref-to-ptr --split-input-file --verify-diagnostics %s | filecheck %s

%v, %idx, %arr = "test.op"() : () -> (i32, index, memref<10xi32>)
memref.store %v, %arr[%idx] {"nontemporal" = false} : memref<10xi32>

// CHECK: %bytes_per_element = ptr_xdsl.type_offset i32 : index
// CHECK-NEXT: %scaled_pointer_offset = arith.muli %idx, %bytes_per_element : index
// CHECK-NEXT: %0 = ptr_xdsl.to_ptr %arr : memref<10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer = ptr_xdsl.ptradd %0, %scaled_pointer_offset : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %v, %offset_pointer : i32, !ptr_xdsl.ptr

%idx1, %idx2, %arr2 = "test.op"() : () -> (index, index, memref<10x10xi32>)
memref.store %v, %arr2[%idx1, %idx2] {"nontemporal" = false} : memref<10x10xi32>

// CHECK-NEXT: %idx1, %idx2, %arr2 = "test.op"() : () -> (index, index, memref<10x10xi32>)
// CHECK-NEXT: %pointer_dim_stride = arith.constant 10 : index
// CHECK-NEXT: %pointer_dim_offset = arith.muli %idx1, %pointer_dim_stride : index
// CHECK-NEXT: %pointer_dim_stride_1 = arith.addi %pointer_dim_offset, %idx2 : index
// CHECK-NEXT: %bytes_per_element_1 = ptr_xdsl.type_offset i32 : index
// CHECK-NEXT: %scaled_pointer_offset_1 = arith.muli %pointer_dim_stride_1, %bytes_per_element_1 : index
// CHECK-NEXT: %1 = ptr_xdsl.to_ptr %arr2 : memref<10x10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer_1 = ptr_xdsl.ptradd %1, %scaled_pointer_offset_1 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %v, %offset_pointer_1 : i32, !ptr_xdsl.ptr

%lv = memref.load %arr[%idx] {"nontemporal" = false} : memref<10xi32>

// CHECK-NEXT: %bytes_per_element_2 = ptr_xdsl.type_offset i32 : index
// CHECK-NEXT: %scaled_pointer_offset_2 = arith.muli %idx, %bytes_per_element_2 : index
// CHECK-NEXT: %lv = ptr_xdsl.to_ptr %arr : memref<10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer_2 = ptr_xdsl.ptradd %lv, %scaled_pointer_offset_2 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: %lv_1 = ptr_xdsl.load %offset_pointer_2 : !ptr_xdsl.ptr -> i32

%lv2 = memref.load %arr2[%idx1, %idx2] {"nontemporal" = false} : memref<10x10xi32>

// CHECK-NEXT: %pointer_dim_stride_2 = arith.constant 10 : index
// CHECK-NEXT: %pointer_dim_offset_1 = arith.muli %idx1, %pointer_dim_stride_2 : index
// CHECK-NEXT: %pointer_dim_stride_3 = arith.addi %pointer_dim_offset_1, %idx2 : index
// CHECK-NEXT: %bytes_per_element_3 = ptr_xdsl.type_offset i32 : index
// CHECK-NEXT: %scaled_pointer_offset_3 = arith.muli %pointer_dim_stride_3, %bytes_per_element_3 : index
// CHECK-NEXT: %lv2 = ptr_xdsl.to_ptr %arr2 : memref<10x10xi32> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer_3 = ptr_xdsl.ptradd %lv2, %scaled_pointer_offset_3 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: %lv2_1 = ptr_xdsl.load %offset_pointer_3 : !ptr_xdsl.ptr -> i32

%fv, %farr = "test.op"() : () -> (f64, memref<10xf64>)
memref.store %fv, %farr[%idx] {"nontemporal" = false} : memref<10xf64>

// CHECK-NEXT: %fv, %farr = "test.op"() : () -> (f64, memref<10xf64>)
// CHECK-NEXT: %bytes_per_element_4 = ptr_xdsl.type_offset f64 : index
// CHECK-NEXT: %scaled_pointer_offset_4 = arith.muli %idx, %bytes_per_element_4 : index
// CHECK-NEXT: %2 = ptr_xdsl.to_ptr %farr : memref<10xf64> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer_4 = ptr_xdsl.ptradd %2, %scaled_pointer_offset_4 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: ptr_xdsl.store %fv, %offset_pointer_4 : f64, !ptr_xdsl.ptr

%flv = memref.load %farr[%idx] {"nontemporal" = false} : memref<10xf64>

// CHECK-NEXT: %bytes_per_element_5 = ptr_xdsl.type_offset f64 : index
// CHECK-NEXT: %scaled_pointer_offset_5 = arith.muli %idx, %bytes_per_element_5 : index
// CHECK-NEXT: %flv = ptr_xdsl.to_ptr %farr : memref<10xf64> -> !ptr_xdsl.ptr
// CHECK-NEXT: %offset_pointer_5 = ptr_xdsl.ptradd %flv, %scaled_pointer_offset_5 : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: %flv_1 = ptr_xdsl.load %offset_pointer_5 : !ptr_xdsl.ptr -> f64

%fmem = "test.op"() : () -> (memref<f64>)
%flv2 = memref.load %fmem[] {"nontemporal" = false} : memref<f64>

// CHECK-NEXT: %fmem = "test.op"() : () -> memref<f64>
// CHECK-NEXT: %flv2 = ptr_xdsl.to_ptr %fmem : memref<f64> -> !ptr_xdsl.ptr
// CHECK-NEXT: %flv2_1 = ptr_xdsl.load %flv2 : !ptr_xdsl.ptr -> f64
6 changes: 6 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ def get_memref_to_dsd():

return memref_to_dsd.MemrefToDsdPass

def get_memref_to_ptr():
from xdsl.transforms import convert_memref_to_ptr

return convert_memref_to_ptr.ConvertMemrefToPtr

def get_mlir_opt():
from xdsl.transforms import mlir_opt

Expand Down Expand Up @@ -501,6 +506,7 @@ def get_varith_fuse_repeated_operands():
"memref-stream-tile-outer-loops": get_memref_stream_tile_outer_loops,
"memref-stream-legalize": get_memref_stream_legalize,
"memref-to-dsd": get_memref_to_dsd,
"convert-memref-to-ptr": get_memref_to_ptr,
"mlir-opt": get_mlir_opt,
"printf-to-llvm": get_printf_to_llvm,
"printf-to-putchar": get_printf_to_putchar,
Expand Down
164 changes: 164 additions & 0 deletions xdsl/transforms/convert_memref_to_ptr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, memref, ptr
from xdsl.ir import Operation, SSAValue
from xdsl.irdl import Any
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.exceptions import DiagnosticException


def offset_calculations(
memref_type: memref.MemRefType[Any], indices: Iterable[SSAValue]
) -> tuple[list[Operation], SSAValue]:
"""Get operations calculating an offset which needs to be added to memref's base pointer to access an element referenced by indices."""

assert isinstance(memref_type.element_type, builtin.FixedBitwidthType)

match memref_type.layout:
case builtin.NoneAttr():
strides = builtin.ShapedType.strides_for_shape(memref_type.get_shape())
case builtin.StridedLayoutAttr():
strides = memref_type.layout.get_strides()
case _:
raise DiagnosticException(f"Unsupported layout type {memref_type.layout}")

ops: list[Operation] = []

head: SSAValue | None = None

for index, stride in zip(indices, strides, strict=True):
# Calculate the offset that needs to be added through the index of the current
# dimension.
increment = index
match stride:
case None:
raise DiagnosticException(
f"MemRef {memref_type} with dynamic stride is not yet implemented"
)
case 1:
# Stride 1 is a noop making the index equal to the offset.
pass
case _:
# Otherwise, multiply the stride (which by definition is the number of
# elements required to be skipped when incrementing that dimension).
ops.extend(
(
stride_op := arith.Constant.from_int_and_width(
stride, builtin.IndexType()
),
offset_op := arith.Muli(increment, stride_op),
)
)
stride_op.result.name_hint = "pointer_dim_stride"
offset_op.result.name_hint = "pointer_dim_offset"

increment = offset_op.result

if head is None:
# First iteration.
head = increment
continue

# Otherwise sum up the products.
add_op = arith.Addi(head, increment)
add_op.result.name_hint = "pointer_dim_stride"
ops.append(add_op)
head = add_op.result

if head is None:
raise DiagnosticException("Got empty indices for offset calculations.")

ops.extend(
[
bytes_per_element_op := ptr.TypeOffsetOp(
operands=[],
result_types=[builtin.IndexType()],
properties={"elem_type": memref_type.element_type},
),
final_offset := arith.Muli(head, bytes_per_element_op),
]
)

bytes_per_element_op.offset.name_hint = "bytes_per_element"
final_offset.result.name_hint = "scaled_pointer_offset"

return ops, final_offset.result


def get_target_ptr(
target_memref: SSAValue,
memref_type: memref.MemRefType[Any],
indices: Iterable[SSAValue],
) -> tuple[list[Operation], SSAValue]:
"""Get operations returning a pointer to an element of a memref referenced by indices."""

ops: list[Operation] = [
memref_ptr := ptr.ToPtrOp(
operands=[target_memref], result_types=[ptr.PtrType()]
)
]

if not indices:
return ops, memref_ptr.res

offset_ops, offset = offset_calculations(memref_type, indices)
ops = offset_ops + ops
ops.append(
target_ptr := ptr.PtrAddOp(
operands=[memref_ptr, offset], result_types=[ptr.PtrType()]
)
)

target_ptr.result.name_hint = "offset_pointer"
return ops, target_ptr.result


@dataclass
class ConvertStoreOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.Store, rewriter: PatternRewriter, /):
assert isinstance(op_memref_type := op.memref.type, memref.MemRefType)
memref_type = cast(memref.MemRefType[Any], op_memref_type)

ops, target_ptr = get_target_ptr(op.memref, memref_type, op.indices)
ops.append(ptr.StoreOp(operands=[target_ptr, op.value]))

rewriter.replace_matched_op(ops)


@dataclass
class ConvertLoadOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.Load, rewriter: PatternRewriter, /):
assert isinstance(op_memref_type := op.memref.type, memref.MemRefType)
memref_type = cast(memref.MemRefType[Any], op_memref_type)

ops, target_ptr = get_target_ptr(op.memref, memref_type, op.indices)
ops.append(
load_result := ptr.LoadOp(
operands=[target_ptr], result_types=[memref_type.element_type]
)
)

rewriter.replace_matched_op(ops, new_results=[load_result.res])


@dataclass(frozen=True)
class ConvertMemrefToPtr(ModulePass):
name = "convert-memref-to-ptr"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
the_one_pass = PatternRewriteWalker(
GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()])
)
the_one_pass.rewrite_module(op)

0 comments on commit 152aa06

Please sign in to comment.