-
Notifications
You must be signed in to change notification settings - Fork 73
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(dialects): adding csl_wrapper dialect (#2867)
The `csl_wrapper` dialect is helps initialise CSL modules and manage params. * Properties of the `csl_wrapper.module` are passed as BlockArgs to both layout_module and program_module * The layout module takes `%x` and `%y` BlockArgs, as well as all properties. The layout module offers various simplifications: * The terminating `csl_wrapper.yield` op is lowered to `@set_tile_code`. Any values yielded are passed as BlockArgs to the program module * Any operation using `%x` or `%y` will automatically be placed at the correct level of nesting for the generated `@set_tile_code` loop * Structs do not need to be handled manually at this stage and should be constructed automatically when lowering. At this level, we can simply provide a list of field names and args to yield. * The `csl_wrapper.import_module` op and `yield` can both take field names directly, without the need to handle structs at this level (the former exists solely for this purpose) * The program module's BlockArgs are all `csl_wrapper.module` properties as well as every field yielded by layout's `yield` op. * The semantics of properties-as-BlockArgs should hopefully keep program parameters nicely organised - i.e.., that we don't end up with separate, diverging copies of anything. This should make it really straightforward to manage program-wide properties before further lowering. --------- Co-authored-by: n-io <[email protected]> Co-authored-by: Anton Lydike <[email protected]>
- Loading branch information
1 parent
987abb5
commit 9f4a53b
Showing
4 changed files
with
528 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from xdsl.builder import ImplicitBuilder | ||
from xdsl.dialects import arith | ||
from xdsl.dialects.builtin import IntegerAttr, IntegerType | ||
from xdsl.dialects.csl import csl_wrapper | ||
|
||
|
||
def test_get_layout_arg(): | ||
m_op = csl_wrapper.ModuleOp( | ||
7, 8, params={"one": IntegerAttr(9, 16), "two": IntegerAttr(10, 16)} | ||
) | ||
assert m_op.get_layout_param("x") == m_op.layout_module.block.args[0] | ||
assert m_op.get_layout_param("y") == m_op.layout_module.block.args[1] | ||
assert m_op.get_layout_param("width") == m_op.layout_module.block.args[2] | ||
assert m_op.get_layout_param("height") == m_op.layout_module.block.args[3] | ||
assert m_op.get_layout_param("one") == m_op.layout_module.block.args[4] | ||
assert m_op.get_layout_param("two") == m_op.layout_module.block.args[5] | ||
assert len(m_op.layout_module.block.args) == 6 | ||
|
||
|
||
def test_get_program_arg(): | ||
m_op = csl_wrapper.ModuleOp( | ||
7, 8, params={"one": IntegerAttr(9, 16), "two": IntegerAttr(10, 16)} | ||
) | ||
assert m_op.get_program_param("width") == m_op.program_module.block.args[0] | ||
assert m_op.get_program_param("height") == m_op.program_module.block.args[1] | ||
assert m_op.get_program_param("one") == m_op.program_module.block.args[2] | ||
assert m_op.get_program_param("two") == m_op.program_module.block.args[3] | ||
assert len(m_op.program_module.block.args) == 4 | ||
|
||
|
||
def test_update_program_args(): | ||
m_op = csl_wrapper.ModuleOp( | ||
7, 8, params={"one": IntegerAttr(9, 16), "two": IntegerAttr(10, 16)} | ||
) | ||
assert len(m_op.program_module.block.args) == 4 | ||
with ImplicitBuilder(m_op.layout_module.block): | ||
zero_const = arith.Constant(IntegerAttr(0, 16)) | ||
seven_const = arith.Constant(IntegerAttr(7, 32)) | ||
csl_wrapper.YieldOp.from_field_name_mapping( | ||
{ | ||
"zero_param": zero_const, | ||
"seven_param": seven_const, | ||
} | ||
) | ||
|
||
m_op.update_program_block_args_from_layout() | ||
|
||
assert len(m_op.program_module.block.args) == 6 | ||
assert m_op.get_program_param("width") == m_op.program_module.block.args[0] | ||
assert m_op.get_program_param("height") == m_op.program_module.block.args[1] | ||
assert m_op.get_program_param("one") == m_op.program_module.block.args[2] | ||
assert m_op.get_program_param("two") == m_op.program_module.block.args[3] | ||
assert m_op.get_program_param("zero_param") == m_op.program_module.block.args[4] | ||
assert m_op.get_program_param("seven_param") == m_op.program_module.block.args[5] | ||
assert m_op.program_module.block.args[4].type == IntegerType(16) | ||
assert m_op.program_module.block.args[5].type == IntegerType(32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// RUN: XDSL_ROUNDTRIP | ||
// RUN: XDSL_GENERIC_ROUNDTRIP | ||
|
||
builtin.module { | ||
"csl_wrapper.module"() <{"width"=10 : i16, "height"=10: i16, "params" = [ | ||
#csl_wrapper.param<"z_dim" default=4: i16>, #csl_wrapper.param<"pattern" : i16> | ||
]}> ({ | ||
^0(%x: i16, %y: i16, %width: i16, %height: i16, %z_dim: i16, %pattern: i16): | ||
%0 = arith.constant 0 : i16 | ||
%1 = "csl.get_color"(%0) : (i16) -> !csl.color | ||
|
||
%routes = "csl_wrapper.import_module"(%pattern, %width, %height) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module | ||
%memcpy = "csl_wrapper.import_module"(%width, %height, %1) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module | ||
|
||
%compute_all_routes = "csl.member_call"(%routes, %x, %y, %height, %width, %pattern) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct | ||
%memcpy_params = "csl.member_call"(%memcpy, %x) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct | ||
|
||
%2 = arith.constant 1 : i16 | ||
%3 = arith.minsi %pattern, %2 : i16 | ||
%4 = arith.minsi %width, %x : i16 | ||
%5 = arith.minsi %height, %y : i16 | ||
%6 = arith.cmpi slt, %x, %3 : i16 | ||
%7 = arith.cmpi slt, %y, %3 : i16 | ||
%8 = arith.cmpi slt, %4, %pattern : i16 | ||
%9 = arith.cmpi slt, %5, %pattern : i16 | ||
%10 = arith.ori %6, %7 : i1 | ||
%11 = arith.ori %10, %8 : i1 | ||
%is_border_region_pe = arith.ori %11, %9 : i1 | ||
|
||
"csl_wrapper.yield"(%memcpy_params, %compute_all_routes, %is_border_region_pe) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () | ||
}, { | ||
^0(%width: i16, %height: i16, %z_dim: i16, %pattern: i16, %memcpy_params: !csl.comptime_struct, %stencil_comms_params: !csl.comptime_struct, %is_border_region_pe: i1): | ||
|
||
func.func @gauss_seidel () { | ||
func.return | ||
} | ||
"csl_wrapper.yield"() <{"fields" = []}> : () -> () | ||
}) : () -> () | ||
} | ||
|
||
|
||
// CHECK: builtin.module { | ||
// CHECK-NEXT: "csl_wrapper.module"() <{"width" = 10 : i16, "height" = 10 : i16, "params" = [#csl_wrapper.param<"z_dim" default=4 : i16>, #csl_wrapper.param<"pattern" : i16>]}> ({ | ||
// CHECK-NEXT: ^0(%x : i16, %y : i16, %width : i16, %height : i16, %z_dim : i16, %pattern : i16): | ||
// CHECK-NEXT: %0 = arith.constant 0 : i16 | ||
// CHECK-NEXT: %1 = "csl.get_color"(%0) : (i16) -> !csl.color | ||
// CHECK-NEXT: %routes = "csl_wrapper.import_module"(%pattern, %width, %height) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module | ||
// CHECK-NEXT: %memcpy = "csl_wrapper.import_module"(%width, %height, %1) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module | ||
// CHECK-NEXT: %compute_all_routes = "csl.member_call"(%routes, %x, %y, %height, %width, %pattern) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct | ||
// CHECK-NEXT: %memcpy_params = "csl.member_call"(%memcpy, %x) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct | ||
// CHECK-NEXT: %2 = arith.constant 1 : i16 | ||
// CHECK-NEXT: %3 = arith.minsi %pattern, %2 : i16 | ||
// CHECK-NEXT: %4 = arith.minsi %width, %x : i16 | ||
// CHECK-NEXT: %5 = arith.minsi %height, %y : i16 | ||
// CHECK-NEXT: %6 = arith.cmpi slt, %x, %3 : i16 | ||
// CHECK-NEXT: %7 = arith.cmpi slt, %y, %3 : i16 | ||
// CHECK-NEXT: %8 = arith.cmpi slt, %4, %pattern : i16 | ||
// CHECK-NEXT: %9 = arith.cmpi slt, %5, %pattern : i16 | ||
// CHECK-NEXT: %10 = arith.ori %6, %7 : i1 | ||
// CHECK-NEXT: %11 = arith.ori %10, %8 : i1 | ||
// CHECK-NEXT: %is_border_region_pe = arith.ori %11, %9 : i1 | ||
// CHECK-NEXT: "csl_wrapper.yield"(%memcpy_params, %compute_all_routes, %is_border_region_pe) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () | ||
// CHECK-NEXT: }, { | ||
// CHECK-NEXT: ^1(%width_1 : i16, %height_1 : i16, %z_dim_1 : i16, %pattern_1 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %is_border_region_pe_1 : i1): | ||
// CHECK-NEXT: func.func @gauss_seidel() { | ||
// CHECK-NEXT: func.return | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () | ||
// CHECK-NEXT: }) : () -> () | ||
// CHECK-NEXT: } | ||
|
||
|
||
// CHECK-GENERIC: "builtin.module"() ({ | ||
// CHECK-GENERIC-NEXT: "csl_wrapper.module"() <{"width" = 10 : i16, "height" = 10 : i16, "params" = [#csl_wrapper.param<"z_dim" default=4 : i16>, #csl_wrapper.param<"pattern" : i16>]}> ({ | ||
// CHECK-GENERIC-NEXT: ^0(%x : i16, %y : i16, %width : i16, %height : i16, %z_dim : i16, %pattern : i16): | ||
// CHECK-GENERIC-NEXT: %0 = "arith.constant"() <{"value" = 0 : i16}> : () -> i16 | ||
// CHECK-GENERIC-NEXT: %1 = "csl.get_color"(%0) : (i16) -> !csl.color | ||
// CHECK-GENERIC-NEXT: %routes = "csl_wrapper.import_module"(%pattern, %width, %height) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module | ||
// CHECK-GENERIC-NEXT: %memcpy = "csl_wrapper.import_module"(%width, %height, %1) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module | ||
// CHECK-GENERIC-NEXT: %compute_all_routes = "csl.member_call"(%routes, %x, %y, %height, %width, %pattern) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct | ||
// CHECK-GENERIC-NEXT: %memcpy_params = "csl.member_call"(%memcpy, %x) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct | ||
// CHECK-GENERIC-NEXT: %2 = "arith.constant"() <{"value" = 1 : i16}> : () -> i16 | ||
// CHECK-GENERIC-NEXT: %3 = "arith.minsi"(%pattern, %2) : (i16, i16) -> i16 | ||
// CHECK-GENERIC-NEXT: %4 = "arith.minsi"(%width, %x) : (i16, i16) -> i16 | ||
// CHECK-GENERIC-NEXT: %5 = "arith.minsi"(%height, %y) : (i16, i16) -> i16 | ||
// CHECK-GENERIC-NEXT: %6 = "arith.cmpi"(%x, %3) <{"predicate" = 2 : i64}> : (i16, i16) -> i1 | ||
// CHECK-GENERIC-NEXT: %7 = "arith.cmpi"(%y, %3) <{"predicate" = 2 : i64}> : (i16, i16) -> i1 | ||
// CHECK-GENERIC-NEXT: %8 = "arith.cmpi"(%4, %pattern) <{"predicate" = 2 : i64}> : (i16, i16) -> i1 | ||
// CHECK-GENERIC-NEXT: %9 = "arith.cmpi"(%5, %pattern) <{"predicate" = 2 : i64}> : (i16, i16) -> i1 | ||
// CHECK-GENERIC-NEXT: %10 = "arith.ori"(%6, %7) : (i1, i1) -> i1 | ||
// CHECK-GENERIC-NEXT: %11 = "arith.ori"(%10, %8) : (i1, i1) -> i1 | ||
// CHECK-GENERIC-NEXT: %is_border_region_pe = "arith.ori"(%11, %9) : (i1, i1) -> i1 | ||
// CHECK-GENERIC-NEXT: "csl_wrapper.yield"(%memcpy_params, %compute_all_routes, %is_border_region_pe) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () | ||
// CHECK-GENERIC-NEXT: }, { | ||
// CHECK-GENERIC-NEXT: ^1(%width_1 : i16, %height_1 : i16, %z_dim_1 : i16, %pattern_1 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %is_border_region_pe_1 : i1): | ||
// CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "gauss_seidel", "function_type" = () -> ()}> ({ | ||
// CHECK-GENERIC-NEXT: "func.return"() : () -> () | ||
// CHECK-GENERIC-NEXT: }) : () -> () | ||
// CHECK-GENERIC-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () | ||
// CHECK-GENERIC-NEXT: }) : () -> () | ||
// CHECK-GENERIC-NEXT: }) : () -> () |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.