Skip to content

Commit

Permalink
[TensorRT] Fix Graph contains EmbeddingVariable compiling issue. (#964)
Browse files Browse the repository at this point in the history
Signed-off-by: 泊霆 <[email protected]>
Co-authored-by: 泊霆 <[email protected]>
  • Loading branch information
Mesilenceki and Mesilenceki authored Jan 10, 2024
1 parent 0f536a2 commit 2f938dc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
12 changes: 5 additions & 7 deletions tensorflow/python/compiler/tensorrt/trt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,13 +539,10 @@ def _gather_names(tensor_info):
# EmbeddingVariable can not be convert to constant, so we need to
# load ev varibles at runtime always.
if self._use_ev:
global_step_collection_ops = sess.graph.get_collection("global_step")
global_step_name = global_step_collection_ops[0].name.split(":")[0]
output_node_names.add(filename_tensor_name)
output_node_names.add(save_tensor_name)
output_node_names.add(restore_op_name)

tf_logging.info("TensorRT - global_step_name: %s" % str(global_step_name))
tf_logging.info("TensorRT - filename_tensor_name: %s" % str(filename_tensor_name))
tf_logging.info("TensorRT - save_tensor_name: %s" % str(save_tensor_name))
tf_logging.info("TensorRT - restore_op_name: %s" % str(restore_op_name))
Expand All @@ -559,18 +556,19 @@ def _gather_names(tensor_info):

# Freeze the variables in the SavedModel graph and copy the frozen
# graph over.
variable_names_blacklist = []
if self._use_ev:
variable_names_blacklist.append(global_step_name)
global_step_collection_ops = sess.graph.get_collection("global_step")
if len(global_step_collection_ops) > 0:
sess.run([sess.graph.get_operation_by_name("global_step/Assign")])

frozen_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(add_shapes=True),
list(output_node_names), variable_names_blacklist=variable_names_blacklist)
list(output_node_names))

if self._use_ev:
# Keep KV Variable in saver_def, these kv-vars will be initialized at runtime.
frozen_graph_def = graph_util.create_kv_variable_init_graph(
frozen_graph_def, global_step_name, restore_op_name)
frozen_graph_def, restore_op_name)

self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
Expand Down
12 changes: 5 additions & 7 deletions tensorflow/python/framework/graph_util_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
return nodes_to_keep

@tf_export(v1=["graph_util.create_kv_variable_init_graph"])
def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
def create_kv_variable_init_graph(graph, restore_all_op_name):
name_to_input_name, name_to_node, name_to_seq_num = \
_extract_graph_summary(graph)

Expand All @@ -184,8 +184,10 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
" {} in current graph.".format(restore_all_op_name))

for restore_shard_input_full_name in restore_all_op.input:
restore_shard_input_name = re.sub(r"^\^", "", restore_shard_input_full_name)
restore_shard_input_op = name_to_node[restore_shard_input_name]
restore_shard_input_no_op_name = re.sub(r"^\^", "", restore_shard_input_full_name)
restore_shard_input_no_op = name_to_node[restore_shard_input_no_op_name]
restore_shard_input_op_name = re.sub(r"^\^", "",restore_shard_input_no_op.input[0])
restore_shard_input_op = name_to_node[restore_shard_input_op_name]
# go through all restore_shard ops
new_node = node_def_pb2.NodeDef()
new_node.CopyFrom(restore_shard_input_op)
Expand All @@ -198,10 +200,6 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
n_node.op == "KvResourceImportV2" or \
n_node.op == "KvResourceImport":
new_node.input.append(n_full_name)
else:
# Keep global_step assign op in new save/restore_all
if n_node.input[0] == global_step_name:
new_node.input.append(n_full_name)

graph.node.remove(restore_shard_input_op)
graph.node.extend([new_node])
Expand Down

0 comments on commit 2f938dc

Please sign in to comment.