From bb1d70a012ff8c5244a4c9ff004a7ac02a89f824 Mon Sep 17 00:00:00 2001 From: "TF.Text Team" Date: Wed, 11 Dec 2024 15:51:43 -0800 Subject: [PATCH] Protect the fast wordpiece tokenizer from infinite looping. PiperOrigin-RevId: 705269179 --- .../core/kernels/fast_wordpiece_tokenizer.cc | 30 ++++++++++++------- .../core/kernels/fast_wordpiece_tokenizer.h | 7 +++-- ...fast_wordpiece_tokenizer_kernel_template.h | 8 ++++- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc index ac23c5794..55386a3fa 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc @@ -65,12 +65,13 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input, std::vector* output_ids, std::vector* output_start_offsets, std::vector* output_end_offsets, - int input_word_offset_in_text) const { + int input_word_offset_in_text, + bool* error) const { if (config_->end_to_end()) { TokenizeTextImpl(input, output_pieces, output_ids, output_start_offsets, - output_end_offsets); + output_end_offsets, error); } else { TokenizeSingleWordImpl( @@ -86,9 +87,9 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input, int input_word_offset_in_text) const { if (config_->end_to_end()) { TokenizeTextImpl(input, /*output_pieces=*/nullptr, - output_ids, output_start_offsets, - output_end_offsets); + /*kGetOffsets=*/true>( + input, /*output_pieces=*/nullptr, output_ids, output_start_offsets, + output_end_offsets, /*error=*/nullptr); } else { TokenizeSingleWordImpl( @@ -102,10 +103,10 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input, int input_word_offset_in_text) const { if (config_->end_to_end()) { TokenizeTextImpl(input, /*output_pieces=*/nullptr, - output_ids, - /*output_start_offsets=*/nullptr, - /*output_end_offsets=*/nullptr); + /*kGetOffsets=*/false>( + input, /*output_pieces=*/nullptr, output_ids, + /*output_start_offsets=*/nullptr, + /*output_end_offsets=*/nullptr, /*error=*/nullptr); } else { TokenizeSingleWordImpl( @@ -186,13 +187,14 @@ template void FastWordpieceTokenizer::TokenizeTextImpl( absl::string_view input_text, std::vector* output_pieces, std::vector* output_ids, std::vector* output_start_offsets, - std::vector* output_end_offsets) const { + std::vector* output_end_offsets, bool* error) const { static_assert(kGetPieces || kGetIds, "At least one of `kGetPieces` and `kGetIds` should be true."); if (input_text.empty()) { return; } const int input_size = input_text.size(); + int prev_pos = -1; int next_pos = 0; int cur_pos = 0; int original_num_tokens = @@ -200,6 +202,14 @@ void FastWordpieceTokenizer::TokenizeTextImpl( UChar32 prev_unicode_char; UChar32 cur_unicode_char; while (cur_pos < input_size) { + if (prev_pos == cur_pos) { + if (error != nullptr) { + *error = true; + return; + } + } + prev_pos = cur_pos; + int cur_offset_in_input_word = 0; // Tokenize the word starting at the current position. auto cur_node = trie_->CreateTraversalCursorPointToRoot(); diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h index 9369dde2a..908e18862 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h @@ -72,13 +72,15 @@ class FastWordpieceTokenizer { // text, in utf-8 bytes. // * input_word_offset_in_text: The relative offset of the input word in // the whole text. Only used when not using end-to-end tokenizer. + // * error: If not null, this will be set to true if the tokenizer failed to + // make progress in decoding the input. // Note: the start offsets are inclusive and the end offsets are exclusive. void Tokenize(absl::string_view input, std::vector* output_pieces, std::vector* output_ids, std::vector* output_start_offsets, std::vector* output_end_offsets, - int input_word_offset_in_text = 0) const; + int input_word_offset_in_text = 0, bool* error = nullptr) const; // An override not returning `output_pieces`. void Tokenize(absl::string_view input, std::vector* output_ids, @@ -125,7 +127,8 @@ class FastWordpieceTokenizer { std::vector* output_pieces, std::vector* output_ids, std::vector* output_start_offsets, - std::vector* output_end_offsets) const; + std::vector* output_end_offsets, + bool* error) const; // Try following the failure link to make the transition when trie matching // fails. diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h index 7586d2de1..96c25b93f 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h @@ -163,8 +163,14 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp::Invoke( for (int i = 0; i < values_vec.Dim(0); ++i) { // Tokenize into subwords and record the offset locations. const int original_num_wordpieces = subwords.size(); + bool error = false; fast_wordpiece_tokenizer->Tokenize(values_vec(i), &subwords, &subword_ids, - &begin_offset, &end_offset); + &begin_offset, &end_offset, + /*input_word_offset_in_text=*/0, &error); + if (error) { + return absl::InternalError( + "Failed to make any progress in tokenizing the input text."); + } const int delta_num_wordpieces = subwords.size() - original_num_wordpieces; // Record the row splits.