From 27a519b323ac2d64b50c7f56effc439f947f5dcc Mon Sep 17 00:00:00 2001 From: Nicolai Stawinoga <36768051+n-io@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:16:04 +0200 Subject: [PATCH] dialects: (csl-stencil) Separate varargs for apply regions (#3316) Both regions of the `csl_stencil.apply` op could take args from a single `var_operand_def`. This is now separated out more cleanly into two separate `var_operand_def`s, one for each region. --------- Co-authored-by: n-io --- tests/dialects/test_csl_stencil.py | 2 +- .../dialects/csl/csl-stencil-canonicalize.mlir | 12 ++++++------ tests/filecheck/dialects/csl/csl-stencil-ops.mlir | 10 +++++----- .../transforms/convert-stencil-to-csl-stencil.mlir | 4 ++-- .../transforms/csl-stencil-handle-async-flow.mlir | 12 ++++++------ .../transforms/csl-stencil-materialize-stores.mlir | 4 ++-- .../transforms/csl-stencil-to-csl-wrapper.mlir | 4 ++-- .../filecheck/transforms/csl_stencil_bufferize.mlir | 4 ++-- tests/filecheck/transforms/lower-csl-stencil.mlir | 8 ++++---- tests/filecheck/transforms/lower-csl-wrapper.mlir | 2 +- xdsl/dialects/csl/csl_stencil.py | 13 +++++++++---- xdsl/transforms/convert_stencil_to_csl_stencil.py | 4 ++-- xdsl/transforms/csl_stencil_bufferize.py | 10 ++++++++-- xdsl/transforms/csl_stencil_materialize_stores.py | 8 +++++--- xdsl/transforms/lower_csl_stencil.py | 8 ++++---- 15 files changed, 59 insertions(+), 46 deletions(-) diff --git a/tests/dialects/test_csl_stencil.py b/tests/dialects/test_csl_stencil.py index 21d3c87e24..024cb90f90 100644 --- a/tests/dialects/test_csl_stencil.py +++ b/tests/dialects/test_csl_stencil.py @@ -29,7 +29,7 @@ def region1(args: tuple[SSAValue, ...]): AccessOp(t1, IndexAttr.get(0, 0), tens_t) apply = ApplyOp( - operands=[temp, mref, [], []], + operands=[temp, mref, [], [], []], properties={ "swaps": None, "topo": None, diff --git a/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir b/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir index 6a95b40014..cbefbb2493 100644 --- a/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir +++ b/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir @@ -6,7 +6,7 @@ builtin.module { %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> %1 = tensor.empty() : tensor<510xf32> - %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ + %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>): %6 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> %7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> @@ -18,7 +18,7 @@ builtin.module { stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> %10 = tensor.empty() : tensor<510xf32> - %11 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ + %11 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ ^0(%12 : tensor<4x255xf32>, %13 : index, %14 : tensor<510xf32>): %15 = csl_stencil.access %12[1, 0] : tensor<4x255xf32> %16 = "tensor.insert_slice"(%15, %14, %13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> @@ -30,7 +30,7 @@ builtin.module { stencil.store %11 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> %19 = tensor.empty() : tensor<510xf32> - %20 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %19 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ + %20 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %19 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ ^0(%21 : tensor<4x255xf32>, %22 : index, %23 : tensor<510xf32>): %24 = csl_stencil.access %21[1, 0] : tensor<4x255xf32> %25 = "tensor.insert_slice"(%24, %23, %22) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> @@ -49,7 +49,7 @@ builtin.module { // CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> // CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>): // CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> @@ -59,7 +59,7 @@ builtin.module { // CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> // CHECK-NEXT: }) // CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %3 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %3 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%4 : tensor<4x255xf32>, %5 : index, %6 : tensor<510xf32>): // CHECK-NEXT: %7 = csl_stencil.access %4[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %6, %5) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> @@ -69,7 +69,7 @@ builtin.module { // CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32> // CHECK-NEXT: }) // CHECK-NEXT: stencil.store %3 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %4 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %4 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%5 : tensor<4x255xf32>, %6 : index, %7 : tensor<510xf32>): // CHECK-NEXT: %8 = csl_stencil.access %5[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %7, %6) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> diff --git a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir index c7fe8e65d9..662b86b8ab 100644 --- a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir +++ b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir @@ -142,7 +142,7 @@ builtin.module { // CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> // CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%recv : tensor<4x255xf32>, %offset : index, %iter_arg : tensor<510xf32>): // CHECK-NEXT: %3 = csl_stencil.access %recv[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %4 = csl_stencil.access %recv[-1, 0] : tensor<4x255xf32> @@ -177,7 +177,7 @@ builtin.module { // CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>): // CHECK-GENERIC-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> // CHECK-GENERIC-NEXT: %1 = "tensor.empty"() : () -> tensor<510xf32> -// CHECK-GENERIC-NEXT: %2 = "csl_stencil.apply"(%0, %1) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: %2 = "csl_stencil.apply"(%0, %1) <{"num_chunks" = 2 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^1(%recv : tensor<4x255xf32>, %offset : index, %iter_arg : tensor<510xf32>): // CHECK-GENERIC-NEXT: %3 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index<[1, 0]>, "offset_mapping" = #stencil.index<[0, 1]>}> : (tensor<4x255xf32>) -> tensor<255xf32> // CHECK-GENERIC-NEXT: %4 = "csl_stencil.access"(%recv) <{"offset" = #stencil.index<[-1, 0]>, "offset_mapping" = #stencil.index<[0, 1]>}> : (tensor<4x255xf32>) -> tensor<255xf32> @@ -216,7 +216,7 @@ builtin.module { builtin.module { func.func @bufferized_stencil(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { %0 = tensor.empty() : tensor<510xf32> - csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) outs (%b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ + csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) outs (%b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ ^0(%1 : tensor<4x255xf32>, %2 : index, %3 : tensor<510xf32>): %4 = csl_stencil.access %1[1, 0] : tensor<4x255xf32> %5 = csl_stencil.access %1[-1, 0] : tensor<4x255xf32> @@ -245,7 +245,7 @@ builtin.module { //CHECK: builtin.module { //CHECK-NEXT: func.func @bufferized_stencil(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { //CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -//CHECK-NEXT: csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) outs (%b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +//CHECK-NEXT: csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) outs (%b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ //CHECK-NEXT: ^0(%1 : tensor<4x255xf32>, %2 : index, %3 : tensor<510xf32>): //CHECK-NEXT: %4 = csl_stencil.access %1[1, 0] : tensor<4x255xf32> //CHECK-NEXT: %5 = csl_stencil.access %1[-1, 0] : tensor<4x255xf32> @@ -275,7 +275,7 @@ builtin.module { // CHECK-GENERIC-NEXT: "func.func"() <{"sym_name" = "bufferized_stencil", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({ // CHECK-GENERIC-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>): // CHECK-GENERIC-NEXT: %0 = "tensor.empty"() : () -> tensor<510xf32> -// CHECK-GENERIC-NEXT: "csl_stencil.apply"(%a, %0, %b) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "csl_stencil.apply"(%a, %0, %b) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^1(%1 : tensor<4x255xf32>, %2 : index, %3 : tensor<510xf32>): // CHECK-GENERIC-NEXT: %4 = "csl_stencil.access"(%1) <{"offset" = #stencil.index<[1, 0]>, "offset_mapping" = #stencil.index<[0, 1]>}> : (tensor<4x255xf32>) -> tensor<255xf32> // CHECK-GENERIC-NEXT: %5 = "csl_stencil.access"(%1) <{"offset" = #stencil.index<[-1, 0]>, "offset_mapping" = #stencil.index<[0, 1]>}> : (tensor<4x255xf32>) -> tensor<255xf32> diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index d7b53644a0..de5bb37172 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -37,7 +37,7 @@ builtin.module { // CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> // CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>): // CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %7 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32> @@ -88,7 +88,7 @@ builtin.module { // CHECK-NEXT: func.func @bufferized(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): // CHECK-NEXT: %5 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %6 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> diff --git a/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir b/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir index 5b34e8d4a5..65c2620302 100644 --- a/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir +++ b/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir @@ -37,7 +37,7 @@ %39 = arith.constant 1000 : index %40 = arith.constant 1 : index %41, %42 = scf.for %arg2 = %38 to %39 step %40 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (memref<512xf32>, memref<512xf32>) { - csl_stencil.apply(%arg3 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg4 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg3 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg4 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg5 : memref<4x255xf32>, %arg6 : index, %arg7 : memref<510xf32>): %43 = csl_stencil.access %arg5[1, 0] : memref<4x255xf32> %44 = csl_stencil.access %arg5[-1, 0] : memref<4x255xf32> @@ -126,7 +126,7 @@ // CHECK-NEXT: %arg2 = "csl.load_var"(%38) : (!csl.var) -> i16 // CHECK-NEXT: %arg3 = "csl.load_var"(%39) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %arg4 = "csl.load_var"(%40) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg4 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg3 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg4 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ // CHECK-NEXT: ^2(%arg5 : memref<4x255xf32>, %arg6 : index, %arg7 : memref<510xf32>): // CHECK-NEXT: %47 = csl_stencil.access %arg5[1, 0] : memref<4x255xf32> // CHECK-NEXT: %48 = csl_stencil.access %arg5[-1, 0] : memref<4x255xf32> @@ -184,7 +184,7 @@ "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () %37 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> csl.func @sequential_kernels_func() { - csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg5 : memref<4x255xf32>, %arg6 : index, %arg7 : memref<510xf32>): csl_stencil.yield %arg7 : memref<510xf32> }, { @@ -192,7 +192,7 @@ %50 = arith.constant 1.666600e-01 : f32 csl_stencil.yield %arg6_1 : memref<510xf32> }) to <[0, 0], [1, 1]> - csl_stencil.apply(%arg1 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg0 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg1 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg0 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^4(%arg5 : memref<4x255xf32>, %arg6 : index, %arg7 : memref<510xf32>): csl_stencil.yield %arg7 : memref<510xf32> }, { @@ -222,7 +222,7 @@ // CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () // CHECK-NEXT: %61 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: csl.func @sequential_kernels_func() { -// CHECK-NEXT: csl_stencil.apply(%arg0_1 : memref<512xf32>, %61 : memref<510xf32>) outs (%arg1_1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg0_1 : memref<512xf32>, %61 : memref<510xf32>) outs (%arg1_1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ // CHECK-NEXT: ^4(%arg5 : memref<4x255xf32>, %arg6 : index, %arg7 : memref<510xf32>): // CHECK-NEXT: csl_stencil.yield %arg7 : memref<510xf32> // CHECK-NEXT: }, { @@ -234,7 +234,7 @@ // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @step0() { -// CHECK-NEXT: csl_stencil.apply(%arg1_1 : memref<512xf32>, %61 : memref<510xf32>) outs (%arg0_1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg1_1 : memref<512xf32>, %61 : memref<510xf32>) outs (%arg0_1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ // CHECK-NEXT: ^4(%arg5 : memref<4x255xf32>, %arg6 : index, %arg7 : memref<510xf32>): // CHECK-NEXT: csl_stencil.yield %arg7 : memref<510xf32> // CHECK-NEXT: }, { diff --git a/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir b/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir index 2a7a7ae88e..d0065b95fa 100644 --- a/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir +++ b/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir @@ -34,7 +34,7 @@ builtin.module { "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel}> : () -> () csl.func @gauss_seidel() { %23 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> - csl_stencil.apply(%19 : memref<512xf32>, %23 : memref<510xf32>) outs (%20 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%19 : memref<512xf32>, %23 : memref<510xf32>) outs (%20 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg10 : memref<4x510xf32>, %arg11 : index, %arg12 : memref<510xf32>): %24 = csl_stencil.access %arg10[1, 0] : memref<4x510xf32> %25 = csl_stencil.access %arg10[-1, 0] : memref<4x510xf32> @@ -98,7 +98,7 @@ builtin.module { // CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel}> : () -> () // CHECK-NEXT: csl.func @gauss_seidel() { // CHECK-NEXT: %23 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: csl_stencil.apply(%19 : memref<512xf32>, %23 : memref<510xf32>, %20 : memref<512xf32>, %arg9 : i1) outs (%20 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: csl_stencil.apply(%19 : memref<512xf32>, %23 : memref<510xf32>, %20 : memref<512xf32>, %arg9 : i1) outs (%20 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ // CHECK-NEXT: ^2(%arg10 : memref<4x510xf32>, %arg11 : index, %arg12 : memref<510xf32>): // CHECK-NEXT: %24 = csl_stencil.access %arg10[1, 0] : memref<4x510xf32> // CHECK-NEXT: %25 = csl_stencil.access %arg10[-1, 0] : memref<4x510xf32> diff --git a/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir b/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir index 4527d0daed..9de211db70 100644 --- a/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir +++ b/tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir @@ -72,7 +72,7 @@ func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, // CHECK-NEXT: csl.func @gauss_seidel() { // CHECK-NEXT: %40 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> // CHECK-NEXT: %41 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %42 = csl_stencil.apply(%40 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %41 : tensor<510xf32>, %c : memref<255xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: %42 = csl_stencil.apply(%40 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %41 : tensor<510xf32>, %c : memref<255xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^2(%43 : tensor<4x255xf32>, %44 : index, %45 : tensor<510xf32>, %46 : memref<255xf32>): // CHECK-NEXT: %47 = csl_stencil.access %43[1, 0] : tensor<4x255xf32> // CHECK-NEXT: %48 = csl_stencil.access %43[-1, 0] : tensor<4x255xf32> @@ -170,7 +170,7 @@ func.func private @timer_end(f64) -> f64 // CHECK-NEXT: "csl.member_call"(%81) <{"field" = "enable_tsc"}> : (!csl.imported_module) -> () // CHECK-NEXT: "csl.member_call"(%81, %83) <{"field" = "get_timestamp"}> : (!csl.imported_module, !csl.ptr, #csl, #csl>) -> () // CHECK-NEXT: %84 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: csl_stencil.apply(%arg0 : memref<512xf32>, %84 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg0 : memref<512xf32>, %84 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ // CHECK-NEXT: ^4(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>): // CHECK-NEXT: %85 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32> // CHECK-NEXT: %86 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> diff --git a/tests/filecheck/transforms/csl_stencil_bufferize.mlir b/tests/filecheck/transforms/csl_stencil_bufferize.mlir index d9d4a9d81d..7d7d846d07 100644 --- a/tests/filecheck/transforms/csl_stencil_bufferize.mlir +++ b/tests/filecheck/transforms/csl_stencil_bufferize.mlir @@ -3,7 +3,7 @@ builtin.module { func.func @bufferized_stencil(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { %0 = tensor.empty() : tensor<510xf32> - csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) outs (%b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ + csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) outs (%b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ ^0(%1 : tensor<4x255xf32>, %2 : index, %3 : tensor<510xf32>): %4 = csl_stencil.access %1[1, 0] : tensor<4x255xf32> %5 = csl_stencil.access %1[-1, 0] : tensor<4x255xf32> @@ -34,7 +34,7 @@ builtin.module { // CHECK-NEXT: func.func @bufferized_stencil(%a : memref<512xf32>, %b : memref<512xf32>) { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> // CHECK-NEXT: %1 = bufferization.to_memref %0 : memref<510xf32> -// CHECK-NEXT: csl_stencil.apply(%a : memref<512xf32>, %1 : memref<510xf32>) outs (%b : memref<512xf32>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: csl_stencil.apply(%a : memref<512xf32>, %1 : memref<510xf32>) outs (%b : memref<512xf32>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%2 : memref<4x255xf32>, %3 : index, %4 : memref<510xf32>): // CHECK-NEXT: %5 = bufferization.to_tensor %4 restrict writable : memref<510xf32> // CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : memref<4x255xf32> diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index b720f030c2..fa0b1c16e0 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -36,7 +36,7 @@ builtin.module { "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () csl.func @gauss_seidel_func() { %37 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> - csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>): %38 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32> %39 = csl_stencil.access %arg2[-1, 0] : memref<4x255xf32> @@ -185,7 +185,7 @@ builtin.module { %arg11 = "csl.load_var"(%24) : (!csl.var>) -> memref<512xf32> %arg12 = "csl.load_var"(%25) : (!csl.var>) -> memref<512xf32> %32 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> - csl_stencil.apply(%arg11 : memref<512xf32>, %32 : memref<510xf32>, %arg12 : memref<512xf32>, %arg9 : i1) outs (%arg12 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg11 : memref<512xf32>, %32 : memref<510xf32>, %arg12 : memref<512xf32>, %arg9 : i1) outs (%arg12 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg13 : memref<4x510xf32>, %arg14 : index, %arg15 : memref<510xf32>): %33 = csl_stencil.access %arg13[1, 0] : memref<4x510xf32> %34 = csl_stencil.access %arg13[-1, 0] : memref<4x510xf32> @@ -382,7 +382,7 @@ builtin.module { "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () csl.func @partial_access() { %37 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> - csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>): %38 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32> %39 = csl_stencil.access %arg2[-1, 0] : memref<4x255xf32> @@ -536,7 +536,7 @@ builtin.module { %arg11 = "csl.load_var"(%24) : (!csl.var>) -> memref<511xf32> %arg12 = "csl.load_var"(%25) : (!csl.var>) -> memref<511xf32> %32 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> - csl_stencil.apply(%arg11 : memref<511xf32>, %32 : memref<510xf32>, %arg11 : memref<511xf32>, %arg12 : memref<511xf32>, %arg9 : i1) outs (%arg12 : memref<511xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + csl_stencil.apply(%arg11 : memref<511xf32>, %32 : memref<510xf32>, %arg11 : memref<511xf32>, %arg12 : memref<511xf32>, %arg9 : i1) outs (%arg12 : memref<511xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ ^2(%arg13 : memref<2x510xf32>, %arg14 : index, %arg15 : memref<510xf32>, %arg16 : memref<511xf32>): %33 = arith.constant dense<1.234500e-01> : memref<510xf32> %34 = csl_stencil.access %arg13[1, 0] : memref<2x510xf32> diff --git a/tests/filecheck/transforms/lower-csl-wrapper.mlir b/tests/filecheck/transforms/lower-csl-wrapper.mlir index 2c1e6cdfd0..c6795ecda9 100644 --- a/tests/filecheck/transforms/lower-csl-wrapper.mlir +++ b/tests/filecheck/transforms/lower-csl-wrapper.mlir @@ -151,7 +151,7 @@ builtin.module { // CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () // CHECK-NEXT: csl.func @gauss_seidel_func() { // CHECK-NEXT: %scratchBuffer = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: csl_stencil.apply(%inputArr : memref<512xf32>, %scratchBuffer : memref<510xf32>) outs (%outputArr : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: csl_stencil.apply(%inputArr : memref<512xf32>, %scratchBuffer : memref<510xf32>) outs (%outputArr : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ // CHECK-NEXT: ^0(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>): // CHECK-NEXT: %5 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32> // CHECK-NEXT: %6 = csl_stencil.access %arg2[-1, 0] : memref<4x255xf32> diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index 7a0bb98b00..2c8a93fc57 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -178,6 +178,8 @@ class ApplyOp(IRDLOperation): Further fields: - `field` - the stencil field to communicate (send and receive) + - `args_rchunk` - arguments passed to the `receive_chunk` region, may include other prefetched buffers + - `args_dexchng` - arguments passed to the `done_exchange` region, may include other prefetched buffers - `args` - arguments to the stencil computation, may include other prefetched buffers - `topo` - as received from `csl_stencil.prefetch`/`dmp.swap` - `num_chunks` - number of chunks into which to slice the communication @@ -206,7 +208,8 @@ class ApplyOp(IRDLOperation): accumulator = operand_def(AnyTensorTypeConstr | AnyMemRefTypeConstr) - args = var_operand_def(Attribute) + args_rchunk = var_operand_def(Attribute) + args_dexchng = var_operand_def(Attribute) dest = var_operand_def(stencil.FieldTypeConstr | AnyMemRefTypeConstr) receive_chunk = region_def() @@ -241,7 +244,7 @@ def print_arg(arg: SSAValue): printer.print("(") # args required by function signature, plus optional args for regions - args = [self.field, self.accumulator, *self.args] + args = [self.field, self.accumulator, *self.args_rchunk, *self.args_dexchng] printer.print_list(args, print_arg) if self.dest: @@ -274,7 +277,7 @@ def parse_args(): value = parser.resolve_operand(value, type) return value - operands = parser.parse_comma_separated_list(parser.Delimiter.PAREN, parse_args) + ops = parser.parse_comma_separated_list(parser.Delimiter.PAREN, parse_args) if parser.parse_optional_punctuation("->"): parser.parse_punctuation("(") @@ -304,8 +307,10 @@ def parse_args(): props["bounds"] = stencil.StencilBoundsAttr.new( stencil.StencilBoundsAttr.parse_parameters(parser) ) + # `-3` fixed block args, `+2` offset for operands with fixed use + split = len(receive_chunk.block.args) - 3 + 2 return cls( - operands=[operands[0], operands[1], operands[2:], destinations], + operands=[ops[0], ops[1], ops[2:split], ops[split:], destinations], result_types=[result_types], regions=[receive_chunk, done_exchange], properties=props, diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 32d259311d..fd0980a88f 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -513,8 +513,8 @@ def get_prefetch_overhead(o: OpResult): operands=[ field_op_arg, accumulator, - [op.operands[a.index] for a in chunk_region_used_block_args] - + [op.operands[a.index] for a in done_exchange_used_block_args], + [op.operands[a.index] for a in chunk_region_used_block_args], + [op.operands[a.index] for a in done_exchange_used_block_args], op.dest, ], properties={ diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index ad252c1a96..a06ec2192c 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -84,7 +84,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # convert args buf_args: list[SSAValue] = [] to_memrefs: list[Operation] = [buf_iter_arg := to_memref_op(op.accumulator)] - for arg in op.args: + for arg in [*op.args_rchunk, *op.args_dexchng]: if isa(arg.type, TensorType[Attribute]): to_memrefs.append(new_arg := to_memref_op(arg)) buf_args.append(new_arg.memref) @@ -93,7 +93,13 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # create new op buf_apply_op = csl_stencil.ApplyOp( - operands=[op.field, buf_iter_arg.memref, op.args, op.dest], + operands=[ + op.field, + buf_iter_arg.memref, + op.args_rchunk, + op.args_dexchng, + op.dest, + ], result_types=op.res.types or [[]], regions=[ self._get_empty_bufferized_region(op.receive_chunk.block.args), diff --git a/xdsl/transforms/csl_stencil_materialize_stores.py b/xdsl/transforms/csl_stencil_materialize_stores.py index e549bf1d04..a27dcacffc 100644 --- a/xdsl/transforms/csl_stencil_materialize_stores.py +++ b/xdsl/transforms/csl_stencil_materialize_stores.py @@ -66,7 +66,8 @@ def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, operands=[ apply.field, apply.accumulator, - [*apply.args, *add_args], + apply.args_rchunk, + [*apply.args_dexchng, *add_args], apply.dest, ], regions=[apply.detach_region(r) for r in apply.regions], @@ -92,7 +93,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, return cond = wrapper_op.get_program_param("isBorderRegionPE") - if cond in op.args: + if cond in op.args_dexchng: return op.done_exchange.block.insert_arg(cond.type, len(op.done_exchange.block.args)) @@ -125,7 +126,8 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, operands=[ op.field, op.accumulator, - [*op.args, cond], + op.args_rchunk, + [*op.args_dexchng, cond], op.dest, ], regions=[op.detach_region(r) for r in op.regions], diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 79476b45c5..21d88e3aad 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -157,12 +157,12 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, op.field, # buffer - this is a placeholder and should not be used after lowering AccessOp index_op.result, op.accumulator, - *op.args[: len(op.receive_chunk.block.args) - 3], + *op.args_rchunk, ] done_arg_m = [ op.field, op.accumulator, - *op.args[len(chunk_arg_m) - 3 :], + *op.args_dexchng, ] index_op.result.name_hint = "offset" op.accumulator.name_hint = "accumulator" @@ -243,7 +243,7 @@ class InlineApplyOpArgs(RewritePattern): def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): arg_mapping = zip( op.done_exchange.block.args[2:], - op.args[-(len(op.done_exchange.block.args) - 2) :], + op.args_dexchng, ) for block_arg, arg in [ (op.done_exchange.block.args[0], op.field), @@ -252,7 +252,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, self._replace_block_arg(block_arg, arg, op.done_exchange, op, rewriter) for block_arg, arg in zip( op.receive_chunk.block.args[3:], - op.args[: len(op.receive_chunk.block.args) - 3], + op.args_rchunk, ): self._replace_block_arg(block_arg, arg, op.receive_chunk, op, rewriter)