From 5bc25129f5a4c51377ca01fb417be344c5ec4a31 Mon Sep 17 00:00:00 2001 From: lixy9474 Date: Fri, 8 Sep 2023 15:53:37 +0800 Subject: [PATCH] [Embedding] Fix the graph of frequency recorder. Signed-off-by: lixy9474 --- .../python/ops/embedding_variable_ops_test.py | 74 +++++++++++++++++++ tensorflow/python/ops/kv_variable_ops.py | 4 +- .../python/training/gradient_descent.py | 15 +++- tensorflow/python/training/optimizer.py | 30 +++++++- 4 files changed, 115 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/ops/embedding_variable_ops_test.py b/tensorflow/python/ops/embedding_variable_ops_test.py index 25a0cb6ff11..c6cdf951a1e 100644 --- a/tensorflow/python/ops/embedding_variable_ops_test.py +++ b/tensorflow/python/ops/embedding_variable_ops_test.py @@ -2816,5 +2816,79 @@ def testSetInitializedWithRestore(self): result = sess.run(var._is_initialized_op) self.assertEqual(True, result) + def testCountsTensor(self): + os.environ["TF_RECORD_FREQ"] = "1" + checkpoint_directory = self.get_temp_dir() + ckpt_path = os.path.join(checkpoint_directory, "model.ckpt") + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + sp1 = sparse_tensor.SparseTensor( + indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]], + values=math_ops.cast([0,0,0,1,1,2], dtypes.int64), + dense_shape=[6, 1]) + sp2 = sparse_tensor.SparseTensor( + indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]], + values=math_ops.cast([3,3,3,4,4,1], dtypes.int64), + dense_shape=[6, 1]) + emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None) + emb2 = embedding_ops.embedding_lookup_sparse(var, sp2, None) + emb = emb1 + emb2 + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = adagrad_decay.AdagradDecayOptimizer(0.1, gs) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + with self.test_session(graph=g) as sess: + sess.run([init]) + sess.run(train_op) + saver.save(sess, ckpt_path) + + for name, shape in checkpoint_utils.list_variables(ckpt_path): + if name == "var_1-freqs": + value = checkpoint_utils.load_variable(ckpt_path, name) + self.assertAllEqual(value, [3, 3, 1, 3, 2]) + + def testCountsTensorWithGradientDescent(self): + os.environ["TF_RECORD_FREQ"] = "1" + checkpoint_directory = self.get_temp_dir() + ckpt_path = os.path.join(checkpoint_directory, "model.ckpt") + with ops.Graph().as_default() as g, ops.device('/cpu:0'): + var = variable_scope.get_embedding_variable("var_1", + embedding_dim = 3) + sp1 = sparse_tensor.SparseTensor( + indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]], + values=math_ops.cast([0,0,0,1,1,2], dtypes.int64), + dense_shape=[6, 1]) + sp2 = sparse_tensor.SparseTensor( + indices=[[0,0],[1,0],[2,0],[3,0],[4,0],[5,0]], + values=math_ops.cast([3,3,3,4,4,1], dtypes.int64), + dense_shape=[6, 1]) + emb1 = embedding_ops.embedding_lookup_sparse(var, sp1, None) + emb2 = embedding_ops.embedding_lookup_sparse(var, sp2, None) + emb = emb1 + emb2 + fun = math_ops.multiply(emb, 2.0, name='multiply') + loss = math_ops.reduce_sum(fun, name='reduce_sum') + gs = training_util.get_or_create_global_step() + opt = gradient_descent.GradientDescentOptimizer(0.1) + g_v = opt.compute_gradients(loss) + train_op = opt.apply_gradients(g_v) + saver = saver_module.Saver() + init = variables.global_variables_initializer() + with self.test_session(graph=g) as sess: + sess.run([init]) + sess.run(train_op) + saver.save(sess, ckpt_path) + + for name, shape in checkpoint_utils.list_variables(ckpt_path): + if name == "var_1-freqs": + value = checkpoint_utils.load_variable(ckpt_path, name) + self.assertAllEqual(value, [3, 3, 1, 3, 2]) + + del os.environ["TF_RECORD_FREQ"] + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index 701c03f6975..96329ca345b 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -368,7 +368,7 @@ def _init_from_args(self, self._dtype = initial_value.dtype.base_dtype self._constraint = constraint self._gather_op = None - self._counts_tensor = None + self._counts_tensor = {} if self._is_primary: self._slot_num = 0 else: @@ -850,7 +850,7 @@ def sparse_read(self, indices, name=None, ev_init_value=None, counts=None): default_value, counts, is_inference=True, name=name) - self._counts_tensor = counts + self._counts_tensor[indices] = counts else: value = gen_kv_variable_ops.kv_resource_gather(self._handle, indices, diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py index 32a12a0554f..799e3c5f5bd 100644 --- a/tensorflow/python/training/gradient_descent.py +++ b/tensorflow/python/training/gradient_descent.py @@ -71,12 +71,23 @@ def _resource_apply_dense(self, grad, handle): def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): if isinstance(handle, kv_variable_ops.EmbeddingVariable): global_step = training_util.get_or_create_global_step() - if handle.need_counts() and handle._counts_tensor is not None: + if handle.need_counts() and len(handle._counts_tensor.keys()) != 0: + if indices.op.type == "ConcatV2": + total_counts = [] + for tensor in indices.op.inputs: + if tensor.op.type == "Reshape": + indices_tensor = tensor.op.inputs[0] + total_counts.append(handle._counts_tensor[indices_tensor]) + from tensorflow.python.ops import array_ops + counts_tensor = array_ops.concat(total_counts, 0) + elif indices.op.type == "Reshape": + indices_tensor = indices.op.inputs[0] + counts_tensor = handle._counts_tensor[indices_tensor] return training_ops.kv_resource_sparse_apply_gradient_descent_with_counts( handle.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype), grad, indices, global_step, - handle._counts_tensor, use_locking=self._use_locking) + counts_tensor, use_locking=self._use_locking) else: return training_ops.kv_resource_sparse_apply_gradient_descent( handle.handle, math_ops.cast(self._learning_rate_tensor, diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 578d682cc11..7523604ccf9 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -93,6 +93,18 @@ def _deduplicate_indexed_slices_with_counts(values, indices): array_ops.shape(unique_indices)[0]) return (summed_values, unique_indices, indices_counts) +def _deduplicate_indexed_slices_with_counts_reduction(values, indices, counts): + """Sums `values` associated with any non-unique `indices` + and return counts of each count in `values`.""" + unique_indices, new_index_positions = array_ops.unique(indices) + summed_values = math_ops.unsorted_segment_sum( + values, new_index_positions, + array_ops.shape(unique_indices)[0]) + summed_counts = math_ops.unsorted_segment_sum( + counts, new_index_positions, + array_ops.shape(unique_indices)[0]) + return (summed_values, unique_indices, summed_counts) + def _var_key(var): # TODO(ashankar): Consolidate handling for eager and graph if hasattr(var, "op"): @@ -1088,14 +1100,24 @@ def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): """ from tensorflow.python.ops import kv_variable_ops if isinstance(handle, kv_variable_ops.EmbeddingVariable) and handle.need_counts(): - if handle._counts_tensor is None: + if len(handle._counts_tensor.keys()) == 0: summed_grad, unique_indices, indices_counts = \ _deduplicate_indexed_slices_with_counts( values=grad, indices=indices) else: - summed_grad, unique_indices = _deduplicate_indexed_slices( - values=grad, indices=indices) - indices_counts = handle._counts_tensor + if indices.op.type == "ConcatV2": + total_counts = [] + for tensor in indices.op.inputs: + if tensor.op.type == "Reshape": + indices_tensor = tensor.op.inputs[0] + total_counts.append(handle._counts_tensor[indices_tensor]) + counts_tensor = array_ops.concat(total_counts, 0) + elif indices.op.type == "Reshape": + indices_tensor = indices.op.inputs[0] + counts_tensor = handle._counts_tensor[indices_tensor] + summed_grad, unique_indices, indices_counts = \ + _deduplicate_indexed_slices_with_counts_reduction( + grad, indices, counts_tensor) return self._resource_apply_sparse( summed_grad, handle, unique_indices, indices_counts) else: