From 7e78253151ffa46e6141c8d2fbbb8a6ab313d731 Mon Sep 17 00:00:00 2001 From: "chenbangduo.cbd" Date: Thu, 25 Apr 2024 20:12:37 +0800 Subject: [PATCH] [Hook] Add 'before_create_session' interface to SessionRunHook. Signed-off-by: chenbangduo.cbd --- tensorflow/python/training/monitored_session.py | 3 +++ tensorflow/python/training/session_run_hook.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 6eb204785dd..9492028a200 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -957,6 +957,8 @@ def __init__(self, session_creator, hooks, stop_grace_period_secs): def create_session(self): """Creates a coordinated session.""" # Keep the tf_sess for unit testing. + for hook in self._hooks: + hook.before_create_session() self.tf_sess = self._session_creator.create_session() # We don't want coordinator to suppress any exception. self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) @@ -1027,6 +1029,7 @@ class MonitoredSession(_MonitoredSession): in given order: * calls `hook.begin()` for each given hook + * calls `hook.before_create_session()` * finalizes the graph via `scaffold.finalize()` * create session * initializes the model via initialization ops provided by `Scaffold` diff --git a/tensorflow/python/training/session_run_hook.py b/tensorflow/python/training/session_run_hook.py index e598bc2d98c..9d05d04c139 100644 --- a/tensorflow/python/training/session_run_hook.py +++ b/tensorflow/python/training/session_run_hook.py @@ -109,6 +109,20 @@ def begin(self): """ pass + def before_create_session(self): + """Called before new TensorFlow session is created. + + This has two essential differences with the situation in which `begin` is + called: + + * Do not modify the graph in this method, ops should not be added to graph. + The modification of the graph should take place within the begin + interface. + * This method will also be called prior to the recovery of a wrapped + session, not just at the beginning of the overall session. + """ + pass + def after_create_session(self, session, coord): # pylint: disable=unused-argument """Called when new TensorFlow session is created.