Skip to content

Commit

Permalink
transformations: New test-add-timers-to-top-level-funcs pass
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io committed Nov 7, 2024
1 parent f40c920 commit 54c270c
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 9 deletions.
66 changes: 66 additions & 0 deletions tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: xdsl-opt %s -p test-add-timers-to-top-level-funcs --split-input-file | filecheck %s

builtin.module {

// CHECK: builtin.module {
// CHECK-NEXT: func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 {
// CHECK-NEXT: %start = func.call @timer_start() : () -> f64
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: %end = func.call @timer_end(%start) : (f64) -> f64
// CHECK-NEXT: "llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return %arg0 : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func private @timer_start() -> f64
// CHECK-NEXT: func.func private @timer_end(f64) -> f64
// CHECK-NEXT: }

func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 {
%start = func.call @timer_start() : () -> f64
"test.op"() : () -> ()
%end = func.call @timer_end(%start) : (f64) -> f64
"llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
func.return %arg0 : i32
}
func.func private @timer_start() -> f64
func.func private @timer_end(f64) -> f64
}

// -----

builtin.module {

// CHECK: builtin.module {
// CHECK-NEXT: func.func @has_no_timers(%arg0 : i32, %arg1 : i32, %timers : !llvm.ptr) -> i32 {
// CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64
// CHECK-NEXT: %res = arith.addi %arg0, %arg1 : i32
// CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return %res : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func @also_has_no_timers(%timers : !llvm.ptr) {
// CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64
// CHECK-NEXT: func.func @nested_should_not_get_timers() {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64
// CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: func.func @timer_start() -> f64
// CHECK-NEXT: func.func @timer_end(f64) -> f64
// CHECK-NEXT: }

func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 {
%res = arith.addi %arg0, %arg1 : i32
func.return %res : i32
}

func.func @also_has_no_timers() {
func.func @nested_should_not_get_timers() {
func.return
}
"test.op"() : () -> ()
func.return
}
}
17 changes: 9 additions & 8 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
)
from xdsl.rewriter import InsertPoint
from xdsl.transforms import csl_stencil_bufferize
from xdsl.transforms.function_transformations import (
TIMER_END,
TIMER_FUNC_NAMES,
TIMER_START,
)
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

_TIMER_START = "timer_start"
_TIMER_END = "timer_end"
_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END]


def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None:
"""
Expand Down Expand Up @@ -64,7 +65,7 @@ class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# erase timer stubs
if op.is_declaration and op.sym_name.data in _TIMER_FUNC_NAMES:
if op.is_declaration and op.sym_name.data in TIMER_FUNC_NAMES:
rewriter.erase_matched_op()
return
# find csl_stencil.apply ops, abort if there are none
Expand Down Expand Up @@ -250,7 +251,7 @@ def _translate_function_args(
isinstance(u.operation, llvm.StoreOp)
and isinstance(u.operation.value, OpResult)
and isinstance(u.operation.value.op, func.Call)
and u.operation.value.op.callee.string_value() == _TIMER_END
and u.operation.value.op.callee.string_value() == TIMER_END
for u in arg.uses
):
start_end_size = 3
Expand Down Expand Up @@ -394,9 +395,9 @@ class LowerTimerFuncCall(RewritePattern):
def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
if (
not isinstance(end_call := op.value.owner, func.Call)
or not end_call.callee.string_value() == _TIMER_END
or not end_call.callee.string_value() == TIMER_END
or not (isinstance(start_call := end_call.arguments[0].owner, func.Call))
or not start_call.callee.string_value() == _TIMER_START
or not start_call.callee.string_value() == TIMER_START
or not (wrapper := _get_module_wrapper(op))
or not isa(op.ptr.type, AnyMemRefType)
):
Expand Down
71 changes: 70 additions & 1 deletion xdsl/transforms/function_transformations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import builtin, func
from xdsl.dialects import builtin, func, llvm
from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, StringAttr
from xdsl.ir import Region
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint


class ArgNamesToArgAttrsPass(RewritePattern):
Expand Down Expand Up @@ -36,6 +40,71 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
rewriter.has_done_action = True


TIMER_START = "timer_start"
TIMER_END = "timer_end"
TIMER_FUNC_NAMES = [TIMER_START, TIMER_END]


@dataclass
class AddBenchTimersPattern(RewritePattern):
start_func_t: func.FunctionType
end_func_t: func.FunctionType

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
if (
not (top_level := op.parent_op())
or not isinstance(top_level, builtin.ModuleOp)
or top_level.parent
):
return

ptr = op.body.block.insert_arg(llvm.LLVMPointerType.opaque(), len(op.args))
start_call = func.Call(TIMER_START, [], tuple(self.start_func_t.outputs))
end_call = func.Call(TIMER_END, start_call.res, tuple(self.end_func_t.outputs))
store_time = llvm.StoreOp(end_call.res[0], ptr)

ptr.name_hint = "timers"
start_call.res[0].name_hint = "timestamp"
end_call.res[0].name_hint = "timediff"

assert op.body.block.last_op
rewriter.insert_op(start_call, InsertPoint.at_start(op.body.block))
rewriter.insert_op(
[end_call, store_time], InsertPoint.before(op.body.block.last_op)
)
op.update_function_type()


class TestAddBenchTimersToTopLevelFunctions(ModulePass):
"""
Adds timers to top-level functions, by adding `timer_start() -> f64` and `timer_end(f64) -> f64`
to the start and end of each module-level function. The time is stored in an `llvm.ptr` passed in
as a function arg.
"""

name = "test-add-timers-to-top-level-funcs"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
all_funcs = [f for f in op.body.block.ops if isinstance(f, func.FuncOp)]
func_names = [f.sym_name.data for f in all_funcs]
if TIMER_START in func_names or TIMER_END in func_names:
return

start_func_t = func.FunctionType.from_lists([], [builtin.Float64Type()])
end_func_t = func.FunctionType.from_lists(
[builtin.Float64Type()], [builtin.Float64Type()]
)
start_func = func.FuncOp(TIMER_START, start_func_t, Region([]))
end_func = func.FuncOp(TIMER_END, end_func_t, Region([]))

PatternRewriteWalker(
AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False
).rewrite_module(op)

op.body.block.add_ops((start_func, end_func))


class FunctionPersistArgNames(ModulePass):
"""
Persists func.func arg name hints to arg_attrs.
Expand Down

0 comments on commit 54c270c

Please sign in to comment.