diff --git a/src/generation/logits_process.js b/src/generation/logits_process.js index f82634f75..4949a137d 100644 --- a/src/generation/logits_process.js +++ b/src/generation/logits_process.js @@ -572,6 +572,40 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor { } } +export class OnlyGoodWordsLogitsProcessor extends LogitsProcessor { + /** + * Create a `OnlyGoodWordsLogitsProcessor`. + * @param {number[][]} good_words_ids List of list of token ids that are allowed to be generated. + * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + */ + constructor(good_words_ids, eos_token_id) { + super(); + this.good_words_ids = good_words_ids; + this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id]; + } + + /** + * Apply logit processor. + * @param {bigint[][]} input_ids The input IDs. + * @param {Tensor} logits The logits. + * @returns {Object} The processed logits. + */ + _call(input_ids, logits) { + const good_ids = this.good_words_ids.flat(); + // Iterate over batches of input IDs and logits + for (let i = 0; i < input_ids.length; ++i) { + const batch_logits_data = /** @type {Float32Array} */(logits[i].data); + // For every ID, set its logit score to -Infinity unless it's in our list of valid token IDs + for (let j = 0; j < batch_logits_data.length; ++j) { + if (!good_ids.includes(j)) { + batch_logits_data[j] = -Infinity; + } + } + } + return logits + } +} + /** * [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, * where the first half correspond to the conditional logits (predicted from the input prompt) and the second half diff --git a/tests/utils/logits_process.test.js b/tests/utils/logits_process.test.js index 5da188ed4..0055af5e0 100644 --- a/tests/utils/logits_process.test.js +++ b/tests/utils/logits_process.test.js @@ -81,6 +81,61 @@ describe("Logits Processors", () => { ); }); + describe("good_words_ids", () => { + it( + "generates nothing given empty good_words_ids", + async () => { + const text_input = "hello"; + const generated_text_target = ""; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const output = await pipe(text_input, { + max_new_tokens: 5, + good_words_ids: [ + [], + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "passes basic test", + async () => { + const text_input = "hello"; + // Default output tokens for this input: 22172,18547,8136,18547,8136 + // Default output text for this input: helloerdingsAndroid Load Между ligger + const generated_text_target = "Android helloAndroid hello hello"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const output = await pipe(text_input, { + max_new_tokens: 5, + good_words_ids: [ + [22172, 8136], // hello, Android + ], + }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "passes test with many good words", + async () => { + const text_input = "hello"; + const generated_text_target = "erdingsAndroidierraég migli"; + const text_target = [{ generated_text: text_input + generated_text_target }]; + const good_words_ids = []; + for (let i = 0; i < 100000; ++i) { + good_words_ids.push([i * 2 + 1]); // allow all odd numbers + } + good_words_ids.push([22172, 8136]); + const output = await pipe(text_input, { max_new_tokens: 5, good_words_ids }); + compare(output, text_target); + }, + MAX_TEST_EXECUTION_TIME, + ); + }); + afterAll(async () => { await pipe?.dispose(); }, MAX_MODEL_DISPOSE_TIME);