Skip to content

Commit

Permalink
Support resource manager scoped Sentencepiece resources.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596019271
  • Loading branch information
aaudiber authored and tf-text-github-robot committed Jan 5, 2024
1 parent 8ada63e commit dba0c2f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
24 changes: 20 additions & 4 deletions tensorflow_text/python/ops/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from __future__ import division
from __future__ import print_function

import uuid

from tensorflow.python.eager import monitoring
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
Expand All @@ -43,16 +45,24 @@
class _SentencepieceModelResource(resource.TrackableResource):
"""Utility to track the model resource tensor (for SavedModel support)."""

def __init__(self, model, name):
def __init__(self, model, name, use_unique_shared_resource_name):
super(_SentencepieceModelResource, self).__init__()
self._model = model
self._name = name
self._use_unique_shared_resource_name = use_unique_shared_resource_name
_ = self.resource_handle # Accessing this property creates the resource.

def _create_resource(self):
model, name = self._model, self._name
with ops.name_scope(name, "SentenceTokenizerInitializer", [model]):
return gen_sentencepiece_tokenizer.sentencepiece_op(model=model)
# TODO(b/318839908): Switch to ref-counted resources and remove this.
shared_name = ""
if self._use_unique_shared_resource_name:
shared_name = str(uuid.uuid4())
return gen_sentencepiece_tokenizer.sentencepiece_op(
model=model,
shared_name=shared_name,
)


class SentencepieceTokenizer(TokenizerWithOffsets, Detokenizer):
Expand All @@ -79,7 +89,8 @@ def __init__(self,
add_bos=False,
add_eos=False,
return_nbest=False,
name=None):
name=None,
use_unique_shared_resource_name=False):
"""Creates & initializes a Sentencepiece processor.
Args:
Expand All @@ -104,6 +115,10 @@ def __init__(self,
of a single one. The returned tensor has shape
`[batch * nbest, (tokens)]`.
name: The name argument that is passed to the op function.
use_unique_shared_resource_name: Whether to set a unique `shared_name` for
the Sentencepiece resource. This is useful for allowing the resource to
outlive the kernel that created it in eager mode. The resource's
lifetime will instead be tied to the eager context.
Returns:
pieces: A SentencepieceTokenizer.
Expand All @@ -117,7 +132,8 @@ def __init__(self,
self.add_bos = add_bos
self.add_eos = add_eos
self.return_nbest = return_nbest
self._model_resource = _SentencepieceModelResource(model, name)
self._model_resource = _SentencepieceModelResource(
model, name, use_unique_shared_resource_name)

def tokenize(self, input, name=None): # pylint: disable=redefined-builtin
"""Tokenizes a tensor of UTF-8 strings.
Expand Down
8 changes: 6 additions & 2 deletions tensorflow_text/python/ops/sentencepiece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,15 @@ def testSavedModel(self):
self.assertAllEqual(restored_model.tokenize(inputs), expected_result)
file_io.delete_recursively(temp_dir)

def testBasicPipeline(self):
@parameterized.parameters([True, False])
def testBasicPipeline(self, use_unique_shared_resource_name):
if not context.executing_eagerly():
self.skipTest('testBasicPipeline only supported in eager mode.')

sp = SentencepieceTokenizer(self.model)
sp = SentencepieceTokenizer(
self.model,
use_unique_shared_resource_name=use_unique_shared_resource_name,
)

strings = ['hello', 'world']
dataset = dataset_ops.Dataset.from_tensor_slices(strings)
Expand Down

0 comments on commit dba0c2f

Please sign in to comment.