Skip to content

Commit

Permalink
Merge pull request #1 from andrew-k-park/enable_fc_horizontal_fusion
Browse files Browse the repository at this point in the history
[GPU] Enable FullyConnectedHorizontalFusion with ActivationsScaling
  • Loading branch information
e-ddykim authored Nov 6, 2024
2 parents cc4b37f + 4fa090f commit 46b17ca
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TRANSFORMATIONS_API ActivationsScaling;
namespace activations_scaling {

class TRANSFORMATIONS_API ScaleDownSingleLayer;
class TRANSFORMATIONS_API ScaleDownMultipleLayers;
class TRANSFORMATIONS_API MulGroupNormTransformation;
class TRANSFORMATIONS_API MulMulAddTransformation;
class TRANSFORMATIONS_API SplitTransformation;
Expand Down Expand Up @@ -49,6 +50,12 @@ class ov::pass::activations_scaling::ScaleDownSingleLayer : public ov::pass::Mat
ScaleDownSingleLayer(float scale_factor);
};

class ov::pass::activations_scaling::ScaleDownMultipleLayers : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ScaleDownMultipleLayers", "0");
ScaleDownMultipleLayers(float scale_factor);
};

class ov::pass::activations_scaling::MulGroupNormTransformation : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MulGroupNormTransformation", "0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "openvino/op/reshape.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
Expand Down Expand Up @@ -59,10 +60,23 @@ using ov::pass::pattern::op::Or;
ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float scale_factor) {
MATCHER_SCOPE(ScaleDownSingleLayer);

auto is_single_matmul = [](const Output<Node>& output) {
auto matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(output.get_node_shared_ptr());
auto input = matmul->get_input_node_shared_ptr(0);
size_t user_matmul_count = 0;
for (const auto& u : input->get_users()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u);
if (!matmul_user)
continue;
user_matmul_count++;
}
return user_matmul_count == 1;
};

auto activation_m = any_input();
auto weights_m = any_input();
auto convolution_m = wrap_type<ov::op::v1::Convolution>({activation_m, weights_m});
auto matmul_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m});
auto matmul_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m}, is_single_matmul);
auto scaled_op_m = std::make_shared<Or>(OutputVector{convolution_m, matmul_m});

ov::Shape scale_const_shape = {1};
Expand Down Expand Up @@ -139,6 +153,98 @@ ov::pass::activations_scaling::ScaleDownSingleLayer::ScaleDownSingleLayer(float
this->register_matcher(m, callback);
}

ov::pass::activations_scaling::ScaleDownMultipleLayers::ScaleDownMultipleLayers(float scale_factor) {
MATCHER_SCOPE(ScaleDownMultipleLayers);

auto is_mutiple_matmuls = [](const Output<Node>& output) {
auto matmul = std::dynamic_pointer_cast<ov::op::v0::MatMul>(output.get_node_shared_ptr());
auto input = matmul->get_input_node_shared_ptr(0);
size_t user_matmul_count = 0;
for (const auto& u : input->get_users()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u);
if (!matmul_user)
continue;
user_matmul_count++;
}
return !ov::is_type<ov::op::v1::Multiply>(input) &&
input->get_users().size() > 1 &&
input->get_users().size() == user_matmul_count;
};

auto activation_m = any_input();
auto weights_m = any_input();
auto scaled_op_m = wrap_type<ov::op::v0::MatMul>({activation_m, weights_m}, is_mutiple_matmuls);

ov::Shape scale_const_shape = {1};
std::vector<float> scale_down_value = {1.f / scale_factor};
std::shared_ptr<ov::Node> scale_down_const_f16 =
std::make_shared<ov::op::v0::Constant>(ov::element::f16, scale_const_shape, scale_down_value);
std::shared_ptr<ov::Node> scale_down_const_f32 =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, scale_const_shape, scale_down_value);
std::vector<float> scale_up_value = {scale_factor};
std::shared_ptr<ov::Node> scale_up_const_f16 =
std::make_shared<ov::op::v0::Constant>(ov::element::f16, scale_const_shape, scale_up_value);
std::shared_ptr<ov::Node> scale_up_const_f32 =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, scale_const_shape, scale_up_value);

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

auto scaled_op = std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(scaled_op_m).get_node_shared_ptr());
if (!scaled_op || transformation_callback(scaled_op))
return false;

auto input_node = scaled_op->get_input_node_shared_ptr(0);
auto scale_down = std::make_shared<ov::op::v1::Multiply>(
input_node, (input_node->get_element_type() == ov::element::f32) ? scale_down_const_f32 : scale_down_const_f16);
scale_down->set_friendly_name(scaled_op->get_friendly_name() + "_scale_down");
ov::copy_runtime_info(scaled_op, scale_down);

for (const auto& u : input_node->get_users()) {
auto matmul_user = std::dynamic_pointer_cast<ov::op::v0::MatMul>(u);
if (matmul_user) {
matmul_user->input(0).replace_source_output(scale_down);
auto child = matmul_user->get_output_target_inputs(0).begin()->get_node();
if (matmul_user->get_output_target_inputs(0).size() == 1 && ov::is_type<ov::op::v1::Add>(child)) {
auto add = child->shared_from_this();
auto target_inputs = add->get_output_target_inputs(0);
auto scale_down_bias = std::make_shared<ov::op::v1::Multiply>(
add->input(1).get_source_output(),
(add->input(1).get_element_type() == ov::element::f32) ? scale_down_const_f32 : scale_down_const_f16);
scale_down_bias->set_friendly_name(add->get_friendly_name() + "_scale_down");
ov::copy_runtime_info(add, scale_down_bias);
add->input(1).replace_source_output(scale_down_bias->output(0));

auto scale_up = register_new_node<ov::op::v1::Multiply>(
add->output(0),
(add->output(0).get_element_type() == ov::element::f32) ? scale_up_const_f32 : scale_up_const_f16);
scale_up->set_friendly_name(matmul_user->get_friendly_name() + "_scale_up");
ov::copy_runtime_info(matmul_user, scale_up);
for (auto& in : target_inputs) {
in.replace_source_output(scale_up);
}
} else {
auto target_inputs = matmul_user->get_output_target_inputs(0);
auto scale_up = register_new_node<ov::op::v1::Multiply>(
matmul_user->output(0),
(matmul_user->output(0).get_element_type() == ov::element::f32) ? scale_up_const_f32
: scale_up_const_f16);
scale_up->set_friendly_name(matmul_user->get_friendly_name() + "_scale_up");
ov::copy_runtime_info(matmul_user, scale_up);
for (auto& in : target_inputs) {
in.replace_source_output(scale_up);
}
}
}
}
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(scaled_op_m, "ScaleDownMultipleLayers");
this->register_matcher(m, callback);
}


// MulMulAddTransformation makes the target pattern to be easy to be merged with followig nodes.
//
// input_a const_a input_b const_b input_a (const_a/const_b)
Expand Down Expand Up @@ -582,9 +688,10 @@ bool ov::pass::ActivationsScaling::run_on_model(const std::shared_ptr<ov::Model>
manager.set_per_pass_validation(false);

manager.register_pass<ScaleDownSingleLayer>(m_scale_factor);
manager.register_pass<LinOpSequenceFusion>();
manager.register_pass<ScaleDownMultipleLayers>(m_scale_factor);
manager.register_pass<MultiplyMultiplyFusion>();
manager.register_pass<MulGroupNormTransformation>();
manager.register_pass<LinOpSequenceFusion>();
manager.register_pass<MultiplyMultiplyFusion>();
manager.register_pass<MulMulAddTransformation>();
manager.register_pass<MulGroupNormTransformation>();
manager.register_pass<SplitTransformation>();
Expand All @@ -595,6 +702,7 @@ bool ov::pass::ActivationsScaling::run_on_model(const std::shared_ptr<ov::Model>
manager.register_pass<MulMVNTransformation>();
manager.register_pass<MulMulAddTransformation>();
manager.register_pass<MulMVNTransformation>();
manager.register_pass<ConstantFolding>();

manager.run_passes(f);

Expand Down

0 comments on commit 46b17ca

Please sign in to comment.