From 7b683243566fdcb5cd3cf0e6eb9db14a326110cd Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 31 Aug 2023 22:05:54 +0000 Subject: [PATCH 1/8] split_reshape bug fix --- src/simplify_algebra.cpp | 7 +-- test/simplify_algebra_test.cpp | 93 ++++++++++++++++++++++++++++++++-- 2 files changed, 93 insertions(+), 7 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 1c79158ebb2..ad2a01d78fc 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1355,9 +1355,10 @@ struct find_split_reshape return cont->outputs().front(); }); - // all outputs are reshape and of the same shape - auto dims = any_cast(rsp->get_operator()).dims; - if(not same_ops(vec_rsp)) + // all outputs are reshape and of the same shape as matched reshape op + std::vector all_rsp = vec_rsp; + all_rsp.push_back(rsp); + if(not same_ops(all_rsp)) { return; } diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index b6425974d78..2e7630d781d 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -603,8 +603,8 @@ TEST_CASE(simplify_inner_broadcast_scalar) migraphx::module m2; { - auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); - auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto yb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); @@ -630,8 +630,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) migraphx::module m2; { - auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); - auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); + auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); + auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); auto yb = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); @@ -2910,6 +2910,91 @@ TEST_CASE(reorder_reshape_slice_not_apply) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reorder_reshape_slice_multi_rsp) +{ + migraphx::module m1; + { + migraphx::shape s{migraphx::shape::float_type, {4, 128, 3, 32, 80}}; + auto input = m1.add_parameter("input", s); + auto t1 = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), t1); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), t1); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), t1); + + auto c1_1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2_1 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + + auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto r1 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c1); + + auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + auto r2 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c2); + + auto r1_1 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c1_1); + auto r2_1 = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c2_1); + + auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0); + auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c0); + + auto t2 = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), r1_1); + auto c_t2 = m1.add_instruction(migraphx::make_op("contiguous"), t2); + + auto dot = m1.add_instruction(migraphx::make_op("dot"), r0, c_t2); + + m1.add_return({r1, r2, dot, r2_1}); + }; + + migraphx::module m2; + { + migraphx::shape s{migraphx::shape::float_type, {4, 128, 3, 32, 80}}; + auto input = m2.add_parameter("input", s); + auto t1 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input); + auto c_t1 = m2.add_instruction(migraphx::make_op("contiguous"), t1); + auto rsp1 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), c_t1); + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {256}}, {"ends", {384}}}), rsp1); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), rsp1); + + auto t_slc1 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), slc1); + auto c_t_slc1 = m2.add_instruction(migraphx::make_op("contiguous"), t_slc1); + + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), rsp1); + + auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1); + + auto slc4 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), t1); + auto c4 = m2.add_instruction(migraphx::make_op("contiguous"), slc4); + auto r4 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c4); + + auto slc5 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), t1); + auto c5 = m2.add_instruction(migraphx::make_op("contiguous"), slc5); + auto r5 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c5); + + m2.add_return({r5, r4, dot, slc0}); + }; + + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + template void reorder_reshape_slice_diff_dims() { From d103134601bf7eabce68d7b9e114ef4fb397d3bd Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 6 Sep 2023 00:40:31 +0000 Subject: [PATCH 2/8] allow slice reshape simplification for multiple reshapes from single slice --- src/simplify_algebra.cpp | 81 +++++++++++++++++++++------------- test/simplify_algebra_test.cpp | 45 +++++++++++++------ 2 files changed, 82 insertions(+), 44 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index ad2a01d78fc..88bd55ee127 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1330,35 +1330,36 @@ struct find_split_reshape auto input = slc->inputs().front(); auto split_outputs = get_splits(input); + if(split_outputs.empty()) { return; } - // Only want to apply this optimization if each split output is followed by - // a contiguous op and a reshape - if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) { - if(i->outputs().size() == 1) - { - auto cont = i->outputs().front(); - return cont->outputs().size() != 1; - } - return false; - })) + // Find all the reshapes (similar to rsp) that can be simplified + std::vector conts; + std::vector vec_rsp; + + // Iterate through slice and contiguous outputs to allow simplifications when + // slice is followed by multiple reshapes + for(auto& i : split_outputs) { - return; + std::copy_if(i->outputs().begin(), + i->outputs().end(), + std::back_inserter(conts), + [](auto j) { return j->name() == "contiguous"; }); } - std::vector vec_rsp(split_outputs.size()); - std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { - auto cont = i->outputs().front(); - return cont->outputs().front(); - }); + for(auto& i : conts) + { + std::copy_if(i->outputs().begin(), + i->outputs().end(), + std::back_inserter(vec_rsp), + [&](auto j) { return j->get_operator() == rsp->get_operator(); }); + } - // all outputs are reshape and of the same shape as matched reshape op - std::vector all_rsp = vec_rsp; - all_rsp.push_back(rsp); - if(not same_ops(all_rsp)) + // No simplification needed if there is only one slice -> cont -> reshape + if(vec_rsp.size() <= 1) { return; } @@ -1368,6 +1369,10 @@ struct find_split_reshape auto slc_lens = slc->get_shape().lens(); auto slc_dim_size = std::accumulate( slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies()); + auto input_lens = input->get_shape().lens(); + auto input_size = std::accumulate( + input_lens.begin(), input_lens.end(), 1, std::multiplies()); + auto slc_axis_len = input_lens[axis]; // search the reshape output (standard shape) to decide which axis are // in its output corresponding to the slc_dim_size @@ -1394,16 +1399,34 @@ struct find_split_reshape { rsp_axis = std::distance(rsp_strides.begin(), ait); } - // calculate reshape output shape - std::vector vec_dims(vec_rsp.size()); - - std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { - return is->get_shape().lens()[rsp_axis]; - }); + // calculate reshape output shape std::vector rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); + rsp_out_lens[rsp_axis] = 1; + auto rsp_fixed_size = std::accumulate( + rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies()); + auto rsp_axis_len = input_size / rsp_fixed_size; + rsp_out_lens[rsp_axis] = rsp_axis_len; + + // calculate new slice start and end indices + // indices are scaled using the new reshape axis and the original slice axis + // eg. slice with size [1, 2, 4, 30] can be reshaped to [8, 30] and vice-versa + + std::vector new_starts(vec_rsp.size()); + std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) { + auto cont = is->inputs().front(); + auto og_slc = cont->inputs().front(); + return any_cast(og_slc->get_operator()).starts[0] * rsp_axis_len / + slc_axis_len; + }); - rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); + std::vector new_ends(vec_rsp.size()); + std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) { + auto cont = is->inputs().front(); + auto og_slc = cont->inputs().front(); + return any_cast(og_slc->get_operator()).ends[0] * rsp_axis_len / + slc_axis_len; + }); // insert the reshape instruction and add contiguous if needed if(not input->get_shape().standard()) @@ -1414,16 +1437,14 @@ struct find_split_reshape std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); // replace the original reshape with slice - int64_t start = 0; for(std::size_t i = 0; i < vec_rsp.size(); ++i) { m.replace_instruction( vec_rsp[i], make_op( "slice", - {{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}), + {{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}), rsp_ins); - start += vec_dims[i]; } } }; diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 2e7630d781d..f8e4f537cb4 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -2976,19 +2976,17 @@ TEST_CASE(reorder_reshape_slice_multi_rsp) auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1); - auto slc4 = m2.add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), t1); - auto c4 = m2.add_instruction(migraphx::make_op("contiguous"), slc4); - auto r4 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c4); + auto c_t1_1 = m2.add_instruction(migraphx::make_op("contiguous"), t1); + auto rsp2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), c_t1_1); - auto slc5 = m2.add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), t1); - auto c5 = m2.add_instruction(migraphx::make_op("contiguous"), slc5); - auto r5 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c5); + auto slc2_1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), rsp2); + + auto slc2_2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {12}}}), rsp2); - m2.add_return({r5, r4, dot, slc0}); + m2.add_return({slc2_1, slc2_2, dot, slc0}); }; run_pass(m1); @@ -3016,13 +3014,32 @@ void reorder_reshape_slice_diff_dims() std::vector lens = {static_cast(BS), 32, 3, 32}; std::vector lens1 = {static_cast(BS), 48, 2, 32}; auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); - auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); - auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2); + auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c1); + auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); m1.add_return({r0, r1, r2}); }; - auto m2 = m1; + migraphx::module m2; + { + auto s = migraphx::shape{migraphx::shape::float_type, {BS, 96, 96}}; + auto input = m2.add_parameter("input", s); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input); + auto c1 = m2.add_instruction(migraphx::make_op("contiguous"), slc1); + std::vector lens1 = {static_cast(BS), 48, 2, 32}; + auto r1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c1); + + std::vector lens = {static_cast(BS), 32, 3, 96}; + auto r_new = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), r_new); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), r_new); + + m2.add_return({slc0, r1, slc2}); + }; + run_pass(m1); EXPECT(m1.sort() == m2.sort()); } From 4b286e5e6a948edf5daabfab52522744014da5ab Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 6 Sep 2023 19:46:26 +0000 Subject: [PATCH 3/8] use elements() to get arg size --- src/simplify_algebra.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 88bd55ee127..9d5473dbeb0 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1369,9 +1369,8 @@ struct find_split_reshape auto slc_lens = slc->get_shape().lens(); auto slc_dim_size = std::accumulate( slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies()); - auto input_lens = input->get_shape().lens(); - auto input_size = std::accumulate( - input_lens.begin(), input_lens.end(), 1, std::multiplies()); + auto input_lens = input->get_shape().lens(); + auto input_size = input->get_shape().elements(); auto slc_axis_len = input_lens[axis]; // search the reshape output (standard shape) to decide which axis are From a349cf03d04cb3e46ccc23dbf044edc7d1ff5a9b Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 7 Sep 2023 23:02:01 +0000 Subject: [PATCH 4/8] handle edge case where dimension is not a multiple of the split sizes --- src/simplify_algebra.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index e9c4f6055b0..343aedbeb71 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1404,6 +1404,11 @@ struct find_split_reshape rsp_out_lens[rsp_axis] = 1; auto rsp_fixed_size = std::accumulate( rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies()); + // cannot create a valid reshape for simplification + if(input_size % rsp_fixed_size != 0) + { + return; + } auto rsp_axis_len = input_size / rsp_fixed_size; rsp_out_lens[rsp_axis] = rsp_axis_len; From dfe9465db4e3bf94dbb84c91b0c3853957a6585e Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 8 Sep 2023 19:28:20 +0000 Subject: [PATCH 5/8] add test case for invalid reshape axis len --- test/simplify_algebra_test.cpp | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 4f68c2cf625..34b6abf17b3 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -2993,6 +2993,43 @@ TEST_CASE(reorder_reshape_slice_multi_rsp) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reorder_reshape_slice_uneven_slice) +{ + auto create_p = [] { + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {128, 96}}; + auto input = m.add_parameter("input", s); + auto slc0 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {31}}}), input); + auto slc1 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {31}}, {"ends", {62}}}), input); + auto slc2 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {62}}, {"ends", {93}}}), input); + auto slc3 = m.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {93}}, {"ends", {128}}}), input); + + auto c0 = m.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {1, 31, 96}; + auto r0 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); + + auto sum = m.add_instruction(migraphx::make_op("add"), r0, r1); + auto ret = m.add_instruction(migraphx::make_op("mul"), sum, r2); + m.add_return({ret, slc3}); + + return m; + }; + + auto m1 = create_p(); + auto m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + template void reorder_reshape_slice_diff_dims() { From 0849f485412824c4a6057ee037f4cbe2710f0c01 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 18 Sep 2023 19:22:20 +0000 Subject: [PATCH 6/8] add check for single axis when performing split reshape simplification --- src/simplify_algebra.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 343aedbeb71..c5013a79386 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1325,12 +1325,18 @@ struct find_split_reshape void apply(module& m, const match::matcher_result& r) const { - auto slc = r.instructions["slice"]; - auto rsp = r.instructions["reshape"]; + auto slc = r.instructions["slice"]; + auto rsp = r.instructions["reshape"]; + auto input = slc->inputs().front(); - auto input = slc->inputs().front(); - auto split_outputs = get_splits(input); + // Only apply simplification when slices are on a single axis + auto axes = any_cast(slc->get_operator()).axes; + if(axes.size() > 1) + { + return; + } + auto split_outputs = get_splits(input); if(split_outputs.empty()) { return; @@ -1365,7 +1371,7 @@ struct find_split_reshape } // ensure reshape happens after the axis dimension - auto axis = any_cast(slc->get_operator()).axes[0]; + auto axis = axes[0]; auto slc_lens = slc->get_shape().lens(); auto slc_dim_size = std::accumulate( slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies()); From 2ebb7112961d5346307a3ae1953b721646b61224 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 19 Sep 2023 00:56:06 +0000 Subject: [PATCH 7/8] add comments with examples for clarity --- src/simplify_algebra.cpp | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index c5013a79386..43fb8ec17b9 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1405,11 +1405,28 @@ struct find_split_reshape rsp_axis = std::distance(rsp_strides.begin(), ait); } - // calculate reshape output shape + // Calculate reshape output shape + // Need to find a reshape such that data represented by instructions in vec_rsp can be + // written as slices of this new reshape. This is done by holding all the dims constant in + // rsp_lens to compute the required dim for rsp_axis (axis that will be sliced) + + // ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12), + // Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4} + // rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16 + // rsp_axis_len = 2*12*4 / 16 = 6 + // rsp_out_lens (final) = {2, 6, 2, 4} + + // ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12), + // Reshape Outputs: {2, 16}, {2, 16}, {2, 16} + // rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2 + // rsp_axis_len = 2*12*4 / 2 = 48 + // rsp_out_lens (final) = {2, 48} + std::vector rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); rsp_out_lens[rsp_axis] = 1; auto rsp_fixed_size = std::accumulate( rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies()); + // cannot create a valid reshape for simplification if(input_size % rsp_fixed_size != 0) { @@ -1418,9 +1435,20 @@ struct find_split_reshape auto rsp_axis_len = input_size / rsp_fixed_size; rsp_out_lens[rsp_axis] = rsp_axis_len; - // calculate new slice start and end indices - // indices are scaled using the new reshape axis and the original slice axis - // eg. slice with size [1, 2, 4, 30] can be reshaped to [8, 30] and vice-versa + // Calculate new slice start and end indices. Indices are scaled using the new reshape axis + // and the original slice axis. See examples: + + // ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12), + // Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4} + // slc_axis_len = 12, rsp_axis_len = 6 + // New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4} + // New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6} + + // ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12), + // Reshape Outputs: {2, 16}, {2, 16}, {2, 16} + // slc_axis_len = 12, rsp_axis_len = 48 + // New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32} + // New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48} std::vector new_starts(vec_rsp.size()); std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) { From 01f35e544614fdd937ab8a0b50de35e8d54044a0 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 19 Sep 2023 17:16:29 +0000 Subject: [PATCH 8/8] add additional test case for slice_reshape --- test/simplify_algebra_test.cpp | 53 ++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 34b6abf17b3..cea2fd54a56 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -2993,6 +2993,59 @@ TEST_CASE(reorder_reshape_slice_multi_rsp) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reorder_reshape_slice_partial) +{ + migraphx::module m1; + { + migraphx::shape s{migraphx::shape::float_type, {128, 96}}; + auto input = m1.add_parameter("input", s); + auto slc0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), input); + auto slc1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), input); + auto slc2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {16}}, {"ends", {24}}}), input); + auto slc3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {24}}, {"ends", {128}}}), input); + + auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0); + auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); + auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); + + std::vector lens = {2, 4, 96}; + auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); + auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); + auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); + + auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1); + auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2); + m1.add_return({ret, slc3}); + }; + + migraphx::module m2; + { + migraphx::shape s{migraphx::shape::float_type, {128, 96}}; + auto input = m2.add_parameter("input", s); + auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {32, 4, 96}}}), input); + auto slc3 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {24}}, {"ends", {128}}}), input); + + auto slc0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), rsp); + auto slc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), rsp); + auto slc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {6}}}), rsp); + + auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1); + auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2); + m2.add_return({ret, slc3}); + }; + + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(reorder_reshape_slice_uneven_slice) { auto create_p = [] {