From 54c270c4feacf6128694852bb3bc279771455e1e Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 7 Nov 2024 15:08:02 +0100 Subject: [PATCH] transformations: New test-add-timers-to-top-level-funcs pass --- .../test-add-timers-to-top-level-funcs.mlir | 66 +++++++++++++++++ xdsl/transforms/csl_stencil_to_csl_wrapper.py | 17 ++--- xdsl/transforms/function_transformations.py | 71 ++++++++++++++++++- 3 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir diff --git a/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir b/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir new file mode 100644 index 0000000000..bf14d55036 --- /dev/null +++ b/tests/filecheck/transforms/test-add-timers-to-top-level-funcs.mlir @@ -0,0 +1,66 @@ +// RUN: xdsl-opt %s -p test-add-timers-to-top-level-funcs --split-input-file | filecheck %s + +builtin.module { + + // CHECK: builtin.module { + // CHECK-NEXT: func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 { + // CHECK-NEXT: %start = func.call @timer_start() : () -> f64 + // CHECK-NEXT: "test.op"() : () -> () + // CHECK-NEXT: %end = func.call @timer_end(%start) : (f64) -> f64 + // CHECK-NEXT: "llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () + // CHECK-NEXT: func.return %arg0 : i32 + // CHECK-NEXT: } + // CHECK-NEXT: func.func private @timer_start() -> f64 + // CHECK-NEXT: func.func private @timer_end(f64) -> f64 + // CHECK-NEXT: } + + func.func @has_timers(%arg0 : i32, %timers : !llvm.ptr) -> i32 { + %start = func.call @timer_start() : () -> f64 + "test.op"() : () -> () + %end = func.call @timer_end(%start) : (f64) -> f64 + "llvm.store"(%end, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () + func.return %arg0 : i32 + } + func.func private @timer_start() -> f64 + func.func private @timer_end(f64) -> f64 +} + +// ----- + +builtin.module { + + // CHECK: builtin.module { + // CHECK-NEXT: func.func @has_no_timers(%arg0 : i32, %arg1 : i32, %timers : !llvm.ptr) -> i32 { + // CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64 + // CHECK-NEXT: %res = arith.addi %arg0, %arg1 : i32 + // CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64 + // CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () + // CHECK-NEXT: func.return %res : i32 + // CHECK-NEXT: } + // CHECK-NEXT: func.func @also_has_no_timers(%timers : !llvm.ptr) { + // CHECK-NEXT: %timestamp = func.call @timer_start() : () -> f64 + // CHECK-NEXT: func.func @nested_should_not_get_timers() { + // CHECK-NEXT: func.return + // CHECK-NEXT: } + // CHECK-NEXT: "test.op"() : () -> () + // CHECK-NEXT: %timediff = func.call @timer_end(%timestamp) : (f64) -> f64 + // CHECK-NEXT: "llvm.store"(%timediff, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () + // CHECK-NEXT: func.return + // CHECK-NEXT: } + // CHECK-NEXT: func.func @timer_start() -> f64 + // CHECK-NEXT: func.func @timer_end(f64) -> f64 + // CHECK-NEXT: } + + func.func @has_no_timers(%arg0 : i32, %arg1 : i32) -> i32 { + %res = arith.addi %arg0, %arg1 : i32 + func.return %res : i32 + } + + func.func @also_has_no_timers() { + func.func @nested_should_not_get_timers() { + func.return + } + "test.op"() : () -> () + func.return + } +} diff --git a/xdsl/transforms/csl_stencil_to_csl_wrapper.py b/xdsl/transforms/csl_stencil_to_csl_wrapper.py index 839fb4f79e..a740510eaa 100644 --- a/xdsl/transforms/csl_stencil_to_csl_wrapper.py +++ b/xdsl/transforms/csl_stencil_to_csl_wrapper.py @@ -30,13 +30,14 @@ ) from xdsl.rewriter import InsertPoint from xdsl.transforms import csl_stencil_bufferize +from xdsl.transforms.function_transformations import ( + TIMER_END, + TIMER_FUNC_NAMES, + TIMER_START, +) from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr -_TIMER_START = "timer_start" -_TIMER_END = "timer_end" -_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END] - def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None: """ @@ -64,7 +65,7 @@ class ConvertStencilFuncToModuleWrappedPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): # erase timer stubs - if op.is_declaration and op.sym_name.data in _TIMER_FUNC_NAMES: + if op.is_declaration and op.sym_name.data in TIMER_FUNC_NAMES: rewriter.erase_matched_op() return # find csl_stencil.apply ops, abort if there are none @@ -250,7 +251,7 @@ def _translate_function_args( isinstance(u.operation, llvm.StoreOp) and isinstance(u.operation.value, OpResult) and isinstance(u.operation.value.op, func.Call) - and u.operation.value.op.callee.string_value() == _TIMER_END + and u.operation.value.op.callee.string_value() == TIMER_END for u in arg.uses ): start_end_size = 3 @@ -394,9 +395,9 @@ class LowerTimerFuncCall(RewritePattern): def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /): if ( not isinstance(end_call := op.value.owner, func.Call) - or not end_call.callee.string_value() == _TIMER_END + or not end_call.callee.string_value() == TIMER_END or not (isinstance(start_call := end_call.arguments[0].owner, func.Call)) - or not start_call.callee.string_value() == _TIMER_START + or not start_call.callee.string_value() == TIMER_START or not (wrapper := _get_module_wrapper(op)) or not isa(op.ptr.type, AnyMemRefType) ): diff --git a/xdsl/transforms/function_transformations.py b/xdsl/transforms/function_transformations.py index 3b66102b4e..1324c46e63 100644 --- a/xdsl/transforms/function_transformations.py +++ b/xdsl/transforms/function_transformations.py @@ -1,6 +1,9 @@ +from dataclasses import dataclass + from xdsl.context import MLContext -from xdsl.dialects import builtin, func +from xdsl.dialects import builtin, func, llvm from xdsl.dialects.builtin import ArrayAttr, DictionaryAttr, StringAttr +from xdsl.ir import Region from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( PatternRewriter, @@ -8,6 +11,7 @@ RewritePattern, op_type_rewrite_pattern, ) +from xdsl.rewriter import InsertPoint class ArgNamesToArgAttrsPass(RewritePattern): @@ -36,6 +40,71 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): rewriter.has_done_action = True +TIMER_START = "timer_start" +TIMER_END = "timer_end" +TIMER_FUNC_NAMES = [TIMER_START, TIMER_END] + + +@dataclass +class AddBenchTimersPattern(RewritePattern): + start_func_t: func.FunctionType + end_func_t: func.FunctionType + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): + if ( + not (top_level := op.parent_op()) + or not isinstance(top_level, builtin.ModuleOp) + or top_level.parent + ): + return + + ptr = op.body.block.insert_arg(llvm.LLVMPointerType.opaque(), len(op.args)) + start_call = func.Call(TIMER_START, [], tuple(self.start_func_t.outputs)) + end_call = func.Call(TIMER_END, start_call.res, tuple(self.end_func_t.outputs)) + store_time = llvm.StoreOp(end_call.res[0], ptr) + + ptr.name_hint = "timers" + start_call.res[0].name_hint = "timestamp" + end_call.res[0].name_hint = "timediff" + + assert op.body.block.last_op + rewriter.insert_op(start_call, InsertPoint.at_start(op.body.block)) + rewriter.insert_op( + [end_call, store_time], InsertPoint.before(op.body.block.last_op) + ) + op.update_function_type() + + +class TestAddBenchTimersToTopLevelFunctions(ModulePass): + """ + Adds timers to top-level functions, by adding `timer_start() -> f64` and `timer_end(f64) -> f64` + to the start and end of each module-level function. The time is stored in an `llvm.ptr` passed in + as a function arg. + """ + + name = "test-add-timers-to-top-level-funcs" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + all_funcs = [f for f in op.body.block.ops if isinstance(f, func.FuncOp)] + func_names = [f.sym_name.data for f in all_funcs] + if TIMER_START in func_names or TIMER_END in func_names: + return + + start_func_t = func.FunctionType.from_lists([], [builtin.Float64Type()]) + end_func_t = func.FunctionType.from_lists( + [builtin.Float64Type()], [builtin.Float64Type()] + ) + start_func = func.FuncOp(TIMER_START, start_func_t, Region([])) + end_func = func.FuncOp(TIMER_END, end_func_t, Region([])) + + PatternRewriteWalker( + AddBenchTimersPattern(start_func_t, end_func_t), apply_recursively=False + ).rewrite_module(op) + + op.body.block.add_ops((start_func, end_func)) + + class FunctionPersistArgNames(ModulePass): """ Persists func.func arg name hints to arg_attrs.