Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

[Model Enabling] Support ChatGLM3 #182

Merged
merged 15 commits into from
Mar 21, 2024
81 changes: 81 additions & 0 deletions docs/prompt_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Prompt template

This document will show some examples to introduce how to correctly use prompt templates in Neural Speed and [ITREX](https://github.com/intel/intel-extension-for-transformers).

For the base model (without SFT for pre-training), prompt can be directly encoded into token ids without adding any special prefix or suffix token. But for the chat model, we need some prompt templates to generate correct and human understandable words. The reason is that these models are usually trained with specific prompt templates.

## Chat with ChatGLM3:
```python
from transformers import AutoTokenizer, TextStreamer
from neural_speed import Model

prompt = "你好"
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
inputs = tokenizer.build_chat_input(prompt)['input_ids']
model = Model()
model.init_from_bin(args.model_name, gguf_path)
outputs = model.generate(inputs, max_new_tokens=300, do_sample=True)
words = tokenizer.decode(outputs[0])
```

## Chat with LLaMA2:
Zhenzhong1 marked this conversation as resolved.
Show resolved Hide resolved

```python
from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

# Please change to local path to model, llama2 does not support online conversion, currently.
model_name = "meta-llama/Llama-2-7b-chat-hf"
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
prompt = input("> ").strip()
if prompt == "quit":
break
b_prompt = "[INST]{}[/INST]".format(prompt) # prompt template for llama2
inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True)
```

## Chat with ChatGLM2:
```python
from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "THUDM/chatglm2-6b" # or local path to model
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
prompt = input("> ").strip()
if prompt == "quit":
break
prompt = tokenizer.build_prompt(prompt) # prompt template for chatglm2
inputs = tokenizer([prompt], return_tensors="pt").input_ids
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True, n_keep=2)
```

## Chat with Qwen:
```python
from transformers import AutoTokenizer, TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig

model_name = "Qwen/Qwen-7B-Chat" # or local path to model
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
streamer = TextStreamer(tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

while True:
prompt = input("> ").strip()
if prompt == "quit":
break
prompt = "\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n".format(prompt) # prompt template for qwen
inputs = tokenizer([prompt], return_tensors="pt").input_ids
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True)
```
3 changes: 2 additions & 1 deletion docs/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ Neural Speed supports the following models:
</tr>
<tr>
<td><a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank" rel="noopener noreferrer">ChatGLM-6B</a>,
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a></td>
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a>,
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a></td>
<td>✅</td>
<td> </td>
<td> </td>
Expand Down
10 changes: 8 additions & 2 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


class Model:

def __init__(self):
self.module = None
self.model = None
Expand Down Expand Up @@ -55,7 +56,7 @@ def __import_package(self, model_type):
import neural_speed.bloom_cpp as cpp_model
elif model_type == "chatglm":
import neural_speed.chatglm_cpp as cpp_model
elif model_type == "chatglm2":
elif model_type == "chatglm2" or model_type == "chatglm3":
import neural_speed.chatglm2_cpp as cpp_model
elif model_type == "baichuan":
import neural_speed.baichuan_cpp as cpp_model
Expand Down Expand Up @@ -85,6 +86,11 @@ def get_model_type(model_config):
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
model_type = "chatglm2"

# For ChatGLM3
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"

# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
model_type = "falcon"
Expand Down Expand Up @@ -200,7 +206,7 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):

def get_max_seq_length():
config = self.config.to_dict()
# chatglm2, bloom
# chatglm2, bloom, chatglm3
if 'seq_length' in config:
return config['seq_length']
# qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi
Expand Down
4 changes: 3 additions & 1 deletion neural_speed/application/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ compile_quant(quant_bloom quant_model.cpp bloom bloom)

compile_quant(quant_chatglm quant_model.cpp chatglm chatglm)
compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)
compile_quant(quant_chatglm3 quant_model.cpp chatglm2 chatglm2)
compile_quant(quant_baichuan quant_model.cpp baichuan baichuan)
compile_quant(quant_mistral quant_model.cpp mistral llama)
compile_quant(quant_mixtral quant_model.cpp mixtral llama)
Expand Down Expand Up @@ -97,7 +98,7 @@ set(mymap_phi 16)
set(mymap_stablelm 17)
set(mymap_whisper 18)
set(mymap_mixtral 19)

set(mymap_chatglm3 20)


function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB)
Expand Down Expand Up @@ -128,6 +129,7 @@ compile_run(run_starcoder main_run.cpp main_pybind.cpp starcoder starcoder)
compile_run(run_opt main_run.cpp main_pybind.cpp opt opt)
compile_run(run_bloom main_run.cpp main_pybind.cpp bloom bloom)
compile_run(run_chatglm2 main_run.cpp main_pybind.cpp chatglm2 chatglm2)
compile_run(run_chatglm3 main_run.cpp main_pybind.cpp chatglm3 chatglm3)
compile_run(run_chatglm main_run.cpp main_pybind.cpp chatglm chatglm)
compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan)
compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama)
Expand Down
4 changes: 4 additions & 0 deletions neural_speed/application/main_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,10 @@ PYBIND11_MODULE(whisper_cpp, m)

PYBIND11_MODULE(mixtral_cpp, m)

#elif MODEL_NAME_ID == 20

PYBIND11_MODULE(chatglm3_cpp, m)

#endif
{
m.doc() = "cpp model python binding";
Expand Down
5 changes: 3 additions & 2 deletions neural_speed/application/main_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ int main(int argc, char** argv) { // NOLINT
std::string prompt = build_prompt_glm2(prompts);
embd_inp = ::model_tokenize(ctx, prompt, false);
embd_inp.insert(embd_inp.begin(), {64790, 64792}); // special prefix
} else if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_BAICHUAN) {
} else if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_BAICHUAN ||
params.model_arch == MODEL_CHATGLM3) {
for (auto& i : params.ids) {
embd_inp.emplace_back(i);
}
Expand Down Expand Up @@ -646,7 +647,7 @@ int main(int argc, char** argv) { // NOLINT

// display text
if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_CHATGLM2 ||
params.model_arch == MODEL_BAICHUAN) {
params.model_arch == MODEL_BAICHUAN || params.model_arch == MODEL_CHATGLM3) {
static bool is_prompt = true;
if (input_echo) {
if (is_prompt == true) {
Expand Down
49 changes: 49 additions & 0 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,56 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
return tensor


def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q8_0 in ggml.c
assert tensor.shape[1] % GGML_QK8_0 == 0
tensor = tensor.view(-1, GGML_QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
return tensor


def quantize_q5_0(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q5_0 in ggml.c
assert tensor.shape[1] % GGML_QK5_0 == 0
tensor = tensor.view(-1, GGML_QK5_0)
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
scale = max_values / -16
tensor = (tensor / scale + 16).round().clamp(min=0, max=31).char()
qs = (tensor[:, :16] & 0x0F) | (tensor[:, 16:] << 4)
qh = torch.zeros(tensor.shape[:-1], dtype=torch.int32)
for i in range(32):
qh |= ((tensor[:, i] & 0x10) >> 4).int() << i

# add scale into each block
tensor = torch.cat((scale.half().view(torch.int8), qh[..., None].view(torch.int8), qs), dim=-1)
return tensor


def quantize_q5_1(tensor: torch.Tensor) -> torch.Tensor:
# equivalent to ggml_quantize_q5_1 in ggml.c
assert tensor.shape[1] % GGML_QK5_1 == 0
tensor = tensor.view(-1, GGML_QK5_1)
min_vals = tensor.min(dim=-1, keepdim=True).values
max_vals = tensor.max(dim=-1, keepdim=True).values
scale = (max_vals - min_vals) / ((1 << 5) - 1)
tensor = ((tensor - min_vals) / scale).round().clamp(min=0, max=31).char()
qs = (tensor[:, :16] & 0x0F) | (tensor[:, 16:] << 4)
qh = torch.zeros(tensor.shape[:-1], dtype=torch.int32)
for i in range(32):
qh |= ((tensor[:, i] & 0x10) >> 4).int() << i

# add scale & min into each block
tensor = torch.cat(
(scale.half().view(torch.int8), min_vals.half().view(torch.int8), qh[..., None].view(torch.int8), qs), dim=-1)
return tensor


class SentencePieceVocab:

def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
Expand Down
Loading
Loading