Skip to content

Commit

Permalink
Add support for axes input to reduce operators (#2120)
Browse files Browse the repository at this point in the history
  • Loading branch information
music-dino authored Mar 6, 2024
1 parent effedcd commit 11f6abc
Show file tree
Hide file tree
Showing 23 changed files with 1,293 additions and 119 deletions.
141 changes: 100 additions & 41 deletions src/include/migraphx/op/reduce_op.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -94,16 +94,69 @@ struct reduce_op : op_name<Derived>
return {{"normalize_axes", normalize}, {"reduce", true}};
}

std::vector<int64_t> tune_axes(std::size_t n_dim) const
shape collapse_reduced_axes(const shape& original_shape,
const std::vector<int64_t>& reduce_axes) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
auto lens = original_shape.lens();
for(const auto a : reduce_axes)
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
lens[a] = 1;
}

return tuned_axes;
return original_shape.with_lens(lens);
}

// Compute the output shape for cases when the input tensor has a dynamic shape.
//
// If the axes are passed as a variable input(indicated by an empty axes attribute), we cannot
// determine which axes must be collapsed until we see the actual input values, so we must treat
// each axis as potentially collapsable and set its minimum dimension to 1.
shape compute_dynamic_shape(const std::vector<shape>& inputs) const
{
const auto& data_shape = inputs[0];
auto dims = data_shape.dyn_dims();
if(axes.empty())
{
for(auto& dim : dims)
{
dim = {1, dim.max};
}
}
else
{
for(auto a : axes)
{
dims[a] = {1, 1};
}
}

return {data_shape.type(), dims};
}

// Compute the output shape for cases when the input tensor has a static shape.
// Depending on how axes is passed to the operator the output shape can be either dynamic or
// static.
//
// If the axes are passed as a variable input(indicated by an empty axes attribute), we cannot
// determine which axes must be collapsed until we see the actual input values, so we must treat
// each axis as potentially collapsable, producing a dynamic output shape.
shape compute_static_shape(const std::vector<shape>& inputs) const
{
const auto& data_shape = inputs[0];
if(axes.empty())
{
std::vector<shape::dynamic_dimension> dims(data_shape.ndim());
auto lens = data_shape.lens();
std::transform(lens.begin(), lens.end(), dims.begin(), [](auto len) {
return shape::dynamic_dimension{1, len};
});

return {data_shape.type(), std::move(dims)};
}
else
{
return collapse_reduced_axes(data_shape, axes);
}
}

/**
Expand All @@ -115,29 +168,16 @@ struct reduce_op : op_name<Derived>
*/
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0);
if(s.dynamic())
{
auto output_dyn_dims = s.dyn_dims();
auto tuned_axes = tune_axes(output_dyn_dims.size());
for(const auto& axis : tuned_axes)
{
output_dyn_dims[axis] = {1, 1};
}
auto expected_arg_count = axes.empty() ? 2 : 1;
check_shapes{inputs, *this, true}.has(expected_arg_count);

return shape{s.type(), output_dyn_dims};
if(inputs[0].dynamic())
{
return compute_dynamic_shape(inputs);
}
else
{
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(const auto& axis : tuned_axes)
{
lens[axis] = 1;
}

return inputs[0].with_lens(lens);
return compute_static_shape(inputs);
}
}

Expand All @@ -153,10 +193,10 @@ struct reduce_op : op_name<Derived>
}

template <class T>
void reduce(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
void reduce(const tensor_view<T>& input,
const shape& batch_shape,
const std::vector<int64_t>& tuned_axes,
const std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
using accumulator = accumulator_type<T>;
Expand All @@ -173,24 +213,43 @@ struct reduce_op : op_name<Derived>
static_cast<const Derived&>(*this).output(batch_shape)(val);
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
argument reduce(const shape& computed_shape,
const std::vector<int64_t>& reduce_axes,
argument& data_arg) const
{
argument result{dyn_out.computed_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(dyn_out.computed_shape.lens().size(), 1);
tune_dims(tuned_axes, arg_lens, batch_lens);
shape batch_shape{dyn_out.computed_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto out_idx = dyn_out.computed_shape.multi(i);
this->reduce(input, batch_shape, tuned_axes, out_idx, output);
std::vector<std::size_t> batch_lens(computed_shape.ndim(), 1);
auto arg_lens = data_arg.get_shape().lens();
tune_dims(reduce_axes, arg_lens, batch_lens);
shape batch_shape{computed_shape.type(), batch_lens};
argument result{computed_shape};

visit_all(result, data_arg)([&](auto output, auto input) {
par_for(computed_shape.elements(), [&](auto i) {
auto out_idx = computed_shape.multi(i);
this->reduce(input, batch_shape, reduce_axes, out_idx, output);
});
});

return result;
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
auto&& data_arg = args[0];
// cppcheck-suppress knownConditionTrueFalse
if(not axes.empty())
return reduce(dyn_out.computed_shape, axes, data_arg);

if(args[1].get_shape().elements() == 0)
return args[0];

std::vector<int64_t> reduce_axes;
args[1].visit([&](auto&& s) { reduce_axes.assign(s.begin(), s.end()); });
const auto result_shape = collapse_reduced_axes(data_arg.get_shape(), reduce_axes);

return reduce(result_shape, reduce_axes, data_arg);
}

auto init() const { return zero(); }

auto input() const
Expand Down
22 changes: 17 additions & 5 deletions src/include/migraphx/op/where.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -42,7 +42,13 @@ struct where

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(3).same_dims();
check_shapes shape_checker{inputs, *this, true};
shape_checker.has(3);
if(auto s = inputs[0]; not s.dynamic() and s.elements() == 1)
check_shapes{std::next(inputs.begin()), inputs.end(), *this, true}.same_dims();
else
shape_checker.same_dims();

auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1.dynamic() or s2.dynamic())
Expand Down Expand Up @@ -71,12 +77,18 @@ struct where
}
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
if(auto s = args[0].get_shape(); not s.dynamic() and s.elements() == 1)
return args[args[0].at<bool>() ? 1 : 2].copy();

if(output_shape.dynamic())
output_shape = compute_shape(to_shapes(args));
argument result{output_shape};

visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(dyn_out.computed_shape.elements(),
par_for(output_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
Expand Down
Loading

0 comments on commit 11f6abc

Please sign in to comment.