Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyuT committed Jul 11, 2024
1 parent 6f725d6 commit 858e09c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def load_model(self, model_path, world_size, low_bit='sym_int4'):
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
cpu_embedding=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True,
Expand Down Expand Up @@ -499,7 +500,8 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
return tuple(result)
else:
# num_layers = self.model.layer_end - self.model.layer_start
for layer_idx in range(self.model.num_layers):
num_cache = min(len(kv_cache_1.key_cache), self.model.num_layers)
for layer_idx in range(num_cache):
kv_cache_1.key_cache[layer_idx] = \
torch.cat([kv_cache_1.key_cache[layer_idx],
kv_cache_2.key_cache[layer_idx]], dim=0)
Expand Down

0 comments on commit 858e09c

Please sign in to comment.