Skip to content

Commit

Permalink
Remove zero point parameter for dequantizelinear when its zero (#3531)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Nov 13, 2024
1 parent c3c1bba commit ca70d73
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 81 deletions.
24 changes: 24 additions & 0 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ void remove_qdq_pairs(module& m)
}
}

void remove_zero_point(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "dequantizelinear")
continue;
if(ins->inputs().size() != 3)
continue;
auto zp = ins->inputs().at(2);
if(not zp->can_eval())
continue;
auto a = zp->eval();
bool is_zero = false;
a.visit([&](auto t) {
is_zero = std::all_of(t.begin(), t.end(), [](auto x) { return float_equal(x, 0); });
});
if(not is_zero)
continue;
m.replace_instruction(ins, ins->get_operator(), ins->inputs().at(0), ins->inputs().at(1));
}
}

void add_int4_pack_unpack_pair(module& m)
{
for(auto ins : iterator_for(m))
Expand Down Expand Up @@ -446,6 +468,8 @@ void simplify_qdq::apply(module& m) const
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
48 changes: 10 additions & 38 deletions test/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ TEST_CASE(dot_float)
auto pb = mm->add_parameter("b", sb);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
auto zp_out = mm->add_literal(std::int32_t{0});
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale);
auto zp_a =
Expand All @@ -685,10 +684,7 @@ TEST_CASE(dot_float)
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
zp_out = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out);
auto r =
mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});

return p;
Expand All @@ -705,11 +701,11 @@ TEST_CASE(dot_float)
migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();

EXPECT(p == qp);
EXPECT(p.sort() == qp.sort());

optimize_prog_int8(p);
auto op = create_int8_optimized_prog();
EXPECT(p == op);
EXPECT(p.sort() == op.sort());
}

TEST_CASE(dot_double_2args)
Expand Down Expand Up @@ -785,11 +781,7 @@ TEST_CASE(dot_double_2args)
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto zp_out = mm->add_literal(std::int32_t{0});
zp_out = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out);
auto r =
mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale, zp_out);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
Expand Down Expand Up @@ -856,18 +848,14 @@ TEST_CASE(dot_half_1arg)

auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
auto zp_out = mm->add_literal(std::int32_t{0});
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
zp_out = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out);
auto r =
mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale, zp_out);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
Expand Down Expand Up @@ -923,11 +911,7 @@ TEST_CASE(conv_float)
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto zp_out = mm->add_literal(std::int32_t{0});
zp_out = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out);
auto r =
mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});

return p;
Expand Down Expand Up @@ -1005,11 +989,7 @@ TEST_CASE(conv_half)
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto zp_out = mm->add_literal(std::int32_t{0});
zp_out = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out);
auto r =
mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});

return p;
Expand Down Expand Up @@ -1257,10 +1237,7 @@ TEST_CASE(int8_subgraph)
auto s1_mb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto zp_out = then_mod->add_literal(std::int32_t{0});
zp_out = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so, zp_out);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});

migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
Expand All @@ -1286,13 +1263,8 @@ TEST_CASE(int8_subgraph)
auto ssw_mb = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
ssw_lit);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto zp1_out = else_mod->add_literal(std::int32_t{0});
zp1_out = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
zp1_out);
auto r1 =
else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1, zp1_out);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});

auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
Expand Down
Loading

0 comments on commit ca70d73

Please sign in to comment.