Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: ptr -> riscv conversion #3393

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/filecheck/backend/riscv/convert_ptr_to_riscv.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: xdsl-opt %s -p convert-ptr-to-riscv --split-input-file --verify-diagnostics | filecheck %s

%m, %idx, %v = "test.op"() : () -> (memref<3x2xi32>, index, i32)

%p = ptr_xdsl.to_ptr %m : memref<3x2xi32> -> !ptr_xdsl.ptr
// CHECK: %p = builtin.unrealized_conversion_cast %m : memref<3x2xi32> to !riscv.reg

%r0 = ptr_xdsl.ptradd %p, %idx : (!ptr_xdsl.ptr, index) -> !ptr_xdsl.ptr
// CHECK-NEXT: %idx_1 = builtin.unrealized_conversion_cast %idx : index to !riscv.reg
// CHECK-NEXT: %r0 = riscv.add %p, %idx_1 : (!riscv.reg, !riscv.reg) -> !riscv.reg

ptr_xdsl.store %v, %p : i32, !ptr_xdsl.ptr
// CHECK-NEXT: %v_1 = builtin.unrealized_conversion_cast %v : i32 to !riscv.reg
// CHECK-NEXT: riscv.sw %p, %v_1, 0 {"comment" = "store int value to pointer"} : (!riscv.reg, !riscv.reg) -> ()

%r3 = ptr_xdsl.load %p : !ptr_xdsl.ptr -> i32
// CHECK-NEXT: %r3 = riscv.lw %p, 0 {"comment" = "load word from pointer"} : (!riscv.reg) -> !riscv.reg
// CHECK-NEXT: %r3_1 = builtin.unrealized_conversion_cast %r3 : !riscv.reg to i32

// -----

%m2 = "test.op"() : () -> (memref<3x2xf128>)
%p2 = ptr_xdsl.to_ptr %m2 : memref<3x2xf128> -> !ptr_xdsl.ptr
%v1 = ptr_xdsl.load %p2 : !ptr_xdsl.ptr -> f128
// CHECK: Unexpected floating point type f128
6 changes: 6 additions & 0 deletions xdsl/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def get_convert_print_format_to_riscv_debug():

return convert_print_format_to_riscv_debug.ConvertPrintFormatToRiscvDebugPass

def get_convert_ptr_to_riscv():
from xdsl.transforms import convert_ptr_to_riscv

return convert_ptr_to_riscv.ConvertPtrToRiscvPass

def get_convert_qref_to_qssa():
from xdsl.transforms import convert_qref_to_qssa

Expand Down Expand Up @@ -469,6 +474,7 @@ def get_varith_fuse_repeated_operands():
"convert-scf-to-riscv-scf": get_convert_scf_to_riscv_scf,
"convert-snitch-stream-to-snitch": get_convert_snitch_stream_to_snitch,
"convert-stencil-to-csl-stencil": get_convert_stencil_to_csl_stencil,
"convert-ptr-to-riscv": get_convert_ptr_to_riscv,
"inline-snrt": get_convert_snrt_to_riscv,
"convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir,
"cse": get_cse,
Expand Down
142 changes: 142 additions & 0 deletions xdsl/transforms/convert_ptr_to_riscv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from dataclasses import dataclass
from typing import cast

from xdsl.backend.riscv.lowering.utils import (
cast_operands_to_regs,
register_type_for_type,
)
from xdsl.context import MLContext
from xdsl.dialects import ptr, riscv
from xdsl.dialects.builtin import (
AnyFloat,
Float32Type,
Float64Type,
ModuleOp,
UnrealizedConversionCastOp,
)
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
TypeConversionPattern,
attr_type_rewrite_pattern,
op_type_rewrite_pattern,
)
from xdsl.utils.exceptions import DiagnosticException


class PtrTypeConversion(TypeConversionPattern):
@attr_type_rewrite_pattern
def convert_type(self, typ: ptr.PtrType) -> riscv.IntRegisterType:
return riscv.IntRegisterType.unallocated()


@dataclass
class ConvertPtrAddOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.PtrAddOp, rewriter: PatternRewriter, /):
oper1, oper2 = cast_operands_to_regs(rewriter)
rewriter.replace_matched_op(
riscv.AddOp(oper1, oper2, rd=riscv.IntRegisterType.unallocated())
)


@dataclass
class ConvertStoreOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.StoreOp, rewriter: PatternRewriter, /):
addr, value = cast_operands_to_regs(rewriter)

match value.type:
case riscv.IntRegisterType():
new_op = riscv.SwOp(
addr, value, 0, comment="store int value to pointer"
)
case riscv.FloatRegisterType():
float_type = cast(AnyFloat, op.value.type)
match float_type:
case Float32Type():
new_op = riscv.FSwOp(
addr,
value,
0,
comment="store float value to pointer",
)
case Float64Type():
new_op = riscv.FSdOp(
addr,
value,
0,
comment="store double value to pointer",
)
case _:
raise DiagnosticException(
f"Unexpected floating point type {float_type}"
)

case _:
assert False, f"Unexpected register type {op.value.type}"

rewriter.replace_matched_op(new_op)


@dataclass
class ConvertLoadOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.LoadOp, rewriter: PatternRewriter, /):
casted = cast_operands_to_regs(rewriter)
addr = casted[0]

result_register_type = register_type_for_type(op.res.type)

match result_register_type:
case riscv.IntRegisterType:
lw_op = riscv.LwOp(addr, 0, comment="load word from pointer")
case riscv.FloatRegisterType:
float_type = cast(AnyFloat, op.res.type)
match float_type:
case Float32Type():
lw_op = riscv.FLwOp(addr, 0, comment="load float from pointer")
case Float64Type():
lw_op = riscv.FLdOp(addr, 0, comment="load double from pointer")
case _:
raise DiagnosticException(
f"Unexpected floating point type {float_type}"
)

case _:
assert False, f"Unexpected register type {result_register_type}"

rewriter.replace_matched_op(
(lw := lw_op, UnrealizedConversionCastOp.get(lw.results, (op.res.type,)))
)


@dataclass
class ConvertMemrefToPtrOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.ToPtrOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(
UnrealizedConversionCastOp.get(
[op.source], [riscv.IntRegisterType.unallocated()]
)
)


class ConvertPtrToRiscvPass(ModulePass):
name = "convert-ptr-to-riscv"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
PtrTypeConversion(),
ConvertPtrAddOp(),
ConvertStoreOp(),
ConvertLoadOp(),
ConvertMemrefToPtrOp(),
]
),
).rewrite_module(op)
Loading