Skip to content

Commit

Permalink
Add support for cross-encoder models (+fix token type ids) (#501)
Browse files Browse the repository at this point in the history
* Formatting

* Update ESM pair template

* Fix token type ids

* Update JSDoc

* Cleanup

* Remove unused `prepare_model_inputs` function

* Move pad and truncate logic to helper functions

* Add static padding/truncation unit tests

* Fix padding/truncation

* Remove unused `add_token_types` function

* Reduce duplication

* `let` -> `const` where possible

* Add cross-encoder models
  • Loading branch information
xenova authored Jan 4, 2024
1 parent f3482ba commit ebd5335
Show file tree
Hide file tree
Showing 6 changed files with 535 additions and 325 deletions.
37 changes: 21 additions & 16 deletions scripts/extra/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,37 @@
from tokenizers import Tokenizer, pre_tokenizers, processors
from tokenizers.models import WordPiece


class EsmConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int(1e10), unk_token=str(self.original_tokenizer.unk_token)))
tokenizer = Tokenizer(WordPiece(vocab, continuing_subword_prefix='', max_input_chars_per_word=int(
1e10), unk_token=str(self.original_tokenizer.unk_token)))

tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()

cls = str(self.original_tokenizer.cls_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep = str(self.original_tokenizer.eos_token) # No sep token in ESM vocabulary
# No sep token in ESM vocabulary
sep = str(self.original_tokenizer.eos_token)
sep_token_id = self.original_tokenizer.eos_token_id

if sep_token_id is None:
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0",
special_tokens=[
(cls, cls_token_id),
],
)
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0",
special_tokens=[
(cls, cls_token_id),
],
)
else:
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)

# For some reason, all tokens are added: none of them are special, but they all need special splitting.
# See https://github.com/huggingface/transformers/blob/df5c5c62ae253055336f5bb0828ca8e3e15ab6bd/src/transformers/models/esm/tokenization_esm.py#L79-L80
Expand All @@ -44,6 +48,7 @@ def converted(self) -> Tokenizer:
tokenizer.add_tokens(other_tokens)
return tokenizer


def generate_fast_tokenizer(tokenizer):
tokenizer.vocab = tokenizer._token_to_id
return EsmConverter(tokenizer).converted()
5 changes: 5 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@
'unitary/toxic-bert',
'BAAI/bge-reranker-large',
'BAAI/bge-reranker-base',
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-4-v2',
'cross-encoder/ms-marco-MiniLM-L-6-v2',
'cross-encoder/ms-marco-MiniLM-L-12-v2',
],

# Token classification
Expand Down
12 changes: 6 additions & 6 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ import {
AutoConfig,
} from './configs.js';

import {
add_token_types,
} from './tokenizers.js';

import {
Callable,
isIntegralNumber,
Expand Down Expand Up @@ -512,9 +508,13 @@ async function encoderForward(self, model_inputs) {
encoderFeeds[key] = model_inputs[key];
}
if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` to the `encoderFeeds` if the model expects it,
// Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
add_token_types(encoderFeeds);
encoderFeeds.token_type_ids = new Tensor(
'int64',
new BigInt64Array(encoderFeeds.input_ids.data.length),
encoderFeeds.input_ids.dims
)
}
return await sessionRun(self.session, encoderFeeds);
}
Expand Down
Loading

0 comments on commit ebd5335

Please sign in to comment.