Skip to content

Commit

Permalink
Enable ipex-llm optimization for lm head (intel-analytics#11589)
Browse files Browse the repository at this point in the history
* basic

* Modify convert.py

* fix
  • Loading branch information
gc-fu authored and MeouSker77 committed Jul 19, 2024
1 parent 1b74af2 commit 13477e1
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 7 deletions.
12 changes: 10 additions & 2 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def is_linear_module(module):
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)

from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear,
ParallelLMHead
]
if is_module_in_classes(module, VLLM_LINEAR_LIST):
if 'xpu' in _VLLM_VERSION:
Expand All @@ -167,6 +169,12 @@ def is_linear_module(module):
else:
# For vllm cpu
tp_size = 1
if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION:
in_features = module.embedding_dim
out_features = module.num_embeddings_per_partition
result = True
mp_group = None
return result, (in_features, out_features, mp_group)
in_features = module.input_size
out_features = module.output_size
result = True
Expand Down
78 changes: 73 additions & 5 deletions python/llm/src/ipex_llm/vllm/xpu/model_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,71 @@
#
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.sampler import Sampler

from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.input_metadata import InputMetadata
from vllm.config import DeviceConfig
from typing import Tuple
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather

from typing import Tuple, Optional
from ipex_llm.utils.common import invalidInputError
from vllm.sequence import SamplerOutput


def _Llama_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head, hidden_states,
sampling_metadata)
return next_tokens


def _Qwen2_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens
else:
lm_head_weight = self.lm_head
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens


def _Chatglm_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
sampling_metadata)

return next_tokens


def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
logits = embedding(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits


def _MLP_forward(self, x):
Expand Down Expand Up @@ -139,12 +192,26 @@ def _ChatGLM_Attention_forward(
GLMAttention: _ChatGLM_Attention_forward
}

_REPLACED_SAMPLER_LAYERS = {
LlamaForCausalLM: _Llama_sample,
QWenLMHeadModel: _Llama_sample,
ChatGLMForCausalLM: _Chatglm_sample,
Qwen2ForCausalLM: _Qwen2_sample,
BaiChuanBaseForCausalLM: _Llama_sample,
}


def _model_mlp_convert():
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
setattr(module, "forward", replaced_func)


def _model_sample_convert():
setattr(Sampler, "_get_logits", _sample_get_logits)
for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
setattr(module, "sample", replaced_func)


def _model_attention_convert():
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
setattr(module, "forward", replaced_func)
Expand All @@ -160,6 +227,7 @@ def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
_model_mlp_convert()
_model_attention_convert()
_model_sample_convert()

from vllm.utils import measure_device_memory
with measure_device_memory() as m:
Expand Down

0 comments on commit 13477e1

Please sign in to comment.