Skip to content

Commit

Permalink
Improve Whisper language detection performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ae9is committed Dec 15, 2024
1 parent fe38529 commit f30002e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/generation/configuration_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ export class GenerationConfig {
*/
bad_words_ids = null;

/**
* List of token ids that are allowed to be generated.
* @type {number[][]}
* @default null
*/
good_words_ids = null;

/**
* List of token ids that must be generated.
* If given a `number[][]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`.
Expand Down
13 changes: 13 additions & 0 deletions src/generation/stopping_criteria.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,16 @@ export class InterruptableStoppingCriteria extends StoppingCriteria {
return new Array(input_ids.length).fill(this.interrupted);
}
}

/**
* This class can be used to always stop generation after one pass.
*/
export class AlwaysStopCriteria extends StoppingCriteria {
constructor() {
super();
}

_call(input_ids, scores) {
return new Array(input_ids.length).fill(true);
}
}
22 changes: 20 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ import {
TopKLogitsWarper,
TopPLogitsWarper,
ClassifierFreeGuidanceLogitsProcessor,
OnlyGoodWordsLogitsProcessor,
} from './generation/logits_process.js';

import {
Expand All @@ -112,7 +113,7 @@ import {
import { RawImage } from './utils/image.js';

import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { AlwaysStopCriteria, EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
import { LogitsSampler } from './generation/logits_sampler.js';
import { apis } from './env.js';

Expand Down Expand Up @@ -1195,6 +1196,10 @@ export class PreTrainedModel extends Callable {
processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
}

if (generation_config.good_words_ids !== null) {
processors.push(new OnlyGoodWordsLogitsProcessor(generation_config.good_words_ids, generation_config.eos_token_id));
}

if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}
Expand Down Expand Up @@ -3120,7 +3125,20 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
if (!all_lang_ids || all_lang_ids.length <= 0) {
throw new Error("Cannot detect language without language code to token ID map for model");
}
const output = await this.generate({ ...options, decoder_input_ids });
const stopping_criteria = new StoppingCriteriaList();
stopping_criteria.push(new AlwaysStopCriteria());
const good_words_ids = [all_lang_ids];
const output = await this.generate({
...options,
generation_config: {
...generation_config,
good_words_ids,
num_beams: 1,
do_sample: false,
},
stopping_criteria,
decoder_input_ids,
});
const sane = Array.from((/**@type {Tensor}**/(output)).data).flatMap(x => Number(x));
const lang_ids = sane.filter(x => Object.values(generation_config.lang_to_id).includes(x));
return lang_ids;
Expand Down

0 comments on commit f30002e

Please sign in to comment.