Skip to content

Commit

Permalink
Add a new logits processor to only generate allowed token IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
ae9is committed Dec 15, 2024
1 parent e717e30 commit fe38529
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/generation/logits_process.js
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,40 @@ export class NoBadWordsLogitsProcessor extends LogitsProcessor {
}
}

export class OnlyGoodWordsLogitsProcessor extends LogitsProcessor {
/**
* Create a `OnlyGoodWordsLogitsProcessor`.
* @param {number[][]} good_words_ids List of list of token ids that are allowed to be generated.
* @param {number|number[]} eos_token_id The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
*/
constructor(good_words_ids, eos_token_id) {
super();
this.good_words_ids = good_words_ids;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {bigint[][]} input_ids The input IDs.
* @param {Tensor} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
const good_ids = this.good_words_ids.flat();
// Iterate over batches of input IDs and logits
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
// For every ID, set its logit score to -Infinity unless it's in our list of valid token IDs
for (let j = 0; j < batch_logits_data.length; ++j) {
if (!good_ids.includes(j)) {
batch_logits_data[j] = -Infinity;
}
}
}
return logits
}
}

/**
* [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
* where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
Expand Down
55 changes: 55 additions & 0 deletions tests/utils/logits_process.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,61 @@ describe("Logits Processors", () => {
);
});

describe("good_words_ids", () => {
it(
"generates nothing given empty good_words_ids",
async () => {
const text_input = "hello";
const generated_text_target = "";
const text_target = [{ generated_text: text_input + generated_text_target }];
const output = await pipe(text_input, {
max_new_tokens: 5,
good_words_ids: [
[],
],
});
compare(output, text_target);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"passes basic test",
async () => {
const text_input = "hello";
// Default output tokens for this input: 22172,18547,8136,18547,8136
// Default output text for this input: helloerdingsAndroid Load Между ligger
const generated_text_target = "Android helloAndroid hello hello";
const text_target = [{ generated_text: text_input + generated_text_target }];
const output = await pipe(text_input, {
max_new_tokens: 5,
good_words_ids: [
[22172, 8136], // hello, Android
],
});
compare(output, text_target);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"passes test with many good words",
async () => {
const text_input = "hello";
const generated_text_target = "erdingsAndroidierraég migli";
const text_target = [{ generated_text: text_input + generated_text_target }];
const good_words_ids = [];
for (let i = 0; i < 100000; ++i) {
good_words_ids.push([i * 2 + 1]); // allow all odd numbers
}
good_words_ids.push([22172, 8136]);
const output = await pipe(text_input, { max_new_tokens: 5, good_words_ids });
compare(output, text_target);
},
MAX_TEST_EXECUTION_TIME,
);
});

afterAll(async () => {
await pipe?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
Expand Down

0 comments on commit fe38529

Please sign in to comment.