diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc index ac23c5794..8268566ea 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc @@ -65,12 +65,13 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input, std::vector* output_ids, std::vector* output_start_offsets, std::vector* output_end_offsets, - int input_word_offset_in_text) const { + int input_word_offset_in_text, + bool* error) const { if (config_->end_to_end()) { TokenizeTextImpl(input, output_pieces, output_ids, output_start_offsets, - output_end_offsets); + output_end_offsets, error); } else { TokenizeSingleWordImpl( @@ -86,9 +87,9 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input, int input_word_offset_in_text) const { if (config_->end_to_end()) { TokenizeTextImpl(input, /*output_pieces=*/nullptr, - output_ids, output_start_offsets, - output_end_offsets); + /*kGetOffsets=*/true>( + input, /*output_pieces=*/nullptr, output_ids, output_start_offsets, + output_end_offsets, /*error=*/nullptr); } else { TokenizeSingleWordImpl( @@ -102,10 +103,10 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input, int input_word_offset_in_text) const { if (config_->end_to_end()) { TokenizeTextImpl(input, /*output_pieces=*/nullptr, - output_ids, - /*output_start_offsets=*/nullptr, - /*output_end_offsets=*/nullptr); + /*kGetOffsets=*/false>( + input, /*output_pieces=*/nullptr, output_ids, + /*output_start_offsets=*/nullptr, + /*output_end_offsets=*/nullptr, /*error=*/nullptr); } else { TokenizeSingleWordImpl( @@ -186,13 +187,14 @@ template void FastWordpieceTokenizer::TokenizeTextImpl( absl::string_view input_text, std::vector* output_pieces, std::vector* output_ids, std::vector* output_start_offsets, - std::vector* output_end_offsets) const { + std::vector* output_end_offsets, bool* error) const { static_assert(kGetPieces || kGetIds, "At least one of `kGetPieces` and `kGetIds` should be true."); if (input_text.empty()) { return; } const int input_size = input_text.size(); + int prev_pos = -1; int next_pos = 0; int cur_pos = 0; int original_num_tokens = @@ -200,6 +202,13 @@ void FastWordpieceTokenizer::TokenizeTextImpl( UChar32 prev_unicode_char; UChar32 cur_unicode_char; while (cur_pos < input_size) { + // Prevent looping without progress in cur_pos. + if (prev_pos == cur_pos && error != nullptr) { + *error = true; + return; + } + prev_pos = cur_pos; + int cur_offset_in_input_word = 0; // Tokenize the word starting at the current position. auto cur_node = trie_->CreateTraversalCursorPointToRoot(); @@ -210,7 +219,15 @@ void FastWordpieceTokenizer::TokenizeTextImpl( // 1. it steps over the input boundary, or // 2. the length of the current word reaches 'max_bytes_per_token', or // 3. it sees a whitespace / punctuation / unknown character. + int prev_pos_inner = -1; while (cur_pos < input_size) { + // Prevent looping without progress in cur_pos. + if (prev_pos_inner == cur_pos && error != nullptr) { + *error = true; + return; + } + prev_pos_inner = cur_pos; + prev_unicode_char = cur_unicode_char; next_pos = cur_pos; U8_NEXT(input_text, next_pos, input_text.length(), cur_unicode_char); diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h index 9369dde2a..908e18862 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h @@ -72,13 +72,15 @@ class FastWordpieceTokenizer { // text, in utf-8 bytes. // * input_word_offset_in_text: The relative offset of the input word in // the whole text. Only used when not using end-to-end tokenizer. + // * error: If not null, this will be set to true if the tokenizer failed to + // make progress in decoding the input. // Note: the start offsets are inclusive and the end offsets are exclusive. void Tokenize(absl::string_view input, std::vector* output_pieces, std::vector* output_ids, std::vector* output_start_offsets, std::vector* output_end_offsets, - int input_word_offset_in_text = 0) const; + int input_word_offset_in_text = 0, bool* error = nullptr) const; // An override not returning `output_pieces`. void Tokenize(absl::string_view input, std::vector* output_ids, @@ -125,7 +127,8 @@ class FastWordpieceTokenizer { std::vector* output_pieces, std::vector* output_ids, std::vector* output_start_offsets, - std::vector* output_end_offsets) const; + std::vector* output_end_offsets, + bool* error) const; // Try following the failure link to make the transition when trie matching // fails. diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h index 7586d2de1..96c25b93f 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h @@ -163,8 +163,14 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp::Invoke( for (int i = 0; i < values_vec.Dim(0); ++i) { // Tokenize into subwords and record the offset locations. const int original_num_wordpieces = subwords.size(); + bool error = false; fast_wordpiece_tokenizer->Tokenize(values_vec(i), &subwords, &subword_ids, - &begin_offset, &end_offset); + &begin_offset, &end_offset, + /*input_word_offset_in_text=*/0, &error); + if (error) { + return absl::InternalError( + "Failed to make any progress in tokenizing the input text."); + } const int delta_num_wordpieces = subwords.size() - original_num_wordpieces; // Record the row splits.