Skip to content

Commit

Permalink
Add axes (optional) input to Pad (#2178)
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-dusnoki-htec authored Oct 17, 2023
1 parent 52c74f0 commit 94bda24
Show file tree
Hide file tree
Showing 9 changed files with 396 additions and 33 deletions.
147 changes: 114 additions & 33 deletions src/onnx/parse_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,34 +115,9 @@ struct parse_pad : op_parser<parse_pad>
{
std::vector<op_desc> operators() const { return {{"Pad"}}; }

instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
std::string parse_mode(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
else
{
MIGRAPHX_THROW("PARSE_PAD: pad must be available");
}

// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}

if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
Expand All @@ -152,35 +127,141 @@ struct parse_pad : op_parser<parse_pad>
{
MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported");
}
return reflect_pad(info, pads, args.front());
}
if(mode != "constant")
else if(mode != "constant")
{
MIGRAPHX_THROW(
"PARSE_PAD: migraphx currently only supports constant and reflect padding");
}
return mode;
}
else
{
// default mode
return "constant";
}
}

std::vector<int64_t> parse_pads(const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads{};
if(args.size() >= 2)
{
auto pad_arg = args.at(1)->eval();
check_arg_empty(pad_arg, "PARSE_PAD: `pads` input must be constant");
pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
}
else if(contains(info.attributes, "pads"))
{
auto&& pad_vals = info.attributes.at("pads").ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
else
{
MIGRAPHX_THROW("PARSE_PAD: `pads` must be available");
}
return pads;
}

float parse_constant_value(const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
float value = 0.0f;
// third input is the value
if(args.size() == 3)
if(args.size() >= 3 and args.at(2)->get_shape().scalar())
{
auto val_ins = args.at(2);
if(not val_ins->can_eval())
{
MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
MIGRAPHX_THROW("PARSE_PAD: input `value` must be constant");
}
auto val_arg = val_ins->eval();
if(val_arg.get_shape().elements() != 1)
{
MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
MIGRAPHX_THROW("PARSE_PAD: `value` should contain only one element");
}
value = val_arg.at<float>();
}
else if(contains(info.attributes, "value"))
{
value = parser.parse_value(info.attributes.at("value")).at<float>();
}
return value;
}

std::vector<int64_t> parse_axes(const std::vector<instruction_ref>& args,
bool is_constant_mode) const
{
std::vector<int64_t> axes{};
// axes is 3rd or 4th, depending on constant mode
auto pos = is_constant_mode ? 4 : 3;
if(args.size() >= pos)
{
auto axes_arg = args.at(pos - 1)->eval();
check_arg_empty(axes_arg, "PARSE_PAD: variable `axes` input not supported");
axes_arg.visit([&](auto v) { axes.assign(v.begin(), v.end()); });
}
return axes;
}

std::vector<int64_t> calculate_pads_with_axes(const std::vector<int64_t>& pads,
const std::vector<int64_t>& axes,
size_t input_rank) const
{
size_t num_axes = axes.size();
if(num_axes * 2 != pads.size())
{
MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * "
"number of elements of axes");
}

std::vector<int64_t> new_pads(input_rank * 2);
for(size_t idx{0}; idx < num_axes; ++idx)
{
// axis can be negative
int64_t axis = axes[idx] < 0 ? input_rank + axes[idx] : axes[idx];
// pad format is x1_begin, x2_begin, ... , x3_end, x4_end
new_pads[axis] = pads[idx];
new_pads[axis + input_rank] = pads[idx + num_axes];
}
return new_pads;
}

instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
std::vector<int64_t> pads = parse_pads(info, args);

// check if padding is actually being done (at least one value is nonzero)
if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
{
return info.add_instruction(make_op("identity"), args.front());
}

std::string mode = parse_mode(info, args);
bool is_constant_mode = mode == "constant";
float value = is_constant_mode ? parse_constant_value(parser, info, args) : 0.0f;
std::vector<int64_t> axes = parse_axes(args, is_constant_mode);
size_t input_rank = args.front()->get_shape().ndim();

if(not axes.empty())
{
pads = calculate_pads_with_axes(pads, axes, input_rank);
}

if(pads.size() != input_rank * 2)
{
MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * "
"input rank");
}

if(mode == "reflect")
{
return reflect_pad(info, pads, args.front());
}

return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front());
Expand Down
182 changes: 182 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5107,6 +5107,32 @@ def pad_test():
return ([node], [x], [y])


@onnx_test()
def pad_asym_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])

node = onnx.helper.make_node('Pad',
inputs=['0'],
pads=[0, 1, 0, 3, 0, 2, 0, 4],
outputs=['1'])

return ([node], [x], [y])


@onnx_test()
def pad_asym_invalid_pads_error_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])

node = onnx.helper.make_node('Pad',
inputs=['0'],
pads=[0, 1, 0, 3, 0, 2],
outputs=['1'])

return ([node], [x], [y])


@onnx_test()
def pad_3arg_test():
values = np.array([1])
Expand Down Expand Up @@ -5139,6 +5165,129 @@ def pad_3arg_test():
return ([arg_val, arg_pad, node], [x], [y])


@onnx_test()
def pad_4arg_axes_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)

sizes = np.array([1, 3, 2, 4])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)

axes = np.array([1, 3])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)

x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])

node = onnx.helper.make_node(
'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1'])

return ([arg_axes, arg_val, arg_pad, node], [x], [y])


@onnx_test()
def pad_4arg_invalid_axes_error_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)

sizes = np.array([1, 3, 2, 4])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)

axes = np.array([1, 2, 3])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)

x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])

node = onnx.helper.make_node(
'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1'])

return ([arg_axes, arg_val, arg_pad, node], [x], [y])


@onnx_test()
def pad_4arg_neg_axes_test():
values = np.array([1])
val_tensor = helper.make_tensor(name='val',
data_type=TensorProto.FLOAT,
dims=values.reshape(()).shape,
vals=values.astype(float))
arg_val = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_val'],
value=val_tensor)

sizes = np.array([1, 3, 2, 4])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)

axes = np.array([-3, -1])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)

x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12])

node = onnx.helper.make_node(
'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1'])

return ([arg_axes, arg_val, arg_pad, node], [x], [y])


@onnx_test()
def pad_reflect_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
Expand All @@ -5162,6 +5311,39 @@ def pad_reflect_test():
return ([arg_pad, node], [x], [y])


@onnx_test()
def pad_reflect_with_axes_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5])

sizes = np.array([2, 1])
pad_tensor = helper.make_tensor(name='pad_size',
data_type=TensorProto.INT32,
dims=sizes.shape,
vals=sizes.astype(int))
arg_pad = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_pad'],
value=pad_tensor)

axes = np.array([1])
axes_tensor = helper.make_tensor(name='pad_axes',
data_type=TensorProto.INT32,
dims=axes.shape,
vals=axes.astype(int))
arg_axes = onnx.helper.make_node('Constant',
inputs=[],
outputs=['arg_axes'],
value=axes_tensor)

node = onnx.helper.make_node('Pad',
mode='reflect',
inputs=['0', 'arg_pad', 'arg_axes'],
outputs=['1'])

return ([arg_axes, arg_pad, node], [x], [y])


@onnx_test()
def pad_reflect_multiaxis_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3])
Expand Down
Loading

0 comments on commit 94bda24

Please sign in to comment.