Skip to content

Commit

Permalink
update tests, update eliminate_contiguous to remove contiguous at end…
Browse files Browse the repository at this point in the history
… of program
  • Loading branch information
kahmed10 committed Aug 11, 2023
1 parent 0ede5e1 commit 225cd3a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/auto_contiguous.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void auto_contiguous::apply(module& m) const
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(contains({"layout", "contiguous", "@return", "@param"}, ins->name()))
if(contains({"layout", "contiguous", "@return", "@param", "@outline"}, ins->name()))
continue;
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
Expand Down
14 changes: 14 additions & 0 deletions src/eliminate_contiguous.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/par_for.hpp>
#include <type_traits>
#include <utility>

namespace migraphx {
Expand Down Expand Up @@ -161,6 +162,18 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}

static void remove_contiguous_noops(const std::string& op_name, module& m)
{
for(auto ins : iterator_for(m))
{
if (ins->name() != op_name)
continue;
if(ins->inputs().front()->get_shape() != ins->get_shape())
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}

void eliminate_contiguous::apply(module& m) const
{
// Skip contiguous from splits first
Expand All @@ -170,6 +183,7 @@ void eliminate_contiguous::apply(module& m) const
return (ins->inputs().front()->outputs().size() == 1);
});
remove_contiguous(op_name, m, [](auto) { return true; });
remove_contiguous_noops(op_name, m);
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
13 changes: 9 additions & 4 deletions test/auto_contiguous_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,13 @@ TEST_CASE(two_transpose_gather)
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), data);
auto ctd = m2.add_instruction(migraphx::make_op("contiguous"), td);
auto sd = m2.add_instruction(migraphx::make_op("softmax", {{"axis", 2}}), ctd);
auto csd = m2.add_instruction(migraphx::make_op("contiguous"), sd);
auto bd =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), sd);
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), csd);
auto cbd = m2.add_instruction(migraphx::make_op("contiguous"), bd);
auto r = m2.add_instruction(migraphx::make_op("gather", {{"axis", 2}}), cbd, ind);
m2.add_return({r});
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}

EXPECT(m1 == m2);
Expand All @@ -174,8 +176,11 @@ TEST_CASE(standard_reshape)
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
// extra contiguous coming from reshape logic which has "requires_std_shape" attribute
auto cb = m2.add_instruction(migraphx::make_op("contiguous"), ca);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), cb);
auto cr = m2.add_instruction(migraphx::make_op("contiguous"), r);
m2.add_return({cr});
}

EXPECT(m1 == m2);
Expand Down
74 changes: 24 additions & 50 deletions test/gpu/pack_int8_args.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/replace_allocate.hpp>
#include <migraphx/instruction.hpp>
Expand All @@ -49,6 +50,7 @@ void run_passes(migraphx::module& m, migraphx::gpu::context& ctx)
migraphx::run_passes(m,
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::eliminate_contiguous{"gpu::contiguous"},
migraphx::dead_code_elimination{},
migraphx::replace_allocate{migraphx::gpu::gpu_allocation_model{}},
migraphx::dead_code_elimination{},
Expand Down Expand Up @@ -104,13 +106,9 @@ TEST_CASE(quant_dot)

auto beta_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", m3_shape.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m3_shape)}}));
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_contiguous, mul_alloc);
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc);
auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
m.add_return({gemm_add});

Expand Down Expand Up @@ -158,53 +156,43 @@ TEST_CASE(quant_dot_trans)
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);

auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb);

auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha);
migraphx::make_op("multibroadcast", {{"out_lens", tl1->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate",
{{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type, {3, 2, 5, 8}))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert =
m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", alpha->get_shape().type()}})),
conta,
tl1_convert_alloc);
tl1,
alpha_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
"hip::allocate", {{"shape", migraphx::to_value(alpha_alloc->get_shape())}}));
auto tl1_alpha_int32 =
m.add_instruction(make_precompile_op("mul"), alpha_contiguous, tl1_convert, mul_alloc);
m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc);
// convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
"hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto tl1_alpha_int8 =
m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", conta->get_shape().type()}})),
"convert", {{"target_type", tl1->get_shape().type()}})),
tl1_alpha_int32,
tl1_alpha_int8_alloc);

auto packb = contb;
auto packb = tl2;
if(int8_x4)
{
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), tl2, allocpb);
}

auto gemm = m.add_instruction(
Expand Down Expand Up @@ -301,13 +289,9 @@ TEST_CASE(quant_dot_pad)

auto beta_broadcast =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s3.lens()}}), beta);
auto beta_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto beta_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), beta_broadcast, beta_alloc);
auto mul_alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(s3)}}));
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_contiguous, mul_alloc);
auto m3_beta = m.add_instruction(make_precompile_op("mul"), l3, beta_broadcast, mul_alloc);
auto gemm_add = m.add_instruction(make_precompile_op("add"), gemm, m3_beta, output);
m.add_return({gemm_add});
return m;
Expand Down Expand Up @@ -357,58 +341,48 @@ TEST_CASE(quant_dot_trans_pad)
auto tl1 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta);

auto tl2 =
m.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));


migraphx::instruction_ref ptb{};
if(int8_x4)
{
ptb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
}
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb);
auto pb = contb;
auto pb = tl2;
if(int8_x4)
{
pb = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}),
contb,
tl2,
ptb);
}

auto alpha_broadcast = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conta->get_shape().lens()}}), alpha);
migraphx::make_op("multibroadcast", {{"out_lens", tl1->get_shape().lens()}}), alpha);
auto alpha_alloc = m.add_instruction(
migraphx::make_op("hip::allocate",
{{"shape",
migraphx::to_value(migraphx::shape(migraphx::shape::int32_type,
conta->get_shape().lens()))}}));
auto alpha_contiguous =
m.add_instruction(migraphx::make_op("gpu::contiguous"), alpha_broadcast, alpha_alloc);
tl1->get_shape().lens()))}}));

// alpha = int32 and tl1 = int8, convert tl1 to int32 for multiplication and then convert
// back result to int8
auto tl1_convert_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(alpha_contiguous->get_shape())}}));
auto tl1_convert =
m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", alpha->get_shape().type()}})),
conta,
tl1_convert_alloc);
tl1,
alpha_alloc);
auto mul_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(tl1_convert->get_shape())}}));
"hip::allocate", {{"shape", migraphx::to_value(alpha_alloc->get_shape())}}));
auto tl1_alpha_int32 =
m.add_instruction(make_precompile_op("mul"), alpha_contiguous, tl1_convert, mul_alloc);
m.add_instruction(make_precompile_op("mul"), alpha_broadcast, tl1_convert, mul_alloc);
// convert mul_res to int8
auto tl1_alpha_int8_alloc = m.add_instruction(migraphx::make_op(
"hip::allocate", {{"shape", migraphx::to_value(conta->get_shape())}}));
"hip::allocate", {{"shape", migraphx::to_value(ts1)}}));

migraphx::instruction_ref pta{};
if(int8_x4)
Expand All @@ -419,7 +393,7 @@ TEST_CASE(quant_dot_trans_pad)

auto tl1_alpha_int8 =
m.add_instruction(make_precompile_op(migraphx::make_op(
"convert", {{"target_type", conta->get_shape().type()}})),
"convert", {{"target_type", ts1.type()}})),
tl1_alpha_int32,
tl1_alpha_int8_alloc);

Expand Down

0 comments on commit 225cd3a

Please sign in to comment.