Skip to content

Commit

Permalink
2 Input Reshape ref implementation (#2304)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Oct 17, 2023
1 parent a720061 commit f25606f
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 48 deletions.
58 changes: 49 additions & 9 deletions src/include/migraphx/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

/**
* 1 input version:
* reshape(input_data)
* this.dims = output_dims
* Makes a copy of input_data to the output shape.
*
* 2 input version:
* reshape(input_data, output_buffer)
* this.dims = unset
* Copies input_data to output_buffer; output_buffer already has the output shape.
* This version will not fail gracefully if the input shape and output_buffer shape are
* incompatible. There's a throw that will catch when the number of elements do not match at
* runtime. This version should only be used for dynamic reshapes (output dimensions only known at
* runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input
* version.
*/
struct reshape
{
std::vector<int64_t> dims;
Expand Down Expand Up @@ -215,32 +231,56 @@ struct reshape

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
check_shapes{inputs, *this, true}.has(1, 2);

auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");

auto s0 = inputs.front();
if(s0.dynamic())
if(inputs.size() == 1)
{
return dyn_compute_shape(s0);
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
else
{
return static_compute_shape(inputs, n_neg_dims);
return inputs.back();
}
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
if(args.size() == 1)
{
argument result{dyn_out.computed_shape};

visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
else
{
// 2 arg
if(args[0].get_shape().elements() != args[1].get_shape().elements())
{
MIGRAPHX_THROW("Reshape: Number of elements must match at runtime. Input: " +
std::to_string(args[0].get_shape().elements()) +
" Output buffer: " + std::to_string(args[1].get_shape().elements()));
}
visit_all(args[1], args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return args[1];
}
}
};

Expand Down
22 changes: 16 additions & 6 deletions src/onnx/parse_reshape.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -45,15 +45,25 @@ struct parse_reshape : op_parser<parse_reshape>
{
literal s = parser.parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
}
if(args.size() == 2)
else
{
// 2 inputs
auto s = args[1]->eval();
check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
if(s.empty())
{
// arg[1] not eval-able
auto alloc_ins = info.add_instruction(
make_op("allocate", {{"buf_type", args[0]->get_shape().type()}}), args[1]);
return info.add_instruction(make_op("reshape"), args[0], alloc_ins);
}
else
{
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
}
}

return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
}
};

Expand Down
18 changes: 18 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6065,6 +6065,24 @@ def reshape_non_standard_test():
return ([trans, res], [x], [y])


@onnx_test()
def reshape_variable_input_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [4, 2, 3])
x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 8])
node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2'])
return ([node], [x, x_shape], [y])


@onnx_test()
def reshape_variable_input_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 2, 3])
x_shape = helper.make_tensor_value_info('1', TensorProto.INT64, [2])
y = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None, 6])
node = onnx.helper.make_node('Reshape', inputs=['0', '1'], outputs=['2'])
return ([node], [x, x_shape], [y])


@onnx_test()
def resize_downsample_f_test():
scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32)
Expand Down
94 changes: 63 additions & 31 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ TEST_CASE(averagepool_notset_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {2, 2, 2, 2}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {2, 2, 2, 2}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
input);
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {2, 2}}}), ins);
Expand All @@ -382,11 +382,11 @@ TEST_CASE(averagepool_nt_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 0, 0, 0, 0, 1, 1};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {2, 2}},
{"lengths", {6, 6}}}),
ins_pad);
mm->add_return({ret});

Expand Down Expand Up @@ -426,11 +426,11 @@ TEST_CASE(averagepool_sl_cip_test)
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
std::vector<int64_t> pads = {0, 0, 1, 1, 0, 0, 0, 0};
auto ins_pad = mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}}), input);
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
auto ret = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
ins_pad);
mm->add_return({ret});
auto prog = migraphx::parse_onnx("averagepool_sl_cip_test.onnx");
Expand All @@ -444,10 +444,10 @@ TEST_CASE(averagepool_same_upper_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}});
auto ins = mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
{{"mode", migraphx::op::pooling_mode::average},
{"padding", {1, 1, 1, 1}},
{"stride", {1, 1}},
{"lengths", {2, 2}}}),
input);
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {1, 1}}, {"ends", {6, 6}}}), ins);
Expand Down Expand Up @@ -1634,7 +1634,7 @@ TEST_CASE(conv_transpose_input_pads_asymm_1d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3}});
auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards",
{{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}),
{{"padding", {0}}, {"stride", {2}}, {"dilation", {1}}}),
l0,
l1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {6}}}),
Expand Down Expand Up @@ -1668,7 +1668,7 @@ TEST_CASE(conv_transpose_output_padding_3d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
l0,
l1);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
Expand Down Expand Up @@ -1701,7 +1701,7 @@ TEST_CASE(conv_transpose_output_shape_3d_test)
auto l1 = mm->add_parameter("w", {migraphx::shape::float_type, {1, 2, 3, 3, 3}});
auto l2 = mm->add_instruction(
migraphx::make_op("convolution_backwards",
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
{{"padding", {0, 0, 0}}, {"stride", {3, 2, 2}}, {"dilation", {1, 1, 1}}}),
l0,
l1);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1, 1, 1}}}), l2);
Expand Down Expand Up @@ -1996,7 +1996,7 @@ TEST_CASE(equal_test)
auto eq = mm->add_instruction(migraphx::make_op("equal"), input1, input2);
auto ret = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
eq);
mm->add_return({ret});

Expand All @@ -2016,7 +2016,7 @@ TEST_CASE(equal_bool_test)
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
auto ret = mm->add_instruction(migraphx::make_op("equal"), cin1, input2);
mm->add_return({ret});
Expand Down Expand Up @@ -2726,7 +2726,7 @@ TEST_CASE(greater_test)
auto gr = mm->add_instruction(migraphx::make_op("greater"), input1, input2);
auto ret = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
gr);
mm->add_return({ret});

Expand All @@ -2745,7 +2745,7 @@ TEST_CASE(greater_bool_test)
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
auto ret = mm->add_instruction(migraphx::make_op("greater"), cin1, input2);
mm->add_return({ret});
Expand Down Expand Up @@ -3602,7 +3602,7 @@ TEST_CASE(less_test)
auto le = mm->add_instruction(migraphx::make_op("less"), input1, input2);
auto ret = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
le);
mm->add_return({ret});

Expand All @@ -3621,7 +3621,7 @@ TEST_CASE(less_bool_test)
auto input2 = mm->add_parameter("x2", sb);
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::bool_type)}}),
input1);
auto ret = mm->add_instruction(migraphx::make_op("less"), cin1, input2);
mm->add_return({ret});
Expand Down Expand Up @@ -5463,7 +5463,7 @@ TEST_CASE(reducel1_dyn_test)
// a shape with 4 dynamic dimensions
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {-2}}}), abs_ins);
Expand All @@ -5483,7 +5483,7 @@ TEST_CASE(reducel1_dyn_test)
// No axes given in the onnx file. Parser should default to all axes.
auto l0 = mm->add_parameter("x",
migraphx::shape{migraphx::shape::float_type,
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
{{3, 3}, {3, 5}, {4, 6, {5}}, {5, 7, {6}}}});
auto abs_ins = mm->add_instruction(migraphx::make_op("abs"), l0);
auto sum_ins =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1, 2, 3}}}), abs_ins);
Expand Down Expand Up @@ -5719,6 +5719,38 @@ TEST_CASE(reshape_non_standard_test)
EXPECT(p == prog);
}

TEST_CASE(reshape_variable_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);

auto prog = optimize_onnx("reshape_variable_input_test.onnx");
EXPECT(p == prog);
}

TEST_CASE(reshape_variable_input_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 2}, {3, 3}}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int64_type, {2}});
auto alloc = mm->add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), p1);
auto reshape = mm->add_instruction(migraphx::make_op("reshape"), p0, alloc);
mm->add_return({reshape});

migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4};
auto prog = parse_onnx("reshape_variable_input_dyn_test.onnx", options);
EXPECT(p == prog);
}

TEST_CASE(resize_downsample_c_test)
{
migraphx::program p;
Expand Down Expand Up @@ -7169,7 +7201,7 @@ TEST_CASE(squeeze_unsqueeze_dyn_test)
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}});
{{1, 1}, {1, 4}, {1, 1}, {1, 1}, {1, 4}, {1, 1}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
Expand Down Expand Up @@ -7249,7 +7281,7 @@ TEST_CASE(sum_int_test)
auto input2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint32_type, {3}});
auto cin0 = mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}),
{{"target_type", migraphx::to_value(migraphx::shape::uint32_type)}}),
input0);
auto cin1 = mm->add_instruction(
migraphx::make_op("convert",
Expand Down
Binary file added test/onnx/reshape_variable_input_dyn_test.onnx
Binary file not shown.
17 changes: 17 additions & 0 deletions test/onnx/reshape_variable_input_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
reshape_variable_input_test:p

0
12"Reshapereshape_variable_input_testZ
0



Z
1


b
2


B
Loading

0 comments on commit f25606f

Please sign in to comment.