Skip to content

Commit

Permalink
Renamed fuse_mlir.hpp/cpp to mlir_offload.hpp/cpp
Browse files Browse the repository at this point in the history
* renamed the corresponding struct
* addressed suggestions of PR ROCm#2110
  • Loading branch information
ravil-mobile committed Aug 23, 2023
1 parent 63bdf8f commit 50b1321
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 96 deletions.
2 changes: 1 addition & 1 deletion src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ add_library(migraphx_gpu
compiler.cpp
device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp
mlir_offload.cpp
fuse_ops.cpp
gather.cpp
gemm_impl.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#ifndef MIGRAPHX_GUARD_GPU_MLIR_OFFLOAD_HPP
#define MIGRAPHX_GUARD_GPU_MLIR_OFFLOAD_HPP

#include <migraphx/gpu/context.hpp>

Expand All @@ -35,15 +35,15 @@ namespace gpu {

MIGRAPHX_GPU_EXPORT bool mlir_enabled();

struct MIGRAPHX_GPU_EXPORT fuse_mlir
struct MIGRAPHX_GPU_EXPORT mlir_offload
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_mlir"; }
std::string name() const { return "gpu::mlir_offload"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#endif // MIGRAPHX_GUARD_GPU_MLIR_OFFLOAD_HPP
47 changes: 0 additions & 47 deletions src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp

This file was deleted.

48 changes: 12 additions & 36 deletions src/targets/gpu/fuse_mlir.cpp → src/targets/gpu/mlir_offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/standalone_mlir.hpp>
#include <migraphx/gpu/mlir_offload.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
Expand Down Expand Up @@ -147,9 +146,7 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
} // namespace

namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
Expand All @@ -164,7 +161,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true;
}

struct find_mlir_op
struct find_mlir_fused_ops
{
auto matcher() const
{
Expand Down Expand Up @@ -302,38 +299,10 @@ struct find_mlir_op
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
}
};
} // namespace

#endif // MIGRAPHX_MLIR

void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_op{});
#else
(void)mpm;
#endif
}

#ifdef MIGRAPHX_MLIR

namespace {
MIGRAPHX_PRED_MATCHER(is_supported_arch, instruction_ref)
{
// TODO(ravil): debug
static std::unordered_set<std::string> supported_consumer_archs{
"gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"};

// static std::unordered_set<std::string> supported_consumer_archs{"gfx1030"};
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(supported_consumer_archs, device_name))
return true;
return false;
}

struct find_mlir_standalone_convolution_op
{
auto matcher() const { return match::name("convolution")(is_supported_arch); }
auto matcher() const { return match::name("convolution"); }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
Expand All @@ -354,14 +323,21 @@ struct find_mlir_standalone_convolution_op
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};

} // namespace

#endif // MIGRAPHX_MLIR

void standalone_mlir::apply(module_pass_manager& mpm) const
void mlir_offload::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
match::find_matches(mpm, find_mlir_fused_ops{});

const auto& device = this->ctx->get_current_device();
if(starts_with(device.get_gfx_name(), "gfx110"))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
#else
(void)mpm;
#endif
Expand Down
6 changes: 2 additions & 4 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/standalone_mlir.hpp>
#include <migraphx/gpu/mlir_offload.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
Expand Down Expand Up @@ -143,8 +142,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
enable_pass(mlir_enabled(), standalone_mlir{&ctx}),
enable_pass(mlir_enabled(), mlir_offload{&ctx}),
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
Expand Down
4 changes: 2 additions & 2 deletions test/gpu/fuse_mlir.cpp → test/gpu/mlir_offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir_offload.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
Expand All @@ -34,7 +34,7 @@

void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p, {migraphx::gpu::mlir_offload{}, migraphx::dead_code_elimination{}});
}

template <class F>
Expand Down
2 changes: 1 addition & 1 deletion test/gpu/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
#include <iostream>
#include <vector>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir_offload.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp>
Expand Down

0 comments on commit 50b1321

Please sign in to comment.