Skip to content

Commit

Permalink
Fix dyn broadcasting (#2821)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Mar 1, 2024
1 parent d9b81bd commit 84fc9f0
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 95 deletions.
19 changes: 8 additions & 11 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,13 @@ compute_broadcasted_dyn_dims(std::vector<shape::dynamic_dimension> dds0,
{
return b;
}
else if(a.within_range(b))
{
return shape::dynamic_dimension{a.min, a.max};
}
else if(b.within_range(a))
{
return shape::dynamic_dimension{b.min, b.max};
}
else
{
auto intersect = a.intersection(b);
if(intersect.has_value())
{
return intersect.value();
}
MIGRAPHX_THROW("COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {" +
migraphx::to_string_range(dds0) + "} and {" +
migraphx::to_string_range(dds1) + "} mismatch!");
Expand Down Expand Up @@ -225,10 +222,10 @@ instruction_ref add_common_op(module& m, const operation& op, std::vector<instru
return insert_common_op(m, m.end(), op, std::move(inputs));
}

shape make_bcast_shape(const shape& input_shape,
const std::vector<std::size_t>& bcast_lens,
const std::size_t& offset)
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens)
{
assert(not input_shape.dynamic());
auto offset = bcast_lens.size() - input_shape.ndim();
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i : reverse(range(input_shape.ndim())))
{
Expand Down
29 changes: 14 additions & 15 deletions src/include/migraphx/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,16 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
* Compares `dynamic_dimension` objects from the trailing (right-most) dimension and working
* leftwards.
*
* Rules for broadcasting:
* Rules for broadcasting dynamic_dimension:
* If the same `dynamic_dimension`, return either.
* If one of the `dynamic_dimension`s is 1, return the other one.
* If one `dynamic_dimension` can fit within the range of the other,
* return the `dynamic_dimension` with the smaller range.
* Else, throw an error.
*
* For the within_range() cases the broadcasting works out to outputting the smaller ranger because
* for the shape to be broadcastable at runtime (when the dimensions are constant) the dimensions
* must be the same. The only way for the dimensions to be the same is if the output dimension is
* the intersection of the ranges. The current code only handles if one range is within the other,
* but it can be extended to do the intersection of the ranges.
* This case is mainly for handling unknown dynamic_dimensions like {0, max_int}.
* If the `dynamic_dimension`s have an intersection, return the intersection.
* Explanation:
* For the shape to be broadcastable at runtime (when the dimensions are constant) the dimensions
* must be the same. The only way for the dimensions to be the same is if the output dimension is
* the intersection of the ranges.
* In practice, we will mostly see this case for handling unknown dynamic_dimensions like {0,
* max_int}. Else, throw an error.
*
* There is a contrived edge case for ranges that include 1 but are not a fixed {1, 1}.
* That case is not supported.
Expand Down Expand Up @@ -134,12 +131,14 @@ MIGRAPHX_EXPORT
instruction_ref add_common_op(module& m, const operation& op, std::vector<instruction_ref> inputs);

/**
* Calculates the broadcasted strides from broadcast_for_dot or multibroadcast.
* Calculates the broadcasted shape with the given input_shape and broadcasted dimensions.
*
* @param input_shape static shape to broadcast
* @param bcast_lens dimensions to broadcast to
* @return broadcasted shape with calculated strides
*/
MIGRAPHX_EXPORT
shape make_bcast_shape(const shape& input_shape,
const std::vector<std::size_t>& bcast_lens,
const std::size_t& offset);
shape make_bcast_shape(const shape& input_shape, const std::vector<std::size_t>& bcast_lens);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
3 changes: 1 addition & 2 deletions src/include/migraphx/op/broadcast_for_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ struct broadcast_for_dot
std::vector<std::size_t> l1_broadcasted_lens(s1.lens().begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
output_lens.insert(output_lens.end(), l0_it, s0.lens().end());
auto offset = output_lens.size() - s0.ndim();
return make_bcast_shape(s0, output_lens, offset);
return make_bcast_shape(s0, output_lens);
}
}

Expand Down
43 changes: 18 additions & 25 deletions src/include/migraphx/op/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,23 @@ struct dot
auto s1 = b.to_dynamic();
std::vector<shape::dynamic_dimension> out_dyn_dims;

// check outer dimensions are within range
// put within range dynamic_dimensions into the out_dyn_dims
bool outers_within_range = std::equal(s0.dyn_dims().begin(),
s0.dyn_dims().end() - 2,
s1.dyn_dims().begin(),
s1.dyn_dims().end() - 2,
[&](auto x, auto y) {
if(x.within_range(y))
{
out_dyn_dims.push_back(x);
return true;
}
if(y.within_range(x))
{
out_dyn_dims.push_back(y);
return true;
}
return false;
});
// check outer dynamic dimensions are the same
bool same_outers = std::equal(s0.dyn_dims().begin(),
s0.dyn_dims().end() - 2,
s1.dyn_dims().begin(),
s1.dyn_dims().end() - 2,
[&](auto x, auto y) {
if(x == y)
{
out_dyn_dims.push_back(x);
return true;
}
return false;
});

if(not outers_within_range)
if(not same_outers)
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch or not within "
"dynamic_dimension range: {" +
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B are not compatible: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
Expand All @@ -89,10 +82,10 @@ struct dot
auto x = s0.dyn_dims()[dim_j];
auto y = s1.dyn_dims()[dim_i];

// check inner dimensions are within range
if(not x.within_range(y) and not y.within_range(x))
// check inner dimensions are compatible
if(not x.intersection(y).has_value())
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
MIGRAPHX_THROW("DOT: dynamic inner dimensions are not compatible: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
Expand Down
5 changes: 2 additions & 3 deletions src/include/migraphx/op/multibroadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct multibroadcast
}
}

return make_bcast_shape(s0, output_lens, offset);
return make_bcast_shape(s0, output_lens);
}
else
{
Expand All @@ -106,8 +106,7 @@ struct multibroadcast
{
// output_lens will not be set for 2+ input version
auto bcast_lens = compute_common_lens(inputs);
auto offset = bcast_lens.size() - s0.ndim();
return make_bcast_shape(s0, bcast_lens, offset);
return make_bcast_shape(s0, bcast_lens);
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,19 @@ struct MIGRAPHX_EXPORT shape
bool is_fixed() const;
bool has_optimal() const;

bool within_range(const dynamic_dimension& other) const
/**
* Return a dynamic_dimension with the intersection of two dynamic_dimension ranges if
* possible.
*/
std::optional<dynamic_dimension> intersection(const dynamic_dimension& other) const
{
return ((this->min >= other.min) and (this->max <= other.max));
auto left = std::max(this->min, other.min);
auto right = std::min(this->max, other.max);
if(left <= right)
{
return dynamic_dimension{left, right};
}
return nullopt;
}

MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x,
Expand Down
76 changes: 52 additions & 24 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,24 +802,14 @@ TEST_CASE(dot_dyn_static_test0)

TEST_CASE(dot_dyn_static_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 3}, {5, 5}, {5, 5}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{3, 3}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {3, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}

TEST_CASE(dot_dyn_static_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2, 2}, {3, 3}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}

TEST_CASE(dot_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
Expand All @@ -842,7 +832,7 @@ TEST_CASE(dot_dyn_test1)

TEST_CASE(dot_dyn_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 20}, {5, 5}, {5, 5}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 1}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
Expand All @@ -863,20 +853,27 @@ TEST_CASE(dot_dyn_test3)

TEST_CASE(dot_dyn_test4)
{
std::size_t max_val = std::numeric_limits<std::size_t>::max();
migraphx::shape s_m1{migraphx::shape::float_type, {{0, max_val}, {5, 5}, {0, max_val}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 8}, {5, 5}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{4, 8}, {5, 5}, {8, 8}}},
// Note how the inner dimensions have an intersection in range
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {4, 8}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{1, 4}, {5, 9}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}

TEST_CASE(dot_dyn_mismatcher_outer)
TEST_CASE(dot_dyn_inner_mismatch)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {4, 8}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{1, 4}, {10, 20}, {8, 8}}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}

TEST_CASE(dot_dyn_test_outer_mismatch)
{

migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {1, 4}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{3, 8}, {5, 5}, {6, 8, {8}}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 8}, {5, 5}, {6, 8, {8}}}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}

Expand Down Expand Up @@ -929,6 +926,22 @@ TEST_CASE(broadcast_for_dot_dyn1)
s0);
}

TEST_CASE(broadcast_for_dot_dyn2)
{
migraphx::shape s0{migraphx::shape::float_type, {{6, 12}, {4, 4}, {8, 8}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {2, 10}, {8, 8}, {4, 4}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {4, 4}, {8, 8}}},
migraphx::make_op("broadcast_for_dot"),
s0,
s1);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, {1, 2, 4}}, {6, 10}, {8, 8}, {4, 4}}},
migraphx::make_op("broadcast_for_dot"),
s1,
s0);
}

TEST_CASE(flatten_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
Expand Down Expand Up @@ -1800,9 +1813,9 @@ TEST_CASE(multibroadcast_2in_static_dyn2)
a_shape);
}

TEST_CASE(multibroadcast_2in_static_dyn_within0)
TEST_CASE(multibroadcast_2in_static_dyn_intersection0)
{
// dynamic_dimension.within_range for first dimension
// dynamic_dimension.intersection for first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
Expand All @@ -1816,9 +1829,24 @@ TEST_CASE(multibroadcast_2in_static_dyn_within0)
a_shape);
}

TEST_CASE(multibroadcast_2in_static_dyn_within1)
TEST_CASE(multibroadcast_2in_static_dyn_intersection1)
{
std::vector<migraphx::shape::dynamic_dimension> a_dds{{5, 10}, {1, 6}};
migraphx::shape a_shape{migraphx::shape::float_type, a_dds};
std::vector<migraphx::shape::dynamic_dimension> b_dds{{3, 8}, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b_dds};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{5, 8}, {3, 6}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{5, 8}, {3, 6}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}

TEST_CASE(multibroadcast_2in_static_dyn_intersection2)
{
// dynamic_dimension.within_range for first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
auto max_val = std::numeric_limits<std::size_t>::max();
std::vector<migraphx::shape::dynamic_dimension> b{{0, max_val}, {6, 6}};
Expand All @@ -1833,9 +1861,9 @@ TEST_CASE(multibroadcast_2in_static_dyn_within1)
a_shape);
}

TEST_CASE(multibroadcast_2in_static_dyn_within_error)
TEST_CASE(multibroadcast_2in_static_dyn_intersection_error)
{
// not dynamic_dimension.within_range for first dimension
// not compatible for first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
Expand Down
36 changes: 23 additions & 13 deletions test/shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,33 @@ TEST_CASE(dynamic_dimension_add_sub_fixed)
EXPECT((2 + e) == d);
}

TEST_CASE(dynamic_dimension_within_range)
TEST_CASE(dynamic_dimension_intersection)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, {2, 5}};
auto b = shape::dynamic_dimension{3, 4};
EXPECT(b.within_range(a));
EXPECT(not a.within_range(b));

auto c = shape::dynamic_dimension{3, 4};
EXPECT(c.within_range(b));
EXPECT(b.within_range(c));

auto d = shape::dynamic_dimension{0, std::numeric_limits<std::size_t>::max()};
EXPECT(a.within_range(d));
EXPECT(b.within_range(d));
EXPECT(not d.within_range(a));
EXPECT(not d.within_range(b));
auto aib = a.intersection(b);
auto bia = b.intersection(a);
EXPECT(aib.has_value());
EXPECT(bia.has_value());
EXPECT(aib.value() == shape::dynamic_dimension{3, 4});
EXPECT(aib.value() == bia.value());

auto c = shape::dynamic_dimension{3, 8};
auto cia = c.intersection(a);
EXPECT(cia.value() == shape::dynamic_dimension{3, 5});

auto d = shape::dynamic_dimension{8, 10};
auto dib = d.intersection(b);
EXPECT(not dib.has_value());

auto e = shape::dynamic_dimension{4, 10};
auto eib = e.intersection(b);
EXPECT(eib.value() == shape::dynamic_dimension{4, 4});

auto f = shape::dynamic_dimension{0, std::numeric_limits<std::size_t>::max()};
auto fib = f.intersection(b);
EXPECT(fib.value() == shape::dynamic_dimension{3, 4});
}

TEST_CASE(dynamic_dimension_serialize)
Expand Down

0 comments on commit 84fc9f0

Please sign in to comment.