Skip to content

Commit

Permalink
[Qwen] add compile & demo
Browse files Browse the repository at this point in the history
  • Loading branch information
chuxiaoyi2023 committed Feb 8, 2024
1 parent ea88e9d commit d736824
Show file tree
Hide file tree
Showing 4 changed files with 848 additions and 0 deletions.
185 changes: 185 additions & 0 deletions models/Qwen/compile/compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#!/bin/bash
set -ex
models=
mode="int8"
folder="tmp"
num_device=1
mode_args=""
device_args=""
quantize_args="--quantize W8BF16"
name=""
num_layers=
out_model=$name.bmodel

while [[ $# -gt 0 ]]; do
key="$1"

case $key in
--mode)
mode="$2"
shift 2
;;
--num_device)
num_device="$2"
shift 2
;;
--name)
name="$2"
shift 2
;;
*)
echo "Invalid option: $key" >&2
exit 1
;;
:)
echo "Option -$OPTARG requires an argument." >&2
exit 1
;;
esac
done

if [ "$name" = "qwen-1_8b" ]; then
num_layers=23
echo "Compile Qwen-1_8B"
elif [ "$name" = "qwen-7b" ]; then
num_layers=31
echo "Compile Qwen-7B"
elif [ "$name" = "qwen-14b" ]; then
num_layers=39
echo "Compile Qwen-14B"
else
>&2 echo -e "Error: Invalid name $name, the input name must be \033[31mqwen-1_8b|qwen-7b|qwen-14b\033[0m"
exit 1
fi

if [ x$mode == x"int8" ]; then
quantize_args="--quantize W8BF16"
elif [ x$mode == x"bf16" ]; then
quantize_args="--quantize BF16"
elif [ x$mode == x"int4" ]; then
quantize_args="--quantize W4BF16 --q_group_size 64"
else
echo "Error, unknown quantize mode"
exit 1
fi

if [ x$num_device != x1 ]; then
device_args="--num_device $num_device"
out_model=$name'_'$mode'_'$num_device'dev.bmodel'
else
out_model=$name'_'$mode'_1dev.bmodel'
fi

outdir=${folder}/embedding
mkdir -p $outdir
pushd $outdir

model_transform.py \
--model_name embedding \
--model_def ../onnx/embedding.onnx \
--mlir embedding.mlir

model_deploy.py \
--mlir embedding.mlir \
--quantize BF16 \
--quant_input \
--quant_output \
--chip bm1684x \
$device_args \
--model embedding.bmodel

model_transform.py \
--model_name embedding_cache \
--model_def ../onnx/embedding.onnx \
--input_shapes [[1,1]] \
--mlir embedding_cache.mlir

model_deploy.py \
--mlir embedding_cache.mlir \
--quantize BF16 \
--quant_input \
--quant_output \
--chip bm1684x \
$device_args \
--model embedding_cache.bmodel

rm *.npz

models=$models' '$outdir'/embedding.bmodel '$outdir'/embedding_cache.bmodel '

popd

echo $models

outdir=${folder}/$mode"_"$num_device"dev"/lm_head
mkdir -p $outdir
pushd $outdir

model_transform.py \
--model_name lm_head \
--model_def ../../onnx/lm_head.onnx \
--mlir lm_head.mlir

model_deploy.py \
--mlir lm_head.mlir \
$quantize_args \
--quant_input \
--quant_output \
--chip bm1684x \
$device_args \
--model lm_head.bmodel

rm *.npz

models=${models}${outdir}'/lm_head.bmodel '
popd

echo $models

outdir=tmp/$mode"_"$num_device"dev"/block
mkdir -p $outdir

pushd $outdir
mkdir -p $outdir

for ((i=0; i<=$num_layers; i++)); do

model_transform.py \
--model_name block_$i \
--model_def ../../onnx/block_$i.onnx \
--mlir block_$i.mlir

model_deploy.py \
--mlir block_$i.mlir \
$quantize_args \
--quant_input \
--quant_output \
--chip bm1684x \
$device_args \
--model block_$i.bmodel

model_transform.py \
--model_name block_cache_$i \
--model_def ../../onnx/block_cache_$i.onnx \
--mlir block_cache_$i.mlir

model_deploy.py \
--mlir block_cache_$i.mlir \
$quantize_args \
--quant_input \
--quant_output \
--chip bm1684x \
$device_args \
--io_alone \
--model block_cache_$i.bmodel

rm *.npz

models=${models}${outdir}'/block_'$i'.bmodel '$outdir'/block_cache_'$i'.bmodel '

done
popd
echo $models

model_tool --combine $models -o $out_model

194 changes: 194 additions & 0 deletions models/Qwen/compile/export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================

import os
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer

parser = argparse.ArgumentParser(description='export onnx.')
parser.add_argument('--model_path', type=str, help='path to the torch model.')

args = parser.parse_args()

model_path = args.model_path
folder = f"./tmp/onnx"

device = torch.device("cuda:0")
origin_model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True,
torch_dtype=torch.bfloat16, device_map="auto").eval()
config = origin_model.config
transformer = origin_model.transformer
layers = transformer.h

SEQ_LENGTH = config.seq_length
NUM_LAYERS = config.num_hidden_layers
HIDDEN_SIZE = config.hidden_size
NUM_ATTENTION_HEADS = config.num_attention_heads
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS

for param in origin_model.parameters():
param.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

class Embedding(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, input_ids):
out = transformer.wte(input_ids)
return out.float()


class QwenBlock(torch.nn.Module):

def __init__(self, layer_id):
super().__init__()
# params
self.layer_id = layer_id
self.layer = layers[layer_id]
self.rotary_emb = transformer.rotary_emb(SEQ_LENGTH)
self.cos_emb = self.rotary_emb[0].view(SEQ_LENGTH, HEAD_DIM)
self.sin_emb = self.rotary_emb[1].view(SEQ_LENGTH, HEAD_DIM)

def forward(self, hidden_states, position_ids, attention_mask):
cos_pos = self.cos_emb[position_ids].unsqueeze(2)
sin_pos = self.sin_emb[position_ids].unsqueeze(2)
hidden_states, past_kv = self.layer(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb_list=[[cos_pos, sin_pos]],
# registered_causal_mask=attention_mask,
use_cache=True)
present_k, present_v = past_kv
return hidden_states.float(), present_k.float(), present_v.float()


class QwenBlockCache(torch.nn.Module):

def __init__(self, layer_id):
super().__init__()
# params
self.layer_id = layer_id
self.layer = layers[layer_id]
self.rotary_emb = transformer.rotary_emb(SEQ_LENGTH)
self.cos_emb = self.rotary_emb[0].view(SEQ_LENGTH, HEAD_DIM)
self.sin_emb = self.rotary_emb[1].view(SEQ_LENGTH, HEAD_DIM)

def forward(self, hidden_states, position_ids, attention_mask, past_k,
past_v):
cos_pos = self.cos_emb[position_ids].unsqueeze(2)
sin_pos = self.sin_emb[position_ids].unsqueeze(2)
hidden_states, past_kv = self.layer(
hidden_states,
layer_past=(past_k, past_v),
attention_mask=attention_mask,
rotary_pos_emb_list=[[cos_pos, sin_pos]],
use_cache=True)
present_k, present_v = past_kv
return hidden_states.float(), present_k.float(), present_v.float()


class LmHead(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, hidden_states):
hidden_states = transformer.ln_f(hidden_states)
m_logits = origin_model.lm_head(hidden_states)
_, token = torch.topk(m_logits.float(), 1)
return token


def convert_block(layer_id):
# input
hidden_states = torch.randn(
(1, SEQ_LENGTH, HIDDEN_SIZE)).bfloat16().to(device)
position_ids = torch.tensor(
[range(SEQ_LENGTH)], dtype=torch.long).to(device)
attention_mask = torch.randn(
(1, 1, SEQ_LENGTH, SEQ_LENGTH)).bfloat16().to(device)
model = QwenBlock(layer_id)
torch.onnx.export(
model, (hidden_states, position_ids, attention_mask),
f'{folder}/qwen_block_{layer_id}.onnx',
verbose=False,
input_names=['input_states', 'position_ids', 'attention_mask'],
output_names=['hidden_states', 'past_k', 'past_v'],
do_constant_folding=True,
opset_version=15)


def convert_block_cache(layer_id):
# input
hidden_states = torch.randn((1, 1, HIDDEN_SIZE)).bfloat16().to(device)
position_ids = torch.tensor([range(1)], dtype=torch.long).to(device)
attention_mask = torch.ones(
(1, 1, 1, SEQ_LENGTH + 1)).bfloat16().to(device)
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).bfloat16().to(device)
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM)).bfloat16().to(device)
model = QwenBlockCache(layer_id)

torch.onnx.export(
model, (hidden_states, position_ids, attention_mask, past_k, past_v),
f'{folder}/qwen_block_cache_{layer_id}.onnx',
verbose=False,
input_names=[
'input_states', 'position_ids', 'attention_mask', 'history_k',
'history_v'
],
output_names=['hidden_states', 'past_k', 'past_v'],
do_constant_folding=True,
opset_version=15)


def convert_embedding():
model = Embedding()
input = torch.tensor([range(SEQ_LENGTH)]).to(device)
torch.onnx.export(model, (input),
f'{folder}/embedding.onnx',
verbose=False,
input_names=['input_ids'],
output_names=['input_embed'],
do_constant_folding=True,
opset_version=15)


def convert_lm_head():
model = LmHead()
input = torch.randn(1, HIDDEN_SIZE).bfloat16().to(device)
torch.onnx.export(model, (input),
f'{folder}/lm_head.onnx',
verbose=False,
input_names=['hidden_states'],
output_names=['token'],
do_constant_folding=True,
opset_version=15)

# create folder to store onnx
if not os.path.exists(folder):
os.makedirs(folder)

# export models
for i in range(NUM_LAYERS):
print("convert_block_{}".format(i))
convert_block_cache(i)
convert_block(i)

print("convert_embedding")
convert_embedding()

print("convert_lm_head")
convert_lm_head()

Loading

0 comments on commit d736824

Please sign in to comment.