Skip to content

Commit

Permalink
Fuse inputs with mlir (#3010)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Jul 16, 2024
1 parent b4c29f0 commit ff81caa
Show file tree
Hide file tree
Showing 11 changed files with 329 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
}
}, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot']) {
withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1', 'MIGRAPHX_MLIR_USE_SPECIFIC_OPS=fused,attention,convolution,dot', 'MIGRAPHX_ENABLE_MLIR_INPUT_FUSION=1']) {
def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
Expand Down
5 changes: 5 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ Performs exhaustive tuning for MLIR.
Set to an integer greater than 1.
Limits the number of solutions available to MLIR for tuning.

.. envvar:: MIGRAPHX_ENABLE_MLIR_INPUT_FUSION

Set to "1", "enable", "enabled", "yes", or "true" to use.
Enable input fusions in MLIR.

CK vars
-----------

Expand Down
130 changes: 26 additions & 104 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/rewrite_reshapes.hpp>
#include <migraphx/param_utils.hpp>
#include <iterator>
#include <map>

Expand Down Expand Up @@ -91,93 +92,14 @@ MIGRAPHX_PRED_MATCHER(input_output_ndim_match, instruction_ref ins)
return input_shape.ndim() == output_shape.ndim();
}

static void insert_params(module_ref sm,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input : inputs)
{
if(contains(map_ins, input))
continue;
map_ins[input] =
sm->add_parameter("x" + std::to_string(n++), input->get_shape().as_standard());
}
}

static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins->inputs(), map_ins);
return sm->add_instructions({ins}, &map_ins);
}

static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_ins_in_submodule(sm, ins, map_ins);
}

static auto
insert_module_in_submodule(module_ref sm,
const std::vector<instruction_ref>& inputs,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
{
insert_params(sm, inputs, map_ins);
auto param_map = m->get_ins_param_map(inputs);
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, &map_ins, std::move(insert));
}

static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins,
module::inserter insert = nullptr)
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
module::inserter insert = nullptr)
{
return insert_module_in_submodule(
sm, ins->inputs(), ins->module_inputs().front(), map_ins, std::move(insert));
}

static auto insert_module_in_submodule(module_ref sm,
const std::vector<instruction_ref>& inputs,
module_ref m,
module::inserter insert = nullptr)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_module_in_submodule(sm, inputs, m, map_ins, std::move(insert));
}

static std::vector<instruction_ref>
find_inputs(const_module_ref sm,
const module& parent,
const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param] : map_ins)
{
if(not sm->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
if(not parent.has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(result.size() == sm->get_parameter_shapes().size());
return result;
assert(ins->module_inputs().size() == 1);
return sm->fuse(*ins->module_inputs().front(), ins->inputs(), map_ins, std::move(insert));
}

static void create_reduce_modules(module_pass_manager& mpm)
Expand All @@ -194,7 +116,7 @@ static void create_reduce_modules(module_pass_manager& mpm)
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass();

rm->add_return(insert_ins_in_submodule(rm, ins));
rm->add_return(rm->fuse({ins}));
auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction(
ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
Expand Down Expand Up @@ -286,23 +208,23 @@ struct find_pointwise_reduce
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = insert_ins_in_submodule(rm, input, map_ins).front();
auto rins = rm->fuse({input}, &map_ins).front();
map_ins[input] = rins;

if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
auto fbroadcast = r.instructions["final_broadcast"];
map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front();
map_ins[broadcast] = rm->fuse({broadcast}, &map_ins).front();
if(fbroadcast != broadcast)
map_ins[fbroadcast] = map_ins[broadcast];
}

// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, map_ins));
rm->add_return(insert_module_in_submodule(rm, reduce, &map_ins));
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
}
};
Expand All @@ -327,24 +249,24 @@ struct find_reduce_pointwise
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions
insert_module_in_submodule(rm, reduce, map_ins);
insert_module_in_submodule(rm, reduce, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}

auto out = insert_ins_in_submodule(rm, pw, map_ins);
auto out = rm->fuse({pw}, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
Expand Down Expand Up @@ -372,24 +294,24 @@ struct find_reduce_reduce

std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, map_ins);
insert_module_in_submodule(rm, reduce2, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}

auto out = insert_module_in_submodule(rm, reduce1, map_ins);
auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
Expand Down Expand Up @@ -429,14 +351,14 @@ struct reduce_reshape : rewrite_reshapes_base
auto* oldm = ins->module_inputs().front();
auto* sm = mpm.create_module(oldm->name() + "_reshape");
sm->set_bypass();
insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
sm->fuse(*oldm, inputs, nullptr, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
if(sop.name() == "multibroadcast")
return make_op("multibroadcast", {{"out_lens", dims}});
assert(sop.name() == "pointwise");
return sop;
}));
return mpm.get_module().insert_instruction(ins, fused_reduce{axes}, inputs, {sm});
}

Expand Down
15 changes: 15 additions & 0 deletions src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,21 @@ struct MIGRAPHX_EXPORT module
const std::vector<instruction_ref>& splits1,
const std::vector<instruction_ref>& splits2) const;

// Fuse the instruction into the module by inserting the instructions and
// parameters for any missing inputs.
std::vector<instruction_ref>
fuse(const std::vector<instruction_ref>& inss,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

// Fuse another module into this module by inserting the instructions and
// parameters from the module
std::vector<instruction_ref>
fuse(const module& m,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
inserter insert = nullptr);

void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(instruction_ref ins,
Expand Down
8 changes: 8 additions & 0 deletions src/include/migraphx/param_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/module_ref.hpp>
#include <vector>
#include <string>

Expand All @@ -37,6 +38,13 @@ MIGRAPHX_EXPORT std::string param_name(std::size_t i, const std::string& prefix

void sort_params(std::vector<instruction_ref>& params);

// Find the inputs for a module by finding instructions that are mapped to the
// parameters in the module
std::vector<instruction_ref>
find_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins,
const_module_ref parent,
const_module_ref sub);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_PARAM_UTILS_HPP
57 changes: 57 additions & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,63 @@ std::array<module::with_inputs, 3> module::split(const std::vector<instruction_r
return {{std::move(mods1[0]), std::move(mods2[0]), std::move(mods2[1])}};
}

// Insert parameters into the module based on the input instructions and then
// update the map_ins to map the input to the parameter.
static void insert_params(module& m,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = m.get_parameter_shapes().size();
for(auto input : inputs)
{
if(contains(map_ins, input))
continue;
map_ins[input] = m.add_parameter(param_name(n++), input->get_shape().as_standard());
}
}

std::vector<instruction_ref>
module::fuse(const std::vector<instruction_ref>& inss,
std::unordered_map<instruction_ref, instruction_ref>* map_ins,
module::inserter insert)
{
std::unordered_map<instruction_ref, instruction_ref> default_map_ins;
if(map_ins == nullptr)
map_ins = &default_map_ins;
std::vector<instruction_ref> inputs;
for(auto ins : inss)
{
for(auto input : ins->inputs())
{
if(contains(inss, input))
continue;
if(contains(inputs, input))
continue;
inputs.push_back(input);
}
}
insert_params(*this, inputs, *map_ins);
return this->add_instructions(inss, map_ins, std::move(insert));
}

std::vector<instruction_ref>
module::fuse(const module& m,
const std::vector<instruction_ref>& inputs,
std::unordered_map<instruction_ref, instruction_ref>* map_ins,
module::inserter insert)
{
std::unordered_map<instruction_ref, instruction_ref> default_map_ins;
if(map_ins == nullptr)
map_ins = &default_map_ins;
insert_params(*this, inputs, *map_ins);
auto param_map = m.get_ins_param_map(inputs);
for(auto&& [input, param] : param_map)
{
(*map_ins)[param] = map_ins->at(input);
}
return this->add_instructions(&m, map_ins, std::move(insert));
}

void module_with_inputs::replace(instruction_ref ins, instruction_ref rep)
{
auto it = std::find(inputs.begin(), inputs.end(), ins);
Expand Down
29 changes: 29 additions & 0 deletions src/param_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include <migraphx/param_utils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/module.hpp>
#include <migraphx/ranges.hpp>
#include <map>
#include <cmath>

namespace migraphx {
Expand All @@ -49,5 +52,31 @@ void sort_params(std::vector<instruction_ref>& params)
}));
}

std::vector<instruction_ref>
find_inputs(const std::unordered_map<instruction_ref, instruction_ref>& map_ins,
const_module_ref parent,
const_module_ref sub)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param] : map_ins)
{
if(sub != nullptr and not sub->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
if(parent != nullptr and not parent->has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(not sub or result.size() == sub->get_parameter_shapes().size());
return result;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Loading

0 comments on commit ff81caa

Please sign in to comment.