Skip to content

Commit

Permalink
[Embedding] Check the sharded property of tf.train.Saver. (#996)
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty authored May 23, 2024
1 parent 93c69ad commit 9e30ab6
Show file tree
Hide file tree
Showing 22 changed files with 76 additions and 71 deletions.
3 changes: 1 addition & 2 deletions modelzoo/bst/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dbmtl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dcnv2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/deepfm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dien/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,10 +776,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/din/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dlrm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,10 +507,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/dssm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/esmm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/masknet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/mlperf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/mmoe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/ple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/simple_multitask/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
3 changes: 1 addition & 2 deletions modelzoo/wide_and_deep/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,9 @@ def train(sess_config,
hooks = []
hooks.extend(input_hooks)

sharded_saver = tf_config != None
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=True))

stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/python/feature_column/feature_column_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7527,7 +7527,7 @@ def testEmbeddingVariableForL2FeatureEviction(self):
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)
init = variables_lib.global_variables_initializer()
with self.test_session() as sess:
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
Expand Down Expand Up @@ -7758,7 +7758,7 @@ def testEmbeddingVariableForSharedEmbeddingColumnsWithPartitionNum(self):
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
init = variables_lib.global_variables_initializer()
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)

@test_util.run_deprecated_v1
def testEmbeddingVariableForInt32ID(self):
Expand All @@ -7783,7 +7783,7 @@ def testEmbeddingVariableForInt32ID(self):
opt = ftrl.FtrlOptimizer(0.1, l1_regularization_strength=2.0, l2_regularization_strength=0.00001)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)
init = variables_lib.global_variables_initializer()
with self.test_session() as sess:
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS))
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/python/ops/embedding_variable_ops_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def testEmbeddingVariableForInitFromProto(self):
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v)
graph = ops.get_default_graph()
meta_graph_def = saver_module.export_meta_graph()
saver = saver_module.Saver(sharded=True)
meta_graph_def = saver_module.export_meta_graph(saver_def=saver.as_saver_def())
ops.reset_default_graph()
with self.test_session() as sess:
res = saver_module.import_meta_graph(meta_graph_def)
Expand Down Expand Up @@ -748,7 +749,7 @@ def testSaveV3(self):
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v, global_step=gs)
init = variables.global_variables_initializer()
saver = saver = saver_module.Saver()
saver = saver = saver_module.Saver(sharded=True)
checkpoint_directory = self.get_temp_dir()
model_path = os.path.join(checkpoint_directory, "model.ckpt")
with self.test_session() as sess:
Expand Down Expand Up @@ -816,7 +817,7 @@ def testEmbeddingVariableSaveAndRestoreOptimzierStatesForMultiTierWithHbm(self):
opt = adagrad.AdagradOptimizer(0.1)
g_v = opt.compute_gradients(loss)
train_op = opt.apply_gradients(g_v, gs)
saver = saver_module.Saver()
saver = saver_module.Saver(sharded=True)
graph = ops.get_default_graph()
with self.test_session(graph = graph) as sess:
saver.restore(sess, os.path.join(checkpoint_directory, "model.ckpt-12345"))
Expand Down
Loading

0 comments on commit 9e30ab6

Please sign in to comment.