diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index df6ae6f1277..c0d26a2f4d8 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -140,13 +140,6 @@ class EmbeddingVar : public ResourceBase { return storage_->Get(key, value_ptr); } - void BatchLookupKey(const EmbeddingVarContext& ctx, - const K* keys, - void** value_ptr_list, - int64 num_of_keys) { - storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys); - } - Status LookupOrCreateKey(K key, void** value_ptr, bool* is_filter, bool indices_as_pointer, int64 count = 1) { @@ -167,45 +160,6 @@ class EmbeddingVar : public ResourceBase { return Status::OK(); } - Status LookupOrCreateKey(const EmbeddingVarContext& context, - const K* keys, - void** value_ptrs, - int64 num_of_keys, - int64* indices_counts, - bool indices_as_pointer = false) { - if (indices_as_pointer) { - auto lookup_key_and_set_version_fn = [keys, value_ptrs] - (int64 start, int64 limit) { - for (int i = start; i < limit; i++) { - value_ptrs[i] = (void*)keys[i]; - } - }; - const int64 unit_cost = 1000; //very unreliable estimate for cost per step. - auto worker_threads = context.worker_threads; - Shard(worker_threads->num_threads, - worker_threads->workers, num_of_keys, unit_cost, - lookup_key_and_set_version_fn); - } else { - filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys); - } - - if (indices_counts != nullptr) { - auto add_freq_fn = [this, value_ptrs, indices_counts] - (int64 start, int64 limit) { - for (int i = start; i < limit; i++) { - feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]); - } - }; - const int64 unit_cost = 1000; //very unreliable estimate for cost per step. - auto worker_threads = context.worker_threads; - Shard(worker_threads->num_threads, - worker_threads->workers, num_of_keys, unit_cost, - add_freq_fn); - } - return Status::OK(); - } - - Status LookupOrCreateKey(K key, void** value_ptr) { Status s = storage_->GetOrCreate(key, value_ptr); TF_CHECK_OK(s); @@ -402,6 +356,51 @@ class EmbeddingVar : public ResourceBase { storage_->AddToCache(keys_tensor); } + + void BatchLookupKey(const EmbeddingVarContext& ctx, + const K* keys, + void** value_ptr_list, + int64 num_of_keys) { + storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys); + } + + Status LookupOrCreateKey(const EmbeddingVarContext& context, + const K* keys, + void** value_ptrs, + int64 num_of_keys, + int64* indices_counts, + bool indices_as_pointer = false) { + if (indices_as_pointer) { + auto lookup_key_and_set_version_fn = [keys, value_ptrs] + (int64 start, int64 limit) { + for (int i = start; i < limit; i++) { + value_ptrs[i] = (void*)keys[i]; + } + }; + const int64 unit_cost = 1000; //very unreliable estimate for cost per step. + auto worker_threads = context.worker_threads; + Shard(worker_threads->num_threads, + worker_threads->workers, num_of_keys, unit_cost, + lookup_key_and_set_version_fn); + } else { + filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys); + } + + if (indices_counts != nullptr) { + auto add_freq_fn = [this, value_ptrs, indices_counts] + (int64 start, int64 limit) { + for (int i = start; i < limit; i++) { + feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]); + } + }; + const int64 unit_cost = 1000; //very unreliable estimate for cost per step. + auto worker_threads = context.worker_threads; + Shard(worker_threads->num_threads, + worker_threads->workers, num_of_keys, unit_cost, + add_freq_fn); + } + return Status::OK(); + } #endif #if GOOGLE_CUDA