diff --git a/src/include/migraphx/rewrite_pooling.hpp b/src/include/migraphx/rewrite_pooling.hpp index a250e255f78..7b715c02f3d 100644 --- a/src/include/migraphx/rewrite_pooling.hpp +++ b/src/include/migraphx/rewrite_pooling.hpp @@ -26,6 +26,7 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -39,6 +40,10 @@ struct MIGRAPHX_EXPORT rewrite_pooling { std::string name() const { return "rewrite_pooling"; } void apply(module& m) const; + + private: + void replace_with_reduce(module& m, instruction_ref ins) const; + void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) const; }; } // namespace MIGRAPHX_INLINE_NS diff --git a/src/rewrite_pooling.cpp b/src/rewrite_pooling.cpp index 0ecb07c9dae..a2944900897 100644 --- a/src/rewrite_pooling.cpp +++ b/src/rewrite_pooling.cpp @@ -46,38 +46,156 @@ void rewrite_pooling::apply(module& m) const auto&& s = ins->inputs().front()->get_shape(); if(not s.standard()) continue; - auto&& op = any_cast(ins->get_operator()); - if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) - continue; - if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; })) - continue; - if(not std::all_of(op.dilations.begin(), op.dilations.end(), [](auto i) { return i == 1; })) - continue; - auto lens = s.lens(); - if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) - continue; - std::int64_t n = s.lens()[0]; - std::int64_t c = s.lens()[1]; - auto reshape = m.insert_instruction( - ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front()); - instruction_ref pooling{}; - - // average pooling - if(op.mode == op::pooling_mode::average) + auto&& op = any_cast(ins->get_operator()); + bool same_kernel_as_shape = std::equal( + s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend()); + bool same_kernel_values = std::all_of(op.lengths.cbegin() + 1, + op.lengths.cend(), + [&op](auto i) { return i == op.lengths.at(0); }); + bool default_strides = + std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; }); + bool default_padding = + std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; }); + bool default_dilations = + std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; }); + if(same_kernel_as_shape and default_strides and default_padding and default_dilations) + { + replace_with_reduce(m, ins); + } + else if(not default_dilations) + { + // Dilated AvgPool with padding is not supported + if(not default_padding and op.mode == op::pooling_mode::average) + { + continue; + } + // Asym kernels not supported + if(not same_kernel_values) + { + continue; + } + auto size = + std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies()); + // Can't handle too much size because of literal size + if(size > 100000) + { + continue; + } + + replace_dilations_with_gather_pooling(m, ins); + } + } +} + +void rewrite_pooling::replace_with_reduce(module& m, instruction_ref ins) const +{ + auto&& s = ins->inputs().front()->get_shape(); + auto&& op = any_cast(ins->get_operator()); + std::int64_t n = s.lens()[0]; + std::int64_t c = s.lens()[1]; + auto reshape = m.insert_instruction( + ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front()); + instruction_ref pooling{}; + + // average pooling + if(op.mode == op::pooling_mode::average) + { + pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); + } + // max pooling + else + { + pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); + } + + std::vector rsp_lens(s.lens().size(), 1); + rsp_lens[0] = n; + rsp_lens[1] = c; + m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling); +} + +void rewrite_pooling::replace_dilations_with_gather_pooling(module& m, instruction_ref ins) const +{ + // TODO remove this when MIOpen supports dilated pooling + auto&& s = ins->inputs().front()->get_shape(); + auto&& op = any_cast(ins->get_operator()); + // Ignore N, C axes + std::vector dims = {s.lens().cbegin() + 2, s.lens().cend()}; + + bool default_padding = + std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; }); + + if(not default_padding) + { + for(size_t idx{0}; idx < op.padding.size(); ++idx) { - pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); + // We need to pad both ends + dims[idx] += op.padding.at(idx) * 2; } - // max pooling - else + } + std::vector kernels = op.lengths; + std::vector strides = op.stride; + std::vector dilations = op.dilations; + + std::vector> axis_indices; + axis_indices.resize(dims.size()); + + for(auto idx{0}; idx < dims.size(); ++idx) + { + // Only consider if iw fits into the window + for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1); + stride += strides.at(idx)) { - pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); + for(size_t step{0}; step < kernels.at(idx); ++step) + { + axis_indices.at(idx).push_back(stride + dilations.at(idx) * step); + } + } + } + + auto elements = ins->inputs().front(); + if(not default_padding) + { + // Pad supports asym, we need to provide both ends + std::vector padding(2 * s.lens().size(), 0); + // Format will be e.g {N, C, P1, P2, N, C, P1, P2} + for(size_t idx{0}; idx < op.padding.size(); ++idx) + { + // Ignore N, C axes + padding.at(2 + idx) = op.padding.at(idx); + padding.at(2 + idx + s.lens().size()) = op.padding.at(idx); } - std::vector rsp_lens(lens.size(), 1); - rsp_lens[0] = n; - rsp_lens[1] = c; - m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling); + // Default value needed for Max pooling + elements = m.insert_instruction( + ins, + make_op("pad", {{"pads", padding}, {"value", std::numeric_limits::lowest()}}), + elements); } + + for(auto idx{0}; idx < axis_indices.size(); ++idx) + { + migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}}; + auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)}); + elements = m.insert_instruction( + ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices); + } + + // Ignore padding + std::vector new_padding(kernels.size(), 0); + // The kernel window elements (pairs) are places next to each other. E.g. {x1, y1, x2, y2, ...} + // We need to skip them to not overlap + std::vector new_strides(kernels.size(), kernels.at(0)); + // Ignore dilations + std::vector new_dilations(kernels.size(), 1); + m.replace_instruction(ins, + make_op("pooling", + {{"mode", op.mode}, + {"padding", new_padding}, + {"stride", new_strides}, + {"lengths", kernels}, + {"dilations", new_dilations}}), + elements); } } // namespace MIGRAPHX_INLINE_NS diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 4ae7203b6da..2e558c17ee0 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -289,7 +289,8 @@ def create_backend_test(testname=None, target_device=None): r'test_argmin_no_keepdims_example_select_last_index_cpu') backend_test.exclude(r'test_lrn_cpu') backend_test.exclude(r'test_lrn_default_cpu') - backend_test.exclude(r'test_maxpool_2d_dilations_cpu') + # MaxPool dialtion is partially supported on GPU by a workaround + # But these tests require too large allocations to work properly backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') backend_test.exclude( diff --git a/test/rewrite_pooling_test.cpp b/test/rewrite_pooling_test.cpp index f8047c33e2c..136ffc7bb51 100644 --- a/test/rewrite_pooling_test.cpp +++ b/test/rewrite_pooling_test.cpp @@ -83,6 +83,328 @@ TEST_CASE(rewrite_pooling_test) test_rewrite(migraphx::op::pooling_mode::max, migraphx::make_op("reduce_max", {{"axes", {1}}})); } +TEST_CASE(rewrite_pooling_dialtions_test) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 5, 5}}; + auto pooling_program = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m; + auto input = m.add_parameter("x", s); + auto ret = m.add_instruction(migraphx::make_op("pooling", + {{"mode", mode}, + {"padding", {0, 0}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}, + {"dilations", {2, 2}}}), + input); + m.add_return({ret}); + return m; + }; + + auto opt_program = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m; + auto input = m.add_parameter("x", s); + std::vector indices{0, 2, 1, 3, 2, 4}; + migraphx::shape s_indices{migraphx::shape::int32_type, {indices.size()}}; + auto i1 = m.add_literal(migraphx::literal{s_indices, indices}); + auto g1 = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), input, i1); + auto i2 = m.add_literal(migraphx::literal{s_indices, indices}); + auto g2 = m.add_instruction(migraphx::make_op("gather", {{"axis", 3}}), g1, i2); + auto ret = m.add_instruction(migraphx::make_op("pooling", + {{"mode", mode}, + {"padding", {0, 0}}, + {"stride", {2, 2}}, + {"lengths", {2, 2}}, + {"dilations", {1, 1}}}), + g2); + m.add_return({ret}); + return m; + }; + + auto test_rewrite = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m1 = pooling_program(mode); + migraphx::module m2 = opt_program(mode); + opt_pooling(m1); + EXPECT(m1 == m2); + }; + + test_rewrite(migraphx::op::pooling_mode::average); + test_rewrite(migraphx::op::pooling_mode::max); +} + +TEST_CASE(rewrite_pooling_dialtions_test2) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 1, 5, 5, 5}}; + auto pooling_program = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m; + auto input = m.add_parameter("x", s); + auto ret = m.add_instruction(migraphx::make_op("pooling", + {{"mode", mode}, + {"padding", {0, 0, 0}}, + {"stride", {1, 1, 1}}, + {"lengths", {2, 2, 2}}, + {"dilations", {2, 2, 2}}}), + input); + m.add_return({ret}); + return m; + }; + + auto opt_program = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m; + auto input = m.add_parameter("x", s); + std::vector indices{0, 2, 1, 3, 2, 4}; + migraphx::shape s_indices{migraphx::shape::int32_type, {indices.size()}}; + auto i1 = m.add_literal(migraphx::literal{s_indices, indices}); + auto g1 = m.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), input, i1); + auto i2 = m.add_literal(migraphx::literal{s_indices, indices}); + auto g2 = m.add_instruction(migraphx::make_op("gather", {{"axis", 3}}), g1, i2); + auto i3 = m.add_literal(migraphx::literal{s_indices, indices}); + auto g3 = m.add_instruction(migraphx::make_op("gather", {{"axis", 4}}), g2, i3); + auto ret = m.add_instruction(migraphx::make_op("pooling", + {{"mode", mode}, + {"padding", {0, 0, 0}}, + {"stride", {2, 2, 2}}, + {"lengths", {2, 2, 2}}, + {"dilations", {1, 1, 1}}}), + g3); + m.add_return({ret}); + return m; + }; + + auto test_rewrite = [&](const migraphx::op::pooling_mode mode) { + migraphx::module m1 = pooling_program(mode); + migraphx::module m2 = opt_program(mode); + opt_pooling(m1); + EXPECT(m1 == m2); + }; + + test_rewrite(migraphx::op::pooling_mode::average); + test_rewrite(migraphx::op::pooling_mode::max); +} + +TEST_CASE(rewrite_pooling_dialtions_test3) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 5}}; + auto pooling_program = [&]() { + migraphx::module m; + + auto input = m.add_parameter("x", s); + auto ret = + m.add_instruction(migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"padding", {1}}, + {"stride", {1}}, + {"lengths", {3}}, + {"dilations", {2}}}), + input); + m.add_return({ret}); + return m; + }; + + migraphx::module m1 = pooling_program(); + migraphx::module m2 = m1; + + opt_pooling(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(rewrite_avgpool_rank3_dil_test) +{ + // 1D case 1, input is 3D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + op.dilations = {2}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.35, 0.15, 0.85, 0.3, 0.1, 0.65}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(rewrite_avgpool_rank3_dil_test2) +{ + // 1D case 1, input is 3D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + op.dilations = {3}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.2, 0.45, 0.35}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(rewrite_maxpool_rank3_test) +{ + // 1D case 1, input is 3D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2}; + op.padding = {0}; + op.stride = {1}; + op.dilations = {2}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.4, 0.2, 0.9, 0.5, 0.1, 0.7}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(rewrite_maxpool_rank3_test2) +{ + // 1D case 1, input is 3D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2}; + op.padding = {1}; + op.stride = {1}; + op.dilations = {3}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.4, 0.3, 0.2, 0.9, 0.8, 0.5, 0.1, 0.6, 0.7}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(rewrite_maxpool_rank3_test3) +{ + // 1D case 1, input is 3D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 4}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {3}; + op.padding = {2}; + op.stride = {2}; + op.dilations = {3}; + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.2, 0.5, 0.7}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(maxpool_rank5_test) +{ + // 3D, input is 5D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2, 2, 2}; + op.padding = {0, 0, 0}; + op.stride = {1, 1, 1}; + op.dilations = {2, 2, 2}; + + std::vector data{ + -2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, 0.3405, + -0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102, -0.042, -0.7146, + 0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622, -0.5691, 0.907, -0.0569, + -1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282, 1.3655, -2.1519, 0.6068, -0.2001, + -0.4702, 0.3864, 1.7083, 0.9096, 0.4286, -1.8866, 0.7034, 0.0293, 1.4587, 0.7672, + -2.8614, 0.8124, -0.053, 1.0449, 0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337, + -0.4644, 0.1938, 0.2889, 0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756, + 0.0677, -0.0223, -0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777, + -0.9427, 2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229, + -0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664, 1.1364, + 0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{0.7812, 1.0449, 2.7666, 2.6859}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(maxpool_rank5_test2) +{ + // 3D, input is 5D + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {2, 2, 3, 3, 3}}; + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + op.lengths = {2, 2, 2}; + op.padding = {2, 2, 2}; + op.stride = {2, 2, 2}; + op.dilations = {3, 3, 3}; + + std::vector data{ + -2.8029, 0.5861, 0.7015, 0.1297, -1.44, -1.9472, 0.7812, 2.408, -0.3145, 0.3405, + -0.9146, 0.0624, 1.5064, -0.8345, 1.7977, 1.8949, 1.0073, -0.2102, -0.042, -0.7146, + 0.6227, -0.5263, -2.2598, 0.1713, 0.449, 0.5303, -0.8622, -0.5691, 0.907, -0.0569, + -1.5348, -0.4109, -0.1461, -0.5445, 0.4266, 0.2282, 1.3655, -2.1519, 0.6068, -0.2001, + -0.4702, 0.3864, 1.7083, 0.9096, 0.4286, -1.8866, 0.7034, 0.0293, 1.4587, 0.7672, + -2.8614, 0.8124, -0.053, 1.0449, 0.845, -0.0131, 0.1139, -0.859, -1.2681, -0.6337, + -0.4644, 0.1938, 0.2889, 0.9035, 0.7118, -0.5767, 0.4577, -0.0549, 0.2237, 0.5756, + 0.0677, -0.0223, -0.329, 0.2364, 2.7666, -0.7417, -1.3196, -0.2655, 0.1698, -0.1777, + -0.9427, 2.6859, -0.7501, 0.5175, 1.0029, -2.6436, -0.4388, -1.2348, -0.1539, -0.6229, + -0.4136, 0.5085, 0.4136, -0.6439, -1.1953, -0.406, -0.0195, 0.1869, -0.8664, 1.1364, + 0.5041, 0.0647, 0.1941, -1.0819, -0.4629, -0.5107, 0.3612, -0.3583}; + auto l0 = mm->add_literal(migraphx::literal{s, data}); + mm->add_instruction(op, l0); + opt_pooling(*mm); + p.compile(migraphx::make_target("ref")); + auto result = p.eval({}).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + std::vector gold{-0.8345, 1.5064, -0.9146, 0.3405, -1.44, 0.1297, 0.5861, -2.8029, + -0.4702, -0.2001, -2.1519, 1.3655, -0.4109, -1.5348, 0.907, -0.5691, + -0.0549, 0.4577, 0.7118, 0.9035, -1.2681, -0.859, -0.0131, 0.845, + -1.1953, -0.6439, 0.5085, -0.4136, -2.6436, 1.0029, -0.7501, 2.6859}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + TEST_CASE(rewrite_avepooling_na1_test) { migraphx::shape s{migraphx::shape::float_type, {2, 2, 3, 4, 5}};