Skip to content

Commit

Permalink
Protect the fast wordpiece tokenizer from infinite looping.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705292812
  • Loading branch information
tf-text-github-robot committed Dec 12, 2024
1 parent 6365dba commit 534859c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
37 changes: 27 additions & 10 deletions tensorflow_text/core/kernels/fast_wordpiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
std::vector<int>* output_ids,
std::vector<int>* output_start_offsets,
std::vector<int>* output_end_offsets,
int input_word_offset_in_text) const {
int input_word_offset_in_text,
bool* error) const {
if (config_->end_to_end()) {
TokenizeTextImpl</*kGetPieces=*/true, /*kGetIds=*/true,
/*kGetOffsets=*/true>(input, output_pieces, output_ids,
output_start_offsets,
output_end_offsets);
output_end_offsets, error);
} else {
TokenizeSingleWordImpl</*kGetPieces=*/true, /*kGetIds=*/true,
/*kGetOffsets=*/true>(
Expand All @@ -86,9 +87,9 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
int input_word_offset_in_text) const {
if (config_->end_to_end()) {
TokenizeTextImpl</*kGetPieces=*/false, /*kGetIds=*/true,
/*kGetOffsets=*/true>(input, /*output_pieces=*/nullptr,
output_ids, output_start_offsets,
output_end_offsets);
/*kGetOffsets=*/true>(
input, /*output_pieces=*/nullptr, output_ids, output_start_offsets,
output_end_offsets, /*error=*/nullptr);
} else {
TokenizeSingleWordImpl</*kGetPieces=*/false, /*kGetIds=*/true,
/*kGetOffsets=*/true>(
Expand All @@ -102,10 +103,10 @@ void FastWordpieceTokenizer::Tokenize(absl::string_view input,
int input_word_offset_in_text) const {
if (config_->end_to_end()) {
TokenizeTextImpl</*kGetPieces=*/false, /*kGetIds=*/true,
/*kGetOffsets=*/false>(input, /*output_pieces=*/nullptr,
output_ids,
/*output_start_offsets=*/nullptr,
/*output_end_offsets=*/nullptr);
/*kGetOffsets=*/false>(
input, /*output_pieces=*/nullptr, output_ids,
/*output_start_offsets=*/nullptr,
/*output_end_offsets=*/nullptr, /*error=*/nullptr);
} else {
TokenizeSingleWordImpl</*kGetPieces=*/false, /*kGetIds=*/true,
/*kGetOffsets=*/false>(
Expand Down Expand Up @@ -186,20 +187,28 @@ template <bool kGetPieces, bool kGetIds, bool kGetOffsets>
void FastWordpieceTokenizer::TokenizeTextImpl(
absl::string_view input_text, std::vector<std::string>* output_pieces,
std::vector<int>* output_ids, std::vector<int>* output_start_offsets,
std::vector<int>* output_end_offsets) const {
std::vector<int>* output_end_offsets, bool* error) const {
static_assert(kGetPieces || kGetIds,
"At least one of `kGetPieces` and `kGetIds` should be true.");
if (input_text.empty()) {
return;
}
const int input_size = input_text.size();
int prev_pos = -1;
int next_pos = 0;
int cur_pos = 0;
int original_num_tokens =
GetCurrentOutputSize<kGetPieces>(output_pieces, output_ids);
UChar32 prev_unicode_char;
UChar32 cur_unicode_char;
while (cur_pos < input_size) {
// Prevent looping without progress in cur_pos.
if (prev_pos == cur_pos && error != nullptr) {
*error = true;
return;
}
prev_pos = cur_pos;

int cur_offset_in_input_word = 0;
// Tokenize the word starting at the current position.
auto cur_node = trie_->CreateTraversalCursorPointToRoot();
Expand All @@ -210,7 +219,15 @@ void FastWordpieceTokenizer::TokenizeTextImpl(
// 1. it steps over the input boundary, or
// 2. the length of the current word reaches 'max_bytes_per_token', or
// 3. it sees a whitespace / punctuation / unknown character.
int prev_pos_inner = -1;
while (cur_pos < input_size) {
// Prevent looping without progress in cur_pos.
if (prev_pos_inner == cur_pos && error != nullptr) {
*error = true;
return;
}
prev_pos_inner = cur_pos;

prev_unicode_char = cur_unicode_char;
next_pos = cur_pos;
U8_NEXT(input_text, next_pos, input_text.length(), cur_unicode_char);
Expand Down
7 changes: 5 additions & 2 deletions tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ class FastWordpieceTokenizer {
// text, in utf-8 bytes.
// * input_word_offset_in_text: The relative offset of the input word in
// the whole text. Only used when not using end-to-end tokenizer.
// * error: If not null, this will be set to true if the tokenizer failed to
// make progress in decoding the input.
// Note: the start offsets are inclusive and the end offsets are exclusive.
void Tokenize(absl::string_view input,
std::vector<std::string>* output_pieces,
std::vector<int>* output_ids,
std::vector<int>* output_start_offsets,
std::vector<int>* output_end_offsets,
int input_word_offset_in_text = 0) const;
int input_word_offset_in_text = 0, bool* error = nullptr) const;

// An override not returning `output_pieces`.
void Tokenize(absl::string_view input, std::vector<int>* output_ids,
Expand Down Expand Up @@ -125,7 +127,8 @@ class FastWordpieceTokenizer {
std::vector<std::string>* output_pieces,
std::vector<int>* output_ids,
std::vector<int>* output_start_offsets,
std::vector<int>* output_end_offsets) const;
std::vector<int>* output_end_offsets,
bool* error) const;

// Try following the failure link to make the transition when trie matching
// fails.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,14 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp<Rt>::Invoke(
for (int i = 0; i < values_vec.Dim(0); ++i) {
// Tokenize into subwords and record the offset locations.
const int original_num_wordpieces = subwords.size();
bool error = false;
fast_wordpiece_tokenizer->Tokenize(values_vec(i), &subwords, &subword_ids,
&begin_offset, &end_offset);
&begin_offset, &end_offset,
/*input_word_offset_in_text=*/0, &error);
if (error) {
return absl::InternalError(
"Failed to make any progress in tokenizing the input text.");
}
const int delta_num_wordpieces = subwords.size() - original_num_wordpieces;

// Record the row splits.
Expand Down

0 comments on commit 534859c

Please sign in to comment.