Skip to content

Commit

Permalink
[Embedding] Fix bug of saving EmbeddingVariable with int32 type. (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixy9474 authored Feb 17, 2023
1 parent 60d515b commit 08c81ad
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/save_restore_v2_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ class SaveV2 : public OpKernel {
const string& tensor_name = tensor_names_flat(i);
if (tensor_types_[i] == DT_RESOURCE) {
auto& handle = HandleFromInput(context, i + kFixedInputs);
if (IsHandle<EmbeddingVar<int64, float>>(handle)) {
if (IsHandle<EmbeddingVar<int64, float>>(handle) ||
IsHandle<EmbeddingVar<int32, float>>(handle)) {
if (ev_key_types_[start_ev_key_index] == DT_INT32) {
DumpEvWithGlobalStep<int32, float>(context,
i + kFixedInputs, tensor_name, writer, tensor_types_[0]);
Expand Down

0 comments on commit 08c81ad

Please sign in to comment.