Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support of conversational models #658

Merged
merged 5 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te

| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |
Expand Down
1 change: 0 additions & 1 deletion docs/snippets/5_supported-tasks.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

| Task | ID | Description | Supported? |
|--------------------------|----|-------------|------------|
| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ |
| [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) |
| [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) |
| [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) |
Expand Down
68 changes: 60 additions & 8 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -840,18 +840,24 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC
}
}

function isChat(x) {
return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x);
}

/**
* @typedef {import('./tokenizers.js').Message[]} Chat
*
* @typedef {Object} TextGenerationSingle
* @property {string} generated_text The generated text.
* @property {string|Chat} generated_text The generated text.
* @typedef {TextGenerationSingle[]} TextGenerationOutput
*
* @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines.
* @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences.
* @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned.
* @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig
*
* @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs.
* @param {string|string[]} texts One or several prompts (or one list of prompts) to complete.
* @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete.
* @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model.
* @returns {Promise<TextGenerationOutput|TextGenerationOutput[]>} An array or object containing the generated texts.
*
Expand Down Expand Up @@ -920,17 +926,46 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli

/** @type {TextGenerationPipelineCallback} */
async _call(texts, generate_kwargs = {}) {
let isBatched = false;
let isChatInput = false;

// Normalize inputs
/** @type {string[]} */
let inputs;
if (typeof texts === 'string') {
inputs = texts = [texts];
} else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) {
isBatched = true;
inputs = /** @type {string[]} */(texts);
} else {
if (isChat(texts)) {
texts = [/** @type {Chat} */(texts)];
} else if (Array.isArray(texts) && texts.every(isChat)) {
isBatched = true;
} else {
throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats');
}
isChatInput = true;

const isBatched = Array.isArray(texts);
if (!isBatched) {
texts = [/** @type {string}*/ (texts)];
// If the input is a chat, we need to apply the chat template
inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map(
x => this.tokenizer.apply_chat_template(x, {
tokenize: false,
add_generation_prompt: true,
})
));
}

// By default, do not add special tokens
const add_special_tokens = generate_kwargs.add_special_tokens ?? false;

// By default, return full text
const return_full_text = isChatInput
? false
: generate_kwargs.return_full_text ?? true;

this.tokenizer.padding_side = 'left';
const { input_ids, attention_mask } = this.tokenizer(texts, {
const { input_ids, attention_mask } = this.tokenizer(inputs, {
add_special_tokens,
padding: true,
truncation: true,
Expand All @@ -940,17 +975,34 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli
inputs_attention_mask: attention_mask
});

const decoded = this.tokenizer.batch_decode(outputTokenIds, {
let decoded = this.tokenizer.batch_decode(outputTokenIds, {
skip_special_tokens: true,
});


let promptLengths;
if (!return_full_text && input_ids.dims.at(-1) > 0) {
promptLengths = this.tokenizer.batch_decode(input_ids, {
skip_special_tokens: true,
}).map(x => x.length);
}

/** @type {TextGenerationOutput[]} */
const toReturn = Array.from({ length: texts.length }, _ => []);
for (let i = 0; i < decoded.length; ++i) {
const textIndex = Math.floor(i / outputTokenIds.length * texts.length);

if (promptLengths) {
// Trim the decoded text to only include the generated part
decoded[i] = decoded[i].slice(promptLengths[textIndex]);
}
toReturn[textIndex].push({
generated_text: decoded[i]
generated_text: isChatInput
? [
...((/** @type {Chat[]} */(texts)[textIndex])),
{ role: 'assistant', content: decoded[i] },
]
: decoded[i]
});
}
return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn;
Expand Down
12 changes: 6 additions & 6 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,12 @@ function truncateHelper(item, length) {
}


/**
* @typedef {Object} Message
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
* @property {string} content The content of the message.
*/

export class PreTrainedTokenizer extends Callable {
return_token_type_ids = false;

Expand Down Expand Up @@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable {
return this._default_chat_template;
}

/**
* @typedef {Object} Message
* @property {string} role The role of the message (e.g., "user" or "assistant" or "system").
* @property {string} content The content of the message.
*/

/**
* Converts a list of message objects with `"role"` and `"content"` keys to a list of token
* ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
Expand Down
35 changes: 35 additions & 0 deletions tests/generation.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ describe('Generation parameters', () => {
const models = [
'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder
'MBZUAI/LaMini-GPT-124M', // decoder-only

'Xenova/llama2.c-stories15M', // decoder-only
];

// encoder-decoder model
Expand Down Expand Up @@ -135,4 +137,37 @@ describe('Generation parameters', () => {

}, MAX_TEST_EXECUTION_TIME);

// decoder-only model
it(models[2], async () => {
const MAX_NEW_TOKENS = 1;

const text = [
'Once upon a time,',
'Lily',
'Suddenly,',
];

const generator = await pipeline('text-generation', m(models[2]));

{ // return_full_text=false
const output = await generator(text, {
return_full_text: false,
max_new_tokens: MAX_NEW_TOKENS,
num_beams: 2,
num_return_sequences: 2,
});
const lengths = output.flatMap(
x => x.flatMap(
y => generator.tokenizer.encode(y.generated_text.trim(), null, {
add_special_tokens: false,
}).length
)
).every(x => x === MAX_NEW_TOKENS);

expect(lengths).toBe(true);
}
await generator.dispose();

}, MAX_TEST_EXECUTION_TIME);

});
Loading