From 2d31c8e37ea28d7c169879ebd9c3a89bd8d26cb5 Mon Sep 17 00:00:00 2001 From: Junqi Hu <42396655+Mesilenceki@users.noreply.github.com> Date: Thu, 26 Oct 2023 04:02:41 -0700 Subject: [PATCH] [Embedding] Add interface of EmbeddingVar for Elastic Training. (#933) Signed-off-by: JunqiHu --- configure.py | 2 +- tensorflow/contrib/elastic_grpc_server/BUILD | 3 +- tensorflow/core/BUILD | 5 +- .../framework/embedding/bloom_filter_policy.h | 2 +- .../embedding/counter_filter_policy.h | 2 +- .../framework/embedding/cpu_hash_map_kv.h | 22 +++++ .../framework/embedding/dense_hash_map_kv.h | 19 ++++ .../core/framework/embedding/embedding_var.h | 86 ++++++++++++++++++- .../embedding/embedding_var_ckpt_data.h | 1 - .../core/framework/embedding/filter_policy.h | 20 ++++- .../framework/embedding/gpu_hash_map_kv.h | 7 ++ .../core/framework/embedding/kv_interface.h | 5 ++ .../core/framework/embedding/leveldb_kv.h | 32 +++++++ .../framework/embedding/multi_tier_storage.h | 9 +- .../embedding/nullable_filter_policy.h | 2 +- .../framework/embedding/single_tier_storage.h | 13 ++- .../core/framework/embedding/ssd_hash_kv.h | 6 ++ tensorflow/core/framework/embedding/storage.h | 7 +- tensorflow/core/kernels/data/BUILD | 6 ++ tensorflow/core/kernels/data/iterator_ops.cc | 12 ++- tensorflow/python/ops/embedding_ops.py | 3 +- 21 files changed, 244 insertions(+), 20 deletions(-) diff --git a/configure.py b/configure.py index 6aeaf7d12af..4fb1c78c40b 100644 --- a/configure.py +++ b/configure.py @@ -1434,7 +1434,7 @@ def main(): True, 'star') set_build_var(environ_cp, 'TF_NEED_ELASTIC', 'ELASTIC TRAINING', 'with_elastic_support', - True, 'elastic') + False, 'elastic') set_build_var(environ_cp, 'TF_ENABLE_PMEM', 'PMEM', 'with_pmem_support', False, 'pmem') diff --git a/tensorflow/contrib/elastic_grpc_server/BUILD b/tensorflow/contrib/elastic_grpc_server/BUILD index ea4b87e3b58..16ec91f4435 100644 --- a/tensorflow/contrib/elastic_grpc_server/BUILD +++ b/tensorflow/contrib/elastic_grpc_server/BUILD @@ -56,7 +56,8 @@ cc_library( tf_cc_test( name = "elastic_grpc_test", size = "small", - srcs = ["elastic_grpc_server_lib_test.cc"], + srcs = select({"//tensorflow:with_elastic_support": ["elastic_grpc_server_lib_test.cc"], + "//conditions:default": []}), deps = [ ":elastic_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0531200e7ab..ef1ebcb6dcf 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -128,6 +128,7 @@ load( "tf_additional_numa_deps", "tf_additional_numa_lib_defines", "tf_additional_star_lib_defines", + "tf_additional_elastic_server_lib_defines", "tf_additional_api_compatible_defines", "tf_additional_pmem_lib_defines", "tf_additional_test_deps", @@ -1441,6 +1442,7 @@ tf_cc_test( cc_library( name = "ops", visibility = ["//visibility:public"], + defines = tf_additional_elastic_server_lib_defines(), deps = [ ":array_ops_op_lib", ":parquet_ops_op_lib", @@ -2562,7 +2564,8 @@ LIB_INTERNAL_DEFINES = ( tf_additional_gdr_lib_defines() + tf_additional_numa_lib_defines() + tf_additional_star_lib_defines() + - tf_additional_pmem_lib_defines() + tf_additional_pmem_lib_defines() + + tf_additional_elastic_server_lib_defines() ) cc_library( diff --git a/tensorflow/core/framework/embedding/bloom_filter_policy.h b/tensorflow/core/framework/embedding/bloom_filter_policy.h index 781511578af..8019e70a312 100644 --- a/tensorflow/core/framework/embedding/bloom_filter_policy.h +++ b/tensorflow/core/framework/embedding/bloom_filter_policy.h @@ -333,7 +333,7 @@ class BloomFilterPolicy : public FilterPolicy { // this can describe by graph(Mod + DynamicPartition), // but memory waste and slow if (*(key_buff + i) % bucket_num % partition_num != partition_id) { - LOG(INFO) << "skip EV key:" << *(key_buff + i); + VLOG(1) << "skip EV key:" << *(key_buff + i); continue; } void* value_ptr = nullptr; diff --git a/tensorflow/core/framework/embedding/counter_filter_policy.h b/tensorflow/core/framework/embedding/counter_filter_policy.h index 19cd90ad01c..e53d574182c 100644 --- a/tensorflow/core/framework/embedding/counter_filter_policy.h +++ b/tensorflow/core/framework/embedding/counter_filter_policy.h @@ -159,7 +159,7 @@ class CounterFilterPolicy : public FilterPolicy { // this can describe by graph(Mod + DynamicPartition), // but memory waste and slow if (*(key_buff + i) % bucket_num % partition_num != partition_id) { - LOG(INFO) << "skip EV key:" << *(key_buff + i); + VLOG(1) << "skip EV key:" << *(key_buff + i); continue; } int64 import_freq = 0; diff --git a/tensorflow/core/framework/embedding/cpu_hash_map_kv.h b/tensorflow/core/framework/embedding/cpu_hash_map_kv.h index 8476c399c40..750ba282285 100644 --- a/tensorflow/core/framework/embedding/cpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/cpu_hash_map_kv.h @@ -137,6 +137,28 @@ class LocklessHashMap : public KVInterface { return Status::OK(); } + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + std::pair *hash_map_dump; + int64 bucket_count; + auto it = hash_map_.GetSnapshot(); + hash_map_dump = it.first; + bucket_count = it.second; + for (int64 j = 0; j < bucket_count; j++) { + if (hash_map_dump[j].first != LocklessHashMap::EMPTY_KEY_ + && hash_map_dump[j].first != LocklessHashMap::DELETED_KEY_ + && hash_map_dump[j].first % kSavedPartitionNum + % partition_nums != partition_id) { + key_list->emplace_back(hash_map_dump[j].first); + value_ptr_list->emplace_back(hash_map_dump[j].second); + } + } + + free(hash_map_dump); + return Status::OK(); + } + std::string DebugString() const override { LOG(INFO) << "map info size:" << Size() << "map info bucket_count:" << hash_map_.bucket_count() diff --git a/tensorflow/core/framework/embedding/dense_hash_map_kv.h b/tensorflow/core/framework/embedding/dense_hash_map_kv.h index ffaf2e335dc..8a27404b66f 100644 --- a/tensorflow/core/framework/embedding/dense_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/dense_hash_map_kv.h @@ -121,6 +121,25 @@ class DenseHashMap : public KVInterface { return Status::OK(); } + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + dense_hash_map hash_map_dump[partition_num_]; + for (int i = 0; i< partition_num_; i++) { + spin_rd_lock l(hash_map_[i].mu); + hash_map_dump[i].hash_map = hash_map_[i].hash_map; + } + for (int i = 0; i< partition_num_; i++) { + for (const auto it : hash_map_dump[i].hash_map) { + if (it.first % kSavedPartitionNum % partition_nums != partition_id) { + key_list->push_back(it.first); + value_ptr_list->push_back(it.second); + } + } + } + return Status::OK(); + } + std::string DebugString() const override { return ""; } diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index 487f595bf31..a66ec19fb97 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -435,6 +435,10 @@ class EmbeddingVar : public ResourceBase { return storage_->CacheSize(); } + int64 MemoryUsage() const { + return storage_->Size() * (sizeof(K) + feat_desc_->data_bytes()); + } + int64 MinFreq() { return emb_config_.filter_freq; } @@ -516,6 +520,85 @@ class EmbeddingVar : public ResourceBase { } } + Status GetShardedSnapshot(std::vector* key_list, + std::vector* value_ptr_list, + int partition_id, int partition_num) { + return storage_->GetShardedSnapshot(key_list, value_ptr_list, + partition_id, partition_num); + } + + void ExportAndRemove(K* key_list, V* value_list, + int64* version_list, int64* freq_list, + std::vector& tot_keys_list, + std::vector& tot_value_ptr_list) { + bool save_unfiltered_features = true; + TF_CHECK_OK(ReadBoolFromEnvVar( + "TF_EV_SAVE_FILTERED_FEATURES", true, &save_unfiltered_features)); + + bool is_save_freq = emb_config_.is_save_freq(); + bool is_save_version = emb_config_.is_save_version(); + + for (int64 i = 0; i < tot_keys_list.size(); ++i) { + auto& value_ptr = tot_value_ptr_list[i]; + if((int64)value_ptr == embedding::ValuePtrStatus::IS_DELETED) + continue; + + bool is_admit = feat_desc_->IsAdmit(value_ptr); + bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0); + + if (!is_admit) { + key_list[i] = tot_keys_list[i]; + + if (!is_in_dram) { + auto tmp_value = value_list + i * value_len_; + tmp_value = (V*)embedding::ValuePtrStatus::NOT_IN_DRAM; + value_ptr = (void*)((int64)value_ptr & ((1L << kDramFlagOffset) - 1)); + } else if (feat_desc_->GetEmbedding(value_ptr, 0) == nullptr) { + memcpy(value_list + i * value_len_, default_value_, sizeof(V) * value_len_); + } else { + V* val = feat_desc_->GetEmbedding(value_ptr, emb_config_.emb_index); + memcpy(value_list + i * value_len_, val, sizeof(V) * value_len_); + } + + if(is_save_version) { + int64 dump_version = feat_desc_->GetVersion(value_ptr); + version_list[i] = dump_version; + } + + if(is_save_freq) { + int64 dump_freq = feat_desc_->GetFreq(value_ptr); + freq_list[i] = dump_freq; + } + } else { + if (!save_unfiltered_features) + return; + //TODO(JUNQI) : currently not export filtered keys + } + + if (emb_config_.is_primary()) { + Status s; + s = storage_->Remove(tot_keys_list[i]); + if (!s.ok()) { + LOG(ERROR) << "Remove keys error: " << s.error_message(); + } + feat_desc_->Deallocate(value_ptr); + } + } + } + + Status RestoreFromKeysAndValues(int64 key_num, int partition_id, + int partition_num, const K* key_list, + const V* value_list, const int64* version_list, + const int64* freq_list, + const Eigen::GpuDevice* device = nullptr) { + RestoreBuffer restore_buff((char*)key_list, (char*)value_list, + (char*)version_list, (char*)freq_list); + return storage_->RestoreFeatures(key_num, kSavedPartitionNum, + partition_id, partition_num, + value_len_, false/* is_filter*/, false/* is_incr*/, + emb_config_, device, filter_, restore_buff); + } + mutex* mu() { return &mu_; } @@ -537,6 +620,8 @@ class EmbeddingVar : public ResourceBase { } } + string Name() {return name_; } + V* GetDefaultValuePtr() { return default_value_; } @@ -645,7 +730,6 @@ class EmbeddingVar : public ResourceBase { GPUHashTable* HashTable() { return storage_->HashTable(); } - FilterPolicy>* GetFilter() const { return filter_; } diff --git a/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h index 10bf0d0e43b..13072f9cdd1 100644 --- a/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h +++ b/tensorflow/core/framework/embedding/embedding_var_ckpt_data.h @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { class BundleWriter; namespace { - const int kSavedPartitionNum = 1000; const int kDramFlagOffset = 49; } diff --git a/tensorflow/core/framework/embedding/filter_policy.h b/tensorflow/core/framework/embedding/filter_policy.h index 256d3b044d4..c994829bafc 100644 --- a/tensorflow/core/framework/embedding/filter_policy.h +++ b/tensorflow/core/framework/embedding/filter_policy.h @@ -27,19 +27,31 @@ struct RestoreBuffer { char* value_buffer = nullptr; char* version_buffer = nullptr; char* freq_buffer = nullptr; + bool should_release = false; explicit RestoreBuffer(size_t buffer_size) { key_buffer = new char[buffer_size]; value_buffer = new char[buffer_size]; version_buffer = new char[buffer_size]; freq_buffer = new char[buffer_size]; + should_release = true; + } + + explicit RestoreBuffer(char* i_key_buffer, char* i_value_buffer, + char* i_version_buffer, char* i_freq_buffer) { + key_buffer = i_key_buffer; + value_buffer = i_value_buffer; + version_buffer = i_version_buffer; + freq_buffer = i_freq_buffer; } ~RestoreBuffer() { - delete []key_buffer; - delete []value_buffer; - delete []version_buffer; - delete []freq_buffer; + if (should_release) { + delete []key_buffer; + delete []value_buffer; + delete []version_buffer; + delete []freq_buffer; + } } }; diff --git a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h index fc4a2506313..e73839e3f76 100644 --- a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h @@ -252,6 +252,13 @@ class GPUHashMapKV : public KVInterface { return Status::OK(); } + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot"; + return Status::OK(); + } + std::string DebugString() const override { return std::string(); } GPUHashTable* HashTable() override { return hash_table_; } diff --git a/tensorflow/core/framework/embedding/kv_interface.h b/tensorflow/core/framework/embedding/kv_interface.h index 3659187c825..dc603680138 100644 --- a/tensorflow/core/framework/embedding/kv_interface.h +++ b/tensorflow/core/framework/embedding/kv_interface.h @@ -23,6 +23,7 @@ limitations under the License. namespace tensorflow { namespace { const char* kInferenceMode = "INFERENCE_MODE"; +const int kSavedPartitionNum = 1000; } template @@ -89,6 +90,10 @@ class KVInterface { virtual Status GetSnapshot(std::vector* key_list, std::vector* value_ptr_list) = 0; + virtual Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) = 0; + virtual std::string DebugString() const = 0; virtual Status BatchLookupOrCreate(const K* keys, V* val, V* default_v, diff --git a/tensorflow/core/framework/embedding/leveldb_kv.h b/tensorflow/core/framework/embedding/leveldb_kv.h index e488ab3776d..47c8a39dfbd 100644 --- a/tensorflow/core/framework/embedding/leveldb_kv.h +++ b/tensorflow/core/framework/embedding/leveldb_kv.h @@ -193,6 +193,38 @@ class LevelDBKV : public KVInterface { return Status::OK(); } + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + ReadOptions options; + options.snapshot = db_->GetSnapshot(); + leveldb::Iterator* it = db_->NewIterator(options); + void* dram_value_ptr = feat_desc_->Allocate(); + for (it->SeekToFirst(); it->Valid(); it->Next()) { + K key; + memcpy((char*)&key, it->key().ToString().data(), sizeof(K)); + if (key % kSavedPartitionNum % partition_nums == partition_id) continue; + key_list->emplace_back(key); + FeatureDescriptor hbm_feat_desc( + 1, 1, ev_allocator()/*useless*/, + StorageType::HBM_DRAM, true, true, + {false, 0}); + void* value_ptr = cpu_allocator()->AllocateRaw( + Allocator::kAllocatorAlignment, hbm_feat_desc.data_bytes()); + memcpy(dram_value_ptr, + it->value().ToString().data(), + feat_desc_->data_bytes()); + hbm_feat_desc.SetFreq( + value_ptr, feat_desc_->GetFreq(dram_value_ptr)); + hbm_feat_desc.UpdateVersion( + value_ptr, feat_desc_->GetVersion(dram_value_ptr)); + value_ptr_list->emplace_back(value_ptr); + } + delete it; + feat_desc_->Deallocate(dram_value_ptr); + return Status::OK(); + } + int64 Size() const override { return counter_->size(); } diff --git a/tensorflow/core/framework/embedding/multi_tier_storage.h b/tensorflow/core/framework/embedding/multi_tier_storage.h index 7955322aca6..f77fec8c85a 100644 --- a/tensorflow/core/framework/embedding/multi_tier_storage.h +++ b/tensorflow/core/framework/embedding/multi_tier_storage.h @@ -87,6 +87,14 @@ class MultiTierStorage : public Storage { Status GetSnapshot(std::vector* key_list, std::vector* value_ptr_list) override { LOG(FATAL)<<"Can't get snapshot of MultiTierStorage."; + return Status::OK(); + } + + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage."; + return Status::OK(); } void CopyEmbeddingsFromCPUToGPU( @@ -170,7 +178,6 @@ class MultiTierStorage : public Storage { }); } - protected: Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id, int64 partition_num, int64 value_len, bool is_filter, bool is_incr, const EmbeddingConfig& emb_config, diff --git a/tensorflow/core/framework/embedding/nullable_filter_policy.h b/tensorflow/core/framework/embedding/nullable_filter_policy.h index 7e3ace0063d..55f718d7ca4 100644 --- a/tensorflow/core/framework/embedding/nullable_filter_policy.h +++ b/tensorflow/core/framework/embedding/nullable_filter_policy.h @@ -150,7 +150,7 @@ class NullableFilterPolicy : public FilterPolicy { // this can describe by graph(Mod + DynamicPartition), // but memory waste and slow if (*(key_buff + i) % bucket_num % partition_num != partition_id) { - LOG(INFO) << "skip EV key:" << *(key_buff + i); + VLOG(1) << "skip EV key:" << *(key_buff + i); continue; } int64 import_freq = 0; diff --git a/tensorflow/core/framework/embedding/single_tier_storage.h b/tensorflow/core/framework/embedding/single_tier_storage.h index be08afd7f50..db96c807c5e 100644 --- a/tensorflow/core/framework/embedding/single_tier_storage.h +++ b/tensorflow/core/framework/embedding/single_tier_storage.h @@ -223,6 +223,14 @@ class SingleTierStorage : public Storage { return kv_->GetSnapshot(key_list, value_ptr_list); } + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + mutex_lock l(Storage::mu_); + return kv_->GetShardedSnapshot(key_list, value_ptr_list, + partition_id, partition_nums); + } + Status Save( const std::string& tensor_name, const std::string& prefix, @@ -286,7 +294,7 @@ class SingleTierStorage : public Storage { FeatureDescriptor* feature_descriptor() { return feat_desc_; } - protected: + virtual Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id, int64 partition_num, int64 value_len, bool is_filter, bool is_incr, const EmbeddingConfig& emb_config, @@ -298,7 +306,8 @@ class SingleTierStorage : public Storage { false/*to_dram*/, is_incr, restore_buff); return s; } - + + protected: virtual void Shrink(std::vector& key_list, std::vector& value_ptr_list, ShrinkArgs& shrink_args, diff --git a/tensorflow/core/framework/embedding/ssd_hash_kv.h b/tensorflow/core/framework/embedding/ssd_hash_kv.h index f51c6904a50..a56c9f73385 100644 --- a/tensorflow/core/framework/embedding/ssd_hash_kv.h +++ b/tensorflow/core/framework/embedding/ssd_hash_kv.h @@ -349,6 +349,12 @@ class SSDHashKV : public KVInterface { return Status::OK(); } + Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) override { + return Status::OK(); + } + Status GetSnapshot( std::vector* key_list, std::vector* file_list) { diff --git a/tensorflow/core/framework/embedding/storage.h b/tensorflow/core/framework/embedding/storage.h index 1ffb435054b..a652de5fa5f 100644 --- a/tensorflow/core/framework/embedding/storage.h +++ b/tensorflow/core/framework/embedding/storage.h @@ -95,6 +95,9 @@ class Storage { virtual int64 Size(int level) const = 0; virtual Status GetSnapshot(std::vector* key_list, std::vector* value_ptr_list) = 0; + virtual Status GetShardedSnapshot( + std::vector* key_list, std::vector* value_ptr_list, + int partition_id, int partition_nums) = 0; virtual Status Save( const string& tensor_name, const string& prefix, @@ -197,7 +200,6 @@ class Storage { int64 freq, int64 version, int emb_index) = 0; - protected: virtual Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id, int64 partition_num, int64 value_len, bool is_filter, bool is_incr, const EmbeddingConfig& emb_config, @@ -206,7 +208,8 @@ class Storage { RestoreBuffer& restore_buff) { return Status::OK(); } - + + protected: virtual Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len, const std::string& ssd_emb_file_name, diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 08445403b58..6878c5f8350 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -9,6 +9,11 @@ load( "transitive_hdrs", ) +load( + "//tensorflow/core/platform:default/build_config.bzl", + "tf_additional_elastic_server_lib_defines", +) + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -1119,6 +1124,7 @@ tf_kernel_library( name = "iterator_ops", srcs = ["iterator_ops.cc"], hdrs = ["iterator_ops.h"], + defines = tf_additional_elastic_server_lib_defines(), deps = [ ":captured_function", ":dataset_utils", diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 08d9d936537..ed6b40a38a0 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -308,7 +308,11 @@ void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { } ResourceMgr* mgr = context->resource_manager(); - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); +#ifdef TENSORFLOW_USE_ELASTIC_SERVER + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), true)); +#else + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def(), false)); +#endif IteratorResource* resource; OP_REQUIRES_OK( @@ -783,7 +787,11 @@ class OneShotIteratorOp : public AsyncOpKernel { Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, ContainerInfo* cinfo) { - TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); +#ifdef TENSORFLOW_USE_ELASTIC_SERVER + TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), true)); +#else + TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def(), false)); +#endif FunctionLibraryRuntime* flr; std::unique_ptr flib_def(nullptr); diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index cb2b7bb8154..e239c9ba8d5 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export +SAVED_PARTITIONED_NUM = 1000 def _clip(params, ids, max_norm): """Helper function for _embedding_lookup_and_transform. @@ -216,7 +217,7 @@ def _embedding_lookup_and_transform(params, if isinstance(params[0], kv_variable_ops.EmbeddingVariable): new_ids = flat_ids - p_assignments = flat_ids % 1000 % np + p_assignments = flat_ids % SAVED_PARTITIONED_NUM % np elif partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np