Skip to content

Commit

Permalink
transformations: Split varith into neighbour and own data across csl_…
Browse files Browse the repository at this point in the history
…stencil regions (xdslproject#3307)

Background: The `csl_stencil.apply` op does communicate-and-compute on a
given stencil buffer. It holds two regions, one for processing chunks of
neighbour data (of this one buffer only), and one region for processing
everything else after the exchange is done. The
`convert-stencil-to-csl-stencil` pass splits the computation of the
`stencil.apply` op across these two regions.

The split was done in two steps, first re-ordering arith ops in the
`RestructureSymmetricReductionPattern`, and then calling the
`get_ops_split` function on the re-shuffled arith ops. Intuitively, the
re-order pass would identify chained reductions (`arith.addf`,
`arith.mulf`) and restructure them such that all neighbour data which
should end up in the first region is consumed first, and the chained
arith ops become easily splittable.

This PR replaces this logic by converting arith to varith, splitting the
varith op into neighbour/other data in the `SplitVarithOpPattern`
rewrite, and then proceeding with `get_ops_split` and everything else as
before. At the end, varith is converted back to arith.

Minor improvements:
* Constants are now always duplicated and appear on both regions, which
`dce` can clean up

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
2 people authored and EdmundGoodman committed Dec 6, 2024
1 parent 4d255c2 commit 1d7e6b5
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 135 deletions.
158 changes: 109 additions & 49 deletions tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,29 @@ builtin.module {
// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>}> ({
// CHECK-NEXT: ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>):
// CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %7 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %8 = csl_stencil.access %3[0, 1] : tensor<4x255xf32>
// CHECK-NEXT: %9 = csl_stencil.access %3[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32>
// CHECK-NEXT: %11 = arith.addf %10, %7 : tensor<255xf32>
// CHECK-NEXT: %12 = arith.addf %11, %6 : tensor<255xf32>
// CHECK-NEXT: %13 = "tensor.insert_slice"(%12, %5, %4) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %13 : tensor<510xf32>
// CHECK-NEXT: %6 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %7 = csl_stencil.access %3[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %8 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %9 = csl_stencil.access %3[0, 1] : tensor<4x255xf32>
// CHECK-NEXT: %10 = csl_stencil.access %3[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %11 = arith.addf %10, %9 : tensor<255xf32>
// CHECK-NEXT: %12 = arith.addf %11, %8 : tensor<255xf32>
// CHECK-NEXT: %13 = arith.addf %12, %7 : tensor<255xf32>
// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %5, %4) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %14 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%14 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %15 : tensor<510xf32>):
// CHECK-NEXT: %16 = csl_stencil.access %14[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %17 = csl_stencil.access %14[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %18 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %19 = "tensor.extract_slice"(%16) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %20 = "tensor.extract_slice"(%17) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %21 = arith.addf %15, %20 : tensor<510xf32>
// CHECK-NEXT: %22 = arith.addf %21, %19 : tensor<510xf32>
// CHECK-NEXT: %23 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %24 = linalg.fill ins(%18 : f32) outs(%23 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %25 = arith.mulf %22, %24 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %25 : tensor<510xf32>
// CHECK-NEXT: ^1(%15 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %16 : tensor<510xf32>):
// CHECK-NEXT: %17 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %20 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %22 = arith.addf %16, %21 : tensor<510xf32>
// CHECK-NEXT: %23 = arith.addf %22, %19 : tensor<510xf32>
// CHECK-NEXT: %24 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %25 = linalg.fill ins(%17 : f32) outs(%24 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %26 = arith.mulf %23, %25 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %26 : tensor<510xf32>
// CHECK-NEXT: })
// CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
Expand Down Expand Up @@ -90,19 +91,20 @@ builtin.module {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>):
// CHECK-NEXT: %5 = csl_stencil.access %2[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %6 = csl_stencil.access %2[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %7 = arith.addf %6, %5 : tensor<255xf32>
// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32>
// CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32>
// CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %7 = csl_stencil.access %2[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %8 = arith.addf %7, %6 : tensor<255xf32>
// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%9 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %10 : tensor<510xf32>):
// CHECK-NEXT: %11 = csl_stencil.access %9[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: ^1(%10 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %11 : tensor<510xf32>):
// CHECK-NEXT: %12 = arith.constant dense<1.666600e-01> : tensor<510xf32>
// CHECK-NEXT: %13 = "tensor.extract_slice"(%11) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %14 = arith.addf %10, %13 : tensor<510xf32>
// CHECK-NEXT: %15 = arith.mulf %14, %12 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %15 : tensor<510xf32>
// CHECK-NEXT: %13 = csl_stencil.access %10[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %15 = arith.addf %11, %14 : tensor<510xf32>
// CHECK-NEXT: %16 = arith.mulf %15, %12 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
Expand All @@ -125,35 +127,93 @@ builtin.module {
%13 = arith.addf %12, %8 : tensor<510xf32>
%14 = arith.addf %13, %11 : tensor<510xf32>
%15 = arith.mulf %14, %2 : tensor<510xf32>
stencil.return %13 : tensor<510xf32>
stencil.return %15 : tensor<510xf32>
} to <[0, 0], [1, 1]>
stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}

// CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>, %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>, %5 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-NEXT: %6 = arith.constant dense<1.234500e-01> : tensor<255xf32>
// CHECK-NEXT: %7 = arith.constant dense<2.345678e-01> : tensor<510xf32>
// CHECK-NEXT: %8 = arith.constant dense<3.141500e-01> : tensor<510xf32>
// CHECK-NEXT: %9 = csl_stencil.access %2[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %10 = csl_stencil.access %5[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<255xf32>
// CHECK-NEXT: %12 = csl_stencil.access %2[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %13 = arith.addf %12, %11 : tensor<255xf32>
// CHECK-NEXT: %14 = arith.addf %13, %9 : tensor<255xf32>
// CHECK-NEXT: %15 = arith.mulf %14, %6 : tensor<255xf32>
// CHECK-NEXT: %16 = "tensor.insert_slice"(%15, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>):
// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32>
// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32>
// CHECK-NEXT: %7 = arith.constant dense<3.141500e-01> : tensor<510xf32>
// CHECK-NEXT: %8 = csl_stencil.access %2[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %9 = csl_stencil.access %2[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32>
// CHECK-NEXT: %11 = "tensor.insert_slice"(%10, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%17 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %18 : tensor<510xf32>):
// CHECK-NEXT: csl_stencil.yield %13 : tensor<255xf32>
// CHECK-NEXT: ^1(%12 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %13 : tensor<510xf32>):
// CHECK-NEXT: %14 = arith.constant dense<1.234500e-01> : tensor<510xf32>
// CHECK-NEXT: %15 = csl_stencil.access %12[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %17 = arith.addf %13, %16 : tensor<510xf32>
// CHECK-NEXT: %18 = arith.mulf %17, %14 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %18 : tensor<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
%0 = arith.constant 41 : index
%1 = arith.constant 0 : index
%2 = arith.constant 1 : index
%3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
"dmp.swap"(%arg3) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<600x600>, false>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 600] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [2, 0, 0] size [1, 1, 600] source offset [-2, 0, 0] to [2, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 600] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [-2, 0, 0] size [1, 1, 600] source offset [2, 0, 0] to [-2, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 600] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, 2, 0] size [1, 1, 600] source offset [0, -2, 0] to [0, 2, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 600] source offset [0, 1, 0] to [0, -1, 0]>, #dmp.exchange<at [0, -2, 0] size [1, 1, 600] source offset [0, 2, 0] to [0, -2, 0]>]} : (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) -> ()
stencil.apply(%arg5 = %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
%5 = arith.constant dense<1.287158e+09> : tensor<600xf32>
%6 = arith.constant dense<1.196003e+05> : tensor<600xf32>
%7 = stencil.access %arg5[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>
%8 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 600>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<604xf32>) -> tensor<600xf32>
%9 = arith.mulf %8, %5 : tensor<600xf32>
%10 = stencil.access %arg5[-1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>
%11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 600>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<604xf32>) -> tensor<600xf32>
%12 = arith.mulf %11, %6 : tensor<600xf32>
%13 = stencil.access %arg5[1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>
%14 = "tensor.extract_slice"(%13) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 600>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<604xf32>) -> tensor<600xf32>
%15 = arith.mulf %14, %6 : tensor<600xf32>
%16 = arith.addf %12, %9 : tensor<600xf32>
%17 = arith.addf %16, %15 : tensor<600xf32>
stencil.return %17 : tensor<600xf32>
} to <[0, 0], [1, 1]>
scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>
}
func.return
}

// CHECK-NEXT: func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
// CHECK-NEXT: %0 = arith.constant 41 : index
// CHECK-NEXT: %1 = arith.constant 0 : index
// CHECK-NEXT: %2 = arith.constant 1 : index
// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32>
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>):
// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32>
// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32>
// CHECK-NEXT: %11 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32>
// CHECK-NEXT: %12 = csl_stencil.access %6[1, 0] : tensor<8x300xf32>
// CHECK-NEXT: %13 = arith.addf %11, %12 : tensor<300xf32>
// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %8, %7) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 300>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32>
// CHECK-NEXT: csl_stencil.yield %14 : tensor<600xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%15 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %16 : tensor<600xf32>):
// CHECK-NEXT: %17 = arith.constant dense<1.287158e+09> : tensor<600xf32>
// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>
// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 600>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<604xf32>) -> tensor<600xf32>
// CHECK-NEXT: %20 = arith.mulf %19, %17 : tensor<600xf32>
// CHECK-NEXT: %21 = arith.addf %16, %20 : tensor<600xf32>
// CHECK-NEXT: csl_stencil.yield %21 : tensor<600xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }


}
// CHECK-NEXT: }
Loading

0 comments on commit 1d7e6b5

Please sign in to comment.