Skip to content

Commit

Permalink
Added support for standalone mlir-conv
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Aug 23, 2023
1 parent 5bfc7a2 commit 7f14839
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 30 deletions.
113 changes: 83 additions & 30 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/standalone_mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
Expand Down Expand Up @@ -119,7 +120,36 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op);

namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
} // namespace

namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
Expand Down Expand Up @@ -163,34 +193,6 @@ struct find_mlir_op
return ins_map;
}

std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}

// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
Expand Down Expand Up @@ -300,10 +302,9 @@ struct find_mlir_op
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
}
};

} // namespace

#endif
#endif // MIGRAPHX_MLIR

void fuse_mlir::apply(module_pass_manager& mpm) const
{
Expand All @@ -314,6 +315,58 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
#endif
}

#ifdef MIGRAPHX_MLIR

namespace {
MIGRAPHX_PRED_MATCHER(is_supported_arch, instruction_ref)
{
// TODO(ravil): debug
static std::unordered_set<std::string> supported_consumer_archs{
"gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"};

// static std::unordered_set<std::string> supported_consumer_archs{"gfx1030"};
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(supported_consumer_archs, device_name))
return true;
return false;
}

struct find_mlir_standalone_convolution
{
auto matcher() const { return match::name("convolution")(is_supported_arch); }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// Only fuse with fp32/fp16
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type},
i->get_shape().type());
}))
return;

static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};
} // namespace

#endif // MIGRAPHX_MLIR

void standalone_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_standalone_convolution{});
#else
(void)mpm;
#endif
}

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
Expand Down
47 changes: 47 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP
#define MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP

#include <migraphx/gpu/context.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module_pass_manager;

namespace gpu {

struct MIGRAPHX_GPU_EXPORT standalone_mlir
{
context* ctx = nullptr;
std::string name() const { return "gpu::standalone_mlir"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP
2 changes: 2 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/standalone_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
Expand Down Expand Up @@ -143,6 +144,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
#endif
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
enable_pass(mlir_enabled(), standalone_mlir{&ctx}),
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
Expand Down

0 comments on commit 7f14839

Please sign in to comment.