Skip to content

Commit

Permalink
[Graph] Fix hang bug for async embedding lookup. (#934)
Browse files Browse the repository at this point in the history
Skip edges to 'SaveV3' Op.

Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty authored Oct 18, 2023
1 parent 06f81cc commit be62ec3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
7 changes: 6 additions & 1 deletion tensorflow/python/training/async_embedding_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def __init__(self, options, checkpoint_dir = None):
self._checkpoint_dir = checkpoint_dir if checkpoint_dir else ""
self._use_stage_subgraph_thread_pool = options.use_stage_subgraph_thread_pool
self._stage_subgraph_thread_pool_id = options.stage_subgraph_thread_pool_id
self._is_staged = False
self._control_flow_ops = ['Switch', '_SwitchN', 'Merge', '_XlaMerge',
'Enter', 'Exit']
self._variable_ops = ['Variable', 'VariableV2', 'VarHandleOp',
'KvVarHandleOp', 'HashTableV2']
self._variable_is_init_ops = ['IsVariableInitialized',
'VarIsInitializedOp', 'KvVarIsInitializedOp']
self._saver_ops = ['SaveV2']
self._saver_ops = ['SaveV2', 'SaveV3']
self._no_data_input_ops = self._variable_ops + ['Placeholder', 'PlaceholderV2', 'Const']
self._boundary_ops = set()
for tensor in ops.get_collection(ops.GraphKeys.ASYNC_EMBEDDING_OUTPUT_TENSORS):
Expand All @@ -74,6 +75,10 @@ def __init__(self, options, checkpoint_dir = None):
def stage(self, graph):
""" add async embedding stage node to graph
"""
if self._is_staged:
return
self._is_staged = True

logging.info('async embedding stage begin')
logging.info('async embedding thread num: ' + str(self._threads_num))
logging.info('async embedding capacity: ' + str(self._capacity))
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/python/training/monitored_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(self,
self._saver = saver
self._incremental_save_restore = incremental_save_restore
self._incr_saver = None
self._async_embedding_stage = None
self._enable_async_embedding = False
self._async_embedding_checkpoint_dir = None
self._async_embedding_options = None
Expand Down Expand Up @@ -247,10 +248,11 @@ def default_ready_for_local_init_op():
self._incr_saver = incr_saver._get_incremental_saver(self._incremental_save_restore, self._saver)

if self._enable_async_embedding:
async_embedding_stage = async_embedding.AsyncEmbeddingStage(
self._async_embedding_options,
self._async_embedding_checkpoint_dir)
async_embedding_stage.stage(ops.get_default_graph())
if self._async_embedding_stage is None:
self._async_embedding_stage = async_embedding.AsyncEmbeddingStage(
self._async_embedding_options,
self._async_embedding_checkpoint_dir)
self._async_embedding_stage.stage(ops.get_default_graph())

ops.get_default_graph().finalize()
logging.info('Graph was finalized.')
Expand Down

0 comments on commit be62ec3

Please sign in to comment.