From 877b70c8c005111c3b85aa4d0b9c9908034b182e Mon Sep 17 00:00:00 2001 From: Nicolai Stawinoga <36768051+n-io@users.noreply.github.com> Date: Tue, 15 Oct 2024 16:00:49 +0200 Subject: [PATCH] transformations: (lower-csl-stencil) Add iter args to first region (#3304) Translate block args for both regions alike in `csl_stencil.apply` lowering. Co-authored-by: n-io --- .../transforms/lower-csl-stencil.mlir | 232 +++++++++++++++++- xdsl/transforms/lower_csl_stencil.py | 45 +++- 2 files changed, 257 insertions(+), 20 deletions(-) diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index 552e935005..b720f030c2 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -96,13 +96,13 @@ builtin.module { // CHECK-NEXT: csl.func @gauss_seidel_func() { // CHECK-NEXT: %accumulator = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: %37 = arith.constant 2 : i16 -// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb2}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb2}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: %40 = memref.subview %arg0[1] [510] [1] : memref<512xf32> to memref<510xf32> // CHECK-NEXT: "csl.member_call"(%34, %40, %37, %38, %39) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @receive_chunk_cb1(%offset : i16) { +// CHECK-NEXT: csl.func @receive_chunk_cb2(%offset : i16) { // CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index // CHECK-NEXT: %41 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> // CHECK-NEXT: %42 = arith.constant 4 : i16 @@ -113,7 +113,7 @@ builtin.module { // CHECK-NEXT: "csl.fadds"(%45, %45, %46) : (!csl, !csl, !csl) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @done_exchange_cb1() { +// CHECK-NEXT: csl.func @done_exchange_cb2() { // CHECK-NEXT: %47 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> // CHECK-NEXT: %48 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> // CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %48) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () @@ -294,13 +294,13 @@ builtin.module { // CHECK-NEXT: %arg12 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %accumulator_1 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: %82 = arith.constant 1 : i16 -// CHECK-NEXT: %83 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb2}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %84 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb2}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %83 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb3}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %84 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb3}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: %85 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> // CHECK-NEXT: "csl.member_call"(%68, %85, %82, %83, %84) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @receive_chunk_cb2(%offset_2 : i16) { +// CHECK-NEXT: csl.func @receive_chunk_cb3(%offset_2 : i16) { // CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index // CHECK-NEXT: %86 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> // CHECK-NEXT: %87 = arith.constant 4 : i16 @@ -312,7 +312,7 @@ builtin.module { // CHECK-NEXT: "memref.copy"(%86, %86) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @done_exchange_cb2() { +// CHECK-NEXT: csl.func @done_exchange_cb3() { // CHECK-NEXT: %arg12_1 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %arg11_1 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: scf.if %arg9 { @@ -477,7 +477,223 @@ builtin.module { // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () // CHECK-NEXT: }) : () -> () + "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=511 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "chunk_reduce_only", "width" = 1024 : i16}> ({ + ^0(%arg0 : i16, %arg1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): + %0 = arith.constant 1 : i16 + %1 = arith.constant 0 : i16 + %2 = "csl.get_color"(%1) : (i16) -> !csl.color + %3 = "csl_wrapper.import"(%arg2, %arg3, %2) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module + %4 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module + %5 = "csl.member_call"(%4, %arg0, %arg1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct + %6 = "csl.member_call"(%3, %arg0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct + %7 = arith.subi %arg5, %0 : i16 + %8 = arith.subi %arg2, %arg0 : i16 + %9 = arith.subi %arg3, %arg1 : i16 + %10 = arith.cmpi slt, %arg0, %7 : i16 + %11 = arith.cmpi slt, %arg1, %7 : i16 + %12 = arith.cmpi slt, %8, %arg5 : i16 + %13 = arith.cmpi slt, %9, %arg5 : i16 + %14 = arith.ori %10, %11 : i1 + %15 = arith.ori %14, %12 : i1 + %16 = arith.ori %15, %13 : i1 + "csl_wrapper.yield"(%6, %5, %16) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () + }, { + ^1(%arg0_1 : i16, %arg1_1 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): + %17 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module + %18 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module + %19 = memref.alloc() : memref<511xf32> + %20 = memref.alloc() : memref<511xf32> + %21 = "csl.addressof"(%19) : (memref<511xf32>) -> !csl.ptr, #csl> + %22 = "csl.addressof"(%20) : (memref<511xf32>) -> !csl.ptr, #csl> + "csl.export"(%21) <{"type" = !csl.ptr, #csl>, "var_name" = "arg0"}> : (!csl.ptr, #csl>) -> () + "csl.export"(%22) <{"type" = !csl.ptr, #csl>, "var_name" = "arg1"}> : (!csl.ptr, #csl>) -> () + "csl.export"() <{"type" = () -> (), "var_name" = @chunk_reduce_only}> : () -> () + %23 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var + %24 = "csl.variable"() : () -> !csl.var> + %25 = "csl.variable"() : () -> !csl.var> + csl.func @chunk_reduce_only() { + %26 = arith.constant 0 : index + %27 = arith.constant 1000 : index + %28 = arith.constant 1 : index + "csl.store_var"(%24, %19) : (!csl.var>, memref<511xf32>) -> () + "csl.store_var"(%25, %20) : (!csl.var>, memref<511xf32>) -> () + csl.activate local, 1 : i32 + csl.return + } + csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ + %29 = arith.constant 1000 : i16 + %30 = "csl.load_var"(%23) : (!csl.var) -> i16 + %31 = arith.cmpi slt, %30, %29 : i16 + scf.if %31 { + "csl.call"() <{"callee" = @for_body0}> : () -> () + } else { + "csl.call"() <{"callee" = @for_post0}> : () -> () + } + csl.return + } + csl.func @for_body0() { + %arg10 = "csl.load_var"(%23) : (!csl.var) -> i16 + %arg11 = "csl.load_var"(%24) : (!csl.var>) -> memref<511xf32> + %arg12 = "csl.load_var"(%25) : (!csl.var>) -> memref<511xf32> + %32 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> + csl_stencil.apply(%arg11 : memref<511xf32>, %32 : memref<510xf32>, %arg11 : memref<511xf32>, %arg12 : memref<511xf32>, %arg9 : i1) outs (%arg12 : memref<511xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + ^2(%arg13 : memref<2x510xf32>, %arg14 : index, %arg15 : memref<510xf32>, %arg16 : memref<511xf32>): + %33 = arith.constant dense<1.234500e-01> : memref<510xf32> + %34 = csl_stencil.access %arg13[1, 0] : memref<2x510xf32> + %35 = memref.subview %arg16[1] [510] [1] : memref<511xf32> to memref<510xf32, strided<[1], offset: 1>> + %36 = csl_stencil.access %arg13[0, 1] : memref<2x510xf32> + %37 = memref.subview %arg15[%arg14] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> + "csl.fadds"(%37, %35, %36) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: 1>>, memref<510xf32>) -> () + "csl.fadds"(%37, %37, %34) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () + %38 = arith.constant 1.234500e-01 : f32 + "csl.fmuls"(%37, %37, %38) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, f32) -> () + csl_stencil.yield %arg15 : memref<510xf32> + }, { + ^3(%arg13_1 : memref<511xf32>, %arg14_1 : memref<510xf32>, %39 : memref<511xf32>, %40 : i1): + scf.if %40 { + } else { + %41 = memref.subview %39[0] [510] [1] : memref<511xf32> to memref<510xf32> + "memref.copy"(%arg14_1, %41) : (memref<510xf32>, memref<510xf32>) -> () + } + "csl.call"() <{"callee" = @for_inc0}> : () -> () + csl_stencil.yield + }) to <[0, 0], [1, 1]> + csl.return + } + csl.func @for_inc0() { + %33 = arith.constant 1 : i16 + %34 = "csl.load_var"(%23) : (!csl.var) -> i16 + %35 = arith.addi %34, %33 : i16 + "csl.store_var"(%23, %35) : (!csl.var, i16) -> () + %36 = "csl.load_var"(%24) : (!csl.var>) -> memref<511xf32> + %37 = "csl.load_var"(%25) : (!csl.var>) -> memref<511xf32> + "csl.store_var"(%24, %37) : (!csl.var>, memref<511xf32>) -> () + "csl.store_var"(%25, %36) : (!csl.var>, memref<511xf32>) -> () + csl.activate local, 1 : i32 + csl.return + } + csl.func @for_post0() { + "csl.member_call"(%17) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () + csl.return + } + "csl_wrapper.yield"() <{"fields" = []}> : () -> () + }) : () -> () +// CHECK-NEXT: "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=511 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "chunk_reduce_only", "width" = 1024 : i16}> ({ +// CHECK-NEXT: ^6(%arg0_4 : i16, %arg1_4 : i16, %arg2_2 : i16, %arg3_2 : i16, %arg4_2 : i16, %arg5_2 : i16, %arg6_2 : i16, %arg7_2 : i16, %arg8_2 : i16): +// CHECK-NEXT: %158 = arith.constant 1 : i16 +// CHECK-NEXT: %159 = arith.constant 0 : i16 +// CHECK-NEXT: %160 = "csl.get_color"(%159) : (i16) -> !csl.color +// CHECK-NEXT: %161 = "csl_wrapper.import"(%arg2_2, %arg3_2, %160) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %162 = "csl_wrapper.import"(%arg5_2, %arg2_2, %arg3_2) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %163 = "csl.member_call"(%162, %arg0_4, %arg1_4, %arg2_2, %arg3_2, %arg5_2) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %164 = "csl.member_call"(%161, %arg0_4) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %165 = arith.subi %arg5_2, %158 : i16 +// CHECK-NEXT: %166 = arith.subi %arg2_2, %arg0_4 : i16 +// CHECK-NEXT: %167 = arith.subi %arg3_2, %arg1_4 : i16 +// CHECK-NEXT: %168 = arith.cmpi slt, %arg0_4, %165 : i16 +// CHECK-NEXT: %169 = arith.cmpi slt, %arg1_4, %165 : i16 +// CHECK-NEXT: %170 = arith.cmpi slt, %166, %arg5_2 : i16 +// CHECK-NEXT: %171 = arith.cmpi slt, %167, %arg5_2 : i16 +// CHECK-NEXT: %172 = arith.ori %168, %169 : i1 +// CHECK-NEXT: %173 = arith.ori %172, %170 : i1 +// CHECK-NEXT: %174 = arith.ori %173, %171 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%164, %163, %174) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: }, { +// CHECK-NEXT: ^7(%arg0_5 : i16, %arg1_5 : i16, %arg2_3 : i16, %arg3_3 : i16, %arg4_3 : i16, %arg5_3 : i16, %arg6_3 : i16, %arg7_3 : !csl.comptime_struct, %arg8_3 : !csl.comptime_struct, %arg9_1 : i1): +// CHECK-NEXT: %175 = "csl_wrapper.import"(%arg7_3) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %176 = "csl_wrapper.import"(%arg3_3, %arg5_3, %arg8_3) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %177 = memref.alloc() : memref<511xf32> +// CHECK-NEXT: %178 = memref.alloc() : memref<511xf32> +// CHECK-NEXT: %179 = "csl.addressof"(%177) : (memref<511xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %180 = "csl.addressof"(%178) : (memref<511xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%179) <{"type" = !csl.ptr, #csl>, "var_name" = "arg0"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%180) <{"type" = !csl.ptr, #csl>, "var_name" = "arg1"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @chunk_reduce_only}> : () -> () +// CHECK-NEXT: %181 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var +// CHECK-NEXT: %182 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %183 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: csl.func @chunk_reduce_only() { +// CHECK-NEXT: %184 = arith.constant 0 : index +// CHECK-NEXT: %185 = arith.constant 1000 : index +// CHECK-NEXT: %186 = arith.constant 1 : index +// CHECK-NEXT: "csl.store_var"(%182, %177) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%183, %178) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: csl.activate local, 1 : i32 +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ +// CHECK-NEXT: %187 = arith.constant 1000 : i16 +// CHECK-NEXT: %188 = "csl.load_var"(%181) : (!csl.var) -> i16 +// CHECK-NEXT: %189 = arith.cmpi slt, %188, %187 : i16 +// CHECK-NEXT: scf.if %189 { +// CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: "csl.call"() <{"callee" = @for_post0}> : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @for_body0() { +// CHECK-NEXT: %arg10_1 = "csl.load_var"(%181) : (!csl.var) -> i16 +// CHECK-NEXT: %arg11_2 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %arg12_2 = "csl.load_var"(%183) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %accumulator_3 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> +// CHECK-NEXT: %190 = arith.constant 1 : i16 +// CHECK-NEXT: %191 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %192 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %193 = memref.subview %arg11_2[0] [510] [1] : memref<511xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%176, %193, %190, %191, %192) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_6 : i16) { +// CHECK-NEXT: %offset_7 = arith.index_cast %offset_6 : i16 to index +// CHECK-NEXT: %arg11_3 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %194 = arith.constant dense<1.234500e-01> : memref<510xf32> +// CHECK-NEXT: %195 = arith.constant 1 : i16 +// CHECK-NEXT: %196 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %197 = "csl.member_call"(%176, %196, %195) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %198 = builtin.unrealized_conversion_cast %197 : !csl to memref<510xf32> +// CHECK-NEXT: %199 = memref.subview %arg11_3[1] [510] [1] : memref<511xf32> to memref<510xf32, strided<[1], offset: 1>> +// CHECK-NEXT: %200 = arith.constant 1 : i16 +// CHECK-NEXT: %201 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %202 = "csl.member_call"(%176, %201, %200) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %203 = builtin.unrealized_conversion_cast %202 : !csl to memref<510xf32> +// CHECK-NEXT: %204 = memref.subview %accumulator_3[%offset_7] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%204, %199, %203) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: 1>>, memref<510xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%204, %204, %198) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () +// CHECK-NEXT: %205 = arith.constant 1.234500e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%204, %204, %205) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, f32) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @done_exchange_cb1() { +// CHECK-NEXT: %arg12_3 = "csl.load_var"(%183) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %arg11_4 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: scf.if %arg9_1 { +// CHECK-NEXT: } else { +// CHECK-NEXT: %206 = memref.subview %arg12_3[0] [510] [1] : memref<511xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator_3, %206) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: } +// CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @for_inc0() { +// CHECK-NEXT: %207 = arith.constant 1 : i16 +// CHECK-NEXT: %208 = "csl.load_var"(%181) : (!csl.var) -> i16 +// CHECK-NEXT: %209 = arith.addi %208, %207 : i16 +// CHECK-NEXT: "csl.store_var"(%181, %209) : (!csl.var, i16) -> () +// CHECK-NEXT: %210 = "csl.load_var"(%182) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: %211 = "csl.load_var"(%183) : (!csl.var>) -> memref<511xf32> +// CHECK-NEXT: "csl.store_var"(%182, %211) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%183, %210) : (!csl.var>, memref<511xf32>) -> () +// CHECK-NEXT: csl.activate local, 1 : i32 +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @for_post0() { +// CHECK-NEXT: "csl.member_call"(%175) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () +// CHECK-NEXT: }) : () -> () } // CHECK-NEXT: } diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 231a318152..79476b45c5 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -13,7 +13,15 @@ i16, ) from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper -from xdsl.ir import Attribute, Block, Operation, OpResult, Region, SSAValue +from xdsl.ir import ( + Attribute, + Block, + BlockArgument, + Operation, + OpResult, + Region, + SSAValue, +) from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -241,18 +249,31 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, (op.done_exchange.block.args[0], op.field), *arg_mapping, ]: - if isinstance(arg, OpResult) and arg.op.parent == op.parent: - if not ( - isinstance(arg.op, csl.LoadVarOp) or is_side_effect_free(arg.op) - ): - raise ValueError( - "Can only promote csl.LoadVarOp or side_effect_free op" - ) - rewriter.insert_op( - new_arg := arg.op.clone(), - InsertPoint.at_start(op.done_exchange.block), + self._replace_block_arg(block_arg, arg, op.done_exchange, op, rewriter) + for block_arg, arg in zip( + op.receive_chunk.block.args[3:], + op.args[: len(op.receive_chunk.block.args) - 3], + ): + self._replace_block_arg(block_arg, arg, op.receive_chunk, op, rewriter) + + @staticmethod + def _replace_block_arg( + block_arg: BlockArgument, + arg: SSAValue, + region: Region, + apply: csl_stencil.ApplyOp, + rewriter: PatternRewriter, + ): + if isinstance(arg, OpResult) and arg.op.parent == apply.parent: + if not (isinstance(arg.op, csl.LoadVarOp) or is_side_effect_free(arg.op)): + raise ValueError( + "Can only promote csl.LoadVarOp or side_effect_free op" ) - block_arg.replace_by(SSAValue.get(new_arg)) + rewriter.insert_op( + new_arg := arg.op.clone(), + InsertPoint.at_start(region.block), + ) + block_arg.replace_by(SSAValue.get(new_arg)) @dataclass(frozen=True)