diff --git a/src/include/migraphx/check_shapes.hpp b/src/include/migraphx/check_shapes.hpp index ced99e5d593..a273d4627d0 100644 --- a/src/include/migraphx/check_shapes.hpp +++ b/src/include/migraphx/check_shapes.hpp @@ -70,13 +70,19 @@ struct check_shapes check_dynamic(); } - template + template {})> check_shapes(const std::vector& s, const Op& op, const bool d = false) : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d) { check_dynamic(); } + check_shapes(const std::vector& s, const std::string& n, const bool d = false) + : begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d) + { + check_dynamic(); + } + void check_dynamic() const { if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); })) @@ -228,6 +234,16 @@ struct check_shapes return *this; } + /*! + * Check all shapes have the same layout. + */ + const check_shapes& same_layout() const + { + if(not this->same([](const shape& s) { return find_permutation(s); })) + MIGRAPHX_THROW(prefix() + "Layouts do not match"); + return *this; + } + /*! * Check all shapes are standard. */ diff --git a/src/targets/gpu/include/migraphx/gpu/convolution.hpp b/src/targets/gpu/include/migraphx/gpu/convolution.hpp index 0b85075c66d..d6680f17ec8 100644 --- a/src/targets/gpu/include/migraphx/gpu/convolution.hpp +++ b/src/targets/gpu/include/migraphx/gpu/convolution.hpp @@ -84,8 +84,10 @@ struct miopen_convolution { check_shapes{inputs, op}.has(4); std::vector conv_inputs(inputs.begin(), inputs.begin() + 2); - check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts( - {{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}); + check_shapes{conv_inputs, *this} + .max_ndims(5) + .packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}) + .same_layout(); return migraphx::compute_shape(op, conv_inputs); } diff --git a/test/check_shapes_test.cpp b/test/check_shapes_test.cpp index 021877ad471..557c9f1b7a8 100644 --- a/test/check_shapes_test.cpp +++ b/test/check_shapes_test.cpp @@ -31,24 +31,39 @@ using migraphx::shape; -bool create_shapes(bool dynamic_allowed) +void create_shapes(bool dynamic_allowed) { - try - { - shape a{shape::int64_type, {3}}; - shape b{shape::float_type, {{3, 6}, {4, 4}}}; - auto op = migraphx::make_op("add"); - migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2); - return true; - } - catch(...) - { - return false; - } + shape a{shape::int64_type, {3}}; + shape b{shape::float_type, {{3, 6}, {4, 4}}}; + migraphx::check_shapes{{a, b}, "", dynamic_allowed}.has(2); } -TEST_CASE(allow_dynamic_shape) { EXPECT(create_shapes(true)); } +TEST_CASE(allow_dynamic_shape) +{ + EXPECT(not test::throws([] { create_shapes(true); })); +} + +TEST_CASE(fail_dynamic_shape) +{ + EXPECT(test::throws([] { create_shapes(false); })); +} -TEST_CASE(fail_dynamic_shape) { EXPECT(not create_shapes(false)); } +TEST_CASE(same_layout_fail) +{ + EXPECT(test::throws([] { + shape a{shape::float_type, {2, 3}}; + shape b{shape::float_type, {2, 3}, {1, 2}}; + migraphx::check_shapes{{a, b}, ""}.same_layout(); + })); +} + +TEST_CASE(same_layout_pass) +{ + EXPECT(not test::throws([] { + shape a{shape::float_type, {2, 3}, {1, 2}}; + shape b{shape::float_type, {2, 3}, {1, 2}}; + migraphx::check_shapes{{a, b}, ""}.same_layout(); + })); +} int main(int argc, const char* argv[]) { test::run(argc, argv); }