From cf16856d01551c9d1cb005722d7f62a448df7095 Mon Sep 17 00:00:00 2001 From: Chen Bangduo Date: Tue, 26 Mar 2024 17:15:18 +0800 Subject: [PATCH] [Incremental Checkpoint] Fix import incremental embedding variable. (#983) Signed-off-by: chenbangduo.cbd --- .../embedding/embedding_var_restore.cc | 50 +++++++++-------- tensorflow/python/training/incr_ckpt_test.py | 54 +++++++++++++++++++ 2 files changed, 82 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/framework/embedding/embedding_var_restore.cc b/tensorflow/core/framework/embedding/embedding_var_restore.cc index 11c13008995..6ff07bf7e43 100644 --- a/tensorflow/core/framework/embedding/embedding_var_restore.cc +++ b/tensorflow/core/framework/embedding/embedding_var_restore.cc @@ -102,45 +102,48 @@ void CheckpointLoader::RestoreInternal( Tensor part_filter_offset_tensor; if (!restore_args_.m_is_oldform) { /****** InitPartOffsetTensor ******/ - TensorShape part_offset_shape, part_filter_offset_shape; - DataType part_offset_type, part_filter_offset_type; + TensorShape part_offset_shape; + DataType part_offset_type; string offset_tensor_name; if (!restore_args_.m_is_incr) { offset_tensor_name = name_string + kPartOffsetTensorSuffsix; } else { offset_tensor_name = name_string + kIncrPartOffsetTensorSuffsix; } - - string offset_filter_tensor_name = - name_string + kPartFilterOffsetTensorSuffsix; + Status s = reader_->LookupDtypeAndShape( offset_tensor_name, &part_offset_type, &part_offset_shape); if (!s.ok()) { LOG(ERROR) << "EV restoring fail:" << s.error_message(); } - s = reader_->LookupDtypeAndShape(offset_filter_tensor_name, - &part_filter_offset_type, - &part_filter_offset_shape); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail: " << s.error_message(); - } part_offset_tensor = Tensor(cpu_allocator(), part_offset_type, part_offset_shape); - part_filter_offset_tensor = Tensor( - cpu_allocator(), part_filter_offset_type, part_filter_offset_shape); s = reader_->Lookup(offset_tensor_name, &part_offset_tensor); if (!s.ok()) { LOG(ERROR) << "EV restoring fail:" << s.error_message(); } - s = reader_->Lookup(offset_filter_tensor_name, - &part_filter_offset_tensor); - if (!s.ok()) { - LOG(ERROR) << "EV restoring fail: " << s.error_message(); + if (restore_args_.m_has_filter) { + TensorShape part_filter_offset_shape; + DataType part_filter_offset_type; + string offset_filter_tensor_name = + name_string + kPartFilterOffsetTensorSuffsix; + s = reader_->LookupDtypeAndShape(offset_filter_tensor_name, + &part_filter_offset_type, + &part_filter_offset_shape); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail: " << s.error_message(); + } + part_filter_offset_tensor = \ + Tensor(cpu_allocator(), part_filter_offset_type, + part_filter_offset_shape); + s = reader_->Lookup(offset_filter_tensor_name, + &part_filter_offset_tensor); + if (!s.ok()) { + LOG(ERROR) << "EV restoring fail: " << s.error_message(); + } } } - auto part_offset_flat = part_offset_tensor.flat(); - auto part_filter_offset_flat = part_filter_offset_tensor.flat(); if (restore_args_.m_is_oldform) { VLOG(1) << "old form, EV name:" << name_string @@ -164,6 +167,7 @@ void CheckpointLoader::RestoreInternal( VLOG(1) << "new form checkpoint... :" << name_string << " , partition_id:" << restore_args_.m_partition_id << " , partition_num:" << restore_args_.m_partition_num; + auto part_offset_flat = part_offset_tensor.flat(); for (size_t i = 0; i < restore_args_.m_loaded_parts.size(); i++) { int subpart_id = restore_args_.m_loaded_parts[i]; size_t value_unit_bytes = sizeof(V) * restore_args_.m_old_dim; @@ -183,6 +187,7 @@ void CheckpointLoader::RestoreInternal( new_dim, emb_config, device); if (restore_args_.m_has_filter) { + auto part_filter_offset_flat = part_filter_offset_tensor.flat(); Status s = EVRestoreFilteredFeatures( subpart_id, new_dim, restore_buff, part_filter_offset_flat, emb_config, device); @@ -444,7 +449,7 @@ Status CheckpointLoader::EVInitTensorNameAndShape( } st = reader_->LookupHeader(restore_args_.m_tensor_version + "_filtered", sizeof(K) * version_filter_shape.dim_size(0)); - if (!st.ok()) { + if (!st.ok() && st.code() != error::NOT_FOUND) { return st; } st = reader_->LookupTensorShape(restore_args_.m_tensor_freq + "_filtered", @@ -463,7 +468,8 @@ Status CheckpointLoader::EVInitTensorNameAndShape( return st; } } - return st; + + return Status::OK(); } #define REGISTER_KERNELS(ktype, vtype) \ template Status CheckpointLoader::EVInitTensorNameAndShape(\ @@ -644,4 +650,4 @@ TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_ALL_INDEX) #undef REGISTER_KERNELS_ALL_INDEX #undef REGISTER_KERNELS -}// namespace tensorflow \ No newline at end of file +}// namespace tensorflow diff --git a/tensorflow/python/training/incr_ckpt_test.py b/tensorflow/python/training/incr_ckpt_test.py index b4f7ded3cea..55cf748a9d6 100644 --- a/tensorflow/python/training/incr_ckpt_test.py +++ b/tensorflow/python/training/incr_ckpt_test.py @@ -451,5 +451,59 @@ def testIncrementalSaverForResourceVariable(self): saver.build() incr_saver = incr_saver_module._get_incremental_saver(True, saver) + def testIncrementalSaverSaveAndRestore(self): + tmp_path = self.get_temp_dir() + full_ckpt_dir = os.path.join(tmp_path, "model.ckpt") + incr_ckpt_dir = os.path.join(tmp_path, "incr.ckpt") + full_ckpt_path = None + incr_ckpt_path = None + + # construct graph + emb_var = variable_scope.get_embedding_variable("emb", embedding_dim=3, + initializer = init_ops.ones_initializer(dtypes.float32)) + emb = embedding_ops.embedding_lookup(emb_var, + math_ops.cast([0, 1, 2, 3, 4], dtypes.int64)) + loss = math_ops.reduce_sum(emb, name = 'reduce_sum') + opt = adagrad.AdagradOptimizer(0.1) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + init = variables.global_variables_initializer() + saver = saver_module.Saver(sharded=True, incremental_save_restore=True) + incr_saver = \ + incr_saver_module.IncrementalSaver(sharded=True, + saver_def=saver.saver_def, defer_build=True) + incr_saver.build(saver._builder.filename_tensor) + + # generate full ckpt and incr ckpt. + full_ckpt_value=None + incr_ckpt_value=None + with self.test_session() as sess: + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) + sess.run([init]) + sess.run([train_op]) + full_ckpt_path = saver.save(sess, full_ckpt_dir, global_step = 10) + full_ckpt_value = sess.run([emb]) + print("full_ckpt: {}".format(full_ckpt_value)) + sess.run([train_op]) + incr_ckpt_path = \ + incr_saver.incremental_save(sess, incr_ckpt_dir, global_step=20) + incr_ckpt_value = sess.run([emb]) + print("incr_ckpt: {}".format(incr_ckpt_value)) + + # check the value after restoring parameter. + with self.test_session() as sess: + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_VAR_OPS)) + sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) + sess.run([init]) + saver.restore(sess, full_ckpt_path) + restore_full_ckpt_value = sess.run([emb]) + print("restore_full_ckpt: {}".format(restore_full_ckpt_value)) + incr_saver.incremental_restore(sess, full_ckpt_path, incr_ckpt_path) + restore_incr_ckpt_value = sess.run([emb]) + print("restore_incr_ckpt: {}".format(restore_incr_ckpt_value)) + self.assertAllClose(full_ckpt_value, restore_full_ckpt_value) + self.assertAllClose(incr_ckpt_value, restore_incr_ckpt_value) + if __name__ == "__main__": googletest.main()