From af3e36c0ab99d1489b65ca5a5d36e98645bda802 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 11 Nov 2023 17:37:49 +0000 Subject: [PATCH] Remove additional contiguous around reshape from fuse_pointwise and simplify_reshapes. Update tests to reflect change accordingly --- src/fuse_pointwise.cpp | 3 +-- src/simplify_reshapes.cpp | 8 ++------ test/fuse_pointwise.cpp | 15 +++++---------- test/simplify_reshapes_test.cpp | 12 ++++-------- 4 files changed, 12 insertions(+), 26 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 54f362c57fa..491f68a146e 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -219,9 +219,8 @@ struct find_pointwise_reshape_pointwise auto reshape_input = [&](const auto& ins_to_insert) { return [&](auto input) { - auto c = m.insert_instruction(ins_to_insert, make_op("contiguous"), input); return m.insert_instruction( - ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), c); + ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input); }; }; auto x_inputs = x_ins->inputs(); diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 9528b0fe2e0..0dc093026a3 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -103,8 +103,6 @@ struct find_reshaper auto input = mr.instructions["x"]; auto dims = ins->get_shape().lens(); - if(not input->get_shape().standard()) - input = m.insert_instruction(ins, make_op("contiguous"), input); m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input); } }; @@ -475,9 +473,8 @@ struct find_resize ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); auto mb_rsp = m.insert_instruction( ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); - auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); std::vector rsp_dims(out_lens.begin(), out_lens.end()); - m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); + m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp); } }; @@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary auto cont_ins = r.instructions["cont_ins"]; auto unary_op_name = ins->get_operator().name(); auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); - auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins); // older cont and reshape are removed by deadcode elimination - m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); + m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins); } }; diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 39827eb2507..80a79b673f0 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -414,8 +414,7 @@ TEST_CASE(add_reshape_add_nonstandard) auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - auto c = mm->add_instruction(migraphx::make_op("contiguous"), add1); - auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), c); + auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1); auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); mm->add_return({add2}); } @@ -426,10 +425,8 @@ TEST_CASE(add_reshape_add_nonstandard) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); - auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y); - auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cx); - auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cy); + auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y); auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); auto fadd = add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) { @@ -466,10 +463,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); - auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y); - auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cx); - auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cy); + auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y); auto fadd = add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 2429b93ed85..abfd97d334b 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -888,9 +888,8 @@ TEST_CASE(optimize_resize) std::vector mb_dims = {1, 2, 2, 2, 2, 3}; auto mbx = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx); - auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx); std::vector orig_dims = {1, 2, 4, 6}; - auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb); + auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), mbx); auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), rmb); m.add_return({r}); @@ -1301,9 +1300,8 @@ TEST_CASE(transpose_contiguous_reshape_unary) auto transpose_ins = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins); - auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu); auto reshape_ins2 = - m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins); + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu); m2.add_instruction(pass_op{}, reshape_ins2); } EXPECT(m1 == m2); @@ -1328,8 +1326,7 @@ TEST_CASE(transpose_contiguous_squeeze_unary) auto transpose_ins = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins); - auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt); - auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins); + auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt); m2.add_instruction(pass_op{}, sq_ins); } EXPECT(m1 == m2); @@ -1355,9 +1352,8 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary) auto transpose_ins = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins); - auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round); auto unsq_ins = - m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins); + m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round); m2.add_instruction(pass_op{}, unsq_ins); } EXPECT(m1 == m2);