Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove contiguous from passes for reshapes #2319

Merged
merged 10 commits into from
Nov 15, 2023
3 changes: 1 addition & 2 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ struct find_pointwise_reshape_pointwise

auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
auto c = m.insert_instruction(ins_to_insert, make_op("contiguous"), input);
return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), c);
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input);
};
};
auto x_inputs = x_ins->inputs();
Expand Down
21 changes: 0 additions & 21 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,15 +941,6 @@ struct find_splits
{
auto split = i->inputs()[split_idx];
assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}

m.replace_instruction(i, split->get_operator(), c);
}
Expand Down Expand Up @@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}

int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
Expand Down Expand Up @@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len;
});

// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);

Expand Down
8 changes: 2 additions & 6 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ struct find_reshaper
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();

if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
Expand Down Expand Up @@ -475,9 +473,8 @@ struct find_resize
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp);
}
};

Expand Down Expand Up @@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary
auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name();
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
}
};

Expand Down
15 changes: 5 additions & 10 deletions test/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ TEST_CASE(add_reshape_add_nonstandard)
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), add1);
auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), c);
auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z);
mm->add_return({add2});
}
Expand All @@ -426,10 +425,8 @@ TEST_CASE(add_reshape_add_nonstandard)
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cy);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y);
auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) {
Expand Down Expand Up @@ -466,10 +463,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard)
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cy);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
Expand Down
29 changes: 13 additions & 16 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1897,12 +1897,17 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto concatb = m2.add_instruction(b, concat);
auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = m2.add_instruction(migraphx::make_op("relu"), sum);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 8}}}), relu);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering what change caused this reshape to split into two.

auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), rsp);
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);

auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4}}}), slc1);

auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {8}}}), rsp);
auto add = m2.add_instruction(migraphx::make_op("add"), slc1, slc2);
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);

auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4}}}), slc2);

auto add = m2.add_instruction(migraphx::make_op("add"), rsp1, rsp2);
m2.add_instruction(pass_op{}, add);
}
EXPECT(m1.sort() == m2.sort());
Expand Down Expand Up @@ -2323,9 +2328,8 @@ TEST_CASE(simplify_dot_horiz_reshape)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {4}}}), dot);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot);
auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont);
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x);
auto y_rsp =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
Expand Down Expand Up @@ -2690,10 +2694,6 @@ void reorder_reshape_slice()
}
auto input = m2.add_parameter("input", s);
auto rsp_input = input;
if(TransposeInput)
{
rsp_input = m2.add_instruction(migraphx::make_op("contiguous"), {input});
}
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 30, 64};
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), rsp_input);

Expand Down Expand Up @@ -2976,9 +2976,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp)
auto input = m2.add_parameter("input", s);
auto t1 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input);
auto c_t1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), c_t1);
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), t1);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {256}}, {"ends", {384}}}), rsp1);
auto slc1 = m2.add_instruction(
Expand All @@ -2993,9 +2992,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp)

auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1);

auto c_t1_1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), c_t1_1);
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), t1);

auto slc2_1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), rsp2);
Expand Down Expand Up @@ -3372,9 +3370,8 @@ TEST_CASE(dot_fusion_reshape)
auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {320}}, {"ends", {640}}}), d);

auto cont0 = m2.add_instruction(migraphx::make_op("contiguous"), s0);
auto r0 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 8, 40}}}), cont0);
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 8, 40}}}), s0);

m2.add_return({r0, s1});
};
Expand Down
12 changes: 4 additions & 8 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,8 @@ TEST_CASE(optimize_resize)
std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3};
auto mbx =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx);
auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx);
std::vector<int64_t> orig_dims = {1, 2, 4, 6};
auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb);
auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), mbx);
auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), rmb);
m.add_return({r});

Expand Down Expand Up @@ -1301,9 +1300,8 @@ TEST_CASE(transpose_contiguous_reshape_unary)
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins2);
}
EXPECT(m1 == m2);
Expand All @@ -1328,8 +1326,7 @@ TEST_CASE(transpose_contiguous_squeeze_unary)
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
m2.add_instruction(pass_op{}, sq_ins);
}
EXPECT(m1 == m2);
Expand All @@ -1355,9 +1352,8 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round);
auto unsq_ins =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round);
m2.add_instruction(pass_op{}, unsq_ins);
}
EXPECT(m1 == m2);
Expand Down
Loading