Skip to content

Commit

Permalink
[XPU] Migrate xpu_embedding_with_eltwise_add_fuse_pass (#50590)
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 authored Feb 23, 2023
1 parent d7673e2 commit 8d325d8
Show file tree
Hide file tree
Showing 13 changed files with 655 additions and 42 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
Expand Down
74 changes: 39 additions & 35 deletions paddle/fluid/framework/ir/delete_dropout_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,50 @@ namespace ir {
void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph);
int found_subgraph_count = 0;

GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern();
for (auto with_mask : {true, false}) {
GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern(with_mask);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);
GET_IR_NODE(dropout_op_mask);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);

// link dropout_op_out to pre_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto pre_ops = dropout_op_x->inputs;
if (pre_ops.empty()) return;
auto pre_op_desc = pre_ops[0]->Op();
auto pre_op_outs = pre_op_desc->Outputs();
for (auto& out_var : pre_op_outs) {
auto names = out_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_x_name) {
names[i] = dropout_op_out_name;
pre_op_desc->SetOutput(out_var.first, names);
break;
// link dropout_op_x to next_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto next_op_nodes = dropout_op_out->outputs;
for (auto next_op_node : next_op_nodes) {
auto next_op_desc = next_op_node->Op();
auto next_op_inputs = next_op_desc->Inputs();
for (auto& input_var : next_op_inputs) {
auto names = input_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_out_name) {
names[i] = dropout_op_x_name;
next_op_desc->SetInput(input_var.first, names);
break;
}
}
}
IR_NODE_LINK_TO(dropout_op_x, next_op_node);
}
}
IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);

// delete useless node
std::unordered_set<const Node*> delete_nodes{
dropout_op_x, dropout_op, dropout_op_mask};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
// delete useless node
std::unordered_set<const Node*> delete_nodes{dropout_op, dropout_op_out};
if (with_mask) {
GET_IR_NODE(dropout_op_mask);
delete_nodes.insert(dropout_op_mask);
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
}
AddStatis(found_subgraph_count);
}

Expand Down
14 changes: 9 additions & 5 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3032,7 +3032,7 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}

void patterns::DeleteDropoutOpPattern::operator()() {
void patterns::DeleteDropoutOpPattern::operator()(bool with_mask) {
auto dropout_op_x = pattern->NewNode(dropout_op_x_repr())
->assert_is_op_input("dropout", "X")
->AsInput();
Expand All @@ -3042,10 +3042,14 @@ void patterns::DeleteDropoutOpPattern::operator()() {
std::string("upscale_in_train"));
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out");
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
if (with_mask) {
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
} else {
dropout_op->LinksFrom({dropout_op_x}).LinksTo({dropout_op_out});
}
}

void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1759,7 +1759,7 @@ struct DeleteDropoutOpPattern : public PatternBase {
DeleteDropoutOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_dropout_op_pattern") {}

void operator()();
void operator()(bool with_mask);

PATTERN_DECL_NODE(dropout_op_x);
PATTERN_DECL_NODE(dropout_op);
Expand Down
Loading

0 comments on commit 8d325d8

Please sign in to comment.