Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for axes input to reduce operators #2120

Merged
merged 36 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e7243d9
Implement parsing support for onnx opset 18 versions of reduce operators
music-dino Aug 18, 2023
714b10e
Implement codepath for handling a static data shape for reduce operators
music-dino Aug 19, 2023
ae024f2
Implement dynamic shape computation and refactor existing code
music-dino Aug 19, 2023
44be9e5
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Aug 21, 2023
904de2a
Add reference operator tests for ReduceMin, ReduceMax, and ReduceMean
music-dino Aug 23, 2023
a0f8e52
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Aug 24, 2023
4316e48
Disable failing onnx backend tests
music-dino Aug 31, 2023
b024436
Remove noop_with_empty_axes_attribute, fix code style issues
music-dino Sep 6, 2023
5b2f4c3
Restore ref_ops_test.cpp
music-dino Sep 6, 2023
8b91ef0
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Sep 6, 2023
4b200ca
Restore noop_with_empty_axes attribute
music-dino Sep 24, 2023
844e358
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Sep 25, 2023
075f663
Add op shape tests for variable axes case
music-dino Sep 26, 2023
2919f1d
Implement onnx parsing tests for variable axes
music-dino Sep 26, 2023
2fbd708
Add reference operator tests
music-dino Sep 26, 2023
5762d6c
Merge branch 'develop' into reduce_op_opset_18_compat
causten Oct 1, 2023
af06175
Update verify_range uses to verify_rms_range in reduce op reference t…
music-dino Oct 2, 2023
7d8bede
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Nov 29, 2023
53877e8
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Jan 24, 2024
0cd9b14
Rejig reduce operator parsing to remove complexity from the operator …
music-dino Jan 25, 2024
0bcca87
Implement parsing tests for ReduceL1
music-dino Jan 29, 2024
71d00e6
Implement parse tests for ReduceSum
music-dino Jan 29, 2024
2e324f8
Delete ReduceL1 onnx test files
music-dino Jan 29, 2024
05338d4
Update test name
music-dino Jan 29, 2024
399cde5
Update Where operator to handle scalar predicate
music-dino Feb 7, 2024
c6122d8
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Feb 7, 2024
95fa52e
Implement additional tests, fix parsing and licensing
music-dino Feb 7, 2024
5f7b702
Fix formatting, cppcheck, and tidy issues
music-dino Feb 7, 2024
2784e31
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Feb 26, 2024
2e6e5e8
Remove superfluous onnx test files, and add additional test cases for…
music-dino Feb 27, 2024
c350c5f
Add some documentation for reduce_op, and implement additional onnx_v…
music-dino Feb 28, 2024
9efb952
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Feb 28, 2024
13364ec
Fix formatting issue in gen_onnx.py
music-dino Feb 28, 2024
d92a664
Fixed formatting issue
music-dino Feb 28, 2024
766ecfe
Update code comments
music-dino Mar 5, 2024
bbdc7ac
Merge remote-tracking branch 'upstream/develop' into reduce_op_opset_…
music-dino Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 89 additions & 40 deletions src/include/migraphx/op/reduce_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,57 @@
return {{"normalize_axes", normalize}, {"reduce", true}};
}

std::vector<int64_t> tune_axes(std::size_t n_dim) const
shape collapse_reduced_axes(const shape& original_shape,
const std::vector<int64_t>& reduce_axes) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
auto lens = original_shape.lens();
for(const auto a : reduce_axes)
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
lens[a] = 1;
}

return tuned_axes;
return original_shape.with_lens(lens);
}

shape compute_dynamic_shape(const std::vector<shape>& inputs) const
{
const auto& data_shape = inputs[0];
auto dims = data_shape.dyn_dims();
if(axes.empty())
{
for(auto& dim : dims)
{
dim = {1, dim.max};
}
}
else
{
for(auto a : axes)
{
dims[a] = {1, 1};
}
}

return {data_shape.type(), dims};
}

shape compute_static_shape(const std::vector<shape>& inputs) const
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
{
const auto& data_shape = inputs[0];
if(axes.empty())
{
std::vector<shape::dynamic_dimension> dims(data_shape.ndim());
auto lens = data_shape.lens();
std::transform(lens.begin(), lens.end(), dims.begin(), [](auto l) {
return shape::dynamic_dimension{1, l};
lakhinderwalia marked this conversation as resolved.
Show resolved Hide resolved
});

return {data_shape.type(), std::move(dims)};
}
else
{
return collapse_reduced_axes(data_shape, axes);
}
}

/**
Expand All @@ -115,29 +156,16 @@
*/
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0);
if(s.dynamic())
{
auto output_dyn_dims = s.dyn_dims();
auto tuned_axes = tune_axes(output_dyn_dims.size());
for(const auto& axis : tuned_axes)
{
output_dyn_dims[axis] = {1, 1};
}
auto expected_arg_count = axes.empty() ? 2 : 1;
check_shapes{inputs, *this, true}.has(expected_arg_count);

return shape{s.type(), output_dyn_dims};
if(inputs[0].dynamic())
{
return compute_dynamic_shape(inputs);
}
else
{
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(const auto& axis : tuned_axes)
{
lens[axis] = 1;
}

return inputs[0].with_lens(lens);
return compute_static_shape(inputs);
}
}

Expand All @@ -153,10 +181,10 @@
}

template <class T>
void reduce(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
void reduce(const tensor_view<T>& input,
const shape& batch_shape,
const std::vector<int64_t>& tuned_axes,
const std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
using accumulator = accumulator_type<T>;
Expand All @@ -173,24 +201,45 @@
static_cast<const Derived&>(*this).output(batch_shape)(val);
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
argument reduce(const shape& computed_shape,
const std::vector<int64_t>& reduce_axes,
argument& data_arg) const
{
argument result{dyn_out.computed_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(dyn_out.computed_shape.lens().size(), 1);
tune_dims(tuned_axes, arg_lens, batch_lens);
shape batch_shape{dyn_out.computed_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto out_idx = dyn_out.computed_shape.multi(i);
this->reduce(input, batch_shape, tuned_axes, out_idx, output);
std::vector<std::size_t> batch_lens(computed_shape.ndim(), 1);
auto arg_lens = data_arg.get_shape().lens();
tune_dims(reduce_axes, arg_lens, batch_lens);
shape batch_shape{computed_shape.type(), batch_lens};
argument result{computed_shape};

visit_all(result, data_arg)([&](auto output, auto input) {
par_for(computed_shape.elements(), [&](auto i) {
auto out_idx = computed_shape.multi(i);
this->reduce(input, batch_shape, reduce_axes, out_idx, output);
});
});

return result;
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
auto&& data_arg = args[0];
if(not axes.empty())
{
return reduce(dyn_out.computed_shape, axes, data_arg);
}

if(axes.empty() and args[1].empty()) {

Check warning on line 232 in src/include/migraphx/op/reduce_op.hpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Condition 'axes.empty()' is always true [knownConditionTrueFalse]
return args[0];
}

std::vector<int64_t> reduce_axes;
args[1].visit([&](auto&& s) { reduce_axes.assign(s.begin(), s.end()); });
const auto result_shape = collapse_reduced_axes(data_arg.get_shape(), reduce_axes);

return reduce(result_shape, reduce_axes, data_arg);
}

auto init() const { return zero(); }

auto input() const
Expand Down
Loading
Loading