Skip to content

Commit

Permalink
Add support for the dilations attribute to Pooling ops (#2105)
Browse files Browse the repository at this point in the history
Introduce dilations attribute to pooling operators reference implementation.
  • Loading branch information
mirza-halilcevic authored Nov 22, 2023
1 parent c7bae54 commit 19bd9c4
Show file tree
Hide file tree
Showing 24 changed files with 1,170 additions and 134 deletions.
51 changes: 40 additions & 11 deletions src/include/migraphx/op/pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ struct pooling
// 2 smaller than the input tensor rank (NCHW layout)
std::vector<std::size_t> lengths = {1, 1};

// Dilations are not supported at this time.
// Spacing between the elements of the pooling kernel. Must be the same ndim as lengths.
std::vector<std::size_t> dilations = {1, 1};

// ceiling mode is a flag affecting output size
// or equivalently, placements of the pooling kernel.
Expand Down Expand Up @@ -99,6 +100,7 @@ struct pooling
f(self.padding_mode, "padding_mode"),
f(self.stride, "stride"),
f(self.lengths, "lengths"),
f(self.dilations, "dilations"),
f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
Expand All @@ -112,14 +114,17 @@ struct pooling
return;
if((padding_mode != default_ and padding.size() != stride.size() and
(padding.size()) != stride.size() * 2) or
stride.size() != lengths.size())
stride.size() != lengths.size() or dilations.size() != lengths.size())
{
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
}
if(std::any_of(lengths.begin(), lengths.end(), [&](auto i) { return (i == 0); }) or
std::any_of(stride.begin(), stride.end(), [&](auto i) { return (i == 0); }))

const auto is_zero = [](auto el) { return el == 0; };
if(std::any_of(lengths.begin(), lengths.end(), is_zero) or
std::any_of(stride.begin(), stride.end(), is_zero) or
std::any_of(dilations.begin(), dilations.end(), is_zero))
{
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride");
MIGRAPHX_THROW("POOLING: size 0 pooling kernel or stride or dilations");
}

// TODO: update lowering to run the reference
Expand All @@ -142,6 +147,11 @@ struct pooling

value attributes() const { return {{"normalize_padding", "padding"}}; }

inline std::size_t dilate_dim(std::size_t dim, std::size_t dilation) const
{
return 1 + dilation * (dim - 1);
}

std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
Expand All @@ -151,8 +161,9 @@ struct pooling
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
std::size_t dilated_length = dilate_dim(lengths[i], dilations[i]);
std::size_t dim_size;
if(input_lens[i + 2] + padding_factor < lengths[i])
if(input_lens[i + 2] + padding_factor < dilated_length)
{
if(padding_mode == default_)
MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size");
Expand All @@ -162,7 +173,7 @@ struct pooling
}
else
{
dim_size = input_lens[i + 2] + padding_factor - lengths[i];
dim_size = input_lens[i + 2] + padding_factor - dilated_length;
}
std::size_t len =
(ceil_mode)
Expand Down Expand Up @@ -331,6 +342,7 @@ struct pooling
int start = static_cast<int>(idx_o[dim] * stride[d_2]) -
static_cast<int>(padding_vals[d_2]);
int end;
std::size_t dilated_kernel_dim = dilate_dim(kernel_dims[d_2], dilations[d_2]);
// NOLINT
if(count_include_pad and ceil_mode and (mode != pooling_mode::max))
{
Expand All @@ -340,15 +352,14 @@ struct pooling
// padding. Clip out-of-bounds indexes but not padding.

// Check if this kernel extends beyond the padding at end of dimension
end = std::min(start + kernel_dims[d_2],
end = std::min(start + dilated_kernel_dim,
in_lens[dim] + static_cast<int>(padding_vals[d_2]));
}
else
{
// In non-ceiling mode, when
// count_include_pad is false, or for max pooling, clip off padding.
end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0);
end = std::min(start + dilated_kernel_dim, in_lens[dim]);
}
win_start.push_back(start);
if(end < start)
Expand All @@ -366,6 +377,16 @@ struct pooling

// for each element in the window...
shape_for_each(win_shape, [&](const auto& idx_w) {
// Skip elements that belong to the dilated area
for(size_t axis = 0; axis < idx_w.size(); ++axis)
{
if(idx_w[axis] % dilations[axis])
{
pool_size -= 1;
return;
}
}

// the coordinates of this element
auto idx = idx_o;

Expand All @@ -390,7 +411,15 @@ struct pooling
// this is a padding element. Padding locations
// don't contribute to average or max pooling total but can play in
// lpnorm pooling.
output_val = op(output_val, 0);
if(mode == pooling_mode::lpnorm)
{
output_val = op(output_val, op.template init<Type>());
}
if(mode == pooling_mode::average)
{
// Ignore padding
pool_size -= 1;
}
}
});
output[i] = Type(op.final(output_val, pool_size));
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/rewrite_pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <string>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down
18 changes: 15 additions & 3 deletions src/onnx/parse_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ struct parse_pooling : op_parser<parse_pooling>
kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
}

if(contains(info.attributes, "dilations"))
{
values["dilations"].clear();
copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilations"]));
check_attr_sizes(
kdims, values["dilations"].size(), "PARSE_POOLING: inconsistent dilations");
}

// lp_order attribute
if(contains(info.attributes, "p"))
{
Expand Down Expand Up @@ -169,10 +177,15 @@ struct parse_pooling : op_parser<parse_pooling>
std::fill_n(values["stride"].begin(), kdims, 1);
}

if(values["dilations"].size() != kdims)
{
values["dilations"].resize(kdims);
std::fill_n(values["dilations"].begin(), kdims, 1);
}

// used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings;

// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
Expand All @@ -189,11 +202,10 @@ struct parse_pooling : op_parser<parse_pooling>
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
values["dilations"].to_vector<std::size_t>(),
in_shape.lens(),
paddings);
values["padding"] = paddings;
Expand Down
148 changes: 131 additions & 17 deletions src/rewrite_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,110 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

static void replace_with_reduce(module& m, instruction_ref ins)
{
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
auto lens = s.lens();
std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);

// average pooling
if(op.mode == op::pooling_mode::average)
{
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
}
// max pooling
else
{
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
}
}

static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
// Ignore N, C axes
std::vector<size_t> dims = {s.lens().cbegin() + 2, s.lens().cend()};

bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });

if(not default_padding)
{
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// We need to pad both ends
dims[idx] += op.padding.at(idx) * 2;
}
}
std::vector<size_t> kernels = op.lengths;
std::vector<size_t> strides = op.stride;
std::vector<size_t> dilations = op.dilations;

std::vector<std::vector<int>> axis_indices;
axis_indices.resize(dims.size());

for(auto idx{0}; idx < dims.size(); ++idx)
{
// Only consider if iw fits into the window
for(size_t stride{0}; stride < dims.at(idx) - dilations.at(idx) * (kernels.at(idx) - 1);
stride += strides.at(idx))
{
for(size_t step{0}; step < kernels.at(idx); ++step)
{
axis_indices.at(idx).push_back(stride + dilations.at(idx) * step);
}
}
}

auto elements = ins->inputs().front();
if(not default_padding)
{
// Pad supports asym, we need to provide both ends
std::vector<size_t> padding(2 * s.lens().size(), 0);
// Format will be e.g {N, C, P1, P2, N, C, P1, P2}
for(size_t idx{0}; idx < op.padding.size(); ++idx)
{
// Ignore N, C axes
padding.at(2 + idx) = op.padding.at(idx);
padding.at(2 + idx + s.lens().size()) = op.padding.at(idx);
}

// Default value needed for Max pooling
elements = m.insert_instruction(
ins,
make_op("pad", {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
elements);
}

for(auto idx{0}; idx < axis_indices.size(); ++idx)
{
migraphx::shape s_indices{migraphx::shape::int32_type, {axis_indices.at(idx).size()}};
auto indices = m.add_literal(migraphx::literal{s_indices, axis_indices.at(idx)});
elements = m.insert_instruction(
ins, make_op("gather", {{"axis", idx + 2 /*ignore N,C*/}}), elements, indices);
}

// Ignore padding
std::vector<size_t> new_padding(kernels.size(), 0);
// The kernel window elements are places next to each other. E.g. {x1, y1, x2, y2, ...}
// We need to skip them to not overlap
std::vector<size_t> new_strides(kernels);
// Ignore dilations
std::vector<size_t> new_dilations(kernels.size(), 1);
m.replace_instruction(ins,
make_op("pooling",
{{"mode", op.mode},
{"padding", new_padding},
{"stride", new_strides},
{"lengths", kernels},
{"dilations", new_dilations}}),
elements);
}

void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(m))
Expand All @@ -43,26 +147,36 @@ void rewrite_pooling::apply(module& m) const
continue;
if(ins->inputs().empty())
continue;
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::vector<std::int64_t> axes(lens.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// average pooling
if(op.mode == op::pooling_mode::average)
auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
s.lens().cbegin() + 2, s.lens().cend(), op.lengths.cbegin(), op.lengths.cend());
bool default_strides =
std::all_of(op.stride.cbegin(), op.stride.cend(), [](auto i) { return i == 1; });
bool default_padding =
std::all_of(op.padding.cbegin(), op.padding.cend(), [](auto i) { return i == 0; });
bool default_dilations =
std::all_of(op.dilations.cbegin(), op.dilations.cend(), [](auto i) { return i == 1; });
if(same_kernel_as_shape and default_strides and default_padding and default_dilations)
{
m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
replace_with_reduce(m, ins);
}
// max pooling
else
else if(not default_dilations)
{
m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
// Dilated AvgPool with padding is not supported
if(not default_padding and op.mode == op::pooling_mode::average)
{
continue;
}
auto size =
std::accumulate(s.lens().cbegin(), s.lens().cend(), 1, std::multiplies<size_t>());
// Can't handle too much size because of literal size
if(size > 100000)
{
continue;
}

replace_dilations_with_gather_pooling(m, ins);
}
}
}
Expand Down
17 changes: 13 additions & 4 deletions src/targets/cpu/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,32 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {

struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_v2_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }

dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
dnnl::pooling_v2_forward::desc
get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
// Note: It is not documented, but the default dilation seems to be 0 instead of 1.
// We need to offset dilations with -1.
std::vector<size_t> dilations;
std::transform(op.dilations.cbegin(),
op.dilations.cend(),
std::back_inserter(dilations),
[](size_t d) { return d - 1; });
return {dnnl::prop_kind::forward_inference,
algo,
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths),
to_dnnl_dims(dilations),
to_dnnl_dims(padding_l),
to_dnnl_dims(padding_r)};
}
Expand Down
6 changes: 6 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/miopen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
if(not std::all_of(
op.dilations.cbegin(), op.dilations.cend(), [](std::size_t d) { return d == 1; }))
{
MIGRAPHX_THROW("Unsupported dilations for pooling: [" + to_string_range(op.dilations) +
"]");
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);

int kdims = op.kdims();
Expand Down
Loading

0 comments on commit 19bd9c4

Please sign in to comment.