Skip to content

Commit

Permalink
[Embedding] Add interface of EmbeddingVar for Elastic Training.
Browse files Browse the repository at this point in the history
Signed-off-by: JunqiHu <[email protected]>
  • Loading branch information
Mesilenceki committed Oct 19, 2023
1 parent be62ec3 commit eb68cfd
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/framework/embedding/bloom_filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
VLOG(1) << "skip EV key:" << *(key_buff + i);
continue;
}
void* value_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
VLOG(1) << "skip EV key:" << *(key_buff + i);
continue;
}
int64 import_freq = 0;
Expand Down
21 changes: 21 additions & 0 deletions tensorflow/core/framework/embedding/cpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ class LocklessHashMap : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
std::pair<const K, void*> *hash_map_dump;
int64 bucket_count;
auto it = hash_map_.GetSnapshot();
hash_map_dump = it.first;
bucket_count = it.second;
for (int64 j = 0; j < bucket_count; j++) {
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
&& hash_map_dump[j].first % 1000 % partition_nums != partition_id) {
key_list->emplace_back(hash_map_dump[j].first);
value_ptr_list->emplace_back(hash_map_dump[j].second);
}
}

free(hash_map_dump);
return Status::OK();
}

std::string DebugString() const override {
LOG(INFO) << "map info size:" << Size()
<< "map info bucket_count:" << hash_map_.bucket_count()
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/core/framework/embedding/dense_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ class DenseHashMap : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
dense_hash_map hash_map_dump[partition_num_];
for (int i = 0; i< partition_num_; i++) {
spin_rd_lock l(hash_map_[i].mu);
hash_map_dump[i].hash_map = hash_map_[i].hash_map;
}
for (int i = 0; i< partition_num_; i++) {
for (const auto it : hash_map_dump[i].hash_map) {
if (it.first % 1000 % partition_nums != partition_id) {
key_list->push_back(it.first);
value_ptr_list->push_back(it.second);
}
}
}
return Status::OK();
}

std::string DebugString() const override {
return "";
}
Expand Down
86 changes: 85 additions & 1 deletion tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ class EmbeddingVar : public ResourceBase {
return storage_->CacheSize();
}

int64 MemoryUsage() const {
return storage_->Size() * (sizeof(K) + feat_desc_->data_bytes());
}

int64 MinFreq() {
return emb_config_.filter_freq;
}
Expand Down Expand Up @@ -516,6 +520,85 @@ class EmbeddingVar : public ResourceBase {
}
}

Status GetShardedSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list,
int partition_id, int partition_num) {
return storage_->GetShardedSnapshot(key_list, value_ptr_list,
partition_id, partition_num);
}

void ExportAndRemove(K* key_list, V* value_list,
int64* version_list, int64* freq_list,
std::vector<K>& tot_keys_list,
std::vector<void*>& tot_value_ptr_list) {
bool save_unfiltered_features = true;
TF_CHECK_OK(ReadBoolFromEnvVar(
"TF_EV_SAVE_FILTERED_FEATURES", true, &save_unfiltered_features));

bool is_save_freq = emb_config_.is_save_freq();
bool is_save_version = emb_config_.is_save_version();

for (int64 i = 0; i < tot_keys_list.size(); ++i) {
auto& value_ptr = tot_value_ptr_list[i];
if((int64)value_ptr == embedding::ValuePtrStatus::IS_DELETED)
continue;

bool is_admit = feat_desc_->IsAdmit(value_ptr);
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);

if (!is_admit) {
key_list[i] = tot_keys_list[i];

if (!is_in_dram) {
auto tmp_value = value_list + i * value_len_;
tmp_value = (V*)embedding::ValuePtrStatus::NOT_IN_DRAM;
value_ptr = (void*)((int64)value_ptr & ((1L << kDramFlagOffset) - 1));
} else if (feat_desc_->GetEmbedding(value_ptr, 0) == nullptr) {
memcpy(value_list + i * value_len_, default_value_, sizeof(V) * value_len_);
} else {
V* val = feat_desc_->GetEmbedding(value_ptr, emb_config_.emb_index);
memcpy(value_list + i * value_len_, val, sizeof(V) * value_len_);
}

if(is_save_version) {
int64 dump_version = feat_desc_->GetVersion(value_ptr);
version_list[i] = dump_version;
}

if(is_save_freq) {
int64 dump_freq = feat_desc_->GetFreq(value_ptr);
freq_list[i] = dump_freq;
}
} else {
if (!save_unfiltered_features)
return;
//TODO(JUNQI) : currently not export filtered keys
}

if (emb_config_.is_primary()) {
Status s;
s = storage_->Remove(tot_keys_list[i]);
if (!s.ok()) {
LOG(ERROR) << "Remove keys error: " << s.error_message();
}
feat_desc_->Deallocate(value_ptr);
}
}
}

Status RestoreFromKeysAndValues(int64 key_num, int partition_id,
int partition_num, const K* key_list,
const V* value_list, const int64* version_list,
const int64* freq_list,
const Eigen::GpuDevice* device = nullptr) {
RestoreBuffer restore_buff((char*)key_list, (char*)value_list,
(char*)version_list, (char*)freq_list);
return storage_->RestoreFeatures(key_num, kSavedPartitionNum,
partition_id, partition_num,
value_len_, false/* is_filter*/, false/* is_incr*/,
emb_config_, device, filter_, restore_buff);
}

mutex* mu() {
return &mu_;
}
Expand All @@ -537,6 +620,8 @@ class EmbeddingVar : public ResourceBase {
}
}

string Name() {return name_; }

V* GetDefaultValuePtr() {
return default_value_;
}
Expand Down Expand Up @@ -645,7 +730,6 @@ class EmbeddingVar : public ResourceBase {
GPUHashTable<K, V>* HashTable() {
return storage_->HashTable();
}

FilterPolicy<K, V, EmbeddingVar<K, V>>* GetFilter() const {
return filter_;
}
Expand Down
20 changes: 16 additions & 4 deletions tensorflow/core/framework/embedding/filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,31 @@ struct RestoreBuffer {
char* value_buffer = nullptr;
char* version_buffer = nullptr;
char* freq_buffer = nullptr;
bool should_release = false;

explicit RestoreBuffer(size_t buffer_size) {
key_buffer = new char[buffer_size];
value_buffer = new char[buffer_size];
version_buffer = new char[buffer_size];
freq_buffer = new char[buffer_size];
should_release = true;
}

explicit RestoreBuffer(char* i_key_buffer, char* i_value_buffer,
char* i_version_buffer, char* i_freq_buffer) {
key_buffer = i_key_buffer;
value_buffer = i_value_buffer;
version_buffer = i_version_buffer;
freq_buffer = i_freq_buffer;
}

~RestoreBuffer() {
delete []key_buffer;
delete []value_buffer;
delete []version_buffer;
delete []freq_buffer;
if (should_release) {
delete []key_buffer;
delete []value_buffer;
delete []version_buffer;
delete []freq_buffer;
}
}
};

Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/framework/embedding/gpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ class GPUHashMapKV : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot";
return Status::OK();
}

std::string DebugString() const override { return std::string(); }

GPUHashTable<K, V>* HashTable() override { return hash_table_; }
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/kv_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class KVInterface {
virtual Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) = 0;

virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) = 0;

virtual std::string DebugString() const = 0;

virtual Status BatchLookupOrCreate(const K* keys, V* val, V* default_v,
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/core/framework/embedding/leveldb_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,38 @@ class LevelDBKV : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
ReadOptions options;
options.snapshot = db_->GetSnapshot();
leveldb::Iterator* it = db_->NewIterator(options);
void* dram_value_ptr = feat_desc_->Allocate();
for (it->SeekToFirst(); it->Valid(); it->Next()) {
K key;
memcpy((char*)&key, it->key().ToString().data(), sizeof(K));
if (key % 1000 % partition_nums == partition_id) continue;
key_list->emplace_back(key);
FeatureDescriptor<V> hbm_feat_desc(
1, 1, ev_allocator()/*useless*/,
StorageType::HBM_DRAM, true, true,
{false, 0});
void* value_ptr = cpu_allocator()->AllocateRaw(
Allocator::kAllocatorAlignment, hbm_feat_desc.data_bytes());
memcpy(dram_value_ptr,
it->value().ToString().data(),
feat_desc_->data_bytes());
hbm_feat_desc.SetFreq(
value_ptr, feat_desc_->GetFreq(dram_value_ptr));
hbm_feat_desc.UpdateVersion(
value_ptr, feat_desc_->GetVersion(dram_value_ptr));
value_ptr_list->emplace_back(value_ptr);
}
delete it;
feat_desc_->Deallocate(dram_value_ptr);
return Status::OK();
}

int64 Size() const override {
return counter_->size();
}
Expand Down
9 changes: 8 additions & 1 deletion tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ class MultiTierStorage : public Storage<K, V> {
Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) override {
LOG(FATAL)<<"Can't get snapshot of MultiTierStorage.";
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage.";
return Status::OK();
}

void CopyEmbeddingsFromCPUToGPU(
Expand Down Expand Up @@ -170,7 +178,6 @@ class MultiTierStorage : public Storage<K, V> {
});
}

protected:
Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
bool is_incr, const EmbeddingConfig& emb_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class NullableFilterPolicy : public FilterPolicy<K, V, EV> {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
VLOG(1) << "skip EV key:" << *(key_buff + i);
continue;
}
int64 import_freq = 0;
Expand Down
13 changes: 11 additions & 2 deletions tensorflow/core/framework/embedding/single_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ class SingleTierStorage : public Storage<K, V> {
return kv_->GetSnapshot(key_list, value_ptr_list);
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
mutex_lock l(Storage<K, V>::mu_);
return kv_->GetShardedSnapshot(key_list, value_ptr_list,
partition_id, partition_nums);
}

Status Save(
const std::string& tensor_name,
const std::string& prefix,
Expand Down Expand Up @@ -286,7 +294,7 @@ class SingleTierStorage : public Storage<K, V> {
FeatureDescriptor<V>* feature_descriptor() {
return feat_desc_;
}
protected:

virtual Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
bool is_incr, const EmbeddingConfig& emb_config,
Expand All @@ -298,7 +306,8 @@ class SingleTierStorage : public Storage<K, V> {
false/*to_dram*/, is_incr, restore_buff);
return s;
}


protected:
virtual void Shrink(std::vector<K>& key_list,
std::vector<void*>& value_ptr_list,
ShrinkArgs& shrink_args,
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/framework/embedding/ssd_hash_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,12 @@ class SSDHashKV : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
return Status::OK();
}

Status GetSnapshot(
std::vector<K>* key_list,
std::vector<EmbFile*>* file_list) {
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/core/framework/embedding/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class Storage {
virtual int64 Size(int level) const = 0;
virtual Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) = 0;
virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) = 0;
virtual Status Save(
const string& tensor_name,
const string& prefix,
Expand Down Expand Up @@ -197,7 +200,6 @@ class Storage {
int64 freq, int64 version,
int emb_index) = 0;

protected:
virtual Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
bool is_incr, const EmbeddingConfig& emb_config,
Expand All @@ -206,7 +208,8 @@ class Storage {
RestoreBuffer& restore_buff) {
return Status::OK();
}


protected:
virtual Status RestoreSSD(int64 emb_index, int64 emb_slot_num,
int64 value_len,
const std::string& ssd_emb_file_name,
Expand Down
Loading

0 comments on commit eb68cfd

Please sign in to comment.