diff --git a/Jenkinsfile b/Jenkinsfile index de6fa059a0b..7cb184f51d1 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> } }, mlir_debug: rocmnode('mi100+') { cmake_build -> stage('MLIR Debug') { - withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot']) { + withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) { def sanitizers = "undefined" // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" diff --git a/docs/dev/env_vars.rst b/docs/dev/env_vars.rst index 36c830383a4..ad65dc83c6a 100644 --- a/docs/dev/env_vars.rst +++ b/docs/dev/env_vars.rst @@ -273,6 +273,11 @@ Performs exhaustive tuning for MLIR. Set to an integer greater than 1. Limits the number of solutions available to MLIR for tuning. +.. envvar:: MIGRAPHX_ENABLE_MLIR_INPUT_FUSION + +Set to "1", "enable", "enabled", "yes", or "true" to use. +Enable input fusions in MLIR. + CK vars ----------- diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 37dcfc0e028..515f051778c 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -91,93 +92,14 @@ MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins) return input_shape.ndim() == output_shape.ndim(); } -static void insert_params(module_ref sm, - const std::vector& inputs, - std::unordered_map& map_ins) -{ - auto n = sm->get_parameter_shapes().size(); - for(auto input : inputs) - { - if(contains(map_ins, input)) - continue; - map_ins[input] = - sm->add_parameter("x" + std::to_string(n++), input->get_shape().as_standard()); - } -} - -static auto insert_ins_in_submodule(module_ref sm, - instruction_ref ins, - std::unordered_map& map_ins) -{ - insert_params(sm, ins->inputs(), map_ins); - return sm->add_instructions({ins}, &map_ins); -} - -static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) -{ - std::unordered_map map_ins; - return insert_ins_in_submodule(sm, ins, map_ins); -} - -static auto -insert_module_in_submodule(module_ref sm, - const std::vector& inputs, - module_ref m, - std::unordered_map& map_ins, - module::inserter insert = nullptr) -{ - insert_params(sm, inputs, map_ins); - auto param_map = m->get_ins_param_map(inputs); - for(auto&& [input, param] : param_map) - { - map_ins[param] = map_ins.at(input); - } - return sm->add_instructions(m, &map_ins, std::move(insert)); -} - static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, - std::unordered_map& map_ins, - module::inserter insert = nullptr) + std::unordered_map* map_ins = nullptr, + module::inserter insert = nullptr) { - return insert_module_in_submodule( - sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert)); -} - -static auto insert_module_in_submodule(module_ref sm, - const std::vector& inputs, - module_ref m, - module::inserter insert = nullptr) -{ - std::unordered_map map_ins; - return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert)); -} - -static std::vector -find_inputs(const_module_ref sm, - const module& parent, - const std::unordered_map& map_ins) -{ - std::vector result; - std::map names; - for(auto&& [input, param] : map_ins) - { - if(not sm->has_instruction(param)) - continue; - if(param->name() != "@param") - continue; - if(not parent.has_instruction(input)) - continue; - auto v = param->get_operator().to_value(); - auto name = v.at("parameter").to(); - names[name] = input; - } - std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { - return p.second; - }); - assert(result.size() == sm->get_parameter_shapes().size()); - return result; + assert(ins->module_inputs().size() == 1); + return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert)); } static void create_reduce_modules(module_pass_manager& mpm) @@ -194,7 +116,7 @@ static void create_reduce_modules(module_pass_manager& mpm) mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++)); rm->set_bypass(); - rm->add_return(insert_ins_in_submodule(rm, ins)); + rm->add_return(rm->fuse({ins})); auto v = ins->get_operator().to_value(); mpm.get_module().replace_instruction( ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm}); @@ -286,23 +208,23 @@ struct find_pointwise_reduce rm->set_bypass(); std::unordered_map map_ins; // Insert pointwise - auto rins = insert_ins_in_submodule(rm, input, map_ins).front(); + auto rins = rm->fuse({input}, &map_ins).front(); map_ins[input] = rins; if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; auto fbroadcast = r.instructions["final_broadcast"]; - map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front(); + map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front(); if(fbroadcast != broadcast) map_ins[fbroadcast] = map_ins[broadcast]; } // Insert fused_reduce - rm->add_return(insert_module_in_submodule(rm, reduce, map_ins)); + rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins)); finalize_reduce_module(rm); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm}); } }; @@ -327,12 +249,12 @@ struct find_reduce_pointwise rm->set_bypass(); std::unordered_map map_ins; // Copy module instructions - insert_module_in_submodule(rm, reduce, map_ins); + insert_module_in_submodule(rm, reduce, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); - auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); + auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else @@ -340,11 +262,11 @@ struct find_reduce_pointwise map_ins[input] = rm->get_returns().front(); } - auto out = insert_ins_in_submodule(rm, pw, map_ins); + auto out = rm->fuse({pw}, &map_ins); rm->replace_return(out); finalize_reduce_module(rm); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); } }; @@ -372,12 +294,12 @@ struct find_reduce_reduce std::unordered_map map_ins; // Copy reduce1 instructions - insert_module_in_submodule(rm, reduce2, map_ins); + insert_module_in_submodule(rm, reduce2, &map_ins); if(contains(r.instructions, "broadcast")) { auto broadcast = r.instructions["broadcast"]; map_ins[broadcast->inputs().front()] = rm->get_returns().front(); - auto bout = insert_ins_in_submodule(rm, broadcast, map_ins); + auto bout = rm->fuse({broadcast}, &map_ins); map_ins[input] = bout.front(); } else @@ -385,11 +307,11 @@ struct find_reduce_reduce map_ins[input] = rm->get_returns().front(); } - auto out = insert_module_in_submodule(rm, reduce1, map_ins); + auto out = insert_module_in_submodule(rm, reduce1, &map_ins); rm->replace_return(out); finalize_reduce_module(rm); - auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins); + auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm); mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm}); } }; @@ -429,14 +351,14 @@ struct reduce_reshape : rewrite_reshapes_base auto* oldm = ins->module_inputs().front(); auto* sm = mpm.create_module(oldm->name() + "_reshape"); sm->set_bypass(); - insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) { - if(contains(sop.name(), "reduce")) - return make_op(sop.name(), {{"axes", axes}}); - if(sop.name() == "multibroadcast") - return make_op("multibroadcast", {{"out_lens", dims}}); - assert(sop.name() == "pointwise"); - return sop; - })); + sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) { + if(contains(sop.name(), "reduce")) + return make_op(sop.name(), {{"axes", axes}}); + if(sop.name() == "multibroadcast") + return make_op("multibroadcast", {{"out_lens", dims}}); + assert(sop.name() == "pointwise"); + return sop; + })); return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm}); } diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index a842c4aac28..e477f0c8804 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -245,6 +245,21 @@ struct MIGRAPHX_EXPORT module const std::vector& splits1, const std::vector& splits2) const; + // Fuse the instruction into the module by inserting the instructions and + // parameters for any missing inputs. + std::vector + fuse(const std::vector& inss, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); + + // Fuse another module into this module by inserting the instructions and + // parameters from the module + std::vector + fuse(const module& m, + const std::vector& inputs, + std::unordered_map* map_ins = nullptr, + inserter insert = nullptr); + void debug_print() const; void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins, diff --git a/src/include/migraphx/param_utils.hpp b/src/include/migraphx/param_utils.hpp index 3506fccbb30..1889d68fc38 100644 --- a/src/include/migraphx/param_utils.hpp +++ b/src/include/migraphx/param_utils.hpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -37,6 +38,13 @@ MIGRAPHX_EXPORT std::string param_name(std::size_t i, const std::string& prefix void sort_params(std::vector& params); +// Find the inputs for a module by finding instructions that are mapped to the +// parameters in the module +std::vector +find_inputs(const std::unordered_map& map_ins, + const_module_ref parent, + const_module_ref sub); + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP diff --git a/src/module.cpp b/src/module.cpp index fc93193acff..9d2229dd222 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -979,6 +979,63 @@ std::array module::split(const std::vector& inputs, + std::unordered_map& map_ins) +{ + auto n = m.get_parameter_shapes().size(); + for(auto input : inputs) + { + if(contains(map_ins, input)) + continue; + map_ins[input] = m.add_parameter(param_name(n++), input->get_shape().as_standard()); + } +} + +std::vector +module::fuse(const std::vector& inss, + std::unordered_map* map_ins, + module::inserter insert) +{ + std::unordered_map default_map_ins; + if(map_ins == nullptr) + map_ins = &default_map_ins; + std::vector inputs; + for(auto ins : inss) + { + for(auto input : ins->inputs()) + { + if(contains(inss, input)) + continue; + if(contains(inputs, input)) + continue; + inputs.push_back(input); + } + } + insert_params(*this, inputs, *map_ins); + return this->add_instructions(inss, map_ins, std::move(insert)); +} + +std::vector +module::fuse(const module& m, + const std::vector& inputs, + std::unordered_map* map_ins, + module::inserter insert) +{ + std::unordered_map default_map_ins; + if(map_ins == nullptr) + map_ins = &default_map_ins; + insert_params(*this, inputs, *map_ins); + auto param_map = m.get_ins_param_map(inputs); + for(auto&& [input, param] : param_map) + { + (*map_ins)[param] = map_ins->at(input); + } + return this->add_instructions(&m, map_ins, std::move(insert)); +} + void module_with_inputs::replace(instruction_ref ins, instruction_ref rep) { auto it = std::find(inputs.begin(), inputs.end(), ins); diff --git a/src/param_utils.cpp b/src/param_utils.cpp index 20c1ad8b0e2..9e985b7af71 100644 --- a/src/param_utils.cpp +++ b/src/param_utils.cpp @@ -25,6 +25,9 @@ #include #include #include +#include +#include +#include #include namespace migraphx { @@ -49,5 +52,31 @@ void sort_params(std::vector& params) })); } +std::vector +find_inputs(const std::unordered_map& map_ins, + const_module_ref parent, + const_module_ref sub) +{ + std::vector result; + std::map names; + for(auto&& [input, param] : map_ins) + { + if(sub != nullptr and not sub->has_instruction(param)) + continue; + if(param->name() != "@param") + continue; + if(parent != nullptr and not parent->has_instruction(input)) + continue; + auto v = param->get_operator().to_value(); + auto name = v.at("parameter").to(); + names[name] = input; + } + std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) { + return p.second; + }); + assert(not sub or result.size() == sub->get_parameter_shapes().size()); + return result; +} + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index d8a76be8641..e901dc24a2b 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include #include @@ -40,6 +42,7 @@ struct module; namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** * @brief Declares a new MIGraphX environment variable which forces to generate @@ -397,14 +400,27 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) return false; } +// A separate function so we can remove operators that are supported by mlir +// but not supported for an input fusion. +bool is_pointwise_op_supported_by_mlir_for_input(const instruction& i) +{ + return is_pointwise_op_supported_by_mlir(i); +} + MIGRAPHX_PRED_MATCHER(mlir_pointwise, instruction_ref ins) { if(ins->name() != "pointwise") return false; auto* pm = ins->module_inputs().front(); - return std::all_of(pm->begin(), pm->end(), [&](const auto& i) { - return is_pointwise_op_supported_by_mlir(i); - }); + return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir); +} + +MIGRAPHX_PRED_MATCHER(mlir_input_pointwise, instruction_ref ins) +{ + if(ins->name() != "pointwise") + return false; + auto* pm = ins->module_inputs().front(); + return std::all_of(pm->begin(), pm->end(), &is_pointwise_op_supported_by_mlir_for_input); } std::vector mlir_contiguous(module_pass_manager& mpm, @@ -579,6 +595,48 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op } }; +struct find_pointwise_mlir +{ + auto matcher() const + { + return match::name("gpu::mlir_op")(match::any_of[match::inputs()]( + mlir_input_pointwise(match::used_once()).bind("pointwise"))); + } + + static instruction_ref insert_pointwise(module& m, + instruction_ref ins, + const operation& op, + const std::vector& inputs, + const std::vector& mod_args) + { + // Only used in assert + (void)mod_args; + assert(mod_args.empty()); + return insert_common_op(m, ins, op, inputs); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto pw = r.instructions["pointwise"]; + + auto* mm = ins->module_inputs().front(); + auto* pm = pw->module_inputs().front(); + + std::unordered_map map_ins; + module_ref m = mpm.create_module(pm->name() + ":" + mm->name()); + m->set_bypass(); + auto rins = m->fuse(*pm, pw->inputs(), &map_ins, &insert_pointwise).front(); + map_ins[pw] = rins; + + auto ret = m->fuse(*mm, ins->inputs(), &map_ins); + m->add_return({ret}); + + auto inputs = find_inputs(map_ins, &mpm.get_module(), m); + mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m}); + } +}; + } // namespace #endif // MIGRAPHX_MLIR @@ -614,6 +672,11 @@ void fuse_mlir::apply(module_pass_manager& mpm) const mpm, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)}, find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::fast)}); + + mpm.run_pass(dead_code_elimination{}); + + if(enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + match::find_matches(mpm, find_pointwise_mlir{}); #else (void)mpm; #endif diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 6b646720d66..e124b47da84 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -32,6 +32,8 @@ #include #include +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); + void run_pass(migraphx::program& p) { migraphx::run_passes( @@ -100,6 +102,44 @@ TEST_CASE(dot_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_dot) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, b); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto b = mm->add_parameter("b", s); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto fused = + add_mlir(p2, + "main:pointwise0:mlir_dot1", + {x, y, b}, + {"x0", "x1", "x2"}, + [=](auto* pm, const auto& inputs) { + auto add = + pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + auto dot = pm->add_instruction(migraphx::make_op("dot"), add, inputs[2]); + return std::make_tuple(dot, dot); + }); + mm->add_return({fused}); + } + if(not migraphx::enabled(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION{})) + return; + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(int_quant_dot_abs) { migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}}; diff --git a/test/module_test.cpp b/test/module_test.cpp index 52981930bbb..3b910f8dfdf 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -28,10 +28,11 @@ #include #include #include -#include "test.hpp" #include #include +#include +#include migraphx::program create_program() { @@ -659,4 +660,35 @@ TEST_CASE(module_split3) EXPECT(bool{mods[2].inputs[1] == splits1.front()}); } +TEST_CASE(fuse_module) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + migraphx::module m1; + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add = add_pointwise(p, "main:pointwise0", {x, y}, single_pointwise("add")); + auto mul = add_pointwise(p, "main:pointwise1", {add, z}, single_pointwise("mul")); + + std::unordered_map map_ins; + auto rins = m1.fuse(*add->module_inputs().front(), add->inputs(), &map_ins).front(); + map_ins[add] = rins; + auto ret = m1.fuse(*mul->module_inputs().front(), mul->inputs(), &map_ins); + m1.add_return(ret); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x0", s); + auto y = m2.add_parameter("x1", s); + auto z = m2.add_parameter("x2", s); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + auto mul = m2.add_instruction(migraphx::make_op("mul"), add, z); + m2.add_return({mul}); + } + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/verify/test_add_dot.cpp b/test/verify/test_add_dot.cpp new file mode 100644 index 00000000000..ad23cc5acf6 --- /dev/null +++ b/test/verify/test_add_dot.cpp @@ -0,0 +1,49 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +template +struct test_add_dot : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{DType, {256, 256}}; + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add = mm->add_instruction(migraphx::make_op("add"), x, y); + auto dot = mm->add_instruction(migraphx::make_op("dot"), add, z); + mm->add_return({dot}); + return p; + } +}; + +template struct test_add_dot; +template struct test_add_dot;