Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (csl-stencil) Add coefficients to apply op #3320

Merged
merged 11 commits into from
Oct 22, 2024
47 changes: 47 additions & 0 deletions tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ builtin.module {
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
"dmp.swap"(%a) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
%0 = stencil.apply(%1 = %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%2 = arith.constant dense<1.234500e-01> : tensor<510xf32>
%3 = arith.constant dense<2.345678e-01> : tensor<510xf32>
%4 = arith.constant dense<3.141500e-01> : tensor<510xf32>
%5 = stencil.access %1[1, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
%6 = "tensor.extract_slice"(%5) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%7 = stencil.access %1[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
%8 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%9 = stencil.access %1[0, -1] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
%10 = "tensor.extract_slice"(%9) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%11 = arith.mulf %6, %3 : tensor<510xf32>
%12 = arith.mulf %10, %4 : tensor<510xf32>
%13 = arith.addf %12, %8 : tensor<510xf32>
%14 = arith.addf %13, %11 : tensor<510xf32>
%15 = arith.mulf %14, %2 : tensor<510xf32>
stencil.return %13 : tensor<510xf32>
} to <[0, 0], [1, 1]>
stencil.store %0 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}

// CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>, %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>, %5 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-NEXT: %6 = arith.constant dense<1.234500e-01> : tensor<255xf32>
// CHECK-NEXT: %7 = arith.constant dense<2.345678e-01> : tensor<510xf32>
// CHECK-NEXT: %8 = arith.constant dense<3.141500e-01> : tensor<510xf32>
// CHECK-NEXT: %9 = csl_stencil.access %2[1, 0] : tensor<4x255xf32>
// CHECK-NEXT: %10 = csl_stencil.access %5[0, 0] : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<255xf32>
// CHECK-NEXT: %12 = csl_stencil.access %2[0, -1] : tensor<4x255xf32>
// CHECK-NEXT: %13 = arith.addf %12, %11 : tensor<255xf32>
// CHECK-NEXT: %14 = arith.addf %13, %9 : tensor<255xf32>
// CHECK-NEXT: %15 = arith.mulf %14, %6 : tensor<255xf32>
// CHECK-NEXT: %16 = "tensor.insert_slice"(%15, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %16 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%17 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %18 : tensor<510xf32>):
// CHECK-NEXT: csl_stencil.yield %13 : tensor<255xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: stencil.store %1 to %b(<[0, 0], [1, 1]>) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }

}
Expand Down
66 changes: 38 additions & 28 deletions tests/filecheck/transforms/lower-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ builtin.module {
%arg11 = "csl.load_var"(%24) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
%arg12 = "csl.load_var"(%25) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
%32 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
csl_stencil.apply(%arg11 : memref<511xf32>, %32 : memref<510xf32>, %arg11 : memref<511xf32>, %arg12 : memref<511xf32>, %arg9 : i1) outs (%arg12 : memref<511xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array<i32: 1, 1, 1, 2, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [0, 1]>], "topo" = #dmp.topo<1022x510>}> ({
csl_stencil.apply(%arg11 : memref<511xf32>, %32 : memref<510xf32>, %arg11 : memref<511xf32>, %arg12 : memref<511xf32>, %arg9 : i1) outs (%arg12 : memref<511xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array<i32: 1, 1, 1, 2, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [0, 1]>], "topo" = #dmp.topo<1022x510>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
^2(%arg13 : memref<2x510xf32>, %arg14 : index, %arg15 : memref<510xf32>, %arg16 : memref<511xf32>):
%33 = arith.constant dense<1.234500e-01> : memref<510xf32>
%34 = csl_stencil.access %arg13[1, 0] : memref<2x510xf32>
Expand Down Expand Up @@ -638,53 +638,63 @@ builtin.module {
// CHECK-NEXT: %arg11_2 = "csl.load_var"(%182) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: %arg12_2 = "csl.load_var"(%183) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: %accumulator_3 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: %190 = arith.constant 1 : i16
// CHECK-NEXT: %191 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %192 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %193 = memref.subview %arg11_2[0] [510] [1] : memref<511xf32> to memref<510xf32>
// CHECK-NEXT: "csl.member_call"(%176, %193, %190, %191, %192) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>, !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>) -> ()
// CHECK-NEXT: %north = arith.constant dense<[0.000000e+00, 3.141500e-01]> : memref<2xf32>
// CHECK-NEXT: %south = arith.constant dense<[0.000000e+00, 1.000000e+00]> : memref<2xf32>
// CHECK-NEXT: %east = arith.constant dense<[0.000000e+00, 1.000000e+00]> : memref<2xf32>
// CHECK-NEXT: %west = arith.constant dense<[0.000000e+00, 2.345678e-01]> : memref<2xf32>
// CHECK-NEXT: %190 = "csl.addressof"(%east) : (memref<2xf32>) -> !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %191 = "csl.addressof"(%west) : (memref<2xf32>) -> !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %192 = "csl.addressof"(%south) : (memref<2xf32>) -> !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %193 = "csl.addressof"(%north) : (memref<2xf32>) -> !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %194 = arith.constant false
// CHECK-NEXT: "csl.member_call"(%176, %190, %191, %192, %193, %194) <{"field" = "setCoeffs"}> : (!csl.imported_module, !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>, !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>, !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>, !csl.ptr<memref<2xf32>, #csl<ptr_kind single>, #csl<ptr_const const>>, i1) -> ()
// CHECK-NEXT: %195 = arith.constant 1 : i16
// CHECK-NEXT: %196 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %197 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %198 = memref.subview %arg11_2[0] [510] [1] : memref<511xf32> to memref<510xf32>
// CHECK-NEXT: "csl.member_call"(%176, %198, %195, %196, %197) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl<ptr_kind single>, #csl<ptr_const const>>, !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>) -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_6 : i16) {
// CHECK-NEXT: %offset_7 = arith.index_cast %offset_6 : i16 to index
// CHECK-NEXT: %arg11_3 = "csl.load_var"(%182) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: %194 = arith.constant dense<1.234500e-01> : memref<510xf32>
// CHECK-NEXT: %195 = arith.constant 1 : i16
// CHECK-NEXT: %196 = "csl.get_dir"() <{"dir" = #csl<dir_kind west>}> : () -> !csl.direction
// CHECK-NEXT: %197 = "csl.member_call"(%176, %196, %195) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %198 = builtin.unrealized_conversion_cast %197 : !csl<dsd mem1d_dsd> to memref<510xf32>
// CHECK-NEXT: %199 = memref.subview %arg11_3[1] [510] [1] : memref<511xf32> to memref<510xf32, strided<[1], offset: 1>>
// CHECK-NEXT: %199 = arith.constant dense<1.234500e-01> : memref<510xf32>
// CHECK-NEXT: %200 = arith.constant 1 : i16
// CHECK-NEXT: %201 = "csl.get_dir"() <{"dir" = #csl<dir_kind south>}> : () -> !csl.direction
// CHECK-NEXT: %201 = "csl.get_dir"() <{"dir" = #csl<dir_kind west>}> : () -> !csl.direction
// CHECK-NEXT: %202 = "csl.member_call"(%176, %201, %200) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %203 = builtin.unrealized_conversion_cast %202 : !csl<dsd mem1d_dsd> to memref<510xf32>
// CHECK-NEXT: %204 = memref.subview %accumulator_3[%offset_7] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "csl.fadds"(%204, %199, %203) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: 1>>, memref<510xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%204, %204, %198) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> ()
// CHECK-NEXT: %205 = arith.constant 1.234500e-01 : f32
// CHECK-NEXT: "csl.fmuls"(%204, %204, %205) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, f32) -> ()
// CHECK-NEXT: %204 = memref.subview %arg11_3[1] [510] [1] : memref<511xf32> to memref<510xf32, strided<[1], offset: 1>>
// CHECK-NEXT: %205 = arith.constant 1 : i16
// CHECK-NEXT: %206 = "csl.get_dir"() <{"dir" = #csl<dir_kind south>}> : () -> !csl.direction
// CHECK-NEXT: %207 = "csl.member_call"(%176, %206, %205) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %208 = builtin.unrealized_conversion_cast %207 : !csl<dsd mem1d_dsd> to memref<510xf32>
// CHECK-NEXT: %209 = memref.subview %accumulator_3[%offset_7] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "csl.fadds"(%209, %204, %208) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: 1>>, memref<510xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%209, %209, %203) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> ()
// CHECK-NEXT: %210 = arith.constant 1.234500e-01 : f32
// CHECK-NEXT: "csl.fmuls"(%209, %209, %210) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, f32) -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @done_exchange_cb1() {
// CHECK-NEXT: %arg12_3 = "csl.load_var"(%183) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: %arg11_4 = "csl.load_var"(%182) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: scf.if %arg9_1 {
// CHECK-NEXT: } else {
// CHECK-NEXT: %206 = memref.subview %arg12_3[0] [510] [1] : memref<511xf32> to memref<510xf32>
// CHECK-NEXT: "memref.copy"(%accumulator_3, %206) : (memref<510xf32>, memref<510xf32>) -> ()
// CHECK-NEXT: %211 = memref.subview %arg12_3[0] [510] [1] : memref<511xf32> to memref<510xf32>
// CHECK-NEXT: "memref.copy"(%accumulator_3, %211) : (memref<510xf32>, memref<510xf32>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @for_inc0() {
// CHECK-NEXT: %207 = arith.constant 1 : i16
// CHECK-NEXT: %208 = "csl.load_var"(%181) : (!csl.var<i16>) -> i16
// CHECK-NEXT: %209 = arith.addi %208, %207 : i16
// CHECK-NEXT: "csl.store_var"(%181, %209) : (!csl.var<i16>, i16) -> ()
// CHECK-NEXT: %210 = "csl.load_var"(%182) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: %211 = "csl.load_var"(%183) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: "csl.store_var"(%182, %211) : (!csl.var<memref<511xf32>>, memref<511xf32>) -> ()
// CHECK-NEXT: "csl.store_var"(%183, %210) : (!csl.var<memref<511xf32>>, memref<511xf32>) -> ()
// CHECK-NEXT: %212 = arith.constant 1 : i16
// CHECK-NEXT: %213 = "csl.load_var"(%181) : (!csl.var<i16>) -> i16
// CHECK-NEXT: %214 = arith.addi %213, %212 : i16
// CHECK-NEXT: "csl.store_var"(%181, %214) : (!csl.var<i16>, i16) -> ()
// CHECK-NEXT: %215 = "csl.load_var"(%182) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: %216 = "csl.load_var"(%183) : (!csl.var<memref<511xf32>>) -> memref<511xf32>
// CHECK-NEXT: "csl.store_var"(%182, %216) : (!csl.var<memref<511xf32>>, memref<511xf32>) -> ()
// CHECK-NEXT: "csl.store_var"(%183, %215) : (!csl.var<memref<511xf32>>, memref<511xf32>) -> ()
// CHECK-NEXT: csl.activate local, 1 : i32
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
Expand Down
13 changes: 13 additions & 0 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,16 @@ class PtrType(ParametrizedAttribute, TypeAttribute, ContainerType[Attribute]):
kind: ParameterDef[PtrKindAttr]
constness: ParameterDef[PtrConstAttr]

@staticmethod
def get(typ: Attribute, is_single: bool, is_const: bool):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe typ could be a TypeAttribute, since PtrType.type is?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done as suggested

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing back to Attribute to fix tests

return PtrType(
[
typ,
PtrKindAttr(PtrKind.SINGLE if is_single else PtrKind.MANY),
PtrConstAttr(PtrConst.CONST if is_const else PtrConst.VAR),
]
)

def get_element_type(self) -> Attribute:
return self.type

Expand Down Expand Up @@ -1865,6 +1875,9 @@ class AddressOfOp(IRDLOperation):

traits = frozenset([NoMemoryEffect()])

def __init__(self, value: SSAValue | Operation, result_type: PtrType):
super().__init__(operands=[value], result_types=[result_type])

def _verify_memref_addr(self, val_ty: MemRefType[Attribute], res_ty: PtrType):
"""
Verify that if the address of a memref is taken, the resulting pointer is either:
Expand Down
31 changes: 30 additions & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from collections.abc import Iterable, Sequence
from itertools import pairwise
from typing import cast
from typing import TypeAlias, cast

from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnyMemRefType,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
Float16Type,
Float32Type,
FloatAttr,
IndexType,
MemRefType,
TensorType,
Expand Down Expand Up @@ -47,6 +51,8 @@
HasParent,
IsolatedFromAbove,
IsTerminator,
MemoryReadEffect,
MemoryWriteEffect,
Pure,
RecursiveMemoryEffect,
)
Expand Down Expand Up @@ -152,6 +158,19 @@ def __init__(
)


CslFloat: TypeAlias = Float16Type | Float32Type


@irdl_attr_definition
class CoeffAttr(ParametrizedAttribute):
name = "csl_stencil.coeff"
offset: ParameterDef[stencil.IndexAttr]
coeff: ParameterDef[FloatAttr[AnyFloat]]

def __init__(self, offset: stencil.IndexAttr, coeff: FloatAttr[AnyFloat]):
super().__init__([offset, coeff])


class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
Expand Down Expand Up @@ -223,12 +242,16 @@ class ApplyOp(IRDLOperation):

bounds = opt_prop_def(stencil.StencilBoundsAttr)

coeffs = opt_prop_def(builtin.ArrayAttr[CoeffAttr])

res = var_result_def(stencil.StencilTypeConstr)

traits = frozenset(
[
IsolatedFromAbove(),
ApplyOpHasCanonicalizationPatternsTrait(),
MemoryReadEffect(),
MemoryWriteEffect(),
RecursiveMemoryEffect(),
]
)
Expand Down Expand Up @@ -413,6 +436,11 @@ def get_accesses(self) -> Iterable[stencil.AccessPattern]:
accesses.append(offsets)
yield stencil.AccessPattern(tuple(accesses))

def add_coeff(self, offset: stencil.IndexAttr, coeff: FloatAttr[AnyFloat]):
self.coeffs = builtin.ArrayAttr(
list(self.coeffs or []) + [CoeffAttr(offset, coeff)]
)


@irdl_op_definition
class AccessOp(IRDLOperation):
Expand Down Expand Up @@ -626,5 +654,6 @@ class YieldOp(AbstractYieldOperation[Attribute]):
],
[
ExchangeDeclarationAttr,
CoeffAttr,
],
)
Loading
Loading