Skip to content

Commit

Permalink
[Embedding] Add interface of EmbeddingVar for Elastic Training. (#933)
Browse files Browse the repository at this point in the history
Signed-off-by: JunqiHu <[email protected]>
  • Loading branch information
Mesilenceki authored Oct 26, 2023
1 parent 0e8127a commit 2d31c8e
Show file tree
Hide file tree
Showing 21 changed files with 244 additions and 20 deletions.
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ def main():
True, 'star')

set_build_var(environ_cp, 'TF_NEED_ELASTIC', 'ELASTIC TRAINING', 'with_elastic_support',
True, 'elastic')
False, 'elastic')

set_build_var(environ_cp, 'TF_ENABLE_PMEM', 'PMEM', 'with_pmem_support',
False, 'pmem')
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/contrib/elastic_grpc_server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ cc_library(
tf_cc_test(
name = "elastic_grpc_test",
size = "small",
srcs = ["elastic_grpc_server_lib_test.cc"],
srcs = select({"//tensorflow:with_elastic_support": ["elastic_grpc_server_lib_test.cc"],
"//conditions:default": []}),
deps = [
":elastic_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ load(
"tf_additional_numa_deps",
"tf_additional_numa_lib_defines",
"tf_additional_star_lib_defines",
"tf_additional_elastic_server_lib_defines",
"tf_additional_api_compatible_defines",
"tf_additional_pmem_lib_defines",
"tf_additional_test_deps",
Expand Down Expand Up @@ -1441,6 +1442,7 @@ tf_cc_test(
cc_library(
name = "ops",
visibility = ["//visibility:public"],
defines = tf_additional_elastic_server_lib_defines(),
deps = [
":array_ops_op_lib",
":parquet_ops_op_lib",
Expand Down Expand Up @@ -2562,7 +2564,8 @@ LIB_INTERNAL_DEFINES = (
tf_additional_gdr_lib_defines() +
tf_additional_numa_lib_defines() +
tf_additional_star_lib_defines() +
tf_additional_pmem_lib_defines()
tf_additional_pmem_lib_defines() +
tf_additional_elastic_server_lib_defines()
)

cc_library(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/framework/embedding/bloom_filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class BloomFilterPolicy : public FilterPolicy<K, V, EV> {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
VLOG(1) << "skip EV key:" << *(key_buff + i);
continue;
}
void* value_ptr = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class CounterFilterPolicy : public FilterPolicy<K, V, EV> {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
VLOG(1) << "skip EV key:" << *(key_buff + i);
continue;
}
int64 import_freq = 0;
Expand Down
22 changes: 22 additions & 0 deletions tensorflow/core/framework/embedding/cpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,28 @@ class LocklessHashMap : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
std::pair<const K, void*> *hash_map_dump;
int64 bucket_count;
auto it = hash_map_.GetSnapshot();
hash_map_dump = it.first;
bucket_count = it.second;
for (int64 j = 0; j < bucket_count; j++) {
if (hash_map_dump[j].first != LocklessHashMap<K, V>::EMPTY_KEY_
&& hash_map_dump[j].first != LocklessHashMap<K, V>::DELETED_KEY_
&& hash_map_dump[j].first % kSavedPartitionNum
% partition_nums != partition_id) {
key_list->emplace_back(hash_map_dump[j].first);
value_ptr_list->emplace_back(hash_map_dump[j].second);
}
}

free(hash_map_dump);
return Status::OK();
}

std::string DebugString() const override {
LOG(INFO) << "map info size:" << Size()
<< "map info bucket_count:" << hash_map_.bucket_count()
Expand Down
19 changes: 19 additions & 0 deletions tensorflow/core/framework/embedding/dense_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ class DenseHashMap : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
dense_hash_map hash_map_dump[partition_num_];
for (int i = 0; i< partition_num_; i++) {
spin_rd_lock l(hash_map_[i].mu);
hash_map_dump[i].hash_map = hash_map_[i].hash_map;
}
for (int i = 0; i< partition_num_; i++) {
for (const auto it : hash_map_dump[i].hash_map) {
if (it.first % kSavedPartitionNum % partition_nums != partition_id) {
key_list->push_back(it.first);
value_ptr_list->push_back(it.second);
}
}
}
return Status::OK();
}

std::string DebugString() const override {
return "";
}
Expand Down
86 changes: 85 additions & 1 deletion tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ class EmbeddingVar : public ResourceBase {
return storage_->CacheSize();
}

int64 MemoryUsage() const {
return storage_->Size() * (sizeof(K) + feat_desc_->data_bytes());
}

int64 MinFreq() {
return emb_config_.filter_freq;
}
Expand Down Expand Up @@ -516,6 +520,85 @@ class EmbeddingVar : public ResourceBase {
}
}

Status GetShardedSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list,
int partition_id, int partition_num) {
return storage_->GetShardedSnapshot(key_list, value_ptr_list,
partition_id, partition_num);
}

void ExportAndRemove(K* key_list, V* value_list,
int64* version_list, int64* freq_list,
std::vector<K>& tot_keys_list,
std::vector<void*>& tot_value_ptr_list) {
bool save_unfiltered_features = true;
TF_CHECK_OK(ReadBoolFromEnvVar(
"TF_EV_SAVE_FILTERED_FEATURES", true, &save_unfiltered_features));

bool is_save_freq = emb_config_.is_save_freq();
bool is_save_version = emb_config_.is_save_version();

for (int64 i = 0; i < tot_keys_list.size(); ++i) {
auto& value_ptr = tot_value_ptr_list[i];
if((int64)value_ptr == embedding::ValuePtrStatus::IS_DELETED)
continue;

bool is_admit = feat_desc_->IsAdmit(value_ptr);
bool is_in_dram = ((int64)value_ptr >> kDramFlagOffset == 0);

if (!is_admit) {
key_list[i] = tot_keys_list[i];

if (!is_in_dram) {
auto tmp_value = value_list + i * value_len_;
tmp_value = (V*)embedding::ValuePtrStatus::NOT_IN_DRAM;
value_ptr = (void*)((int64)value_ptr & ((1L << kDramFlagOffset) - 1));
} else if (feat_desc_->GetEmbedding(value_ptr, 0) == nullptr) {
memcpy(value_list + i * value_len_, default_value_, sizeof(V) * value_len_);
} else {
V* val = feat_desc_->GetEmbedding(value_ptr, emb_config_.emb_index);
memcpy(value_list + i * value_len_, val, sizeof(V) * value_len_);
}

if(is_save_version) {
int64 dump_version = feat_desc_->GetVersion(value_ptr);
version_list[i] = dump_version;
}

if(is_save_freq) {
int64 dump_freq = feat_desc_->GetFreq(value_ptr);
freq_list[i] = dump_freq;
}
} else {
if (!save_unfiltered_features)
return;
//TODO(JUNQI) : currently not export filtered keys
}

if (emb_config_.is_primary()) {
Status s;
s = storage_->Remove(tot_keys_list[i]);
if (!s.ok()) {
LOG(ERROR) << "Remove keys error: " << s.error_message();
}
feat_desc_->Deallocate(value_ptr);
}
}
}

Status RestoreFromKeysAndValues(int64 key_num, int partition_id,
int partition_num, const K* key_list,
const V* value_list, const int64* version_list,
const int64* freq_list,
const Eigen::GpuDevice* device = nullptr) {
RestoreBuffer restore_buff((char*)key_list, (char*)value_list,
(char*)version_list, (char*)freq_list);
return storage_->RestoreFeatures(key_num, kSavedPartitionNum,
partition_id, partition_num,
value_len_, false/* is_filter*/, false/* is_incr*/,
emb_config_, device, filter_, restore_buff);
}

mutex* mu() {
return &mu_;
}
Expand All @@ -537,6 +620,8 @@ class EmbeddingVar : public ResourceBase {
}
}

string Name() {return name_; }

V* GetDefaultValuePtr() {
return default_value_;
}
Expand Down Expand Up @@ -645,7 +730,6 @@ class EmbeddingVar : public ResourceBase {
GPUHashTable<K, V>* HashTable() {
return storage_->HashTable();
}

FilterPolicy<K, V, EmbeddingVar<K, V>>* GetFilter() const {
return filter_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
namespace tensorflow {
class BundleWriter;
namespace {
const int kSavedPartitionNum = 1000;
const int kDramFlagOffset = 49;
}

Expand Down
20 changes: 16 additions & 4 deletions tensorflow/core/framework/embedding/filter_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,31 @@ struct RestoreBuffer {
char* value_buffer = nullptr;
char* version_buffer = nullptr;
char* freq_buffer = nullptr;
bool should_release = false;

explicit RestoreBuffer(size_t buffer_size) {
key_buffer = new char[buffer_size];
value_buffer = new char[buffer_size];
version_buffer = new char[buffer_size];
freq_buffer = new char[buffer_size];
should_release = true;
}

explicit RestoreBuffer(char* i_key_buffer, char* i_value_buffer,
char* i_version_buffer, char* i_freq_buffer) {
key_buffer = i_key_buffer;
value_buffer = i_value_buffer;
version_buffer = i_version_buffer;
freq_buffer = i_freq_buffer;
}

~RestoreBuffer() {
delete []key_buffer;
delete []value_buffer;
delete []version_buffer;
delete []freq_buffer;
if (should_release) {
delete []key_buffer;
delete []value_buffer;
delete []version_buffer;
delete []freq_buffer;
}
}
};

Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/framework/embedding/gpu_hash_map_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ class GPUHashMapKV : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
LOG(INFO) << "GPUHashMapKV do not support GetShardedSnapshot";
return Status::OK();
}

std::string DebugString() const override { return std::string(); }

GPUHashTable<K, V>* HashTable() override { return hash_table_; }
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/framework/embedding/kv_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
namespace tensorflow {
namespace {
const char* kInferenceMode = "INFERENCE_MODE";
const int kSavedPartitionNum = 1000;
}

template <class K, class V>
Expand Down Expand Up @@ -89,6 +90,10 @@ class KVInterface {
virtual Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) = 0;

virtual Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) = 0;

virtual std::string DebugString() const = 0;

virtual Status BatchLookupOrCreate(const K* keys, V* val, V* default_v,
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/core/framework/embedding/leveldb_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,38 @@ class LevelDBKV : public KVInterface<K, V> {
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
ReadOptions options;
options.snapshot = db_->GetSnapshot();
leveldb::Iterator* it = db_->NewIterator(options);
void* dram_value_ptr = feat_desc_->Allocate();
for (it->SeekToFirst(); it->Valid(); it->Next()) {
K key;
memcpy((char*)&key, it->key().ToString().data(), sizeof(K));
if (key % kSavedPartitionNum % partition_nums == partition_id) continue;
key_list->emplace_back(key);
FeatureDescriptor<V> hbm_feat_desc(
1, 1, ev_allocator()/*useless*/,
StorageType::HBM_DRAM, true, true,
{false, 0});
void* value_ptr = cpu_allocator()->AllocateRaw(
Allocator::kAllocatorAlignment, hbm_feat_desc.data_bytes());
memcpy(dram_value_ptr,
it->value().ToString().data(),
feat_desc_->data_bytes());
hbm_feat_desc.SetFreq(
value_ptr, feat_desc_->GetFreq(dram_value_ptr));
hbm_feat_desc.UpdateVersion(
value_ptr, feat_desc_->GetVersion(dram_value_ptr));
value_ptr_list->emplace_back(value_ptr);
}
delete it;
feat_desc_->Deallocate(dram_value_ptr);
return Status::OK();
}

int64 Size() const override {
return counter_->size();
}
Expand Down
9 changes: 8 additions & 1 deletion tensorflow/core/framework/embedding/multi_tier_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ class MultiTierStorage : public Storage<K, V> {
Status GetSnapshot(std::vector<K>* key_list,
std::vector<void*>* value_ptr_list) override {
LOG(FATAL)<<"Can't get snapshot of MultiTierStorage.";
return Status::OK();
}

Status GetShardedSnapshot(
std::vector<K>* key_list, std::vector<void*>* value_ptr_list,
int partition_id, int partition_nums) override {
LOG(FATAL)<<"Can't get sharded snapshot of MultiTierStorage.";
return Status::OK();
}

void CopyEmbeddingsFromCPUToGPU(
Expand Down Expand Up @@ -170,7 +178,6 @@ class MultiTierStorage : public Storage<K, V> {
});
}

protected:
Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id,
int64 partition_num, int64 value_len, bool is_filter,
bool is_incr, const EmbeddingConfig& emb_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class NullableFilterPolicy : public FilterPolicy<K, V, EV> {
// this can describe by graph(Mod + DynamicPartition),
// but memory waste and slow
if (*(key_buff + i) % bucket_num % partition_num != partition_id) {
LOG(INFO) << "skip EV key:" << *(key_buff + i);
VLOG(1) << "skip EV key:" << *(key_buff + i);
continue;
}
int64 import_freq = 0;
Expand Down
Loading

0 comments on commit 2d31c8e

Please sign in to comment.