Skip to content

Commit

Permalink
GEMM pointwise fusion for hipBLASLt (#3662)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca authored Dec 5, 2024
1 parent 8576973 commit af3f716
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 15 deletions.
101 changes: 86 additions & 15 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/compile_hipblaslt.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/hip_gemm.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/array.hpp>
Expand All @@ -41,6 +43,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)
#if MIGRAPHX_USE_MIOPEN
struct fusion
Expand Down Expand Up @@ -555,20 +558,9 @@ struct find_conv_pointwise
};
#endif

#if MIGRAPHX_USE_ROCBLAS
struct find_gemm_pointwise
#if MIGRAPHX_USE_ROCBLAS or MIGRAPHX_USE_HIPBLASLT
struct gemm_pointwise
{
auto matcher() const
{
auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
auto binary_op = match::all_of(
match::nargs(3),
match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
}

// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
Expand Down Expand Up @@ -642,6 +634,22 @@ struct find_gemm_pointwise
return false;
}
}
};
#endif

#if MIGRAPHX_USE_ROCBLAS
struct find_rocblas_gemm_pointwise : gemm_pointwise
{
auto matcher() const
{
auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
auto binary_op = match::all_of(
match::nargs(3),
match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
}

void apply(module& m, const match::matcher_result& r) const
{
Expand Down Expand Up @@ -685,6 +693,66 @@ struct find_gemm_pointwise
};
#endif

#if MIGRAPHX_USE_HIPBLASLT
struct find_hipblas_gemm_pointwise : gemm_pointwise
{
auto matcher() const
{
auto gemm_op =
match::name("gpu::hipblaslt_op")(match::nargs(3), match::used_once()).bind("hip_gemm");
auto binary_op = match::all_of(
match::nargs(3),
match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["hip_gemm"];

auto gemm_op = any_cast<hipblaslt_op>(gemm_ins->get_operator()).op;

auto gemm = any_cast<hip_gemm<op::dot>>(gemm_op);

// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
if(ins->inputs().size() == 3)
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
{
return;
}
auto inputs = gemm_ins->inputs();
inputs.pop_back();
if(ins->inputs().size() == 3)
{
auto c_ins = r.instructions["c"];
shape s = c_ins->get_shape();
// const-fold input if not standard shape
// Updated for a case where "standard" shape has out-of-sequence strides
if(not s.standard() or s.normalize_standard() != s)
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
inputs.push_back(c_ins);
}
inputs.push_back(ins->inputs().back());

operation new_gemm_op = gemm;
auto new_ins = m.insert_instruction(
ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs);
m.replace_instruction(ins, new_ins);
}
};
#endif

struct find_contiguous_tranpose_gemm
{
auto matcher() const
Expand Down Expand Up @@ -903,10 +971,13 @@ void fuse_ops::apply(module& m) const
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}});
#endif
match::find_matches(m,
#if MIGRAPHX_USE_ROCBLAS
find_gemm_pointwise{},
match::find_matches(m, find_rocblas_gemm_pointwise{});
#endif
#if MIGRAPHX_USE_HIPBLASLT
match::find_matches(m, find_hipblas_gemm_pointwise{});
#endif
match::find_matches(m,
find_layernorm_pointwise{},
find_concat_pointwise{},
find_contiguous_tranpose_gemm{},
Expand Down
106 changes: 106 additions & 0 deletions test/gpu/fuse_gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <basic_ops.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/gpu/compile_hipblaslt.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/hip_gemm.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/allocate.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <pointwise.hpp>
#include <test.hpp>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM)

void run_lowering(migraphx::program& p, bool offload_copy = false)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(
*p.get_main_module(),
{migraphx::auto_contiguous{}, migraphx::gpu::lowering{&ctx, offload_copy}});
}

void run_fuse_ops(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_ops{}, migraphx::dead_code_elimination{}});
}

#if MIGRAPHX_USE_HIPBLASLT
TEST_CASE(gemm_pointwise_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add"));
mm->add_return({add});
}
run_lowering(p1);
run_fuse_ops(p1);

migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);

auto output = mm->add_instruction(migraphx::op::allocate{s, std::nullopt});

if(migraphx::enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{}) and
migraphx::gpu::hipblaslt_supported())
{
migraphx::op::dot dot_instance;
migraphx::gpu::hipblaslt_op hipblaslt_operator;
hipblaslt_operator.op = migraphx::gpu::hip_gemm<migraphx::op::dot>{dot_instance, 1, 1};
auto add = mm->add_instruction(hipblaslt_operator, a, b, x, output);
mm->add_return({add});
}
else
{
auto gemm_oper =
migraphx::make_op("gpu::gemm", {{"alpha", 1}, {"beta", 1}, {"compute_fp32", 1}});
auto add = mm->add_instruction(gemm_oper, a, b, x, output);
mm->add_return({add});
}
}
EXPECT(p1.sort() == p2.sort());
}
#endif

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit af3f716

Please sign in to comment.