diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index dc200cad8d1..fa7693545e4 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -24,6 +24,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -131,10 +132,44 @@ struct find_const_4in_slice } }; +/** + * Simplify dimensions_of to a literal when the input arugment has a static shape + */ +struct find_static_dimensions_of +{ + auto matcher() const + { + return match::name("dimensions_of")(match::arg(0)(match::static_shape())); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto input = ins->inputs().at(0); + auto dimensions_of_value = ins->get_operator().to_value(); + auto start = dimensions_of_value.at("start").to(); + auto end = dimensions_of_value.at("end").to(); + std::size_t output_ndim = end - start; + std::vector vec_shape(output_ndim); + migraphx::shape s(migraphx::shape::int64_type, {output_ndim}); + std::vector input_lens = input->get_shape().lens(); + std::transform(input_lens.begin() + start, + input_lens.begin() + end, + vec_shape.begin(), + [](auto i) { return int64_t(i); }); + migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}}; + auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape}); + m.replace_instruction(ins, lit_ins); + } +}; + void simplify_dyn_ops::apply(module& m) const { - match::find_matches( - m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); + match::find_matches(m, + find_static_2in_broadcasts{}, + find_const_3in_slice{}, + find_const_4in_slice{}, + find_static_dimensions_of{}); } } // namespace MIGRAPHX_INLINE_NS diff --git a/test/simplify_dyn_ops_test.cpp b/test/simplify_dyn_ops_test.cpp index e529b78f1e1..32e039bd29a 100644 --- a/test/simplify_dyn_ops_test.cpp +++ b/test/simplify_dyn_ops_test.cpp @@ -210,4 +210,32 @@ TEST_CASE(const_slice_4input) EXPECT(m0 == m1); } +TEST_CASE(static_dimensions_of) +{ + migraphx::module m0; + { + migraphx::shape s{migraphx::shape::float_type, {2, 4, 4}}; + m0.add_parameter("data", s); + migraphx::shape lit_shape{migraphx::shape::int64_type, {3}}; + ; + std::vector lit_data = {2, 4, 4}; + auto lit_ins = m0.add_literal(migraphx::literal{lit_shape, lit_data}); + m0.add_return({lit_ins}); + } + + // dead_code_elimination will get rid of atan + migraphx::module m1; + { + migraphx::shape s{migraphx::shape::float_type, {2, 4, 4}}; + auto input = m1.add_parameter("data", s); + auto atan_ins = m1.add_instruction(migraphx::make_op("atan"), input); + auto dimensions_of_ins = + m1.add_instruction(migraphx::make_op("dimensions_of", {{"end", 3}}), atan_ins); + m1.add_return({dimensions_of_ins}); + } + run_pass(m1); + + EXPECT(m0 == m1); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }