diff --git a/src/models.js b/src/models.js index acc7742d9..fa8e89d5f 100644 --- a/src/models.js +++ b/src/models.js @@ -1187,7 +1187,8 @@ export class PreTrainedModel extends Callable { Object.assign(decoderFeeds, pastKeyValues) } else { // TODO support batches (i.e., batch_size > 1) - if (this.config.is_encoder_decoder) { + // @ts-ignore + if (this.config.is_encoder_decoder && (this.add_encoder_pkv ?? true)) { // @ts-ignore let encoder_dims = [1, this.num_encoder_heads, 0, this.encoder_dim_kv]; // @ts-ignore @@ -2538,7 +2539,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { */ export class VisionEncoderDecoderModel extends PreTrainedModel { main_input_name = 'pixel_values'; - // add_decoder_pkv = false; + add_encoder_pkv = false; /** * Creates a new instance of the `VisionEncoderDecoderModel` class.