Skip to content

Commit

Permalink
Merge branch 'develop' into remove-redundant-headers
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Sep 16, 2023
2 parents b181007 + 0064036 commit ceb4eff
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 55 deletions.
65 changes: 41 additions & 24 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ struct mlir_op
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
Expand Down Expand Up @@ -299,10 +302,8 @@ struct find_mlir_fused_ops
}
};

struct find_mlir_standalone_convolution_op
struct find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
Expand All @@ -324,14 +325,24 @@ struct find_mlir_standalone_convolution_op
}
};

struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }
};

struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("dot"); }
};

/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
Expand All @@ -346,31 +357,33 @@ bool is_requested(std::string_view option)
return contains(options, option);
}

bool is_fusion_enabled()
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
return true;
}
return is_requested("fused");
}

bool is_standalone_convs_enabled(context* ctx)
{
if(is_self_decide())
{
if(ctx == nullptr)
if(op_name == "fused")
{
return false;
return true;
}
else if(op_name == "convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
return false;
}
}
return is_requested("convolution");
return is_requested(op_name);
}
} // namespace

Expand All @@ -379,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
if(is_fusion_enabled())
if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_fused_ops{});
}

if(is_standalone_convs_enabled(this->ctx))
if(is_enabled("convolution", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}

if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
#else
(void)mpm;
#endif
}

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
3 changes: 2 additions & 1 deletion src/targets/gpu/include/migraphx/gpu/mlir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,

MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& inputs);
const std::vector<shape>& inputs,
bool exhaustive);

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
Expand Down
4 changes: 1 addition & 3 deletions src/targets/gpu/jit/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const operation&,
bool exhaustive) const
{
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(ctx, *smod, shapes);
return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
}
};

Expand Down
15 changes: 9 additions & 6 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,12 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str);
}

tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST
tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST
{
tuning_config tc;
run_high_level_pipeline();
auto tuning_mode = RocmlirTuningParamSetKindFull;
auto tuning_mode =
exhaustive ? RocmlirTuningParamSetKindFull : RocmlirTuningParamSetKindQuick;
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
Expand Down Expand Up @@ -914,15 +915,17 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs);
}

tuning_config
get_tuning_config_mlir(const context& migraphx_ctx, module m, const std::vector<shape>& inputs)
tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& inputs,
bool exhaustive)
{
adjust_param_shapes(m, inputs);

mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
return mp.get_tuning_config();
return mp.get_tuning_config(exhaustive);
}

#else
Expand Down Expand Up @@ -951,7 +954,7 @@ insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<ins
return m.end();
}

tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&)
tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&, bool)
{
return {};
}
Expand Down
39 changes: 21 additions & 18 deletions src/targets/gpu/prefuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
Expand All @@ -45,40 +46,42 @@ struct layernorm_base
}
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
std::size_t nargs = N;
if(not mods.empty())
{
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
nargs += pm->get_parameter_names().size() - 1;
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs);
auto s = inputs.front();
auto t = s.type();
if(not mods.empty())
t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {t, s.lens()};
}
else
{
return s.with_lens(t, s.lens());
}

// Scalar output if all inputs are scalar
if(inputs.front().elements() == 1 and
all_of(inputs, [](const auto& ss) { return ss.scalar(); }))
return inputs.front();
auto l_s = shape::from_permutation(
t, s.lens(), find_permutation(std::vector<shape>(inputs.begin(), inputs.begin() + N)));
// just prelayernorm or preadd_layernorm
if(nargs <= N)
return l_s;
// else, layernorm + pointwise fusion, preserve layout of fused op
std::vector<shape> lp_s(inputs.begin() + N, inputs.end());
lp_s.insert(lp_s.begin(), l_s);
return shape::from_permutation(t, s.lens(), find_permutation(lp_s));
}
};

struct layernorm : layernorm_base<layernorm, 0>
struct layernorm : layernorm_base<layernorm, 1>
{

std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);

struct add_layernorm : layernorm_base<add_layernorm, 1>
struct add_layernorm : layernorm_base<add_layernorm, 2>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
Expand Down
24 changes: 22 additions & 2 deletions test/verify/test_layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,17 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto pow = m.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
migraphx::make_op("multibroadcast", {{"out_lens", {dims.at(0), dims.at(1), 1}}}), epsilon);

auto add_epsilon = m.add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
auto sqrt = m.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto sqrt_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), sqrt);
auto div = m.add_instruction(migraphx::make_op("div"), sub, sqrt_mbcast);
auto scale_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), scale);
auto mul = m.add_instruction(migraphx::make_op("mul"), scale_mbcast, div);
auto mul = m.add_instruction(migraphx::make_op("mul"), div, scale_mbcast);

auto bias_mbcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), bias);
return m.add_instruction(migraphx::make_op("add"), mul, bias_mbcast);
Expand Down Expand Up @@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
return p;
}
};

struct test_add_layernorm_add_gemm_nonstd : verify_program<test_add_layernorm_add_gemm_nonstd>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape::from_permutation(migraphx::shape::float_type, {8, 1, 16}, {1, 2, 0});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8, 16, 64}});
auto add = mm->add_instruction(migraphx::make_op("add"), x, y);
auto layernorm_ins = add_layernorm(*mm, add, s.lens());
mm->add_instruction(migraphx::make_op("dot"), layernorm_ins, z);
return p;
}
};
16 changes: 15 additions & 1 deletion tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def parse_args():
type=str,
default='gpu',
help='Specify where the tests execute (ref, gpu)')
parser.add_argument('--fp16', action='store_true', help='Quantize to fp16')
parser.add_argument('--atol',
type=float,
default=1e-3,
help='The absolute tolerance parameter')
parser.add_argument('--rtol',
type=float,
default=1e-3,
help='The relative tolerance parameter')
args = parser.parse_args()

return args
Expand Down Expand Up @@ -257,6 +266,8 @@ def main():

# read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
if args.fp16:
migraphx.quantize_fp16(model)
model.compile(migraphx.get_target(target))

# get test cases
Expand All @@ -279,7 +290,10 @@ def main():
output_data = run_one_case(model, input_data)

# check output correctness
ret = check_correctness(gold_outputs, output_data)
ret = check_correctness(gold_outputs,
output_data,
atol=args.atol,
rtol=args.rtol)
if ret:
correct_num += 1

Expand Down

0 comments on commit ceb4eff

Please sign in to comment.