Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for min_length and min_new_tokens generation parameters #308

Merged
merged 7 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ import {
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,

Sampler,
} from './utils/generation.js';
Expand Down Expand Up @@ -678,6 +680,7 @@ export class PreTrainedModel extends Callable {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options),
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);

} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
Expand Down Expand Up @@ -782,17 +785,17 @@ export class PreTrainedModel extends Callable {
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
// }

// if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
// processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
// }
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
}

// if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
// processors.push(new MinNewTokensLengthLogitsProcessor(
// input_ids_seq_length,
// generation_config.min_new_tokens,
// generation_config.eos_token_id
// ));
// }
if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
processors.push(new MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id
));
}

// if (prefix_allowed_tokens_fn !== null) {
// processors.push(new PrefixConstrainedLogitsProcessor(
Expand Down Expand Up @@ -866,7 +869,8 @@ export class PreTrainedModel extends Callable {
*/
_get_generation_config(generation_config) {
// Create empty generation config (contains defaults)
let gen_config = new GenerationConfig();
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
let gen_config = new GenerationConfig(this.config);

// Apply model's generation config, if it exists
if ('generation_config' in this) {
Expand Down Expand Up @@ -928,7 +932,7 @@ export class PreTrainedModel extends Callable {
input_ids_seq_length = 0;

} else {
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims[0] : inputs.length;
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length;

// decoder-only
if (input_ids_seq_length === 0) {
Expand All @@ -948,6 +952,12 @@ export class PreTrainedModel extends Callable {
logits_processor
)

/** @type {number[]} */
let eos_token_ids = generation_config.eos_token_id;
if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
eos_token_ids = [eos_token_ids];
}

// TODO implement early_stopping
// https://huggingface.co/blog/how-to-generate

Expand Down Expand Up @@ -1007,7 +1017,7 @@ export class PreTrainedModel extends Callable {

newBeam.score += logProb;

if (newTokenId === this.config.eos_token_id) {
if (eos_token_ids && eos_token_ids.includes(newTokenId)) {
newBeam.done = true;
}

Expand Down Expand Up @@ -2476,10 +2486,12 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
* @param {Object} session The ONNX session containing the encoder model.
* @param {any} decoder_merged_session The ONNX session containing the merged decoder model.
* @param {Object} generation_config Configuration object for the generation process.
*/
constructor(config, session, decoder_merged_session) {
constructor(config, session, decoder_merged_session, generation_config) {
super(config, session);
this.decoder_merged_session = decoder_merged_session;
this.generation_config = generation_config;

this.num_layers = this.config.decoder.n_layer;
this.num_heads = this.config.decoder.n_head;
Expand Down Expand Up @@ -2617,9 +2629,11 @@ export class GPT2PreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPT2PreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2649,9 +2663,11 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2673,9 +2689,11 @@ export class GPTNeoXPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTNeoXPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2698,9 +2716,11 @@ export class GPTJPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTJPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2724,9 +2744,11 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `GPTBigCodePreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand All @@ -2747,11 +2769,13 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
export class CodeGenPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `CodeGenPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
*/
constructor(config, session) {
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2785,11 +2809,13 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
export class LlamaPreTrainedModel extends PreTrainedModel {
/**
* Creates a new instance of the `LlamaPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
*/
constructor(config, session) {
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2817,9 +2843,11 @@ export class BloomPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `BloomPreTrainedModel` class.
* @param {Object} config The configuration of the model.
* @param {any} session The ONNX session containing the model weights.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2848,9 +2876,11 @@ export class MptPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `MptPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down Expand Up @@ -2880,9 +2910,11 @@ export class OPTPreTrainedModel extends PreTrainedModel {
* Creates a new instance of the `OPTPreTrainedModel` class.
* @param {Object} config The model configuration object.
* @param {Object} session The ONNX session object.
* @param {GenerationConfig} generation_config The generation configuration.
*/
constructor(config, session) {
constructor(config, session, generation_config) {
super(config, session);
this.generation_config = generation_config;

// config doesn't contain pad_token_id, so we assume it is the eos_token_id
this.config.pad_token_id = this.config.eos_token_id
Expand Down
71 changes: 71 additions & 0 deletions src/utils/generation.js
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,77 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
}
}

/**
* A logits processor that enforces a minimum number of tokens.
*
* @extends LogitsProcessor
*/
export class MinLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinLengthLogitsProcessor.
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(min_length, eos_token_id) {
super();
this.min_length = min_length;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {Array} input_ids The input IDs.
* @param {Object} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
if (input_ids.length < this.min_length) {
for (const eos_token of this.eos_token_id) {
logits.data[eos_token] = -Infinity;
}
}

return logits
}
}

/**
* A logits processor that enforces a minimum number of new tokens.
*
* @extends LogitsProcessor
*/
export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
/**
* Create a MinNewTokensLengthLogitsProcessor.
* @param {number} prompt_length_to_skip The input tokens length.
* @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
*/
constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
super();
this.prompt_length_to_skip = prompt_length_to_skip;
this.min_new_tokens = min_new_tokens;
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
}

/**
* Apply logit processor.
* @param {Array} input_ids The input IDs.
* @param {Object} logits The logits.
* @returns {Object} The processed logits.
*/
_call(input_ids, logits) {
const new_tokens_length = input_ids.length - this.prompt_length_to_skip;
if (new_tokens_length < this.min_new_tokens) {
for (const eos_token of this.eos_token_id) {
logits.data[eos_token] = -Infinity;
}
}

return logits
}
}

/**
* Class that holds a configuration for a generation task.
*/
Expand Down
Loading