Skip to content

Commit

Permalink
(dialects): adding csl_wrapper dialect (#2867)
Browse files Browse the repository at this point in the history
The `csl_wrapper` dialect is helps initialise CSL modules and manage
params.

* Properties of the `csl_wrapper.module` are passed as BlockArgs to both
layout_module and program_module
* The layout module takes `%x` and `%y` BlockArgs, as well as all
properties. The layout module offers various simplifications:
* The terminating `csl_wrapper.yield` op is lowered to `@set_tile_code`.
Any values yielded are passed as BlockArgs to the program module
* Any operation using `%x` or `%y` will automatically be placed at the
correct level of nesting for the generated `@set_tile_code` loop
* Structs do not need to be handled manually at this stage and should be
constructed automatically when lowering. At this level, we can simply
provide a list of field names and args to yield.
* The `csl_wrapper.import_module` op and `yield` can both take field
names directly, without the need to handle structs at this level (the
former exists solely for this purpose)
* The program module's BlockArgs are all `csl_wrapper.module` properties
as well as every field yielded by layout's `yield` op.
* The semantics of properties-as-BlockArgs should hopefully keep program
parameters nicely organised - i.e.., that we don't end up with separate,
diverging copies of anything. This should make it really straightforward
to manage program-wide properties before further lowering.

---------

Co-authored-by: n-io <[email protected]>
Co-authored-by: Anton Lydike <[email protected]>
  • Loading branch information
3 people authored Jul 11, 2024
1 parent 987abb5 commit 9f4a53b
Show file tree
Hide file tree
Showing 4 changed files with 528 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/dialects/test_csl_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from xdsl.builder import ImplicitBuilder
from xdsl.dialects import arith
from xdsl.dialects.builtin import IntegerAttr, IntegerType
from xdsl.dialects.csl import csl_wrapper


def test_get_layout_arg():
m_op = csl_wrapper.ModuleOp(
7, 8, params={"one": IntegerAttr(9, 16), "two": IntegerAttr(10, 16)}
)
assert m_op.get_layout_param("x") == m_op.layout_module.block.args[0]
assert m_op.get_layout_param("y") == m_op.layout_module.block.args[1]
assert m_op.get_layout_param("width") == m_op.layout_module.block.args[2]
assert m_op.get_layout_param("height") == m_op.layout_module.block.args[3]
assert m_op.get_layout_param("one") == m_op.layout_module.block.args[4]
assert m_op.get_layout_param("two") == m_op.layout_module.block.args[5]
assert len(m_op.layout_module.block.args) == 6


def test_get_program_arg():
m_op = csl_wrapper.ModuleOp(
7, 8, params={"one": IntegerAttr(9, 16), "two": IntegerAttr(10, 16)}
)
assert m_op.get_program_param("width") == m_op.program_module.block.args[0]
assert m_op.get_program_param("height") == m_op.program_module.block.args[1]
assert m_op.get_program_param("one") == m_op.program_module.block.args[2]
assert m_op.get_program_param("two") == m_op.program_module.block.args[3]
assert len(m_op.program_module.block.args) == 4


def test_update_program_args():
m_op = csl_wrapper.ModuleOp(
7, 8, params={"one": IntegerAttr(9, 16), "two": IntegerAttr(10, 16)}
)
assert len(m_op.program_module.block.args) == 4
with ImplicitBuilder(m_op.layout_module.block):
zero_const = arith.Constant(IntegerAttr(0, 16))
seven_const = arith.Constant(IntegerAttr(7, 32))
csl_wrapper.YieldOp.from_field_name_mapping(
{
"zero_param": zero_const,
"seven_param": seven_const,
}
)

m_op.update_program_block_args_from_layout()

assert len(m_op.program_module.block.args) == 6
assert m_op.get_program_param("width") == m_op.program_module.block.args[0]
assert m_op.get_program_param("height") == m_op.program_module.block.args[1]
assert m_op.get_program_param("one") == m_op.program_module.block.args[2]
assert m_op.get_program_param("two") == m_op.program_module.block.args[3]
assert m_op.get_program_param("zero_param") == m_op.program_module.block.args[4]
assert m_op.get_program_param("seven_param") == m_op.program_module.block.args[5]
assert m_op.program_module.block.args[4].type == IntegerType(16)
assert m_op.program_module.block.args[5].type == IntegerType(32)
101 changes: 101 additions & 0 deletions tests/filecheck/dialects/csl/csl-wrapper-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP

builtin.module {
"csl_wrapper.module"() <{"width"=10 : i16, "height"=10: i16, "params" = [
#csl_wrapper.param<"z_dim" default=4: i16>, #csl_wrapper.param<"pattern" : i16>
]}> ({
^0(%x: i16, %y: i16, %width: i16, %height: i16, %z_dim: i16, %pattern: i16):
%0 = arith.constant 0 : i16
%1 = "csl.get_color"(%0) : (i16) -> !csl.color

%routes = "csl_wrapper.import_module"(%pattern, %width, %height) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module
%memcpy = "csl_wrapper.import_module"(%width, %height, %1) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module

%compute_all_routes = "csl.member_call"(%routes, %x, %y, %height, %width, %pattern) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct
%memcpy_params = "csl.member_call"(%memcpy, %x) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct

%2 = arith.constant 1 : i16
%3 = arith.minsi %pattern, %2 : i16
%4 = arith.minsi %width, %x : i16
%5 = arith.minsi %height, %y : i16
%6 = arith.cmpi slt, %x, %3 : i16
%7 = arith.cmpi slt, %y, %3 : i16
%8 = arith.cmpi slt, %4, %pattern : i16
%9 = arith.cmpi slt, %5, %pattern : i16
%10 = arith.ori %6, %7 : i1
%11 = arith.ori %10, %8 : i1
%is_border_region_pe = arith.ori %11, %9 : i1

"csl_wrapper.yield"(%memcpy_params, %compute_all_routes, %is_border_region_pe) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> ()
}, {
^0(%width: i16, %height: i16, %z_dim: i16, %pattern: i16, %memcpy_params: !csl.comptime_struct, %stencil_comms_params: !csl.comptime_struct, %is_border_region_pe: i1):

func.func @gauss_seidel () {
func.return
}
"csl_wrapper.yield"() <{"fields" = []}> : () -> ()
}) : () -> ()
}


// CHECK: builtin.module {
// CHECK-NEXT: "csl_wrapper.module"() <{"width" = 10 : i16, "height" = 10 : i16, "params" = [#csl_wrapper.param<"z_dim" default=4 : i16>, #csl_wrapper.param<"pattern" : i16>]}> ({
// CHECK-NEXT: ^0(%x : i16, %y : i16, %width : i16, %height : i16, %z_dim : i16, %pattern : i16):
// CHECK-NEXT: %0 = arith.constant 0 : i16
// CHECK-NEXT: %1 = "csl.get_color"(%0) : (i16) -> !csl.color
// CHECK-NEXT: %routes = "csl_wrapper.import_module"(%pattern, %width, %height) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module
// CHECK-NEXT: %memcpy = "csl_wrapper.import_module"(%width, %height, %1) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module
// CHECK-NEXT: %compute_all_routes = "csl.member_call"(%routes, %x, %y, %height, %width, %pattern) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct
// CHECK-NEXT: %memcpy_params = "csl.member_call"(%memcpy, %x) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct
// CHECK-NEXT: %2 = arith.constant 1 : i16
// CHECK-NEXT: %3 = arith.minsi %pattern, %2 : i16
// CHECK-NEXT: %4 = arith.minsi %width, %x : i16
// CHECK-NEXT: %5 = arith.minsi %height, %y : i16
// CHECK-NEXT: %6 = arith.cmpi slt, %x, %3 : i16
// CHECK-NEXT: %7 = arith.cmpi slt, %y, %3 : i16
// CHECK-NEXT: %8 = arith.cmpi slt, %4, %pattern : i16
// CHECK-NEXT: %9 = arith.cmpi slt, %5, %pattern : i16
// CHECK-NEXT: %10 = arith.ori %6, %7 : i1
// CHECK-NEXT: %11 = arith.ori %10, %8 : i1
// CHECK-NEXT: %is_border_region_pe = arith.ori %11, %9 : i1
// CHECK-NEXT: "csl_wrapper.yield"(%memcpy_params, %compute_all_routes, %is_border_region_pe) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%width_1 : i16, %height_1 : i16, %z_dim_1 : i16, %pattern_1 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %is_border_region_pe_1 : i1):
// CHECK-NEXT: func.func @gauss_seidel() {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> ()
// CHECK-NEXT: }) : () -> ()
// CHECK-NEXT: }


// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "csl_wrapper.module"() <{"width" = 10 : i16, "height" = 10 : i16, "params" = [#csl_wrapper.param<"z_dim" default=4 : i16>, #csl_wrapper.param<"pattern" : i16>]}> ({
// CHECK-GENERIC-NEXT: ^0(%x : i16, %y : i16, %width : i16, %height : i16, %z_dim : i16, %pattern : i16):
// CHECK-GENERIC-NEXT: %0 = "arith.constant"() <{"value" = 0 : i16}> : () -> i16
// CHECK-GENERIC-NEXT: %1 = "csl.get_color"(%0) : (i16) -> !csl.color
// CHECK-GENERIC-NEXT: %routes = "csl_wrapper.import_module"(%pattern, %width, %height) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module
// CHECK-GENERIC-NEXT: %memcpy = "csl_wrapper.import_module"(%width, %height, %1) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module
// CHECK-GENERIC-NEXT: %compute_all_routes = "csl.member_call"(%routes, %x, %y, %height, %width, %pattern) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct
// CHECK-GENERIC-NEXT: %memcpy_params = "csl.member_call"(%memcpy, %x) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct
// CHECK-GENERIC-NEXT: %2 = "arith.constant"() <{"value" = 1 : i16}> : () -> i16
// CHECK-GENERIC-NEXT: %3 = "arith.minsi"(%pattern, %2) : (i16, i16) -> i16
// CHECK-GENERIC-NEXT: %4 = "arith.minsi"(%width, %x) : (i16, i16) -> i16
// CHECK-GENERIC-NEXT: %5 = "arith.minsi"(%height, %y) : (i16, i16) -> i16
// CHECK-GENERIC-NEXT: %6 = "arith.cmpi"(%x, %3) <{"predicate" = 2 : i64}> : (i16, i16) -> i1
// CHECK-GENERIC-NEXT: %7 = "arith.cmpi"(%y, %3) <{"predicate" = 2 : i64}> : (i16, i16) -> i1
// CHECK-GENERIC-NEXT: %8 = "arith.cmpi"(%4, %pattern) <{"predicate" = 2 : i64}> : (i16, i16) -> i1
// CHECK-GENERIC-NEXT: %9 = "arith.cmpi"(%5, %pattern) <{"predicate" = 2 : i64}> : (i16, i16) -> i1
// CHECK-GENERIC-NEXT: %10 = "arith.ori"(%6, %7) : (i1, i1) -> i1
// CHECK-GENERIC-NEXT: %11 = "arith.ori"(%10, %8) : (i1, i1) -> i1
// CHECK-GENERIC-NEXT: %is_border_region_pe = "arith.ori"(%11, %9) : (i1, i1) -> i1
// CHECK-GENERIC-NEXT: "csl_wrapper.yield"(%memcpy_params, %compute_all_routes, %is_border_region_pe) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> ()
// CHECK-GENERIC-NEXT: }, {
// CHECK-GENERIC-NEXT: ^1(%width_1 : i16, %height_1 : i16, %z_dim_1 : i16, %pattern_1 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %is_border_region_pe_1 : i1):
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel", "function_type" = () -> ()}> ({
// CHECK-GENERIC-NEXT: "func.return"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
6 changes: 6 additions & 0 deletions xdsl/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def get_csl_stencil():

return CSL_STENCIL

def get_csl_wrapper():
from xdsl.dialects.csl.csl_wrapper import CSL_WRAPPER

return CSL_WRAPPER

def get_dmp():
from xdsl.dialects.experimental.dmp import DMP

Expand Down Expand Up @@ -294,6 +299,7 @@ def get_transform():
"comb": get_comb,
"csl": get_csl,
"csl_stencil": get_csl_stencil,
"csl_wrapper": get_csl_wrapper,
"dmp": get_dmp,
"fir": get_fir,
"fsm": get_fsm,
Expand Down
Loading

0 comments on commit 9f4a53b

Please sign in to comment.