Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Constant parsing to support more attributes #2141 #2216

Merged
merged 14 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/onnx/parse_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const
{
literal v = parser.parse_value(info.attributes.at("value"));
static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};

std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });

if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}

if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}

// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);

// return empty literal
if(v.get_shape().elements() == 0)
{
return info.add_literal(literal{v.get_shape().type()});
}

auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
if(attr.has_t() and attr.t().dims_size() == 0)
{
migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()});
Expand Down
Binary file added test/onnx/constant_multiple_attributes_test.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions test/onnx/constant_no_attributes_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
constant_no_attributes_test:)

"Constantconstant_no_attributes_testB
Binary file added test/onnx/constant_value_float_test.onnx
Binary file not shown.
Binary file added test/onnx/constant_value_floats_test.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions test/onnx/constant_value_int_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
constant_value_int_test:7
"Constant*
value_int@ constant_value_int_testB
4 changes: 4 additions & 0 deletions test/onnx/constant_value_ints_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
constant_value_ints_test:=
!"Constant*

value_ints@@@ constant_value_ints_testB
70 changes: 70 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,76 @@ def constant_test():
return ([node], [], [y])


@onnx_test()
def constant_value_float_test():

node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_float=[1.0])

return ([node], [], [])


@onnx_test()
def constant_value_floats_test():

node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_floats=[1.0, 2.0, 3.0])

return ([node], [], [])


@onnx_test()
def constant_value_int_test():

node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_int=[1])

return ([node], [], [])


@onnx_test()
def constant_value_ints_test():

node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_ints=[1, 2, 3])

return ([node], [], [])


@onnx_test()
def constant_no_attributes_test():

node = onnx.helper.make_node('Constant', inputs=[], outputs=[])

return ([node], [], [])


@onnx_test()
def constant_multiple_attributes_test():
x = np.array([0, 1, 2])

node = onnx.helper.make_node('Constant',
inputs=[],
outputs=[],
value_floats=[1.0, 2.0],
value_ints=[1, 2],
value=onnx.helper.make_tensor(
name='const_tensor',
data_type=TensorProto.FLOAT,
dims=x.shape,
vals=x.flatten().astype(float)))

return ([node], [], [])


@onnx_test()
def constant_fill_test():
value = helper.make_tensor_value_info('value', TensorProto.FLOAT, [2, 3])
Expand Down
52 changes: 52 additions & 0 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,58 @@ TEST_CASE(constant_test)
EXPECT(p == prog);
}

TEST_CASE(constant_value_float_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1.0f}});
auto prog = optimize_onnx("constant_value_float_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(constant_value_floats_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {1.0f, 2.0f, 3.0f}});
auto prog = optimize_onnx("constant_value_floats_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(constant_value_int_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {1}}, {1}});
auto prog = optimize_onnx("constant_value_int_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(constant_value_ints_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {3}}, {1, 2, 3}});
auto prog = optimize_onnx("constant_value_ints_test.onnx");

EXPECT(p == prog);
}

TEST_CASE(constant_no_attributes_test)
{
EXPECT(test::throws([&] { optimize_onnx("constant_no_attributes_test.onnx"); }));
}

TEST_CASE(constant_multiple_attributes_test)
{
EXPECT(test::throws([&] { optimize_onnx("constant_multiple_attributes_test.onnx"); }));
}

TEST_CASE(constant_fill_test)
{
migraphx::program p;
Expand Down
34 changes: 34 additions & 0 deletions test/py/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,34 @@ def disabled_tests_onnx_1_12_0(backend_test):
backend_test.exclude(r'test_scatter_elements_with_duplicate_indices_cpu')


def disabled_tests_onnx_1_13_0(backend_test):
# The following tests fail due to the CastLike operator being unsupported
backend_test.exclude(r'test_elu_default_expanded_ver18_cpu')
backend_test.exclude(r'test_elu_example_expanded_ver18_cpu')
backend_test.exclude(r'test_elu_expanded_ver18_cpu')
backend_test.exclude(r'test_hardsigmoid_default_expanded_ver18_cpu')
backend_test.exclude(r'test_hardsigmoid_example_expanded_ver18_cpu')
backend_test.exclude(r'test_hardsigmoid_expanded_ver18_cpu')
backend_test.exclude(r'test_leakyrelu_default_expanded_cpu')
backend_test.exclude(r'test_leakyrelu_example_expanded_cpu')
backend_test.exclude(r'test_leakyrelu_expanded_cpu')
backend_test.exclude(r'test_selu_default_expanded_ver18_cpu')
backend_test.exclude(r'test_selu_example_expanded_ver18_cpu')
backend_test.exclude(r'test_selu_expanded_ver18_cpu')
backend_test.exclude(r'test_thresholdedrelu_default_expanded_ver18_cpu')
backend_test.exclude(r'test_thresholdedrelu_example_expanded_ver18_cpu')
backend_test.exclude(r'test_thresholdedrelu_expanded_ver18_cpu')
backend_test.exclude(r'test_relu_expanded_ver18_cpu')
backend_test.exclude(r'test_softsign_example_expanded_ver18_cpu')
backend_test.exclude(r'test_softsign_expanded_ver18_cpu')


def disabled_tests_onnx_1_14_0(backend_test):
# The following tests fail due to the CastLike operator being unsupported
backend_test.exclude(r'test_softplus_example_expanded_ver18_cpu')
backend_test.exclude(r'test_softplus_expanded_ver18_cpu')


def create_backend_test(testname=None, target_device=None):
if target_device is not None:
c2.set_device(target_device)
Expand Down Expand Up @@ -334,6 +362,12 @@ def create_backend_test(testname=None, target_device=None):
if version.parse(onnx.__version__) >= version.parse("1.12.0"):
disabled_tests_onnx_1_12_0(backend_test)

if version.parse(onnx.__version__) >= version.parse("1.13.0"):
disabled_tests_onnx_1_13_0(backend_test)

if version.parse(onnx.__version__) >= version.parse("1.14.0"):
disabled_tests_onnx_1_14_0(backend_test)


# import all test cases at global scope to make
# them visible to python.unittest.
Expand Down
Loading