Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ipex-llm optimization for lm head #11589

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def is_linear_module(module):
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
)

from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear,
ParallelLMHead
]
if is_module_in_classes(module, VLLM_LINEAR_LIST):
if 'xpu' in _VLLM_VERSION:
Expand All @@ -167,6 +169,12 @@ def is_linear_module(module):
else:
# For vllm cpu
tp_size = 1
if isinstance(module, ParallelLMHead) and 'xpu' in _VLLM_VERSION:
in_features = module.embedding_dim
out_features = module.num_embeddings_per_partition
result = True
mp_group = None
return result, (in_features, out_features, mp_group)
in_features = module.input_size
out_features = module.output_size
result = True
Expand Down
78 changes: 73 additions & 5 deletions python/llm/src/ipex_llm/vllm/xpu/model_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,71 @@
#
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.sampler import Sampler

from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.input_metadata import InputMetadata
from vllm.config import DeviceConfig
from typing import Tuple
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_gather

from typing import Tuple, Optional
from ipex_llm.utils.common import invalidInputError
from vllm.sequence import SamplerOutput


def _Llama_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head, hidden_states,
sampling_metadata)
return next_tokens


def _Qwen2_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens
else:
lm_head_weight = self.lm_head
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens


def _Chatglm_sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.transformer.output_layer, hidden_states,
sampling_metadata)

return next_tokens


def _sample_get_logits(self, hidden_states: torch.Tensor, embedding: torch.nn.Module,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
logits = embedding(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits


def _MLP_forward(self, x):
Expand Down Expand Up @@ -139,12 +192,26 @@ def _ChatGLM_Attention_forward(
GLMAttention: _ChatGLM_Attention_forward
}

_REPLACED_SAMPLER_LAYERS = {
LlamaForCausalLM: _Llama_sample,
QWenLMHeadModel: _Llama_sample,
ChatGLMForCausalLM: _Chatglm_sample,
Qwen2ForCausalLM: _Qwen2_sample,
BaiChuanBaseForCausalLM: _Llama_sample,
}


def _model_mlp_convert():
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
setattr(module, "forward", replaced_func)


def _model_sample_convert():
setattr(Sampler, "_get_logits", _sample_get_logits)
for module, replaced_func in _REPLACED_SAMPLER_LAYERS.items():
setattr(module, "sample", replaced_func)


def _model_attention_convert():
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
setattr(module, "forward", replaced_func)
Expand All @@ -160,6 +227,7 @@ def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
_model_mlp_convert()
_model_attention_convert()
_model_sample_convert()

from vllm.utils import measure_device_memory
with measure_device_memory() as m:
Expand Down
Loading