From ed648094c98a18f3c8c97508bf7f44e3eefceb10 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 12 Sep 2023 01:05:31 +0200 Subject: [PATCH] Improve `addPastKeyValues` function --- src/models.js | 70 +++++++++++++++++++++++---------------------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/src/models.js b/src/models.js index 0db929f10..acc7742d9 100644 --- a/src/models.js +++ b/src/models.js @@ -308,7 +308,6 @@ function boolTensor(value) { * @private */ async function seq2seqForward(self, model_inputs) { - const add_decoder_pkv = self.add_decoder_pkv ?? true; let { encoder_outputs, past_key_values } = model_inputs; @@ -325,7 +324,7 @@ async function seq2seqForward(self, model_inputs) { if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } - self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv); + self.addPastKeyValues(decoderFeeds, past_key_values); const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds); let logits = decoderResults.logits; @@ -1182,57 +1181,50 @@ export class PreTrainedModel extends Callable { * * @param {Object} decoderFeeds The decoder feeds object to add past key values to. * @param {Object} pastKeyValues An object containing past key values. - * @param {boolean} [hasDecoder=false] Whether the model has a decoder. */ - addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) { + addPastKeyValues(decoderFeeds, pastKeyValues) { if (pastKeyValues) { Object.assign(decoderFeeds, pastKeyValues) } else { // TODO support batches (i.e., batch_size > 1) - if (hasDecoder) { + if (this.config.is_encoder_decoder) { // @ts-ignore let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv]; - // @ts-ignore - for (let i = 0; i < this.num_encoder_layers; ++i) { - decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims) - decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims) - } - // @ts-ignore let decoder_dims = [1, this.num_decoder_heads, 0, this.decoder_dim_kv]; // @ts-ignore for (let i = 0; i < this.num_decoder_layers; ++i) { + decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims) + decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims) decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims) decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims) } + } else if (this.config.multi_query) { // e.g., for `gpt_bigcode` + // @ts-ignore + let dims = [1, 0, 2 * this.dim_kv] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims) + } + } else if (this.config.model_type === 'bloom') { + // NOTE: Custom implementation for Bloom - } else { - if (this.config.multi_query) { - // @ts-ignore - let dims = [1, 0, 2 * this.dim_kv] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims) - } - } else if (this.config.model_type === 'bloom') { - // Custom implementation for Bloom - // @ts-ignore - let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length] - // @ts-ignore - let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims) - } - } else { - // @ts-ignore - let dims = [1, this.num_heads, 0, this.dim_kv] - // @ts-ignore - for (let i = 0; i < this.num_layers; ++i) { - decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) - decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) - } + // @ts-ignore + let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length] + // @ts-ignore + let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims) + } + } else { // Decoder-only + // @ts-ignore + let dims = [1, this.num_heads, 0, this.dim_kv] + // @ts-ignore + for (let i = 0; i < this.num_layers; ++i) { + decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims) + decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) } } } @@ -2546,7 +2538,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { */ export class VisionEncoderDecoderModel extends PreTrainedModel { main_input_name = 'pixel_values'; - add_decoder_pkv = false; + // add_decoder_pkv = false; /** * Creates a new instance of the `VisionEncoderDecoderModel` class.