Skip to content

Commit

Permalink
Add dilated pooling rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-dusnoki-htec committed Sep 8, 2023
1 parent 8ddbfa5 commit f361d5c
Show file tree
Hide file tree
Showing 4 changed files with 620 additions and 27 deletions.
5 changes: 5 additions & 0 deletions src/include/migraphx/rewrite_pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <string>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -39,6 +40,10 @@ struct MIGRAPHX_EXPORT rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(module& m) const;

private:
void replace_with_reduce(module& m, instruction_ref ins) const;
void replace_dilations_with_gather_pooling(module& m, instruction_ref ins) const;
};

} // namespace MIGRAPHX_INLINE_NS
Expand Down
162 changes: 136 additions & 26 deletions src/rewrite_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,148 @@ void rewrite_pooling::apply(module& m) const
auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
if(not std::all_of(op.dilations.begin(), op.dilations.end(), [](auto i) { return i == 1; }))
continue;
auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};

// average pooling
if(op.mode == op::pooling_mode::average)
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
bool default_strides =
std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
bool default_dilations =
std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
{
replace_with_reduce(m, ins);
}
else if(not default_dilations)
{
// Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}

replace_dilations_with_gather_pooling(m, ins);
}
}
}

void rewrite_pooling::replace_with_reduce(module& m, instruction_ref ins) const
{
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};

// average pooling
if(op.mode == op::pooling_mode::average)
{
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
}
// max pooling
else
{
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
}

std::vector<int64_t> rsp_lens(s.lens().size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
}

void rewrite_pooling::replace_dilations_with_gather_pooling(module& m, instruction_ref ins) const
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};

bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });

if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
// max pooling
else
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;

std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());

for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}

auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}

std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}

for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}

// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
3 changes: 2 additions & 1 deletion test/py/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def create_backend_test(testname=None, target_device=None):
r'test_argmin_no_keepdims_example_select_last_index_cpu')
backend_test.exclude(r'test_lrn_cpu')
backend_test.exclude(r'test_lrn_default_cpu')
backend_test.exclude(r'test_maxpool_2d_dilations_cpu')
# MaxPool dialtion is partially supported on GPU by a workaround
# But these tests require too large allocations to work properly
backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu')
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
backend_test.exclude(
Expand Down
Loading

0 comments on commit f361d5c

Please sign in to comment.