forked from sophgo/LLM-TPU
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ea88e9d
commit d736824
Showing
4 changed files
with
848 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
#!/bin/bash | ||
set -ex | ||
models= | ||
mode="int8" | ||
folder="tmp" | ||
num_device=1 | ||
mode_args="" | ||
device_args="" | ||
quantize_args="--quantize W8BF16" | ||
name="" | ||
num_layers= | ||
out_model=$name.bmodel | ||
|
||
while [[ $# -gt 0 ]]; do | ||
key="$1" | ||
|
||
case $key in | ||
--mode) | ||
mode="$2" | ||
shift 2 | ||
;; | ||
--num_device) | ||
num_device="$2" | ||
shift 2 | ||
;; | ||
--name) | ||
name="$2" | ||
shift 2 | ||
;; | ||
*) | ||
echo "Invalid option: $key" >&2 | ||
exit 1 | ||
;; | ||
:) | ||
echo "Option -$OPTARG requires an argument." >&2 | ||
exit 1 | ||
;; | ||
esac | ||
done | ||
|
||
if [ "$name" = "qwen-1_8b" ]; then | ||
num_layers=23 | ||
echo "Compile Qwen-1_8B" | ||
elif [ "$name" = "qwen-7b" ]; then | ||
num_layers=31 | ||
echo "Compile Qwen-7B" | ||
elif [ "$name" = "qwen-14b" ]; then | ||
num_layers=39 | ||
echo "Compile Qwen-14B" | ||
else | ||
>&2 echo -e "Error: Invalid name $name, the input name must be \033[31mqwen-1_8b|qwen-7b|qwen-14b\033[0m" | ||
exit 1 | ||
fi | ||
|
||
if [ x$mode == x"int8" ]; then | ||
quantize_args="--quantize W8BF16" | ||
elif [ x$mode == x"bf16" ]; then | ||
quantize_args="--quantize BF16" | ||
elif [ x$mode == x"int4" ]; then | ||
quantize_args="--quantize W4BF16 --q_group_size 64" | ||
else | ||
echo "Error, unknown quantize mode" | ||
exit 1 | ||
fi | ||
|
||
if [ x$num_device != x1 ]; then | ||
device_args="--num_device $num_device" | ||
out_model=$name'_'$mode'_'$num_device'dev.bmodel' | ||
else | ||
out_model=$name'_'$mode'_1dev.bmodel' | ||
fi | ||
|
||
outdir=${folder}/embedding | ||
mkdir -p $outdir | ||
pushd $outdir | ||
|
||
model_transform.py \ | ||
--model_name embedding \ | ||
--model_def ../onnx/embedding.onnx \ | ||
--mlir embedding.mlir | ||
|
||
model_deploy.py \ | ||
--mlir embedding.mlir \ | ||
--quantize BF16 \ | ||
--quant_input \ | ||
--quant_output \ | ||
--chip bm1684x \ | ||
$device_args \ | ||
--model embedding.bmodel | ||
|
||
model_transform.py \ | ||
--model_name embedding_cache \ | ||
--model_def ../onnx/embedding.onnx \ | ||
--input_shapes [[1,1]] \ | ||
--mlir embedding_cache.mlir | ||
|
||
model_deploy.py \ | ||
--mlir embedding_cache.mlir \ | ||
--quantize BF16 \ | ||
--quant_input \ | ||
--quant_output \ | ||
--chip bm1684x \ | ||
$device_args \ | ||
--model embedding_cache.bmodel | ||
|
||
rm *.npz | ||
|
||
models=$models' '$outdir'/embedding.bmodel '$outdir'/embedding_cache.bmodel ' | ||
|
||
popd | ||
|
||
echo $models | ||
|
||
outdir=${folder}/$mode"_"$num_device"dev"/lm_head | ||
mkdir -p $outdir | ||
pushd $outdir | ||
|
||
model_transform.py \ | ||
--model_name lm_head \ | ||
--model_def ../../onnx/lm_head.onnx \ | ||
--mlir lm_head.mlir | ||
|
||
model_deploy.py \ | ||
--mlir lm_head.mlir \ | ||
$quantize_args \ | ||
--quant_input \ | ||
--quant_output \ | ||
--chip bm1684x \ | ||
$device_args \ | ||
--model lm_head.bmodel | ||
|
||
rm *.npz | ||
|
||
models=${models}${outdir}'/lm_head.bmodel ' | ||
popd | ||
|
||
echo $models | ||
|
||
outdir=tmp/$mode"_"$num_device"dev"/block | ||
mkdir -p $outdir | ||
|
||
pushd $outdir | ||
mkdir -p $outdir | ||
|
||
for ((i=0; i<=$num_layers; i++)); do | ||
|
||
model_transform.py \ | ||
--model_name block_$i \ | ||
--model_def ../../onnx/block_$i.onnx \ | ||
--mlir block_$i.mlir | ||
|
||
model_deploy.py \ | ||
--mlir block_$i.mlir \ | ||
$quantize_args \ | ||
--quant_input \ | ||
--quant_output \ | ||
--chip bm1684x \ | ||
$device_args \ | ||
--model block_$i.bmodel | ||
|
||
model_transform.py \ | ||
--model_name block_cache_$i \ | ||
--model_def ../../onnx/block_cache_$i.onnx \ | ||
--mlir block_cache_$i.mlir | ||
|
||
model_deploy.py \ | ||
--mlir block_cache_$i.mlir \ | ||
$quantize_args \ | ||
--quant_input \ | ||
--quant_output \ | ||
--chip bm1684x \ | ||
$device_args \ | ||
--io_alone \ | ||
--model block_cache_$i.bmodel | ||
|
||
rm *.npz | ||
|
||
models=${models}${outdir}'/block_'$i'.bmodel '$outdir'/block_cache_'$i'.bmodel ' | ||
|
||
done | ||
popd | ||
echo $models | ||
|
||
model_tool --combine $models -o $out_model | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#!/usr/bin/env python3 | ||
# ============================================================================== | ||
# | ||
# Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved. | ||
# | ||
# TPU-MLIR is licensed under the 2-Clause BSD License except for the | ||
# third-party components. | ||
# | ||
# ============================================================================== | ||
|
||
import os | ||
import torch | ||
import argparse | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
parser = argparse.ArgumentParser(description='export onnx.') | ||
parser.add_argument('--model_path', type=str, help='path to the torch model.') | ||
|
||
args = parser.parse_args() | ||
|
||
model_path = args.model_path | ||
folder = f"./tmp/onnx" | ||
|
||
device = torch.device("cuda:0") | ||
origin_model = AutoModelForCausalLM.from_pretrained( | ||
model_path, trust_remote_code=True, | ||
torch_dtype=torch.bfloat16, device_map="auto").eval() | ||
config = origin_model.config | ||
transformer = origin_model.transformer | ||
layers = transformer.h | ||
|
||
SEQ_LENGTH = config.seq_length | ||
NUM_LAYERS = config.num_hidden_layers | ||
HIDDEN_SIZE = config.hidden_size | ||
NUM_ATTENTION_HEADS = config.num_attention_heads | ||
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS | ||
|
||
for param in origin_model.parameters(): | ||
param.requires_grad = False | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
|
||
class Embedding(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input_ids): | ||
out = transformer.wte(input_ids) | ||
return out.float() | ||
|
||
|
||
class QwenBlock(torch.nn.Module): | ||
|
||
def __init__(self, layer_id): | ||
super().__init__() | ||
# params | ||
self.layer_id = layer_id | ||
self.layer = layers[layer_id] | ||
self.rotary_emb = transformer.rotary_emb(SEQ_LENGTH) | ||
self.cos_emb = self.rotary_emb[0].view(SEQ_LENGTH, HEAD_DIM) | ||
self.sin_emb = self.rotary_emb[1].view(SEQ_LENGTH, HEAD_DIM) | ||
|
||
def forward(self, hidden_states, position_ids, attention_mask): | ||
cos_pos = self.cos_emb[position_ids].unsqueeze(2) | ||
sin_pos = self.sin_emb[position_ids].unsqueeze(2) | ||
hidden_states, past_kv = self.layer( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
rotary_pos_emb_list=[[cos_pos, sin_pos]], | ||
# registered_causal_mask=attention_mask, | ||
use_cache=True) | ||
present_k, present_v = past_kv | ||
return hidden_states.float(), present_k.float(), present_v.float() | ||
|
||
|
||
class QwenBlockCache(torch.nn.Module): | ||
|
||
def __init__(self, layer_id): | ||
super().__init__() | ||
# params | ||
self.layer_id = layer_id | ||
self.layer = layers[layer_id] | ||
self.rotary_emb = transformer.rotary_emb(SEQ_LENGTH) | ||
self.cos_emb = self.rotary_emb[0].view(SEQ_LENGTH, HEAD_DIM) | ||
self.sin_emb = self.rotary_emb[1].view(SEQ_LENGTH, HEAD_DIM) | ||
|
||
def forward(self, hidden_states, position_ids, attention_mask, past_k, | ||
past_v): | ||
cos_pos = self.cos_emb[position_ids].unsqueeze(2) | ||
sin_pos = self.sin_emb[position_ids].unsqueeze(2) | ||
hidden_states, past_kv = self.layer( | ||
hidden_states, | ||
layer_past=(past_k, past_v), | ||
attention_mask=attention_mask, | ||
rotary_pos_emb_list=[[cos_pos, sin_pos]], | ||
use_cache=True) | ||
present_k, present_v = past_kv | ||
return hidden_states.float(), present_k.float(), present_v.float() | ||
|
||
|
||
class LmHead(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, hidden_states): | ||
hidden_states = transformer.ln_f(hidden_states) | ||
m_logits = origin_model.lm_head(hidden_states) | ||
_, token = torch.topk(m_logits.float(), 1) | ||
return token | ||
|
||
|
||
def convert_block(layer_id): | ||
# input | ||
hidden_states = torch.randn( | ||
(1, SEQ_LENGTH, HIDDEN_SIZE)).bfloat16().to(device) | ||
position_ids = torch.tensor( | ||
[range(SEQ_LENGTH)], dtype=torch.long).to(device) | ||
attention_mask = torch.randn( | ||
(1, 1, SEQ_LENGTH, SEQ_LENGTH)).bfloat16().to(device) | ||
model = QwenBlock(layer_id) | ||
torch.onnx.export( | ||
model, (hidden_states, position_ids, attention_mask), | ||
f'{folder}/qwen_block_{layer_id}.onnx', | ||
verbose=False, | ||
input_names=['input_states', 'position_ids', 'attention_mask'], | ||
output_names=['hidden_states', 'past_k', 'past_v'], | ||
do_constant_folding=True, | ||
opset_version=15) | ||
|
||
|
||
def convert_block_cache(layer_id): | ||
# input | ||
hidden_states = torch.randn((1, 1, HIDDEN_SIZE)).bfloat16().to(device) | ||
position_ids = torch.tensor([range(1)], dtype=torch.long).to(device) | ||
attention_mask = torch.ones( | ||
(1, 1, 1, SEQ_LENGTH + 1)).bfloat16().to(device) | ||
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).bfloat16().to(device) | ||
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).bfloat16().to(device) | ||
model = QwenBlockCache(layer_id) | ||
|
||
torch.onnx.export( | ||
model, (hidden_states, position_ids, attention_mask, past_k, past_v), | ||
f'{folder}/qwen_block_cache_{layer_id}.onnx', | ||
verbose=False, | ||
input_names=[ | ||
'input_states', 'position_ids', 'attention_mask', 'history_k', | ||
'history_v' | ||
], | ||
output_names=['hidden_states', 'past_k', 'past_v'], | ||
do_constant_folding=True, | ||
opset_version=15) | ||
|
||
|
||
def convert_embedding(): | ||
model = Embedding() | ||
input = torch.tensor([range(SEQ_LENGTH)]).to(device) | ||
torch.onnx.export(model, (input), | ||
f'{folder}/embedding.onnx', | ||
verbose=False, | ||
input_names=['input_ids'], | ||
output_names=['input_embed'], | ||
do_constant_folding=True, | ||
opset_version=15) | ||
|
||
|
||
def convert_lm_head(): | ||
model = LmHead() | ||
input = torch.randn(1, HIDDEN_SIZE).bfloat16().to(device) | ||
torch.onnx.export(model, (input), | ||
f'{folder}/lm_head.onnx', | ||
verbose=False, | ||
input_names=['hidden_states'], | ||
output_names=['token'], | ||
do_constant_folding=True, | ||
opset_version=15) | ||
|
||
# create folder to store onnx | ||
if not os.path.exists(folder): | ||
os.makedirs(folder) | ||
|
||
# export models | ||
for i in range(NUM_LAYERS): | ||
print("convert_block_{}".format(i)) | ||
convert_block_cache(i) | ||
convert_block(i) | ||
|
||
print("convert_embedding") | ||
convert_embedding() | ||
|
||
print("convert_lm_head") | ||
convert_lm_head() | ||
|
Oops, something went wrong.