Skip to content

Commit

Permalink
[Embedding] Adjust the header file of embedding variable.
Browse files Browse the repository at this point in the history
Signed-off-by: chenbangduo.cbd <[email protected]>
  • Loading branch information
JackMoriarty committed Mar 7, 2024
1 parent 186afd0 commit aee6311
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
1 change: 0 additions & 1 deletion tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/framework/embedding/gpu_hash_map_kv.h"
#include "tensorflow/core/framework/embedding/embedding_config.h"
#include "tensorflow/core/framework/embedding/storage.h"
#include "tensorflow/core/framework/embedding/storage_factory.h"
#include "tensorflow/core/framework/typed_allocator.h"

namespace tensorflow {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/embedding/cache.h"
#include "tensorflow/core/framework/embedding/config.pb.h"
#include "tensorflow/core/framework/embedding/embedding_var.h"
#include "tensorflow/core/framework/embedding/storage_factory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/kv_variable_restore_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/embedding/cache.h"
#include "tensorflow/core/framework/embedding/config.pb.h"
#include "tensorflow/core/framework/embedding/embedding_var.h"
#include "tensorflow/core/framework/embedding/storage_factory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/kernels/training_ali_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class KvSparseApplyAdagradGPUOp : public OpKernel {
T** dev_a = dev_v + task_size;
CHECK(dev_a);
CHECK(dev_v);
DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2);
se::DeviceMemoryBase dev_v_ptr(dev_v, sizeof(T*) * task_size * 2);
stream->ThenMemcpy(&dev_v_ptr, v, sizeof(T*) * task_size * 2);

int block_size = 128;
Expand Down Expand Up @@ -1606,7 +1606,7 @@ class KvSparseApplyAdamGPUOp : public OpKernel {
CHECK(dev_m_ptr);
CHECK(dev_v_ptr);

DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3);

int block_size = 128;
Expand Down Expand Up @@ -2579,7 +2579,7 @@ class KvSparseApplyAdamAsyncGPUOp : public OpKernel {
CHECK(dev_m_ptr);
CHECK(dev_v_ptr);

DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3);

int block_size = 128;
Expand Down Expand Up @@ -3236,7 +3236,7 @@ class KvSparseApplyAdamWGPUOp : public OpKernel {
CHECK(dev_m_ptr);
CHECK(dev_v_ptr);

DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
se::DeviceMemoryBase dst_ptr(dev_var_ptr, sizeof(T*) * task_size * 3);
stream->ThenMemcpy(&dst_ptr, var_ptr, sizeof(T*) * task_size * 3);

int block_size = 128;
Expand Down

0 comments on commit aee6311

Please sign in to comment.