Skip to content

Commit

Permalink
[Embedding] undefine EV GPU interface in CPU compile. (#956)
Browse files Browse the repository at this point in the history
Signed-off-by: candy.dc <[email protected]>
  • Loading branch information
candyzone authored Dec 20, 2023
1 parent 717f7c5 commit 6bf5621
Showing 1 changed file with 45 additions and 46 deletions.
91 changes: 45 additions & 46 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,6 @@ class EmbeddingVar : public ResourceBase {
return storage_->Get(key, value_ptr);
}

void BatchLookupKey(const EmbeddingVarContext<GPUDevice>& ctx,
const K* keys,
void** value_ptr_list,
int64 num_of_keys) {
storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys);
}

Status LookupOrCreateKey(K key, void** value_ptr,
bool* is_filter, bool indices_as_pointer,
int64 count = 1) {
Expand All @@ -167,45 +160,6 @@ class EmbeddingVar : public ResourceBase {
return Status::OK();
}

Status LookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& context,
const K* keys,
void** value_ptrs,
int64 num_of_keys,
int64* indices_counts,
bool indices_as_pointer = false) {
if (indices_as_pointer) {
auto lookup_key_and_set_version_fn = [keys, value_ptrs]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
value_ptrs[i] = (void*)keys[i];
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
lookup_key_and_set_version_fn);
} else {
filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys);
}

if (indices_counts != nullptr) {
auto add_freq_fn = [this, value_ptrs, indices_counts]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]);
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
add_freq_fn);
}
return Status::OK();
}


Status LookupOrCreateKey(K key, void** value_ptr) {
Status s = storage_->GetOrCreate(key, value_ptr);
TF_CHECK_OK(s);
Expand Down Expand Up @@ -402,6 +356,51 @@ class EmbeddingVar : public ResourceBase {

storage_->AddToCache(keys_tensor);
}

void BatchLookupKey(const EmbeddingVarContext<GPUDevice>& ctx,
const K* keys,
void** value_ptr_list,
int64 num_of_keys) {
storage_->BatchGet(ctx, keys, value_ptr_list, num_of_keys);
}

Status LookupOrCreateKey(const EmbeddingVarContext<GPUDevice>& context,
const K* keys,
void** value_ptrs,
int64 num_of_keys,
int64* indices_counts,
bool indices_as_pointer = false) {
if (indices_as_pointer) {
auto lookup_key_and_set_version_fn = [keys, value_ptrs]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
value_ptrs[i] = (void*)keys[i];
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
lookup_key_and_set_version_fn);
} else {
filter_->BatchLookupOrCreateKey(context, keys, value_ptrs, num_of_keys);
}

if (indices_counts != nullptr) {
auto add_freq_fn = [this, value_ptrs, indices_counts]
(int64 start, int64 limit) {
for (int i = start; i < limit; i++) {
feat_desc_->AddFreq(value_ptrs[i], indices_counts[i]);
}
};
const int64 unit_cost = 1000; //very unreliable estimate for cost per step.
auto worker_threads = context.worker_threads;
Shard(worker_threads->num_threads,
worker_threads->workers, num_of_keys, unit_cost,
add_freq_fn);
}
return Status::OK();
}
#endif

#if GOOGLE_CUDA
Expand Down

0 comments on commit 6bf5621

Please sign in to comment.