Skip to content

Commit

Permalink
[TensorRT] Add inputs of the very first ConcatOp to exclude list.
Browse files Browse the repository at this point in the history
Signed-off-by: 泊霆 <[email protected]>
  • Loading branch information
Mesilenceki committed Jan 22, 2024
1 parent 0eaa93e commit 92547b0
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,57 @@ Status BuildNodeMap(const Graph& graph,
return Status::OK();
}

Status FindExtraConcatInput(const Graph& graph,
const std::vector<std::string>& input_output_names,
std::vector<const Node*>* filter_concat_node) {
std::unordered_set<const Node*> candidate_node;
std::unordered_set<Node*> concat_nodes;
for (auto* node : graph.nodes()) {
if (node->type_string() == "ConcatV2") {
concat_nodes.insert(node);
}
}
std::unordered_set<std::string> in_out_names;
for (auto& name : input_output_names) {
in_out_names.insert(name);
}
for (const Node* c_nodes : concat_nodes) {
std::vector<const Node*> in_placeholder;
ReverseDFSFrom(
graph, {c_nodes},
[&in_placeholder, in_out_names](const Node* node) {
if (in_out_names.find(node->name()) != in_out_names.end()) {
in_placeholder.emplace_back(node);
}
},
/*end*/ nullptr);
if (in_placeholder.size() > 1) { // verify node in common sub-graph
DataType t_types;
TF_RETURN_IF_ERROR(GetNodeAttr(c_nodes->attrs(), "T", &t_types));
if (t_types == DT_FLOAT) {
candidate_node.insert(c_nodes);
}
}
}

for (const Node* cnode : candidate_node) {
bool is_admit = true;
ReverseDFSFrom(graph, {cnode},
[&filter_concat_node, &is_admit, candidate_node,
cnode](const Node* node) {
if ((candidate_node.find(node) != candidate_node.end()) &&
(cnode->name() != node->name())) {
is_admit = false;
}
},
/*end*/ nullptr);
if (is_admit) {
filter_concat_node->emplace_back(cnode);
}
}
return Status::OK();
}

EngineInfo::EngineType GetEngineType(
const TRTOptimizationPass::ConversionParams& params) {
return (params.is_dynamic_op || params.use_calibration)
Expand Down Expand Up @@ -773,6 +824,14 @@ Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params,
for (const auto& node : input_output_names) {
segment_options.exclude_node_list.insert(node);
}
std::vector<const Node*> filter_concat_node;
TF_RETURN_IF_ERROR(
FindExtraConcatInput(graph, input_output_names, &filter_concat_node));
for (const auto* node : filter_concat_node) {
for (auto* inode : node->in_nodes()) {
segment_options.exclude_node_list.insert(inode->name());
}
}
segment_options.minimum_segment_size = params.minimum_segment_size;
segment_options.use_implicit_batch = params.use_implicit_batch;
if (segment_options.use_implicit_batch)
Expand Down

0 comments on commit 92547b0

Please sign in to comment.