diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 0492f0680b2..94c79f06976 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -109,7 +109,7 @@ add_library(migraphx_gpu compiler.cpp device_name.cpp fuse_ck.cpp - fuse_mlir.cpp + mlir_offload.cpp fuse_ops.cpp gather.cpp gemm_impl.cpp diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/mlir_offload.hpp similarity index 86% rename from src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp rename to src/targets/gpu/include/migraphx/gpu/mlir_offload.hpp index 22dcc4b6c58..c118f404fb9 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/mlir_offload.hpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP -#define MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP +#ifndef MIGRAPHX_GUARD_GPU_MLIR_OFFLOAD_HPP +#define MIGRAPHX_GUARD_GPU_MLIR_OFFLOAD_HPP #include @@ -35,10 +35,10 @@ namespace gpu { MIGRAPHX_GPU_EXPORT bool mlir_enabled(); -struct MIGRAPHX_GPU_EXPORT fuse_mlir +struct MIGRAPHX_GPU_EXPORT mlir_offload { context* ctx = nullptr; - std::string name() const { return "gpu::fuse_mlir"; } + std::string name() const { return "gpu::mlir_offload"; } void apply(module_pass_manager& mpm) const; }; @@ -46,4 +46,4 @@ struct MIGRAPHX_GPU_EXPORT fuse_mlir } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -#endif // MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP +#endif // MIGRAPHX_GUARD_GPU_MLIR_OFFLOAD_HPP diff --git a/src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp deleted file mode 100644 index 430bfcfbd3f..00000000000 --- a/src/targets/gpu/include/migraphx/gpu/standalone_mlir.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP -#define MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP - -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -struct module_pass_manager; - -namespace gpu { - -struct MIGRAPHX_GPU_EXPORT standalone_mlir -{ - context* ctx = nullptr; - std::string name() const { return "gpu::standalone_mlir"; } - void apply(module_pass_manager& mpm) const; -}; - -} // namespace gpu - -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx -#endif // MIGRAPHX_GUARD_GPU_STANDALONE_MLIR_HPP diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/mlir_offload.cpp similarity index 90% rename from src/targets/gpu/fuse_mlir.cpp rename to src/targets/gpu/mlir_offload.cpp index c397f4c0813..d22e680664c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/mlir_offload.cpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include -#include +#include "migraphx/shape.hpp" +#include #include #include #include @@ -147,9 +147,7 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) mm->add_instruction(gemm_based_op->get_operator(), imm_inputs); return {new_gemm_based_op, top_inputs}; } -} // namespace -namespace { MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) { if(ins->name() != "convolution" and ins->name() != "quant_convolution") @@ -164,7 +162,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) return true; } -struct find_mlir_op +struct find_mlir_fused_ops { auto matcher() const { @@ -302,46 +300,19 @@ struct find_mlir_op ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); } }; -} // namespace - -#endif // MIGRAPHX_MLIR - -void fuse_mlir::apply(module_pass_manager& mpm) const -{ -#ifdef MIGRAPHX_MLIR - match::find_matches(mpm, find_mlir_op{}); -#else - (void)mpm; -#endif -} - -#ifdef MIGRAPHX_MLIR - -namespace { -MIGRAPHX_PRED_MATCHER(is_supported_arch, instruction_ref) -{ - // TODO(ravil): debug - static std::unordered_set supported_consumer_archs{ - "gfx900", "gfx906", "gfx908", "gfx1030", "gfx940"}; - - // static std::unordered_set supported_consumer_archs{"gfx1030"}; - const auto device_name = trim(split_string(get_device_name(), ':').front()); - if(contains(supported_consumer_archs, device_name)) - return true; - return false; -} struct find_mlir_standalone_convolution_op { - auto matcher() const { return match::name("convolution")(is_supported_arch); } + auto matcher() const { return match::name("convolution"); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto conv_based_op = r.result; - // Only fuse with fp32/fp16 + // enable only for fp32/fp16/i8 types if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) { - return not contains({shape::type_t::float_type, shape::type_t::half_type}, - i->get_shape().type()); + return not contains( + {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, + i->get_shape().type()); })) return; @@ -354,14 +325,21 @@ struct find_mlir_standalone_convolution_op conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm}); } }; + } // namespace #endif // MIGRAPHX_MLIR -void standalone_mlir::apply(module_pass_manager& mpm) const +void mlir_offload::apply(module_pass_manager& mpm) const { #ifdef MIGRAPHX_MLIR - match::find_matches(mpm, find_mlir_standalone_convolution_op{}); + match::find_matches(mpm, find_mlir_fused_ops{}); + + const auto& device = this->ctx->get_current_device(); + if(starts_with(device.get_gfx_name(), "gfx110")) + { + match::find_matches(mpm, find_mlir_standalone_convolution_op{}); + } #else (void)mpm; #endif diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 9e16bdc87eb..41b589bc4f4 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -58,8 +58,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -143,8 +142,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), #endif dead_code_elimination{}, - enable_pass(mlir_enabled(), fuse_mlir{&ctx}), - enable_pass(mlir_enabled(), standalone_mlir{&ctx}), + enable_pass(mlir_enabled(), mlir_offload{&ctx}), dead_code_elimination{}, lowering{&ctx, options.offload_copy}, eliminate_contiguous{"gpu::contiguous"}, diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/mlir_offload.cpp similarity index 97% rename from test/gpu/fuse_mlir.cpp rename to test/gpu/mlir_offload.cpp index 7ce14d6a036..7ddbb025376 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/mlir_offload.cpp @@ -22,7 +22,7 @@ * THE SOFTWARE. */ #include -#include +#include #include #include #include @@ -34,7 +34,7 @@ void run_pass(migraphx::program& p) { - migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(p, {migraphx::gpu::mlir_offload{}, migraphx::dead_code_elimination{}}); } template diff --git a/test/gpu/quantization.cpp b/test/gpu/quantization.cpp index b048197eb8d..92fd3de9cf3 100644 --- a/test/gpu/quantization.cpp +++ b/test/gpu/quantization.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include