diff --git a/tensorflow/core/framework/embedding/cpu_hash_map_kv.h b/tensorflow/core/framework/embedding/cpu_hash_map_kv.h index 750ba282285..f9a6e1fff25 100644 --- a/tensorflow/core/framework/embedding/cpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/cpu_hash_map_kv.h @@ -138,7 +138,8 @@ class LocklessHashMap : public KVInterface { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { std::pair *hash_map_dump; int64 bucket_count; @@ -147,11 +148,12 @@ class LocklessHashMap : public KVInterface { bucket_count = it.second; for (int64 j = 0; j < bucket_count; j++) { if (hash_map_dump[j].first != LocklessHashMap::EMPTY_KEY_ - && hash_map_dump[j].first != LocklessHashMap::DELETED_KEY_ - && hash_map_dump[j].first % kSavedPartitionNum - % partition_nums != partition_id) { - key_list->emplace_back(hash_map_dump[j].first); - value_ptr_list->emplace_back(hash_map_dump[j].second); + && hash_map_dump[j].first != LocklessHashMap::DELETED_KEY_) { + int part_id = hash_map_dump[j].first % kSavedPartitionNum % partition_nums; + if (part_id != partition_id) { + key_list[part_id].emplace_back(hash_map_dump[j].first); + value_ptr_list[part_id].emplace_back(hash_map_dump[j].second); + } } } diff --git a/tensorflow/core/framework/embedding/dense_hash_map_kv.h b/tensorflow/core/framework/embedding/dense_hash_map_kv.h index 8a27404b66f..12749a92e6e 100644 --- a/tensorflow/core/framework/embedding/dense_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/dense_hash_map_kv.h @@ -122,7 +122,8 @@ class DenseHashMap : public KVInterface { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { dense_hash_map hash_map_dump[partition_num_]; for (int i = 0; i< partition_num_; i++) { @@ -131,9 +132,10 @@ class DenseHashMap : public KVInterface { } for (int i = 0; i< partition_num_; i++) { for (const auto it : hash_map_dump[i].hash_map) { - if (it.first % kSavedPartitionNum % partition_nums != partition_id) { - key_list->push_back(it.first); - value_ptr_list->push_back(it.second); + int part_id = it.first % kSavedPartitionNum % partition_nums; + if (part_id != partition_id) { + key_list[part_id].emplace_back(it.first); + value_ptr_list[part_id].emplace_back(it.second); } } } diff --git a/tensorflow/core/framework/embedding/embedding_var.h b/tensorflow/core/framework/embedding/embedding_var.h index a66ec19fb97..df6ae6f1277 100644 --- a/tensorflow/core/framework/embedding/embedding_var.h +++ b/tensorflow/core/framework/embedding/embedding_var.h @@ -520,8 +520,8 @@ class EmbeddingVar : public ResourceBase { } } - Status GetShardedSnapshot(std::vector* key_list, - std::vector* value_ptr_list, + Status GetShardedSnapshot(std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_num) { return storage_->GetShardedSnapshot(key_list, value_ptr_list, partition_id, partition_num); @@ -546,7 +546,7 @@ class EmbeddingVar : public ResourceBase { bool is_admit = feat_desc_->IsAdmit(value_ptr); bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0); - if (!is_admit) { + if (is_admit) { key_list[i] = tot_keys_list[i]; if (!is_in_dram) { @@ -571,7 +571,7 @@ class EmbeddingVar : public ResourceBase { } } else { if (!save_unfiltered_features) - return; + continue; //TODO(JUNQI) : currently not export filtered keys } @@ -584,6 +584,7 @@ class EmbeddingVar : public ResourceBase { feat_desc_->Deallocate(value_ptr); } } + return; } Status RestoreFromKeysAndValues(int64 key_num, int partition_id, diff --git a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h index e73839e3f76..68fecf690ba 100644 --- a/tensorflow/core/framework/embedding/gpu_hash_map_kv.h +++ b/tensorflow/core/framework/embedding/gpu_hash_map_kv.h @@ -253,7 +253,8 @@ class GPUHashMapKV : public KVInterface { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot"; return Status::OK(); diff --git a/tensorflow/core/framework/embedding/kv_interface.h b/tensorflow/core/framework/embedding/kv_interface.h index dc603680138..8480132a7d9 100644 --- a/tensorflow/core/framework/embedding/kv_interface.h +++ b/tensorflow/core/framework/embedding/kv_interface.h @@ -91,7 +91,8 @@ class KVInterface { std::vector* value_ptr_list) = 0; virtual Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) = 0; virtual std::string DebugString() const = 0; diff --git a/tensorflow/core/framework/embedding/leveldb_kv.h b/tensorflow/core/framework/embedding/leveldb_kv.h index 47c8a39dfbd..030a0969e5d 100644 --- a/tensorflow/core/framework/embedding/leveldb_kv.h +++ b/tensorflow/core/framework/embedding/leveldb_kv.h @@ -194,7 +194,8 @@ class LevelDBKV : public KVInterface { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { ReadOptions options; options.snapshot = db_->GetSnapshot(); @@ -203,8 +204,9 @@ class LevelDBKV : public KVInterface { for (it->SeekToFirst(); it->Valid(); it->Next()) { K key; memcpy((char*)&key, it->key().ToString().data(), sizeof(K)); - if (key % kSavedPartitionNum % partition_nums == partition_id) continue; - key_list->emplace_back(key); + int part_id = key % kSavedPartitionNum % partition_nums; + if (part_id == partition_id) continue; + key_list[part_id].emplace_back(key); FeatureDescriptor hbm_feat_desc( 1, 1, ev_allocator()/*useless*/, StorageType::HBM_DRAM, true, true, @@ -218,7 +220,7 @@ class LevelDBKV : public KVInterface { value_ptr, feat_desc_->GetFreq(dram_value_ptr)); hbm_feat_desc.UpdateVersion( value_ptr, feat_desc_->GetVersion(dram_value_ptr)); - value_ptr_list->emplace_back(value_ptr); + value_ptr_list[part_id].emplace_back(value_ptr); } delete it; feat_desc_->Deallocate(dram_value_ptr); diff --git a/tensorflow/core/framework/embedding/multi_tier_storage.h b/tensorflow/core/framework/embedding/multi_tier_storage.h index f77fec8c85a..e27521f1a65 100644 --- a/tensorflow/core/framework/embedding/multi_tier_storage.h +++ b/tensorflow/core/framework/embedding/multi_tier_storage.h @@ -91,7 +91,8 @@ class MultiTierStorage : public Storage { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage."; return Status::OK(); diff --git a/tensorflow/core/framework/embedding/single_tier_storage.h b/tensorflow/core/framework/embedding/single_tier_storage.h index db96c807c5e..1c6bdd90790 100644 --- a/tensorflow/core/framework/embedding/single_tier_storage.h +++ b/tensorflow/core/framework/embedding/single_tier_storage.h @@ -224,7 +224,8 @@ class SingleTierStorage : public Storage { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { mutex_lock l(Storage::mu_); return kv_->GetShardedSnapshot(key_list, value_ptr_list, diff --git a/tensorflow/core/framework/embedding/ssd_hash_kv.h b/tensorflow/core/framework/embedding/ssd_hash_kv.h index a56c9f73385..bdc38cc5d5e 100644 --- a/tensorflow/core/framework/embedding/ssd_hash_kv.h +++ b/tensorflow/core/framework/embedding/ssd_hash_kv.h @@ -350,7 +350,8 @@ class SSDHashKV : public KVInterface { } Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) override { return Status::OK(); } diff --git a/tensorflow/core/framework/embedding/storage.h b/tensorflow/core/framework/embedding/storage.h index a652de5fa5f..559588af7e1 100644 --- a/tensorflow/core/framework/embedding/storage.h +++ b/tensorflow/core/framework/embedding/storage.h @@ -96,7 +96,8 @@ class Storage { virtual Status GetSnapshot(std::vector* key_list, std::vector* value_ptr_list) = 0; virtual Status GetShardedSnapshot( - std::vector* key_list, std::vector* value_ptr_list, + std::vector>& key_list, + std::vector>& value_ptr_list, int partition_id, int partition_nums) = 0; virtual Status Save( const string& tensor_name,