diff --git a/.github/workflows/sync-onnxrt-main.yaml b/.github/workflows/sync-onnxrt-main.yaml index 701d4ac4b64..f14128fb703 100644 --- a/.github/workflows/sync-onnxrt-main.yaml +++ b/.github/workflows/sync-onnxrt-main.yaml @@ -47,6 +47,7 @@ jobs: onnxruntime dependancies automated + skip bot checks assignees: TedThemistokleous reviewers: TedThemistokleous causten draft: false diff --git a/CHANGELOG.md b/CHANGELOG.md index dedcac00d05..dc183f127a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,49 @@ Full documentation for MIGraphX is available at [MIGraphX Documentation](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). -## MIGraphX 2.5 for ROCm 5.5.0 +## MIGraphX 2.7 for ROCm 5.7.0 +### Added +- Enabled hipRTC to not require dev packages for migraphx runtime and allow the ROCm install to be in a different directory than it was during build time +- Add support for multi-target execution +- Added Dynamic Batch support with C++/Python APIs +- Add migraphx.create_argument to python API +- Added dockerfile example for Ubuntu 22.04 +- Add TensorFlow supported ops in driver similar to exist onnx operator list +- Add a MIGRAPHX_TRACE_MATCHES_FOR env variable to filter the matcher trace +- Improved debugging by printing max,min,mean and stddev values for TRACE_EVAL = 2 +- use fast_math flag instead of ENV flag for GELU +- Print message from driver if offload copy is set for compiled program +### Optimizations +- Optimized for ONNX Runtime 1.14.0 +- Improved compile times by only building for the GPU on the system +- Improve performance of pointwise/reduction kernels when using NHWC layouts +- Load specific version of the migraphx_py library +- Annotate functions with the block size so the compiler can do a better job of optimizing +- Enable reshape on nonstandard shapes +- Use half HIP APIs to compute max and min +- Added support for broadcasted scalars to unsqueeze operator +- Improved multiplies with dot operator +- Handle broadcasts across dot and concat +- Add verify namespace for better symbol resolution +### Fixed +- Resolved accuracy issues with FP16 resnet50 +- Update cpp generator to handle inf from float +- Fix assertion error during verify and make DCE work with tuples +- Fix convert operation for NaNs +- Fix shape typo in API test +- Fix compile warnings for shadowing variable names +- Add missing specialization for the `nullptr` for the hash function +### Changed +- Bumped version of half library to 5.6.0 +- Bumped CI to support rocm 5.6 +- Make building tests optional +- replace np.bool with bool as per numpy request +### Removed +- Removed int8x4 rocBlas calls due to deprecation +- removed std::reduce usage since not all OS' support it + +## MIGraphX 2.5 for ROCm 5.5.0 ### Added - Y-Model feature to store tuning information with the optimized model - Added Python 3.10 bindings @@ -12,15 +53,11 @@ Full documentation for MIGraphX is available at [MIGraphX Documentation](https:/ - Build support for ROCm MLIR - Added migraphx-driver flag to print optimizations in python (--python) - Added JIT implementation of the Gather and Pad operator which results in better handling of larger tensor sizes. - - ### Optimizations - Improved performance of Transformer based models - Improved performance of the Pad, Concat, Gather, and Pointwise operators - Improved onnx/pb file loading speed - Added general optimize pass which runs several passes such as simplify_reshapes/algebra and DCE in loop. - - ### Fixed - Improved parsing Tensorflow Protobuf files - Resolved various accuracy issues with some onnx models @@ -29,6 +66,5 @@ Full documentation for MIGraphX is available at [MIGraphX Documentation](https:/ - Use --offload-arch instead of --cuda-gpu-arch for the HIP compiler - Changes inside JIT to use float accumulator for large reduce ops of half type to avoid overflow. - Changes inside JIT to temporarily use cosine to compute sine function. - ### Changed - Changed version/location of 3rd party build dependencies to pick up fixes diff --git a/Jenkinsfile b/Jenkinsfile index 4e36a4eea2e..1c402e96b1a 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -114,10 +114,10 @@ rocmtest clang_debug: rocmnode('cdna') { cmake_build -> cmake_build(flags: "-DCMAKE_BUILD_TYPE=release") stash includes: 'build/*.deb', name: 'migraphx-package' } -}, hidden_symbols: rocmnode('cdna') { cmake_build -> - stage('Hidden symbols') { - cmake_build(flags: "-DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_C_VISIBILITY_PRESET=hidden") - } +// }, hidden_symbols: rocmnode('cdna') { cmake_build -> +// stage('Hidden symbols') { +// cmake_build(flags: "-DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_C_VISIBILITY_PRESET=hidden") +// } }, all_targets_debug : rocmnode('cdna') { cmake_build -> stage('All targets Release') { cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DMIGRAPHX_ENABLE_FPGA=On") diff --git a/docs/.sphinx/requirements.in b/docs/.sphinx/requirements.in index b8366edf9ec..781cd3ac310 100644 --- a/docs/.sphinx/requirements.in +++ b/docs/.sphinx/requirements.in @@ -1 +1 @@ -rocm-docs-core==0.11.0 +rocm-docs-core>=0.20.0 diff --git a/docs/.sphinx/requirements.txt b/docs/.sphinx/requirements.txt index 8a597fc88c3..55d6c9742f0 100644 --- a/docs/.sphinx/requirements.txt +++ b/docs/.sphinx/requirements.txt @@ -1,29 +1,4 @@ # Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -# -# This file is autogenerated by pip-compile with Python 3.8 -# by the following command: -# -# pip-compile requirements.in -# accessible-pygments==0.0.4 # via pydata-sphinx-theme alabaster==0.7.13 @@ -46,7 +21,7 @@ charset-normalizer==3.1.0 # via requests click==8.1.3 # via sphinx-external-toc -cryptography==40.0.2 +cryptography==41.0.3 # via pyjwt deprecated==1.2.13 # via pygithub @@ -60,22 +35,16 @@ fastjsonschema==2.16.3 # via rocm-docs-core gitdb==4.0.10 # via gitpython -gitpython==3.1.31 +gitpython==3.1.32 # via rocm-docs-core idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.4.1 - # via sphinx -importlib-resources==5.12.0 - # via rocm-docs-core jinja2==3.1.2 # via # myst-parser # sphinx -linkify-it-py==1.0.3 - # via myst-parser markdown-it-py==2.2.0 # via # mdit-py-plugins @@ -86,7 +55,7 @@ mdit-py-plugins==0.3.5 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-parser[linkify]==1.0.0 +myst-parser==1.0.0 # via rocm-docs-core packaging==23.1 # via @@ -109,8 +78,6 @@ pyjwt[crypto]==2.6.0 # via pygithub pynacl==1.5.0 # via pygithub -pytz==2023.3 - # via babel pyyaml==6.0 # via # myst-parser @@ -120,7 +87,7 @@ requests==2.28.2 # via # pygithub # sphinx -rocm-docs-core==0.11.0 +rocm-docs-core>=0.20.0 # via -r requirements.in smmap==5.0.0 # via gitdb @@ -163,13 +130,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx typing-extensions==4.5.0 # via pydata-sphinx-theme -uc-micro-py==1.0.1 - # via linkify-it-py urllib3==1.26.15 # via requests wrapt==1.15.0 # via deprecated -zipp==3.15.0 - # via - # importlib-metadata - # importlib-resources diff --git a/mlir-requirements.txt b/mlir-requirements.txt index 00907ac14b3..2e3f001355f 100644 --- a/mlir-requirements.txt +++ b/mlir-requirements.txt @@ -21,4 +21,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. ##################################################################################### -ROCmSoftwarePlatform/rocMLIR@ea15b3597ce55b9088621818228595dd48fb6ec0 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCmSoftwarePlatform/rocMLIR@3657f509bfed86bb79d5c6e24aa237e48f09f9f3 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/eliminate_contiguous.cpp b/src/eliminate_contiguous.cpp index abfd6f5f019..ac1189761e1 100644 --- a/src/eliminate_contiguous.cpp +++ b/src/eliminate_contiguous.cpp @@ -35,6 +35,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS) + static bool try_compute_shape(instruction_ref ins, const std::vector& inputs, const std::vector& mods) @@ -78,14 +80,26 @@ static bool try_compute_shape(instruction_ref ins, return (arg == ins) ? new_shape : arg->get_shape(); }); - if(not try_compute_shape(output, input_shapes, mods)) + if(not try_compute_shape(output, input_shapes, output->module_inputs())) { return false; } } } + catch(const std::exception& e) + { + if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{})) + { + std::cout << "Exception: " << e.what() << std::endl; + } + return false; + } catch(...) { + if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{})) + { + std::cout << "Unknown exception" << std::endl; + } return false; } @@ -127,6 +141,11 @@ static void remove_contiguous(const std::string& op_name, module& m, F f) { if(arg->name() != op_name) continue; + if(enabled(MIGRAPHX_TRACE_ELIMINATE_CONTIGUOUS{})) + { + std::cout << "eliminate_contiguous: "; + m.debug_print(ins); + } auto prev = arg->inputs().front(); replace(new_args, arg, prev); if(try_compute_shape(ins, new_args, mod_args)) diff --git a/src/include/migraphx/check_shapes.hpp b/src/include/migraphx/check_shapes.hpp index f0799fe80d1..ced99e5d593 100644 --- a/src/include/migraphx/check_shapes.hpp +++ b/src/include/migraphx/check_shapes.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index c8428860c9e..9dce5397672 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -381,22 +381,24 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); - const bool trace_for = not trace_filter.empty() and - (contains(std::string{location.file_name()}, trace_filter) or - contains(std::string{location.function_name()}, trace_filter)); - bool match = false; + bool match = false; each_args( [&](auto&& m) { + const auto& matcher_name = get_type_name(m); + const bool trace_for = not trace_filter.empty() and + (contains(std::string{location.file_name()}, trace_filter) or + contains(std::string{location.function_name()}, trace_filter) or + contains(matcher_name, trace_filter)); if(match) return; - if(trace > 1 or trace_for) - std::cout << "Match: " << get_type_name(m) << std::endl; + if(trace > 1 and trace_for) + std::cout << "Match: " << matcher_name << std::endl; auto r = match_instruction(get_module(mod), ins, m.matcher()); if(r.result == get_module(mod).end()) return; if(trace > 0 or trace_for) { - std::cout << "Matched by " << get_type_name(m) << std::endl; + std::cout << "Matched by " << matcher_name << std::endl; get_module(mod).debug_print(ins); } // If its already invalid dont validate it again @@ -407,7 +409,7 @@ void find_matches_for(source_location location, Mod& mod, instruction_ref ins, M auto invalid = get_module(mod).validate(); if(invalid != get_module(mod).end()) { - std::cout << "Invalid program from match: " << get_type_name(m) << std::endl; + std::cout << "Invalid program from match: " << matcher_name << std::endl; std::cout << "Invalid instructions: " << std::endl; get_module(mod).debug_print(invalid->inputs()); get_module(mod).debug_print(invalid); diff --git a/src/include/migraphx/normalize_attributes.hpp b/src/include/migraphx/normalize_attributes.hpp index 40c1bdda9b0..e88003e6d85 100644 --- a/src/include/migraphx/normalize_attributes.hpp +++ b/src/include/migraphx/normalize_attributes.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -42,6 +43,36 @@ struct select_dependent_type template using dependent_type = typename select_dependent_type::type; +/** + * Used to normalize variable input axes at model runtime. + * Example: the axes inputs of the slice operator. + * + * \param axes the axes to normalize + * \param input_shape shape of the input tensor + * \param attr_val the normalize_axes attributes from the operator + * \param prefix error message prefix + */ +std::vector normalize_axes(const std::vector& axes, + const shape& input_shape, + const value& attr_val, + const std::string& prefix = ""); + +/** + * Used to normalize variable input axes at model runtime. + * Example: the starts and ends inputs of the slice operator. + * + * \param indices the indices to normalize + * \param axes which axes the indices apply over + * \param input_shape shape of the input tensor + * \param attr_val the normalize_axes attributes from the operator + * \param prefix error message prefix + */ +std::vector normalize_indices(const std::vector& indices, + const std::vector& axes, + const shape& input_shape, + const value& attr_val, + const std::string& prefix = ""); + MIGRAPHX_EXPORT bool normalize_attributes(operation& op, const shape& input_shape); diff --git a/src/include/migraphx/op/common.hpp b/src/include/migraphx/op/common.hpp index cb28b41ff24..e6b85f19e23 100644 --- a/src/include/migraphx/op/common.hpp +++ b/src/include/migraphx/op/common.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,8 +33,12 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { +// Specifies where to add the "extra" cell of padding if the +// calculated padding is an odd number. // Padding mode is default_ for fixed shape padding. -// same_lower and same_upper used for dynamic padding. +// same_lower and same_upper specify dynamic padding. +// The odd cell goes at the beginning of the dimension +// (same_lower) or end (same_upper). enum padding_mode_t { default_, // NOLINT diff --git a/src/include/migraphx/op/convolution.hpp b/src/include/migraphx/op/convolution.hpp index daa7d055169..ce2f157eabd 100644 --- a/src/include/migraphx/op/convolution.hpp +++ b/src/include/migraphx/op/convolution.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -206,6 +206,7 @@ struct convolution std::vector new_padding; if(padding_mode != op::padding_mode_t::default_) { + // auto-Calculate the padding sizes with calc_dyn_auto_pad auto input_lens = args[0].get_shape().lens(); auto weights_lens = args[1].get_shape().lens(); new_padding = @@ -217,6 +218,7 @@ struct convolution } else { + // Use the padding that was given new_padding = padding; if(output_shape.dynamic()) { diff --git a/src/include/migraphx/op/pooling.hpp b/src/include/migraphx/op/pooling.hpp index 3d31e5d9181..684d539e32b 100644 --- a/src/include/migraphx/op/pooling.hpp +++ b/src/include/migraphx/op/pooling.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -40,10 +41,20 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { +// The Pooling operator mostly follows the specifications for the Onnx pooling op. +// It assumes an NCHW layout, extended to support any number of spatial dimensions +// from 1 on up; dimensions are +// struct pooling { + // Class members mode, ceil_mode, padding_mode have similar names but refer to separate + // concepts. pooling_mode mode = {pooling_mode::average}; + // If the input has rank other than 4 then padding, stride, lengths must all be specified + // since the defaults have 2-dimensions. Exception: padding not required if + // padding_mode != default_ + // Padding along each spatial input dimension // Can be ndim or 2*ndim values where ndim is size of lengths // ndim values means pad the same before and after each dimension @@ -63,13 +74,14 @@ struct pooling // ceiling mode is a flag affecting output size // or equivalently, placements of the pooling kernel. - // When true, round the size upwards, possibly - // including partial placements where the kernel extends beyond the edge - // of input and even padding. When false, round down so that all + // When true, round the size upwards. When false, round down so that all // kernel placements fit but some input values may be dropped. bool ceil_mode = false; int lp_order = 2; + // Mode for auto padding. default_ indicates no auto padding. + padding_mode_t padding_mode = padding_mode_t::default_; + // Global pooling with dynamic shape input bool dyn_global = false; @@ -84,6 +96,7 @@ struct pooling { return pack(f(self.mode, "mode"), f(self.padding, "padding"), + f(self.padding_mode, "padding_mode"), f(self.stride, "stride"), f(self.lengths, "lengths"), f(self.ceil_mode, "ceil_mode"), @@ -97,7 +110,8 @@ struct pooling { if(dyn_global) return; - if((padding.size() != stride.size() and (padding.size()) != stride.size() * 2) or + if((padding_mode != default_ and padding.size() != stride.size() and + (padding.size()) != stride.size() * 2) or stride.size() != lengths.size()) { MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); @@ -137,8 +151,19 @@ struct pooling std::size_t padding_factor = 2 * padding[i]; if(padding.size() == 2 * kdims) padding_factor = padding[i] + padding[i + kdims]; - assert(input_lens[i + 2] + padding_factor >= lengths[i]); - std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i]; + std::size_t dim_size; + if(input_lens[i + 2] + padding_factor < lengths[i]) + { + if(padding_mode == default_) + MIGRAPHX_THROW("POOLING: not enough padding for the given kernel size"); + // lengths can be legitimately larger only if we're doing auto padding + // with a dynamic shape, in which case given padding is ignored. Set a dummy value. + dim_size = 2; + } + else + { + dim_size = input_lens[i + 2] + padding_factor - lengths[i]; + } std::size_t len = (ceil_mode) ? dim_size / stride[i] + @@ -151,17 +176,13 @@ struct pooling shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this, true}.has(1); + check_shapes{inputs, *this, true}.has(1).min_ndims(3); check_attribute_size(); const shape& input = inputs.at(0); - auto padding_size = padding.size(); + auto stride_size = stride.size(); size_t kdims = input.ndim() - 2; - if(input.ndim() < 3) - { - MIGRAPHX_THROW("POOLING: input must have 3 or more dimensions and be nonempty"); - } - if(input.ndim() * 2 != padding_size + 4 and input.ndim() != padding_size + 2) + if(input.ndim() != stride_size + 2) { MIGRAPHX_THROW("POOLING: input and attribute size mismatch!"); } @@ -179,6 +200,28 @@ struct pooling } return {input.type(), output_dyn_dims}; } + else if(padding_mode != default_) + { + const size_t num_spatial_dims = inputs[0].ndim() - 2; + const shape& x_shape = inputs[0]; + // same as convolution::dynamic_compute_shape() + + for(std::size_t i = 0; i < num_spatial_dims; ++i) + { + auto ceil_div = [](std::size_t x, std::size_t y) { return (x + y - 1) / y; }; + auto s = stride[i]; + + auto x = x_shape.dyn_dims()[i + 2]; + std::set optimals{}; + std::transform(x.optimals.begin(), + x.optimals.end(), + std::inserter(optimals, optimals.begin()), + [&](auto o) { return ceil_div(o, s); }); + output_dyn_dims.push_back( + shape::dynamic_dimension{ceil_div(x.min, s), ceil_div(x.max, s), optimals}); + } + return {input.type(), output_dyn_dims}; + } else { // does not compute optimals @@ -267,6 +310,7 @@ struct pooling Out& output, const In& input, const std::vector& kernel_dims, + const std::vector& padding_vals, Op op) const { auto in_s = input.get_shape(); @@ -283,9 +327,9 @@ struct pooling // For each spatial dimension, find starting and ending index of pooling kernel for(std::size_t dim = 2; dim < n_dim; ++dim) { - auto d_2 = dim - 2; - int start = - static_cast(idx_o[dim] * stride[d_2]) - static_cast(padding[d_2]); + auto d_2 = dim - 2; + int start = static_cast(idx_o[dim] * stride[d_2]) - + static_cast(padding_vals[d_2]); int end; // NOLINT if(count_include_pad and ceil_mode and (mode != pooling_mode::max)) @@ -297,7 +341,7 @@ struct pooling // Check if this kernel extends beyond the padding at end of dimension end = std::min(start + kernel_dims[d_2], - in_lens[dim] + static_cast(padding[d_2])); + in_lens[dim] + static_cast(padding_vals[d_2])); } else { @@ -316,6 +360,7 @@ struct pooling } shape win_shape{output_shape.type(), win_size}; + auto pool_size = win_shape.elements(); double output_val = op.template init(); @@ -354,30 +399,65 @@ struct pooling argument compute(const dyn_output& dyn_out, std::vector args) const { - argument result{dyn_out.computed_shape}; + argument result; auto input_lens = args[0].get_shape().lens(); std::vector kernel_dims; + shape output_shape; + // If we have to auto-calculate padding, it will be passed to calc_pooling() as an argument + // instead of the member variable padding. + std::vector temp_padding(padding); if(dyn_global) { + // for dynamic GlobalPooling, there's no padding kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end()); + output_shape = dyn_out.computed_shape; + result = dyn_out.computed_shape; } - else + else if((padding_mode != op::padding_mode_t::default_)) { + // if padding_mode is set, input was a dynamic size. Calculate padded size now. + + // kernel_lens is the same as kernel_dims, but prepended with the 2 non- + // spatial dimensions. For size computations, it's used like the weights + // tensor for convolutions. + std::vector kernel_lens; + kernel_lens.insert(kernel_lens.end(), input_lens.begin(), input_lens.begin() + 2); + kernel_lens.insert(kernel_lens.end(), lengths.begin(), lengths.end()); kernel_dims = this->lengths; + + auto type = args[0].get_shape().type(); + // dilation not currently supported for pooling, so default to all 1's + temp_padding = calc_dyn_auto_pad( + input_lens, kernel_lens, stride, {1, 1}, bool(padding_mode == op::same_upper)); + + output_shape = compute_padded_pool_shape( + args[0].get_shape(), shape(type, kernel_dims), temp_padding, stride, {1, 1}); + + result = argument(output_shape); + } + else // fixed/static input + { + kernel_dims = this->lengths; + output_shape = dyn_out.computed_shape; + result = dyn_out.computed_shape; } + + // Perform the computation and populate result visit_all(result, args[0])([&](auto output, auto input) { using type = typename decltype(output)::value_type; switch(mode) { case migraphx::op::pooling_mode::average: - calc_pooling(dyn_out.computed_shape, output, input, kernel_dims, avg_pool{}); + calc_pooling( + output_shape, output, input, kernel_dims, temp_padding, avg_pool{}); break; case migraphx::op::pooling_mode::max: - calc_pooling(dyn_out.computed_shape, output, input, kernel_dims, max_pool{}); + calc_pooling( + output_shape, output, input, kernel_dims, temp_padding, max_pool{}); break; case migraphx::op::pooling_mode::lpnorm: calc_pooling( - dyn_out.computed_shape, output, input, kernel_dims, lpnorm_pool{lp_order}); + output_shape, output, input, kernel_dims, temp_padding, lpnorm_pool{lp_order}); break; } }); diff --git a/src/include/migraphx/op/slice.hpp b/src/include/migraphx/op/slice.hpp index 7b77f333657..49db0012afc 100644 --- a/src/include/migraphx/op/slice.hpp +++ b/src/include/migraphx/op/slice.hpp @@ -27,19 +27,34 @@ #include #include #include -#include #include +#include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { +/** + * Slice operator that accepts variable axes, starts and ends. + * + * Attributes: + * axes: constant axes to slice over (optional) + * starts: constant slice starting indices (optional) + * ends: constant slice ending indices (optional) + * + * Parameters: + * data: the input tensor to slice (dynamic or static shape) + * input_starts: starting indicies of slice (optional, static shape) + * input_ends: ending indicies of slice (optional, static shape) + * input_axes: axes to slice over (optional, static shape) + */ struct slice { - std::vector axes; - std::vector starts; - std::vector ends; + std::vector axes{}; + std::vector starts{}; + std::vector ends{}; template static auto reflect(Self& self, F f) @@ -48,8 +63,8 @@ struct slice } /** - * Ensure that attribute vectors axes, starts, and ends are all the same size and values are in - * limits. + * Ensure that attribute vectors axes, starts, and ends are all the same size and values are + * within limits. */ value attributes() const { @@ -70,6 +85,90 @@ struct slice std::string name() const { return "slice"; } + /** + * Computes the slice output shape dimensions for given starts, ends,and axes. + * Templated to also handle tensor views. + * Possibily different type between [in_starts, in_ends] and [in_axes] if in_axes is this + * object's axes attribute. Assumes in_starts and in_ends are normalized; in_axes are valid. + */ + template + std::vector + lens_calc(const std::vector& lengths, A in_starts, A in_ends, B in_axes) const + { + auto new_lens = lengths; + for(std::size_t i = 0; i < in_axes.size(); ++i) + { + auto axis = in_axes[i]; + new_lens[axis] = in_ends[i] - in_starts[i]; + } + return new_lens; + } + + shape normalize_compute_shape(std::vector inputs) const + { + check_shapes{inputs, *this, true}.has(1, 3, 4); + auto input_shape = inputs[0]; + if(inputs.size() == 1) + { + auto t = input_shape.type(); + if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) { + return not input_shape.dyn_dims()[axis].is_fixed(); + })) + { + MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis "); + } + if(input_shape.dynamic()) + { + return shape{t, + lens_calc(input_shape.min_lens(), starts, ends, axes), + lens_calc(input_shape.max_lens(), starts, ends, axes), + {}}; + } + else + { + return shape{ + t, lens_calc(input_shape.lens(), starts, ends, axes), input_shape.strides()}; + } + } + else + { + // check that starts, ends, and optionally input_axes are all 1D, have the same + // dimension, and are static + check_shapes{inputs.begin() + 1, + inputs.end(), + std::string("SLICE: inputs (starts, ends, and input_axes)"), + false} + .only_dims(1) + .same_dims(); + auto dds = input_shape.to_dynamic().dyn_dims(); + if(inputs.size() == 3) + { + if(inputs[1].lens().at(0) != axes.size()) + { + MIGRAPHX_THROW("SLICE: inputs starts and ends do not have the same dimension " + "as the axes attribute"); + } + std::for_each(axes.cbegin(), axes.cend(), [&](const auto& axis) { + dds.at(axis) = {0, dds.at(axis).max}; + }); + } + else + { + // if axes is an input, then all the output dimensions could be 0 to the max value + std::transform(dds.begin(), dds.end(), dds.begin(), [](auto dd) { + return shape::dynamic_dimension{0, dd.max}; + }); + } + return shape{input_shape.type(), dds}; + } + } + + /** + * Calculates the starting offset for the sliced tensor. + * Used in compute when only data input and all other information are in the attributes. + * + * \param s static input shape + */ auto compute_offset(const shape& s) const { const std::vector& lens = s.lens(); @@ -90,80 +189,131 @@ struct slice offset += starts[axis] * strides[axis]; } } - return offset; + return offset * s.type_size(); } - shape normalize_compute_shape(std::vector inputs) const + /** + * Calculates the starting offset for the sliced tensor (for aliasing). + * Used when the starts and/or the axes are inputs. + * + * \param s static input shape + * \param input_starts starting indices of slice + * \param ax_vec axes to slice on + */ + template + auto compute_offset(const shape& s, const IndView& input_starts, const Axes& ax_vec) const { - check_shapes{inputs, *this, true}.has(1); - auto input_shape = inputs[0]; - auto t = input_shape.type(); - - // TODO: When support for dynamic shapes is added to normalize_attributes, - // remove this restriction. - if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return not input_shape.dyn_dims()[axis].is_fixed(); - })) + auto ret = 0; + for(std::size_t i = 0; i < ax_vec.size(); ++i) { - MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis "); + auto axis = ax_vec[i]; + ret += input_starts[i] * s.strides().at(axis); } + return ret * s.type_size(); + } + + std::unordered_map> + normalize_inputs(const shape& input_shape, + const std::vector& input_starts, + const std::vector& input_ends) const + { + auto attrs = this->attributes().at("normalize_axes"); + return {{"input_starts", + normalize_indices(input_starts, + this->axes, + input_shape, + attrs.at("starts"), + "Slice variable input_starts")}, + {"input_ends", + normalize_indices(input_ends, + this->axes, + input_shape, + attrs.at("ends"), + "Slice variable input_ends")}}; + } + + /** + * Three input version of the normalize_inputs. + * This one also checks that the input_axes are valid. + */ + std::unordered_map> + normalize_inputs(shape input_shape, + const std::vector& input_starts, + const std::vector& input_ends, + const std::vector& input_axes) const + { + auto attrs = this->attributes().at("normalize_axes"); + auto norm_axes = + normalize_axes(input_axes, input_shape, attrs.at("axes"), "Slice variable input_axes"); + return {{"input_starts", + normalize_indices(input_starts, + norm_axes, + input_shape, + attrs.at("starts"), + "Slice variable input_starts")}, + {"input_ends", + normalize_indices(input_ends, + norm_axes, + input_shape, + attrs.at("ends"), + "Slice variable input ends")}, + {"input_axes", norm_axes}}; + } - // For a static shape, old_lens will be adjusted to a new size - // for those axes that are sliced. - // For dynamic shape, the adjusted old_lens become the new max values, - // while updating the old mins and optimals if possible. - std::vector new_mins; - std::vector old_lens; - std::vector old_strides; - // Doesn't handle optimals - if(input_shape.dynamic()) + argument compute(const dyn_output& dyn_out, std::vector args) const + { + auto input = args[0]; + auto input_shape = input.get_shape(); + switch(args.size()) { - old_lens = input_shape.max_lens(); - new_mins = input_shape.min_lens(); + case 1: { + std::size_t offset = compute_offset(input_shape); + return {dyn_out.computed_shape, [=] { return input.data() + offset; }}; } - else - { - old_lens = input_shape.lens(); - // For static shape (including during eval step after a dynamic input) the strides are - // indexed into the pre-slice array, so they are larger than the apparent size of the - // resulting shape. - old_strides = input_shape.strides(); + case 3: { + shape calc_shape; + std::size_t offset = 0; + visit_all(args[1], args[2])([&](auto input_starts, auto input_ends) { + auto norm_inputs = normalize_inputs(input_shape, + input_starts.template to_vector(), + input_ends.template to_vector()); + offset = compute_offset(input_shape, norm_inputs.at("input_starts"), this->axes); + calc_shape = {input_shape.type(), + lens_calc(input_shape.lens(), + norm_inputs.at("input_starts"), + norm_inputs.at("input_ends"), + this->axes), + input_shape.strides()}; + }); + return {calc_shape, [=] { return input.data() + offset; }}; } - - std::vector new_lens = old_lens; - for(std::size_t i = 0; i < axes.size(); i++) - { - auto axis = axes[i]; - size_t sliced_length = ends[i] - starts[i]; - // A Numpy indexing convention: a slice size larger than the actual dimension - // is legal and the "ends" value is clipped to the axis size - new_lens[axis] = std::min(new_lens[axis], sliced_length); - if(input_shape.dynamic()) - { - // TODO: when non-fixed shape slicing is allowed, this will be different than - // sliced_length, making use of TBD start/end values. - std::size_t sliced_min_length = ends[i] - starts[i]; - // if the slice size is smaller than maxes but larger than mins - new_mins[axis] = std::min(sliced_min_length, new_mins[axis]); - } + case 4: { + shape calc_shape; + std::size_t offset = 0; + visit_all(args[1], args[2], args[3])( + [&](auto input_starts, auto input_ends, auto input_axes) { + auto norm_inputs = normalize_inputs(input_shape, + input_starts.template to_vector(), + input_ends.template to_vector(), + input_axes.template to_vector()); + offset = compute_offset( + input_shape, norm_inputs.at("input_starts"), norm_inputs.at("input_axes")); + calc_shape = shape{input_shape.type(), + lens_calc(input_shape.lens(), + norm_inputs.at("input_starts"), + norm_inputs.at("input_ends"), + norm_inputs.at("input_axes")), + input_shape.strides()}; + }); + return {calc_shape, [=] { return input.data() + offset; }}; } - if(input_shape.dynamic()) - { - return shape{t, new_mins, new_lens, {}}; + default: { + // Should never get here; covering in case some code change occurs + MIGRAPHX_THROW("SLICE: invalid number of inputs"); } - else - { - return shape{t, new_lens, old_strides}; } } - argument compute(const dyn_output& dyn_out, std::vector args) const - { - auto input = args[0]; - - auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size(); - return {dyn_out.computed_shape, [=] { return input.data() + offset; }}; - } std::ptrdiff_t output_alias(const std::vector&) const { return 0; } }; diff --git a/src/include/migraphx/pad_calc.hpp b/src/include/migraphx/pad_calc.hpp index 06c209f6073..a17c0bc3028 100644 --- a/src/include/migraphx/pad_calc.hpp +++ b/src/include/migraphx/pad_calc.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -62,6 +62,14 @@ shape compute_padded_shape(const shape& input, const std::vector& stride, const std::vector& dilation); +// Used for dynamic auto padding of pooling operators where padding needs to be computed at +// evaulation time. +shape compute_padded_pool_shape(const shape& input, + const shape& kernel, + const std::vector& padding, + const std::vector& stride, + const std::vector& dilation); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/memory_coloring.cpp b/src/memory_coloring.cpp index 548356da4f6..e6583184dae 100644 --- a/src/memory_coloring.cpp +++ b/src/memory_coloring.cpp @@ -23,9 +23,9 @@ */ #include #include -#include #include #include +#include #include #include #include @@ -382,7 +382,8 @@ void memory_coloring::apply(module& m) const auto s = ins->get_shape(); std::size_t offset = seg.first * alignment; assert(offset < n); - m.replace_instruction(ins, op::load{s, offset}, mem); + m.replace_instruction( + ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem); } // Replace zero allocation @@ -391,7 +392,8 @@ void memory_coloring::apply(module& m) const if(ins->name() != allocation_op) continue; assert(ins->get_shape().bytes() == 0); - m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem); + m.replace_instruction( + ins, make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", 0}}), mem); } // Remove scratch parameter if its not used diff --git a/src/normalize_attributes.cpp b/src/normalize_attributes.cpp index 36f6b1be17f..6402d4ee2ec 100644 --- a/src/normalize_attributes.cpp +++ b/src/normalize_attributes.cpp @@ -26,7 +26,7 @@ #include #include #include - +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -49,6 +49,10 @@ auto tune_attribute(const std::vector& vec, Message m) { std::vector result(vec); + if(result.empty()) + { + return result; + }; int64_t n_rank = input_shape.ndim(); std::vector vec_attrs = val.to_vector(); if(contains(vec_attrs, op::normalize_attribute::use_output)) @@ -188,20 +192,27 @@ bool normalize_attributes(operation& op, const shape& input_shape) auto val = op.to_value(); if(attrs.contains("normalize_padding")) { - auto padding = val.at(attrs.at("normalize_padding").to()); - auto padding_size = padding.size(); - auto padding_start = 2; - - if(padding_size == 2 * (input_shape.ndim() - padding_start)) - tuned = true; - else if(padding_size != (input_shape.ndim() - padding_start)) - MIGRAPHX_THROW("inconsistent padding size"); - else + bool use_auto_padding = + (val.contains("padding_mode") and + (val.at("padding_mode").to() != migraphx::op::padding_mode_t::default_)); + if(not use_auto_padding) { - auto result = tune_pad_attribute(padding); - val["padding"] = result; - op.from_value(val); - tuned = true; + auto padding = val.at(attrs.at("normalize_padding").to()); + auto padding_size = padding.size(); + auto padding_start = 2; + if(padding_size == 2 * (input_shape.ndim() - padding_start)) + tuned = true; + else if(padding_size != (input_shape.ndim() - padding_start)) + { + MIGRAPHX_THROW("normalize_attributes: inconsistent padding vector size "); + } + else + { + auto result = tune_pad_attribute(padding); + val["padding"] = result; + op.from_value(val); + tuned = true; + } } } if(not attrs.contains("normalize_axes")) @@ -251,5 +262,22 @@ bool normalize_attributes(operation& op, const shape& input_shape) return tuned; } +std::vector normalize_axes(const std::vector& axes, + const shape& input_shape, + const value& attr_val, + const std::string& prefix) +{ + return tune_attribute(axes, {}, attr_val, input_shape, [&] { return prefix; }); +} + +std::vector normalize_indices(const std::vector& indices, + const std::vector& axes, + const shape& input_shape, + const value& attr_val, + const std::string& prefix) +{ + return tune_attribute(indices, axes, attr_val, input_shape, [&] { return prefix; }); +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/onnx/parse_pooling.cpp b/src/onnx/parse_pooling.cpp index 556d3297061..4a9cb35c875 100644 --- a/src/onnx/parse_pooling.cpp +++ b/src/onnx/parse_pooling.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -151,26 +151,6 @@ struct parse_pooling : op_parser kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings"); } - if(contains(info.attributes, "auto_pad")) - { - if(in_shape.dynamic()) - { - MIGRAPHX_THROW( - "PARSE_POOLING: Auto padding pooling with dynamic input shape not supported"); - } - else - { - values["padding"].clear(); - // return paddings could be empty, then setting to 0 for no padding - cal_auto_padding_size(info, - values, - values["lengths"].to_vector(), - {1, 1}, - in_shape.lens(), - paddings); - } - } - if(paddings.size() != 2 * kdims) { paddings.resize(kdims * 2); @@ -192,6 +172,36 @@ struct parse_pooling : op_parser // used to calculate the supposed output shape std::vector orig_padding = paddings; + // TODO: add parsing for dilations + if(contains(info.attributes, "auto_pad") and + to_upper(info.attributes["auto_pad"].s()) != "NOTSET") + { + auto auto_pad = to_upper(info.attributes["auto_pad"].s()); + // don't use the given padding sizes, if any + // values["padding"].clear(); + if(in_shape.dynamic()) + { + // set padding_mode to trigger auto padding at runtime + bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos); + values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper) + : to_value(op::padding_mode_t::same_lower); + } + else + { + // Calculate auto padding + // dilations (argument 4) not supported; default to all 1's + cal_auto_padding_size(info, + values, + values["lengths"].to_vector(), + std::vector(in_shape.ndim() - 2, 1), + in_shape.lens(), + paddings); + values["padding"] = paddings; + // default padding_mode indicates that padding sizes are not calculated dynamically + values["padding_mode"] = migraphx::op::padding_mode_t::default_; + } + } + std::vector slice_start; std::vector slice_end; tune_padding_size(values, paddings, count_include_pad, slice_start); @@ -208,8 +218,9 @@ struct parse_pooling : op_parser orig_padding.insert(orig_padding.begin(), 2, 0); op::pad pad{orig_padding, 0.0f}; shape padded_shape = pad.compute_shape({l0->get_shape()}); - auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens(); + // make an op just to get its output shape + auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens(); // compute slice_end information slice_end.resize(slice_start.size()); std::transform(out_lens.begin() + 2, diff --git a/src/onnx/parse_slice.cpp b/src/onnx/parse_slice.cpp index 2bae22eef97..7287f6479be 100644 --- a/src/onnx/parse_slice.cpp +++ b/src/onnx/parse_slice.cpp @@ -34,16 +34,65 @@ namespace onnx { struct parse_slice : op_parser { + std::vector operators() const { return {{"Slice"}}; } + struct slice_desc + { + op::slice op; + std::vector op_args; + std::vector steps; + std::vector raxes; + + void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); } + + std::vector insert(instruction_ref arg) + { + std::vector result; + migraphx::argument arg_value = arg->eval(); + if(arg_value.empty()) + { + op_args.insert(op_args.begin(), arg); + } + else + { + arg_value.visit([&](auto s) { result.assign(s.begin(), s.end()); }); + } + return result; + } + }; + instruction_ref parse(const op_desc& /*opd*/, const onnx_parser& parser, - onnx_parser::node_info info, - std::vector args) const + const onnx_parser::node_info& info, + const std::vector& args) const { - op::slice op; + auto sd = construct_slice_desc(parser, info, args); + auto ins = info.add_instruction(sd.op, sd.op_args); + if(not sd.raxes.empty()) + { + ins = info.add_instruction(make_op("reverse", {{"axes", sd.raxes}}), ins); + } + // If any steps are other than default 1, add a "steps" op + if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return std::abs(s) != 1; })) + { + std::vector nsteps; + std::transform(sd.steps.begin(), + sd.steps.end(), + std::back_inserter(nsteps), + [](auto s) { return std::abs(s); }); + return ins = info.add_instruction( + make_op("step", {{"axes", sd.op.axes}, {"steps", nsteps}}), ins); + } + else + return ins; + } - std::vector steps; + slice_desc construct_slice_desc(const onnx_parser& parser, + onnx_parser::node_info info, + std::vector args) const + { + slice_desc sd; // slice can have up to 5 inputs, we first check the 5th one // to decide whether MIGRAPHX can handle this slice. @@ -51,89 +100,73 @@ struct parse_slice : op_parser { migraphx::argument step_arg = args.back()->eval(); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); - step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); + step_arg.visit([&](auto s) { sd.steps.assign(s.begin(), s.end()); }); } if(args.size() >= 4) { - migraphx::argument axes_arg = args.at(3)->eval(); - check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice"); - axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); }); + sd.op.axes = sd.insert(args.at(3)); } else if(contains(info.attributes, "axes")) { literal s = parser.parse_value(info.attributes.at("axes")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); + s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.axes)); }); } if(args.size() >= 3) { - migraphx::argument end_arg = args.at(2)->eval(); - check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice"); - end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); }); + sd.op.ends = sd.insert(args.at(2)); } else if(contains(info.attributes, "ends")) { literal s = parser.parse_value(info.attributes.at("ends")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); }); + s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.ends)); }); } if(args.size() >= 2) { - migraphx::argument start_arg = args.at(1)->eval(); - check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice"); - start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); }); + sd.op.starts = sd.insert(args.at(1)); } else if(contains(info.attributes, "starts")) { literal s = parser.parse_value(info.attributes.at("starts")); - s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); + s.visit([&](auto v) { copy(v, std::back_inserter(sd.op.starts)); }); } + // data input argument + sd.always_insert(args.at(0)); + // If axes arg is not given, the default is all of them. - if(op.axes.empty()) + if(sd.op.axes.empty() and sd.op_args.size() < 3) { std::vector axes(args[0]->get_shape().ndim()); std::iota(axes.begin(), axes.end(), int64_t{0}); - op.axes = axes; + sd.op.axes = axes; } - std::vector raxes; + if(not sd.steps.empty()) + { + if(sd.op.starts.empty() or sd.op.ends.empty()) + MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported"); + if(sd.op.axes.empty()) + MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported"); + } - assert(steps.empty() or steps.size() == op.axes.size()); - assert(op.axes.size() == op.starts.size()); - assert(op.axes.size() == op.ends.size()); + assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size()); // If any axes have negative step, prepare to add a "reverse" op - for(auto i : range(steps.size())) + for(auto i : range(sd.steps.size())) { - if(steps[i] >= 0) + if(sd.steps[i] >= 0) continue; - op.starts[i] += 1; - if(op.starts[i] == 0) - op.starts[i] = INT_MAX; - op.ends[i] += 1; - raxes.push_back(op.axes[i]); - std::swap(op.starts[i], op.ends[i]); - } - - auto ins = info.add_instruction(op, args[0]); - if(not raxes.empty()) - { - ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); + sd.op.starts[i] += 1; + if(sd.op.starts[i] == 0) + sd.op.starts[i] = INT_MAX; + sd.op.ends[i] += 1; + sd.raxes.push_back(sd.op.axes[i]); + std::swap(sd.op.starts[i], sd.op.ends[i]); } - // If any steps are other than default 1, add a "steps" op - if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; })) - { - std::vector nsteps; - std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) { - return std::abs(s); - }); - return ins = info.add_instruction( - make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins); - } - else - return ins; + return sd; } }; diff --git a/src/pad_calc.cpp b/src/pad_calc.cpp index 5662dfb4000..3fe9603aa45 100644 --- a/src/pad_calc.cpp +++ b/src/pad_calc.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,6 +52,11 @@ void calculate_padding(int64_t idx, } } +/** + * Given the input array dimensions; kernel (wei_lens); strides; and dilations, + * calculate the padding value in each dimension. + * + */ std::vector calc_dyn_auto_pad(const std::vector& input_lens, const std::vector& wei_lens, const std::vector& strides, @@ -60,6 +65,7 @@ std::vector calc_dyn_auto_pad(const std::vector& input { std::vector padding; assert(input_lens.size() >= 3); + assert(input_lens.size() == wei_lens.size()); std::size_t num_spatial_dims = input_lens.size() - 2; padding.resize(2 * num_spatial_dims); for(std::size_t i = 0; i < num_spatial_dims; i++) @@ -88,6 +94,11 @@ std::vector calc_dyn_auto_pad(const std::vector& input return padding; } +/** + * Calculate the correct output shape for a convolution with + * a given input size and other parameters. + * + */ shape compute_padded_shape(const shape& input, const shape& weights, const std::vector& padding, @@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input, return input.with_lens(output_lens); } +/** + * Calculate the correct output shape for a pooling with + * a given input size and other parameters. This uses + * the same formula for pooling that compute_padded_shape() uses + * for convolutions, but takes slightly different inputs. + * + */ +shape compute_padded_pool_shape(const shape& input, + const shape& kernel, + const std::vector& padding, + const std::vector& stride, + const std::vector& dilation) +{ + const size_t num_spatial_dims = input.lens().size() - 2; + + std::vector output_lens{input.lens()[0], input.lens()[1]}; + // calculate the output shape of the pooling: ((W - K + 2P) / S) + 1 + for(size_t i = 0; i < num_spatial_dims; ++i) + { + auto padding_factor = padding[i] + padding[i + num_spatial_dims]; + output_lens.push_back(std::size_t(std::max( + 1, + (input.lens()[i + 2] - (1 + dilation[i] * (kernel.lens()[i] - 1)) + padding_factor) / + stride[i] + + 1))); + } + return input.with_lens(output_lens); +} } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 32265887a98..0492f0680b2 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -123,6 +123,7 @@ add_library(migraphx_gpu lrn.cpp mlir.cpp multinomial.cpp + no_device.cpp nonzero.cpp pack_args.cpp pack_int8_args.cpp diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp index a49f4ff7ff9..8db24d51c50 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp @@ -41,7 +41,7 @@ struct index __device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT - __device__ index_int nlocal() const { return blockDim.x; } // NOLINT + __device__ index_int nlocal() const { return blockDim.x; } // NOLINT template __device__ void global_stride(index_int n, F f) const @@ -81,6 +81,12 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) dim3 nthreads(local); // cppcheck-suppress UseDeviceLaunch hipLaunchKernelGGL((launcher), nblocks, nthreads, 0, stream, f); + hipError_t kernel_launch_status = hipGetLastError(); + if(kernel_launch_status != hipSuccess) + { + MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + + std::string(hipGetErrorString(kernel_launch_status))); + } }; } diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 68a05bb2ce8..1bf3348162b 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -93,6 +93,8 @@ struct mlir_handle friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); } friend bool operator!=(ptr x, ptr y) { return not(x == y); } + + explicit operator bool() const noexcept { return obj != ptr(); } T obj{}; }; @@ -645,8 +647,8 @@ struct mlir_program void set_gpu_properties(const context& migraphx_ctx) { const auto& device = migraphx_ctx.get_current_device(); - target_arch = device.get_device_name(); - num_cu = device.get_cu_count(); + target_arch = device.get_device_name(); + num_cu = device.get_cu_count(); } std::pair get_launch_params() const @@ -867,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx, adjust_param_shapes(m, to_shapes(inputs)); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); + static std::mutex mutex; if(trace) + { + const std::lock_guard lock(mutex); std::cout << m << std::endl; + } mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(m); auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); if(trace) + { + const std::lock_guard lock(mutex); std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; + } auto co = mp.compile(solution); co.expected_inputs = to_shapes(inputs); co.output = m.get_output_shapes().front(); diff --git a/src/targets/gpu/no_device.cpp b/src/targets/gpu/no_device.cpp new file mode 100644 index 00000000000..a02d5254cb2 --- /dev/null +++ b/src/targets/gpu/no_device.cpp @@ -0,0 +1,28 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifdef __HIP_DEVICE_COMPILE__ +#error \ + "Device compilation not allowed for migraphx_gpu. Do not link with hip::device. Device code should go into migraphx_device or migraphx_kernels" +#endif diff --git a/test/eliminate_contiguous_test.cpp b/test/eliminate_contiguous_test.cpp index 5b5e4831039..7a1f4668134 100644 --- a/test/eliminate_contiguous_test.cpp +++ b/test/eliminate_contiguous_test.cpp @@ -196,15 +196,47 @@ TEST_CASE(contiguous_pointwise) migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y); auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb); auto add = add_pointwise(p, "main:pointwise0", {x, yc}, single_pointwise("add")); - mm->add_instruction(pass_op{}, add); + auto cadd = mm->add_instruction(migraphx::make_op("contiguous"), add); + mm->add_instruction(pass_op{}, cadd); } auto count = std::distance(mm->begin(), mm->end()); run_pass(*mm); - EXPECT(std::distance(mm->begin(), mm->end()) == (count - 1)); + EXPECT(std::distance(mm->begin(), mm->end()) == (count - 2)); EXPECT(std::none_of( mm->begin(), mm->end(), [](auto&& ins) { return ins.name() == "contiguous"; })); } +TEST_CASE(contiguous_nhwc_pointwise) +{ + auto s = + migraphx::shape::from_permutation(migraphx::shape::float_type, {2, 3, 8, 8}, {0, 2, 3, 1}); + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}}); + auto yb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y); + auto yc = mm->add_instruction(migraphx::make_op("contiguous"), yb); + auto add = add_pointwise(p1, "main:pointwise0", {x, yc}, single_pointwise("add")); + auto cadd = mm->add_instruction(migraphx::make_op("contiguous"), add); + mm->add_instruction(pass_op{}, cadd); + } + run_pass(*p1.get_main_module()); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {3}}); + auto yb = mm->add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 3, 8, 8}}}), y); + auto add = add_pointwise(p2, "main:pointwise0", {x, yb}, single_pointwise("add")); + auto cadd = mm->add_instruction(migraphx::make_op("contiguous"), add); + mm->add_instruction(pass_op{}, cadd); + } + EXPECT(p1 == p2); +} + TEST_CASE(slice_contiguous) { migraphx::module m; diff --git a/test/eliminate_pad_test.cpp b/test/eliminate_pad_test.cpp index 452a204920d..ec2fd94b8a6 100644 --- a/test/eliminate_pad_test.cpp +++ b/test/eliminate_pad_test.cpp @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include @@ -58,9 +58,8 @@ create_conv(migraphx::instruction_ref& l_img, migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; std::vector weights(4 * channels * 3 * 3); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); - migraphx::op::convolution op; - op.padding_mode = padding_mode; - return m.add_instruction(op, l_img, l_weights); + return m.add_instruction( + migraphx::make_op("convolution", {{"padding_mode", padding_mode}}), l_img, l_weights); } TEST_CASE(rewrite_pad) diff --git a/test/gpu/quantization.cpp b/test/gpu/quantization.cpp index a1a08f43a68..b048197eb8d 100644 --- a/test/gpu/quantization.cpp +++ b/test/gpu/quantization.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -90,7 +90,7 @@ TEST_CASE(int8_quantization) migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; auto pa = mm->add_parameter("a", sa); auto pb = mm->add_parameter("b", sb); - mm->add_instruction(migraphx::op::dot{}, pa, pb); + mm->add_instruction(migraphx::make_op("dot"), pa, pb); return p; }; diff --git a/test/inline_module_test.cpp b/test/inline_module_test.cpp index ef3fadbf240..6e05aceed02 100644 --- a/test/inline_module_test.cpp +++ b/test/inline_module_test.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include diff --git a/test/insert_pad_test.cpp b/test/insert_pad_test.cpp index 0c0662af82f..6954a217008 100644 --- a/test/insert_pad_test.cpp +++ b/test/insert_pad_test.cpp @@ -26,8 +26,8 @@ #include #include #include +#include #include -#include #include #include @@ -58,10 +58,11 @@ create_conv(migraphx::instruction_ref& l_img, migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}}; std::vector weights(4 * channels * 3 * 3); auto l_weights = m.add_literal(migraphx::literal{s_weights, weights}); - migraphx::op::convolution op; - op.padding_mode = padding_mode; - op.padding = {0, 0, 1, 1}; - return m.add_instruction(op, l_img, l_weights); + return m.add_instruction( + migraphx::make_op("convolution", + {{"padding_mode", padding_mode}, {"padding", {0, 0, 1, 1}}}), + l_img, + l_weights); } TEST_CASE(rewrite_pad) diff --git a/test/layout_nhwc.cpp b/test/layout_nhwc.cpp index 997024fb974..453d75cedb7 100644 --- a/test/layout_nhwc.cpp +++ b/test/layout_nhwc.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index 5f57948825d..16c7320f3b9 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -e5bb7aba502f5a8783de945258d226c092c14386 +a476dbf430ac8315550474a78d47bf182f202d7c diff --git a/test/onnx/averagepool_dyn_autopad_error_test.onnx b/test/onnx/averagepool_dyn_autopad_error_test.onnx deleted file mode 100644 index 524c0896399..00000000000 Binary files a/test/onnx/averagepool_dyn_autopad_error_test.onnx and /dev/null differ diff --git a/test/onnx/averagepool_dyn_autopad_test.onnx b/test/onnx/averagepool_dyn_autopad_test.onnx new file mode 100644 index 00000000000..248ae610114 Binary files /dev/null and b/test/onnx/averagepool_dyn_autopad_test.onnx differ diff --git a/test/onnx/averagepool_dyn_test.onnx b/test/onnx/averagepool_dyn_test.onnx index cb12c8ebe12..5bc54615e99 100644 Binary files a/test/onnx/averagepool_dyn_test.onnx and b/test/onnx/averagepool_dyn_test.onnx differ diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 54b0963ad18..73c910c4af6 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -270,23 +270,26 @@ def averagepool_dyn_test(): node = onnx.helper.make_node('AveragePool', inputs=['0'], outputs=['1'], - kernel_shape=[3, 3, 3]) - + kernel_shape=[3, 3, 3], + strides=[2, 2, 2], + pads=[1, 1, 1, 1, 1, 1]) return ([node], [x], [out]) @onnx_test() -def averagepool_dyn_autopad_error_test(): - x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 1, 5, 5]) - y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, 1, 5, 5]) +def averagepool_dyn_autopad_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, + [None, 3, 5, 5, 5]) + out = helper.make_tensor_value_info('1', TensorProto.FLOAT, + [None, 3, 3, 3, 3]) node = onnx.helper.make_node('AveragePool', - inputs=['x'], - outputs=['y'], - kernel_shape=[2, 2], - auto_pad='SAME_LOWER') - - return ([node], [x], [y]) + inputs=['0'], + outputs=['1'], + kernel_shape=[3, 3, 3], + strides=[2, 2, 2], + auto_pad='SAME_UPPER') + return ([node], [x], [out]) @onnx_test() @@ -3456,7 +3459,6 @@ def instance_norm_dyn_batch_test(): outputs=['3']) return ([node], [x, scale, bias], [y]) - return ([node], [x, scale, bias], [y]) @onnx_test() @@ -6414,6 +6416,30 @@ def slice_test(): return ([node], [x], [y]) +@onnx_test() +def slice_constant_test(): + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2]) + + x_tensor = helper.make_tensor(name='x_tensor', + data_type=TensorProto.FLOAT, + dims=[3, 2], + vals=[0, 1, 2, 3, 4, 5]) + + x = onnx.helper.make_node('Constant', + inputs=[], + outputs=['x'], + value=x_tensor) + + node = onnx.helper.make_node('Slice', + inputs=['x'], + axes=[0, 1], + starts=[1, 0], + ends=[2, 2], + outputs=['1']) + + return ([x, node], [], [y]) + + @onnx_test() def slice_dyn_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, None, 2]) @@ -6746,6 +6772,92 @@ def slice_max_end_test(): return ([node], [x], [y]) +@onnx_test() +def slice_var_input_static0(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2]) + starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2]) + ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('Slice', + inputs=['data', 'starts', 'ends'], + axes=[0, 1], + outputs=['output']) + + return ([node], [data, starts, ends], [output]) + + +@onnx_test() +def slice_var_input_static1(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2]) + starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2]) + ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2]) + axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('Slice', + inputs=['data', 'starts', 'ends', 'axes'], + outputs=['output']) + + return ([node], [data, starts, ends, axes], [output]) + + +@onnx_test() +def slice_var_input_dyn0(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2]) + starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2]) + ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('Slice', + inputs=['data', 'starts', 'ends'], + axes=[0, 1], + outputs=['output']) + + return ([node], [data, starts, ends], [output]) + + +@onnx_test() +def slice_var_input_dyn1(): + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2]) + starts = helper.make_tensor_value_info('starts', TensorProto.INT32, [2]) + ends = helper.make_tensor_value_info('ends', TensorProto.INT32, [2]) + axes = helper.make_tensor_value_info('axes', TensorProto.INT32, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node('Slice', + inputs=['data', 'starts', 'ends', 'axes'], + outputs=['output']) + + return ([node], [data, starts, ends, axes], [output]) + + +@onnx_test() +def slice_var_input_steps_error(): + step = np.array([2, 1]) + step_tensor = helper.make_tensor(name="step", + data_type=TensorProto.INT32, + dims=step.shape, + vals=step.astype(int)) + arg_step = helper.make_node("Constant", + inputs=[], + outputs=['arg_step'], + value=step_tensor) + + data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2]) + starts = helper.make_tensor_value_info('starts', TensorProto.FLOAT, [2]) + ends = helper.make_tensor_value_info('ends', TensorProto.FLOAT, [2]) + axes = helper.make_tensor_value_info('axes', TensorProto.FLOAT, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) + + node = onnx.helper.make_node( + 'Slice', + inputs=['data', 'starts', 'ends', 'axes', 'arg_step'], + outputs=['output']) + + return ([arg_step, node], [data, starts, ends, axes], [output]) + + @onnx_test() def softmax_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3]) diff --git a/test/onnx/onnx_rnn_test.cpp b/test/onnx/onnx_rnn_test.cpp index 7a8f3e855d4..5ba978ae617 100644 --- a/test/onnx/onnx_rnn_test.cpp +++ b/test/onnx/onnx_rnn_test.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 91b8d53c27d..b23ace35073 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -292,16 +292,21 @@ TEST_CASE(averagepool_3d_test) TEST_CASE(averagepool_dyn_test) { + // Pooling with dynamic input and no auto padding migraphx::program p; auto* mm = p.get_main_module(); auto l0 = mm->add_parameter( "0", {migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}, {5, 5}}}); - auto ret = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"padding", {0, 0, 0, 0, 0, 0}}, - {"stride", {1, 1, 1}}, - {"lengths", {3, 3, 3}}}), - l0); + auto ret = + mm->add_instruction(migraphx::make_op("pooling", + { + {"mode", migraphx::op::pooling_mode::average}, + {"stride", {2, 2, 2}}, + {"lengths", {3, 3, 3}}, + {"padding", {1, 1, 1, 1, 1, 1}}, + {"padding_mode", 0}, + }), + l0); mm->add_return({ret}); migraphx::onnx_options options; @@ -310,12 +315,29 @@ TEST_CASE(averagepool_dyn_test) EXPECT(p == prog); } -TEST_CASE(averagepool_dyn_autopad_error_test) +TEST_CASE(averagepool_dyn_autopad_test) { + // Pooling with dynamic input and auto padding. Default padding values will be overridden. + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter( + "0", {migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}, {5, 5}}}); + auto ret = mm->add_instruction( + migraphx::make_op("pooling", + { + {"mode", migraphx::op::pooling_mode::average}, + {"stride", {2, 2, 2}}, + {"lengths", {3, 3, 3}}, + {"padding", {0, 0, 0, 0, 0, 0}}, + {"padding_mode", migraphx::op::padding_mode_t::same_upper}, + }), + l0); + mm->add_return({ret}); + migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - EXPECT(test::throws( - [&] { migraphx::parse_onnx("averagepool_dyn_autopad_error_test.onnx", options); })); + auto prog = migraphx::parse_onnx("averagepool_dyn_autopad_test.onnx", options); + EXPECT(p == prog); } TEST_CASE(averagepool_dyn_asym_padding_error_test) @@ -374,16 +396,22 @@ TEST_CASE(averagepool_nt_cip_test) TEST_CASE(averagepool_same_lower_test) { + // auto_pad mode of SAME_LOWER with a static input shape is handled in parsing and + // padding_mode is set to default_ when the operation is created migraphx::program p; auto* mm = p.get_main_module(); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 5}}); - auto ins = mm->add_instruction(migraphx::make_op("pooling", - {{"mode", migraphx::op::pooling_mode::average}, - {"padding", {1, 1, 1, 1}}, - {"stride", {1, 1}}, - {"lengths", {2, 2}}}), - input); - auto ret = mm->add_instruction( + auto ins = mm->add_instruction( + migraphx::make_op("pooling", + { + {"mode", migraphx::op::pooling_mode::average}, + {"padding", {1, 1, 1, 1}}, + {"stride", {1, 1}}, + {"lengths", {2, 2}}, + {"padding_mode", migraphx::op::padding_mode_t::default_}, + }), + input); + auto ret = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2, 3}}, {"starts", {0, 0}}, {"ends", {5, 5}}}), ins); mm->add_return({ret}); auto prog = migraphx::parse_onnx("averagepool_same_lower_test.onnx"); @@ -6294,6 +6322,19 @@ TEST_CASE(slice_test) EXPECT(p == prog); } +TEST_CASE(slice_constant_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_literal(migraphx::literal{ + migraphx::shape{migraphx::shape::float_type, {3, 2}}, {0, 1, 2, 3, 4, 5}}); + mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {2, 2}}}), l0); + auto prog = optimize_onnx("slice_constant_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(slice_dyn_test) { migraphx::program p; @@ -6426,6 +6467,74 @@ TEST_CASE(slice_max_end_test) EXPECT(p == prog); } +TEST_CASE(slice_var_input_static0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 2}}); + auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int32_type, {2}}); + auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int32_type, {2}}); + mm->add_instruction(migraphx::make_op("slice", {{"axes", {0, 1}}}), data, starts, ends); + auto prog = optimize_onnx("slice_var_input_static0.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_var_input_static1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data = mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 2}}); + auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int64_type, {2}}); + auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int64_type, {2}}); + auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {2}}); + mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes); + auto prog = optimize_onnx("slice_var_input_static1.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(slice_var_input_dyn0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data = + mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}}); + auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int32_type, {2}}); + auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int32_type, {2}}); + auto ret = + mm->add_instruction(migraphx::make_op("slice", {{"axes", {0, 1}}}), data, starts, ends); + mm->add_return({ret}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {3, 8}; + auto prog = parse_onnx("slice_var_input_dyn0.onnx", options); + EXPECT(p == prog); +} + +TEST_CASE(slice_var_input_dyn1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto data = + mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}}); + auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int32_type, {2}}); + auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int32_type, {2}}); + auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int32_type, {2}}); + auto ret = mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes); + mm->add_return({ret}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {3, 8}; + auto prog = parse_onnx("slice_var_input_dyn1.onnx", options); + EXPECT(p == prog); +} + +TEST_CASE(slice_var_input_steps_error) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); })); +} + TEST_CASE(softmax_test) { migraphx::program p; diff --git a/test/onnx/slice_constant_test.onnx b/test/onnx/slice_constant_test.onnx new file mode 100644 index 00000000000..7c28d7d9b0f Binary files /dev/null and b/test/onnx/slice_constant_test.onnx differ diff --git a/test/onnx/slice_var_input_dyn0.onnx b/test/onnx/slice_var_input_dyn0.onnx new file mode 100644 index 00000000000..431d672e310 Binary files /dev/null and b/test/onnx/slice_var_input_dyn0.onnx differ diff --git a/test/onnx/slice_var_input_dyn1.onnx b/test/onnx/slice_var_input_dyn1.onnx new file mode 100644 index 00000000000..ebfe77bffd1 Binary files /dev/null and b/test/onnx/slice_var_input_dyn1.onnx differ diff --git a/test/onnx/slice_var_input_static0.onnx b/test/onnx/slice_var_input_static0.onnx new file mode 100644 index 00000000000..e587b2ef48c Binary files /dev/null and b/test/onnx/slice_var_input_static0.onnx differ diff --git a/test/onnx/slice_var_input_static1.onnx b/test/onnx/slice_var_input_static1.onnx new file mode 100644 index 00000000000..9b102a778bd --- /dev/null +++ b/test/onnx/slice_var_input_static1.onnx @@ -0,0 +1,26 @@ + slice_var_input_static1:´ +) +data +starts +ends +axesoutput"Sliceslice_var_input_static1Z +data +  + +Z +starts + + +Z +ends + + +Z +axes + + +b +output +  + +B \ No newline at end of file diff --git a/test/onnx/slice_var_input_steps_error.onnx b/test/onnx/slice_var_input_steps_error.onnx new file mode 100644 index 00000000000..62166ec1f60 --- /dev/null +++ b/test/onnx/slice_var_input_steps_error.onnx @@ -0,0 +1,29 @@ + slice_var_input_steps_error:ô +0arg_step"Constant* +value**Bstep  +3 +data +starts +ends +axes +arg_stepoutput"Sliceslice_var_input_steps_errorZ +data +  + +Z +starts + + +Z +ends + + +Z +axes + + +b +output +  + +B \ No newline at end of file diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index 4a5eef9eccb..f491c3e5b45 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index fb34a1206b9..0457f1bb3b8 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -24,7 +24,8 @@ #include #include #include -#include +#include +#include #include #include @@ -156,13 +157,13 @@ TEST_CASE(broadcast) { std::vector lens{1, 1}; migraphx::shape input{migraphx::shape::float_type, {2}}; - throws_shape(migraphx::op::broadcast{1, lens}, input); + throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input); } { std::vector lens{2, 2}; migraphx::shape input{migraphx::shape::float_type, {1, 2}}; - throws_shape(migraphx::op::broadcast{1, lens}, input); + throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input); } { @@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape) input); } -template -void test_softmax_variations() +void test_softmax_variations(const std::string& name) { { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, + migraphx::make_op(name, {{"axis", 0}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, + migraphx::make_op(name, {{"axis", 1}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, + migraphx::make_op(name, {{"axis", 2}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, + migraphx::make_op(name, {{"axis", 3}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; int axis = 4; - throws_shape(T{axis}, input); + throws_shape(migraphx::make_op(name, {{"axis", axis}}), input); } } -TEST_CASE(logsoftmax) { test_softmax_variations(); } +TEST_CASE(logsoftmax) { test_softmax_variations("logsoftmax"); } + +TEST_CASE(softmax) { test_softmax_variations("softmax"); } TEST_CASE(lstm) { @@ -2106,6 +2116,13 @@ TEST_CASE(pooling_shape3) input); } +TEST_CASE(pooling_shape4) +{ + migraphx::shape tiny_input{migraphx::shape::float_type, {4, 1}}; + throws_shape(migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), + tiny_input); +} + TEST_CASE(pooling_dyn_shape0) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}}; @@ -2328,47 +2345,54 @@ TEST_CASE(dqlinear_mismatch_type) throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros); } -template -void test_reduce_ops() +void test_reduce_ops(const std::string& name) { { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, + migraphx::make_op(name), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape( - migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, + migraphx::make_op(name, {{"axes", {0, 1, 2, 3}}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, + migraphx::make_op(name, {{"axes", {2, 3}}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, + migraphx::make_op(name, {{"axes", {0}}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input); + expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, + migraphx::make_op(name, {{"axes", {-1}}}), + input); } { migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; - throws_shape(T{{4}}, input); + throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input); } } // dynamic shape -template -void test_dyn_reduce_ops() +void test_dyn_reduce_ops(const std::string& name) { { migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; expect_shape( migraphx::shape{migraphx::shape::float_type, std::vector({{2, 3, {3}}, {1, 1}})}, - T{{-1}}, + migraphx::make_op(name, {{"axes", {-1}}}), input); } { @@ -2376,7 +2400,7 @@ void test_dyn_reduce_ops() expect_shape( migraphx::shape{migraphx::shape::float_type, std::vector({{1, 1}, {2, 4, {4}}})}, - T{{0}}, + migraphx::make_op(name, {{"axes", {0}}}), input); } { @@ -2385,24 +2409,24 @@ void test_dyn_reduce_ops() expect_shape( migraphx::shape{migraphx::shape::float_type, std::vector({{1, 1}, {1, 1}})}, - T{{}}, + migraphx::make_op(name), input); } { migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}}; - throws_shape(T{{4}}, input); + throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input); } } -TEST_CASE(reduce_max) { test_reduce_ops(); } -TEST_CASE(reduce_mean) { test_reduce_ops(); } -TEST_CASE(reduce_prod) { test_reduce_ops(); } -TEST_CASE(reduce_sum) { test_reduce_ops(); } +TEST_CASE(reduce_max) { test_reduce_ops("reduce_max"); } +TEST_CASE(reduce_mean) { test_reduce_ops("reduce_mean"); } +TEST_CASE(reduce_prod) { test_reduce_ops("reduce_prod"); } +TEST_CASE(reduce_sum) { test_reduce_ops("reduce_sum"); } -TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops(); } -TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops(); } -TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops(); } -TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops(); } +TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops("reduce_max"); } +TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops("reduce_mean"); } +TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops("reduce_prod"); } +TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops("reduce_sum"); } TEST_CASE(reshape_shape) { @@ -2822,7 +2846,7 @@ TEST_CASE(select_module_dyn) input); } -TEST_CASE(slice_shape) +TEST_CASE(slice_static_shape) { migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}}; expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, @@ -2840,6 +2864,67 @@ TEST_CASE(slice_shape) input); } +TEST_CASE(slice_var_inputs_static_shape0) +{ + migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}}; + migraphx::shape starts{migraphx::shape::int64_type, {2}}; + migraphx::shape ends{migraphx::shape::int64_type, {2}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}}, + migraphx::make_op("slice", {{"axes", {1, 2}}}), + input, + starts, + ends); +} + +TEST_CASE(slice_var_inputs_static_shape1) +{ + migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}}; + migraphx::shape starts{migraphx::shape::int64_type, {2}}; + migraphx::shape ends{migraphx::shape::int64_type, {2}}; + migraphx::shape axes{migraphx::shape::int64_type, {2}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}}, + migraphx::make_op("slice"), + input, + starts, + ends, + axes); +} + +TEST_CASE(slice_var_inputs_static_error0) +{ + migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}}; + migraphx::shape starts{migraphx::shape::int64_type, {2}}; + migraphx::shape ends{migraphx::shape::int64_type, {2}}; + migraphx::shape axes{migraphx::shape::int64_type, {3}}; + throws_shape(migraphx::make_op("slice"), input, starts, ends, axes); +} + +TEST_CASE(slice_var_inputs_dyn_shape0) +{ + migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}}; + migraphx::shape starts{migraphx::shape::int64_type, {2}}; + migraphx::shape ends{migraphx::shape::int64_type, {2}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 4}, {0, 4}}}, + migraphx::make_op("slice", {{"axes", {1, 2}}}), + input, + starts, + ends); +} + +TEST_CASE(slice_var_inputs_dyn_shape1) +{ + migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}}; + migraphx::shape starts{migraphx::shape::int64_type, {2}}; + migraphx::shape ends{migraphx::shape::int64_type, {2}}; + migraphx::shape axes{migraphx::shape::int64_type, {2}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 4}, {0, 4}}}, + migraphx::make_op("slice"), + input, + starts, + ends, + axes); +} + TEST_CASE(slice_dyn_shape0) { migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}}; @@ -2870,7 +2955,7 @@ TEST_CASE(slice_dyn_shape2) TEST_CASE(slice_dyn_shape3) { - // TODO: When variable dimension slicing is allowed, Slice to a size smaller than min. + // TODO: When non-fixed dimension slicing is allowed, Slice to a size smaller than min. // Until then, this action is an error. migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}}; throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), @@ -2901,8 +2986,6 @@ TEST_CASE(slice_dyn_shape5) input); } -TEST_CASE(softmax) { test_softmax_variations(); } - TEST_CASE(softmax_dyn0) { migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}; diff --git a/test/pad_calc_test.cpp b/test/pad_calc_test.cpp index 61554a41b26..7e21a9967f1 100644 --- a/test/pad_calc_test.cpp +++ b/test/pad_calc_test.cpp @@ -22,7 +22,6 @@ * THE SOFTWARE. */ #include -#include #include #include "test.hpp" diff --git a/test/quantization.cpp b/test/quantization.cpp index d9f94e32ee0..bb76152a6cb 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include diff --git a/test/ref_ops_test.cpp b/test/ref_ops_test.cpp index 07ad39cbde5..f1e67fc7d5d 100644 --- a/test/ref_ops_test.cpp +++ b/test/ref_ops_test.cpp @@ -613,6 +613,7 @@ TEST_CASE(avgpool_rank3_test) TEST_CASE(avgpool_dyn_test) { + // Dynamic input, no padding migraphx::program p; auto* mm = p.get_main_module(); auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}}; @@ -638,34 +639,99 @@ TEST_CASE(avgpool_dyn_test) TEST_CASE(avgpool_dyn_pad_test) { - // pooling with dynamic input and padding, ceiling mode for output size + // Dynamic input with explicit padding migraphx::program p; auto* mm = p.get_main_module(); - auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {1, 3}, {2, 4}, {2, 4}}}; + auto s = migraphx::shape{migraphx::shape::float_type, {{1, 3}, {3, 3}, {4, 4}}}; auto x = mm->add_parameter("X", s); mm->add_instruction(migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::average}, - {"lengths", {2, 2}}, - {"padding", {1, 0}}, - {"ceil_mode", true}, - {"stride", {2, 2}}}), + {"lengths", {2}}, + {"padding", {1}}, + {"stride", {1}}}), x); p.compile(migraphx::make_target("ref")); - std::vector data{1, 2, 3, 4, 5, 6}; + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}}; + migraphx::parameter_map params; + params["X"] = migraphx::argument(input_fixed_shape, data.data()); + auto result = p.eval(params).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + std::vector gold{ + 0.3, 0.25, 0.3, 0.25, 0.1, 0.8, 0.65, 0.7, 0.5, 0.1, 0.1, 0.4, 0.4, 0.35, 0.6}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} - // * * * - // 1 2 3 padding will look like this - // 4 5 6 The * are used when tiling the kernel - // * * * but are ignored in averaging +TEST_CASE(avgpool_dyn_auto_pad_test) +{ + // Pooling with dynamic input, multidimensional kernel and auto-padding + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = + migraphx::shape{migraphx::shape::float_type, {{1, 1}, {1, 3}, {2, 6, {2}}, {2, 6, {2}}}}; + auto x = mm->add_parameter("X", s); + mm->add_instruction( + migraphx::make_op("pooling", + { + {"mode", migraphx::op::pooling_mode::average}, + {"dyn_global", false}, + // non-default auto padding + {"padding_mode", migraphx::op::padding_mode_t::same_upper}, + {"lengths", {2, 3}}, + }), + x); + p.compile(migraphx::make_target("ref")); - migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 1, 2, 3}}; + std::vector data{1, 2, 3, 4}; + + // * 1 2 * auto padding should look like this + // * 3 4 * + // * * * * + + migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 1, 2, 2}}; migraphx::parameter_map params; params["X"] = migraphx::argument(input_fixed_shape, data.data()); auto result = p.eval(params).back(); std::vector results_vector(12); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); - std::vector gold{1.5, 3.0, 4.5, 6.0}; + std::vector gold{2.5, 2.5, 3.5, 3.5}; + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(avgpool_dyn_auto_pad_1d_test) +{ + // Dynamic input with auto padding (== padding_mode specified) + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::float_type, {{1, 3}, {3, 3}, {4, 4}}}; + auto x = mm->add_parameter("X", s); + mm->add_instruction( + migraphx::make_op("pooling", + {{"mode", migraphx::op::pooling_mode::average}, + {"lengths", {2}}, + // padding added will be {1, 0} to make output + // the same size as input + {"padding_mode", migraphx::op::padding_mode_t::same_lower}, + {"stride", {1}}}), + x); + p.compile(migraphx::make_target("ref")); + + std::vector data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6}; + migraphx::shape input_fixed_shape{migraphx::shape::float_type, {1, 3, 4}}; + migraphx::parameter_map params; + params["X"] = migraphx::argument(input_fixed_shape, data.data()); + auto result = p.eval(params).back(); + std::vector results_vector; + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + + // clang-format off + std::vector gold{0.3, 0.25, 0.3, 0.25, + 0.8, 0.65, 0.7, 0.5, + 0.1, 0.4, 0.4, 0.35}; + // clang-format on EXPECT(migraphx::verify::verify_range(results_vector, gold)); } @@ -1157,7 +1223,11 @@ TEST_CASE(conv_dyn_batch_test) auto input = mm->add_parameter("X", input_dyn_shape); auto weights = mm->add_parameter("W", weights_shape); - mm->add_instruction(migraphx::make_op("convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), + mm->add_instruction(migraphx::make_op("convolution", + { + {"padding", {1, 1}}, + {"stride", {2, 2}}, + }), input, weights); @@ -8153,6 +8223,115 @@ TEST_CASE(slice_test) } } +TEST_CASE(slice_var_inputs_static0) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 2 * 3); + std::iota(data.begin(), data.end(), 0); + migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data}); + migraphx::shape s1{migraphx::shape::int32_type, {1}}; + auto starts = mm->add_parameter("starts", s1); + auto ends = mm->add_parameter("ends", s1); + mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector start_data = {1}; + std::vector end_data = {3}; + params["starts"] = migraphx::argument(s1, start_data.data()); + params["ends"] = migraphx::argument(s1, end_data.data()); + auto result = p.eval(params).back(); + std::vector gold = {1, 2, 4, 5, 7, 8, 10, 11}; + std::vector results_vector(2 * 2 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(slice_var_inputs_static1) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 2 * 3); + std::iota(data.begin(), data.end(), 0); + migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data}); + migraphx::shape s1{migraphx::shape::int32_type, {1}}; + auto starts = mm->add_parameter("starts", s1); + auto ends = mm->add_parameter("ends", s1); + mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector start_data = {-2}; + std::vector end_data = {2831}; + params["starts"] = migraphx::argument(s1, start_data.data()); + params["ends"] = migraphx::argument(s1, end_data.data()); + auto result = p.eval(params).back(); + std::vector gold = {1, 2, 4, 5, 7, 8, 10, 11}; + std::vector results_vector(2 * 2 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(slice_var_inputs_static2) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector data(2 * 2 * 3); + std::iota(data.begin(), data.end(), 0); + migraphx::shape s0{migraphx::shape::float_type, {2, 2, 3}}; + auto l0 = mm->add_literal(migraphx::literal{s0, data}); + migraphx::shape s1{migraphx::shape::int64_type, {3}}; + auto starts = mm->add_parameter("starts", s1); + auto ends = mm->add_parameter("ends", s1); + auto axes = mm->add_parameter("axes", s1); + mm->add_instruction(migraphx::make_op("slice"), l0, starts, ends, axes); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + std::vector start_data = {0, 0, 0}; + std::vector end_data = {2, 2, 2}; + std::vector axes_data = {0, 1, 2}; + params["starts"] = migraphx::argument(s1, start_data.data()); + params["ends"] = migraphx::argument(s1, end_data.data()); + params["axes"] = migraphx::argument(s1, axes_data.data()); + auto result = p.eval(params).back(); + std::vector gold = {0, 1, 3, 4, 6, 7, 9, 10}; + std::vector results_vector(2 * 2 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + +TEST_CASE(slice_var_inputs_dyn) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}}; + auto input = mm->add_parameter("input", s0); + migraphx::shape s1{migraphx::shape::int32_type, {1}}; + auto starts = mm->add_parameter("starts", s1); + auto ends = mm->add_parameter("ends", s1); + mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map params; + migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}}; + std::vector input_data(2 * 2 * 3); + std::iota(input_data.begin(), input_data.end(), 0); + std::vector start_data = {1}; + std::vector end_data = {3}; + params["input"] = migraphx::argument(s2, input_data.data()); + params["starts"] = migraphx::argument(s1, start_data.data()); + params["ends"] = migraphx::argument(s1, end_data.data()); + auto result = p.eval(params).back(); + std::vector gold = {1, 2, 4, 5, 7, 8, 10, 11}; + std::vector results_vector(2 * 2 * 2); + result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); + EXPECT(migraphx::verify::verify_range(results_vector, gold)); +} + TEST_CASE(slice_dyn_test0) { // Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index e202bcc9b2b..b6425974d78 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1) { migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; - migraphx::op::broadcast b{1, {1, 2, 3, 3}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}); migraphx::module m1; { auto x = m1.add_parameter("x", outer); @@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2) { migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; - migraphx::op::broadcast b{1, {1, 2, 3, 3}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}}); auto create_program = [&] { migraphx::module m; auto x = m.add_parameter("x", outer); @@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add) TEST_CASE(simplify_inner_broadcast1) { - auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}}); migraphx::module m1; { auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); @@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1) TEST_CASE(simplify_inner_broadcast2) { - auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}}; + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}}); migraphx::module m1; { auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}}); @@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2) TEST_CASE(simplify_inner_broadcast_scalar) { - auto b = migraphx::op::multibroadcast{{32, 384}}; + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}}); migraphx::module m1; { auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); @@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar) { auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); - auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y); + auto yb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sumb = m2.add_instruction(b, sum); m2.add_instruction(pass_op{}, sumb); @@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar) TEST_CASE(simplify_inner_broadcast_different_dims) { - auto b = migraphx::op::multibroadcast{{2, 384, 768}}; + auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}}); migraphx::module m1; { auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); @@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) { auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); - auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y); + auto yb = + m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sumb = m2.add_instruction(b, sum); m2.add_instruction(pass_op{}, sumb); @@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) TEST_CASE(simplify_inner_broadcast_different_broadcasts) { - auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}}; - auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}}); + auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}}); migraphx::module m1; { auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}}); @@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}}); auto x = m1.add_parameter("x", s); auto y = m1.add_parameter("y", s); auto one = m1.add_literal(1); @@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}}); auto x = m2.add_parameter("x", s); auto y = m2.add_parameter("y", s); auto one = m2.add_literal(1); @@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}}); auto x = m1.add_parameter("x", s); auto y = m1.add_parameter("y", s); auto one = m1.add_literal(1); @@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}}); auto x = m2.add_parameter("x", s); auto y = m2.add_parameter("y", s); auto one = m2.add_literal(1); @@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}}); auto x = m1.add_parameter("x", s); auto y = m1.add_parameter("y", s); auto one = m1.add_literal(1); @@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}}); auto x = m2.add_parameter("x", s); auto y = m2.add_parameter("y", s); auto one = m2.add_literal(1); @@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}}); auto input = m2.add_parameter("input", s); auto one = m2.add_literal(1); auto two = m2.add_literal(2); @@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; - auto r = migraphx::op::reshape{{3, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); + auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}}); auto input = m2.add_parameter("input", s); auto one = m2.add_literal(1); auto two = m2.add_literal(2); @@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}}; migraphx::module m1; { - auto r = migraphx::op::reshape{{3, 2, 4}}; + auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); @@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); @@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}}); auto input = m2.add_parameter("input", s); auto one = m2.add_literal(1); auto two = m2.add_literal(2); @@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}), @@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}}); auto input = m2.add_parameter("input", s); auto slice = m2.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; migraphx::module m1; { - auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}}); auto input = m1.add_parameter("input", s); auto x = m1.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); @@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) migraphx::module m2; { - auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}}); auto input = m2.add_parameter("input", s); auto slice = m2.add_instruction( migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 0b16b46616a..3672ab85532 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1) std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); auto md = std::distance(m.begin(), new_mb); EXPECT(cd == md - 1); - EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 1); + EXPECT(new_concat->get_operator().to_value()["axis"].to() == 1); } TEST_CASE(concat_multibroadcasts2) @@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2) std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); auto md = std::distance(m.begin(), new_mb); EXPECT(cd == md - 1); - EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 0); + EXPECT(new_concat->get_operator().to_value()["axis"].to() == 0); } TEST_CASE(concat_multibroadcasts3) @@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3) std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); auto md = std::distance(m.begin(), new_mb); EXPECT(cd == md - 1); - EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 2); + EXPECT(new_concat->get_operator().to_value()["axis"].to() == 2); } TEST_CASE(concat_multibroadcasts4) @@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1) auto new_concat = std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); EXPECT(bool{new_concat != m.end()}); - EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 3); + EXPECT(new_concat->get_operator().to_value()["axis"].to() == 3); } TEST_CASE(concat_transpose2) @@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2) auto new_concat = std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); EXPECT(bool{new_concat != m.end()}); - EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 1); + EXPECT(new_concat->get_operator().to_value()["axis"].to() == 1); } TEST_CASE(concat_transpose3) @@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3) auto new_concat = std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); EXPECT(bool{new_concat != m.end()}); - EXPECT(migraphx::any_cast(new_concat->get_operator()).axis == 1); + EXPECT(new_concat->get_operator().to_value()["axis"].to() == 1); } TEST_CASE(concat_transpose4) diff --git a/test/tf/tf_test.cpp b/test/tf/tf_test.cpp index 4730e4a1080..82d5cb95e96 100644 --- a/test/tf/tf_test.cpp +++ b/test/tf/tf_test.cpp @@ -37,7 +37,6 @@ #include #include #include -#include #include @@ -840,12 +839,8 @@ TEST_CASE(slice_test) mm->add_literal(migraphx::literal{s0, {1, 0}}); mm->add_literal(migraphx::literal{s0, {2, -1}}); - migraphx::op::slice op; - op.starts = {1, 0}; - op.ends = {3, 10}; - op.axes = std::vector(num_axes); - std::iota(op.axes.begin(), op.axes.end(), 0); - mm->add_instruction(op, l0); + mm->add_instruction( + migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0); auto prog = optimize_tf("slice_test.pb", false); EXPECT(p == prog); @@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test) auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - std::size_t num_axes = 4; - migraphx::op::slice op; - op.starts = {0, 0, 0, 0}; - op.ends = {1, 1, 1, 5}; - op.axes = std::vector(num_axes); - std::iota(op.axes.begin(), op.axes.end(), 0); - auto l2 = mm->add_instruction(op, l1); + auto l2 = mm->add_instruction( + migraphx::make_op( + "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}), + l1); auto shrink_axis = 1; mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); auto prog = optimize_tf("stridedslice_test.pb", true); @@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test) auto* mm = p.get_main_module(); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); - std::size_t num_axes = 4; - migraphx::op::slice op; - op.starts = {0, 1, 1, 0}; - op.ends = {1, 3, 3, 10}; - op.axes = std::vector(num_axes); - std::iota(op.axes.begin(), op.axes.end(), 0); // add literals for starts, ends, and strides in tf (NHWC format) mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector{0, 1, 1, 0}); @@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test) auto l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); - auto l2 = mm->add_instruction(op, l1); + auto l2 = mm->add_instruction( + migraphx::make_op( + "slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}), + l1); auto l3 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); mm->add_return({l3}); diff --git a/test/verify/gemm_literal.cpp b/test/verify/gemm_literal.cpp index fc2195e21cd..3ea52af477b 100644 --- a/test/verify/gemm_literal.cpp +++ b/test/verify/gemm_literal.cpp @@ -25,7 +25,7 @@ #include "verify_program.hpp" #include #include -#include +#include struct gemm_literal : verify_program { @@ -38,7 +38,7 @@ struct gemm_literal : verify_program auto a = mm->add_literal(migraphx::generate_literal(a_shape)); auto b = mm->add_parameter("b", b_shape); - mm->add_instruction(migraphx::op::dot{}, a, b); + mm->add_instruction(migraphx::make_op("dot"), a, b); return p; } diff --git a/tools/build_and_test_onnxrt.sh b/tools/build_and_test_onnxrt.sh index 5a8b1d8b5f2..75915b15f6e 100755 --- a/tools/build_and_test_onnxrt.sh +++ b/tools/build_and_test_onnxrt.sh @@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt # Add newer cmake to the path export PATH="/opt/cmake/bin:$PATH" export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w" -./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root +./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --build_wheel --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root cd build/Linux/Release #Add test launcher for onnxrt tests diff --git a/tools/docker/sles.docker b/tools/docker/sles.docker index 9885c132a2f..1cc4657a13d 100644 --- a/tools/docker/sles.docker +++ b/tools/docker/sles.docker @@ -3,7 +3,7 @@ FROM registry.suse.com/suse/sle15:15.4 RUN sh -c 'echo -e "\ [rocm]\n\ name=rocm\n\ -baseurl=https://repo.radeon.com/rocm/zyp/5.5/main\n\ +baseurl=https://repo.radeon.com/rocm/zyp/5.6/main\n\ enabled=1\n\ gpgcheck=1\n\ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\ @@ -19,7 +19,12 @@ RUN zypper --gpg-auto-import-keys install -y \ gcc-c++ \ gdb \ git \ - python3-pip + python3-pip \ + rpm-build + +#addition of repos for packages +RUN OPENSUSE_REPO=https://download.opensuse.org/repositories && \ + zypper addrepo ${OPENSUSE_REPO}/devel:/languages:/perl/SLE_15_SP4/devel:languages:perl.repo # Workaround broken rocm packages RUN ln -s /opt/rocm-* /opt/rocm