diff --git a/tests/filecheck/backend/riscv/convert_riscv_scf_for_to_frep.mlir b/tests/filecheck/backend/riscv/convert_riscv_scf_for_to_frep.mlir index f5e7dbc61b..dbe7d5faf1 100644 --- a/tests/filecheck/backend/riscv/convert_riscv_scf_for_to_frep.mlir +++ b/tests/filecheck/backend/riscv/convert_riscv_scf_for_to_frep.mlir @@ -7,8 +7,8 @@ %c1 = riscv.li 1 : !riscv.reg %c2 = riscv.li 2 : !riscv.reg -%readable = riscv_snitch.get_stream : !stream.readable> -%writable = riscv_snitch.get_stream : !stream.writable> +%readable = riscv_snitch.get_stream : !snitch.readable> +%writable = riscv_snitch.get_stream : !snitch.writable> %f0 = riscv.get_float_register : !riscv.freg %f1 = riscv.get_float_register : !riscv.freg diff --git a/tests/filecheck/backend/riscv/register-allocation/exclude_snitch.mlir b/tests/filecheck/backend/riscv/register-allocation/exclude_snitch.mlir index e74d0cb0f8..2f0e79dcf6 100644 --- a/tests/filecheck/backend/riscv/register-allocation/exclude_snitch.mlir +++ b/tests/filecheck/backend/riscv/register-allocation/exclude_snitch.mlir @@ -2,7 +2,7 @@ // RUN: xdsl-opt --split-input-file -p "riscv-allocate-registers{allocation_strategy=LivenessBlockNaive exclude_snitch_reserved=false}" %s | filecheck %s --check-prefix=CHECK-SNITCH-UNRESERVED riscv_func.func @main() { - %stream = "test.op"() : () -> (!stream.readable>) + %stream = "test.op"() : () -> (!snitch.readable>) %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg, !riscv.freg, !riscv.freg) %read = riscv_snitch.read from %stream : !riscv.freg %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg, !riscv.freg) -> !riscv.freg @@ -11,7 +11,7 @@ riscv_func.func @main() { // CHECK: builtin.module { // CHECK-NEXT: riscv_func.func @main() { -// CHECK-NEXT: %stream = "test.op"() : () -> !stream.readable> +// CHECK-NEXT: %stream = "test.op"() : () -> !snitch.readable> // CHECK-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg, !riscv.freg, !riscv.freg) // CHECK-NEXT: %read = riscv_snitch.read from %stream : !riscv.freg // CHECK-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg, !riscv.freg) -> !riscv.freg @@ -21,7 +21,7 @@ riscv_func.func @main() { // CHECK-SNITCH-UNRESERVED: builtin.module { // CHECK-SNITCH-UNRESERVED-NEXT: riscv_func.func @main() { -// CHECK-SNITCH-UNRESERVED-NEXT: %stream = "test.op"() : () -> !stream.readable> +// CHECK-SNITCH-UNRESERVED-NEXT: %stream = "test.op"() : () -> !snitch.readable> // CHECK-SNITCH-UNRESERVED-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg, !riscv.freg, !riscv.freg) // CHECK-SNITCH-UNRESERVED-NEXT: %read = riscv_snitch.read from %stream : !riscv.freg // CHECK-SNITCH-UNRESERVED-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg, !riscv.freg) -> !riscv.freg @@ -32,7 +32,7 @@ riscv_func.func @main() { // ----- riscv_func.func @main() { - %stream, %val = "test.op"() : () -> (!stream.writable>, !riscv.freg) + %stream, %val = "test.op"() : () -> (!snitch.writable>, !riscv.freg) %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg, !riscv.freg, !riscv.freg) riscv_snitch.write %val to %stream : !riscv.freg %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg, !riscv.freg) -> !riscv.freg @@ -41,7 +41,7 @@ riscv_func.func @main() { // CHECK: builtin.module { // CHECK-NEXT: riscv_func.func @main() { -// CHECK-NEXT: %stream, %val = "test.op"() : () -> (!stream.writable>, !riscv.freg) +// CHECK-NEXT: %stream, %val = "test.op"() : () -> (!snitch.writable>, !riscv.freg) // CHECK-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg, !riscv.freg, !riscv.freg) // CHECK-NEXT: riscv_snitch.write %val to %stream : !riscv.freg // CHECK-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg, !riscv.freg) -> !riscv.freg @@ -51,7 +51,7 @@ riscv_func.func @main() { // CHECK-SNITCH-UNRESERVED: builtin.module { // CHECK-SNITCH-UNRESERVED-NEXT: riscv_func.func @main() { -// CHECK-SNITCH-UNRESERVED-NEXT: %stream, %val = "test.op"() : () -> (!stream.writable>, !riscv.freg) +// CHECK-SNITCH-UNRESERVED-NEXT: %stream, %val = "test.op"() : () -> (!snitch.writable>, !riscv.freg) // CHECK-SNITCH-UNRESERVED-NEXT: %v0, %v1, %v2 = "test.op"() : () -> (!riscv.freg, !riscv.freg, !riscv.freg) // CHECK-SNITCH-UNRESERVED-NEXT: riscv_snitch.write %val to %stream : !riscv.freg // CHECK-SNITCH-UNRESERVED-NEXT: %sum1 = riscv.fadd.s %v0, %v1 : (!riscv.freg, !riscv.freg) -> !riscv.freg diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index 4f9fa17322..3815da1f75 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -4,18 +4,18 @@ // CHECK:builtin.module { // CHECK-GENERIC: "builtin.module"() ({ -%readable, %writable = "test.op"() : () -> (!stream.readable, !stream.writable) +%readable, %writable = "test.op"() : () -> (!memref_stream.readable, !memref_stream.writable) %val = memref_stream.read from %readable : f32 memref_stream.write %val to %writable : f32 -// CHECK-NEXT: %readable, %writable = "test.op"() : () -> (!stream.readable, !stream.writable) +// CHECK-NEXT: %readable, %writable = "test.op"() : () -> (!memref_stream.readable, !memref_stream.writable) // CHECK-NEXT: %val = memref_stream.read from %readable : f32 // CHECK-NEXT: memref_stream.write %val to %writable : f32 -// CHECK-GENERIC-NEXT: %readable, %writable = "test.op"() : () -> (!stream.readable, !stream.writable) -// CHECK-GENERIC-NEXT: %val = "memref_stream.read"(%readable) : (!stream.readable) -> f32 -// CHECK-GENERIC-NEXT: "memref_stream.write"(%val, %writable) : (f32, !stream.writable) -> () +// CHECK-GENERIC-NEXT: %readable, %writable = "test.op"() : () -> (!memref_stream.readable, !memref_stream.writable) +// CHECK-GENERIC-NEXT: %val = "memref_stream.read"(%readable) : (!memref_stream.readable) -> f32 +// CHECK-GENERIC-NEXT: "memref_stream.write"(%val, %writable) : (f32, !memref_stream.writable) -> () %A, %B, %C, %D = "test.op"() : () -> (memref<2xf32>, memref<3xf32>, memref<3x2xf64>, f64) @@ -26,8 +26,8 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, d1)> ] } ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {hello = "world"} { -^bb0(%a: !stream.readable, %b: !stream.readable, %c: !stream.writable): - "test.op"(%a, %b, %c) : (!stream.readable, !stream.readable, !stream.writable) -> () +^bb0(%a: !memref_stream.readable, %b: !memref_stream.readable, %c: !memref_stream.writable): + "test.op"(%a, %b, %c) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.writable) -> () } // CHECK-NEXT: %A, %B, %C, %D = "test.op"() : () -> (memref<2xf32>, memref<3xf32>, memref<3x2xf64>, f64) @@ -38,14 +38,14 @@ memref_stream.streaming_region { // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { -// CHECK-NEXT: ^0(%a : !stream.readable, %b : !stream.readable, %c : !stream.writable): -// CHECK-NEXT: "test.op"(%a, %b, %c) : (!stream.readable, !stream.readable, !stream.writable) -> () +// CHECK-NEXT: ^0(%a : !memref_stream.readable, %b : !memref_stream.readable, %c : !memref_stream.writable): +// CHECK-NEXT: "test.op"(%a, %b, %c) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.writable) -> () // CHECK-NEXT: } // CHECK-GENERIC-NEXT: %A, %B, %C, %D = "test.op"() : () -> (memref<2xf32>, memref<3xf32>, memref<3x2xf64>, f64) // CHECK-GENERIC-NEXT: "memref_stream.streaming_region"(%A, %B, %C) <{"patterns" = [#memref_stream.stride_pattern (d0)>, #memref_stream.stride_pattern (d1)>, #memref_stream.stride_pattern (d0, d1)>], "operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^0(%a : !stream.readable, %b : !stream.readable, %c : !stream.writable): -// CHECK-GENERIC-NEXT: "test.op"(%a, %b, %c) : (!stream.readable, !stream.readable, !stream.writable) -> () +// CHECK-GENERIC-NEXT: ^0(%a : !memref_stream.readable, %b : !memref_stream.readable, %c : !memref_stream.writable): +// CHECK-GENERIC-NEXT: "test.op"(%a, %b, %c) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.writable) -> () // CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (memref<2xf32>, memref<3xf32>, memref<3x2xf64>) -> () diff --git a/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir b/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir index ee1f5f0f9f..26ee763976 100644 --- a/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir +++ b/tests/filecheck/dialects/riscv_snitch/assembly_emission.mlir @@ -5,8 +5,8 @@ riscv_func.func @main() { %0 = riscv.get_register : !riscv.reg %1 = riscv.get_register : !riscv.reg - %readable = riscv_snitch.get_stream : !stream.readable> - %writable = riscv_snitch.get_stream : !stream.writable> + %readable = riscv_snitch.get_stream : !snitch.readable> + %writable = riscv_snitch.get_stream : !snitch.writable> riscv_snitch.frep_outer %0 { %val0 = riscv_snitch.read from %readable : !riscv.freg %val1 = riscv.fmv.d %val0 : (!riscv.freg) -> !riscv.freg diff --git a/tests/filecheck/dialects/riscv_snitch/ops.mlir b/tests/filecheck/dialects/riscv_snitch/ops.mlir index 54216e14dd..af3392eb0a 100644 --- a/tests/filecheck/dialects/riscv_snitch/ops.mlir +++ b/tests/filecheck/dialects/riscv_snitch/ops.mlir @@ -25,15 +25,15 @@ riscv_func.func @xfrep() { // CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg, !riscv.reg) -> !riscv.reg // CHECK-NEXT: } - %readable = riscv_snitch.get_stream : !stream.readable> - %writable = riscv_snitch.get_stream : !stream.writable> + %readable = riscv_snitch.get_stream : !snitch.readable> + %writable = riscv_snitch.get_stream : !snitch.writable> riscv_snitch.frep_outer %0 { %val0 = riscv_snitch.read from %readable : !riscv.freg %val1 = riscv.fmv.d %val0 : (!riscv.freg) -> !riscv.freg riscv_snitch.write %val1 to %writable : !riscv.freg } - // CHECK-NEXT: %readable = riscv_snitch.get_stream : !stream.readable> - // CHECK-NEXT: %writable = riscv_snitch.get_stream : !stream.writable> + // CHECK-NEXT: %readable = riscv_snitch.get_stream : !snitch.readable> + // CHECK-NEXT: %writable = riscv_snitch.get_stream : !snitch.writable> // CHECK-NEXT: riscv_snitch.frep_outer %0 { // CHECK-NEXT: %val0 = riscv_snitch.read from %readable : !riscv.freg // CHECK-NEXT: %val1 = riscv.fmv.d %val0 : (!riscv.freg) -> !riscv.freg @@ -129,12 +129,12 @@ riscv_func.func @simd() { // CHECK-GENERIC-NEXT: %{{.*}} = "riscv.add"(%{{.*}}, %{{.*}}) : (!riscv.reg, !riscv.reg) -> !riscv.reg // CHECK-GENERIC-NEXT: "riscv_snitch.frep_yield"() : () -> () // CHECK-GENERIC-NEXT: }) {"stagger_mask" = #builtin.int<0>, "stagger_count" = #builtin.int<0>} : (!riscv.reg) -> () -// CHECK-GENERIC-NEXT: %readable = "riscv_snitch.get_stream"() : () -> !stream.readable> -// CHECK-GENERIC-NEXT: %writable = "riscv_snitch.get_stream"() : () -> !stream.writable> +// CHECK-GENERIC-NEXT: %readable = "riscv_snitch.get_stream"() : () -> !snitch.readable> +// CHECK-GENERIC-NEXT: %writable = "riscv_snitch.get_stream"() : () -> !snitch.writable> // CHECK-GENERIC-NEXT: "riscv_snitch.frep_outer"(%0) ({ -// CHECK-GENERIC-NEXT: %val0 = "riscv_snitch.read"(%readable) : (!stream.readable>) -> !riscv.freg +// CHECK-GENERIC-NEXT: %val0 = "riscv_snitch.read"(%readable) : (!snitch.readable>) -> !riscv.freg // CHECK-GENERIC-NEXT: %val1 = "riscv.fmv.d"(%val0) : (!riscv.freg) -> !riscv.freg -// CHECK-GENERIC-NEXT: "riscv_snitch.write"(%val1, %writable) : (!riscv.freg, !stream.writable>) -> () +// CHECK-GENERIC-NEXT: "riscv_snitch.write"(%val1, %writable) : (!riscv.freg, !snitch.writable>) -> () // CHECK-GENERIC-NEXT: "riscv_snitch.frep_yield"() : () -> () // CHECK-GENERIC-NEXT: }) {"stagger_mask" = #builtin.int<0>, "stagger_count" = #builtin.int<0>} : (!riscv.reg) -> () // CHECK-GENERIC-NEXT: %init = "test.op"() : () -> !riscv.freg diff --git a/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir b/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir index aad4be8d6a..cfcd304d50 100644 --- a/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir +++ b/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir @@ -13,7 +13,7 @@ ], "operandSegmentSizes" = array }> ({ -^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): +^0(%a_stream : !snitch.readable>, %b_stream : !snitch.readable>, %c_stream : !snitch.writable>): "test.op"() : () -> () }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () // CHECK-NEXT: %{{.*}} = riscv.li 2 : !riscv.reg @@ -75,7 +75,7 @@ // CHECK-NEXT: "snitch.ssr_set_dimension_source"(%A) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg) -> () // CHECK-NEXT: "snitch.ssr_set_dimension_source"(%B) {"dm" = #builtin.int<1>, "dimension" = #builtin.int<1>} : (!riscv.reg) -> () // CHECK-NEXT: "snitch.ssr_set_dimension_destination"(%C) {"dm" = #builtin.int<2>, "dimension" = #builtin.int<3>} : (!riscv.reg) -> () -// CHECK-NEXT: %a_stream, %b_stream, %c_stream = "snitch.ssr_enable"() : () -> (!stream.readable>, !stream.readable>, !stream.writable>) +// CHECK-NEXT: %a_stream, %b_stream, %c_stream = "snitch.ssr_enable"() : () -> (!snitch.readable>, !snitch.readable>, !snitch.writable>) // CHECK-NEXT: "test.op"() : () -> () // CHECK-NEXT: "snitch.ssr_disable"() : () -> () @@ -85,7 +85,7 @@ ], "operandSegmentSizes" = array }> ({ -^0(%a_stream : !stream.readable>, %b_stream : !stream.writable>): +^0(%a_stream : !snitch.readable>, %b_stream : !snitch.writable>): "test.op"() : () -> () }) : (!riscv.reg, !riscv.reg) -> () @@ -99,7 +99,7 @@ // CHECK-NEXT: "snitch.ssr_set_stream_repetition"(%{{.*}}) {"dm" = #builtin.int<31>} : (!riscv.reg) -> () // CHECK-NEXT: "snitch.ssr_set_dimension_source"(%A) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg) -> () // CHECK-NEXT: "snitch.ssr_set_dimension_destination"(%B) {"dm" = #builtin.int<1>, "dimension" = #builtin.int<0>} : (!riscv.reg) -> () -// CHECK-NEXT: %{{.*}}, %{{.*}} = "snitch.ssr_enable"() : () -> (!stream.readable>, !stream.writable>) +// CHECK-NEXT: %{{.*}}, %{{.*}} = "snitch.ssr_enable"() : () -> (!snitch.readable>, !snitch.writable>) // CHECK-NEXT: "test.op"() : () -> () // CHECK-NEXT: "snitch.ssr_disable"() : () -> () diff --git a/tests/filecheck/dialects/snitch_stream/ops.mlir b/tests/filecheck/dialects/snitch_stream/ops.mlir index 803e56fb50..4989bd8d4a 100644 --- a/tests/filecheck/dialects/snitch_stream/ops.mlir +++ b/tests/filecheck/dialects/snitch_stream/ops.mlir @@ -8,7 +8,7 @@ snitch_stream.streaming_region { #snitch_stream.stride_pattern ] } ins(%X, %Y : !riscv.reg, !riscv.reg) outs(%Z : !riscv.reg) { -^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): +^0(%a_stream : !snitch.readable>, %b_stream : !snitch.readable>, %c_stream : !snitch.writable>): %c5 = riscv.li 5 : !riscv.reg riscv_snitch.frep_outer %c5 { %a = riscv_snitch.read from %a_stream : !riscv.freg @@ -26,7 +26,7 @@ snitch_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%X, %Y : !riscv.reg, !riscv.reg) outs(%Z : !riscv.reg) { -// CHECK-NEXT: ^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): +// CHECK-NEXT: ^0(%a_stream : !snitch.readable>, %b_stream : !snitch.readable>, %c_stream : !snitch.writable>): // CHECK-NEXT: %c5 = riscv.li 5 : !riscv.reg // CHECK-NEXT: riscv_snitch.frep_outer %c5 { // CHECK-NEXT: %a = riscv_snitch.read from %a_stream : !riscv.freg @@ -38,13 +38,13 @@ snitch_stream.streaming_region { // CHECK-GENERIC: %X, %Y, %Z = "test.op"() : () -> (!riscv.reg, !riscv.reg, !riscv.reg) // CHECK-GENERIC-NEXT: "snitch_stream.streaming_region"(%X, %Y, %Z) <{"stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): +// CHECK-GENERIC-NEXT: ^0(%a_stream : !snitch.readable>, %b_stream : !snitch.readable>, %c_stream : !snitch.writable>): // CHECK-GENERIC-NEXT: %c5 = "riscv.li"() {"immediate" = 5 : i32} : () -> !riscv.reg // CHECK-GENERIC-NEXT: "riscv_snitch.frep_outer"(%c5) ({ -// CHECK-GENERIC-NEXT: %a = "riscv_snitch.read"(%a_stream) : (!stream.readable>) -> !riscv.freg -// CHECK-GENERIC-NEXT: %b = "riscv_snitch.read"(%b_stream) : (!stream.readable>) -> !riscv.freg +// CHECK-GENERIC-NEXT: %a = "riscv_snitch.read"(%a_stream) : (!snitch.readable>) -> !riscv.freg +// CHECK-GENERIC-NEXT: %b = "riscv_snitch.read"(%b_stream) : (!snitch.readable>) -> !riscv.freg // CHECK-GENERIC-NEXT: %c = "riscv.fadd.d"(%a, %b) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.freg -// CHECK-GENERIC-NEXT: "riscv_snitch.write"(%c, %c_stream) : (!riscv.freg, !stream.writable>) -> () +// CHECK-GENERIC-NEXT: "riscv_snitch.write"(%c, %c_stream) : (!riscv.freg, !snitch.writable>) -> () // CHECK-GENERIC-NEXT: "riscv_snitch.frep_yield"() : () -> () // CHECK-GENERIC-NEXT: }) {"stagger_mask" = #builtin.int<0>, "stagger_count" = #builtin.int<0>} : (!riscv.reg) -> () // CHECK-GENERIC-NEXT: }) : (!riscv.reg, !riscv.reg, !riscv.reg) -> () diff --git a/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir b/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir index e8690b5798..459512fae3 100644 --- a/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir @@ -20,7 +20,7 @@ builtin.module { "stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array }> ({ - ^0(%a_stream : !stream.readable>, %b_stream : !stream.readable>, %c_stream : !stream.writable>): + ^0(%a_stream : !snitch.readable>, %b_stream : !snitch.readable>, %c_stream : !snitch.writable>): %c5 = riscv.li 5 : !riscv.reg riscv_snitch.frep_outer %c5 { %a = riscv_snitch.read from %a_stream : !riscv.freg diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index c21d6408ad..5aba84d4a7 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -181,7 +181,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( #memref_stream.stride_pattern (d0)> ] } ins(%X, %Y : memref<128xf64>, memref<128xf64>) { - ^0(%x_stream : !stream.readable, %y_stream : !stream.readable): + ^0(%x_stream : !memref_stream.readable, %y_stream : !memref_stream.readable): %zero_float = arith.constant 0.0 : f64 %c0 = arith.constant 0 : i32 @@ -593,7 +593,7 @@ func.func public @pooling_nchw_max_d1_s2_3x3( #snitch_stream.stride_pattern ] } ins(%X_1 : !riscv.reg) outs(%Y_1 : !riscv.reg) { - ^0(%x : !stream.readable>, %0 : !stream.writable>): + ^0(%x : !snitch.readable>, %0 : !snitch.writable>): %c128 = riscv.li 128 : !riscv.reg %c0 = riscv.li 0 : !riscv.reg %c1 = riscv.li 1 : !riscv.reg diff --git a/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir b/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir index e814425f54..73d2648cea 100644 --- a/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir @@ -25,7 +25,7 @@ builtin.module { "stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array }> ({ - ^0(%a_stream : !stream.readable>, %b_stream : !stream.writable>): + ^0(%a_stream : !snitch.readable>, %b_stream : !snitch.writable>): %c5 = riscv.li 5 : !riscv.reg riscv_snitch.frep_outer %c5 { %a = riscv_snitch.read from %a_stream : !riscv.freg diff --git a/tests/filecheck/projects/riscv-backend-paper/source.mlir b/tests/filecheck/projects/riscv-backend-paper/source.mlir index ec03baa125..64d7af2fc2 100644 --- a/tests/filecheck/projects/riscv-backend-paper/source.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/source.mlir @@ -14,8 +14,8 @@ func.func public @dsum(%arg0: memref<8x16xf64>, %arg1: memref<8x16xf64>, %arg2: } // CHECK-NEXT: func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 : memref<8x16xf64>) -> memref<8x16xf64> { // CHECK-NEXT: memref_stream.streaming_region {bounds = [8, 16], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]} ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { -// CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [[affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { +// CHECK-NEXT: ^0(%0 : !memref_stream.readable, %1 : !memref_stream.readable, %2 : !memref_stream.writable): +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [[affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !memref_stream.readable, !memref_stream.readable) outs(%2 : !memref_stream.writable) { // CHECK-NEXT: ^1(%in : f64, %in_0 : f64, %out : f64): // CHECK-NEXT: %3 = arith.addf %in, %in_0 : f64 // CHECK-NEXT: memref_stream.yield %3 : f64 @@ -37,8 +37,8 @@ func.func public @relu(%arg0: memref<16x16xf64>, %arg1: memref<16x16xf64>) -> me // CHECK-NEXT: func.func public @relu(%arg0_1 : memref<16x16xf64>, %arg1_1 : memref<16x16xf64>) -> memref<16x16xf64> { // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: memref_stream.streaming_region {bounds = [16, 16], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]} ins(%arg0_1 : memref<16x16xf64>) outs(%arg1_1 : memref<16x16xf64>) { -// CHECK-NEXT: ^2(%4 : !stream.readable, %5 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [[affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { +// CHECK-NEXT: ^2(%4 : !memref_stream.readable, %5 : !memref_stream.writable): +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [[affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !memref_stream.readable) outs(%5 : !memref_stream.writable) { // CHECK-NEXT: ^3(%in_1 : f64, %out_1 : f64): // CHECK-NEXT: %6 = arith.maximumf %in_1, %cst : f64 // CHECK-NEXT: memref_stream.yield %6 : f64 diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index a9de1257d4..fea98495db 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -10,8 +10,8 @@ #memref_stream.stride_pattern (d0, d1)> ] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { - ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): - memref_stream.generic {bounds = [8, 16], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { + ^0(%0 : !memref_stream.readable, %1 : !memref_stream.readable, %2 : !memref_stream.writable): + memref_stream.generic {bounds = [8, 16], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !memref_stream.readable, !memref_stream.readable) outs(%2 : !memref_stream.writable) { ^1(%in : f64, %in_0 : f64, %out : f64): %3 = arith.addf %in, %in_0 : f64 memref_stream.yield %3 : f64 @@ -27,7 +27,7 @@ // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : memref<8x16xf64>, memref<8x16xf64>) outs(%{{.*}} : memref<8x16xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.writable): // CHECK-NEXT: %{{.*}} = arith.constant 8 : index // CHECK-NEXT: %{{.*}} = arith.constant 16 : index // CHECK-NEXT: %{{.*}} = arith.constant 0 : index @@ -53,8 +53,8 @@ #memref_stream.stride_pattern (d0, d1)> ] } ins(%arg0_1 : memref<16x16xf64>) outs(%arg1_1 : memref<16x16xf64>) { - ^2(%4 : !stream.readable, %5 : !stream.writable): - memref_stream.generic {bounds = [16, 16], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { + ^2(%4 : !memref_stream.readable, %5 : !memref_stream.writable): + memref_stream.generic {bounds = [16, 16], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !memref_stream.readable) outs(%5 : !memref_stream.writable) { ^3(%in_1 : f64, %out_1 : f64): %6 = arith.maximumf %in_1, %cst : f64 memref_stream.yield %6 : f64 @@ -71,7 +71,7 @@ // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}} : memref<16x16xf64>) outs(%{{.*}} : memref<16x16xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.writable): // CHECK-NEXT: %{{.*}} = arith.constant 16 : index // CHECK-NEXT: %{{.*}} = arith.constant 16 : index // CHECK-NEXT: %{{.*}} = arith.constant 0 : index @@ -95,7 +95,7 @@ func.func public @fill(%arg0 : memref<16x16xf64>) -> memref<16x16xf64> { #memref_stream.stride_pattern (d0, d1)> ] } outs(%arg0 : memref<16x16xf64>) { - ^3(%7 : !stream.writable): + ^3(%7 : !memref_stream.writable): memref_stream.generic { bounds = [16, 16], indexing_maps = [ @@ -103,7 +103,7 @@ func.func public @fill(%arg0 : memref<16x16xf64>) -> memref<16x16xf64> { affine_map<(d0, d1) -> (d0, d1)> ], iterator_types = ["parallel", "parallel"] - } ins(%zero : f64) outs(%7 : !stream.writable) { + } ins(%zero : f64) outs(%7 : !memref_stream.writable) { ^4(%in: f64, %out: f64): memref_stream.yield %in : f64 } @@ -118,7 +118,7 @@ func.func public @fill(%arg0 : memref<16x16xf64>) -> memref<16x16xf64> { // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } outs(%{{.*}} : memref<16x16xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.writable): // CHECK-NEXT: %{{.*}} = arith.constant 16 : index // CHECK-NEXT: %{{.*}} = arith.constant 16 : index // CHECK-NEXT: %{{.*}} = arith.constant 0 : index @@ -139,7 +139,7 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> #memref_stream.stride_pattern (d2, d1)> ] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) { - ^0(%0 : !stream.readable, %1 : !stream.readable): + ^0(%0 : !memref_stream.readable, %1 : !memref_stream.readable): memref_stream.generic { bounds = [4, 3, 2], indexing_maps = [ @@ -148,7 +148,7 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> affine_map<(d0, d1) -> (d0, d1)> ], iterator_types = ["parallel", "parallel", "reduction"] - } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) { + } ins(%0, %1 : !memref_stream.readable, !memref_stream.readable) outs(%C : memref<4x3xf64>) { ^1(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 %acc_new = arith.addf %acc_old, %prod : f64 @@ -164,7 +164,7 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> // CHECK-NEXT: #memref_stream.stride_pattern (d2, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable): // CHECK-NEXT: %{{.*}} = arith.constant 4 : index // CHECK-NEXT: %{{.*}} = arith.constant 3 : index // CHECK-NEXT: %{{.*}} = arith.constant 2 : index @@ -193,7 +193,7 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { #memref_stream.stride_pattern (d0 * 3 + d1)> ] } ins(%A : memref<6xf64>) { - ^0(%0 : !stream.readable): + ^0(%0 : !memref_stream.readable): memref_stream.generic { bounds = [2, 3], indexing_maps = [ @@ -201,7 +201,7 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { affine_map<(d0, d1) -> ()> ], iterator_types = ["parallel", "reduction"] - } ins(%0 : !stream.readable) outs(%B : memref) { + } ins(%0 : !memref_stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 memref_stream.yield %acc_new : f64 @@ -215,7 +215,7 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { // CHECK-NEXT: #memref_stream.stride_pattern (((d0 * 3) + d1))> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}} : memref<6xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable): // CHECK-NEXT: %{{.*}} = arith.constant 2 : index // CHECK-NEXT: %{{.*}} = arith.constant 3 : index // CHECK-NEXT: %{{.*}} = arith.constant 0 : index @@ -238,7 +238,7 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< #memref_stream.stride_pattern (d0, d1, d2)> ] } ins(%A : memref<2x3x4xf64>) { - ^0(%0 : !stream.readable): + ^0(%0 : !memref_stream.readable): memref_stream.generic { bounds = [2, 3, 4], indexing_maps = [ @@ -246,7 +246,7 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< affine_map<() -> ()> ], iterator_types = ["reduction", "reduction", "reduction"] - } ins(%0 : !stream.readable) outs(%B : memref) { + } ins(%0 : !memref_stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 memref_stream.yield %acc_new : f64 @@ -261,7 +261,7 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1, d2)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}} : memref<2x3x4xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable): // CHECK-NEXT: %{{.*}} = arith.constant 2 : index // CHECK-NEXT: %{{.*}} = arith.constant 3 : index // CHECK-NEXT: %{{.*}} = arith.constant 4 : index @@ -292,7 +292,7 @@ func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x #memref_stream.stride_pattern (d2, d1)> ] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) { - ^0(%0 : !stream.readable, %1 : !stream.readable): + ^0(%0 : !memref_stream.readable, %1 : !memref_stream.readable): memref_stream.generic { bounds = [4, 3, 2], indexing_maps = [ @@ -301,7 +301,7 @@ func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x affine_map<(d0, d1) -> (d0, d1)> ], iterator_types = ["parallel", "parallel", "reduction"] - } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) inits(%zero_float : f64) { + } ins(%0, %1 : !memref_stream.readable, !memref_stream.readable) outs(%C : memref<4x3xf64>) inits(%zero_float : f64) { ^1(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 %acc_new = arith.addf %acc_old, %prod : f64 @@ -318,7 +318,7 @@ func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x // CHECK-NEXT: #memref_stream.stride_pattern (d2, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable): // CHECK-NEXT: %{{.*}} = arith.constant 4 : index // CHECK-NEXT: %{{.*}} = arith.constant 3 : index // CHECK-NEXT: %{{.*}} = arith.constant 2 : index diff --git a/tests/filecheck/transforms/convert_memref_stream_to_snitch_stream.mlir b/tests/filecheck/transforms/convert_memref_stream_to_snitch_stream.mlir index 25d185c528..899312bd2f 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_snitch_stream.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_snitch_stream.mlir @@ -2,13 +2,13 @@ // CHECK: builtin.module { -// CHECK-NEXT: %f64_readable, %f64_writable = "test.op"() : () -> (!stream.readable, !stream.writable) -%f64_readable, %f64_writable = "test.op"() : () -> (!stream.readable, !stream.writable) +// CHECK-NEXT: %f64_readable, %f64_writable = "test.op"() : () -> (!memref_stream.readable, !memref_stream.writable) +%f64_readable, %f64_writable = "test.op"() : () -> (!memref_stream.readable, !memref_stream.writable) -// CHECK-NEXT: %val_f64 = builtin.unrealized_conversion_cast %f64_readable : !stream.readable to !stream.readable +// CHECK-NEXT: %val_f64 = builtin.unrealized_conversion_cast %f64_readable : !memref_stream.readable to !snitch.readable // CHECK-NEXT: %{{.*}} = riscv_snitch.read from %val_f64 : !riscv.freg // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.freg to f64 -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %f64_writable : !stream.writable to !stream.writable +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %f64_writable : !memref_stream.writable to !snitch.writable %val_f64 = memref_stream.read from %f64_readable : f64 // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : f64 to !riscv.freg @@ -17,13 +17,13 @@ memref_stream.write %val_f64 to %f64_writable : f64 -// CHECK-NEXT: %vf64_readable, %vf64_writable = "test.op"() : () -> (!stream.readable>, !stream.writable>) -%vf64_readable, %vf64_writable = "test.op"() : () -> (!stream.readable>, !stream.writable>) +// CHECK-NEXT: %vf64_readable, %vf64_writable = "test.op"() : () -> (!memref_stream.readable>, !memref_stream.writable>) +%vf64_readable, %vf64_writable = "test.op"() : () -> (!memref_stream.readable>, !memref_stream.writable>) -// CHECK-NEXT: %val_vf64 = builtin.unrealized_conversion_cast %vf64_readable : !stream.readable> to !stream.readable +// CHECK-NEXT: %val_vf64 = builtin.unrealized_conversion_cast %vf64_readable : !memref_stream.readable> to !snitch.readable // CHECK-NEXT: %{{.*}} = riscv_snitch.read from %val_vf64 : !riscv.freg // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.freg to vector<1xf64> -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %vf64_writable : !stream.writable> to !stream.writable +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %vf64_writable : !memref_stream.writable> to !snitch.writable %val_vf64 = memref_stream.read from %vf64_readable : vector<1xf64> // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : vector<1xf64> to !riscv.freg @@ -31,13 +31,13 @@ memref_stream.write %val_f64 to %f64_writable : f64 // CHECK-NEXT: riscv_snitch.write %{{.*}} to %{{.*}} : !riscv.freg memref_stream.write %val_vf64 to %vf64_writable : vector<1xf64> -// CHECK-NEXT: %vf32_readable, %vf32_writable = "test.op"() : () -> (!stream.readable>, !stream.writable>) -%vf32_readable, %vf32_writable = "test.op"() : () -> (!stream.readable>, !stream.writable>) +// CHECK-NEXT: %vf32_readable, %vf32_writable = "test.op"() : () -> (!memref_stream.readable>, !memref_stream.writable>) +%vf32_readable, %vf32_writable = "test.op"() : () -> (!memref_stream.readable>, !memref_stream.writable>) -// CHECK-NEXT: %val_vf32 = builtin.unrealized_conversion_cast %vf32_readable : !stream.readable> to !stream.readable +// CHECK-NEXT: %val_vf32 = builtin.unrealized_conversion_cast %vf32_readable : !memref_stream.readable> to !snitch.readable // CHECK-NEXT: %{{.*}} = riscv_snitch.read from %val_vf32 : !riscv.freg // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.freg to vector<2xf32> -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %vf32_writable : !stream.writable> to !stream.writable +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %vf32_writable : !memref_stream.writable> to !snitch.writable %val_vf32 = memref_stream.read from %vf32_readable : vector<2xf32> // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf32> to !riscv.freg @@ -46,13 +46,13 @@ memref_stream.write %val_vf64 to %vf64_writable : vector<1xf64> memref_stream.write %val_vf32 to %vf32_writable : vector<2xf32> -// CHECK-NEXT: %vf16_readable, %vf16_writable = "test.op"() : () -> (!stream.readable>, !stream.writable>) -%vf16_readable, %vf16_writable = "test.op"() : () -> (!stream.readable>, !stream.writable>) +// CHECK-NEXT: %vf16_readable, %vf16_writable = "test.op"() : () -> (!memref_stream.readable>, !memref_stream.writable>) +%vf16_readable, %vf16_writable = "test.op"() : () -> (!memref_stream.readable>, !memref_stream.writable>) -// CHECK-NEXT: %val_vf16 = builtin.unrealized_conversion_cast %vf16_readable : !stream.readable> to !stream.readable +// CHECK-NEXT: %val_vf16 = builtin.unrealized_conversion_cast %vf16_readable : !memref_stream.readable> to !snitch.readable // CHECK-NEXT: %{{.*}} = riscv_snitch.read from %val_vf16 : !riscv.freg // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.freg to vector<4xf16> -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %vf16_writable : !stream.writable> to !stream.writable +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %vf16_writable : !memref_stream.writable> to !snitch.writable %val_vf16 = memref_stream.read from %vf16_readable : vector<4xf16> // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : vector<4xf16> to !riscv.freg @@ -70,8 +70,8 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, d1)> ] } ins(%A, %B : memref<2xf64>, memref<3xf64>) outs(%C : memref<3x2xf64>) attrs = {hello = "world"} { -^bb0(%a: !stream.readable, %b: !stream.readable, %c: !stream.writable): - "test.op"(%a, %b, %c) : (!stream.readable, !stream.readable, !stream.writable) -> () +^bb0(%a: !memref_stream.readable, %b: !memref_stream.readable, %c: !memref_stream.writable): + "test.op"(%a, %b, %c) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.writable) -> () } // CHECK-NEXT: %A, %B, %C = "test.op"() : () -> (memref<2xf64>, memref<3xf64>, memref<3x2xf64>) @@ -85,11 +85,11 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%A_1, %B_1 : !riscv.reg, !riscv.reg) outs(%C_1 : !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%a : !stream.readable, %b : !stream.readable, %c : !stream.writable): -// CHECK-NEXT: %a_1 = builtin.unrealized_conversion_cast %a : !stream.readable to !stream.readable -// CHECK-NEXT: %b_1 = builtin.unrealized_conversion_cast %b : !stream.readable to !stream.readable -// CHECK-NEXT: %c_1 = builtin.unrealized_conversion_cast %c : !stream.writable to !stream.writable -// CHECK-NEXT: "test.op"(%a_1, %b_1, %c_1) : (!stream.readable, !stream.readable, !stream.writable) -> () +// CHECK-NEXT: ^{{.*}}(%a : !snitch.readable, %b : !snitch.readable, %c : !snitch.writable): +// CHECK-NEXT: %a_1 = builtin.unrealized_conversion_cast %a : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %b_1 = builtin.unrealized_conversion_cast %b : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %c_1 = builtin.unrealized_conversion_cast %c : !snitch.writable to !memref_stream.writable +// CHECK-NEXT: "test.op"(%a_1, %b_1, %c_1) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.writable) -> () // CHECK-NEXT: } memref_stream.streaming_region { @@ -98,8 +98,8 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, d1)> ] } ins(%C, %C : memref<3x2xf64>, memref<3x2xf64>) { -^bb0(%c0: !stream.readable, %c1: !stream.readable): - "test.op"(%c0, %c1) : (!stream.readable, !stream.readable) -> () +^bb0(%c0: !memref_stream.readable, %c1: !memref_stream.readable): + "test.op"(%c0, %c1) : (!memref_stream.readable, !memref_stream.readable) -> () } // CHECK-NEXT: %C_2 = builtin.unrealized_conversion_cast %C : memref<3x2xf64> to !riscv.reg @@ -109,10 +109,10 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%C_2, %C_3 : !riscv.reg, !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%c0 : !stream.readable, %c1 : !stream.readable): -// CHECK-NEXT: %c0_1 = builtin.unrealized_conversion_cast %c0 : !stream.readable to !stream.readable -// CHECK-NEXT: %c1_1 = builtin.unrealized_conversion_cast %c1 : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%c0_1, %c1_1) : (!stream.readable, !stream.readable) -> () +// CHECK-NEXT: ^{{.*}}(%c0 : !snitch.readable, %c1 : !snitch.readable): +// CHECK-NEXT: %c0_1 = builtin.unrealized_conversion_cast %c0 : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %c1_1 = builtin.unrealized_conversion_cast %c1 : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: "test.op"(%c0_1, %c1_1) : (!memref_stream.readable, !memref_stream.readable) -> () // CHECK-NEXT: } %D, %E = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) @@ -124,8 +124,8 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d1, d4, d5, d6)> ] } ins(%D, %E : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) { -^0(%d_stream : !stream.readable, %e_stream : !stream.readable): - "test.op"(%d_stream, %e_stream) : (!stream.readable, !stream.readable) -> () +^0(%d_stream : !memref_stream.readable, %e_stream : !memref_stream.readable): + "test.op"(%d_stream, %e_stream) : (!memref_stream.readable, !memref_stream.readable) -> () } // CHECK-NEXT: %D_1 = builtin.unrealized_conversion_cast %D : memref<1x1x8x8xf64> to !riscv.reg @@ -136,10 +136,10 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%D_1, %E_1 : !riscv.reg, !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%d_stream : !stream.readable, %e_stream : !stream.readable): -// CHECK-NEXT: %d_stream_1 = builtin.unrealized_conversion_cast %d_stream : !stream.readable to !stream.readable -// CHECK-NEXT: %e_stream_1 = builtin.unrealized_conversion_cast %e_stream : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%d_stream_1, %e_stream_1) : (!stream.readable, !stream.readable) -> () +// CHECK-NEXT: ^{{.*}}(%d_stream : !snitch.readable, %e_stream : !snitch.readable): +// CHECK-NEXT: %d_stream_1 = builtin.unrealized_conversion_cast %d_stream : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %e_stream_1 = builtin.unrealized_conversion_cast %e_stream : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: "test.op"(%d_stream_1, %e_stream_1) : (!memref_stream.readable, !memref_stream.readable) -> () // CHECK-NEXT: } %F = "test.op"() : () -> memref<8x8xf64> @@ -152,8 +152,8 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (m, n)> ] } ins(%F, %F, %F : memref<8x8xf64>, memref<8x8xf64>, memref<8x8xf64>) { -^0(%x_stream : !stream.readable, %w_stream : !stream.readable, %b_stream : !stream.readable): - "test.op"(%x_stream, %w_stream, %b_stream) : (!stream.readable, !stream.readable, !stream.readable) -> () +^0(%x_stream : !memref_stream.readable, %w_stream : !memref_stream.readable, %b_stream : !memref_stream.readable): + "test.op"(%x_stream, %w_stream, %b_stream) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.readable) -> () } // CHECK-NEXT: %F_1 = builtin.unrealized_conversion_cast %F : memref<8x8xf64> to !riscv.reg @@ -166,11 +166,11 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%F_1, %F_2, %F_3 : !riscv.reg, !riscv.reg, !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%x_stream : !stream.readable, %w_stream : !stream.readable, %b_stream : !stream.readable): -// CHECK-NEXT: %x_stream_1 = builtin.unrealized_conversion_cast %x_stream : !stream.readable to !stream.readable -// CHECK-NEXT: %w_stream_1 = builtin.unrealized_conversion_cast %w_stream : !stream.readable to !stream.readable -// CHECK-NEXT: %b_stream_1 = builtin.unrealized_conversion_cast %b_stream : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%x_stream_1, %w_stream_1, %b_stream_1) : (!stream.readable, !stream.readable, !stream.readable) -> () +// CHECK-NEXT: ^{{.*}}(%x_stream : !snitch.readable, %w_stream : !snitch.readable, %b_stream : !snitch.readable): +// CHECK-NEXT: %x_stream_1 = builtin.unrealized_conversion_cast %x_stream : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %w_stream_1 = builtin.unrealized_conversion_cast %w_stream : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %b_stream_1 = builtin.unrealized_conversion_cast %b_stream : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: "test.op"(%x_stream_1, %w_stream_1, %b_stream_1) : (!memref_stream.readable, !memref_stream.readable, !memref_stream.readable) -> () // CHECK-NEXT: } %G, %H = "test.op"() : () -> (f64, memref<16x16xf64>) @@ -181,7 +181,7 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, d1)> ] } outs(%H : memref<16x16xf64>) { -^0(%h_stream : !stream.writable): +^0(%h_stream : !memref_stream.writable): %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 %c256 = arith.constant 256 : i32 @@ -196,13 +196,13 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } outs(%H_1 : !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%h_stream : !stream.writable): -// CHECK-NEXT: %h_stream_1 = builtin.unrealized_conversion_cast %h_stream : !stream.writable to !stream.writable +// CHECK-NEXT: ^{{.*}}(%h_stream : !snitch.writable): +// CHECK-NEXT: %h_stream_1 = builtin.unrealized_conversion_cast %h_stream : !snitch.writable to !memref_stream.writable // CHECK-NEXT: %c0_2 = arith.constant 0 : i32 // CHECK-NEXT: %c1_2 = arith.constant 1 : i32 // CHECK-NEXT: %c256 = arith.constant 256 : i32 // CHECK-NEXT: scf.for %i = %c0_2 to %c256 step %c1_2 : i32 { -// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %h_stream_1 : !stream.writable to !stream.writable +// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %h_stream_1 : !memref_stream.writable to !snitch.writable // CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %G : f64 to !riscv.freg // CHECK-NEXT: %{{.*}} = riscv.fmv.d %{{.*}} : (!riscv.freg) -> !riscv.freg // CHECK-NEXT: riscv_snitch.write %{{.*}} to %{{.*}} : !riscv.freg @@ -223,7 +223,7 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, ((d1 * 4) + d2))> ] } ins(%I, %J : memref<3x5xf64>, memref<5x8xf64>) outs(%K : memref<3x8xf64>) { -^0(%i : !stream.readable, %j : !stream.readable, %k : !stream.writable): +^0(%i : !memref_stream.readable, %j : !memref_stream.readable, %k : !memref_stream.writable): %res = "test.op"() : () -> f64 memref_stream.yield %res : f64 } @@ -234,10 +234,10 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%I_1, %J_1 : !riscv.reg, !riscv.reg) outs(%K_1 : !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%i_1 : !stream.readable, %j : !stream.readable, %k : !stream.writable): -// CHECK-NEXT: %i_2 = builtin.unrealized_conversion_cast %i_1 : !stream.readable to !stream.readable -// CHECK-NEXT: %j_1 = builtin.unrealized_conversion_cast %j : !stream.readable to !stream.readable -// CHECK-NEXT: %k_1 = builtin.unrealized_conversion_cast %k : !stream.writable to !stream.writable +// CHECK-NEXT: ^{{.*}}(%i_1 : !snitch.readable, %j : !snitch.readable, %k : !snitch.writable): +// CHECK-NEXT: %i_2 = builtin.unrealized_conversion_cast %i_1 : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %j_1 = builtin.unrealized_conversion_cast %j : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: %k_1 = builtin.unrealized_conversion_cast %k : !snitch.writable to !memref_stream.writable // CHECK-NEXT: %res = "test.op"() : () -> f64 // CHECK-NEXT: memref_stream.yield %res : f64 // CHECK-NEXT: } @@ -252,8 +252,8 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, d1)> ] } ins(%A_strided : memref<3x2xf64, strided<[4, 1]>>) { -^bb0(%a_strided: !stream.readable): - "test.op"(%a_strided) : (!stream.readable) -> () +^bb0(%a_strided: !memref_stream.readable): + "test.op"(%a_strided) : (!memref_stream.readable) -> () } // CHECK-NEXT: %A_strided_1 = builtin.unrealized_conversion_cast %A_strided : memref<3x2xf64, strided<[4, 1]>> to !riscv.reg @@ -262,9 +262,9 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%A_strided_1 : !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%a_strided : !stream.readable): -// CHECK-NEXT: %a_strided_1 = builtin.unrealized_conversion_cast %a_strided : !stream.readable to !stream.readable -// CHECK-NEXT: "test.op"(%a_strided_1) : (!stream.readable) -> () +// CHECK-NEXT: ^{{.*}}(%a_strided : !snitch.readable): +// CHECK-NEXT: %a_strided_1 = builtin.unrealized_conversion_cast %a_strided : !snitch.readable to !memref_stream.readable +// CHECK-NEXT: "test.op"(%a_strided_1) : (!memref_stream.readable) -> () // CHECK-NEXT: } %X_f32, %Y_f32, %Z_f32 = "test.op"() : () -> (memref<8x16xf32>, memref<8x16xf32>, memref<8x16xf32>) @@ -277,7 +277,7 @@ memref_stream.streaming_region { #memref_stream.stride_pattern (d0, 2 * d1)> ] } ins(%X_f32, %Y_f32 : memref<8x16xf32>, memref<8x16xf32>) outs(%Z_f32 : memref<8x16xf32>) { -^0(%x_stream : !stream.readable>, %y_stream : !stream.readable>, %z_stream : !stream.writable>): +^0(%x_stream : !memref_stream.readable>, %y_stream : !memref_stream.readable>, %z_stream : !memref_stream.writable>): memref_stream.generic { bounds = [8, 8], indexing_maps = [ @@ -286,7 +286,7 @@ memref_stream.streaming_region { affine_map<(d0, d1) -> (d0, 2 * d1)> ], iterator_types = ["parallel", "parallel"] - } ins(%x_stream, %y_stream : !stream.readable>, !stream.readable>) outs(%z_stream : !stream.writable>) { + } ins(%x_stream, %y_stream : !memref_stream.readable>, !memref_stream.readable>) outs(%z_stream : !memref_stream.writable>) { ^1(%in : vector<2xf32>, %in_1 : vector<2xf32>, %out : vector<2xf32>): %3 = arith.addf %in, %in_1 : vector<2xf32> memref_stream.yield %3 : vector<2xf32> @@ -300,10 +300,10 @@ memref_stream.streaming_region { // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%X_f32_1, %Y_f32_1 : !riscv.reg, !riscv.reg) outs(%Z_f32_1 : !riscv.reg) { -// CHECK-NEXT: ^7(%x_stream_2 : !stream.readable, %y_stream : !stream.readable, %z_stream : !stream.writable): -// CHECK-NEXT: %x_stream_3 = builtin.unrealized_conversion_cast %x_stream_2 : !stream.readable to !stream.readable> -// CHECK-NEXT: %y_stream_1 = builtin.unrealized_conversion_cast %y_stream : !stream.readable to !stream.readable> -// CHECK-NEXT: %z_stream_1 = builtin.unrealized_conversion_cast %z_stream : !stream.writable to !stream.writable> +// CHECK-NEXT: ^7(%x_stream_2 : !snitch.readable, %y_stream : !snitch.readable, %z_stream : !snitch.writable): +// CHECK-NEXT: %x_stream_3 = builtin.unrealized_conversion_cast %x_stream_2 : !snitch.readable to !memref_stream.readable> +// CHECK-NEXT: %y_stream_1 = builtin.unrealized_conversion_cast %y_stream : !snitch.readable to !memref_stream.readable> +// CHECK-NEXT: %z_stream_1 = builtin.unrealized_conversion_cast %z_stream : !snitch.writable to !memref_stream.writable> // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [8, 8], // CHECK-NEXT: indexing_maps = [ @@ -312,7 +312,7 @@ memref_stream.streaming_region { // CHECK-NEXT: affine_map<(d0, d1) -> (d0, (d1 * 2))> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel"] -// CHECK-NEXT: } ins(%x_stream_3, %y_stream_1 : !stream.readable>, !stream.readable>) outs(%z_stream_1 : !stream.writable>) { +// CHECK-NEXT: } ins(%x_stream_3, %y_stream_1 : !memref_stream.readable>, !memref_stream.readable>) outs(%z_stream_1 : !memref_stream.writable>) { // CHECK-NEXT: ^8(%in : vector<2xf32>, %in_1 : vector<2xf32>, %out : vector<2xf32>): // CHECK-NEXT: %15 = arith.addf %in, %in_1 : vector<2xf32> // CHECK-NEXT: memref_stream.yield %15 : vector<2xf32> diff --git a/tests/filecheck/transforms/memref_streamify.mlir b/tests/filecheck/transforms/memref_streamify.mlir index a9952e1a10..817701331c 100644 --- a/tests/filecheck/transforms/memref_streamify.mlir +++ b/tests/filecheck/transforms/memref_streamify.mlir @@ -27,7 +27,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : memref<8x16xf64>, memref<8x16xf64>) outs(%{{.*}} : memref<8x16xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.writable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [8, 16], // CHECK-NEXT: indexing_maps = [ @@ -36,7 +36,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 // CHECK-NEXT: affine_map<(d0, d1) -> (d0, d1)> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel"] -// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !memref_stream.readable, !memref_stream.readable) outs(%{{.*}} : !memref_stream.writable) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 @@ -67,7 +67,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { -// CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.writable): +// CHECK-NEXT: ^0(%0 : !memref_stream.readable, %1 : !memref_stream.writable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [16, 16], // CHECK-NEXT: indexing_maps = [ @@ -75,7 +75,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: affine_map<(d0, d1) -> (d0, d1)> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel"] -// CHECK-NEXT: } ins(%0 : !stream.readable) outs(%1 : !stream.writable) { +// CHECK-NEXT: } ins(%0 : !memref_stream.readable) outs(%1 : !memref_stream.writable) { // CHECK-NEXT: ^1(%in : f64, %out : f64): // CHECK-NEXT: %2 = arith.maximumf %in, %cst : f64 // CHECK-NEXT: memref_stream.yield %2 : f64 @@ -109,7 +109,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1)> // CHECK-NEXT: ] // CHECK-NEXT: } outs(%{{.*}} : memref<16x16xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.writable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [16, 16], // CHECK-NEXT: indexing_maps = [ @@ -117,7 +117,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: affine_map<(d0, d1) -> (d0, d1)> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel"] -// CHECK-NEXT: } ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: } ins(%{{.*}} : f64) outs(%{{.*}} : !memref_stream.writable) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 // CHECK-NEXT: } @@ -158,7 +158,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: #memref_stream.stride_pattern (d0, d1, d2, d3)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { -// CHECK-NEXT: ^0(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: ^0(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.writable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [1, 1, 6, 6, 1, 3, 3], // CHECK-NEXT: indexing_maps = [ @@ -167,7 +167,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] -// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : !stream.writable) inits(%zero_float : f64) { +// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !memref_stream.readable, !memref_stream.readable) outs(%{{.*}} : !memref_stream.writable) inits(%zero_float : f64) { // CHECK-NEXT: ^{{\d+}}(%x : f64, %y : f64, %acc : f64): // CHECK-NEXT: %prod = arith.mulf %x, %y fastmath : f64 // CHECK-NEXT: %res = arith.addf %prod, %acc fastmath : f64 @@ -205,7 +205,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: #memref_stream.stride_pattern (d0)> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}} : memref<2xf64>) outs(%{{.*}} : memref<2xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.writable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [2], // CHECK-NEXT: indexing_maps = [ @@ -214,7 +214,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: affine_map<(d0) -> (d0)> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel"] -// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !stream.readable, memref<2xf64>) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !memref_stream.readable, memref<2xf64>) outs(%{{.*}} : !memref_stream.writable) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 // CHECK-NEXT: } @@ -260,7 +260,7 @@ func.func @interleaved_no_init(%A : memref<3x5xf64>, %B : memref<5x8xf64>, %C : // CHECK-NEXT: #memref_stream.stride_pattern (d2, ((d1 * 4) + d3))> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : memref<3x5xf64>, memref<5x8xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [3, 2, 5, 4], // CHECK-NEXT: indexing_maps = [ @@ -269,7 +269,7 @@ func.func @interleaved_no_init(%A : memref<3x5xf64>, %B : memref<5x8xf64>, %C : // CHECK-NEXT: affine_map<(d0, d1, d2) -> (d0, ((d1 * 4) + d2))> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel", "reduction", "interleaved"] -// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : memref<3x8xf64>) { +// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !memref_stream.readable, !memref_stream.readable) outs(%{{.*}} : memref<3x8xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} fastmath : f64 // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} fastmath : f64 @@ -325,7 +325,7 @@ func.func @interleaved_init(%A : memref<3x5xf64>, %B : memref<5x8xf64>, %C : mem // CHECK-NEXT: #memref_stream.stride_pattern (d0, ((d1 * 4) + d2))> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : memref<3x5xf64>, memref<5x8xf64>) outs(%{{.*}} : memref<3x8xf64>) { -// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.readable, %{{.*}} : !memref_stream.writable): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [3, 2, 5, 4], // CHECK-NEXT: indexing_maps = [ @@ -334,7 +334,7 @@ func.func @interleaved_init(%A : memref<3x5xf64>, %B : memref<5x8xf64>, %C : mem // CHECK-NEXT: affine_map<(d0, d1, d2) -> (d0, ((d1 * 4) + d2))> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel", "reduction", "interleaved"] -// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : !stream.writable) inits(%{{.*}} : f64) { +// CHECK-NEXT: } ins(%{{.*}}, %{{.*}} : !memref_stream.readable, !memref_stream.readable) outs(%{{.*}} : !memref_stream.writable) inits(%{{.*}} : f64) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} fastmath : f64 // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} fastmath : f64 @@ -379,7 +379,7 @@ func.func public @ssum( // CHECK-NEXT: #memref_stream.stride_pattern (d0, (d1 * 2))> // CHECK-NEXT: ] // CHECK-NEXT: } ins(%X, %Y : memref<8x16xf32>, memref<8x16xf32>) outs(%Z : memref<8x16xf32>) { -// CHECK-NEXT: ^0(%0 : !stream.readable>, %1 : !stream.readable>, %2 : !stream.writable>): +// CHECK-NEXT: ^0(%0 : !memref_stream.readable>, %1 : !memref_stream.readable>, %2 : !memref_stream.writable>): // CHECK-NEXT: memref_stream.generic { // CHECK-NEXT: bounds = [8, 8], // CHECK-NEXT: indexing_maps = [ @@ -388,7 +388,7 @@ func.func public @ssum( // CHECK-NEXT: affine_map<(d0, d1) -> (d0, (d1 * 2))> // CHECK-NEXT: ], // CHECK-NEXT: iterator_types = ["parallel", "parallel"] -// CHECK-NEXT: } ins(%0, %1 : !stream.readable>, !stream.readable>) outs(%2 : !stream.writable>) { +// CHECK-NEXT: } ins(%0, %1 : !memref_stream.readable>, !memref_stream.readable>) outs(%2 : !memref_stream.writable>) { // CHECK-NEXT: ^1(%in : vector<2xf32>, %in_1 : vector<2xf32>, %out : vector<2xf32>): // CHECK-NEXT: %3 = arith.addf %in, %in_1 : vector<2xf32> // CHECK-NEXT: memref_stream.yield %3 : vector<2xf32> diff --git a/tests/filecheck/transforms/riscv_cse.mlir b/tests/filecheck/transforms/riscv_cse.mlir index 1810081a34..2ea52c0f6c 100644 --- a/tests/filecheck/transforms/riscv_cse.mlir +++ b/tests/filecheck/transforms/riscv_cse.mlir @@ -30,7 +30,7 @@ riscv.assembly_section ".text" { // ----- -%0, %1, %2 = "test.op"() : () -> (!stream.readable, !stream.readable, !riscv.reg) +%0, %1, %2 = "test.op"() : () -> (!snitch.readable, !snitch.readable, !riscv.reg) %8 = riscv.li 8 : !riscv.reg %9 = riscv.li 8 : !riscv.reg @@ -65,7 +65,7 @@ riscv_scf.for %13 : !riscv.reg = %11 to %8 step %12 { } // CHECK: builtin.module { -// CHECK-NEXT: %{{.*}}, %{{.*}}, %{{.*}} = "test.op"() : () -> (!stream.readable, !stream.readable, !riscv.reg) +// CHECK-NEXT: %{{.*}}, %{{.*}}, %{{.*}} = "test.op"() : () -> (!snitch.readable, !snitch.readable, !riscv.reg) // CHECK-NEXT: %{{.*}} = riscv.li 8 : !riscv.reg // CHECK-NEXT: %{{.*}} = riscv.li 0 : !riscv.reg // CHECK-NEXT: %{{.*}} = riscv.li 1 : !riscv.reg diff --git a/tests/filecheck/transforms/snitch_register_allocation.mlir b/tests/filecheck/transforms/snitch_register_allocation.mlir index beeed82309..7a7084b9db 100644 --- a/tests/filecheck/transforms/snitch_register_allocation.mlir +++ b/tests/filecheck/transforms/snitch_register_allocation.mlir @@ -6,7 +6,7 @@ "stride_patterns" = [#snitch_stream.stride_pattern], "operandSegmentSizes" = array }> ({ -^0(%s0 : !stream.readable, %s1 : !stream.readable, %s2 : !stream.writable): +^0(%s0 : !snitch.readable, %s1 : !snitch.readable, %s2 : !snitch.writable): %c5 = riscv.li 5 : !riscv.reg riscv_snitch.frep_outer %c5 { %x = riscv_snitch.read from %s0 : !riscv.freg @@ -24,7 +24,7 @@ // CHECK-NEXT: #snitch_stream.stride_pattern // CHECK-NEXT: ] // CHECK-NEXT: } ins(%ptr0, %ptr1 : !riscv.reg, !riscv.reg) outs(%ptr2 : !riscv.reg) { -// CHECK-NEXT: ^{{.*}}(%s0 : !stream.readable>, %s1 : !stream.readable>, %s2 : !stream.writable>): +// CHECK-NEXT: ^{{.*}}(%s0 : !snitch.readable>, %s1 : !snitch.readable>, %s2 : !snitch.writable>): // CHECK-NEXT: %c5 = riscv.li 5 : !riscv.reg // CHECK-NEXT: riscv_snitch.frep_outer %c5 { // CHECK-NEXT: %x = riscv_snitch.read from %s0 : !riscv.freg diff --git a/tests/interpreters/test_riscv_snitch_interpreter.py b/tests/interpreters/test_riscv_snitch_interpreter.py index 9f6166500d..cbb89368eb 100644 --- a/tests/interpreters/test_riscv_snitch_interpreter.py +++ b/tests/interpreters/test_riscv_snitch_interpreter.py @@ -1,5 +1,5 @@ from xdsl.builder import Builder, ImplicitBuilder -from xdsl.dialects import func, riscv, riscv_snitch, stream +from xdsl.dialects import func, riscv, riscv_snitch, snitch from xdsl.dialects.builtin import ModuleOp from xdsl.interpreter import Interpreter from xdsl.interpreters.func import FuncFunctions @@ -21,18 +21,18 @@ def test_read_write(): output_stream = Acc() assert interpreter.run_op( - riscv_snitch.ReadOp(TestSSAValue(stream.ReadableStreamType(a0)), a0), + riscv_snitch.ReadOp(TestSSAValue(snitch.ReadableStreamType(a0)), a0), (input_stream,), ) == (1,) assert interpreter.run_op( - riscv_snitch.ReadOp(TestSSAValue(stream.ReadableStreamType(a1)), a1), + riscv_snitch.ReadOp(TestSSAValue(snitch.ReadableStreamType(a1)), a1), (input_stream,), ) == (2,) assert ( interpreter.run_op( riscv_snitch.WriteOp( - TestSSAValue(a0), TestSSAValue(stream.ReadableStreamType(a0)) + TestSSAValue(a0), TestSSAValue(snitch.ReadableStreamType(a0)) ), ( 1, @@ -45,7 +45,7 @@ def test_read_write(): assert ( interpreter.run_op( riscv_snitch.WriteOp( - TestSSAValue(a1), TestSSAValue(stream.ReadableStreamType(a1)) + TestSSAValue(a1), TestSSAValue(snitch.ReadableStreamType(a1)) ), ( 2, diff --git a/tests/interpreters/test_snitch_stream_interpreter.py b/tests/interpreters/test_snitch_stream_interpreter.py index d1fd1dcd7a..1d689643d4 100644 --- a/tests/interpreters/test_snitch_stream_interpreter.py +++ b/tests/interpreters/test_snitch_stream_interpreter.py @@ -1,7 +1,7 @@ import pytest from xdsl.builder import ImplicitBuilder -from xdsl.dialects import riscv, riscv_snitch, snitch_stream, stream +from xdsl.dialects import riscv, riscv_snitch, snitch, snitch_stream from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.interpreter import Interpreter from xdsl.interpreters.riscv import RiscvFunctions @@ -71,9 +71,9 @@ def test_snitch_stream_interpreter(): streaming_region_body = Region( Block( arg_types=( - stream.ReadableStreamType(riscv.Registers.FT0), - stream.ReadableStreamType(riscv.Registers.FT1), - stream.WritableStreamType(riscv.Registers.FT2), + snitch.ReadableStreamType(riscv.Registers.FT0), + snitch.ReadableStreamType(riscv.Registers.FT1), + snitch.WritableStreamType(riscv.Registers.FT2), ) ) ) diff --git a/xdsl/dialects/__init__.py b/xdsl/dialects/__init__.py index f9a342f151..3fe02e45b0 100644 --- a/xdsl/dialects/__init__.py +++ b/xdsl/dialects/__init__.py @@ -283,11 +283,6 @@ def get_stim(): return Stim - def get_stream(): - from xdsl.dialects.stream import Stream - - return Stream - def get_symref(): from xdsl.frontend.symref import Symref @@ -389,7 +384,6 @@ def get_transform(): "stablehlo": get_stablehlo, "stencil": get_stencil, "stim": get_stim, - "stream": get_stream, "symref": get_symref, "tensor": get_tensor, "test": get_test, diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 81af928429..2872d469b0 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -10,15 +10,16 @@ from collections.abc import Iterator, Sequence from enum import auto from itertools import product -from typing import Any, ClassVar, cast +from typing import Any, ClassVar, Generic, TypeVar, cast from typing_extensions import Self -from xdsl.dialects import memref, stream +from xdsl.dialects import memref from xdsl.dialects.builtin import ( AffineMapAttr, AnyMemRefTypeConstr, ArrayAttr, + ContainerType, IndexType, IntAttr, IntegerAttr, @@ -33,11 +34,15 @@ ParametrizedAttribute, Region, SSAValue, + TypeAttribute, ) from xdsl.irdl import ( AnyAttr, AttrSizedOperandSegments, + BaseAttr, + GenericAttrConstraint, IRDLOperation, + ParamAttrConstraint, ParameterDef, VarConstraint, irdl_attr_definition, @@ -61,6 +66,71 @@ from xdsl.utils.hints import isa from xdsl.utils.str_enum import StrEnum +_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) +_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) + + +@irdl_attr_definition +class ReadableStreamType( + Generic[_StreamTypeElement], + ParametrizedAttribute, + TypeAttribute, + ContainerType[_StreamTypeElement], +): + name = "memref_stream.readable" + + element_type: ParameterDef[_StreamTypeElement] + + def get_element_type(self) -> _StreamTypeElement: + return self.element_type + + def __init__(self, element_type: _StreamTypeElement): + super().__init__([element_type]) + + @staticmethod + def constr( + element_type: GenericAttrConstraint[_StreamTypeElementConstrT], + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: + return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( + ReadableStreamType, (element_type,) + ) + + +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) + + +@irdl_attr_definition +class WritableStreamType( + Generic[_StreamTypeElement], + ParametrizedAttribute, + TypeAttribute, + ContainerType[_StreamTypeElement], +): + name = "memref_stream.writable" + + element_type: ParameterDef[_StreamTypeElement] + + def get_element_type(self) -> _StreamTypeElement: + return self.element_type + + def __init__(self, element_type: _StreamTypeElement): + super().__init__([element_type]) + + @staticmethod + def constr( + element_type: GenericAttrConstraint[_StreamTypeElementConstrT], + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: + return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( + WritableStreamType, (element_type,) + ) + + +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) + class IteratorType(StrEnum): "Iterator type for memref_stream Attribute" @@ -193,15 +263,15 @@ class ReadOp(IRDLOperation): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(stream.ReadableStreamType.constr(T)) + stream = operand_def(ReadableStreamType.constr(T)) res = result_def(T) assembly_format = "`from` $stream attr-dict `:` type($res)" def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): if result_type is None: - assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) - stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + assert isinstance(stream_type := stream_val.type, ReadableStreamType) + stream_type = cast(ReadableStreamType[Attribute], stream_type) result_type = stream_type.element_type super().__init__(operands=[stream_val], result_types=[result_type]) @@ -216,7 +286,7 @@ class WriteOp(IRDLOperation): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(stream.WritableStreamType.constr(T)) + stream = operand_def(WritableStreamType.constr(T)) assembly_format = "$value `to` $stream attr-dict `:` type($value)" @@ -401,7 +471,7 @@ class GenericOp(IRDLOperation): Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be read. """ - outputs = var_operand_def(AnyMemRefTypeConstr | stream.AnyWritableStreamTypeConstr) + outputs = var_operand_def(AnyMemRefTypeConstr | AnyWritableStreamTypeConstr) """ Pointers to memory buffers or streams to be operated on. The corresponding stride pattern defines the order in which the elements of the input buffers will be written @@ -902,6 +972,8 @@ def __init__(self, memref: SSAValue, value: SSAValue): FillOp, ], [ + ReadableStreamType, + WritableStreamType, IteratorTypeAttr, StridePattern, ], diff --git a/xdsl/dialects/riscv_snitch.py b/xdsl/dialects/riscv_snitch.py index 1bcbaaa670..caa8c27ea4 100644 --- a/xdsl/dialects/riscv_snitch.py +++ b/xdsl/dialects/riscv_snitch.py @@ -8,7 +8,7 @@ from xdsl.backend.register_allocatable import RegisterConstraints from xdsl.backend.riscv.traits import StaticInsnRepresentation -from xdsl.dialects import riscv, stream +from xdsl.dialects import riscv, snitch from xdsl.dialects.builtin import ( IntAttr, IntegerAttr, @@ -41,6 +41,7 @@ from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( AnyAttr, + BaseAttr, VarConstraint, attr_def, base, @@ -161,15 +162,15 @@ class ReadOp(RISCVAsmOperation): T: ClassVar = VarConstraint("T", AnyAttr()) - stream = operand_def(stream.ReadableStreamType.constr(T)) + stream = operand_def(snitch.ReadableStreamType.constr(T)) res = result_def(T) assembly_format = "`from` $stream attr-dict `:` type($res)" def __init__(self, stream_val: SSAValue, result_type: Attribute | None = None): if result_type is None: - assert isinstance(stream_type := stream_val.type, stream.ReadableStreamType) - stream_type = cast(stream.ReadableStreamType[Attribute], stream_type) + assert isinstance(stream_type := stream_val.type, snitch.ReadableStreamType) + stream_type = cast(snitch.ReadableStreamType[Attribute], stream_type) result_type = stream_type.element_type super().__init__(operands=[stream_val], result_types=[result_type]) @@ -184,7 +185,7 @@ class WriteOp(RISCVAsmOperation): T: ClassVar = VarConstraint("T", AnyAttr()) value = operand_def(T) - stream = operand_def(stream.WritableStreamType.constr(T)) + stream = operand_def(snitch.WritableStreamType.constr(T)) assembly_format = "$value `to` $stream attr-dict `:` type($value)" @@ -482,7 +483,10 @@ def assembly_instruction_name(self) -> str: class GetStreamOp(RISCVAsmOperation): name = "riscv_snitch.get_stream" - stream = result_def(stream.StreamType[riscv.FloatRegisterType]) + stream = result_def( + snitch.ReadableStreamType.constr(BaseAttr(riscv.FloatRegisterType)) + | snitch.WritableStreamType.constr(BaseAttr(riscv.FloatRegisterType)) + ) def __init__(self, result_type: Attribute): super().__init__(result_types=[result_type]) diff --git a/xdsl/dialects/snitch.py b/xdsl/dialects/snitch.py index 984b4f7e9b..bb596a4c01 100644 --- a/xdsl/dialects/snitch.py +++ b/xdsl/dialects/snitch.py @@ -8,23 +8,102 @@ [1] https://pulp-platform.github.io/snitch/publications """ +from __future__ import annotations + from abc import ABC from collections.abc import Sequence from dataclasses import dataclass +from typing import Generic, TypeVar -from xdsl.dialects import stream -from xdsl.dialects.builtin import IntAttr +from xdsl.dialects.builtin import ContainerType, IntAttr from xdsl.dialects.riscv import IntRegisterType -from xdsl.ir import Attribute, Dialect, Operation, SSAValue +from xdsl.ir import ( + Attribute, + Dialect, + Operation, + ParametrizedAttribute, + SSAValue, + TypeAttribute, +) from xdsl.irdl import ( + BaseAttr, + GenericAttrConstraint, IRDLOperation, + ParamAttrConstraint, + ParameterDef, attr_def, + irdl_attr_definition, irdl_op_definition, operand_def, var_result_def, ) from xdsl.utils.exceptions import VerifyException +_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) +_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) + + +@irdl_attr_definition +class ReadableStreamType( + Generic[_StreamTypeElement], + ParametrizedAttribute, + TypeAttribute, + ContainerType[_StreamTypeElement], +): + name = "snitch.readable" + + element_type: ParameterDef[_StreamTypeElement] + + def get_element_type(self) -> _StreamTypeElement: + return self.element_type + + def __init__(self, element_type: _StreamTypeElement): + super().__init__([element_type]) + + @staticmethod + def constr( + element_type: GenericAttrConstraint[_StreamTypeElementConstrT], + ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: + return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( + ReadableStreamType, (element_type,) + ) + + +AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( + ReadableStreamType +) + + +@irdl_attr_definition +class WritableStreamType( + Generic[_StreamTypeElement], + ParametrizedAttribute, + TypeAttribute, + ContainerType[_StreamTypeElement], +): + name = "snitch.writable" + + element_type: ParameterDef[_StreamTypeElement] + + def get_element_type(self) -> _StreamTypeElement: + return self.element_type + + def __init__(self, element_type: _StreamTypeElement): + super().__init__([element_type]) + + @staticmethod + def constr( + element_type: GenericAttrConstraint[_StreamTypeElementConstrT], + ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: + return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( + WritableStreamType, (element_type,) + ) + + +AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( + WritableStreamType +) + @dataclass(frozen=True) class SnitchResources: @@ -143,7 +222,7 @@ class SsrEnable(IRDLOperation): name = "snitch.ssr_enable" - streams = var_result_def(stream.StreamType) + streams = var_result_def(AnyReadableStreamTypeConstr | AnyWritableStreamTypeConstr) def __init__(self, stream_types: Sequence[Attribute]): super().__init__(result_types=[stream_types]) @@ -172,5 +251,8 @@ def __init__(self): SsrEnable, SsrDisable, ], - [], + [ + ReadableStreamType, + WritableStreamType, + ], ) diff --git a/xdsl/dialects/stream.py b/xdsl/dialects/stream.py deleted file mode 100644 index 7ed748f393..0000000000 --- a/xdsl/dialects/stream.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from typing import Generic, TypeVar - -from xdsl.dialects.builtin import ContainerType -from xdsl.ir import ( - Attribute, - Dialect, - ParametrizedAttribute, - TypeAttribute, -) -from xdsl.irdl import ( - BaseAttr, - GenericAttrConstraint, - ParamAttrConstraint, - ParameterDef, - irdl_attr_definition, -) - -_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True) -_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute) - - -class StreamType( - Generic[_StreamTypeElement], - ParametrizedAttribute, - TypeAttribute, - ContainerType[_StreamTypeElement], -): - element_type: ParameterDef[_StreamTypeElement] - - def __init__(self, element_type: _StreamTypeElement): - super().__init__([element_type]) - - def get_element_type(self) -> _StreamTypeElement: - return self.element_type - - @staticmethod - def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[StreamType[_StreamTypeElementConstrT]]( - StreamType, (element_type,) - ) - - -@irdl_attr_definition -class ReadableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): - name = "stream.readable" - - @staticmethod - def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]( - ReadableStreamType, (element_type,) - ) - - -AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]]( - ReadableStreamType -) - - -@irdl_attr_definition -class WritableStreamType(Generic[_StreamTypeElement], StreamType[_StreamTypeElement]): - name = "stream.writable" - - @staticmethod - def constr( - element_type: GenericAttrConstraint[_StreamTypeElementConstrT], - ) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]: - return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]( - WritableStreamType, (element_type,) - ) - - -AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]]( - WritableStreamType -) - - -Stream = Dialect( - "stream", - [], - [ - ReadableStreamType, - WritableStreamType, - ], -) diff --git a/xdsl/transforms/convert_memref_stream_to_loops.py b/xdsl/transforms/convert_memref_stream_to_loops.py index 50bf4d3cf6..f13101725b 100644 --- a/xdsl/transforms/convert_memref_stream_to_loops.py +++ b/xdsl/transforms/convert_memref_stream_to_loops.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from xdsl.context import MLContext -from xdsl.dialects import arith, memref, memref_stream, stream +from xdsl.dialects import arith, memref, memref_stream from xdsl.dialects.builtin import AffineMapAttr, IntegerAttr, ModuleOp, UnitAttr from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass @@ -33,7 +33,7 @@ def _insert_load( rewriter, insertion_point, affine_map_attr.data, ind_vars ) op = memref.Load.get(source, indices) - elif isinstance(source.type, stream.ReadableStreamType): + elif isinstance(source.type, memref_stream.ReadableStreamType): op = memref_stream.ReadOp(source) else: return source diff --git a/xdsl/transforms/convert_memref_stream_to_snitch_stream.py b/xdsl/transforms/convert_memref_stream_to_snitch_stream.py index 090e35c9c7..a3277aed47 100644 --- a/xdsl/transforms/convert_memref_stream_to_snitch_stream.py +++ b/xdsl/transforms/convert_memref_stream_to_snitch_stream.py @@ -11,8 +11,8 @@ memref_stream, riscv, riscv_snitch, + snitch, snitch_stream, - stream, ) from xdsl.dialects.builtin import ( ArrayAttr, @@ -65,9 +65,9 @@ def match_and_rewrite( self, op: memref_stream.ReadOp, rewriter: PatternRewriter ) -> None: stream_type = op.stream.type - assert isinstance(stream_type, stream.ReadableStreamType) + assert isinstance(stream_type, memref_stream.ReadableStreamType) value_type = cast( - stream.ReadableStreamType[Attribute], stream_type + memref_stream.ReadableStreamType[Attribute], stream_type ).element_type if not snitch_stream_element_type_is_valid(value_type): raise DiagnosticException( @@ -76,7 +76,7 @@ def match_and_rewrite( register_type = riscv.Registers.UNALLOCATED_FLOAT new_stream = UnrealizedConversionCastOp.get( - (op.stream,), (stream.ReadableStreamType(register_type),) + (op.stream,), (snitch.ReadableStreamType(register_type),) ) new_op = riscv_snitch.ReadOp(new_stream.results[0]) if len(op.res.uses) == 1: @@ -103,9 +103,9 @@ def match_and_rewrite( self, op: memref_stream.WriteOp, rewriter: PatternRewriter ) -> None: stream_type = op.stream.type - assert isinstance(stream_type, stream.WritableStreamType) + assert isinstance(stream_type, memref_stream.WritableStreamType) value_type = cast( - stream.WritableStreamType[Attribute], stream_type + memref_stream.WritableStreamType[Attribute], stream_type ).element_type if not snitch_stream_element_type_is_valid(value_type): raise DiagnosticException( @@ -114,7 +114,7 @@ def match_and_rewrite( register_type = riscv.Registers.UNALLOCATED_FLOAT new_stream = UnrealizedConversionCastOp.get( - (op.stream,), (stream.WritableStreamType(register_type),) + (op.stream,), (snitch.WritableStreamType(register_type),) ) cast_op = UnrealizedConversionCastOp.get((op.value,), (register_type,)) if isinstance(defining_op := op.value.owner, Operation) and ( @@ -175,8 +175,8 @@ def match_and_rewrite( new_body = new_op.body.block - input_stream_types = (stream.ReadableStreamType(freg),) * len(op.inputs) - output_stream_types = (stream.WritableStreamType(freg),) * len(op.outputs) + input_stream_types = (snitch.ReadableStreamType(freg),) * len(op.inputs) + output_stream_types = (snitch.WritableStreamType(freg),) * len(op.outputs) stream_types = input_stream_types + output_stream_types for i in reversed(range(len(stream_types))): arg = new_body.args[i] diff --git a/xdsl/transforms/convert_riscv_scf_for_to_frep.py b/xdsl/transforms/convert_riscv_scf_for_to_frep.py index a49b9a4349..407f300a78 100644 --- a/xdsl/transforms/convert_riscv_scf_for_to_frep.py +++ b/xdsl/transforms/convert_riscv_scf_for_to_frep.py @@ -1,7 +1,7 @@ from itertools import chain from xdsl.context import MLContext -from xdsl.dialects import builtin, riscv, riscv_scf, riscv_snitch, stream +from xdsl.dialects import builtin, riscv, riscv_scf, riscv_snitch, snitch from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -36,7 +36,12 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> N return if not all( - isinstance(value.type, riscv.FloatRegisterType | stream.StreamType) + isinstance( + value.type, + riscv.FloatRegisterType + | snitch.ReadableStreamType + | snitch.WritableStreamType, + ) for o in body_block.ops for value in chain(o.operands, o.results) ): diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py index c7dc71415c..fb6a51be68 100644 --- a/xdsl/transforms/memref_streamify.py +++ b/xdsl/transforms/memref_streamify.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from xdsl.context import MLContext -from xdsl.dialects import memref, memref_stream, stream +from xdsl.dialects import memref, memref_stream from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.ir import Block, Region from xdsl.passes import ModulePass @@ -22,7 +22,13 @@ class StreamifyGenericOpPattern(RewritePattern): def match_and_rewrite( self, op: memref_stream.GenericOp, rewriter: PatternRewriter ) -> None: - if any(isinstance(operand.type, stream.StreamType) for operand in op.operands): + if any( + isinstance( + operand.type, + memref_stream.ReadableStreamType | memref_stream.WritableStreamType, + ) + for operand in op.operands + ): # Already streamified return @@ -59,10 +65,10 @@ def match_and_rewrite( input_el_types = tuple(el_type for _, el_type in streamed_input_indices) output_el_types = tuple(el_type for _, el_type in streamed_output_indices) input_stream_types = tuple( - stream.ReadableStreamType(el_type) for el_type in input_el_types + memref_stream.ReadableStreamType(el_type) for el_type in input_el_types ) output_stream_types = tuple( - stream.WritableStreamType(el_type) for el_type in output_el_types + memref_stream.WritableStreamType(el_type) for el_type in output_el_types ) # input patterns are never unnested diff --git a/xdsl/transforms/snitch_register_allocation.py b/xdsl/transforms/snitch_register_allocation.py index 6ac09c68b0..0b124ae78c 100644 --- a/xdsl/transforms/snitch_register_allocation.py +++ b/xdsl/transforms/snitch_register_allocation.py @@ -2,7 +2,7 @@ from typing import cast from xdsl.context import MLContext -from xdsl.dialects import riscv, riscv_snitch, snitch_stream, stream +from xdsl.dialects import riscv, riscv_snitch, snitch, snitch_stream from xdsl.dialects.builtin import ModuleOp from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -40,12 +40,12 @@ def match_and_rewrite( block = op.body.block for index, input_stream in enumerate(block.args): - input_stream.type = stream.ReadableStreamType(riscv.Registers.FT[index]) + input_stream.type = snitch.ReadableStreamType(riscv.Registers.FT[index]) input_count = len(op.inputs) for index, output_stream in enumerate(block.args[input_count:]): - output_stream.type = stream.WritableStreamType( + output_stream.type = snitch.WritableStreamType( riscv.Registers.FT[index + input_count] ) @@ -59,7 +59,7 @@ class AllocateRiscvSnitchReadRegisters(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: riscv_snitch.ReadOp, rewriter: PatternRewriter, /): stream_type = cast( - stream.ReadableStreamType[riscv.FloatRegisterType], op.stream.type + snitch.ReadableStreamType[riscv.FloatRegisterType], op.stream.type ) op.res.type = stream_type.element_type @@ -73,7 +73,7 @@ class AllocateRiscvSnitchWriteRegisters(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: riscv_snitch.WriteOp, rewriter: PatternRewriter, /): stream_type = cast( - stream.WritableStreamType[riscv.FloatRegisterType], op.stream.type + snitch.WritableStreamType[riscv.FloatRegisterType], op.stream.type ) op.value.type = stream_type.element_type