Skip to content

Commit

Permalink
add compresskv back for mistral (#12607)
Browse files Browse the repository at this point in the history
* add compresskv back for mistral

* fix

* fix
  • Loading branch information
MeouSker77 authored Dec 25, 2024
1 parent 9c9800b commit 4e6b9d8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
31 changes: 27 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from typing import Optional, Tuple, Union, List

import os
import torch
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand All @@ -45,8 +46,11 @@
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))


def mistral_model_forward(
Expand All @@ -69,11 +73,22 @@ def mistral_model_forward(
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
isinstance(past_key_values, DynamicCompressCache)

if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif (
not use_quantize_kv
and not use_compress_kv
and not isinstance(past_key_values, DynamicNormalCache)
):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end

Expand Down Expand Up @@ -127,8 +142,16 @@ def mistral_attention_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "mistral")

key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
if isinstance(past_key_value, DynamicCompressCache):
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

# IPEX-LLM OPT: sdpa
attn_weights = None
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
Expand Down Expand Up @@ -171,7 +171,7 @@ def mixtral_attention_forward(
# for flash attention
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
Expand Down

0 comments on commit 4e6b9d8

Please sign in to comment.