diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 7f6dbd75939..fa5f94d7724 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -237,27 +237,50 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device) # compute + # import xe_addons + # if is_causal: + # if key.dtype == torch.uint8: + # attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) + # else: + # attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) + # elif seq_length != kv_length and seq_length <= 32: + # # todo: add scale support + # if key.dtype == torch.uint8: + # attn_output = xe_addons.sdp_fp8(query, key, value, mask) + # else: + # attn_output = xe_addons.sdp(query, key, value, mask) + # else: + # if key.dtype == torch.uint8: + # attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) + # else: + # attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) + import xe_addons if is_causal: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) + attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask) else: - attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) + attn_output = xe_addons.sdp_causal(query, key, value, mask) elif seq_length != kv_length and seq_length <= 32: - # todo: add scale support if key.dtype == torch.uint8: attn_output = xe_addons.sdp_fp8(query, key, value, mask) else: attn_output = xe_addons.sdp(query, key, value, mask) else: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) + attn_output = xe_addons.sdp_fp8(query, key, value, mask) else: - attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) + attn_output = xe_addons.sdp_non_causal(query, key, value, mask) return attn_output else: mask = mask[..., :seq_length, :kv_length] if mask is not None else None + + from ipex_llm.transformers.models.utils import repeat_kv + if n_heads != n_kv_heads: + key = repeat_kv(key, n_heads // n_kv_heads) + value = repeat_kv(value, n_heads // n_kv_heads) + return torch.nn.functional.scaled_dot_product_attention( query, key, value, mask, is_causal=is_causal, scale=scale ) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index accfd2dc0a8..7f9bba681cc 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -59,6 +59,10 @@ def siglip_attention_forward( and get_xpu_device_type(query_states) in ["arc", "flex"] and query_states.dtype in [torch.float, torch.half] ): + n_heads, kv_length = query_states.size(1), key_states.size(2) + from ipex_llm.transformers.models.common import prepare_mask + attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length, + False, query_states.dtype, query_states.device) import xe_addons attn_weights = None attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 904ffe9b727..8d085ee8a32 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -388,6 +388,15 @@ def fp16_fusion_check(proj, x, training): return True +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def update_past_key_value(past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device): bsz, num_heads, _, head_dim = key_states.shape