Skip to content

Commit

Permalink
optimize new minicpm model
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Dec 19, 2024
1 parent 4540424 commit b654b8a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 62 deletions.
10 changes: 6 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de
return mask


def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = None,
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, mask: torch.Tensor = None,
is_causal: bool = False, scale: float = None) -> torch.Tensor:
bsz, n_heads, seq_length, head_dim = query.shape
_, n_kv_heads, kv_length, _ = key.shape
Expand Down Expand Up @@ -268,7 +268,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)

Expand All @@ -281,6 +281,8 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
key = repeat_kv(key, n_heads // n_kv_heads)
value = repeat_kv(value, n_heads // n_kv_heads)

return torch.nn.functional.scaled_dot_product_attention(
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
return attn_output
47 changes: 5 additions & 42 deletions python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,49 +127,12 @@ def minicpm_attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
# [CompressKV]
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)

if use_quantizekv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantizekv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantizekv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand Down
20 changes: 4 additions & 16 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from torch.nn.functional import linear
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import use_sdp_non_causal
from transformers import AutoProcessor, TextIteratorStreamer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor

Expand Down Expand Up @@ -73,21 +72,10 @@ def siglip_attention_forward(
72, 80
)

if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
import xe_addons
attn_weights = None
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
value_states.contiguous(), attention_mask)
else:
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = attention_softmax(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights,
p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_weights = None
attn_output = scaled_dot_product_attention(query_states, key_states, value_states,
attention_mask, False, math.sqrt(self.head_dim))

attn_output = attn_output[:, :, :, :self.head_dim]

Expand Down

0 comments on commit b654b8a

Please sign in to comment.