diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 5f13eda6901..7965d2b9400 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -326,20 +326,52 @@ struct find_mlir_standalone_convolution_op } }; +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_ENABLE_OPS); +bool is_self_decide() { return env(MIGRAPHX_MLIR_ENABLE_OPS::value()).empty(); } + +bool is_requested(std::string_view option) +{ + assert(enabled(MIGRAPHX_MLIR_ENABLE_OPS{})); + auto string_value = string_value_of(MIGRAPHX_MLIR_ENABLE_OPS{}, ""); + static const char delim{','}; + string_value.push_back(delim); + + const auto options = split_string(string_value, delim); + return contains(options, option); +} + +bool is_fusion_enabled() +{ + if(is_self_decide()) + { + return true; + } + return is_requested("fused"); +} + +bool is_standalone_convs_enabled(const std::string& gfx_name) +{ + const std::string navi_family{"gfx110"}; + if(is_self_decide() and starts_with(gfx_name, navi_family)) + { + return true; + } + return is_requested("conv"); +} } // namespace #endif // MIGRAPHX_MLIR -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_STANDALONE_CONVS); - void mlir_offload::apply(module_pass_manager& mpm) const { #ifdef MIGRAPHX_MLIR - match::find_matches(mpm, find_mlir_fused_ops{}); + if(is_fusion_enabled()) + { + match::find_matches(mpm, find_mlir_fused_ops{}); + } const auto& device = this->ctx->get_current_device(); - const std::string navi_family{"gfx110"}; - if(starts_with(device.get_gfx_name(), navi_family) or enabled(MIGRAPHX_MLIR_STANDALONE_CONVS{})) + if(is_standalone_convs_enabled(device.get_gfx_name())) { match::find_matches(mpm, find_mlir_standalone_convolution_op{}); }