From 3da19cb3eb201a234c3d9835e1b52286b99824c5 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 21 Mar 2024 16:23:42 +0200 Subject: [PATCH 1/5] Add `return_full_text` option for text-generation models --- src/pipelines.js | 17 ++++++++++++++++- tests/generation.test.js | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/pipelines.js b/src/pipelines.js index 2b064d522..0a81e8a39 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -848,6 +848,7 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC * * @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines. * @property {boolean} [add_special_tokens] Whether or not to add special tokens when tokenizing the sequences. + * @property {boolean} [return_full_text=true] If set to `false` only added text is returned, otherwise the full text is returned. * @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig * * @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs. @@ -929,6 +930,9 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli // By default, do not add special tokens const add_special_tokens = generate_kwargs.add_special_tokens ?? false; + // By default, return full text + const return_full_text = generate_kwargs.return_full_text ?? true; + this.tokenizer.padding_side = 'left'; const { input_ids, attention_mask } = this.tokenizer(texts, { add_special_tokens, @@ -940,15 +944,26 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli inputs_attention_mask: attention_mask }); - const decoded = this.tokenizer.batch_decode(outputTokenIds, { + let decoded = this.tokenizer.batch_decode(outputTokenIds, { skip_special_tokens: true, }); + + let promptLengths; + if (!return_full_text && input_ids.dims.at(-1) > 0) { + promptLengths = this.tokenizer.batch_decode(input_ids, { + skip_special_tokens: true, + }).map(x => x.length); + } + /** @type {TextGenerationOutput[]} */ const toReturn = Array.from({ length: texts.length }, _ => []); for (let i = 0; i < decoded.length; ++i) { const textIndex = Math.floor(i / outputTokenIds.length * texts.length); + if (promptLengths) { + decoded[i] = decoded[i].slice(promptLengths[textIndex]); + } toReturn[textIndex].push({ generated_text: decoded[i] }); diff --git a/tests/generation.test.js b/tests/generation.test.js index eb6b87f49..454348e2d 100644 --- a/tests/generation.test.js +++ b/tests/generation.test.js @@ -11,6 +11,8 @@ describe('Generation parameters', () => { const models = [ 'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder 'MBZUAI/LaMini-GPT-124M', // decoder-only + + 'Xenova/llama2.c-stories15M', // decoder-only ]; // encoder-decoder model @@ -135,4 +137,35 @@ describe('Generation parameters', () => { }, MAX_TEST_EXECUTION_TIME); + // decoder-only model + it(models[2], async () => { + const MAX_NEW_TOKENS = 1; + + const text = [ + 'Once upon a time,', + 'Lily', + 'Suddenly,', + ]; + + { // return_full_text=false + const output = await generator(text, { + return_full_text: false, + max_new_tokens: MAX_NEW_TOKENS, + num_beams: 2, + num_return_sequences: 2, + }); + const lengths = output.flatMap( + x => x.flatMap( + y => generator.tokenizer.encode(y.generated_text.trim(), null, { + add_special_tokens: false, + }).length + ) + ).every(x => x === MAX_NEW_TOKENS); + + expect(lengths).toBe(true); + } + await generator.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + }); \ No newline at end of file From 385de36470fbfcc61287b7f351f07c06383e136c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 21 Mar 2024 23:14:56 +0200 Subject: [PATCH 2/5] [wip] Support chat inputs in text-generation pipeline --- src/pipelines.js | 39 ++++++++++++++++++++++++++++++++++----- src/tokenizers.js | 12 ++++++------ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 0a81e8a39..95df5a441 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -840,8 +840,13 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC } } +function isChat(x) { + return Array.isArray(x) && x.every(x => 'role' in x && 'content' in x); +} /** + * @typedef {import('./tokenizers.js').Message[]} Chat + * * @typedef {Object} TextGenerationSingle * @property {string} generated_text The generated text. * @typedef {TextGenerationSingle[]} TextGenerationOutput @@ -852,7 +857,7 @@ export class TranslationPipeline extends (/** @type {new (options: TextPipelineC * @typedef {import('./utils/generation.js').GenerationConfigType & TextGenerationSpecificParams} TextGenerationConfig * * @callback TextGenerationPipelineCallback Complete the prompt(s) given as inputs. - * @param {string|string[]} texts One or several prompts (or one list of prompts) to complete. + * @param {string|string[]|Chat|Chat[]} texts One or several prompts (or one list of prompts) to complete. * @param {TextGenerationConfig} [options] Additional keyword arguments to pass along to the generate method of the model. * @returns {Promise} An array or object containing the generated texts. * @@ -921,17 +926,41 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli /** @type {TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { + let isBatched; - const isBatched = Array.isArray(texts); - if (!isBatched) { - texts = [/** @type {string}*/ (texts)]; + let isChatInput = false; + + // Normalize inputs + if (typeof texts === 'string') { + texts = [texts]; + } else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) { + isBatched = true; + } else { + if (isChat(texts)) { + texts = [/** @type {Chat} */(texts)]; + } else if (Array.isArray(texts) && texts.every(isChat)) { + isBatched = true; + } else { + throw new Error('Input must be a string, an array of strings, a Chat, or an array of Chats'); + } + isChatInput = true; + + // If the input is a chat, we need to apply the chat template + texts = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map( + x => this.tokenizer.apply_chat_template(x, { + tokenize: false, + add_generation_prompt: true, + }) + )); } // By default, do not add special tokens const add_special_tokens = generate_kwargs.add_special_tokens ?? false; // By default, return full text - const return_full_text = generate_kwargs.return_full_text ?? true; + const return_full_text = !isChatInput + ? false + : generate_kwargs.return_full_text ?? true; this.tokenizer.padding_side = 'left'; const { input_ids, attention_mask } = this.tokenizer(texts, { diff --git a/src/tokenizers.js b/src/tokenizers.js index 5b58e37c0..964a79f3f 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2429,6 +2429,12 @@ function truncateHelper(item, length) { } +/** + * @typedef {Object} Message + * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). + * @property {string} content The content of the message. + */ + export class PreTrainedTokenizer extends Callable { return_token_type_ids = false; @@ -2959,12 +2965,6 @@ export class PreTrainedTokenizer extends Callable { return this._default_chat_template; } - /** - * @typedef {Object} Message - * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). - * @property {string} content The content of the message. - */ - /** * Converts a list of message objects with `"role"` and `"content"` keys to a list of token * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to From 0fc8090b535b667db9af56d3ab11abbfd9cb92cc Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Mar 2024 00:43:38 +0200 Subject: [PATCH 3/5] Align return type with python version --- src/pipelines.js | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 95df5a441..719326a8a 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -848,7 +848,7 @@ function isChat(x) { * @typedef {import('./tokenizers.js').Message[]} Chat * * @typedef {Object} TextGenerationSingle - * @property {string} generated_text The generated text. + * @property {string|Chat} generated_text The generated text. * @typedef {TextGenerationSingle[]} TextGenerationOutput * * @typedef {Object} TextGenerationSpecificParams Parameters specific to text-generation pipelines. @@ -926,15 +926,17 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli /** @type {TextGenerationPipelineCallback} */ async _call(texts, generate_kwargs = {}) { - let isBatched; - + let isBatched = false; let isChatInput = false; // Normalize inputs + /** @type {string[]} */ + let inputs; if (typeof texts === 'string') { - texts = [texts]; + inputs = texts = [texts]; } else if (Array.isArray(texts) && texts.every(x => typeof x === 'string')) { isBatched = true; + inputs = /** @type {string[]} */(texts); } else { if (isChat(texts)) { texts = [/** @type {Chat} */(texts)]; @@ -946,7 +948,7 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli isChatInput = true; // If the input is a chat, we need to apply the chat template - texts = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map( + inputs = /** @type {string[]} */(/** @type {Chat[]} */ (texts).map( x => this.tokenizer.apply_chat_template(x, { tokenize: false, add_generation_prompt: true, @@ -963,7 +965,7 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli : generate_kwargs.return_full_text ?? true; this.tokenizer.padding_side = 'left'; - const { input_ids, attention_mask } = this.tokenizer(texts, { + const { input_ids, attention_mask } = this.tokenizer(inputs, { add_special_tokens, padding: true, truncation: true, @@ -991,10 +993,16 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli const textIndex = Math.floor(i / outputTokenIds.length * texts.length); if (promptLengths) { + // Trim the decoded text to only include the generated part decoded[i] = decoded[i].slice(promptLengths[textIndex]); } toReturn[textIndex].push({ - generated_text: decoded[i] + generated_text: isChatInput + ? [ + ...((/** @type {Chat[]} */(texts)[textIndex])), + { role: 'assistant', content: decoded[i] }, + ] + : decoded[i] }); } return (!isBatched && toReturn.length === 1) ? toReturn[0] : toReturn; From 0a71a0dccda22a4889c704596a6b736f0f8d060e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Mar 2024 00:53:04 +0200 Subject: [PATCH 4/5] Remove conversational task (moved to text-generation) --- README.md | 1 - docs/snippets/5_supported-tasks.snippet | 1 - 2 files changed, 2 deletions(-) diff --git a/README.md b/README.md index 755d0c505..f4b804c5f 100644 --- a/README.md +++ b/README.md @@ -198,7 +198,6 @@ You can refine your search by selecting the task you're interested in (e.g., [te | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| -| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ | | [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) | | [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) | | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) | diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index ac71ee528..ee682ffca 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -5,7 +5,6 @@ | Task | ID | Description | Supported? | |--------------------------|----|-------------|------------| -| [Conversational](https://huggingface.co/tasks/conversational) | `conversational` | Generating conversational text that is relevant, coherent and knowledgable given a prompt. | ❌ | | [Fill-Mask](https://huggingface.co/tasks/fill-mask) | `fill-mask` | Masking some of the words in a sentence and predicting which words should replace those masks. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FillMaskPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers.js) | | [Question Answering](https://huggingface.co/tasks/question-answering) | `question-answering` | Retrieve the answer to a question from a given text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.QuestionAnsweringPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=question-answering&library=transformers.js) | | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | `sentence-similarity` | Determining how similar two texts are. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.FeatureExtractionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=feature-extraction&library=transformers.js) | From f8bd6ff9eb9a201102c13dc05fdb40d82216c612 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 22 Mar 2024 01:22:05 +0200 Subject: [PATCH 5/5] Fix typos --- src/pipelines.js | 2 +- tests/generation.test.js | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pipelines.js b/src/pipelines.js index 719326a8a..d68547bf7 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -960,7 +960,7 @@ export class TextGenerationPipeline extends (/** @type {new (options: TextPipeli const add_special_tokens = generate_kwargs.add_special_tokens ?? false; // By default, return full text - const return_full_text = !isChatInput + const return_full_text = isChatInput ? false : generate_kwargs.return_full_text ?? true; diff --git a/tests/generation.test.js b/tests/generation.test.js index 454348e2d..da50388aa 100644 --- a/tests/generation.test.js +++ b/tests/generation.test.js @@ -147,6 +147,8 @@ describe('Generation parameters', () => { 'Suddenly,', ]; + const generator = await pipeline('text-generation', m(models[2])); + { // return_full_text=false const output = await generator(text, { return_full_text: false,