Skip to content

Commit

Permalink
Only cut decoder_input_ids if past model output
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Sep 23, 2023
1 parent 279cb7c commit 4fa3b7d
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,17 @@ function seq2seqStartBeams(self, inputTokenIds, generation_config, numOutputToke
async function seq2seqRunBeam(self, beam) {
const input_name = self.main_input_name;

let decoder_input_ids = beam.output_token_ids;
if (beam.prev_model_outputs) {
// After the first step, `prev_model_outputs` won't be null.
// So, we cut decoder_input_ids if past is used
decoder_input_ids = decoder_input_ids.slice(-1);
}

// 1. Prepare
let model_inputs = {
[input_name]: beam.inputs,
decoder_input_ids: toI64Tensor(beam.output_token_ids.slice(-1)),
decoder_input_ids: toI64Tensor(decoder_input_ids),
encoder_outputs: beam.encoder_outputs,
past_key_values: beam.prev_model_outputs?.past_key_values,
}
Expand Down Expand Up @@ -3294,15 +3301,6 @@ export class MarianMTModel extends MarianPreTrainedModel {
this.num_encoder_heads = this.config.encoder_attention_heads;
this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
}


/**
* @param {any} model_inputs
* @returns {Promise<Seq2SeqLMOutput>}
*/
async forward(model_inputs) {
return await seq2seqForward(this, model_inputs);
}
}
//////////////////////////////////////////////////

Expand Down Expand Up @@ -3335,13 +3333,6 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel {
this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads;
}

/**
* @param {any} model_inputs
* @returns {Promise<Seq2SeqLMOutput>}
*/
async forward(model_inputs) {
return await seq2seqForward(this, model_inputs);
}
}
//////////////////////////////////////////////////

Expand Down

0 comments on commit 4fa3b7d

Please sign in to comment.