Skip to content

Commit

Permalink
Glm4 9b chat 适配 (#51)
Browse files Browse the repository at this point in the history
* Update download_model.py

* Update requirements.txt

* Update test_models.py

* glm4

* Update requirements.txt

* Create glm4_predictor.py

* Update glm4_predictor.py

* Update glm4_predictor.py

* Update app.py

* Update requirements.txt

* Update glm4_predictor.py

* Update requirements.txt

* Update requirements.txt

* Update glm4_predictor.py

* Update app.py

* Update base.py

* predict_mode dict

* predict_mode tuple
  • Loading branch information
ypwhs authored Jul 6, 2024
1 parent 9f30f51 commit 8e77164
Show file tree
Hide file tree
Showing 14 changed files with 2,037 additions and 40 deletions.
70 changes: 53 additions & 17 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
print('Done'.center(64, '-'))

# 加载模型
model_name = 'THUDM/chatglm3-6b'
model_name = 'THUDM/glm-4-9b-chat-1m'
int4 = True

if 'chatglm3' in model_name.lower():
if 'glm-4' in model_name.lower():
from predictors.glm4_predictor import GLM4
predictor = GLM4(model_name, int4=int4)
elif 'chatglm3' in model_name.lower():
from predictors.chatglm3_predictor import ChatGLM3
predictor = ChatGLM3(model_name)
elif 'chatglm2' in model_name.lower():
Expand Down Expand Up @@ -59,8 +63,9 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
if len(query) == 0:
print("Please input a query first.")
return
for x in predictor.predict_continue(query, continue_message, max_length, top_p,
temperature, allow_generate, history, last_state):
for x in predictor.predict_continue(query, continue_message, max_length,
top_p, temperature, allow_generate,
history, last_state):
yield x


Expand All @@ -69,8 +74,7 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
width: inherit !important;
padding-left: 20px !important;
}""") as demo:
gr.Markdown(
f"""
gr.Markdown(f"""
# 💡Creative ChatGLM WebUI
👋 欢迎来到 ChatGLM 创意世界![https://github.com/ypwhs/CreativeChatGLM](https://github.com/ypwhs/CreativeChatGLM)
Expand All @@ -82,14 +86,34 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
""")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False, height=850)
chatbot = gr.Chatbot(
elem_id="chat-box", show_label=False, height=850)
with gr.Column(scale=1):
with gr.Row():
max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 5, value=0.95, step=0.01, label="Temperature", interactive=True)
max_length = gr.Slider(
32,
4096,
value=2048,
step=1.0,
label="Maximum length",
interactive=True)
top_p = gr.Slider(
0.01,
1,
value=0.7,
step=0.01,
label="Top P",
interactive=True)
temperature = gr.Slider(
0.01,
5,
value=0.95,
step=0.01,
label="Temperature",
interactive=True)
with gr.Row():
query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4)
query = gr.Textbox(
show_label=False, placeholder="Prompts", lines=4)
generate_button = gr.Button("生成")
with gr.Row():
continue_message = gr.Textbox(
Expand All @@ -108,17 +132,29 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
last_state = gr.State([[], '', '']) # history, query, continue_message
generate_button.click(
predictor.predict_continue,
inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history, last_state],
inputs=[
query, blank_input, max_length, top_p, temperature, allow_generate,
history, last_state
],
outputs=[chatbot, query])
revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message])
revise_btn.click(
revise,
inputs=[history, revise_message],
outputs=[chatbot, revise_message])
revoke_btn.click(revoke, inputs=[history, last_state], outputs=[chatbot])
continue_btn.click(
predictor.predict_continue,
inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state],
inputs=[
query, continue_message, max_length, top_p, temperature,
allow_generate, history, last_state
],
outputs=[chatbot, query, continue_message])
regenerate_btn.click(
regenerate,
inputs=[last_state, max_length, top_p, temperature, allow_generate],
outputs=[chatbot, query, continue_message])
regenerate_btn.click(regenerate, inputs=[last_state, max_length, top_p, temperature, allow_generate],
outputs=[chatbot, query, continue_message])
interrupt_btn.click(interrupt, inputs=[allow_generate])

demo.queue().launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
demo.queue().launch(
server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
demo.close()
17 changes: 9 additions & 8 deletions download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@

model_name_list = [
# 'THUDM/chatglm-6b-int4-qe',
'THUDM/chatglm-6b-int4',
'THUDM/chatglm-6b',
# 'THUDM/chatglm-6b-int4',
# 'THUDM/chatglm-6b',
# 'THUDM/glm-10b-chinese',
#
# 'THUDM/chatglm2-6b',
# 'THUDM/chatglm2-6b-int4',
#
# 'THUDM/chatglm3-6b',
# 'THUDM/chatglm3-6b-128k',

'THUDM/chatglm2-6b',
'THUDM/chatglm2-6b-int4',

'THUDM/chatglm3-6b',
'THUDM/chatglm3-6b-128k',
'THUDM/glm-4-9b-chat-1m',

# 'silver/chatglm-6b-slim',
# 'silver/chatglm-6b-int4-slim',
Expand All @@ -33,7 +35,6 @@
print(f'Downloading {model_name}')
snapshot_download(
repo_id=model_name,
resume_download=True,
max_workers=2,
# proxies={'https': 'http://127.0.0.1:7890'}
)
Expand Down
58 changes: 58 additions & 0 deletions glm4/configuration_chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"

def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
classifier_dropout=None,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
rope_ratio=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.rope_ratio = rope_ratio
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
super().__init__(**kwargs)
Loading

0 comments on commit 8e77164

Please sign in to comment.