diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 38119d0f3c4..d138b3e9967 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -149,9 +149,11 @@ def is_linear_module(module): from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ) - + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead VLLM_LINEAR_LIST = [ - ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear + ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, + MergedColumnParallelLinear, + ParallelLMHead ] if is_module_in_classes(module, VLLM_LINEAR_LIST): if 'xpu' in _VLLM_VERSION: @@ -167,6 +169,12 @@ def is_linear_module(module): else: # For vllm cpu tp_size = 1 + if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION: + in_features = module.embedding_dim + out_features = module.num_embeddings_per_partition + result = True + mp_group = None + return result, (in_features, out_features, mp_group) in_features = module.input_size out_features = module.output_size result = True diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 7a6f1d2561c..85719ccdc84 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -15,18 +15,71 @@ # import torch from vllm.logger import init_logger -from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention -from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention -from vllm.model_executor.models.qwen import QWenMLP, QWenAttention +from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM +from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM +from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention -from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention +from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.sampler import Sampler from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.input_metadata import InputMetadata from vllm.config import DeviceConfig -from typing import Tuple +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather + +from typing import Tuple, Optional from ipex_llm.utils.common import invalidInputError +from vllm.sequence import SamplerOutput + + +def _Llama_sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head, hidden_states, + sampling_metadata) + return next_tokens + + +def _Qwen2_sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> Optional[SamplerOutput]: + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens + else: + lm_head_weight = self.lm_head + next_tokens = self.sampler(lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + +def _Chatglm_sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.transformer.output_layer, hidden_states, + sampling_metadata) + + return next_tokens + + +def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + logits = embedding(hidden_states) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits def _MLP_forward(self, x): @@ -139,12 +192,26 @@ def _ChatGLM_Attention_forward( GLMAttention: _ChatGLM_Attention_forward } +_REPLACED_SAMPLER_LAYERS = { + LlamaForCausalLM: _Llama_sample, + QWenLMHeadModel: _Llama_sample, + ChatGLMForCausalLM: _Chatglm_sample, + Qwen2ForCausalLM: _Qwen2_sample, + BaiChuanBaseForCausalLM: _Llama_sample, +} + def _model_mlp_convert(): for module, replaced_func in _REPLACED_MLP_LAYERS.items(): setattr(module, "forward", replaced_func) +def _model_sample_convert(): + setattr(Sampler, "_get_logits", _sample_get_logits) + for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items(): + setattr(module, "sample", replaced_func) + + def _model_attention_convert(): for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items(): setattr(module, "forward", replaced_func) @@ -160,6 +227,7 @@ def get_load_function(low_bit): def _ipex_llm_load_model(self) -> None: _model_mlp_convert() _model_attention_convert() + _model_sample_convert() from vllm.utils import measure_device_memory with measure_device_memory() as m: