diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index c33adc560d3..3d297652148 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ #include +#include #include #include #include @@ -119,7 +120,36 @@ struct mlir_op MIGRAPHX_REGISTER_OP(mlir_op); namespace { +std::tuple> +fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) +{ + std::vector top_inputs; + std::vector imm_inputs; + size_t input_cnt = 0; + for(instruction_ref input : gemm_based_op->inputs()) + { + std::vector op_stream; + while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) + { + op_stream.push_back(input->get_operator()); + input = input->inputs().at(0); + } + top_inputs.push_back(input); + instruction_ref prev_input = + mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape()); + for(const auto& op : reverse(op_stream)) + { + prev_input = mm->add_instruction(op, {prev_input}); + } + imm_inputs.push_back(prev_input); + } + instruction_ref new_gemm_based_op = + mm->add_instruction(gemm_based_op->get_operator(), imm_inputs); + return {new_gemm_based_op, top_inputs}; +} +} // namespace +namespace { MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) { if(ins->name() != "convolution" and ins->name() != "quant_convolution") @@ -163,34 +193,6 @@ struct find_mlir_op return ins_map; } - std::tuple> - fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const - { - std::vector top_inputs; - std::vector imm_inputs; - size_t input_cnt = 0; - for(instruction_ref input : gemm_based_op->inputs()) - { - std::vector op_stream; - while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name())) - { - op_stream.push_back(input->get_operator()); - input = input->inputs().at(0); - } - top_inputs.push_back(input); - instruction_ref prev_input = - mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape()); - for(const auto& op : reverse(op_stream)) - { - prev_input = mm->add_instruction(op, {prev_input}); - } - imm_inputs.push_back(prev_input); - } - instruction_ref new_gemm_based_op = - mm->add_instruction(gemm_based_op->get_operator(), imm_inputs); - return {new_gemm_based_op, top_inputs}; - } - // Whitelist supported fusion options, including imposing type constraints // for cases where MLIR only supports an operation (usually a pointwise function) // on particular types. @@ -300,10 +302,9 @@ struct find_mlir_op ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); } }; - } // namespace -#endif +#endif // MIGRAPHX_MLIR void fuse_mlir::apply(module_pass_manager& mpm) const { @@ -314,6 +315,58 @@ void fuse_mlir::apply(module_pass_manager& mpm) const #endif } +#ifdef MIGRAPHX_MLIR + +namespace { +MIGRAPHX_PRED_MATCHER(is_supported_arch, instruction_ref) +{ + // TODO(ravil): debug + static std::unordered_set supported_consumer_archs{ + "gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"}; + + // static std::unordered_set supported_consumer_archs{"gfx1030"}; + const auto device_name = trim(split_string(get_device_name(), ':').front()); + if(contains(supported_consumer_archs, device_name)) + return true; + return false; +} + +struct find_mlir_standalone_convolution +{ + auto matcher() const { return match::name("convolution")(is_supported_arch); } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto conv_based_op = r.result; + // Only fuse with fp32/fp16 + if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) { + return not contains({shape::type_t::float_type, shape::type_t::half_type}, + i->get_shape().type()); + })) + return; + + static size_t counter = 0; + module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++)); + mm->set_bypass(); + auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op); + mm->add_return({anchor_op}); + mpm.get_module().replace_instruction( + conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm}); + } +}; +} // namespace + +#endif // MIGRAPHX_MLIR + +void standalone_mlir::apply(module_pass_manager& mpm) const +{ +#ifdef MIGRAPHX_MLIR + match::find_matches(mpm, find_mlir_standalone_convolution{}); +#else + (void)mpm; +#endif +} + } // namespace gpu } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp new file mode 100644 index 00000000000..430bfcfbd3f --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp @@ -0,0 +1,47 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP +#define MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP + +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module_pass_manager; + +namespace gpu { + +struct MIGRAPHX_GPU_EXPORT standalone_mlir +{ + context* ctx = nullptr; + std::string name() const { return "gpu::standalone_mlir"; } + void apply(module_pass_manager& mpm) const; +}; + +} // namespace gpu + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 082bc5fa949..9e16bdc87eb 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -59,6 +59,7 @@ #include #include #include +#include #include #include #include @@ -143,6 +144,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti #endif dead_code_elimination{}, enable_pass(mlir_enabled(), fuse_mlir{&ctx}), + enable_pass(mlir_enabled(), standalone_mlir{&ctx}), dead_code_elimination{}, lowering{&ctx, options.offload_copy}, eliminate_contiguous{"gpu::contiguous"},