Skip to content

Commit

Permalink
Rework split_reshape simplification to handle multiple reshapes from …
Browse files Browse the repository at this point in the history
…same slice (#2146)

This resolves an edge case found in multiple huggingface models  in some cases the find_split_reshape matcher will match with reshape2, but vec_rsp will consist of reshape1 dims causing a shape mismatch error. Solution is to include rsp when checking all reshapes are the same.
  • Loading branch information
shivadbhavsar authored Sep 19, 2023
1 parent c2e01b1 commit 0bb8508
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 35 deletions.
124 changes: 92 additions & 32 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,48 +1325,59 @@ struct find_split_reshape

void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();

// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}

auto input = slc->inputs().front();
auto split_outputs = get_splits(input);
if(split_outputs.empty())
{
return;
}

// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
// Find all the reshapes (similar to rsp) that can be simplified
std::vector<instruction_ref> conts;
std::vector<instruction_ref> vec_rsp;

// Iterate through slice and contiguous outputs to allow simplifications when
// slice is followed by multiple reshapes
for(auto& i : split_outputs)
{
return;
std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
}

std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
auto cont = i->outputs().front();
return cont->outputs().front();
});
for(auto& i : conts)
{
std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
}

// all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(not same_ops(vec_rsp))
// No simplification needed if there is only one slice -> cont -> reshape
if(vec_rsp.size() <= 1)
{
return;
}

// ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto axis = axes[0];
auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];

// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
Expand All @@ -1393,16 +1404,67 @@ struct find_split_reshape
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());

std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis];
});
// Calculate reshape output shape
// Need to find a reshape such that data represented by instructions in vec_rsp can be
// written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)

// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}

// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}

std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());

rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;

// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:

// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}

// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}

std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});

std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});

// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
Expand All @@ -1413,16 +1475,14 @@ struct find_split_reshape
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);

// replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
m.replace_instruction(
vec_rsp[i],
make_op(
"slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}),
{{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins);
start += vec_dims[i];
}
}
};
Expand Down
198 changes: 195 additions & 3 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,179 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reorder_reshape_slice_multi_rsp)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {4, 128, 3, 32, 80}};
auto input = m1.add_parameter("input", s);
auto t1 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), t1);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), t1);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), t1);

auto c1_1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2_1 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);

auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto r1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c1);

auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
auto r2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 32, 128, 80}}}), c2);

auto r1_1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c1_1);
auto r2_1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c2_1);

auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {128, 128, 80}}}), c0);

auto t2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), r1_1);
auto c_t2 = m1.add_instruction(migraphx::make_op("contiguous"), t2);

auto dot = m1.add_instruction(migraphx::make_op("dot"), r0, c_t2);

m1.add_return({r1, r2, dot, r2_1});
};

migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {4, 128, 3, 32, 80}};
auto input = m2.add_parameter("input", s);
auto t1 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input);
auto c_t1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), c_t1);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {256}}, {"ends", {384}}}), rsp1);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {128}}, {"ends", {256}}}), rsp1);

auto t_slc1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), slc1);
auto c_t_slc1 = m2.add_instruction(migraphx::make_op("contiguous"), t_slc1);

auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {128}}}), rsp1);

auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1);

auto c_t1_1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), c_t1_1);

auto slc2_1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), rsp2);

auto slc2_2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {12}}}), rsp2);

m2.add_return({slc2_1, slc2_2, dot, slc0});
};

run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reorder_reshape_slice_partial)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {16}}, {"ends", {24}}}), input);
auto slc3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {24}}, {"ends", {128}}}), input);

auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);

std::vector<int64_t> lens = {2, 4, 96};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);

auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret, slc3});
};

migraphx::module m2;
{
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m2.add_parameter("input", s);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {32, 4, 96}}}), input);
auto slc3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {24}}, {"ends", {128}}}), input);

auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {6}}}), rsp);

auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret, slc3});
};

run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reorder_reshape_slice_uneven_slice)
{
auto create_p = [] {
migraphx::module m;
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m.add_parameter("input", s);
auto slc0 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {31}}}), input);
auto slc1 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {31}}, {"ends", {62}}}), input);
auto slc2 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {62}}, {"ends", {93}}}), input);
auto slc3 = m.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {93}}, {"ends", {128}}}), input);

auto c0 = m.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m.add_instruction(migraphx::make_op("contiguous"), slc2);

std::vector<int64_t> lens = {1, 31, 96};
auto r0 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);

auto sum = m.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m.add_instruction(migraphx::make_op("mul"), sum, r2);
m.add_return({ret, slc3});

return m;
};

auto m1 = create_p();
auto m2 = m1;
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}

template <std::size_t BS>
void reorder_reshape_slice_diff_dims()
{
Expand All @@ -2931,13 +3104,32 @@ void reorder_reshape_slice_diff_dims()
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);

m1.add_return({r0, r1, r2});
};

auto m2 = m1;
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 96, 96}};
auto input = m2.add_parameter("input", s);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {32}}, {"ends", {64}}}), input);
auto c1 = m2.add_instruction(migraphx::make_op("contiguous"), slc1);
std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32};
auto r1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c1);

std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 96};
auto r_new = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), r_new);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), r_new);

m2.add_return({slc0, r1, slc2});
};

run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
Expand Down

0 comments on commit 0bb8508

Please sign in to comment.