From cdeecca052fc3f232c5b276a76f33eb9df30aab4 Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 16 Oct 2024 14:16:43 +0200 Subject: [PATCH 1/8] transformations: (convert-stencil-to-csl-stencil) Support varith --- .../convert-stencil-to-csl-stencil.mlir | 251 +++++++++++++++--- .../convert_stencil_to_csl_stencil.py | 142 ++++++++-- .../stencil_tensorize_z_dimension.py | 14 +- 3 files changed, 348 insertions(+), 59 deletions(-) diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index d7b53644a0..d880b3fdd6 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -39,28 +39,29 @@ builtin.module { // CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> // CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%3 : tensor<4x255xf32>, %4 : index, %5 : tensor<510xf32>): -// CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %7 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %8 = csl_stencil.access %3[0, 1] : tensor<4x255xf32> -// CHECK-NEXT: %9 = csl_stencil.access %3[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32> -// CHECK-NEXT: %11 = arith.addf %10, %7 : tensor<255xf32> -// CHECK-NEXT: %12 = arith.addf %11, %6 : tensor<255xf32> -// CHECK-NEXT: %13 = "tensor.insert_slice"(%12, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %13 : tensor<510xf32> +// CHECK-NEXT: %6 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %7 = csl_stencil.access %3[1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %8 = csl_stencil.access %3[-1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %9 = csl_stencil.access %3[0, 1] : tensor<4x255xf32> +// CHECK-NEXT: %10 = csl_stencil.access %3[0, -1] : tensor<4x255xf32> +// CHECK-NEXT: %11 = arith.addf %10, %9 : tensor<255xf32> +// CHECK-NEXT: %12 = arith.addf %11, %8 : tensor<255xf32> +// CHECK-NEXT: %13 = arith.addf %12, %7 : tensor<255xf32> +// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %14 : tensor<510xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%14 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %15 : tensor<510xf32>): -// CHECK-NEXT: %16 = csl_stencil.access %14[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %17 = csl_stencil.access %14[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> -// CHECK-NEXT: %18 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: %19 = "tensor.extract_slice"(%16) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %20 = "tensor.extract_slice"(%17) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %21 = arith.addf %15, %20 : tensor<510xf32> -// CHECK-NEXT: %22 = arith.addf %21, %19 : tensor<510xf32> -// CHECK-NEXT: %23 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %24 = linalg.fill ins(%18 : f32) outs(%23 : tensor<510xf32>) -> tensor<510xf32> -// CHECK-NEXT: %25 = arith.mulf %22, %24 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %25 : tensor<510xf32> +// CHECK-NEXT: ^1(%15 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %16 : tensor<510xf32>): +// CHECK-NEXT: %17 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> +// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %20 = csl_stencil.access %15[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> +// CHECK-NEXT: %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %22 = arith.addf %16, %21 : tensor<510xf32> +// CHECK-NEXT: %23 = arith.addf %22, %19 : tensor<510xf32> +// CHECK-NEXT: %24 = tensor.empty() : tensor<510xf32> +// CHECK-NEXT: %25 = linalg.fill ins(%17 : f32) outs(%24 : tensor<510xf32>) -> tensor<510xf32> +// CHECK-NEXT: %26 = arith.mulf %23, %25 : tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %26 : tensor<510xf32> // CHECK-NEXT: }) // CHECK-NEXT: stencil.store %2 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return @@ -90,23 +91,209 @@ builtin.module { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> // CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): -// CHECK-NEXT: %5 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %6 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %7 = arith.addf %6, %5 : tensor<255xf32> -// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32> +// CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32> +// CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %7 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> +// CHECK-NEXT: %8 = arith.addf %7, %6 : tensor<255xf32> +// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%9 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %10 : tensor<510xf32>): -// CHECK-NEXT: %11 = csl_stencil.access %9[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: ^1(%10 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %11 : tensor<510xf32>): // CHECK-NEXT: %12 = arith.constant dense<1.666600e-01> : tensor<510xf32> -// CHECK-NEXT: %13 = "tensor.extract_slice"(%11) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %14 = arith.addf %10, %13 : tensor<510xf32> -// CHECK-NEXT: %15 = arith.mulf %14, %12 : tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %15 : tensor<510xf32> +// CHECK-NEXT: %13 = csl_stencil.access %10[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %15 = arith.addf %11, %14 : tensor<510xf32> +// CHECK-NEXT: %16 = arith.mulf %15, %12 : tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32> // CHECK-NEXT: }) to <[0, 0], [1, 1]> // CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return // CHECK-NEXT: } + +func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + %0 = arith.constant 41 : index + %1 = arith.constant 0 : index + %2 = arith.constant 1 : index + %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + "dmp.swap"(%arg3) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<600x600>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) -> () + stencil.apply(%arg5 = %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + %5 = arith.constant dense<1.287158e+09> : tensor<600xf32> + %6 = arith.constant dense<1.196003e+05> : tensor<600xf32> + %7 = stencil.access %arg5[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %9 = arith.mulf %8, %5 : tensor<600xf32> + %10 = stencil.access %arg5[-1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %12 = arith.mulf %11, %6 : tensor<600xf32> + %13 = stencil.access %arg5[1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %15 = arith.mulf %14, %6 : tensor<600xf32> + %16 = arith.addf %12, %9 : tensor<600xf32> + %17 = arith.addf %16, %15 : tensor<600xf32> + stencil.return %17 : tensor<600xf32> + } to <[0, 0], [1, 1]> + scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + } + func.return +} + +// CHECK-NEXT: func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: %0 = arith.constant 41 : index +// CHECK-NEXT: %1 = arith.constant 0 : index +// CHECK-NEXT: %2 = arith.constant 1 : index +// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): +// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<300xf32> +// CHECK-NEXT: %11 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %12 = arith.mulf %11, %10 : tensor<300xf32> +// CHECK-NEXT: %13 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %14 = arith.mulf %13, %10 : tensor<300xf32> +// CHECK-NEXT: %15 = arith.addf %12, %14 : tensor<300xf32> +// CHECK-NEXT: %16 = "tensor.insert_slice"(%15, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %16 : tensor<600xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%17 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %18 : tensor<600xf32>): +// CHECK-NEXT: %19 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %20 = csl_stencil.access %17[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %22 = arith.mulf %21, %19 : tensor<600xf32> +// CHECK-NEXT: %23 = arith.addf %18, %22 : tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %23 : tensor<600xf32> +// CHECK-NEXT: }) to <[0, 0], [1, 1]> +// CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: } +// CHECK-NEXT: func.return +// CHECK-NEXT: } + + + func.func @diffusion(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + %0 = arith.constant 41 : index + %1 = arith.constant 0 : index + %2 = arith.constant 1 : index + %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + "dmp.swap"(%arg3) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<600x600>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) -> () + stencil.apply(%arg5 = %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { + %5 = arith.constant dense<1.287158e+09> : tensor<600xf32> + %6 = arith.constant dense<1.196003e+05> : tensor<600xf32> + %7 = arith.constant dense<-2.242506e+05> : tensor<600xf32> + %8 = arith.constant dense<-7.475020e+03> : tensor<600xf32> + %9 = arith.constant dense<9.000000e-01> : tensor<600xf32> + %10 = arith.constant dense<1.033968e-08> : tensor<600xf32> + %11 = stencil.access %arg5[-1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %12 = "tensor.extract_slice"(%11) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %13 = arith.mulf %12, %6 : tensor<600xf32> + %14 = stencil.access %arg5[1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %16 = arith.mulf %15, %6 : tensor<600xf32> + %17 = stencil.access %arg5[-2, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %18 = "tensor.extract_slice"(%17) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %19 = arith.mulf %18, %8 : tensor<600xf32> + %20 = stencil.access %arg5[2, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %22 = arith.mulf %21, %8 : tensor<600xf32> + %23 = stencil.access %arg5[0, -1] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %24 = "tensor.extract_slice"(%23) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %25 = arith.mulf %24, %6 : tensor<600xf32> + %26 = stencil.access %arg5[0, 1] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %27 = "tensor.extract_slice"(%26) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %28 = arith.mulf %27, %6 : tensor<600xf32> + %29 = stencil.access %arg5[0, -2] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %30 = "tensor.extract_slice"(%29) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %31 = arith.mulf %30, %8 : tensor<600xf32> + %32 = stencil.access %arg5[0, 2] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %33 = "tensor.extract_slice"(%32) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %34 = arith.mulf %33, %8 : tensor<600xf32> + %35 = stencil.access %arg5[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + %36 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %37 = arith.mulf %36, %7 : tensor<600xf32> + %38 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %39 = arith.mulf %38, %6 : tensor<600xf32> + %40 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %41 = arith.mulf %40, %6 : tensor<600xf32> + %42 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %43 = arith.mulf %42, %8 : tensor<600xf32> + %44 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> + %45 = arith.mulf %44, %8 : tensor<600xf32> + %46 = varith.add %45, %37, %39, %41, %43, %22, %37, %13, %16, %19, %34, %37, %25, %28, %31 : tensor<600xf32> + %47 = arith.mulf %46, %9 : tensor<600xf32> + %48 = arith.mulf %36, %5 : tensor<600xf32> + %49 = arith.addf %48, %47 : tensor<600xf32> + %50 = arith.mulf %49, %10 : tensor<600xf32> + stencil.return %50 : tensor<600xf32> + } to <[0, 0], [1, 1]> + scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> + } + func.return + } + +// CHECK-NEXT: func.func @diffusion(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: %0 = arith.constant 41 : index +// CHECK-NEXT: %1 = arith.constant 0 : index +// CHECK-NEXT: %2 = arith.constant 1 : index +// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): +// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<300xf32> +// CHECK-NEXT: %11 = arith.constant dense<-2.242506e+05> : tensor<600xf32> +// CHECK-NEXT: %12 = arith.constant dense<-7.475020e+03> : tensor<300xf32> +// CHECK-NEXT: %13 = arith.constant dense<9.000000e-01> : tensor<600xf32> +// CHECK-NEXT: %14 = arith.constant dense<1.033968e-08> : tensor<600xf32> +// CHECK-NEXT: %15 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %16 = arith.mulf %15, %10 : tensor<300xf32> +// CHECK-NEXT: %17 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %18 = arith.mulf %17, %10 : tensor<300xf32> +// CHECK-NEXT: %19 = csl_stencil.access %6[-2, 0] : tensor<8x300xf32> +// CHECK-NEXT: %20 = arith.mulf %19, %12 : tensor<300xf32> +// CHECK-NEXT: %21 = csl_stencil.access %6[2, 0] : tensor<8x300xf32> +// CHECK-NEXT: %22 = arith.mulf %21, %12 : tensor<300xf32> +// CHECK-NEXT: %23 = csl_stencil.access %6[0, -1] : tensor<8x300xf32> +// CHECK-NEXT: %24 = arith.mulf %23, %10 : tensor<300xf32> +// CHECK-NEXT: %25 = csl_stencil.access %6[0, 1] : tensor<8x300xf32> +// CHECK-NEXT: %26 = arith.mulf %25, %10 : tensor<300xf32> +// CHECK-NEXT: %27 = csl_stencil.access %6[0, -2] : tensor<8x300xf32> +// CHECK-NEXT: %28 = arith.mulf %27, %12 : tensor<300xf32> +// CHECK-NEXT: %29 = csl_stencil.access %6[0, 2] : tensor<8x300xf32> +// CHECK-NEXT: %30 = arith.mulf %29, %12 : tensor<300xf32> +// CHECK-NEXT: %31 = varith.add %22, %16, %18, %20, %30, %24, %26, %28 : tensor<300xf32> +// CHECK-NEXT: %32 = "tensor.insert_slice"(%31, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %32 : tensor<600xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%33 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %34 : tensor<600xf32>): +// CHECK-NEXT: %35 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %36 = arith.constant dense<1.196003e+05> : tensor<600xf32> +// CHECK-NEXT: %37 = arith.constant dense<-2.242506e+05> : tensor<600xf32> +// CHECK-NEXT: %38 = arith.constant dense<-7.475020e+03> : tensor<600xf32> +// CHECK-NEXT: %39 = arith.constant dense<9.000000e-01> : tensor<600xf32> +// CHECK-NEXT: %40 = arith.constant dense<1.033968e-08> : tensor<600xf32> +// CHECK-NEXT: %41 = csl_stencil.access %33[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: %42 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %43 = arith.mulf %42, %37 : tensor<600xf32> +// CHECK-NEXT: %44 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %45 = arith.mulf %44, %36 : tensor<600xf32> +// CHECK-NEXT: %46 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %47 = arith.mulf %46, %36 : tensor<600xf32> +// CHECK-NEXT: %48 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %49 = arith.mulf %48, %38 : tensor<600xf32> +// CHECK-NEXT: %50 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %51 = arith.mulf %50, %38 : tensor<600xf32> +// CHECK-NEXT: %52 = varith.add %34, %51, %43, %45, %47, %49, %43, %43 : tensor<600xf32> +// CHECK-NEXT: %53 = arith.mulf %52, %39 : tensor<600xf32> +// CHECK-NEXT: %54 = arith.mulf %42, %35 : tensor<600xf32> +// CHECK-NEXT: %55 = arith.addf %54, %53 : tensor<600xf32> +// CHECK-NEXT: %56 = arith.mulf %55, %40 : tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %56 : tensor<600xf32> +// CHECK-NEXT: }) to <[0, 0], [1, 1]> +// CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: } +// CHECK-NEXT: func.return +// CHECK-NEXT: } + + } // CHECK-NEXT: } diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 32d259311d..7b134d5669 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -3,7 +3,7 @@ from math import prod from xdsl.context import MLContext -from xdsl.dialects import arith, stencil, tensor +from xdsl.dialects import arith, stencil, tensor, varith from xdsl.dialects.builtin import ( AnyMemRefTypeConstr, AnyTensorType, @@ -15,7 +15,15 @@ ) from xdsl.dialects.csl import csl_stencil from xdsl.dialects.experimental import dmp -from xdsl.ir import Attribute, Block, BlockArgument, Operation, OpResult, Region +from xdsl.ir import ( + Attribute, + Block, + BlockArgument, + Operation, + OpResult, + Region, + SSAValue, +) from xdsl.irdl import Operand, base from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -53,7 +61,40 @@ def get_stencil_access_operands(op: Operand) -> set[Operand]: return res -@dataclass(frozen=True) +def _get_prefetch_buf(op: stencil.ApplyOp) -> int | None: + # calculate memory cost of all prefetch operands + def get_prefetch_overhead(o: OpResult): + assert isa(o.type, TensorType[Attribute]) + buf_count = o.type.get_shape()[0] + buf_size = prod(o.type.get_shape()[1:]) + return buf_count * buf_size + + candidate_prefetches = [ + (get_prefetch_overhead(o), o) + for o in op.operands + if isinstance(o, OpResult) and isinstance(o.op, csl_stencil.PrefetchOp) + ] + if len(candidate_prefetches) == 0: + return + + # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply + prefetch = max(candidate_prefetches)[1] + return op.operands.index(prefetch) + + +def _get_apply_op(op: Operation) -> stencil.ApplyOp | None: + """ + Return the enclosing csl_wrapper.module + """ + parent_op = op.parent_op() + while parent_op: + if isinstance(parent_op, stencil.ApplyOp): + return parent_op + parent_op = parent_op.parent_op() + return None + + +@dataclass(frozen=False) class RestructureSymmetricReductionPattern(RewritePattern): """ Consume data first where that data comes from stencil accesses to `buf`. @@ -80,6 +121,10 @@ class RestructureSymmetricReductionPattern(RewritePattern): def match_and_rewrite( self, op: arith.Addf | arith.Mulf, rewriter: PatternRewriter, / ): + # if not (apply := _get_apply_op(op)) or not (buf_idx := _get_prefetch_buf(apply)): + # return + # self.buf = apply.region.block.args[buf_idx] + # this rewrite requires exactly 1 use which is the same type of operation if len(op.result.uses) != 1 or not isinstance( use := list(op.result.uses)[0].operation, type(op) @@ -296,6 +341,8 @@ def get_op_split( for op in ops: if isinstance(op, csl_stencil.AccessOp): (b, a)[op.op == buf].append(op) + elif isinstance(op, arith.Constant): + a.append(op) else: rem.append(op) @@ -329,14 +376,61 @@ def get_op_split( a.append(use.operation) rem.remove(use.operation) - if len(a_exports) == 1: - return a, b + rem + cnst_exports = [cnst for cnst in a_exports if isinstance(cnst, arith.Constant)] + if len(a_exports) == 1 + len(cnst_exports): + recv_chunk_ops, done_exch_ops = list[Operation](), list[Operation]() + for op in ops: + if op in a: + recv_chunk_ops.append(op) + if op in cnst_exports: + done_exch_ops.append(cln := op.clone()) + op.result.replace_by_if( + cln.result, + lambda use: use.operation in b or use.operation in rem, + ) + else: + done_exch_ops.append(op) + + return recv_chunk_ops, done_exch_ops # fallback # always place `stencil.return` in second block return ops[:-1], [ops[-1]] +@dataclass(frozen=True) +class SplitVarithOpPattern(RewritePattern): + """ + Splits a varith op into two, depending on whether the operands holds stencil accesses to `buf` (only) + or any other accesses. + + This pass is intended to be run with `buf` set to the block arg indicating data received from neighbours. + """ + + buf: BlockArgument + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /): + if not (apply := _get_apply_op(op)) or not ( + buf_idx := _get_prefetch_buf(apply) + ): + return + buf = apply.region.block.args[buf_idx] + buf_accesses, others = list[SSAValue](), list[SSAValue]() + + for arg in op.args: + accs = get_stencil_access_operands(arg) + (others, buf_accesses)[buf in accs and len(accs) == 1].append(arg) + + if len(others) > 1 and len(buf_accesses) > 1: + rewriter.replace_matched_op( + [ + n_op := type(op)(*buf_accesses), + type(op)(n_op, *others), + ] + ) + + @dataclass(frozen=True) class ConvertApplyOpPattern(RewritePattern): """ @@ -357,24 +451,12 @@ class ConvertApplyOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): - # calculate memory cost of all prefetch operands - def get_prefetch_overhead(o: OpResult): - assert isa(o.type, TensorType[Attribute]) - buf_count = o.type.get_shape()[0] - buf_size = prod(o.type.get_shape()[1:]) - return buf_count * buf_size - - candidate_prefetches = [ - (get_prefetch_overhead(o), o) - for o in op.operands - if isinstance(o, OpResult) and isinstance(o.op, csl_stencil.PrefetchOp) - ] - if len(candidate_prefetches) == 0: + if not (prefetch_idx := _get_prefetch_buf(op)): return # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply - prefetch = max(candidate_prefetches)[1] - prefetch_idx = op.operands.index(prefetch) + prefetch = op.operands[prefetch_idx] + assert isinstance(prefetch, OpResult) assert isinstance(prefetch.op, csl_stencil.PrefetchOp) field_idx = op.operands.index(prefetch.op.input_stencil) assert isinstance(prefetch.op, csl_stencil.PrefetchOp) @@ -389,12 +471,20 @@ def get_prefetch_overhead(o: OpResult): ) rewriter.insert_op(accumulator, InsertPoint.before(op)) + # # find varith ops and split according to neighbour data + has_varith = PatternRewriteWalker( + SplitVarithOpPattern(op.region.block.args[prefetch_idx]), + apply_recursively=False, + ).rewrite_op(op) # run pass (on this apply's region only) to consume data from `prefetch` accesses first - nested_rewriter = PatternRewriteWalker( - RestructureSymmetricReductionPattern(op.region.block.args[prefetch_idx]), - walk_reverse=True, - ) - nested_rewriter.rewrite_op(op) + if not has_varith: + nested_rewriter = PatternRewriteWalker( + RestructureSymmetricReductionPattern( + op.region.block.args[prefetch_idx] + ), + walk_reverse=True, + ) + nested_rewriter.rewrite_op(op) # determine how ops should be split across the two regions chunk_region_ops, done_exchange_ops = get_op_split( @@ -551,7 +641,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: ] ), walk_reverse=False, - apply_recursively=True, + # apply_recursively=True, ) module_pass.rewrite_module(op) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index dd2567d391..c5cd8dd2c6 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -3,7 +3,7 @@ from typing import TypeGuard, cast from xdsl.context import MLContext -from xdsl.dialects import builtin +from xdsl.dialects import builtin, varith from xdsl.dialects.arith import ( Addf, Constant, @@ -388,6 +388,17 @@ def match_and_rewrite( arithBinaryOpUpdateShape(op, rewriter) +class VarithOpUpdateShape(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /): + type_constructor = type(op) + if typ := get_required_result_type(op): + if needs_update_shape(op.result_types[0], typ): + rewriter.replace_matched_op( + type_constructor.build(operands=[op.args], result_types=[typ]) + ) + + class EmptyOpUpdateShape(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: EmptyOp, rewriter: PatternRewriter, /): @@ -436,6 +447,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: EmptyOpUpdateShape(), FillOpUpdateShape(), ArithOpUpdateShape(), + VarithOpUpdateShape(), ConstOpUpdateShape(), ] ), From 7a30c4ed86da3b4fcf065e90dac5fc039b16c4a8 Mon Sep 17 00:00:00 2001 From: n-io Date: Tue, 22 Oct 2024 14:36:27 +0200 Subject: [PATCH 2/8] transformations: (arith-to-varith) Support more cases --- .../transforms/convert-arith-to-varith.mlir | 46 +++++++++++++++-- xdsl/transforms/varith_transformations.py | 50 ++++++++----------- 2 files changed, 64 insertions(+), 32 deletions(-) diff --git a/tests/filecheck/transforms/convert-arith-to-varith.mlir b/tests/filecheck/transforms/convert-arith-to-varith.mlir index 0801ec2c5d..8583a9d30c 100644 --- a/tests/filecheck/transforms/convert-arith-to-varith.mlir +++ b/tests/filecheck/transforms/convert-arith-to-varith.mlir @@ -20,7 +20,7 @@ func.func @test_addi() { // CHECK-NEXT: %a, %b, %c = "test.op"() : () -> (i32, i32, i32) // CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (i32, i32, i32) // CHECK-NEXT: %x2 = arith.addi %0, %1 : i32 - // CHECK-NEXT: %r = varith.add %c, %a, %b, %2, %0, %1 : i32 + // CHECK-NEXT: %r = varith.add %a, %b, %c, %0, %1, %2 : i32 // CHECK-NEXT: "test.op"(%r, %x2) : (i32, i32) -> () } @@ -45,7 +45,7 @@ func.func @test_addf() { // CHECK-NEXT: %a, %b, %c = "test.op"() : () -> (f32, f32, f32) // CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (f32, f32, f32) // CHECK-NEXT: %x2 = arith.addf %0, %1 : f32 - // CHECK-NEXT: %r = varith.add %c, %a, %b, %2, %0, %1 : f32 + // CHECK-NEXT: %r = varith.add %a, %b, %c, %0, %1, %2 : f32 // CHECK-NEXT: "test.op"(%r, %x2) : (f32, f32) -> () } @@ -69,6 +69,46 @@ func.func @test_mulf() { // CHECK-NEXT: %a, %b, %c = "test.op"() : () -> (f32, f32, f32) // CHECK-NEXT: %0, %1, %2 = "test.op"() : () -> (f32, f32, f32) // CHECK-NEXT: %x2 = arith.mulf %0, %1 : f32 - // CHECK-NEXT: %r = varith.mul %c, %a, %b, %2, %0, %1 : f32 + // CHECK-NEXT: %r = varith.mul %a, %b, %c, %0, %1, %2 : f32 // CHECK-NEXT: "test.op"(%r, %x2) : (f32, f32) -> () } + +func.func @test() { + %0, %1, %2, %3, %4, %5 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %6 = arith.constant dense<1.234500e-01> : tensor<8xf32> + %a = arith.addf %5, %4 : tensor<8xf32> + %b = arith.addf %a, %3 : tensor<8xf32> + %c = arith.addf %b, %2 : tensor<8xf32> + %d = arith.addf %c, %1 : tensor<8xf32> + %e = arith.addf %d, %0 : tensor<8xf32> + %12 = arith.mulf %e, %6 : tensor<8xf32> + "test.op"(%12) : (tensor<8xf32>) -> () + func.return + + // CHECK-LABEL: @test + // CHECK-NEXT: %0, %1, %2, %3, %4, %5 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %6 = arith.constant dense<1.234500e-01> : tensor<8xf32> + // CHECK-NEXT: %e = varith.add %5, %4, %3, %2, %1, %0 : tensor<8xf32> + // CHECK-NEXT: %7 = arith.mulf %e, %6 : tensor<8xf32> + // CHECK-NEXT: "test.op"(%7) : (tensor<8xf32>) -> () +} + +func.func @test2() { + %0, %1, %2, %3, %4, %5 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + %6 = arith.constant dense<1.234500e-01> : tensor<8xf32> + %a = arith.addf %5, %4 : tensor<8xf32> + %b = arith.addf %3, %a : tensor<8xf32> + %c = arith.addf %2, %b : tensor<8xf32> + %d = arith.addf %1, %c : tensor<8xf32> + %e = arith.addf %0, %d : tensor<8xf32> + %12 = arith.mulf %e, %6 : tensor<8xf32> + "test.op"(%12) : (tensor<8xf32>) -> () + func.return + + // CHECK-LABEL: @test + // CHECK-NEXT: %0, %1, %2, %3, %4, %5 = "test.op"() : () -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %6 = arith.constant dense<1.234500e-01> : tensor<8xf32> + // CHECK-NEXT: %e = varith.add %1, %2, %3, %5, %4, %0 : tensor<8xf32> + // CHECK-NEXT: %7 = arith.mulf %e, %6 : tensor<8xf32> + // CHECK-NEXT: "test.op"(%7) : (tensor<8xf32>) -> () +} diff --git a/xdsl/transforms/varith_transformations.py b/xdsl/transforms/varith_transformations.py index ebdc063010..9837b07dc5 100644 --- a/xdsl/transforms/varith_transformations.py +++ b/xdsl/transforms/varith_transformations.py @@ -33,38 +33,29 @@ class ArithToVarithPattern(RewritePattern): Merges two arith operations into a varith operation. """ - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /): - # check that the op is of a type that we can convert to varith - if type(op) not in ARITH_TO_VARITH_TYPE_MAP: + @op_type_rewrite_pattern + def match_and_rewrite( + self, + op: arith.Addi | arith.Addf | arith.Muli | arith.Mulf, + rewriter: PatternRewriter, + /, + ): + dest_type = ARITH_TO_VARITH_TYPE_MAP[type(op)] + + if len(op.result.uses) != 1: + return + if type(use_op := list(op.result.uses)[0].operation) not in ( + type(op), + dest_type, + ): return - # this must be true, as all keys of ARITH_TO_VARITH_TYPE_MAP are binary ops - op = cast( - arith.SignlessIntegerBinaryOperation - | arith.FloatingPointLikeBinaryOperation, - op, + other_operands = [o for o in use_op.operands if o != op.result] + rewriter.replace_op( + use_op, + dest_type(op.lhs, op.rhs, *other_operands), ) - - dest_type = ARITH_TO_VARITH_TYPE_MAP[type(op)] - - # check LHS and the RHS to see if they can be folded - # but abort after one is merged - for other in (op.rhs.owner, op.lhs.owner): - # if me and the other op are the same op - # (they must necessarily operate on the same data type) - if type(op) is type(other): - other_op = cast( - arith.SignlessIntegerBinaryOperation - | arith.FloatingPointLikeBinaryOperation, - other, - ) - # instantiate a varith op with three operands - rewriter.replace_matched_op( - dest_type(op.rhs, other_op.lhs, other_op.rhs) - ) - if len(other_op.result.uses) == 0: - rewriter.erase_op(other_op) - return + rewriter.erase_matched_op() class VarithToArithPattern(RewritePattern): @@ -209,6 +200,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: MergeVarithOpsPattern(), ] ), + walk_reverse=True, ).rewrite_op(op) From c5a56102b57ad38875cab5b18b0ebf38e11ef590 Mon Sep 17 00:00:00 2001 From: n-io Date: Tue, 22 Oct 2024 14:41:46 +0200 Subject: [PATCH 3/8] remove unused code --- xdsl/transforms/varith_transformations.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xdsl/transforms/varith_transformations.py b/xdsl/transforms/varith_transformations.py index 9837b07dc5..0c6cbaa980 100644 --- a/xdsl/transforms/varith_transformations.py +++ b/xdsl/transforms/varith_transformations.py @@ -24,9 +24,6 @@ arith.Mulf: varith.VarithMulOp, } -# map the arith operation to the right varith op: -VARITH_TYPES = [varith.VarithAddOp, varith.VarithMulOp] - class ArithToVarithPattern(RewritePattern): """ From 49e80885e6a05326e58d38e1c67d350bea97273d Mon Sep 17 00:00:00 2001 From: n-io Date: Tue, 22 Oct 2024 15:14:02 +0200 Subject: [PATCH 4/8] use varith --- .../convert-stencil-to-csl-stencil.mlir | 74 ++++++++------- .../convert_stencil_to_csl_stencil.py | 91 ++----------------- 2 files changed, 53 insertions(+), 112 deletions(-) diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index dc8d9417cc..448d3970d5 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -111,7 +111,7 @@ builtin.module { // CHECK-NEXT: } -func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { %0 = arith.constant 41 : index %1 = arith.constant 0 : index %2 = arith.constant 1 : index @@ -138,13 +138,13 @@ func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604 func.return } -// CHECK-NEXT: func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { +// CHECK-NEXT: func.func @xdiff(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { // CHECK-NEXT: %0 = arith.constant 41 : index // CHECK-NEXT: %1 = arith.constant 0 : index // CHECK-NEXT: %2 = arith.constant 1 : index // CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { // CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): // CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> // CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<300xf32> @@ -236,7 +236,7 @@ func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604 // CHECK-NEXT: %2 = arith.constant 1 : index // CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { // CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): // CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> // CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<300xf32> @@ -260,34 +260,46 @@ func.func @xDSLDiffusionOperator(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604 // CHECK-NEXT: %28 = arith.mulf %27, %12 : tensor<300xf32> // CHECK-NEXT: %29 = csl_stencil.access %6[0, 2] : tensor<8x300xf32> // CHECK-NEXT: %30 = arith.mulf %29, %12 : tensor<300xf32> -// CHECK-NEXT: %31 = varith.add %22, %16, %18, %20, %30, %24, %26, %28 : tensor<300xf32> -// CHECK-NEXT: %32 = "tensor.insert_slice"(%31, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %32 : tensor<600xf32> +// CHECK-NEXT: %31 = arith.addf %22, %16 : tensor<300xf32> +// CHECK-NEXT: %32 = arith.addf %31, %18 : tensor<300xf32> +// CHECK-NEXT: %33 = arith.addf %32, %20 : tensor<300xf32> +// CHECK-NEXT: %34 = arith.addf %33, %30 : tensor<300xf32> +// CHECK-NEXT: %35 = arith.addf %34, %24 : tensor<300xf32> +// CHECK-NEXT: %36 = arith.addf %35, %26 : tensor<300xf32> +// CHECK-NEXT: %37 = arith.addf %36, %28 : tensor<300xf32> +// CHECK-NEXT: %38 = "tensor.insert_slice"(%37, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %38 : tensor<600xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%33 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %34 : tensor<600xf32>): -// CHECK-NEXT: %35 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %36 = arith.constant dense<1.196003e+05> : tensor<600xf32> -// CHECK-NEXT: %37 = arith.constant dense<-2.242506e+05> : tensor<600xf32> -// CHECK-NEXT: %38 = arith.constant dense<-7.475020e+03> : tensor<600xf32> -// CHECK-NEXT: %39 = arith.constant dense<9.000000e-01> : tensor<600xf32> -// CHECK-NEXT: %40 = arith.constant dense<1.033968e-08> : tensor<600xf32> -// CHECK-NEXT: %41 = csl_stencil.access %33[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> -// CHECK-NEXT: %42 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %43 = arith.mulf %42, %37 : tensor<600xf32> -// CHECK-NEXT: %44 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %45 = arith.mulf %44, %36 : tensor<600xf32> -// CHECK-NEXT: %46 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %47 = arith.mulf %46, %36 : tensor<600xf32> -// CHECK-NEXT: %48 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %49 = arith.mulf %48, %38 : tensor<600xf32> -// CHECK-NEXT: %50 = "tensor.extract_slice"(%41) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %51 = arith.mulf %50, %38 : tensor<600xf32> -// CHECK-NEXT: %52 = varith.add %34, %51, %43, %45, %47, %49, %43, %43 : tensor<600xf32> -// CHECK-NEXT: %53 = arith.mulf %52, %39 : tensor<600xf32> -// CHECK-NEXT: %54 = arith.mulf %42, %35 : tensor<600xf32> -// CHECK-NEXT: %55 = arith.addf %54, %53 : tensor<600xf32> -// CHECK-NEXT: %56 = arith.mulf %55, %40 : tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %56 : tensor<600xf32> +// CHECK-NEXT: ^1(%39 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %40 : tensor<600xf32>): +// CHECK-NEXT: %41 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %42 = arith.constant dense<1.196003e+05> : tensor<600xf32> +// CHECK-NEXT: %43 = arith.constant dense<-2.242506e+05> : tensor<600xf32> +// CHECK-NEXT: %44 = arith.constant dense<-7.475020e+03> : tensor<600xf32> +// CHECK-NEXT: %45 = arith.constant dense<9.000000e-01> : tensor<600xf32> +// CHECK-NEXT: %46 = arith.constant dense<1.033968e-08> : tensor<600xf32> +// CHECK-NEXT: %47 = csl_stencil.access %39[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: %48 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %49 = arith.mulf %48, %43 : tensor<600xf32> +// CHECK-NEXT: %50 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %51 = arith.mulf %50, %42 : tensor<600xf32> +// CHECK-NEXT: %52 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %53 = arith.mulf %52, %42 : tensor<600xf32> +// CHECK-NEXT: %54 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %55 = arith.mulf %54, %44 : tensor<600xf32> +// CHECK-NEXT: %56 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %57 = arith.mulf %56, %44 : tensor<600xf32> +// CHECK-NEXT: %58 = arith.addf %40, %57 : tensor<600xf32> +// CHECK-NEXT: %59 = arith.addf %58, %49 : tensor<600xf32> +// CHECK-NEXT: %60 = arith.addf %59, %51 : tensor<600xf32> +// CHECK-NEXT: %61 = arith.addf %60, %53 : tensor<600xf32> +// CHECK-NEXT: %62 = arith.addf %61, %55 : tensor<600xf32> +// CHECK-NEXT: %63 = arith.addf %62, %49 : tensor<600xf32> +// CHECK-NEXT: %64 = arith.addf %63, %49 : tensor<600xf32> +// CHECK-NEXT: %65 = arith.mulf %64, %45 : tensor<600xf32> +// CHECK-NEXT: %66 = arith.mulf %48, %41 : tensor<600xf32> +// CHECK-NEXT: %67 = arith.addf %66, %65 : tensor<600xf32> +// CHECK-NEXT: %68 = arith.mulf %67, %46 : tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %68 : tensor<600xf32> // CHECK-NEXT: }) to <[0, 0], [1, 1]> // CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> // CHECK-NEXT: } diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 760282afed..97255e295c 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -37,6 +37,10 @@ from xdsl.transforms.experimental.stencil_tensorize_z_dimension import ( BackpropagateStencilShapes, ) +from xdsl.transforms.varith_transformations import ( + ConvertArithToVarithPass, + ConvertVarithToArithPass, +) from xdsl.utils.hints import isa from xdsl.utils.isattr import isattr @@ -94,75 +98,6 @@ def _get_apply_op(op: Operation) -> stencil.ApplyOp | None: return None -@dataclass(frozen=False) -class RestructureSymmetricReductionPattern(RewritePattern): - """ - Consume data first where that data comes from stencil accesses to `buf`. - - Identifies a pattern of 2 connected binary ops with 3 args, e.g. of the form `(a+b)+c` with different ops and - bracketings supported, and attempts to re-structure the order of computation. - - Being in principle similarly to constant folding, the difference is that args are *not* required to be stencil - accesses, but could have further compute applied before being passed to the reduction function. - Uses helper function `get_stencil_accessed_symbols` to check which bufs are stencil-accessed in each of these args, - and to distinguish the following three cases: - - (1) all accesses in an arg tree are to `buf` - arg should be moved forward in the computation - (2) no accesses are to `buf` - arg should be moved backward in the computation - (3) there's a mix - unknown, take any or no action - - If two args are identified that should be moved forward, or two args are identified that should be moved backwards, - the computation is restructured accordingly. - """ - - buf: BlockArgument - - @op_type_rewrite_pattern - def match_and_rewrite( - self, op: arith.Addf | arith.Mulf, rewriter: PatternRewriter, / - ): - # if not (apply := _get_apply_op(op)) or not (buf_idx := _get_prefetch_buf(apply)): - # return - # self.buf = apply.region.block.args[buf_idx] - - # this rewrite requires exactly 1 use which is the same type of operation - if len(op.result.uses) != 1 or not isinstance( - use := list(op.result.uses)[0].operation, type(op) - ): - return - c_op = use.operands[0] if use.operands[1] == op.result else use.operands[1] - - def rewrite(one: Operand, two: Operand, three: Operand): - """Builds `(one+two)+three` where `'+' == type(op)`""" - - first_compute = type(op)(one, two) - second_compute = type(op)(first_compute, three) - - # Both ops are inserted at the later point to ensure all dependencies are present when moving compute around. - # Moving the replacement of `op` backwards is safe because we previously asserted at `op` only has one use (ie. in `use`) - rewriter.replace_op(op, [], [first_compute.results[0]]) - rewriter.replace_op( - use, [first_compute, second_compute], [second_compute.results[0]] - ) - - a = get_stencil_access_operands(a_op := op.lhs) - b = get_stencil_access_operands(b_op := op.rhs) - c = get_stencil_access_operands(c_op) - - if self.move_fwd(a) and self.move_fwd(b): - return - elif self.move_back(a) and self.move_back(b): - return - elif self.move_fwd(c) and self.move_back(b): - rewrite(a_op, c_op, b_op) - - def move_fwd(self, accs: set[Operand]) -> bool: - return self.buf in accs and len(accs) == 1 - - def move_back(self, accs: set[Operand]) -> bool: - return self.buf not in accs - - @dataclass(frozen=True) class ConvertAccessOpFromPrefetchPattern(RewritePattern): """ @@ -422,7 +357,7 @@ def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /): accs = get_stencil_access_operands(arg) (others, buf_accesses)[buf in accs and len(accs) == 1].append(arg) - if len(others) > 1 and len(buf_accesses) > 1: + if len(others) > 0 and len(buf_accesses) > 0: rewriter.replace_matched_op( [ n_op := type(op)(*buf_accesses), @@ -471,20 +406,12 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): ) rewriter.insert_op(accumulator, InsertPoint.before(op)) - # # find varith ops and split according to neighbour data - has_varith = PatternRewriteWalker( + # run pass (on this apply's region only) to consume data from `prefetch` accesses first + # find varith ops and split according to neighbour data + PatternRewriteWalker( SplitVarithOpPattern(op.region.block.args[prefetch_idx]), apply_recursively=False, ).rewrite_op(op) - # run pass (on this apply's region only) to consume data from `prefetch` accesses first - if not has_varith: - nested_rewriter = PatternRewriteWalker( - RestructureSymmetricReductionPattern( - op.region.block.args[prefetch_idx] - ), - walk_reverse=True, - ) - nested_rewriter.rewrite_op(op) # determine how ops should be split across the two regions chunk_region_ops, done_exchange_ops = get_op_split( @@ -633,6 +560,7 @@ class ConvertStencilToCslStencilPass(ModulePass): num_chunks: int = 1 def apply(self, ctx: MLContext, op: ModuleOp) -> None: + ConvertArithToVarithPass().apply(ctx, op) module_pass = PatternRewriteWalker( GreedyRewritePatternApplier( [ @@ -644,6 +572,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: # apply_recursively=True, ) module_pass.rewrite_module(op) + ConvertVarithToArithPass().apply(ctx, op) if self.num_chunks > 1: BackpropagateStencilShapes().apply(ctx, op) From 87081305e1672d17c2063af4e7357a19eb550b8d Mon Sep 17 00:00:00 2001 From: n-io Date: Tue, 22 Oct 2024 15:15:24 +0200 Subject: [PATCH 5/8] remove default value flag --- xdsl/transforms/convert_stencil_to_csl_stencil.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 97255e295c..0f28018c92 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -569,7 +569,6 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: ] ), walk_reverse=False, - # apply_recursively=True, ) module_pass.rewrite_module(op) ConvertVarithToArithPass().apply(ctx, op) From e7e23c978a81e7bc8c91806b04143286190b57a8 Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 23 Oct 2024 14:44:39 +0200 Subject: [PATCH 6/8] fix filecheck --- .../convert-stencil-to-csl-stencil.mlir | 165 ++---------------- 1 file changed, 13 insertions(+), 152 deletions(-) diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index 2d4ec283ab..1996a67e10 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -190,25 +190,23 @@ builtin.module { // CHECK-NEXT: %2 = arith.constant 1 : index // CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { // CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({ // CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): // CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<300xf32> +// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32> // CHECK-NEXT: %11 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32> -// CHECK-NEXT: %12 = arith.mulf %11, %10 : tensor<300xf32> -// CHECK-NEXT: %13 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> -// CHECK-NEXT: %14 = arith.mulf %13, %10 : tensor<300xf32> -// CHECK-NEXT: %15 = arith.addf %12, %14 : tensor<300xf32> -// CHECK-NEXT: %16 = "tensor.insert_slice"(%15, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %16 : tensor<600xf32> +// CHECK-NEXT: %12 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> +// CHECK-NEXT: %13 = arith.addf %11, %12 : tensor<300xf32> +// CHECK-NEXT: %14 = "tensor.insert_slice"(%13, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %14 : tensor<600xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%17 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %18 : tensor<600xf32>): -// CHECK-NEXT: %19 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %20 = csl_stencil.access %17[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> -// CHECK-NEXT: %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %22 = arith.mulf %21, %19 : tensor<600xf32> -// CHECK-NEXT: %23 = arith.addf %18, %22 : tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %23 : tensor<600xf32> +// CHECK-NEXT: ^1(%15 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %16 : tensor<600xf32>): +// CHECK-NEXT: %17 = arith.constant dense<1.287158e+09> : tensor<600xf32> +// CHECK-NEXT: %18 = csl_stencil.access %15[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> +// CHECK-NEXT: %19 = "tensor.extract_slice"(%18) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> +// CHECK-NEXT: %20 = arith.mulf %19, %17 : tensor<600xf32> +// CHECK-NEXT: %21 = arith.addf %16, %20 : tensor<600xf32> +// CHECK-NEXT: csl_stencil.yield %21 : tensor<600xf32> // CHECK-NEXT: }) to <[0, 0], [1, 1]> // CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> // CHECK-NEXT: } @@ -216,142 +214,5 @@ builtin.module { // CHECK-NEXT: } - func.func @diffusion(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { - %0 = arith.constant 41 : index - %1 = arith.constant 0 : index - %2 = arith.constant 1 : index - %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { - "dmp.swap"(%arg3) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<600x600>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) -> () - stencil.apply(%arg5 = %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { - %5 = arith.constant dense<1.287158e+09> : tensor<600xf32> - %6 = arith.constant dense<1.196003e+05> : tensor<600xf32> - %7 = arith.constant dense<-2.242506e+05> : tensor<600xf32> - %8 = arith.constant dense<-7.475020e+03> : tensor<600xf32> - %9 = arith.constant dense<9.000000e-01> : tensor<600xf32> - %10 = arith.constant dense<1.033968e-08> : tensor<600xf32> - %11 = stencil.access %arg5[-1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %12 = "tensor.extract_slice"(%11) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %13 = arith.mulf %12, %6 : tensor<600xf32> - %14 = stencil.access %arg5[1, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %16 = arith.mulf %15, %6 : tensor<600xf32> - %17 = stencil.access %arg5[-2, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %18 = "tensor.extract_slice"(%17) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %19 = arith.mulf %18, %8 : tensor<600xf32> - %20 = stencil.access %arg5[2, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %21 = "tensor.extract_slice"(%20) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %22 = arith.mulf %21, %8 : tensor<600xf32> - %23 = stencil.access %arg5[0, -1] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %24 = "tensor.extract_slice"(%23) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %25 = arith.mulf %24, %6 : tensor<600xf32> - %26 = stencil.access %arg5[0, 1] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %27 = "tensor.extract_slice"(%26) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %28 = arith.mulf %27, %6 : tensor<600xf32> - %29 = stencil.access %arg5[0, -2] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %30 = "tensor.extract_slice"(%29) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %31 = arith.mulf %30, %8 : tensor<600xf32> - %32 = stencil.access %arg5[0, 2] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %33 = "tensor.extract_slice"(%32) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %34 = arith.mulf %33, %8 : tensor<600xf32> - %35 = stencil.access %arg5[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - %36 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %37 = arith.mulf %36, %7 : tensor<600xf32> - %38 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %39 = arith.mulf %38, %6 : tensor<600xf32> - %40 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %41 = arith.mulf %40, %6 : tensor<600xf32> - %42 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %43 = arith.mulf %42, %8 : tensor<600xf32> - %44 = "tensor.extract_slice"(%35) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> - %45 = arith.mulf %44, %8 : tensor<600xf32> - %46 = varith.add %45, %37, %39, %41, %43, %22, %37, %13, %16, %19, %34, %37, %25, %28, %31 : tensor<600xf32> - %47 = arith.mulf %46, %9 : tensor<600xf32> - %48 = arith.mulf %36, %5 : tensor<600xf32> - %49 = arith.addf %48, %47 : tensor<600xf32> - %50 = arith.mulf %49, %10 : tensor<600xf32> - stencil.return %50 : tensor<600xf32> - } to <[0, 0], [1, 1]> - scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> - } - func.return - } - -// CHECK-NEXT: func.func @diffusion(%arg0 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %arg1 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { -// CHECK-NEXT: %0 = arith.constant 41 : index -// CHECK-NEXT: %1 = arith.constant 0 : index -// CHECK-NEXT: %2 = arith.constant 1 : index -// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) { -// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32> -// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>): -// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<300xf32> -// CHECK-NEXT: %11 = arith.constant dense<-2.242506e+05> : tensor<600xf32> -// CHECK-NEXT: %12 = arith.constant dense<-7.475020e+03> : tensor<300xf32> -// CHECK-NEXT: %13 = arith.constant dense<9.000000e-01> : tensor<600xf32> -// CHECK-NEXT: %14 = arith.constant dense<1.033968e-08> : tensor<600xf32> -// CHECK-NEXT: %15 = csl_stencil.access %6[-1, 0] : tensor<8x300xf32> -// CHECK-NEXT: %16 = arith.mulf %15, %10 : tensor<300xf32> -// CHECK-NEXT: %17 = csl_stencil.access %6[1, 0] : tensor<8x300xf32> -// CHECK-NEXT: %18 = arith.mulf %17, %10 : tensor<300xf32> -// CHECK-NEXT: %19 = csl_stencil.access %6[-2, 0] : tensor<8x300xf32> -// CHECK-NEXT: %20 = arith.mulf %19, %12 : tensor<300xf32> -// CHECK-NEXT: %21 = csl_stencil.access %6[2, 0] : tensor<8x300xf32> -// CHECK-NEXT: %22 = arith.mulf %21, %12 : tensor<300xf32> -// CHECK-NEXT: %23 = csl_stencil.access %6[0, -1] : tensor<8x300xf32> -// CHECK-NEXT: %24 = arith.mulf %23, %10 : tensor<300xf32> -// CHECK-NEXT: %25 = csl_stencil.access %6[0, 1] : tensor<8x300xf32> -// CHECK-NEXT: %26 = arith.mulf %25, %10 : tensor<300xf32> -// CHECK-NEXT: %27 = csl_stencil.access %6[0, -2] : tensor<8x300xf32> -// CHECK-NEXT: %28 = arith.mulf %27, %12 : tensor<300xf32> -// CHECK-NEXT: %29 = csl_stencil.access %6[0, 2] : tensor<8x300xf32> -// CHECK-NEXT: %30 = arith.mulf %29, %12 : tensor<300xf32> -// CHECK-NEXT: %31 = arith.addf %22, %16 : tensor<300xf32> -// CHECK-NEXT: %32 = arith.addf %31, %18 : tensor<300xf32> -// CHECK-NEXT: %33 = arith.addf %32, %20 : tensor<300xf32> -// CHECK-NEXT: %34 = arith.addf %33, %30 : tensor<300xf32> -// CHECK-NEXT: %35 = arith.addf %34, %24 : tensor<300xf32> -// CHECK-NEXT: %36 = arith.addf %35, %26 : tensor<300xf32> -// CHECK-NEXT: %37 = arith.addf %36, %28 : tensor<300xf32> -// CHECK-NEXT: %38 = "tensor.insert_slice"(%37, %8, %7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<300xf32>, tensor<600xf32>, index) -> tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %38 : tensor<600xf32> -// CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%39 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %40 : tensor<600xf32>): -// CHECK-NEXT: %41 = arith.constant dense<1.287158e+09> : tensor<600xf32> -// CHECK-NEXT: %42 = arith.constant dense<1.196003e+05> : tensor<600xf32> -// CHECK-NEXT: %43 = arith.constant dense<-2.242506e+05> : tensor<600xf32> -// CHECK-NEXT: %44 = arith.constant dense<-7.475020e+03> : tensor<600xf32> -// CHECK-NEXT: %45 = arith.constant dense<9.000000e-01> : tensor<600xf32> -// CHECK-NEXT: %46 = arith.constant dense<1.033968e-08> : tensor<600xf32> -// CHECK-NEXT: %47 = csl_stencil.access %39[0, 0] : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> -// CHECK-NEXT: %48 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %49 = arith.mulf %48, %43 : tensor<600xf32> -// CHECK-NEXT: %50 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %51 = arith.mulf %50, %42 : tensor<600xf32> -// CHECK-NEXT: %52 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %53 = arith.mulf %52, %42 : tensor<600xf32> -// CHECK-NEXT: %54 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %55 = arith.mulf %54, %44 : tensor<600xf32> -// CHECK-NEXT: %56 = "tensor.extract_slice"(%47) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<604xf32>) -> tensor<600xf32> -// CHECK-NEXT: %57 = arith.mulf %56, %44 : tensor<600xf32> -// CHECK-NEXT: %58 = arith.addf %40, %57 : tensor<600xf32> -// CHECK-NEXT: %59 = arith.addf %58, %49 : tensor<600xf32> -// CHECK-NEXT: %60 = arith.addf %59, %51 : tensor<600xf32> -// CHECK-NEXT: %61 = arith.addf %60, %53 : tensor<600xf32> -// CHECK-NEXT: %62 = arith.addf %61, %55 : tensor<600xf32> -// CHECK-NEXT: %63 = arith.addf %62, %49 : tensor<600xf32> -// CHECK-NEXT: %64 = arith.addf %63, %49 : tensor<600xf32> -// CHECK-NEXT: %65 = arith.mulf %64, %45 : tensor<600xf32> -// CHECK-NEXT: %66 = arith.mulf %48, %41 : tensor<600xf32> -// CHECK-NEXT: %67 = arith.addf %66, %65 : tensor<600xf32> -// CHECK-NEXT: %68 = arith.mulf %67, %46 : tensor<600xf32> -// CHECK-NEXT: csl_stencil.yield %68 : tensor<600xf32> -// CHECK-NEXT: }) to <[0, 0], [1, 1]> -// CHECK-NEXT: scf.yield %arg4, %arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>> -// CHECK-NEXT: } -// CHECK-NEXT: func.return -// CHECK-NEXT: } - - } // CHECK-NEXT: } From f9a0f0d31552b9c12994b390a50de0bf0876c7ce Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 23 Oct 2024 15:01:17 +0200 Subject: [PATCH 7/8] fix filecheck --- .../convert-stencil-to-csl-stencil.mlir | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir index 1996a67e10..1f43e07c93 100644 --- a/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir @@ -127,7 +127,7 @@ builtin.module { %13 = arith.addf %12, %8 : tensor<510xf32> %14 = arith.addf %13, %11 : tensor<510xf32> %15 = arith.mulf %14, %2 : tensor<510xf32> - stencil.return %13 : tensor<510xf32> + stencil.return %15 : tensor<510xf32> } to <[0, 0], [1, 1]> stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> func.return @@ -135,23 +135,24 @@ builtin.module { // CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>, %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({ -// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>, %5 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>): -// CHECK-NEXT: %6 = arith.constant dense<1.234500e-01> : tensor<255xf32> -// CHECK-NEXT: %7 = arith.constant dense<2.345678e-01> : tensor<510xf32> -// CHECK-NEXT: %8 = arith.constant dense<3.141500e-01> : tensor<510xf32> -// CHECK-NEXT: %9 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> -// CHECK-NEXT: %10 = csl_stencil.access %5[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<255xf32> -// CHECK-NEXT: %12 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> -// CHECK-NEXT: %13 = arith.addf %12, %11 : tensor<255xf32> -// CHECK-NEXT: %14 = arith.addf %13, %9 : tensor<255xf32> -// CHECK-NEXT: %15 = arith.mulf %14, %6 : tensor<255xf32> -// CHECK-NEXT: %16 = "tensor.insert_slice"(%15, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> -// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32> +// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({ +// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>): +// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32> +// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32> +// CHECK-NEXT: %7 = arith.constant dense<3.141500e-01> : tensor<510xf32> +// CHECK-NEXT: %8 = csl_stencil.access %2[1, 0] : tensor<4x255xf32> +// CHECK-NEXT: %9 = csl_stencil.access %2[0, -1] : tensor<4x255xf32> +// CHECK-NEXT: %10 = arith.addf %9, %8 : tensor<255xf32> +// CHECK-NEXT: %11 = "tensor.insert_slice"(%10, %4, %3) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32> // CHECK-NEXT: }, { -// CHECK-NEXT: ^1(%17 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %18 : tensor<510xf32>): -// CHECK-NEXT: csl_stencil.yield %13 : tensor<255xf32> +// CHECK-NEXT: ^1(%12 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %13 : tensor<510xf32>): +// CHECK-NEXT: %14 = arith.constant dense<1.234500e-01> : tensor<510xf32> +// CHECK-NEXT: %15 = csl_stencil.access %12[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %17 = arith.addf %13, %16 : tensor<510xf32> +// CHECK-NEXT: %18 = arith.mulf %17, %14 : tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %18 : tensor<510xf32> // CHECK-NEXT: }) to <[0, 0], [1, 1]> // CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return From b83551cdde8e3d07094736c529fd4b6172b0311e Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 24 Oct 2024 10:47:08 +0200 Subject: [PATCH 8/8] fixes --- .../convert_stencil_to_csl_stencil.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py index 9e5d025e76..351ee2669f 100644 --- a/xdsl/transforms/convert_stencil_to_csl_stencil.py +++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py @@ -67,13 +67,11 @@ def get_stencil_access_operands(op: Operand) -> set[Operand]: return res -def _get_prefetch_buf(op: stencil.ApplyOp) -> int | None: +def _get_prefetch_buf_idx(op: stencil.ApplyOp) -> int | None: # calculate memory cost of all prefetch operands def get_prefetch_overhead(o: OpResult): assert isa(o.type, TensorType[Attribute]) - buf_count = o.type.get_shape()[0] - buf_size = prod(o.type.get_shape()[1:]) - return buf_count * buf_size + return prod(o.type.get_shape()) candidate_prefetches = [ (get_prefetch_overhead(o), o) @@ -258,7 +256,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): nested_rewriter.rewrite_op(new_apply_op) -def get_op_split( +def split_ops( ops: Sequence[Operation], buf: BlockArgument ) -> tuple[Sequence[Operation], Sequence[Operation]]: """ @@ -313,14 +311,19 @@ def get_op_split( a.append(use.operation) rem.remove(use.operation) + # find constants in `a` needed outside of `a` cnst_exports = [cnst for cnst in a_exports if isinstance(cnst, arith.Constant)] + + # `a` exports one value plus any number of constants - duplicate exported constants and return op split if len(a_exports) == 1 + len(cnst_exports): recv_chunk_ops, done_exch_ops = list[Operation](), list[Operation]() for op in ops: if op in a: recv_chunk_ops.append(op) if op in cnst_exports: + # create a copy of the constant in the second region done_exch_ops.append(cln := op.clone()) + # rewire ops of the second region to use the copied constant op.result.replace_by_if( cln.result, lambda use: use.operation in b or use.operation in rem, @@ -349,7 +352,7 @@ class SplitVarithOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: varith.VarithOp, rewriter: PatternRewriter, /): if not (apply := _get_apply_op(op)) or not ( - buf_idx := _get_prefetch_buf(apply) + buf_idx := _get_prefetch_buf_idx(apply) ): return buf = apply.region.block.args[buf_idx] @@ -388,7 +391,7 @@ class ConvertApplyOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): - if not (prefetch_idx := _get_prefetch_buf(op)): + if not (prefetch_idx := _get_prefetch_buf_idx(op)): return # select the prefetch with the biggest communication overhead to be fused with matched stencil.apply @@ -416,7 +419,7 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /): ).rewrite_op(op) # determine how ops should be split across the two regions - chunk_region_ops, done_exchange_ops = get_op_split( + chunk_region_ops, done_exchange_ops = split_ops( list(op.region.block.ops), op.region.block.args[prefetch_idx] )