Skip to content

Commit

Permalink
fix internlm xcomposser stream chat (#11564)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 11, 2024
1 parent b9c6699 commit a945500
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 24 deletions.
13 changes: 10 additions & 3 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,12 +1259,19 @@ def _optimize_post(model, lightweight_bmm=False):
elif model.config.model_type == "internlmxcomposer2":
modeling_module_name = model.model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_attention_forward
from ipex_llm.transformers.models.internlm import (
internlm_xcomposser2_attention_forward,
internlm_xcomposser2_mlp_forward,
internlm_xcomposser2_model_forward_wrapper,
internlm_xcomposser2_chat
)
convert_forward(model, module.InternLM2Attention, internlm_xcomposser2_attention_forward)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_mlp_forward
convert_forward(model, module.InternLM2MLP, internlm_xcomposser2_mlp_forward)
convert_forward(model, module.InternLM2RMSNorm, llama_rms_norm_forward)
from ipex_llm.transformers.models.internlm import internlm_xcomposser2_chat
internlm_xcomposser2_model_forward = internlm_xcomposser2_model_forward_wrapper(
module.InternLM2Model.forward
)
convert_forward(model, module.InternLM2Model, internlm_xcomposser2_model_forward)
model.chat = MethodType(internlm_xcomposser2_chat, model)
elif model.config.model_type == "qwen":
if hasattr(model.config, "visual"):
Expand Down
86 changes: 65 additions & 21 deletions python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def add_lora(x: torch.Tensor, result: torch.Tensor,
Plora_A: torch.nn.Linear = None, Plora_B: torch.nn.Linear = None):
invalidInputError(x.dim() == 3 and result.dim() == 3,
"`x` and `result` should have 3 dims")
if len(im_mask) == 0 or x.size(1) == 1:
if isinstance(im_mask, torch.Tensor) or len(im_mask) == 0:
return result
else:
for start_idx, end_idx in im_mask:
Expand All @@ -320,6 +320,56 @@ def add_lora(x: torch.Tensor, result: torch.Tensor,
return result


def internlm_xcomposser2_model_forward_wrapper(origin_forward):
def internlm_xcomposser2_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
):
im_mask = kwargs.get('im_mask', None)
if im_mask is None or im_mask.size(-1) <= 1 or im_mask.sum() == 0:
# decoding or no image input, `im_mask` is not needed
kwargs['im_mask'] = []
else:
# replace im_mask with start_idx and end_idx to improve performance
im_mask = im_mask.cpu().flatten().tolist()
length = len(im_mask)
new_mask = []
i = 0
while i < length:
while i < length and not im_mask[i]:
i = i + 1
start_idx = i
while i < length and im_mask[i]:
i = i + 1
end_idx = i
if start_idx != end_idx:
new_mask.append((start_idx, end_idx))
kwargs['im_mask'] = new_mask
return origin_forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
return internlm_xcomposser2_model_forward


def internlm_xcomposser2_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -466,32 +516,26 @@ def internlm_xcomposser2_chat(
**kwargs,
):
# ipex-llm changes start: fix device and dtype conversion
# replace im_mask with start_idx and end_idx to improve performance
if image is None:
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
im_mask = []
im_mask = torch.zeros(inputs['input_ids'].shape[:2]).bool()
else:
image = self.encode_img(image)
inputs, im_mask = self.interleav_wrap_chat(tokenizer, query, image,
history, meta_instruction)
mask = im_mask.cpu().flatten().tolist()
length = len(mask)
im_mask = []
i = 0
while i < length:
while i < length and not mask[i]:
i = i + 1
start_idx = i
while i < length and mask[i]:
i = i + 1
end_idx = i
if start_idx != end_idx:
im_mask.append((start_idx, end_idx))

inputs = {
k: v.to(device=self.device, dtype=self.dtype)
for k, v in inputs.items() if torch.is_tensor(v)
}

new_inputs = {}
for k, v in inputs.items():
if torch.is_tensor(v):
if v.dtype.is_floating_point:
new_inputs[k] = v.to(device=self.device, dtype=self.dtype)
else:
# input_ids, don't convert its dtype
new_inputs[k] = v.to(device=self.device)
else:
new_inputs[k] = v
inputs = new_inputs
im_mask = im_mask.to(self.device)
# ipex-llm changes end

# also add end-of-assistant token in eos token id to avoid unnecessary generation
Expand Down

0 comments on commit a945500

Please sign in to comment.