Skip to content

Commit

Permalink
Improve addPastKeyValues function
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Sep 11, 2023
1 parent 0dbe7b2 commit ed64809
Showing 1 changed file with 31 additions and 39 deletions.
70 changes: 31 additions & 39 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ function boolTensor(value) {
* @private
*/
async function seq2seqForward(self, model_inputs) {
const add_decoder_pkv = self.add_decoder_pkv ?? true;

let { encoder_outputs, past_key_values } = model_inputs;

Expand All @@ -325,7 +324,7 @@ async function seq2seqForward(self, model_inputs) {
if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) {
decoderFeeds.encoder_attention_mask = model_inputs.attention_mask
}
self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv);
self.addPastKeyValues(decoderFeeds, past_key_values);

const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds);
let logits = decoderResults.logits;
Expand Down Expand Up @@ -1182,57 +1181,50 @@ export class PreTrainedModel extends Callable {
*
* @param {Object} decoderFeeds The decoder feeds object to add past key values to.
* @param {Object} pastKeyValues An object containing past key values.
* @param {boolean} [hasDecoder=false] Whether the model has a decoder.
*/
addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) {
addPastKeyValues(decoderFeeds, pastKeyValues) {
if (pastKeyValues) {
Object.assign(decoderFeeds, pastKeyValues)
} else {
// TODO support batches (i.e., batch_size > 1)
if (hasDecoder) {
if (this.config.is_encoder_decoder) {
// @ts-ignore
let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv];
// @ts-ignore
for (let i = 0; i < this.num_encoder_layers; ++i) {
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
}

// @ts-ignore
let decoder_dims = [1, this.num_decoder_heads, 0, this.decoder_dim_kv];
// @ts-ignore
for (let i = 0; i < this.num_decoder_layers; ++i) {
decoderFeeds[`past_key_values.${i}.encoder.key`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.encoder.value`] = new Tensor('float32', [], encoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.key`] = new Tensor('float32', [], decoder_dims)
decoderFeeds[`past_key_values.${i}.decoder.value`] = new Tensor('float32', [], decoder_dims)
}
} else if (this.config.multi_query) { // e.g., for `gpt_bigcode`
// @ts-ignore
let dims = [1, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
}
} else if (this.config.model_type === 'bloom') {
// NOTE: Custom implementation for Bloom

} else {
if (this.config.multi_query) {
// @ts-ignore
let dims = [1, 0, 2 * this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key_value`] = new Tensor('float32', [], dims)
}
} else if (this.config.model_type === 'bloom') {
// Custom implementation for Bloom
// @ts-ignore
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
// @ts-ignore
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
}
} else {
// @ts-ignore
let dims = [1, this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
}
// @ts-ignore
let keyDims = [1 * this.num_heads, this.dim_kv, 0] // [batch_size x num_heads,64,past_sequence_length]
// @ts-ignore
let valueDims = [1 * this.num_heads, 0, this.dim_kv] // [batch_size x num_heads,past_sequence_length,64]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], keyDims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], valueDims)
}
} else { // Decoder-only
// @ts-ignore
let dims = [1, this.num_heads, 0, this.dim_kv]
// @ts-ignore
for (let i = 0; i < this.num_layers; ++i) {
decoderFeeds[`past_key_values.${i}.key`] = new Tensor('float32', [], dims)
decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims)
}
}
}
Expand Down Expand Up @@ -2546,7 +2538,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
*/
export class VisionEncoderDecoderModel extends PreTrainedModel {
main_input_name = 'pixel_values';
add_decoder_pkv = false;
// add_decoder_pkv = false;

/**
* Creates a new instance of the `VisionEncoderDecoderModel` class.
Expand Down

0 comments on commit ed64809

Please sign in to comment.