Skip to content

Commit

Permalink
Add default token_type_ids for multilingual-e5-* models (#403)
Browse files Browse the repository at this point in the history
* Fix #267 & #324

Add default token_type_ids. Fix for multilingual-e5-* family.

* Add add_token_types import

* export `add_token_types`

* Improvements

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
do-me and xenova authored Nov 19, 2023
1 parent b8719b1 commit ac0096e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ import {
AutoConfig,
} from './configs.js';

import {
add_token_types,
} from './tokenizers.js';

import {
Callable,
isIntegralNumber,
Expand Down Expand Up @@ -488,10 +492,15 @@ function seq2seqUpdatebeam(beam, newTokenId) {
* @private
*/
async function encoderForward(self, model_inputs) {
let encoderFeeds = {};
for (let key of self.session.inputNames) {
const encoderFeeds = Object.create(null);
for (const key of self.session.inputNames) {
encoderFeeds[key] = model_inputs[key];
}
if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
// Assign default `token_type_ids` to the `encoderFeeds` if the model expects it,
// but they weren't created by the tokenizer.
add_token_types(encoderFeeds);
}
return await sessionRun(self.session, encoderFeeds);
}

Expand Down
2 changes: 1 addition & 1 deletion src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,7 @@ export class PreTrainedTokenizer extends Callable {
* @param {Object} inputs An object containing the input ids and attention mask.
* @returns {Object} The prepared inputs object.
*/
function add_token_types(inputs) {
export function add_token_types(inputs) {
// TODO ensure correctness when token pair is present
if (inputs.input_ids instanceof Tensor) {
inputs.token_type_ids = new Tensor(
Expand Down

0 comments on commit ac0096e

Please sign in to comment.