From b050368efc48249847f8b6bd2c76119a4202c25f Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 20 Dec 2024 16:41:50 +0800 Subject: [PATCH] refactor yuan2 and starcoder2 and fix (#12589) --- .../ipex_llm/transformers/models/llama32.py | 2 +- .../ipex_llm/transformers/models/minicpm.py | 9 ++-- .../ipex_llm/transformers/models/minicpmv.py | 8 +-- .../src/ipex_llm/transformers/models/qwen2.py | 2 +- .../transformers/models/starcoder2.py | 50 ++++--------------- .../src/ipex_llm/transformers/models/yuan.py | 40 ++++----------- 6 files changed, 28 insertions(+), 83 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama32.py b/python/llm/src/ipex_llm/transformers/models/llama32.py index 15c156a192c..14b11afb2d8 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama32.py +++ b/python/llm/src/ipex_llm/transformers/models/llama32.py @@ -234,7 +234,7 @@ def llama_attention_forward( attn_weights = None attn_output = scaled_dot_product_attention( query_states, key_states, value_states, - attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim) + attention_mask, q_len == key_states.size(2) ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 3bc95d6c3c7..f3c454251fd 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -38,15 +38,13 @@ import torch import warnings -import torch.nn as nn from typing import Optional, Tuple, Union, List import math from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache -from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compresskv_attn_mask +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope -from ipex_llm.transformers.models.llama import repeat_kv from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \ DynamicCompressCache, DynamicCompressFp8Cache from transformers.cache_utils import Cache @@ -127,11 +125,10 @@ def minicpm_attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None) - from ipex_llm.transformers.models.common import scaled_dot_product_attention attn_weights = None attn_output = scaled_dot_product_attention( query_states, key_states, value_states, - attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim) + attention_mask, q_len == kv_seq_len ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index dc996691c71..1a71663dd9c 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -28,6 +28,7 @@ from torch.nn.functional import linear from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd from ipex_llm.transformers.models.common import attention_softmax +from ipex_llm.transformers.models.common import scaled_dot_product_attention from transformers import AutoProcessor, TextIteratorStreamer from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor @@ -72,10 +73,11 @@ def siglip_attention_forward( 72, 80 ) - from ipex_llm.transformers.models.common import scaled_dot_product_attention attn_weights = None - attn_output = scaled_dot_product_attention(query_states, key_states, value_states, - attention_mask, False, math.sqrt(self.head_dim)) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, False, 1 / math.sqrt(self.head_dim) + ) attn_output = attn_output[:, :, :, :self.head_dim] diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index d746f079991..011d6c22d03 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -595,7 +595,7 @@ def qwen2_attention_forward( else: attn_output = scaled_dot_product_attention( query_states, key_states, value_states, - attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim) + attention_mask, q_len == kv_seq_len ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/starcoder2.py b/python/llm/src/ipex_llm/transformers/models/starcoder2.py index 1bffc1ee67f..7a23c80f3ae 100644 --- a/python/llm/src/ipex_llm/transformers/models/starcoder2.py +++ b/python/llm/src/ipex_llm/transformers/models/starcoder2.py @@ -40,17 +40,15 @@ import torch import warnings -from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax -from ipex_llm.transformers.models.utils import ( - use_quantize_kv_cache, restore_fp8_kv_cache, - should_use_fuse_rope, use_sdp, use_sdp_causal -) +from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import scaled_dot_product_attention +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, should_use_fuse_rope from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.utils.common.log4Error import invalidInputError from typing import Optional, Tuple, List from transformers.cache_utils import Cache -from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_rotary_pos_emb +from transformers.models.starcoder2.modeling_starcoder2 import apply_rotary_pos_emb from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention @@ -103,41 +101,11 @@ def attention_forward( self.layer_idx, None) # IPEX-LLM OPT: sdp - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if isinstance(past_key_value, DynamicFp8Cache): - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = None + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index 800f0273c06..afa99da17f1 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -26,12 +26,12 @@ import torch from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.common import attention_softmax +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ mlp_fusion_check, fp16_fusion_check -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import SILU, update_past_key_value -from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import should_use_fuse_rope def merge_qk(module: torch.nn.Module): @@ -214,34 +214,12 @@ def yuan_attention_forward( ) past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None - # IPEX-LLM OPT: sdp - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) + # IPEX-LLM OPT: sdpa + attn_weights = None + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)