Skip to content

Commit

Permalink
[Embedding] Refine KVInterface::GetShardedSnapshot API. (#953)
Browse files Browse the repository at this point in the history
Signed-off-by: 泊霆 <[email protected]>
  • Loading branch information
Mesilenceki authored Dec 5, 2023
1 parent d814969 commit 7ce8477
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 24 deletions.
14 changes: 8 additions & 6 deletions tensorflow/core/framework/embedding/cpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class LocklessHashMap : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
std::pair<const K, void*> *hash_map_dump;
int64 bucket_count;
Expand All @@ -147,11 +148,12 @@ class LocklessHashMap : public KVInterface<K, V> {
bucket_count = it.second;
for (int64 j = 0; j < bucket_count; j++) {
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
&& hash_map_dump[j].first % kSavedPartitionNum
% partition_nums != partition_id) {
key_list->emplace_back(hash_map_dump[j].first);
value_ptr_list->emplace_back(hash_map_dump[j].second);
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_) {
int part_id = hash_map_dump[j].first % kSavedPartitionNum % partition_nums;
if (part_id != partition_id) {
key_list[part_id].emplace_back(hash_map_dump[j].first);
value_ptr_list[part_id].emplace_back(hash_map_dump[j].second);
}
}
}

Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/framework/embedding/dense_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class DenseHashMap : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
dense_hash_map hash_map_dump[partition_num_];
for (int i = 0; i< partition_num_; i++) {
Expand All @@ -131,9 +132,10 @@ class DenseHashMap : public KVInterface<K, V> {
}
for (int i = 0; i< partition_num_; i++) {
for (const auto it : hash_map_dump[i].hash_map) {
if (it.first % kSavedPartitionNum % partition_nums != partition_id) {
key_list->push_back(it.first);
value_ptr_list->push_back(it.second);
int part_id = it.first % kSavedPartitionNum % partition_nums;
if (part_id != partition_id) {
key_list[part_id].emplace_back(it.first);
value_ptr_list[part_id].emplace_back(it.second);
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ class EmbeddingVar : public ResourceBase {
}
}

Status GetShardedSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list,
Status GetShardedSnapshot(std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_num) {
return storage_->GetShardedSnapshot(key_list, value_ptr_list,
partition_id, partition_num);
Expand All @@ -546,7 +546,7 @@ class EmbeddingVar : public ResourceBase {
bool is_admit = feat_desc_->IsAdmit(value_ptr);
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);

if (!is_admit) {
if (is_admit) {
key_list[i] = tot_keys_list[i];

if (!is_in_dram) {
Expand All @@ -571,7 +571,7 @@ class EmbeddingVar : public ResourceBase {
}
} else {
if (!save_unfiltered_features)
return;
continue;
//TODO(JUNQI) : currently not export filtered keys
}

Expand All @@ -584,6 +584,7 @@ class EmbeddingVar : public ResourceBase {
feat_desc_->Deallocate(value_ptr);
}
}
return;
}

Status RestoreFromKeysAndValues(int64 key_num, int partition_id,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/gpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ class GPUHashMapKV : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot";
return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/kv_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class KVInterface {
std::vector<void*>* value_ptr_list) = 0;

virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) = 0;

virtual std::string DebugString() const = 0;
Expand Down
10 changes: 6 additions & 4 deletions tensorflow/core/framework/embedding/leveldb_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ class LevelDBKV : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
ReadOptions options;
options.snapshot = db_->GetSnapshot();
Expand All @@ -203,8 +204,9 @@ class LevelDBKV : public KVInterface<K, V> {
for (it->SeekToFirst(); it->Valid(); it->Next()) {
K key;
memcpy((char*)&key, it->key().ToString().data(), sizeof(K));
if (key % kSavedPartitionNum % partition_nums == partition_id) continue;
key_list->emplace_back(key);
int part_id = key % kSavedPartitionNum % partition_nums;
if (part_id == partition_id) continue;
key_list[part_id].emplace_back(key);
FeatureDescriptor<V> hbm_feat_desc(
1, 1, ev_allocator()/*useless*/,
StorageType::HBM_DRAM, true, true,
Expand All @@ -218,7 +220,7 @@ class LevelDBKV : public KVInterface<K, V> {
value_ptr, feat_desc_->GetFreq(dram_value_ptr));
hbm_feat_desc.UpdateVersion(
value_ptr, feat_desc_->GetVersion(dram_value_ptr));
value_ptr_list->emplace_back(value_ptr);
value_ptr_list[part_id].emplace_back(value_ptr);
}
delete it;
feat_desc_->Deallocate(dram_value_ptr);
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class MultiTierStorage : public Storage<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage.";
return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/single_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ class SingleTierStorage : public Storage<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
mutex_lock l(Storage<K, V>::mu_);
return kv_->GetShardedSnapshot(key_list, value_ptr_list,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/ssd_hash_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ class SSDHashKV : public KVInterface<K, V> {
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) override {
return Status::OK();
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/embedding/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class Storage {
virtual Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) = 0;
virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
std::vector<std::vector<K>>& key_list,
std::vector<std::vector<void*>>& value_ptr_list,
int partition_id, int partition_nums) = 0;
virtual Status Save(
const string& tensor_name,
Expand Down

0 comments on commit 7ce8477

Please sign in to comment.