From f30002ea802b8ab66b0916b543274c32c39c4e8f Mon Sep 17 00:00:00 2001 From: ae9is <125031666+ae9is@users.noreply.github.com> Date: Sun, 15 Dec 2024 00:16:21 +0000 Subject: [PATCH] Improve Whisper language detection performance --- src/generation/configuration_utils.js | 7 +++++++ src/generation/stopping_criteria.js | 13 +++++++++++++ src/models.js | 22 ++++++++++++++++++++-- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/generation/configuration_utils.js b/src/generation/configuration_utils.js index 8474057da..639069b5a 100644 --- a/src/generation/configuration_utils.js +++ b/src/generation/configuration_utils.js @@ -197,6 +197,13 @@ export class GenerationConfig { */ bad_words_ids = null; + /** + * List of token ids that are allowed to be generated. + * @type {number[][]} + * @default null + */ + good_words_ids = null; + /** * List of token ids that must be generated. * If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. diff --git a/src/generation/stopping_criteria.js b/src/generation/stopping_criteria.js index 08434f2b4..aa3327a66 100644 --- a/src/generation/stopping_criteria.js +++ b/src/generation/stopping_criteria.js @@ -154,3 +154,16 @@ export class InterruptableStoppingCriteria extends StoppingCriteria { return new Array(input_ids.length).fill(this.interrupted); } } + +/** + * This class can be used to always stop generation after one pass. + */ +export class AlwaysStopCriteria extends StoppingCriteria { + constructor() { + super(); + } + + _call(input_ids, scores) { + return new Array(input_ids.length).fill(true); + } +} diff --git a/src/models.js b/src/models.js index c9d073fe4..791df5833 100644 --- a/src/models.js +++ b/src/models.js @@ -90,6 +90,7 @@ import { TopKLogitsWarper, TopPLogitsWarper, ClassifierFreeGuidanceLogitsProcessor, + OnlyGoodWordsLogitsProcessor, } from './generation/logits_process.js'; import { @@ -112,7 +113,7 @@ import { import { RawImage } from './utils/image.js'; import { dynamic_time_warping, max, medianFilter } from './utils/maths.js'; -import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; +import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js'; import { LogitsSampler } from './generation/logits_sampler.js'; import { apis } from './env.js'; @@ -1195,6 +1196,10 @@ export class PreTrainedModel extends Callable { processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)); } + if (generation_config.good_words_ids !== null) { + processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id)); + } + if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) { processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)); } @@ -3120,7 +3125,20 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { if (!all_lang_ids || all_lang_ids.length <= 0) { throw new Error("Cannot detect language without language code to token ID map for model"); } - const output = await this.generate({ ...options, decoder_input_ids }); + const stopping_criteria = new StoppingCriteriaList(); + stopping_criteria.push(new AlwaysStopCriteria()); + const good_words_ids = [all_lang_ids]; + const output = await this.generate({ + ...options, + generation_config: { + ...generation_config, + good_words_ids, + num_beams: 1, + do_sample: false, + }, + stopping_criteria, + decoder_input_ids, + }); const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x)); const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x)); return lang_ids;