diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index db9b8fd9afc..4d29835c3fe 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -37,12 +37,12 @@ from typing import Optional, Tuple from transformers.cache_utils import Cache -from transformers.models.glm.modeling_glm import repeat_kv, apply_rotary_pos_emb +from transformers.models.glm.modeling_glm import apply_rotary_pos_emb from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.models.common import merge_qkv_base -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import make_cache_contiguous_inplaced -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache def merge_qkv(module: torch.nn.Module): @@ -102,52 +102,16 @@ def glm_attention_forward( else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - kv_seq_len = key_states.size(-2) - if attention_mask is not None: # no matter the length, we just slice it - attention_mask = attention_mask[:, :, :, : kv_seq_len] - - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantizekv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantizekv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantizekv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = None + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == key_states.size(2), self.scaling + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)