From 8e87dab8f14882aa15abebad91f8f37cefea3801 Mon Sep 17 00:00:00 2001 From: "TF.Text Team" Date: Mon, 19 Feb 2024 09:12:09 -0800 Subject: [PATCH] Expose sentencepiece parallel tokenize in tf text sentencepiece wrapper. PiperOrigin-RevId: 608351975 --- .../core/kernels/sentencepiece_kernels.cc | 26 ++++++++++++++++--- tensorflow_text/core/ops/sentencepiece_ops.cc | 4 +++ .../python/ops/sentencepiece_tokenizer.py | 11 +++++++- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/tensorflow_text/core/kernels/sentencepiece_kernels.cc b/tensorflow_text/core/kernels/sentencepiece_kernels.cc index c81f8cb1f..bc387fa38 100644 --- a/tensorflow_text/core/kernels/sentencepiece_kernels.cc +++ b/tensorflow_text/core/kernels/sentencepiece_kernels.cc @@ -28,6 +28,7 @@ #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" @@ -266,6 +267,10 @@ class SentencepieceTokenizeOp : public OpKernel { public: explicit SentencepieceTokenizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { ctx->GetAttr("return_nbest", &return_nbest_).IgnoreError(); + + // Parallel encode options. + ctx->GetAttr("num_threads", &num_threads_).IgnoreError(); + ctx->GetAttr("chunk_size", &chunk_size_).IgnoreError(); } void Compute(OpKernelContext* ctx) override { @@ -309,6 +314,8 @@ class SentencepieceTokenizeOp : public OpKernel { nbest_tokens(return_nbest_ ? num_of_input_values : 0); if (num_of_input_values > 0) { const bool return_nbest = return_nbest_; + const int32 num_threads = num_threads_; + const int32 chunk_size = chunk_size_; const auto& worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); ::tensorflow::Shard( @@ -317,8 +324,8 @@ class SentencepieceTokenizeOp : public OpKernel { num_of_input_values, // total number of data to process. kCostPerUnit, // cost per unit [ctx, sp, &input_values_flat, &tokens, &nbest_tokens, - &nbest_size_tensor, &alpha_tensor, - return_nbest](int64 start, int64 limit) { + &nbest_size_tensor, &alpha_tensor, &chunk_size, + &num_threads, return_nbest](int64 start, int64 limit) { absl::ReaderMutexLock lock(&sp->mu); for (int i = start; i < limit; ++i) { const int32 nbest_size = nbest_size_tensor->dims() == 1 @@ -329,8 +336,17 @@ class SentencepieceTokenizeOp : public OpKernel { input_values_flat(i), nbest_size, &nbest_tokens[i]))); } else if (nbest_size == 0 || nbest_size == 1) { - OP_REQUIRES_OK(ctx, ToTFStatus(sp->processor.Encode( - input_values_flat(i), &tokens[i]))); + if (num_threads == 1) { + OP_REQUIRES_OK( + ctx, + ToTFStatus(sp->processor.Encode( + input_values_flat(i), &tokens[i]))); + } else { + OP_REQUIRES_OK( + ctx, ToTFStatus(sp->processor.ParallelEncode( + input_values_flat(i), chunk_size, num_threads, + &tokens[i]))); + } } else { const float alpha = alpha_tensor->dims() == 1 ? alpha_tensor->vec()(i) @@ -379,6 +395,8 @@ class SentencepieceTokenizeOp : public OpKernel { } bool return_nbest_{false}; + int32_t num_threads_{1}; + int32_t chunk_size_{0}; }; REGISTER_KERNEL_BUILDER(Name("SentencepieceTokenizeOp") diff --git a/tensorflow_text/core/ops/sentencepiece_ops.cc b/tensorflow_text/core/ops/sentencepiece_ops.cc index d830eb921..df6ef90d0 100644 --- a/tensorflow_text/core/ops/sentencepiece_ops.cc +++ b/tensorflow_text/core/ops/sentencepiece_ops.cc @@ -43,6 +43,8 @@ REGISTER_OP("SentencepieceTokenizeOp") .Input("add_bos: bool") .Input("add_eos: bool") .Input("reverse: bool") + .Attr("num_threads: int32 = 1") + .Attr("chunk_size: int32 = 0") .Attr("out_type: {int32, string} = DT_INT32") .Attr("Tsplits: {int32, int64} = DT_INT64") .Attr("return_nbest: bool = false") @@ -57,6 +59,8 @@ REGISTER_OP("SentencepieceTokenizeOp") TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); bool return_nbest = false; diff --git a/tensorflow_text/python/ops/sentencepiece_tokenizer.py b/tensorflow_text/python/ops/sentencepiece_tokenizer.py index 8c81f64a6..de2d7bd0d 100644 --- a/tensorflow_text/python/ops/sentencepiece_tokenizer.py +++ b/tensorflow_text/python/ops/sentencepiece_tokenizer.py @@ -80,6 +80,8 @@ def __init__(self, reverse=False, add_bos=False, add_eos=False, + num_threads=1, + chunk_size=0, return_nbest=False, name=None): """Creates & initializes a Sentencepiece processor. @@ -101,6 +103,10 @@ def __init__(self, add_eos: Add end of sentence token to the result (Default = false). When `reverse=True` beginning/end of sentence tokens are added after reversing. + num_threads: If `> 1`, the input is split up into chunks of size + `chunk_size` and tokenized in parallel with this many threads. + chunk_size: Only used if `num_threads > 1`. The input is split into + chunks of this size and tokenized in parallel. return_nbest: If True requires that `nbest_size` is a scalar and `> 1`. Returns the `nbest_size` best tokenizations for each sentence instead of a single one. The returned tensor has shape @@ -118,6 +124,8 @@ def __init__(self, self.reverse = reverse self.add_bos = add_bos self.add_eos = add_eos + self.num_threads = num_threads + self.chunk_size = chunk_size self.return_nbest = return_nbest self._model_resource = _SentencepieceModelResource(model, name) @@ -154,7 +162,8 @@ def tokenize(self, input, name=None): # pylint: disable=redefined-builtin gen_sentencepiece_tokenizer.sentencepiece_tokenize_op( self._model_resource.resource_handle, input_tensor, self.nbest_size, self.alpha, self.add_bos, self.add_eos, - self.reverse, self.out_type, return_nbest=self.return_nbest)) + self.reverse, self.num_threads, self.chunk_size, + self.out_type, return_nbest=self.return_nbest)) tokens = RaggedTensor.from_nested_row_splits( flat_values=output_values, nested_row_splits=[row_splits],