From d439fa0f9ebf3e6b76a2c5dbf194db9bb58db41f Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Wed, 3 Jan 2024 15:01:10 -0800 Subject: [PATCH] Support resource manager scoped Sentencepiece resources. PiperOrigin-RevId: 595512471 --- .../python/ops/sentencepiece_tokenizer.py | 24 +++++++++++++++---- .../ops/sentencepiece_tokenizer_test.py | 8 +++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tensorflow_text/python/ops/sentencepiece_tokenizer.py b/tensorflow_text/python/ops/sentencepiece_tokenizer.py index 3ff17c434..f286abab5 100644 --- a/tensorflow_text/python/ops/sentencepiece_tokenizer.py +++ b/tensorflow_text/python/ops/sentencepiece_tokenizer.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function +import uuid + from tensorflow.python.eager import monitoring from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -43,16 +45,24 @@ class _SentencepieceModelResource(resource.TrackableResource): """Utility to track the model resource tensor (for SavedModel support).""" - def __init__(self, model, name): + def __init__(self, model, name, use_unique_shared_resource_name): super(_SentencepieceModelResource, self).__init__() self._model = model self._name = name + self._use_unique_shared_resource_name = use_unique_shared_resource_name _ = self.resource_handle # Accessing this property creates the resource. def _create_resource(self): model, name = self._model, self._name with ops.name_scope(name, "SentenceTokenizerInitializer", [model]): - return gen_sentencepiece_tokenizer.sentencepiece_op(model=model) + # TODO(b/318839908): Switch to ref-counted resources and remove this. + shared_name = "" + if self._use_unique_shared_resource_name: + shared_name = str(uuid.uuid4()) + return gen_sentencepiece_tokenizer.sentencepiece_op( + model=model, + shared_name=shared_name, + ) class SentencepieceTokenizer(TokenizerWithOffsets, Detokenizer): @@ -79,7 +89,8 @@ def __init__(self, add_bos=False, add_eos=False, return_nbest=False, - name=None): + name=None, + use_unique_shared_resource_name=False): """Creates & initializes a Sentencepiece processor. Args: @@ -104,6 +115,10 @@ def __init__(self, of a single one. The returned tensor has shape `[batch * nbest, (tokens)]`. name: The name argument that is passed to the op function. + use_unique_shared_resource_name: Whether to set a unique `shared_name` for + the Sentencepiece resource. This is useful for allowing the resource to + outlive the kernel that created it in eager mode. The resource's + lifetime will instead be tied to the eager context. Returns: pieces: A SentencepieceTokenizer. @@ -117,7 +132,8 @@ def __init__(self, self.add_bos = add_bos self.add_eos = add_eos self.return_nbest = return_nbest - self._model_resource = _SentencepieceModelResource(model, name) + self._model_resource = _SentencepieceModelResource( + model, name, use_unique_shared_resource_name) def tokenize(self, input, name=None): # pylint: disable=redefined-builtin """Tokenizes a tensor of UTF-8 strings. diff --git a/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py b/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py index 53c506288..0651a4b57 100644 --- a/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py +++ b/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py @@ -485,11 +485,15 @@ def testSavedModel(self): self.assertAllEqual(restored_model.tokenize(inputs), expected_result) file_io.delete_recursively(temp_dir) - def testBasicPipeline(self): + @parameterized.parameters([True, False]) + def testBasicPipeline(self, use_unique_shared_resource_name): if not context.executing_eagerly(): self.skipTest('testBasicPipeline only supported in eager mode.') - sp = SentencepieceTokenizer(self.model) + sp = SentencepieceTokenizer( + self.model, + use_unique_shared_resource_name=use_unique_shared_resource_name, + ) strings = ['hello', 'world'] dataset = dataset_ops.Dataset.from_tensor_slices(strings)