Skip to content

Commit

Permalink
fix order in layernorm matcher and add test for the same (#2189)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Sep 27, 2023
1 parent 90c8684 commit 03d8a25
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 10 deletions.
14 changes: 9 additions & 5 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,22 +790,26 @@ struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(
return precompile_name("pointwise")(match::any_of[match::inputs()](
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto pw_ins = r.result;
auto layernorm = r.instructions["layernorm"];
if(not layernorm->module_inputs().empty())
return;
auto* pm = ins->module_inputs().front();
auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto ln_pos = std::find(pw_inputs.begin(), pw_inputs.end(), layernorm);
assert(ln_pos != pw_inputs.end());
pw_inputs.erase(ln_pos);
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());

m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
m.replace_instruction(pw_ins, layernorm->get_operator(), inputs, {pm});
}
};

Expand Down
107 changes: 107 additions & 0 deletions test/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "make_precompile_op.hpp"
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>

void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_ops{}, migraphx::dead_code_elimination{}});
}

TEST_CASE(layernorm_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto create_program = [=](bool first_arg_layernorm) {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));
auto add1 =
mm->add_instruction(make_precompile_op("pointwise"), {x, y, alloc_ins}, {pw_add1});
auto alloc_ins2 = mm->add_instruction(alloc);
auto layernorm_ins =
mm->add_instruction(make_precompile_op("gpu::prelayernorm"), add1, alloc_ins2);
std::vector<migraphx::instruction_ref> pw_inputs = {layernorm_ins, z};
if(not first_arg_layernorm)
{
pw_inputs = {z, layernorm_ins};
}
auto* pw_add2 =
create_pointwise_module(p, "main:pointwise1", pw_inputs, single_pointwise("add"));
auto alloc_ins3 = mm->add_instruction(alloc);
pw_inputs.push_back(alloc_ins3);
auto add2 = mm->add_instruction(make_precompile_op("pointwise"), pw_inputs, {pw_add2});
mm->add_return({add2});
return p;
};

auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));
auto add1 =
mm->add_instruction(make_precompile_op("pointwise"), {x, y, alloc_ins}, {pw_add1});
auto alloc_ins2 = mm->add_instruction(alloc);
auto* pw_add2 =
create_pointwise_module(p, "main:pointwise1", {x, z}, single_pointwise("add"));
auto layernorm_ins = mm->add_instruction(
make_precompile_op("gpu::prelayernorm"), {add1, z, alloc_ins2}, {pw_add2});
mm->add_return({layernorm_ins});
return p;
};

{
migraphx::program p1 = create_program(true);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}
{
migraphx::program p1 = create_program(false);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
21 changes: 16 additions & 5 deletions test/include/pointwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP

#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>

template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
migraphx::module_ref create_pointwise_module(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
pm->set_bypass();
Expand All @@ -44,6 +44,17 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
});
auto r = f(pm, params);
pm->add_return({r});
return pm;
}

template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = create_pointwise_module(p, name, inputs, f);
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}

Expand Down

0 comments on commit 03d8a25

Please sign in to comment.