Skip to content

Commit

Permalink
[MiniCPM3] support minicpm3-4b
Browse files Browse the repository at this point in the history
  • Loading branch information
sgwzy22 committed Nov 23, 2024
1 parent c67420a commit 93aff7b
Show file tree
Hide file tree
Showing 19 changed files with 181,540 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ core
*.tar.gz
*.deb
chat*.so
gen*
verify*
16 changes: 9 additions & 7 deletions models/Llama2/compile/compile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ elif [ "$name" = "llama2-13b" ]; then
num_layers=40
hidden_size=5120
echo "Compile Llama2-13B"
elif [ "$name" = "llama2-1.3b" ]; then
num_layers=24
hidden_size=2048
else
>&2 echo -e "Error: Invalid name $name, the input name must be \033[31mllama2-7b|llama2-13b\033[0m"
>&2 echo -e "Error: Invalid name $name, the input name must be \033[31mllama2-1.3b|llama2-7b|llama2-13b\033[0m"
exit 1
fi

Expand All @@ -79,22 +82,22 @@ fi

if [ x$num_device != x1 ]; then
device_args="--num_device $num_device"
out_model=$name'_'$mode'_'$num_device'dev_'$seq_length'.bmodel'
out_model=$name'_'$mode'_'$num_device'dev_'$seq_length'seq.bmodel'
else
out_model=$name'_'$mode'_1dev_'$seq_length'.bmodel'
out_model=$name'_'$mode'_1dev_'$seq_length'seq.bmodel'
fi

if [ x$addr_mode == x"io_alone" ]; then
addr_args="--addr_mode io_alone"
fi

outdir=${folder}/${mode}_${num_device}/embedding
outdir=${folder}/embedding
mkdir -p $outdir
pushd $outdir

model_transform.py \
--model_name embedding \
--model_def ../../onnx/embedding.pt \
--model_def ../onnx/embedding.pt \
--input_shapes [[1,$seq_length]] \
--input_types "int32" \
--mlir embedding.mlir
Expand All @@ -110,7 +113,7 @@ model_deploy.py \

model_transform.py \
--model_name embedding_cache \
--model_def ../../onnx/embedding.pt \
--model_def ../onnx/embedding.pt \
--input_shapes [[1,1]] \
--input_types "int32" \
--mlir embedding_cache.mlir
Expand Down Expand Up @@ -202,7 +205,6 @@ outdir=${folder}/$mode"_"$num_device"dev"/block
mkdir -p $outdir

pushd $outdir
mkdir -p $outdir

for ((i=0; i<$num_layers; i++)); do
model_transform.py \
Expand Down
7 changes: 3 additions & 4 deletions models/Llama2/compile/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import torch
import argparse
from tqdm import tqdm
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import LlamaForCausalLM
torch.set_grad_enabled(False)

parser = argparse.ArgumentParser(description='export onnx.')
parser.add_argument('--model_path', type=str, help='path to the torch model.')
parser.add_argument('--seq_length', type=int, default=512, help="sequence length")
parser.add_argument('-m', '--model_path', type=str, help='path to the torch model.')
parser.add_argument('-s', '--seq_length', type=int, default=512, help="sequence length")
parser.add_argument('--lmhead_with_topk', type=int, default=0, help="only trace the LmHeadWithTopK")

args = parser.parse_args()
Expand All @@ -44,7 +44,6 @@

print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n')

tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)

class Embedding(torch.nn.Module):

Expand Down
2 changes: 1 addition & 1 deletion models/Llama3_2-Vision/python_demo/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def chat(self):
self.image = Image.open(image_path)
print(f'load new image:"{image_path}"')
except:
print(f"load image:"{image_path}" faild, load origin image:"{args.image_path}" instead")
print(f'load image:"{image_path}" faild, load origin image:"{args.image_path}" instead')
self.clear()
# Chat
else:
Expand Down
80 changes: 80 additions & 0 deletions models/MiniCPM3/compile/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Command

## Export onnx

```shell
pip install -r requirements.txt
cp files/MiniCPM3-4B/modeling_minicpm.py ${your_torch_model}/modeling_minicpm.py
```
your_torch_model是你模型下载的位置,比如 MiniCPM3-4B/

```shell
python3 export_onnx.py --model_path your_torch_model --seq_length 8192 --device cpu
```
* 风险点:尤其注意,如果使用--device cpu在cpu上导出,使用的精度是float32,与训练精度bfloat16不一致,可能导致精度问题
* 如果有cuda,建议使用cuda导出

## Compile bmodel
使用io_alone, int4精度
```shell
./compile.sh --mode int4 --name minicpm3-4b --addr_mode io_alone --seq_length 8192
```
使用io_alone, int8精度
```shell
./compile.sh --mode int8 --name minicpm3-4b --addr_mode io_alone --seq_length 8192
```

### 下载迁移好的模型
也可以直接下载编译好的模型,不用自己编译
```shell
pip3 install dfss
python3 -m dfss [email protected]:/ext_model_information/LLM/LLM-TPU/minicpm3-4b_int4_seq512_1dev.bmodel
```

## python demo

请见python_demo里面的README

### modeling_minicpm.py代码修改

#### 第一处:MiniCPM3Model

* 在初始化中添加
```python
config._attn_implementation = "eager"
```
目的是不使用torch的flash attention或者sdpa attenion,导出原始的attention结构便于在mlir编译时匹配优化pattern。

#### 第二处:修改旋转位置编码
原代码:
```python
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
```
修改后
```python
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=2):
orig_dtype = k.dtype
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, seq_len, 1, dim]
q_fp32 = q.to(dtype=torch.float32, device=q.device)
k_fp32 = k.to(dtype=torch.float32, device=k.device)
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
```

#### 第三处:MiniCPMAttention

* 主要修改了past_key_value和position_embedding的输入方式,将position_embedding进行常量折叠,导出onnx结构。

* 其次修改了attention计算时的一些permute和concat操作,用于后续mlir编译模型时的pattern匹配和简化模型

* 具体修改可对比原始模型文件
Loading

0 comments on commit 93aff7b

Please sign in to comment.