diff --git a/src/include/migraphx/op/reduce_op.hpp b/src/include/migraphx/op/reduce_op.hpp index d9de9939610..c1fc5e56731 100644 --- a/src/include/migraphx/op/reduce_op.hpp +++ b/src/include/migraphx/op/reduce_op.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -94,16 +94,69 @@ struct reduce_op : op_name return {{"normalize_axes", normalize}, {"reduce", true}}; } - std::vector tune_axes(std::size_t n_dim) const + shape collapse_reduced_axes(const shape& original_shape, + const std::vector& reduce_axes) const { - auto tuned_axes = axes; - if(tuned_axes.empty()) + auto lens = original_shape.lens(); + for(const auto a : reduce_axes) { - tuned_axes.resize(n_dim); - std::iota(tuned_axes.begin(), tuned_axes.end(), 0); + lens[a] = 1; } - return tuned_axes; + return original_shape.with_lens(lens); + } + + // Compute the output shape for cases when the input tensor has a dynamic shape. + // + // If the axes are passed as a variable input(indicated by an empty axes attribute), we cannot + // determine which axes must be collapsed until we see the actual input values, so we must treat + // each axis as potentially collapsable and set its minimum dimension to 1. + shape compute_dynamic_shape(const std::vector& inputs) const + { + const auto& data_shape = inputs[0]; + auto dims = data_shape.dyn_dims(); + if(axes.empty()) + { + for(auto& dim : dims) + { + dim = {1, dim.max}; + } + } + else + { + for(auto a : axes) + { + dims[a] = {1, 1}; + } + } + + return {data_shape.type(), dims}; + } + + // Compute the output shape for cases when the input tensor has a static shape. + // Depending on how axes is passed to the operator the output shape can be either dynamic or + // static. + // + // If the axes are passed as a variable input(indicated by an empty axes attribute), we cannot + // determine which axes must be collapsed until we see the actual input values, so we must treat + // each axis as potentially collapsable, producing a dynamic output shape. + shape compute_static_shape(const std::vector& inputs) const + { + const auto& data_shape = inputs[0]; + if(axes.empty()) + { + std::vector dims(data_shape.ndim()); + auto lens = data_shape.lens(); + std::transform(lens.begin(), lens.end(), dims.begin(), [](auto len) { + return shape::dynamic_dimension{1, len}; + }); + + return {data_shape.type(), std::move(dims)}; + } + else + { + return collapse_reduced_axes(data_shape, axes); + } } /** @@ -115,29 +168,16 @@ struct reduce_op : op_name */ shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this, true}.has(1); - auto s = inputs.at(0); - if(s.dynamic()) - { - auto output_dyn_dims = s.dyn_dims(); - auto tuned_axes = tune_axes(output_dyn_dims.size()); - for(const auto& axis : tuned_axes) - { - output_dyn_dims[axis] = {1, 1}; - } + auto expected_arg_count = axes.empty() ? 2 : 1; + check_shapes{inputs, *this, true}.has(expected_arg_count); - return shape{s.type(), output_dyn_dims}; + if(inputs[0].dynamic()) + { + return compute_dynamic_shape(inputs); } else { - auto lens = s.lens(); - auto tuned_axes = tune_axes(lens.size()); - for(const auto& axis : tuned_axes) - { - lens[axis] = 1; - } - - return inputs[0].with_lens(lens); + return compute_static_shape(inputs); } } @@ -153,10 +193,10 @@ struct reduce_op : op_name } template - void reduce(tensor_view& input, - shape& batch_shape, - std::vector& tuned_axes, - std::vector& out_idx, + void reduce(const tensor_view& input, + const shape& batch_shape, + const std::vector& tuned_axes, + const std::vector& out_idx, tensor_view& output) const { using accumulator = accumulator_type; @@ -173,24 +213,43 @@ struct reduce_op : op_name static_cast(*this).output(batch_shape)(val); } - argument compute(const dyn_output& dyn_out, std::vector args) const + argument reduce(const shape& computed_shape, + const std::vector& reduce_axes, + argument& data_arg) const { - argument result{dyn_out.computed_shape}; - auto arg_lens = args.front().get_shape().lens(); - auto tuned_axes = tune_axes(arg_lens.size()); - std::vector batch_lens(dyn_out.computed_shape.lens().size(), 1); - tune_dims(tuned_axes, arg_lens, batch_lens); - shape batch_shape{dyn_out.computed_shape.type(), batch_lens}; - visit_all(result, args[0])([&](auto output, auto input) { - par_for(dyn_out.computed_shape.elements(), [&](auto i) { - auto out_idx = dyn_out.computed_shape.multi(i); - this->reduce(input, batch_shape, tuned_axes, out_idx, output); + std::vector batch_lens(computed_shape.ndim(), 1); + auto arg_lens = data_arg.get_shape().lens(); + tune_dims(reduce_axes, arg_lens, batch_lens); + shape batch_shape{computed_shape.type(), batch_lens}; + argument result{computed_shape}; + + visit_all(result, data_arg)([&](auto output, auto input) { + par_for(computed_shape.elements(), [&](auto i) { + auto out_idx = computed_shape.multi(i); + this->reduce(input, batch_shape, reduce_axes, out_idx, output); }); }); return result; } + argument compute(const dyn_output& dyn_out, std::vector args) const + { + auto&& data_arg = args[0]; + // cppcheck-suppress knownConditionTrueFalse + if(not axes.empty()) + return reduce(dyn_out.computed_shape, axes, data_arg); + + if(args[1].get_shape().elements() == 0) + return args[0]; + + std::vector reduce_axes; + args[1].visit([&](auto&& s) { reduce_axes.assign(s.begin(), s.end()); }); + const auto result_shape = collapse_reduced_axes(data_arg.get_shape(), reduce_axes); + + return reduce(result_shape, reduce_axes, data_arg); + } + auto init() const { return zero(); } auto input() const diff --git a/src/include/migraphx/op/where.hpp b/src/include/migraphx/op/where.hpp index 8c58deaa207..d0f94c312db 100644 --- a/src/include/migraphx/op/where.hpp +++ b/src/include/migraphx/op/where.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -42,7 +42,13 @@ struct where shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this, true}.has(3).same_dims(); + check_shapes shape_checker{inputs, *this, true}; + shape_checker.has(3); + if(auto s = inputs[0]; not s.dynamic() and s.elements() == 1) + check_shapes{std::next(inputs.begin()), inputs.end(), *this, true}.same_dims(); + else + shape_checker.same_dims(); + auto s1 = inputs.at(1); auto s2 = inputs.at(2); if(s1.dynamic() or s2.dynamic()) @@ -71,12 +77,18 @@ struct where } } - argument compute(const dyn_output& dyn_out, std::vector args) const + argument compute(shape output_shape, std::vector args) const { - argument result{dyn_out.computed_shape}; + if(auto s = args[0].get_shape(); not s.dynamic() and s.elements() == 1) + return args[args[0].at() ? 1 : 2].copy(); + + if(output_shape.dynamic()) + output_shape = compute_shape(to_shapes(args)); + argument result{output_shape}; + visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) { args[0].visit([&](const auto condition) { - par_for(dyn_out.computed_shape.elements(), + par_for(output_shape.elements(), [&](auto i) { output[i] = condition[i] ? x[i] : y[i]; }); }); }); diff --git a/src/onnx/parse_reduce_op.cpp b/src/onnx/parse_reduce_op.cpp index 0024635ca2f..9d69c1b15f4 100644 --- a/src/onnx/parse_reduce_op.cpp +++ b/src/onnx/parse_reduce_op.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,66 +31,112 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace onnx { -instruction_ref parse_reduce_oper(const std::string& op_name, - const onnx_parser& parser, - onnx_parser::node_info info, - std::vector args) +template +struct reduce_parser : op_parser { - // default to reduce over all dimensions - std::vector axes; - if(args.size() == 2) + instruction_ref parse_reduce_oper(const std::string& op_name, + const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const { - auto arg_axes = args.at(1)->eval(); - check_arg_empty(arg_axes, "PARSE_" + op_name + ": cannot handle variable axes!"); - axes.clear(); - arg_axes.visit([&](auto s) { axes.assign(s.begin(), s.end()); }); - } - else if(contains(info.attributes, "axes")) - { - axes.clear(); - auto&& attr_axes = info.attributes["axes"].ints(); - axes = std::vector(attr_axes.begin(), attr_axes.end()); - } + auto constant_axes = parse_constant_axes(args, info); - bool noop_with_empty_axes = false; - if(contains(info.attributes, "noop_with_empty_axes")) - { - noop_with_empty_axes = static_cast( - parser.parse_value(info.attributes.at("noop_with_empty_axes")).at()); - } + int noop_with_empty_axes = + parse_attribute("noop_with_empty_axes", parser, info).value_or(0); - // empty axes behavior - if(axes.empty()) - { - if(noop_with_empty_axes) + int keep_dims = parse_attribute("keepdims", parser, info).value_or(1); + + std::vector all_axes(args.front()->get_shape().ndim()); + std::iota(all_axes.begin(), all_axes.end(), 0); + + // Handle axes attribute, constant input axes, and missing both attribute and input cases + if(constant_axes.has_value()) { - return args.at(0); + if(noop_with_empty_axes != 0 and constant_axes->empty()) + return args[0]; + + if(noop_with_empty_axes == 0 and constant_axes->empty()) + constant_axes = all_axes; + + auto reduce = + info.add_instruction(make_op(op_name, {{"axes", *constant_axes}}), args[0]); + + if(keep_dims == 0) + return info.add_instruction(make_op("squeeze", {{"axes", *constant_axes}}), reduce); + + return reduce; + } + + // Handle variable input axes + if(keep_dims == 0) + MIGRAPHX_THROW("Keepdims not supported with runtime provided axes"); + + // Empty axes attribute indicates to the operator to look for axes in the inputs + // If the input axes are empty, the default behavior of reduce_op is to be an + // identity operator + auto reduce_op = make_op(op_name, {{"axes", {}}}); + + if(noop_with_empty_axes != 0) + return info.add_instruction(reduce_op, args); + + if(args[1]->get_shape().dynamic()) + { + auto reduce_input_axes = info.add_instruction(reduce_op, args); + auto all_axes_lit = info.add_literal( + literal{shape{shape::type_t::int64_type, {all_axes.size()}}, all_axes}); + auto reduce_all_axes = info.add_instruction(reduce_op, args[0], all_axes_lit); + auto zero = info.add_literal(literal{shape{shape::type_t::int64_type}, {0u}}); + auto axes_size = info.add_instruction(make_op("dimensions_of", {{"end", 1}}), args[1]); + auto is_axes_empty = info.add_instruction(make_op("equal"), axes_size, zero); + + return info.add_instruction( + make_op("where"), is_axes_empty, reduce_all_axes, reduce_input_axes); + } + else if(args[1]->get_shape().elements() == 0) + { + auto all_axes_lit = info.add_literal( + literal{shape{shape::type_t::int64_type, {all_axes.size()}}, all_axes}); + return info.add_instruction(reduce_op, args[0], all_axes_lit); } else { - axes.resize(args.front()->get_shape().ndim()); - std::iota(axes.begin(), axes.end(), 0); + return info.add_instruction(reduce_op, args); } } - int keep_dims = 1; - if(contains(info.attributes, "keepdims")) + private: + template + std::optional parse_attribute(const std::string& attribute_name, + const onnx_parser& parser, + onnx_parser::node_info& info) const { - keep_dims = parser.parse_value(info.attributes.at("keepdims")).at(); - } + if(not contains(info.attributes, attribute_name)) + return std::nullopt; - if(keep_dims == 1) - { - return info.add_instruction(make_op(op_name, {{"axes", axes}}), args.front()); + return parser.parse_value(info.attributes[attribute_name]).at(); } - else + + std::optional> parse_constant_axes(std::vector& args, + onnx_parser::node_info& info) const { - auto ins = info.add_instruction(make_op(op_name, {{"axes", axes}}), args.front()); - return info.add_instruction(make_op("squeeze", {{"axes", axes}}), ins); + std::vector axes; + if(args.size() == 2) + { + if(not args[1]->can_eval()) + return std::nullopt; + args[1]->eval().visit([&](auto s) { axes.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "axes")) + { + auto&& attr_axes = info.attributes["axes"].ints(); + axes.assign(attr_axes.begin(), attr_axes.end()); + } + + return axes; } -} +}; -struct parse_reduce_op : op_parser +struct parse_reduce_op : reduce_parser { std::vector operators() const { @@ -110,7 +156,7 @@ struct parse_reduce_op : op_parser } }; -struct parse_reduce_l1 : op_parser +struct parse_reduce_l1 : reduce_parser { std::vector operators() const { return {{"ReduceL1"}}; } @@ -119,12 +165,12 @@ struct parse_reduce_l1 : op_parser onnx_parser::node_info info, std::vector args) const { - auto abs_ins = info.add_instruction(make_op("abs"), args[0]); - return parse_reduce_oper("reduce_sum", parser, std::move(info), {abs_ins}); + args[0] = info.add_instruction(make_op("abs"), args[0]); + return parse_reduce_oper("reduce_sum", parser, std::move(info), std::move(args)); } }; -struct parse_reduce_l2 : op_parser +struct parse_reduce_l2 : reduce_parser { std::vector operators() const { return {{"ReduceL2"}}; } @@ -133,13 +179,13 @@ struct parse_reduce_l2 : op_parser const onnx_parser::node_info& info, std::vector args) const { - auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]); - auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {square_ins}); + args[0] = info.add_instruction(make_op("mul"), args[0], args[0]); + auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, std::move(args)); return info.add_instruction(make_op("sqrt"), sum_ins); } }; -struct parse_reduce_log_sum : op_parser +struct parse_reduce_log_sum : reduce_parser { std::vector operators() const { return {{"ReduceLogSum"}}; } @@ -153,7 +199,7 @@ struct parse_reduce_log_sum : op_parser } }; -struct parse_reduce_log_sum_exp : op_parser +struct parse_reduce_log_sum_exp : reduce_parser { std::vector operators() const { return {{"ReduceLogSumExp"}}; } @@ -162,13 +208,13 @@ struct parse_reduce_log_sum_exp : op_parser const onnx_parser::node_info& info, std::vector args) const { - auto exp_ins = info.add_instruction(make_op("exp"), args[0]); - auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, {exp_ins}); + args[0] = info.add_instruction(make_op("exp"), args[0]); + auto sum_ins = parse_reduce_oper("reduce_sum", parser, info, std::move(args)); return info.add_instruction(make_op("log"), sum_ins); } }; -struct parse_reduce_sum_square : op_parser +struct parse_reduce_sum_square : reduce_parser { std::vector operators() const { return {{"ReduceSumSquare"}}; } @@ -177,8 +223,8 @@ struct parse_reduce_sum_square : op_parser onnx_parser::node_info info, std::vector args) const { - auto square_ins = info.add_instruction(make_op("mul"), args[0], args[0]); - return parse_reduce_oper("reduce_sum", parser, std::move(info), {square_ins}); + args[0] = info.add_instruction(make_op("mul"), args[0], args[0]); + return parse_reduce_oper("reduce_sum", parser, std::move(info), std::move(args)); } }; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index b385091221a..66605a91d35 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -7440,6 +7440,23 @@ def recip_test(): return ([node], [x], [y]) +def reduceop_variable_axes_test(op_name, + axes_len=1, + keepdims=1, + noop_with_empty_axes=0): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) + axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [axes_len]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) + + node = onnx.helper.make_node(op_name, + inputs=['x', 'axes'], + outputs=['y'], + keepdims=keepdims, + noop_with_empty_axes=noop_with_empty_axes) + + return ([node], [x, axes], [y]) + + @onnx_test() def reducel1_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) @@ -7623,12 +7640,11 @@ def reduceprod_test(): def reducesum_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4, 1, 6]) - axes = [2] node = onnx.helper.make_node('ReduceSum', inputs=['x'], outputs=['y'], - axes=axes, + axes=[2], keepdims=0) return ([node], [x], [y]) @@ -7702,6 +7718,53 @@ def reducesum_multiaxis_test(): return ([node], [x], [y]) +@onnx_test() +def reducesum_variable_axes_test(): + return reduceop_variable_axes_test('ReduceSum') + + +@onnx_test() +def reducesum_variable_axes_noop_test(): + return reduceop_variable_axes_test('ReduceSum', noop_with_empty_axes=1) + + +@onnx_test() +def reducesum_variable_axes_keepdims_clear_test(): + return reduceop_variable_axes_test('ReduceSum', keepdims=0) + + +@onnx_test() +def reducesum_variable_dynamic_axes_test(): + return reduceop_variable_axes_test('ReduceSum', None) + + +@onnx_test() +def reducesum_variable_dynamic_axes_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2]) + axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [None]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None]) + + node = onnx.helper.make_node('ReduceSum', + inputs=['x', 'axes'], + outputs=['y']) + + return ([node], [x, axes], [y]) + + +@onnx_test() +def reducesum_variable_dynamic_axes_noop_set_verify_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2]) + axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [None]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None]) + + node = onnx.helper.make_node('ReduceSum', + inputs=['x', 'axes'], + outputs=['y'], + noop_with_empty_axes=1) + + return ([node], [x, axes], [y]) + + @onnx_test() def reducesum_square_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6]) diff --git a/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp b/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp new file mode 100644 index 00000000000..d7ca00f188a --- /dev/null +++ b/test/onnx/parse/reducesum_variable_axes_keepdims_clear_test.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include + +TEST_CASE(reducesum_variable_axes_keepdims_clear_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("reducesum_variable_axes_keepdims_clear_test.onnx"); })); +} diff --git a/test/onnx/parse/reducesum_variable_axes_noop_test.cpp b/test/onnx/parse/reducesum_variable_axes_noop_test.cpp new file mode 100644 index 00000000000..b6fa81f03b7 --- /dev/null +++ b/test/onnx/parse/reducesum_variable_axes_noop_test.cpp @@ -0,0 +1,36 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include + +TEST_CASE(reducesum_variable_axes_noop_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {1}}); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), x, axes); + + auto prog = optimize_onnx("reducesum_variable_axes_noop_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/reducesum_variable_axes_test.cpp b/test/onnx/parse/reducesum_variable_axes_test.cpp new file mode 100644 index 00000000000..af2dfcf9c93 --- /dev/null +++ b/test/onnx/parse/reducesum_variable_axes_test.cpp @@ -0,0 +1,36 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include + +TEST_CASE(reducesum_variable_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {1}}); + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), x, axes); + + auto prog = optimize_onnx("reducesum_variable_axes_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp b/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp new file mode 100644 index 00000000000..63c5200ff71 --- /dev/null +++ b/test/onnx/parse/reducesum_variable_dynamic_axes_test.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include + +TEST_CASE(reducesum_variable_dynamic_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + const std::vector axes_dims{{0, 3}}; + auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, axes_dims}); + + auto reduce_input_axes = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), x, axes); + std::vector all_axes(x->get_shape().ndim()); + std::iota(all_axes.begin(), all_axes.end(), 0); + auto all_axes_lit = mm->add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::type_t::int64_type, {all_axes.size()}}, all_axes}); + auto reduce_all_axes = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), x, all_axes_lit); + + auto zero_lit = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int64_type}, {0u}}); + auto axes_size = mm->add_instruction(migraphx::make_op("dimensions_of", {{"end", 1}}), axes); + auto is_axes_empty = mm->add_instruction(migraphx::make_op("equal"), axes_size, zero_lit); + auto where = mm->add_instruction( + migraphx::make_op("where"), is_axes_empty, reduce_all_axes, reduce_input_axes); + mm->add_return({where}); + + migraphx::onnx_options options; + options.map_dyn_input_dims["axes"] = axes->get_shape().dyn_dims(); + auto prog = parse_onnx("reducesum_variable_dynamic_axes_test.onnx", options); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/reducesum_variable_empty_axes_test.cpp b/test/onnx/parse/reducesum_variable_empty_axes_test.cpp new file mode 100644 index 00000000000..cb660001b97 --- /dev/null +++ b/test/onnx/parse/reducesum_variable_empty_axes_test.cpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include + +TEST_CASE(reducesum_variable_empty_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {0}}); + + std::vector all_axes(x->get_shape().ndim()); + std::iota(all_axes.begin(), all_axes.end(), 0); + auto all_axes_lit = mm->add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::int64_type, {all_axes.size()}}, all_axes}); + auto reduce_all_axes = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), x, all_axes_lit); + mm->add_return({reduce_all_axes}); + + migraphx::onnx_options options; + options.map_input_dims["axes"] = axes->get_shape().lens(); + auto prog = parse_onnx("reducesum_variable_axes_test.onnx", options); + EXPECT(p == prog); +} diff --git a/test/onnx/reducesum_variable_axes_keepdims_clear_test.onnx b/test/onnx/reducesum_variable_axes_keepdims_clear_test.onnx new file mode 100644 index 00000000000..f0eb3590646 Binary files /dev/null and b/test/onnx/reducesum_variable_axes_keepdims_clear_test.onnx differ diff --git a/test/onnx/reducesum_variable_axes_noop_test.onnx b/test/onnx/reducesum_variable_axes_noop_test.onnx new file mode 100644 index 00000000000..e3e9e8fd408 --- /dev/null +++ b/test/onnx/reducesum_variable_axes_noop_test.onnx @@ -0,0 +1,22 @@ + !reducesum_variable_axes_noop_test:¸ +E +x +axesy" ReduceSum* +keepdims * +noop_with_empty_axes !reducesum_variable_axes_noop_testZ +x + + + + +Z +axes + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/reducesum_variable_axes_test.onnx b/test/onnx/reducesum_variable_axes_test.onnx new file mode 100644 index 00000000000..d2a2bc4a2c4 Binary files /dev/null and b/test/onnx/reducesum_variable_axes_test.onnx differ diff --git a/test/onnx/reducesum_variable_dynamic_axes_noop_set_verify_test.onnx b/test/onnx/reducesum_variable_dynamic_axes_noop_set_verify_test.onnx new file mode 100644 index 00000000000..6304cc83ca0 Binary files /dev/null and b/test/onnx/reducesum_variable_dynamic_axes_noop_set_verify_test.onnx differ diff --git a/test/onnx/reducesum_variable_dynamic_axes_test.onnx b/test/onnx/reducesum_variable_dynamic_axes_test.onnx new file mode 100644 index 00000000000..c956ac17e95 Binary files /dev/null and b/test/onnx/reducesum_variable_dynamic_axes_test.onnx differ diff --git a/test/onnx/reducesum_variable_dynamic_axes_verify_test.onnx b/test/onnx/reducesum_variable_dynamic_axes_verify_test.onnx new file mode 100644 index 00000000000..b37ab32e730 Binary files /dev/null and b/test/onnx/reducesum_variable_dynamic_axes_verify_test.onnx differ diff --git a/test/onnx/verify/reducesum_variable_axes_test.cpp b/test/onnx/verify/reducesum_variable_axes_test.cpp new file mode 100644 index 00000000000..e747c4db702 --- /dev/null +++ b/test/onnx/verify/reducesum_variable_axes_test.cpp @@ -0,0 +1,104 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +auto reducesum_variable_axes_test_base(const std::string& file, size_t axes_size) +{ + std::pair, migraphx::shape> ret; + + migraphx::onnx_options options; + options.map_input_dims["axes"] = std::vector{axes_size}; + migraphx::program p = migraphx::parse_onnx(file, options); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + migraphx::shape x_shape{migraphx::shape::float_type, {3, 4, 5, 6}}; + std::vector x(x_shape.elements()); + std::iota(x.begin(), x.end(), 0); + pm["x"] = migraphx::argument(x_shape, x.data()); + auto axes_data = axes_size == 0 ? std::vector{} : std::vector{2}; + pm["axes"] = migraphx::argument(migraphx::shape{migraphx::shape::int64_type, {axes_size}}, + axes_data.data()); + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { ret.first.assign(output.begin(), output.end()); }); + ret.second = result.get_shape(); + + return ret; +} + +TEST_CASE(bla) +{ + auto [result_vector, shape] = + reducesum_variable_axes_test_base("reducesum_variable_axes_test.onnx", 1); + std::vector gold{60, 65, 70, 75, 80, 85, 210, 215, 220, 225, 230, 235, + 360, 365, 370, 375, 380, 385, 510, 515, 520, 525, 530, 535, + 660, 665, 670, 675, 680, 685, 810, 815, 820, 825, 830, 835, + 960, 965, 970, 975, 980, 985, 1110, 1115, 1120, 1125, 1130, 1135, + 1260, 1265, 1270, 1275, 1280, 1285, 1410, 1415, 1420, 1425, 1430, 1435, + 1560, 1565, 1570, 1575, 1580, 1585, 1710, 1715, 1720, 1725, 1730, 1735}; + + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 6}}); + EXPECT(result_vector == gold); +} + +TEST_CASE(bla2) +{ + auto [result_vector, shape] = + reducesum_variable_axes_test_base("reducesum_variable_axes_test.onnx", 0); + std::vector gold{64620}; + + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}); + EXPECT(result_vector == gold); +} + +TEST_CASE(bla3) +{ + auto [result_vector, shape] = + reducesum_variable_axes_test_base("reducesum_variable_axes_noop_test.onnx", 1); + std::vector gold{60, 65, 70, 75, 80, 85, 210, 215, 220, 225, 230, 235, + 360, 365, 370, 375, 380, 385, 510, 515, 520, 525, 530, 535, + 660, 665, 670, 675, 680, 685, 810, 815, 820, 825, 830, 835, + 960, 965, 970, 975, 980, 985, 1110, 1115, 1120, 1125, 1130, 1135, + 1260, 1265, 1270, 1275, 1280, 1285, 1410, 1415, 1420, 1425, 1430, 1435, + 1560, 1565, 1570, 1575, 1580, 1585, 1710, 1715, 1720, 1725, 1730, 1735}; + + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 6}}); + EXPECT(result_vector == gold); +} + +TEST_CASE(bla4) +{ + auto [result_vector, shape] = + reducesum_variable_axes_test_base("reducesum_variable_axes_noop_test.onnx", 0); + std::vector gold(3 * 4 * 5 * 6); + std::iota(gold.begin(), gold.end(), 0); + + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); + EXPECT(result_vector == gold); +} diff --git a/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp b/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp new file mode 100644 index 00000000000..aceb49cd180 --- /dev/null +++ b/test/onnx/verify/reducesum_variable_dynamic_axes_test.cpp @@ -0,0 +1,99 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +auto reducesum_variable_dynamic_axes_test_base(migraphx::shape axes_shape, + std::vector axes_data, + const std::string& file) +{ + std::pair, migraphx::shape> ret; + + migraphx::onnx_options options; + const std::vector axes_dims{{0, 3}}; + options.map_dyn_input_dims["axes"] = axes_dims; + migraphx::program p = parse_onnx(file, options); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + migraphx::shape x_shape{migraphx::shape::float_type, {2, 2, 2}}; + std::vector x(x_shape.elements()); + std::iota(x.begin(), x.end(), 0); + pm["x"] = migraphx::argument(x_shape, x.data()); + + std::vector axes{1}; + pm["axes"] = migraphx::argument(axes_shape, axes_data.data()); + + auto result = p.eval(pm).back(); + ret.second = result.get_shape(); + result.visit([&](auto output) { ret.first.assign(output.begin(), output.end()); }); + return ret; +} + +TEST_CASE(reducesum_variable_dynamic_axes_test) +{ + auto [result, shape] = reducesum_variable_dynamic_axes_test_base( + {migraphx::shape::int64_type, {1}}, + std::vector{1}, + "reducesum_variable_dynamic_axes_verify_test.onnx"); + std::vector gold{2, 4, 10, 12}; + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); + EXPECT(result == gold); +} + +TEST_CASE(reducesum_variable_dynamic_axes_empty_test) +{ + auto [result, shape] = reducesum_variable_dynamic_axes_test_base( + {migraphx::shape::int64_type, {0}}, + std::vector{}, + "reducesum_variable_dynamic_axes_verify_test.onnx"); + std::vector gold{28}; + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {1, 1, 1}}); + EXPECT(result == gold); +} + +TEST_CASE(reducesum_variable_dynamic_axes_noop_set_test) +{ + auto [result, shape] = reducesum_variable_dynamic_axes_test_base( + {migraphx::shape::int64_type, {1}}, + std::vector{1}, + "reducesum_variable_dynamic_axes_noop_set_verify_test.onnx"); + std::vector gold{2, 4, 10, 12}; + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {2, 1, 2}}); + EXPECT(result == gold); +} + +TEST_CASE(reducesum_variable_dynamic_axes_empty_noop_set_test) +{ + auto [result, shape] = reducesum_variable_dynamic_axes_test_base( + {migraphx::shape::int64_type, {0}}, + std::vector{}, + "reducesum_variable_dynamic_axes_noop_set_verify_test.onnx"); + std::vector gold(8); + std::iota(gold.begin(), gold.end(), 0); + EXPECT(shape == migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + EXPECT(result == gold); +} diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index f1b4bf6ec91..6e7e00925be 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -2719,13 +2719,6 @@ TEST_CASE(dqlinear_mismatch_type) void test_reduce_ops(const std::string& name) { - { - migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, - migraphx::make_op(name), - input); - } - { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, @@ -2776,12 +2769,11 @@ void test_dyn_reduce_ops(const std::string& name) input); } { - // Empty axis argument reduces all axes migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; expect_shape( migraphx::shape{migraphx::shape::float_type, std::vector({{1, 1}, {1, 1}})}, - migraphx::make_op(name), + migraphx::make_op(name, {{"axes", {0, 1}}}), input); } { @@ -2790,16 +2782,41 @@ void test_dyn_reduce_ops(const std::string& name) } } +void test_reduce_ops_variable_axes(const std::string& name) +{ + { + migraphx::shape input_shape{migraphx::shape::float_type, {2, 3, 4}}; + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + migraphx::shape expected_shape{migraphx::shape::float_type, {{1, 2}, {1, 3}, {1, 4}}}; + expect_shape(expected_shape, migraphx::make_op(name), input_shape, axes_shape); + } + + { + migraphx::shape input_shape{migraphx::shape::float_type, {{2, 3}, {3, 4}}}; + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + migraphx::shape expected_shape{migraphx::shape::float_type, {{1, 3}, {1, 4}}}; + expect_shape(expected_shape, migraphx::make_op(name), input_shape, axes_shape); + } +} + TEST_CASE(reduce_max) { test_reduce_ops("reduce_max"); } +TEST_CASE(reduce_min) { test_reduce_ops("reduce_min"); } TEST_CASE(reduce_mean) { test_reduce_ops("reduce_mean"); } TEST_CASE(reduce_prod) { test_reduce_ops("reduce_prod"); } TEST_CASE(reduce_sum) { test_reduce_ops("reduce_sum"); } TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops("reduce_max"); } +TEST_CASE(reduce_min_dyn) { test_dyn_reduce_ops("reduce_min"); } TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops("reduce_mean"); } TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops("reduce_prod"); } TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops("reduce_sum"); } +TEST_CASE(reduce_max_variable_axes) { test_reduce_ops_variable_axes("reduce_max"); } +TEST_CASE(reduce_min_variable_axes) { test_reduce_ops_variable_axes("reduce_min"); } +TEST_CASE(reduce_mean_variable_axes) { test_reduce_ops_variable_axes("reduce_mean"); } +TEST_CASE(reduce_prod_variable_axes) { test_reduce_ops_variable_axes("reduce_prod"); } +TEST_CASE(reduce_sum_variable_axes) { test_reduce_ops_variable_axes("reduce_sum"); } + TEST_CASE(reshape_shape) { migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}}; diff --git a/test/ref/reduce_max.cpp b/test/ref/reduce_max.cpp index 2fcfcaec729..bc353743b29 100644 --- a/test/ref/reduce_max.cpp +++ b/test/ref/reduce_max.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,6 +46,31 @@ TEST_CASE(reduce_max_axis0) EXPECT(results_vector == gold); } +TEST_CASE(reduce_max_variable_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_max"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{9, 10, 11, 12}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_max_dynamic_axis0) { migraphx::program p; @@ -67,6 +92,31 @@ TEST_CASE(reduce_max_dynamic_axis0) EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); } +TEST_CASE(reduce_max_dynamic_variable_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_shape{migraphx::shape::float_type, {{2, 4, {2}}, {3, 5, {3}}}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_max"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + migraphx::shape x_fixed_shape{migraphx::shape::float_type, {2, 5}}; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + pm["x"] = migraphx::argument(x_fixed_shape, x_arg.data()); + std::vector axes_arg{0}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {6, 7, 8, 9, 10}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + TEST_CASE(reduce_max_axis01) { migraphx::program p; @@ -83,6 +133,31 @@ TEST_CASE(reduce_max_axis01) EXPECT(results_vector == gold); } +TEST_CASE(reduce_max_variable_axes01) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_max"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0, 1}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{11, 12}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_max_axis02) { migraphx::program p; @@ -98,3 +173,28 @@ TEST_CASE(reduce_max_axis02) std::vector gold{10, 12}; EXPECT(results_vector == gold); } + +TEST_CASE(reduce_max_variable_axes02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_max"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{10, 12}; + EXPECT(results_vector == gold); +} diff --git a/test/ref/reduce_mean.cpp b/test/ref/reduce_mean.cpp index fb7c2b421c3..97806f2c999 100644 --- a/test/ref/reduce_mean.cpp +++ b/test/ref/reduce_mean.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,6 +46,31 @@ TEST_CASE(reduce_mean_axis02) EXPECT(results_vector == gold); } +TEST_CASE(reduce_mean_variable_axes02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_mean"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{5.5, 7.5}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_mean_axis1) { migraphx::program p; @@ -62,6 +87,31 @@ TEST_CASE(reduce_mean_axis1) EXPECT(results_vector == gold); } +TEST_CASE(reduce_mean_variable_axis1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_mean"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{2, 3, 6, 7, 10, 11}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_mean_axis12) { migraphx::program p; @@ -78,6 +128,31 @@ TEST_CASE(reduce_mean_axis12) EXPECT(results_vector == gold); } +TEST_CASE(reduce_mean_variable_axes12) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_mean"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{2.5, 6.5, 10.5}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_mean_axis2) { migraphx::program p; @@ -94,6 +169,31 @@ TEST_CASE(reduce_mean_axis2) EXPECT(results_vector == gold); } +TEST_CASE(reduce_mean_variable_axis2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_mean"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_mean_int) { migraphx::program p; @@ -109,3 +209,28 @@ TEST_CASE(reduce_mean_int) std::vector gold{2, 6, 10}; EXPECT(results_vector == gold); } + +TEST_CASE(reduce_mean_variable_axes12_int) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::int32_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_mean"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{2, 6, 10}; + EXPECT(results_vector == gold); +} diff --git a/test/ref/reduce_min.cpp b/test/ref/reduce_min.cpp index e1f1796bf03..a007a9a341e 100644 --- a/test/ref/reduce_min.cpp +++ b/test/ref/reduce_min.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,6 +46,31 @@ TEST_CASE(reduce_min_axis02) EXPECT(results_vector == gold); } +TEST_CASE(reduce_min_variable_axes02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_min"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{1, 3}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_min_axis1) { migraphx::program p; @@ -62,6 +87,31 @@ TEST_CASE(reduce_min_axis1) EXPECT(results_vector == gold); } +TEST_CASE(reduce_min_variable_axis1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_min"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{1, 2, 5, 6, 9, 10}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_min_axis12) { migraphx::program p; @@ -77,3 +127,53 @@ TEST_CASE(reduce_min_axis12) std::vector gold{1, 5, 9}; EXPECT(results_vector == gold); } + +TEST_CASE(reduce_min_variable_axes12) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_min"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{1, 5, 9}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_min_dynamic_variable_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_shape{migraphx::shape::float_type, {{2, 4, {2}}, {3, 5, {3}}}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_min"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + migraphx::shape x_fixed_shape{migraphx::shape::float_type, {2, 5}}; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + pm["x"] = migraphx::argument(x_fixed_shape, x_arg.data()); + std::vector axes_arg{0}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {1, 2, 3, 4, 5}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} diff --git a/test/ref/reduce_prod.cpp b/test/ref/reduce_prod.cpp index 9c2a9b180b5..ce034341b4f 100644 --- a/test/ref/reduce_prod.cpp +++ b/test/ref/reduce_prod.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -45,3 +45,28 @@ TEST_CASE(reduce_prod_axis0) std::vector gold{6, 18, 12, 18}; EXPECT(results_vector == gold); } + +TEST_CASE(reduce_prod_variable_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {4, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_prod"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 3, 2, 3}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{6, 18, 12, 18}; + EXPECT(results_vector == gold); +} diff --git a/test/ref/reduce_sum.cpp b/test/ref/reduce_sum.cpp index 51514a719c2..e48dbaf6c31 100644 --- a/test/ref/reduce_sum.cpp +++ b/test/ref/reduce_sum.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,6 +46,31 @@ TEST_CASE(reduce_sum_axis0) EXPECT(results_vector == gold); } +TEST_CASE(reduce_sum_variable_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_sum"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{15, 18, 21, 24}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_sum_axis02) { migraphx::program p; @@ -62,6 +87,31 @@ TEST_CASE(reduce_sum_axis02) EXPECT(results_vector == gold); } +TEST_CASE(reduce_sum_variable_axes02) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_sum"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{0, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{33, 45}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_sum_axis1) { migraphx::program p; @@ -78,6 +128,31 @@ TEST_CASE(reduce_sum_axis1) EXPECT(results_vector == gold); } +TEST_CASE(reduce_sum_variable_axis1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_sum"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{4, 6, 12, 14, 20, 22}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_sum_axis12) { migraphx::program p; @@ -94,6 +169,31 @@ TEST_CASE(reduce_sum_axis12) EXPECT(results_vector == gold); } +TEST_CASE(reduce_sum_variable_axes12) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {2}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_sum"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{1, 2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{10, 26, 42}; + EXPECT(results_vector == gold); +} + TEST_CASE(reduce_sum_axis2) { migraphx::program p; @@ -109,3 +209,103 @@ TEST_CASE(reduce_sum_axis2) std::vector gold{3, 7, 11, 15, 19, 23}; EXPECT(results_vector == gold); } + +TEST_CASE(reduce_sum_variable_axis2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape x_shape{migraphx::shape::float_type, {3, 2, 2}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_sum"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + pm["x"] = migraphx::argument(x_shape, x_arg.data()); + std::vector axes_arg{2}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{3, 7, 11, 15, 19, 23}; + EXPECT(results_vector == gold); +} + +TEST_CASE(reduce_sum_dynamic_variable_axis0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape x_shape{migraphx::shape::float_type, {{2, 4, {2}}, {3, 5, {3}}}}; + auto x = mm->add_parameter("x", x_shape); + migraphx::shape axes_shape{migraphx::shape::int64_type, {1}}; + auto axes = mm->add_parameter("axes", axes_shape); + mm->add_instruction(migraphx::make_op("reduce_sum"), x, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + migraphx::shape x_fixed_shape{migraphx::shape::float_type, {2, 5}}; + std::vector x_arg{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + pm["x"] = migraphx::argument(x_fixed_shape, x_arg.data()); + std::vector axes_arg{0}; + pm["axes"] = migraphx::argument(axes_shape, axes_arg.data()); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {7, 9, 11, 13, 15}; + EXPECT(migraphx::verify::verify_rms_range(results_vector, gold)); +} + +TEST_CASE(reduce_sum_variable_dynamic_empty_axes) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + auto input = migraphx::literal{s, input_data}; + auto l0 = mm->add_literal(input); + const std::vector axes_dynamic_dims{{0, 3}}; + migraphx::shape axes_dynamic_shape{migraphx::shape::int64_type, axes_dynamic_dims}; + auto axes = mm->add_parameter("axes", axes_dynamic_shape); + + migraphx::parameter_map pm; + migraphx::shape axes_shape{migraphx::shape::int64_type, {0}}; + std::vector axes_data; + pm["axes"] = migraphx::argument(axes_shape, axes_data.data()); + + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), l0, axes); + p.compile(migraphx::make_target("ref")); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + EXPECT(results_vector == input_data); +} + +TEST_CASE(reduce_sum_variable_empty_axes) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; + std::vector input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + auto input = migraphx::literal{s, input_data}; + auto l0 = mm->add_literal(input); + migraphx::shape axes_shape{migraphx::shape::int64_type, {0}}; + auto axes = mm->add_parameter("axes", axes_shape); + + migraphx::parameter_map pm; + std::vector axes_data; + pm["axes"] = migraphx::argument(axes_shape, axes_data.data()); + + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {}}}), l0, axes); + p.compile(migraphx::make_target("ref")); + auto result = p.eval(pm).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + EXPECT(results_vector == input_data); +}