Skip to content

Commit

Permalink
Merge branch 'develop' into bump-migraphx-commit-ptr-tiling-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Sep 12, 2023
2 parents 4c7e9cd + 64b306a commit 711e307
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 18 deletions.
18 changes: 17 additions & 1 deletion src/include/migraphx/check_shapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ struct check_shapes
check_dynamic();
}

template <class Op>
template <class Op, MIGRAPHX_REQUIRES(not std::is_convertible<Op, std::string>{})>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{
check_dynamic();
}

check_shapes(const std::vector<shape>& s, const std::string& n, const bool d = false)
: begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d)
{
check_dynamic();
}

void check_dynamic() const
{
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
Expand Down Expand Up @@ -228,6 +234,16 @@ struct check_shapes
return *this;
}

/*!
* Check all shapes have the same layout.
*/
const check_shapes& same_layout() const
{
if(not this->same([](const shape& s) { return find_permutation(s); }))
MIGRAPHX_THROW(prefix() + "Layouts do not match");
return *this;
}

/*!
* Check all shapes are standard.
*/
Expand Down
6 changes: 4 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ struct miopen_convolution
{
check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts(
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
check_shapes{conv_inputs, *this}
.max_ndims(5)
.packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}})
.same_layout();
return migraphx::compute_shape<Op>(op, conv_inputs);
}

Expand Down
45 changes: 30 additions & 15 deletions test/check_shapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,39 @@

using migraphx::shape;

bool create_shapes(bool dynamic_allowed)
void create_shapes(bool dynamic_allowed)
{
try
{
shape a{shape::int64_type, {3}};
shape b{shape::float_type, {{3, 6}, {4, 4}}};
auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true;
}
catch(...)
{
return false;
}
shape a{shape::int64_type, {3}};
shape b{shape::float_type, {{3, 6}, {4, 4}}};
migraphx::check_shapes{{a, b}, "", dynamic_allowed}.has(2);
}

TEST_CASE(allow_dynamic_shape) { EXPECT(create_shapes(true)); }
TEST_CASE(allow_dynamic_shape)
{
EXPECT(not test::throws([] { create_shapes(true); }));
}

TEST_CASE(fail_dynamic_shape)
{
EXPECT(test::throws([] { create_shapes(false); }));
}

TEST_CASE(fail_dynamic_shape) { EXPECT(not create_shapes(false)); }
TEST_CASE(same_layout_fail)
{
EXPECT(test::throws([] {
shape a{shape::float_type, {2, 3}};
shape b{shape::float_type, {2, 3}, {1, 2}};
migraphx::check_shapes{{a, b}, ""}.same_layout();
}));
}

TEST_CASE(same_layout_pass)
{
EXPECT(not test::throws([] {
shape a{shape::float_type, {2, 3}, {1, 2}};
shape b{shape::float_type, {2, 3}, {1, 2}};
migraphx::check_shapes{{a, b}, ""}.same_layout();
}));
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 711e307

Please sign in to comment.