Skip to content

Commit

Permalink
transformations: Support devito timers in the csl pipeline (#3312)
Browse files Browse the repository at this point in the history
Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io committed Oct 17, 2024
1 parent 017abd1 commit 9be95f3
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 7 deletions.
148 changes: 145 additions & 3 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@

from xdsl.builder import ImplicitBuilder
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, memref, stencil
from xdsl.dialects import arith, builtin, func, llvm, memref, stencil
from xdsl.dialects.builtin import (
AnyMemRefType,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
IndexType,
IntegerAttr,
IntegerType,
ShapedType,
Signedness,
TensorType,
)
from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper
from xdsl.ir import Attribute, BlockArgument, Operation, SSAValue
from xdsl.ir import Attribute, BlockArgument, Operation, OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -26,6 +30,22 @@
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

_TIMER_START = "timer_start"
_TIMER_END = "timer_end"
_TIMER_FUNC_NAMES = [_TIMER_START, _TIMER_END]


def _get_module_wrapper(op: Operation) -> csl_wrapper.ModuleOp | None:
"""
Return the enclosing csl_wrapper.module
"""
parent_op = op.parent_op()
while parent_op:
if isinstance(parent_op, csl_wrapper.ModuleOp):
return parent_op
parent_op = parent_op.parent_op()
return None


@dataclass(frozen=True)
class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):
Expand All @@ -40,6 +60,10 @@ class ConvertStencilFuncToModuleWrappedPattern(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# erase timer stubs
if op.is_declaration and op.sym_name.data in _TIMER_FUNC_NAMES:
rewriter.erase_matched_op()
return
# find csl_stencil.apply ops, abort if there are none
apply_ops = self.get_csl_stencil_apply_ops(op)
if len(apply_ops) == 0:
Expand Down Expand Up @@ -176,6 +200,7 @@ def _translate_function_args(
ptr_converts: list[Operation] = []
export_ops: list[Operation] = []
cast_ops: list[Operation] = []
import_ops: list[Operation] = []

for arg in args:
arg_name = arg.name_hint or ("arg" + str(args.index(arg)))
Expand Down Expand Up @@ -215,8 +240,49 @@ def _translate_function_args(
arg_op_mapping.append(cast_op.outputs[0])
else:
arg_op_mapping.append(alloc.memref)
# check if this looks like a timer
elif isinstance(arg.type, llvm.LLVMPointerType) and all(
isinstance(u.operation, llvm.StoreOp)
and isinstance(u.operation.value, OpResult)
and isinstance(u.operation.value.op, func.Call)
and u.operation.value.op.callee.string_value() == _TIMER_END
for u in arg.uses
):
start_end_size = 3
arg_t = memref.MemRefType(
IntegerType(16, Signedness.UNSIGNED), (2 * start_end_size,)
)
arg_ops.append(alloc := memref.Alloc([], [], arg_t))
ptr_converts.append(
address := csl.AddressOfOp(
operands=[alloc],
result_types=[
csl.PtrType(
[
arg_t.get_element_type(),
csl.PtrKindAttr(csl.PtrKind.MANY),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)
],
)
)
export_ops.append(csl.SymbolExportOp(arg_name, SSAValue.get(address)))
arg_op_mapping.append(alloc.memref)
import_ops.append(
csl_wrapper.ImportOp(
"<time>",
field_name_mapping={},
)
)

return [*arg_ops, *cast_ops, *ptr_converts, *export_ops], arg_op_mapping
return [
*arg_ops,
*cast_ops,
*ptr_converts,
*export_ops,
*import_ops,
], arg_op_mapping

def initialise_layout_module(self, module_op: csl_wrapper.ModuleOp):
"""Initialises the layout_module (wrapper block) by setting up (esp. stencil-related) program params"""
Expand Down Expand Up @@ -319,6 +385,81 @@ def initialise_program_module(
module_op.program_module.block.add_op(csl_wrapper.YieldOp([], []))


@dataclass(frozen=True)
class LowerTimerFuncCall(RewritePattern):
"""
Lowers calls to the start and end timer to csl API calls.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: llvm.StoreOp, rewriter: PatternRewriter, /):
if (
not isinstance(end_call := op.value.owner, func.Call)
or not end_call.callee.string_value() == _TIMER_END
or not (isinstance(start_call := end_call.arguments[0].owner, func.Call))
or not start_call.callee.string_value() == _TIMER_START
or not (wrapper := _get_module_wrapper(op))
or not isa(op.ptr.type, AnyMemRefType)
):
return

time_lib = wrapper.get_program_import("<time>")

three_elem_ptr_type = csl.PtrType(
[
memref.MemRefType(op.ptr.type.get_element_type(), (3,)),
csl.PtrKindAttr(csl.PtrKind.SINGLE),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)

rewriter.insert_op(
[
three := arith.Constant.from_int_and_width(3, IndexType()),
load_three := memref.Load.get(op.ptr, [three]),
addr_of := csl.AddressOfOp(
operands=[load_three],
result_types=[
csl.PtrType(
[
op.ptr.type.get_element_type(),
csl.PtrKindAttr(csl.PtrKind.SINGLE),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)
],
),
ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
csl.MemberCallOp("disable_tsc", None, time_lib, []),
],
InsertPoint.before(end_call),
)
rewriter.insert_op(
[
addr_of := csl.AddressOfOp(
operands=[op.ptr],
result_types=[
csl.PtrType(
[
op.ptr.type.get_element_type(),
csl.PtrKindAttr(csl.PtrKind.MANY),
csl.PtrConstAttr(csl.PtrConst.VAR),
]
)
],
),
ptrcast := csl.PtrCastOp(addr_of, three_elem_ptr_type),
csl.MemberCallOp("enable_tsc", None, time_lib, []),
csl.MemberCallOp("get_timestamp", None, time_lib, [ptrcast]),
],
InsertPoint.before(start_call),
)
rewriter.erase_op(op)
rewriter.erase_op(end_call)
rewriter.erase_op(start_call)


@dataclass(frozen=True)
class CslStencilToCslWrapperPass(ModulePass):
"""
Expand All @@ -333,6 +474,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
GreedyRewritePatternApplier(
[
ConvertStencilFuncToModuleWrappedPattern(),
LowerTimerFuncCall(),
]
),
apply_recursively=False,
Expand Down
7 changes: 4 additions & 3 deletions xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
class FuncOpTensorize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /):
for arg in op.args:
if isa(arg.type, FieldType[Attribute]):
op.replace_argument_type(arg, stencil_field_to_tensor(arg.type))
if not op.is_declaration:
for arg in op.args:
if isa(arg.type, FieldType[Attribute]):
op.replace_argument_type(arg, stencil_field_to_tensor(arg.type))


def is_tensorized(
Expand Down
3 changes: 2 additions & 1 deletion xdsl/transforms/stencil_shape_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def convert_type(self, typ: stencil.FieldType[Attribute], /) -> Attribute | None
class FuncOpShapeUpdate(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
op.update_function_type()
if not op.is_declaration:
op.update_function_type()


@dataclass(frozen=True)
Expand Down

0 comments on commit 9be95f3

Please sign in to comment.