From 9be95f38284d41221bd768d726a61cff6c6afe23 Mon Sep 17 00:00:00 2001 From: Nicolai Stawinoga <36768051+n-io@users.noreply.github.com> Date: Thu, 17 Oct 2024 13:37:40 +0200 Subject: [PATCH] transformations: Support devito timers in the csl pipeline (#3312) Co-authored-by: n-io --- xdsl/transforms/csl_stencil_to_csl_wrapper.py | 148 +++++++++++++++++- .../stencil_tensorize_z_dimension.py | 7 +- xdsl/transforms/stencil_shape_minimize.py | 3 +- 3 files changed, 151 insertions(+), 7 deletions(-) diff --git a/xdsl/transforms/csl_stencil_to_csl_wrapper.py b/xdsl/transforms/csl_stencil_to_csl_wrapper.py index 695b9cd9c6..59ccc95225 100644 --- a/xdsl/transforms/csl_stencil_to_csl_wrapper.py +++ b/xdsl/transforms/csl_stencil_to_csl_wrapper.py @@ -3,16 +3,20 @@ from xdsl.builder import ImplicitBuilder from xdsl.context import MLContext -from xdsl.dialects import arith, builtin, func, memref, stencil +from xdsl.dialects import arith, builtin, func, llvm, memref, stencil from xdsl.dialects.builtin import ( + AnyMemRefType, AnyMemRefTypeConstr, AnyTensorTypeConstr, + IndexType, IntegerAttr, + IntegerType, ShapedType, + Signedness, TensorType, ) from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper -from xdsl.ir import Attribute, BlockArgument, Operation, SSAValue +from xdsl.ir import Attribute, BlockArgument, Operation, OpResult, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -26,6 +30,22 @@ from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr +_TIMER_START = "timer_start" +_TIMER_END = "timer_end" +_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END] + + +def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None: + """ + Return the enclosing csl_wrapper.module + """ + parent_op = op.parent_op() + while parent_op: + if isinstance(parent_op, csl_wrapper.ModuleOp): + return parent_op + parent_op = parent_op.parent_op() + return None + @dataclass(frozen=True) class ConvertStencilFuncToModuleWrappedPattern(RewritePattern): @@ -40,6 +60,10 @@ class ConvertStencilFuncToModuleWrappedPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): + # erase timer stubs + if op.is_declaration and op.sym_name.data in _TIMER_FUNC_NAMES: + rewriter.erase_matched_op() + return # find csl_stencil.apply ops, abort if there are none apply_ops = self.get_csl_stencil_apply_ops(op) if len(apply_ops) == 0: @@ -176,6 +200,7 @@ def _translate_function_args( ptr_converts: list[Operation] = [] export_ops: list[Operation] = [] cast_ops: list[Operation] = [] + import_ops: list[Operation] = [] for arg in args: arg_name = arg.name_hint or ("arg" + str(args.index(arg))) @@ -215,8 +240,49 @@ def _translate_function_args( arg_op_mapping.append(cast_op.outputs[0]) else: arg_op_mapping.append(alloc.memref) + # check if this looks like a timer + elif isinstance(arg.type, llvm.LLVMPointerType) and all( + isinstance(u.operation, llvm.StoreOp) + and isinstance(u.operation.value, OpResult) + and isinstance(u.operation.value.op, func.Call) + and u.operation.value.op.callee.string_value() == _TIMER_END + for u in arg.uses + ): + start_end_size = 3 + arg_t = memref.MemRefType( + IntegerType(16, Signedness.UNSIGNED), (2 * start_end_size,) + ) + arg_ops.append(alloc := memref.Alloc([], [], arg_t)) + ptr_converts.append( + address := csl.AddressOfOp( + operands=[alloc], + result_types=[ + csl.PtrType( + [ + arg_t.get_element_type(), + csl.PtrKindAttr(csl.PtrKind.MANY), + csl.PtrConstAttr(csl.PtrConst.VAR), + ] + ) + ], + ) + ) + export_ops.append(csl.SymbolExportOp(arg_name, SSAValue.get(address))) + arg_op_mapping.append(alloc.memref) + import_ops.append( + csl_wrapper.ImportOp( + "