Skip to content

Commit

Permalink
add npu sdp (#11562)
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardozcm authored Jul 11, 2024
1 parent 2b8ad87 commit b9c6699
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ In this directory, you will find examples on how you could apply IPEX-LLM INT4 o
| Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
| Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) |
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
| Deepseek | [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) |

## 0. Requirements
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down
64 changes: 64 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@


import torch
from torch.nn import functional as F
import importlib
from typing import Optional, Tuple
from ipex_llm.transformers.npu_models.common import merge_linear


Expand All @@ -51,3 +54,64 @@ def baichuan_mlp_forward(self, x):
gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1)
down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj)
return down_proj


def baichuan_attention_fwd(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
modeling_module_name = self.__class__.__module__
module = importlib.import_module(modeling_module_name)
apply_rotary_pos_emb = module.apply_rotary_pos_emb

bsz, q_len, _ = hidden_states.size()

proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
# [bsz, nh, t, hd]

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask
)
attn_weights = None
else:
with torch.backends.cuda.sdp_kernel(enable_flash=True,
enable_math=True, enable_mem_efficient=True):
attn_output = F.scaled_dot_product_attention(query_states, key_states,
value_states, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
12 changes: 12 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ def optimize_llm(model: torch.nn.Module):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.npu_models.baichuan import baichuan_mlp_forward, merge_mlp
from ipex_llm.transformers.npu_models.baichuan import baichuan_attention_fwd
model.apply(merge_mlp)

convert_forward(model, module.MLP, baichuan_mlp_forward)
convert_forward(model, module.Attention, baichuan_attention_fwd)

elif model.config.model_type == "phi3_v":
modeling_module_name = model.__class__.__module__
Expand All @@ -189,3 +191,13 @@ def optimize_llm(model: torch.nn.Module):
from transformers.models.clip.modeling_clip import CLIPAttention
convert_forward(model, CLIPAttention, phi3v_encoder_attention_forward)
convert_forward(model, module.Phi3VModel, phi3v_model_forward)

from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward
convert_forward(model, module.Phi3Attention, phi3_attention_forward)

elif model.config.model_type == "phi3":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward

convert_forward(model, module.Phi3Attention, phi3_attention_forward)
157 changes: 157 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/phi3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional, Tuple, List
import torch
from torch import nn
import math
import importlib
from transformers.cache_utils import Cache
from ipex_llm.utils.common.log4Error import invalidInputError


def phi3_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
modeling_module_name = self.__class__.__module__
module = importlib.import_module(modeling_module_name)
apply_rotary_pos_emb, repeat_kv = module.apply_rotary_pos_emb, module.repeat_kv
bsz, q_len, _ = hidden_states.size()

qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos:query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim:]

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
f"The cache structure has changed since version v4.36."
f"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching,"
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
else:
causal_mask = None

if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=self.is_causal and causal_mask is None and q_len > 1,
)
attn_weights = None
else:

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of"
f"size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.attention_dropout, training=self.training)

attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

0 comments on commit b9c6699

Please sign in to comment.