diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index 11d6a54e99..1f43e07c93 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -39,28 +39,29 @@ builtin.module { // CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> // CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>): -// CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %7 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %8 = csl_stencil.access %3[0, 1] : tensor<4x255xf32> -// CHECK-NEXT: %9 = csl_stencil.access %3[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32> -// CHECK-NEXT: %11 = arith.addf %10, %7 : tensor<255xf32> -// CHECK-NEXT: %12 = arith.addf %11, %6 : tensor<255xf32> -// CHECK-NEXT: %13 = "tensor.insert_slice"(%12, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %13 : tensor<510xf32> +// CHECK-NEXT: %6 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %7 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %8 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %9 = csl_stencil.access %3[0, 1] : tensor<4x255xf32> +// CHECK-NEXT: %10 = csl_stencil.access %3[0, -1] : tensor<4x255xf32> +// CHECK-NEXT: %11 = arith.addf %10, %9 : tensor<255xf32> +// CHECK-NEXT: %12 = arith.addf %11, %8 : tensor<255xf32> +// CHECK-NEXT: %13 = arith.addf %12, %7 : tensor<255xf32> +// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %14 : tensor<510xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%14 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %15 : tensor<510xf32>): -// CHECK-NEXT: %16 = csl_stencil.access %14[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %17 = csl_stencil.access %14[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %18 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: %19 = "tensor.extract_slice"(%16) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %20 = "tensor.extract_slice"(%17) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %21 = arith.addf %15, %20 : tensor<510xf32> -// CHECK-NEXT: %22 = arith.addf %21, %19 : tensor<510xf32> -// CHECK-NEXT: %23 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %24 = linalg.fill ins(%18 : f32) outs(%23 : tensor<510xf32>) -> tensor<510xf32> -// CHECK-NEXT: %25 = arith.mulf %22, %24 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %25 : tensor<510xf32> +// CHECK-NEXT: ^1(%15 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %16 : tensor<510xf32>): +// CHECK-NEXT: %17 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> +// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %20 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> +// CHECK-NEXT: %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %22 = arith.addf %16, %21 : tensor<510xf32> +// CHECK-NEXT: %23 = arith.addf %22, %19 : tensor<510xf32> +// CHECK-NEXT: %24 = tensor.empty() : tensor<510xf32> +// CHECK-NEXT: %25 = linalg.fill ins(%17 : f32) outs(%24 : tensor<510xf32>) -> tensor<510xf32> +// CHECK-NEXT: %26 = arith.mulf %23, %25 : tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %26 : tensor<510xf32> // CHECK-NEXT: }) // CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return @@ -90,19 +91,20 @@ builtin.module { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> // CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): -// CHECK-NEXT: %5 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %6 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %7 = arith.addf %6, %5 : tensor<255xf32> -// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32> +// CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32> +// CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %7 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> +// CHECK-NEXT: %8 = arith.addf %7, %6 : tensor<255xf32> +// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%9 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %10 : tensor<510xf32>): -// CHECK-NEXT: %11 = csl_stencil.access %9[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: ^1(%10 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %11 : tensor<510xf32>): // CHECK-NEXT: %12 = arith.constant dense<1.666600e-01> : tensor<510xf32> -// CHECK-NEXT: %13 = "tensor.extract_slice"(%11) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %14 = arith.addf %10, %13 : tensor<510xf32> -// CHECK-NEXT: %15 = arith.mulf %14, %12 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %15 : tensor<510xf32> +// CHECK-NEXT: %13 = csl_stencil.access %10[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %15 = arith.addf %11, %14 : tensor<510xf32> +// CHECK-NEXT: %16 = arith.mulf %15, %12 : tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32> // CHECK-NEXT: }) to <[0, 0], [1, 1]> // CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return @@ -125,7 +127,7 @@ builtin.module { %13 = arith.addf %12, %8 : tensor<510xf32> %14 = arith.addf %13, %11 : tensor<510xf32> %15 = arith.mulf %14, %2 : tensor<510xf32> - stencil.return %13 : tensor<510xf32> + stencil.return %15 : tensor<510xf32> } to <[0, 0], [1, 1]> stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return @@ -133,27 +135,85 @@ builtin.module { // CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>, %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({ -// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>, %5 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>): -// CHECK-NEXT: %6 = arith.constant dense<1.234500e-01> : tensor<255xf32> -// CHECK-NEXT: %7 = arith.constant dense<2.345678e-01> : tensor<510xf32> -// CHECK-NEXT: %8 = arith.constant dense<3.141500e-01> : tensor<510xf32> -// CHECK-NEXT: %9 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %10 = csl_stencil.access %5[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<255xf32> -// CHECK-NEXT: %12 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %13 = arith.addf %12, %11 : tensor<255xf32> -// CHECK-NEXT: %14 = arith.addf %13, %9 : tensor<255xf32> -// CHECK-NEXT: %15 = arith.mulf %14, %6 : tensor<255xf32> -// CHECK-NEXT: %16 = "tensor.insert_slice"(%15, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32> +// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({ +// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): +// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32> +// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32> +// CHECK-NEXT: %7 = arith.constant dense<3.141500e-01> : tensor<510xf32> +// CHECK-NEXT: %8 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %9 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> +// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32> +// CHECK-NEXT: %11 = "tensor.insert_slice"(%10, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%17 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %18 : tensor<510xf32>): -// CHECK-NEXT: csl_stencil.yield %13 : tensor<255xf32> +// CHECK-NEXT: ^1(%12 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %13 : tensor<510xf32>): +// CHECK-NEXT: %14 = arith.constant dense<1.234500e-01> : tensor<510xf32> +// CHECK-NEXT: %15 = csl_stencil.access %12[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %17 = arith.addf %13, %16 : tensor<510xf32> +// CHECK-NEXT: %18 = arith.mulf %17, %14 : tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %18 : tensor<510xf32> // CHECK-NEXT: }) to <[0, 0], [1, 1]> // CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } + func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + %0 = arith.constant 41 : index + %1 = arith.constant 0 : index + %2 = arith.constant 1 : index + %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + "dmp.swap"(%arg3) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<600x600>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) -> () + stencil.apply(%arg5 = %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + %5 = arith.constant dense<1.287158e+09> : tensor<600xf32> + %6 = arith.constant dense<1.196003e+05> : tensor<600xf32> + %7 = stencil.access %arg5[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %9 = arith.mulf %8, %5 : tensor<600xf32> + %10 = stencil.access %arg5[-1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %12 = arith.mulf %11, %6 : tensor<600xf32> + %13 = stencil.access %arg5[1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %15 = arith.mulf %14, %6 : tensor<600xf32> + %16 = arith.addf %12, %9 : tensor<600xf32> + %17 = arith.addf %16, %15 : tensor<600xf32> + stencil.return %17 : tensor<600xf32> + } to <[0, 0], [1, 1]> + scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + } + func.return +} + +// CHECK-NEXT: func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: %0 = arith.constant 41 : index +// CHECK-NEXT: %1 = arith.constant 0 : index +// CHECK-NEXT: %2 = arith.constant 1 : index +// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({ +// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): +// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32> +// CHECK-NEXT: %11 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %12 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %13 = arith.addf %11, %12 : tensor<300xf32> +// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %14 : tensor<600xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%15 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %16 : tensor<600xf32>): +// CHECK-NEXT: %17 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %20 = arith.mulf %19, %17 : tensor<600xf32> +// CHECK-NEXT: %21 = arith.addf %16, %20 : tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %21 : tensor<600xf32> +// CHECK-NEXT: }) to <[0, 0], [1, 1]> +// CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: } +// CHECK-NEXT: func.return +// CHECK-NEXT: } + + } // CHECK-NEXT: } diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index d2c59c9757..351ee2669f 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -3,7 +3,7 @@ from math import prod from xdsl.context import MLContext -from xdsl.dialects import arith, stencil, tensor +from xdsl.dialects import arith, stencil, tensor, varith from xdsl.dialects.builtin import ( AnyFloatAttr, AnyMemRefTypeConstr, @@ -17,7 +17,15 @@ ) from xdsl.dialects.csl import csl_stencil from xdsl.dialects.experimental import dmp -from xdsl.ir import Attribute, Block, BlockArgument, Operation, OpResult, Region +from xdsl.ir import ( + Attribute, + Block, + BlockArgument, + Operation, + OpResult, + Region, + SSAValue, +) from xdsl.irdl import Operand, base from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -31,6 +39,10 @@ from xdsl.transforms.experimental.stencil_tensorize_z_dimension import ( BackpropagateStencilShapes, ) +from xdsl.transforms.varith_transformations import ( + ConvertArithToVarithPass, + ConvertVarithToArithPass, +) from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr @@ -55,69 +67,35 @@ def get_stencil_access_operands(op: Operand) -> set[Operand]: return res -@dataclass(frozen=True) -class RestructureSymmetricReductionPattern(RewritePattern): - """ - Consume data first where that data comes from stencil accesses to `buf`. +def _get_prefetch_buf_idx(op: stencil.ApplyOp) -> int | None: + # calculate memory cost of all prefetch operands + def get_prefetch_overhead(o: OpResult): + assert isa(o.type, TensorType[Attribute]) + return prod(o.type.get_shape()) - Identifies a pattern of 2 connected binary ops with 3 args, e.g. of the form `(a+b)+c` with different ops and - bracketings supported, and attempts to re-structure the order of computation. + candidate_prefetches = [ + (get_prefetch_overhead(o), o) + for o in op.operands + if isinstance(o, OpResult) and isinstance(o.op, csl_stencil.PrefetchOp) + ] + if len(candidate_prefetches) == 0: + return - Being in principle similarly to constant folding, the difference is that args are *not* required to be stencil - accesses, but could have further compute applied before being passed to the reduction function. - Uses helper function `get_stencil_accessed_symbols` to check which bufs are stencil-accessed in each of these args, - and to distinguish the following three cases: + # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply + prefetch = max(candidate_prefetches)[1] + return op.operands.index(prefetch) - (1) all accesses in an arg tree are to `buf` - arg should be moved forward in the computation - (2) no accesses are to `buf` - arg should be moved backward in the computation - (3) there's a mix - unknown, take any or no action - If two args are identified that should be moved forward, or two args are identified that should be moved backwards, - the computation is restructured accordingly. +def _get_apply_op(op: Operation) -> stencil.ApplyOp | None: """ - - buf: BlockArgument - - @op_type_rewrite_pattern - def match_and_rewrite( - self, op: arith.Addf | arith.Mulf, rewriter: PatternRewriter, / - ): - # this rewrite requires exactly 1 use which is the same type of operation - if len(op.result.uses) != 1 or not isinstance( - use := list(op.result.uses)[0].operation, type(op) - ): - return - c_op = use.operands[0] if use.operands[1] == op.result else use.operands[1] - - def rewrite(one: Operand, two: Operand, three: Operand): - """Builds `(one+two)+three` where `'+' == type(op)`""" - - first_compute = type(op)(one, two) - second_compute = type(op)(first_compute, three) - - # Both ops are inserted at the later point to ensure all dependencies are present when moving compute around. - # Moving the replacement of `op` backwards is safe because we previously asserted at `op` only has one use (ie. in `use`) - rewriter.replace_op(op, [], [first_compute.results[0]]) - rewriter.replace_op( - use, [first_compute, second_compute], [second_compute.results[0]] - ) - - a = get_stencil_access_operands(a_op := op.lhs) - b = get_stencil_access_operands(b_op := op.rhs) - c = get_stencil_access_operands(c_op) - - if self.move_fwd(a) and self.move_fwd(b): - return - elif self.move_back(a) and self.move_back(b): - return - elif self.move_fwd(c) and self.move_back(b): - rewrite(a_op, c_op, b_op) - - def move_fwd(self, accs: set[Operand]) -> bool: - return self.buf in accs and len(accs) == 1 - - def move_back(self, accs: set[Operand]) -> bool: - return self.buf not in accs + Return the enclosing csl_wrapper.module + """ + parent_op = op.parent_op() + while parent_op: + if isinstance(parent_op, stencil.ApplyOp): + return parent_op + parent_op = parent_op.parent_op() + return None @dataclass(frozen=True) @@ -278,7 +256,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): nested_rewriter.rewrite_op(new_apply_op) -def get_op_split( +def split_ops( ops: Sequence[Operation], buf: BlockArgument ) -> tuple[Sequence[Operation], Sequence[Operation]]: """ @@ -298,6 +276,8 @@ def get_op_split( for op in ops: if isinstance(op, csl_stencil.AccessOp): (b, a)[op.op == buf].append(op) + elif isinstance(op, arith.Constant): + a.append(op) else: rem.append(op) @@ -331,14 +311,66 @@ def get_op_split( a.append(use.operation) rem.remove(use.operation) - if len(a_exports) == 1: - return a, b + rem + # find constants in `a` needed outside of `a` + cnst_exports = [cnst for cnst in a_exports if isinstance(cnst, arith.Constant)] + + # `a` exports one value plus any number of constants - duplicate exported constants and return op split + if len(a_exports) == 1 + len(cnst_exports): + recv_chunk_ops, done_exch_ops = list[Operation](), list[Operation]() + for op in ops: + if op in a: + recv_chunk_ops.append(op) + if op in cnst_exports: + # create a copy of the constant in the second region + done_exch_ops.append(cln := op.clone()) + # rewire ops of the second region to use the copied constant + op.result.replace_by_if( + cln.result, + lambda use: use.operation in b or use.operation in rem, + ) + else: + done_exch_ops.append(op) + + return recv_chunk_ops, done_exch_ops # fallback # always place `stencil.return` in second block return ops[:-1], [ops[-1]] +@dataclass(frozen=True) +class SplitVarithOpPattern(RewritePattern): + """ + Splits a varith op into two, depending on whether the operands holds stencil accesses to `buf` (only) + or any other accesses. + + This pass is intended to be run with `buf` set to the block arg indicating data received from neighbours. + """ + + buf: BlockArgument + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /): + if not (apply := _get_apply_op(op)) or not ( + buf_idx := _get_prefetch_buf_idx(apply) + ): + return + buf = apply.region.block.args[buf_idx] + buf_accesses, others = list[SSAValue](), list[SSAValue]() + + for arg in op.args: + accs = get_stencil_access_operands(arg) + (others, buf_accesses)[buf in accs and len(accs) == 1].append(arg) + + if len(others) > 0 and len(buf_accesses) > 0: + rewriter.replace_matched_op( + [ + n_op := type(op)(*buf_accesses), + type(op)(n_op, *others), + ] + ) + + @dataclass(frozen=True) class ConvertApplyOpPattern(RewritePattern): """ @@ -359,24 +391,12 @@ class ConvertApplyOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): - # calculate memory cost of all prefetch operands - def get_prefetch_overhead(o: OpResult): - assert isa(o.type, TensorType[Attribute]) - buf_count = o.type.get_shape()[0] - buf_size = prod(o.type.get_shape()[1:]) - return buf_count * buf_size - - candidate_prefetches = [ - (get_prefetch_overhead(o), o) - for o in op.operands - if isinstance(o, OpResult) and isinstance(o.op, csl_stencil.PrefetchOp) - ] - if len(candidate_prefetches) == 0: + if not (prefetch_idx := _get_prefetch_buf_idx(op)): return # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply - prefetch = max(candidate_prefetches)[1] - prefetch_idx = op.operands.index(prefetch) + prefetch = op.operands[prefetch_idx] + assert isinstance(prefetch, OpResult) assert isinstance(prefetch.op, csl_stencil.PrefetchOp) field_idx = op.operands.index(prefetch.op.input_stencil) assert isinstance(prefetch.op, csl_stencil.PrefetchOp) @@ -392,14 +412,14 @@ def get_prefetch_overhead(o: OpResult): rewriter.insert_op(accumulator, InsertPoint.before(op)) # run pass (on this apply's region only) to consume data from `prefetch` accesses first - nested_rewriter = PatternRewriteWalker( - RestructureSymmetricReductionPattern(op.region.block.args[prefetch_idx]), - walk_reverse=True, - ) - nested_rewriter.rewrite_op(op) + # find varith ops and split according to neighbour data + PatternRewriteWalker( + SplitVarithOpPattern(op.region.block.args[prefetch_idx]), + apply_recursively=False, + ).rewrite_op(op) # determine how ops should be split across the two regions - chunk_region_ops, done_exchange_ops = get_op_split( + chunk_region_ops, done_exchange_ops = split_ops( list(op.region.block.ops), op.region.block.args[prefetch_idx] ) @@ -576,6 +596,7 @@ class ConvertStencilToCslStencilPass(ModulePass): num_chunks: int = 1 def apply(self, ctx: MLContext, op: ModuleOp) -> None: + ConvertArithToVarithPass().apply(ctx, op) module_pass = PatternRewriteWalker( GreedyRewritePatternApplier( [ @@ -585,9 +606,9 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: ] ), walk_reverse=False, - apply_recursively=True, ) module_pass.rewrite_module(op) + ConvertVarithToArithPass().apply(ctx, op) if self.num_chunks > 1: BackpropagateStencilShapes().apply(ctx, op) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 6f6a92c0ac..f22e842ff9 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -3,7 +3,7 @@ from typing import TypeGuard, cast from xdsl.context import MLContext -from xdsl.dialects import builtin +from xdsl.dialects import builtin, varith from xdsl.dialects.arith import ( Addf, Constant, @@ -389,6 +389,17 @@ def match_and_rewrite( arithBinaryOpUpdateShape(op, rewriter) +class VarithOpUpdateShape(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /): + type_constructor = type(op) + if typ := get_required_result_type(op): + if needs_update_shape(op.result_types[0], typ): + rewriter.replace_matched_op( + type_constructor.build(operands=[op.args], result_types=[typ]) + ) + + class EmptyOpUpdateShape(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: EmptyOp, rewriter: PatternRewriter, /): @@ -437,6 +448,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: EmptyOpUpdateShape(), FillOpUpdateShape(), ArithOpUpdateShape(), + VarithOpUpdateShape(), ConstOpUpdateShape(), ] ),