From bb3f6bfb5fc83ba949a50d847d8137ba3272844f Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 16 Jul 2024 13:48:10 +0800 Subject: [PATCH 1/3] basic --- .../llm/src/ipex_llm/transformers/convert.py | 11 ++++- .../src/ipex_llm/vllm/xpu/model_convert.py | 43 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 38119d0f3c4..6a97c5cb7c5 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -149,9 +149,12 @@ def is_linear_module(module): from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ) + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + VLLM_LINEAR_LIST = [ - ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear + ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear, + ParallelLMHead ] if is_module_in_classes(module, VLLM_LINEAR_LIST): if 'xpu' in _VLLM_VERSION: @@ -167,6 +170,12 @@ def is_linear_module(module): else: # For vllm cpu tp_size = 1 + if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION: + in_features = module.embedding_dim + out_features = module.num_embeddings_per_partition + result = True + mp_group = None + return result, (in_features, out_features, mp_group) in_features = module.input_size out_features = module.output_size result = True diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 7a6f1d2561c..d937066c96e 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -15,18 +15,46 @@ # import torch from vllm.logger import init_logger -from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention +from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention from vllm.model_executor.models.qwen import QWenMLP, QWenAttention from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.sampler import Sampler from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.input_metadata import InputMetadata from vllm.config import DeviceConfig -from typing import Tuple +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather + +from typing import Tuple, Optional from ipex_llm.utils.common import invalidInputError +from vllm.sequence import SamplerOutput + +def _Llama_sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head, hidden_states, + sampling_metadata) + return next_tokens + + +def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + logits = embedding(hidden_states) + # print(f"Before gather: {logits.shape}") + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + def _MLP_forward(self, x): @@ -139,12 +167,22 @@ def _ChatGLM_Attention_forward( GLMAttention: _ChatGLM_Attention_forward } +_REPLACED_SAMPLER_LAYERS = { + LlamaForCausalLM: _Llama_sample, +} + def _model_mlp_convert(): for module, replaced_func in _REPLACED_MLP_LAYERS.items(): setattr(module, "forward", replaced_func) +def _model_sample_convert(): + setattr(Sampler, "_get_logits", _sample_get_logits) + for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items(): + setattr(module, "sample", replaced_func) + + def _model_attention_convert(): for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items(): setattr(module, "forward", replaced_func) @@ -160,6 +198,7 @@ def get_load_function(low_bit): def _ipex_llm_load_model(self) -> None: _model_mlp_convert() _model_attention_convert() + _model_sample_convert() from vllm.utils import measure_device_memory with measure_device_memory() as m: From 629e7d5d385e4917664d18f93aad445296e446e6 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 16 Jul 2024 14:33:58 +0800 Subject: [PATCH 2/3] Modify convert.py --- .../src/ipex_llm/vllm/xpu/model_convert.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index d937066c96e..fccd015382d 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -16,10 +16,11 @@ import torch from vllm.logger import init_logger from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM -from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention -from vllm.model_executor.models.qwen import QWenMLP, QWenAttention +from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM +from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention -from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention +from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM from vllm.model_executor.model_loader import get_model from vllm.model_executor.layers.sampler import Sampler @@ -42,11 +43,35 @@ def _Llama_sample( sampling_metadata) return next_tokens +def _Qwen2_sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> Optional[SamplerOutput]: + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens + else: + lm_head_weight = self.lm_head + next_tokens = self.sampler(lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + +def _Chatglm_sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.transformer.output_layer, hidden_states, + sampling_metadata) + + return next_tokens + + def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: logits = embedding(hidden_states) - # print(f"Before gather: {logits.shape}") if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) @@ -169,6 +194,10 @@ def _ChatGLM_Attention_forward( _REPLACED_SAMPLER_LAYERS = { LlamaForCausalLM: _Llama_sample, + QWenLMHeadModel: _Llama_sample, + ChatGLMForCausalLM: _Chatglm_sample, + Qwen2ForCausalLM: _Qwen2_sample, + BaiChuanBaseForCausalLM: _Llama_sample, } From 7e395edcf49cfb4e4f2ddd4ab91067860f387a69 Mon Sep 17 00:00:00 2001 From: gc-fu Date: Tue, 16 Jul 2024 14:58:29 +0800 Subject: [PATCH 3/3] fix --- python/llm/src/ipex_llm/transformers/convert.py | 5 ++--- python/llm/src/ipex_llm/vllm/xpu/model_convert.py | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 6a97c5cb7c5..d138b3e9967 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -150,10 +150,9 @@ def is_linear_module(module): ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear ) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead - - VLLM_LINEAR_LIST = [ - ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear, + ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, + MergedColumnParallelLinear, ParallelLMHead ] if is_module_in_classes(module, VLLM_LINEAR_LIST): diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index fccd015382d..85719ccdc84 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -34,15 +34,17 @@ from ipex_llm.utils.common import invalidInputError from vllm.sequence import SamplerOutput + def _Llama_sample( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head, hidden_states, - sampling_metadata) + sampling_metadata) return next_tokens + def _Qwen2_sample( self, hidden_states: torch.Tensor, @@ -53,7 +55,7 @@ def _Qwen2_sample( else: lm_head_weight = self.lm_head next_tokens = self.sampler(lm_head_weight, hidden_states, - sampling_metadata) + sampling_metadata) return next_tokens @@ -63,14 +65,13 @@ def _Chatglm_sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.transformer.output_layer, hidden_states, - sampling_metadata) + sampling_metadata) return next_tokens - def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: logits = embedding(hidden_states) if embedding_bias is not None: logits += embedding_bias @@ -81,7 +82,6 @@ def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Mo return logits - def _MLP_forward(self, x): gate_up = self.gate_up_proj(x) x = self.act_fn(gate_up)