Skip to content

Commit

Permalink
Merge branch 'river/gpu_p2p_support' into river/gpu_p2p_support_rebas…
Browse files Browse the repository at this point in the history
…e_master
  • Loading branch information
riverlijunjie committed Sep 19, 2024
2 parents c29cea2 + 29d6c82 commit 4dfbe14
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RemainFCParallelFusion: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RemainFCParallelFusion", "0");
RemainFCParallelFusion(size_t world_size, size_t world_rank);
std::shared_ptr<ov::Node> find_first_fc_after_multiply(std::shared_ptr<ov::Node> root_node);
};

} // namespace intel_gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
#include "openvino/op/add.hpp"
namespace ov {
namespace intel_gpu {
std::shared_ptr<ov::Node> RemainFCParallelFusion::find_first_fc_after_multiply(std::shared_ptr<ov::Node> root_node) {
const auto& users = root_node->get_users();
if (users.size() != 1)
return nullptr;
auto cur_node = users[0];

if (ov::is_type<ov::op::v0::Result>(cur_node)) {
return nullptr;
}

if (ov::is_type<ov::op::PagedAttentionExtension>(cur_node)) {
return nullptr;
}

if (ov::is_type<ov::intel_gpu::op::FullyConnected>(cur_node)) {
return cur_node;
}
if (ov::is_type<ov::intel_gpu::op::FullyConnectedCompressed>(cur_node)) {
return cur_node;
}
return find_first_fc_after_multiply(cur_node);
}

RemainFCParallelFusion::RemainFCParallelFusion(size_t world_size, size_t world_rank) {
using namespace ov::pass::pattern;
auto data = any_input();
Expand Down Expand Up @@ -88,6 +111,7 @@ RemainFCParallelFusion::RemainFCParallelFusion(size_t world_size, size_t world_r
if (compressed_fc) {
auto scale_node = compressed_fc->get_input_node_shared_ptr(3);
if (tp_mode == op::TP_MODE::ALL_REDUCE) {
ranked_scale = scale_node;
if (scale_node->get_shape()[1] > 1)
ranked_scale =
std::make_shared<ov::intel_gpu::op::RankConstant>(scale_node, world_size, world_rank, tp_mode);
Expand Down Expand Up @@ -139,73 +163,209 @@ RemainFCParallelFusion::RemainFCParallelFusion(size_t world_size, size_t world_r
}
};
{
auto get_output_node = [](const ov::Output<ov::Node>& output) -> std::shared_ptr<ov::Node> {
return output.get_node_shared_ptr();
};
auto get_input_node = [&get_output_node](const ov::Input<ov::Node>& input) -> std::shared_ptr<ov::Node> {
return get_output_node(input.get_source_output());
};
// auto print_shape = [&](const std::shared_ptr<ov::Node>& m_data) {
// std::cout << m_data->get_friendly_name() << ": '";
// for (size_t shape_id = 0; shape_id < m_data->get_output_partial_shape(0).size(); shape_id++) {
// if (!m_data->get_output_partial_shape(0)[shape_id].is_dynamic()) {
// int64_t len = m_data->get_output_partial_shape(0)[shape_id].get_length();
// std::cout << len << ", ";
// } else {
// std::cout << "?" << ", ";
// }
// }
// std::cout << "'\n";
// };
auto fc_after_pa_sync = [&](std::shared_ptr<ov::Node>& fc_node) {
std::map<int, std::shared_ptr<ov::Node>> org_users;
for (auto u : fc_node->get_users()) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == fc_node) {
org_users.insert({idx, u});
}
}
}
// print_shape(fc_node->get_input_node_shared_ptr(0));
// print_shape(fc_node->get_input_node_shared_ptr(1));
// print_shape(fc_node->get_input_node_shared_ptr(2));
// print_shape(fc_node->get_input_node_shared_ptr(3));
auto new_fc = split_fc(fc_node, op::TP_MODE::ALL_REDUCE).first;
new_fc->get_rt_info().insert({"splitted", true});
std::shared_ptr<ov::intel_gpu::op::SyncTensor> sync_node;
sync_node =
std::make_shared<ov::intel_gpu::op::SyncTensor>(new_fc,
world_size,
world_rank,
fc_node->get_input_node_shared_ptr(1)->get_shape()[-1],
fc_node->get_element_type(),
ov::intel_gpu::op::TP_MODE::ALL_REDUCE);
sync_node->set_friendly_name(fc_node->get_friendly_name() + "_TP");
copy_runtime_info(fc_node, new_fc);
for (auto& iter : org_users) {
iter.second->input(iter.first).replace_source_output(sync_node->output(0));
}
fc_node->clear_control_dependencies();
};
if (m_fc->get_rt_info().find("splitted") != m_fc->get_rt_info().end()) {
if (m_fc->get_rt_info()["splitted"].as<bool>()) {
return false;
}
}
#if 0
// std::cout << "m_fc->get_friendly_name(): " << m_fc->get_friendly_name() << std::endl;
// some accuracy lost, disable for now
if (m_fc->get_friendly_name().find("mlp.down_proj") != std::string::npos)
return false;
#endif
auto splitted_context = split_fc(m_fc, op::TP_MODE::ALL_GATHERH);
auto new_fc = splitted_context.first;
new_fc->set_friendly_name(m_fc->get_friendly_name());
copy_runtime_info(m_fc, new_fc);
replace_node(m_fc, new_fc);

if (new_fc->get_users().size() == 1) {
if (m_fc->get_friendly_name().find("mlp.gate_proj") != std::string::npos) {
auto splitted_context = split_fc(m_fc, op::TP_MODE::ALL_GATHERH);
auto new_fc = splitted_context.first;
new_fc->set_friendly_name(m_fc->get_friendly_name());
copy_runtime_info(m_fc, new_fc);
replace_node(m_fc, new_fc);

// if (new_fc->get_users().size() == 1) {
// for (auto& iter : new_fc->get_users()) {
// if (ov::is_type<ov::op::v1::Multiply>(iter)) {
// // return true;
// std::shared_ptr<ov::Node> first_fc_after_pa = find_first_fc_after_multiply(new_fc);
// if (first_fc_after_pa != nullptr) {
// std::cout << "first_fc_after_pa: " << first_fc_after_pa->get_friendly_name() << std::endl;
// fc_after_pa_sync(first_fc_after_pa);
// }
// }
// }
// }
std::shared_ptr<ov::op::v4::Swish> activation;
std::shared_ptr<ov::op::v1::Multiply> eltwise_node;
//bool elwise_flag = false;
for (auto& iter : new_fc->get_users()) {
if (ov::is_type<ov::op::v1::Multiply>(iter))
return true;
}
}
std::shared_ptr<ov::op::v4::Swish> activation;
std::shared_ptr<ov::op::v1::Multiply> eltwise_node;
//bool elwise_flag = false;
for (auto& iter : new_fc->get_users()) {
if (ov::is_type<ov::op::v4::Swish>(iter)) {
activation = std::dynamic_pointer_cast<ov::op::v4::Swish>(iter);
if (activation->get_users().size() == 1) {
for (auto& iter2 : activation->get_users())
if (ov::is_type<ov::op::v1::Multiply>(iter2))
eltwise_node = std::dynamic_pointer_cast<ov::op::v1::Multiply>(iter2);
}
}
}
{
std::map<int, std::shared_ptr<ov::Node>> org_users;
auto node_to_operate = eltwise_node ? eltwise_node : activation ? activation : new_fc;
for (auto u : node_to_operate->get_users()) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == node_to_operate) {
org_users.insert({idx, u});
if (ov::is_type<ov::op::v4::Swish>(iter)) {
activation = std::dynamic_pointer_cast<ov::op::v4::Swish>(iter);
// print_shape(activation);
// print_shape(activation->get_input_node_shared_ptr(0));
// print_shape(new_fc);
auto new_swish = std::make_shared<ov::op::v4::Swish>(activation->get_input_source_output(0));
new_swish->set_friendly_name(activation->get_friendly_name());
copy_runtime_info(activation, new_swish);
replace_node(activation, new_swish);
// print_shape(new_swish);


if (new_swish->get_users().size() == 1) {
for (auto& iter2 : new_swish->get_users())
if (ov::is_type<ov::op::v1::Multiply>(iter2)) {
eltwise_node = std::dynamic_pointer_cast<ov::op::v1::Multiply>(iter2);
// std::cout << eltwise_node->get_friendly_name() << std::endl;
// print_shape(eltwise_node);
auto up_node = get_input_node(eltwise_node->inputs()[1]);

std::map<int, std::shared_ptr<ov::Node>> org_users;
for (auto u : up_node->get_users()) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == up_node) {
org_users.insert({idx, u});
}
}
}

// std::cout << up_node->get_friendly_name() << std::endl;
// print_shape(up_node);
auto splitted_context = split_fc(up_node, op::TP_MODE::ALL_GATHERH);
auto new_up = splitted_context.first;
new_up->set_friendly_name(up_node->get_friendly_name());
copy_runtime_info(up_node, new_up);
// replace_node(up_node, new_up);
for (auto& iter : org_users) {
iter.second->input(iter.first).replace_source_output(new_up->output(0));
}
// print_shape(new_up);
// std::cout << "**********************************\n";

auto new_multiply = std::make_shared<ov::op::v1::Multiply>(
eltwise_node->get_input_source_output(0),
eltwise_node->get_input_source_output(1));
new_multiply->set_friendly_name(eltwise_node->get_friendly_name());
copy_runtime_info(eltwise_node, new_multiply);
replace_node(eltwise_node, new_multiply);
// print_shape(new_multiply);
// print_shape(new_multiply->get_input_node_shared_ptr(0));
// print_shape(new_multiply->get_input_node_shared_ptr(1));


std::shared_ptr<ov::Node> first_fc_after_pa = find_first_fc_after_multiply(new_multiply);
if (first_fc_after_pa != nullptr) {
// std::cout << "first_fc_after_pa: " << first_fc_after_pa->get_friendly_name() << std::endl;
fc_after_pa_sync(first_fc_after_pa);
}
}
}
}
}
std::shared_ptr<ov::intel_gpu::op::SyncTensor> sync_node;
sync_node = std::make_shared<ov::intel_gpu::op::SyncTensor>(node_to_operate,
world_size,
world_rank,
splitted_context.second,
new_fc->get_element_type());
sync_node->set_friendly_name(new_fc->get_friendly_name()+ "_TP_remain");
if (sync_node->get_gpu_p2p_enabled()) {
copy_runtime_info(new_fc, sync_node);
for (auto& iter : org_users) {
iter.second->input(iter.first).replace_source_output(sync_node->output(0));
}
} else {
auto concat_node = std::make_shared<ov::op::v0::Concat>(sync_node->outputs(), -1);
concat_node->set_friendly_name(new_fc->get_friendly_name() + "_ALLGATHER");
copy_runtime_info(new_fc, concat_node);
for (auto& iter : org_users) {
iter.second->input(iter.first).replace_source_output(concat_node->output(0));
}
}
new_fc->clear_control_dependencies();
}
// auto splitted_context = split_fc(m_fc, op::TP_MODE::ALL_GATHERH);
// auto new_fc = splitted_context.first;
// new_fc->set_friendly_name(m_fc->get_friendly_name());
// copy_runtime_info(m_fc, new_fc);
// replace_node(m_fc, new_fc);

// // if (new_fc->get_users().size() == 1) {
// // for (auto& iter : new_fc->get_users()) {
// // if (ov::is_type<ov::op::v1::Multiply>(iter)) {
// // // return true;
// // std::shared_ptr<ov::Node> first_fc_after_pa = find_first_fc_after_multiply(new_fc);
// // if (first_fc_after_pa != nullptr) {
// // std::cout << "first_fc_after_pa: " << first_fc_after_pa->get_friendly_name() << std::endl;
// // fc_after_pa_sync(first_fc_after_pa);
// // }
// // }
// // }
// // }
// std::shared_ptr<ov::op::v4::Swish> activation;
// std::shared_ptr<ov::op::v1::Multiply> eltwise_node;
// //bool elwise_flag = false;
// for (auto& iter : new_fc->get_users()) {
// if (ov::is_type<ov::op::v4::Swish>(iter)) {
// activation = std::dynamic_pointer_cast<ov::op::v4::Swish>(iter);
// if (activation->get_users().size() == 1) {
// for (auto& iter2 : activation->get_users())
// if (ov::is_type<ov::op::v1::Multiply>(iter2)) {
// eltwise_node = std::dynamic_pointer_cast<ov::op::v1::Multiply>(iter2);
// std::cout << eltwise_node->get_friendly_name() << std::endl;
// }
// }
// }
// }
// {
// std::map<int, std::shared_ptr<ov::Node>> org_users;
// auto node_to_operate = eltwise_node ? eltwise_node : activation ? activation : new_fc;
// for (auto u : node_to_operate->get_users()) {
// for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
// if (u->get_input_node_shared_ptr(idx) == node_to_operate) {
// org_users.insert({idx, u});
// }
// }
// }
// std::shared_ptr<ov::Node> first_fc_after_pa = find_first_fc_after_multiply(node_to_operate);
// if (first_fc_after_pa != nullptr) {
// std::cout << "first_fc_after_pa: " << first_fc_after_pa->get_friendly_name() << std::endl;
// fc_after_pa_sync(first_fc_after_pa);
// }

// std::shared_ptr<ov::intel_gpu::op::SyncTensor> sync_node;
// sync_node = std::make_shared<ov::intel_gpu::op::SyncTensor>(node_to_operate, world_size, splitted_context.second,
// new_fc->get_element_type());
// sync_node->set_friendly_name(new_fc->get_friendly_name()+ "_TP");

// auto concat_node = std::make_shared<ov::op::v0::Concat>(sync_node->outputs(), -1);
// concat_node->set_friendly_name(new_fc->get_friendly_name()+ "_ALLGATHER");
// copy_runtime_info(new_fc, concat_node);
// for (auto& iter : org_users) {
// iter.second->input(iter.first).replace_source_output(concat_node->output(0));
// }
// new_fc->clear_control_dependencies();
// }
}
return true;
};
Expand Down

0 comments on commit 4dfbe14

Please sign in to comment.