From 2f938dc2a18e57c9a302f5a8b988f6cd39f89e2f Mon Sep 17 00:00:00 2001 From: Junqi Hu <42396655+Mesilenceki@users.noreply.github.com> Date: Tue, 9 Jan 2024 17:46:11 -0800 Subject: [PATCH] [TensorRT] Fix Graph contains EmbeddingVariable compiling issue. (#964) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 泊霆 Co-authored-by: 泊霆 --- tensorflow/python/compiler/tensorrt/trt_convert.py | 12 +++++------- tensorflow/python/framework/graph_util_impl.py | 12 +++++------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index 2c8d603ba01..064e32c6984 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -539,13 +539,10 @@ def _gather_names(tensor_info): # EmbeddingVariable can not be convert to constant, so we need to # load ev varibles at runtime always. if self._use_ev: - global_step_collection_ops = sess.graph.get_collection("global_step") - global_step_name = global_step_collection_ops[0].name.split(":")[0] output_node_names.add(filename_tensor_name) output_node_names.add(save_tensor_name) output_node_names.add(restore_op_name) - tf_logging.info("TensorRT - global_step_name: %s" % str(global_step_name)) tf_logging.info("TensorRT - filename_tensor_name: %s" % str(filename_tensor_name)) tf_logging.info("TensorRT - save_tensor_name: %s" % str(save_tensor_name)) tf_logging.info("TensorRT - restore_op_name: %s" % str(restore_op_name)) @@ -559,18 +556,19 @@ def _gather_names(tensor_info): # Freeze the variables in the SavedModel graph and copy the frozen # graph over. - variable_names_blacklist = [] if self._use_ev: - variable_names_blacklist.append(global_step_name) + global_step_collection_ops = sess.graph.get_collection("global_step") + if len(global_step_collection_ops) > 0: + sess.run([sess.graph.get_operation_by_name("global_step/Assign")]) frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), - list(output_node_names), variable_names_blacklist=variable_names_blacklist) + list(output_node_names)) if self._use_ev: # Keep KV Variable in saver_def, these kv-vars will be initialized at runtime. frozen_graph_def = graph_util.create_kv_variable_init_graph( - frozen_graph_def, global_step_name, restore_op_name) + frozen_graph_def, restore_op_name) self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index 76d69e886e7..c3fa37529c3 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -169,7 +169,7 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name): return nodes_to_keep @tf_export(v1=["graph_util.create_kv_variable_init_graph"]) -def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name): +def create_kv_variable_init_graph(graph, restore_all_op_name): name_to_input_name, name_to_node, name_to_seq_num = \ _extract_graph_summary(graph) @@ -184,8 +184,10 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name): " {} in current graph.".format(restore_all_op_name)) for restore_shard_input_full_name in restore_all_op.input: - restore_shard_input_name = re.sub(r"^\^", "", restore_shard_input_full_name) - restore_shard_input_op = name_to_node[restore_shard_input_name] + restore_shard_input_no_op_name = re.sub(r"^\^", "", restore_shard_input_full_name) + restore_shard_input_no_op = name_to_node[restore_shard_input_no_op_name] + restore_shard_input_op_name = re.sub(r"^\^", "",restore_shard_input_no_op.input[0]) + restore_shard_input_op = name_to_node[restore_shard_input_op_name] # go through all restore_shard ops new_node = node_def_pb2.NodeDef() new_node.CopyFrom(restore_shard_input_op) @@ -198,10 +200,6 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name): n_node.op == "KvResourceImportV2" or \ n_node.op == "KvResourceImport": new_node.input.append(n_full_name) - else: - # Keep global_step assign op in new save/restore_all - if n_node.input[0] == global_step_name: - new_node.input.append(n_full_name) graph.node.remove(restore_shard_input_op) graph.node.extend([new_node])