Skip to content

Commit

Permalink
added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Nov 4, 2024
1 parent 22d8f06 commit ebca03d
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class TRANSFORMATIONS_API ActivationsScaling;
namespace activations_scaling {

class TRANSFORMATIONS_API ScaleDownSingleLayer;
class TRANSFORMATIONS_API MulGroupNormFusion;
class TRANSFORMATIONS_API MulMulAddFusion;
class TRANSFORMATIONS_API CropTransformation;
class TRANSFORMATIONS_API MulGroupNormTransformation;
class TRANSFORMATIONS_API MulMulAddTransformation;
class TRANSFORMATIONS_API SplitTransformation;
class TRANSFORMATIONS_API ReshapeTransformation;
class TRANSFORMATIONS_API MulMulMulTransformation;
class TRANSFORMATIONS_API MulMVNTransformation;
Expand All @@ -29,7 +29,10 @@ class TRANSFORMATIONS_API ConcatTransformation;
} // namespace pass
} // namespace ov

// ActivationsScaling scales down activations to prevent overflow due to the limited range of FP16
// ActivationsScaling makes activation values smaller to prevent overflow due to the limited range of FP16
// This feature is controlled by ov::hint::activations_scale_factor.
// For example, when this property is set as 16, activations are divided by 16.
// If ov::hint::activations_scale_factor is less than zero, it is disabled.
class ov::pass::ActivationsScaling : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ActivationsScaling", "0");
Expand All @@ -46,22 +49,22 @@ class ov::pass::activations_scaling::ScaleDownSingleLayer : public ov::pass::Mat
ScaleDownSingleLayer(float scale_factor);
};

class ov::pass::activations_scaling::MulGroupNormFusion : public ov::pass::MatcherPass {
class ov::pass::activations_scaling::MulGroupNormTransformation : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MulGroupNormFusion", "0");
MulGroupNormFusion();
OPENVINO_RTTI("MulGroupNormTransformation", "0");
MulGroupNormTransformation();
};

class ov::pass::activations_scaling::MulMulAddFusion : public ov::pass::MatcherPass {
class ov::pass::activations_scaling::MulMulAddTransformation : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MulMulAddFusion", "0");
MulMulAddFusion();
OPENVINO_RTTI("MulMulAddTransformation", "0");
MulMulAddTransformation();
};

class ov::pass::activations_scaling::CropTransformation : public ov::pass::MatcherPass {
class ov::pass::activations_scaling::SplitTransformation : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("CropTransformation", "0");
CropTransformation();
OPENVINO_RTTI("SplitTransformation", "0");
SplitTransformation();
};

class ov::pass::activations_scaling::ReshapeTransformation : public ov::pass::MatcherPass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ using namespace ov::pass::pattern;
using ov::pass::pattern::op::Or;

// Add scale_down and scale_up layers around Convolution and MatMul nodes
// Conv/MatMul ==> Multiply(scale_down) --> Conv/MatMul --> Multiply(scale_up)
// Conv/MatMul
// ==>
// Multiply(scale_down by scale_factor) --> Conv/MatMul --> Multiply(scale_up by scale_factor)
ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float scale_factor) {
MATCHER_SCOPE(ScaleDownSingleLayer);

Expand Down Expand Up @@ -137,7 +139,7 @@ ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float
this->register_matcher(m, callback);
}

// MulMulAddFusion makes the target pattern to be easy to be merged with other nodes.
// MulMulAddTransformation makes the target pattern to be easy to be merged with followig nodes.
//
// input_a const_a input_b const_b input_a (const_a/const_b)
// \ / \ / \ /
Expand All @@ -148,8 +150,8 @@ ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float
// Add Multiply_b_mma
//
// (input_a * const_a) + (input_b * const_b) ==> ((input_a * (const_a / const_b)) + input_b) * const_b
ov::pass::activations_scaling::MulMulAddFusion::MulMulAddFusion() {
MATCHER_SCOPE(MulMulAddFusion);
ov::pass::activations_scaling::MulMulAddTransformation::MulMulAddTransformation() {
MATCHER_SCOPE(MulMulAddTransformation);

auto activation0_m = any_input(is_non_const_node);
auto scale_const0_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(is_scalar_node);
Expand Down Expand Up @@ -204,7 +206,7 @@ ov::pass::activations_scaling::MulMulAddFusion::MulMulAddFusion() {
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(add_m, "MulMulAddFusion");
auto m = std::make_shared<ov::pass::pattern::Matcher>(add_m, "MulMulAddTransformation");
this->register_matcher(m, callback);
}

Expand All @@ -214,9 +216,11 @@ ov::pass::activations_scaling::MulMulAddFusion::MulMulAddFusion() {
//
// So, we can skip Multiply that is connected to GroupNormalization.
//
// input --> Multiply --> GroupNormalization ==> input --> GroupNormalization
ov::pass::activations_scaling::MulGroupNormFusion::MulGroupNormFusion() {
MATCHER_SCOPE(MulGroupNormFusion);
// input --> Multiply --> GroupNormalization
// ==>
// input --> GroupNormalization
ov::pass::activations_scaling::MulGroupNormTransformation::MulGroupNormTransformation() {
MATCHER_SCOPE(MulGroupNormTransformation);

auto activation_m = any_input(is_non_const_node);
auto scale_const_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(is_scalar_node);
Expand All @@ -239,13 +243,14 @@ ov::pass::activations_scaling::MulGroupNormFusion::MulGroupNormFusion() {
}

if (mul && norm) {
norm->input(0).replace_source_output(mul->get_input_source_output(0));
size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
norm->input(0).replace_source_output(mul->get_input_source_output(activation_index));
return true;
}
return false;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(norm_m, "MulGroupNormFusion");
auto m = std::make_shared<ov::pass::pattern::Matcher>(norm_m, "MulGroupNormTransformation");
this->register_matcher(m, callback);
}

Expand All @@ -255,7 +260,9 @@ ov::pass::activations_scaling::MulGroupNormFusion::MulGroupNormFusion() {
//
// So, we can skip Multiply that is connected to MVN.
//
// input --> Multiply --> MVN ==> input --> MVN
// input --> Multiply --> MVN
// ==>
// input --> MVN
ov::pass::activations_scaling::MulMVNTransformation::MulMVNTransformation() {
MATCHER_SCOPE(MulMVNTransformation);

Expand All @@ -279,7 +286,8 @@ ov::pass::activations_scaling::MulMVNTransformation::MulMVNTransformation() {
}

if (mul && norm) {
norm->input(0).replace_source_output(mul->get_input_source_output(0));
size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
norm->input(0).replace_source_output(mul->get_input_source_output(activation_index));
return true;
}
return false;
Expand All @@ -289,8 +297,16 @@ ov::pass::activations_scaling::MulMVNTransformation::MulMVNTransformation() {
this->register_matcher(m, callback);
}

ov::pass::activations_scaling::CropTransformation::CropTransformation() {
MATCHER_SCOPE(CropTransformation);
// input const input
// \ / |
// Multiply ==> VariadicSplit
// | const / | const \ const
// VariadicSplit | / | / \ /
// / | \ Multiply_a Multiply_b Multiply_c
// output_a output_b output_c | | |
// output_a output_b output_c
ov::pass::activations_scaling::SplitTransformation::SplitTransformation() {
MATCHER_SCOPE(SplitTransformation);

auto activation_m = any_input(is_non_const_node);
auto scale_const_m = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(is_scalar_node);
Expand Down Expand Up @@ -321,12 +337,14 @@ ov::pass::activations_scaling::CropTransformation::CropTransformation() {
target_inputs[i] = split->get_output_target_inputs(i);
}

split->input(0).replace_source_output(mul->input(0).get_source_output());
size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
size_t const_index = (activation_index == 1) ? 0 : 1;
split->input(0).replace_source_output(mul->input(activation_index).get_source_output());

for (size_t i = 0; i < num_split_outputs; i++) {
auto new_mul = register_new_node<ov::op::v1::Multiply>(
split->output(i),
mul->input(1).get_source_output());
mul->input(const_index).get_source_output());
new_mul->set_friendly_name(mul->get_friendly_name() + "_" + std::to_string(i));
ov::copy_runtime_info(mul, new_mul);

Expand All @@ -340,10 +358,15 @@ ov::pass::activations_scaling::CropTransformation::CropTransformation() {
return false;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(split_m, "CropTransformation");
auto m = std::make_shared<ov::pass::pattern::Matcher>(split_m, "SplitTransformation");
this->register_matcher(m, callback);
}

// input const input
// \ / |
// Multiply ==> Reshape const
// | | /
// Reshape Multiply
ov::pass::activations_scaling::ReshapeTransformation::ReshapeTransformation() {
MATCHER_SCOPE(ReshapeTransformation);

Expand Down Expand Up @@ -389,7 +412,7 @@ ov::pass::activations_scaling::ReshapeTransformation::ReshapeTransformation() {
this->register_matcher(m, callback);
}

// MulMulAddFusion makes the target pattern to be easy to be merged with other nodes.
// MulMulAddTransformation makes the target pattern to be easy to be merged with other nodes.
//
// input_a const_a input_b const_b input_a input_b
// \ / \ / \ /
Expand Down Expand Up @@ -456,6 +479,23 @@ ov::pass::activations_scaling::MulMulMulTransformation::MulMulMulTransformation(
this->register_matcher(m, callback);
}

// input_a const_a input_b const_b input_c const_c
// \ / \ / \ /
// Multiply_a Multiply_b Multiply_c
// \ | /
// \ | /
// ---------- Concat ------------
// ==>
// (const_a (const_b (const_c
// input_a /const_c) input_b /const_c) input_c /const_c)
// \ / \ / \ /
// Multiply_a Multiply_b Multiply_c
// \ | /
// \ | /
// ---------- Concat ------------
// | const_c
// | /
// Multiply
ov::pass::activations_scaling::ConcatTransformation::ConcatTransformation() {
MATCHER_SCOPE(ConcatTransformation);

Expand All @@ -473,30 +513,23 @@ ov::pass::activations_scaling::ConcatTransformation::ConcatTransformation() {
}

// check if all inputs are Multiply with scalar operand
bool can_be_transformed = true;
ov::Output<ov::Node> last_dep_const;
for (auto &input : concat->inputs()) {
auto dep_node = std::dynamic_pointer_cast<ov::op::v1::Multiply>(input.get_source_output().get_node_shared_ptr());
if (!dep_node) {
can_be_transformed = false;
break;
return false;
}
auto dep_const0 = std::dynamic_pointer_cast<ov::op::v0::Constant>(dep_node->input(0).get_source_output().get_node_shared_ptr());
auto dep_const1 = std::dynamic_pointer_cast<ov::op::v0::Constant>(dep_node->input(1).get_source_output().get_node_shared_ptr());
if (!dep_const0 && !dep_const1) {
can_be_transformed = false;
break;
return false;
}
last_dep_const = dep_const0 ? dep_node->input(0).get_source_output() : dep_node->input(1).get_source_output();
if (!is_scalar_node(last_dep_const)) {
can_be_transformed = false;
break;
return false;
}
}

if (!can_be_transformed)
return false;

auto target_inputs = concat->get_output_target_inputs(0);

for (auto &input : concat->inputs()) {
Expand Down Expand Up @@ -540,17 +573,17 @@ bool ov::pass::ActivationsScaling::run_on_model(const std::shared_ptr<ov::Model>

manager.register_pass<ScaleDownSingleLayer>(m_scale_factor);
manager.register_pass<LinOpSequenceFusion>();
manager.register_pass<MulGroupNormFusion>();
manager.register_pass<MulGroupNormTransformation>();
manager.register_pass<LinOpSequenceFusion>();
manager.register_pass<MulMulAddFusion>();
manager.register_pass<MulGroupNormFusion>();
manager.register_pass<CropTransformation>();
manager.register_pass<MulMulAddTransformation>();
manager.register_pass<MulGroupNormTransformation>();
manager.register_pass<SplitTransformation>();
manager.register_pass<ReshapeTransformation>();
manager.register_pass<MulMulMulTransformation>();
manager.register_pass<MulMulAddFusion>();
manager.register_pass<MulMulAddTransformation>();
manager.register_pass<ConcatTransformation>();
manager.register_pass<MulMVNTransformation>();
manager.register_pass<MulMulAddFusion>();
manager.register_pass<MulMulAddTransformation>();
manager.register_pass<MulMVNTransformation>();

manager.run_passes(f);
Expand Down
Loading

0 comments on commit ebca03d

Please sign in to comment.