-
Notifications
You must be signed in to change notification settings - Fork 31
/
test_models.py
56 lines (48 loc) · 1.75 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
def test_model(model_name):
if 'glm-4' in model_name.lower():
from predictors.glm4_predictor import GLM4
predictor = GLM4(model_name)
elif 'chatglm3' in model_name.lower():
from predictors.chatglm3_predictor import ChatGLM3
predictor = ChatGLM3(model_name)
elif 'chatglm2' in model_name.lower():
from predictors.chatglm2_predictor import ChatGLM2
predictor = ChatGLM2(model_name)
elif 'chatglm' in model_name.lower():
from predictors.chatglm_predictor import ChatGLM
predictor = ChatGLM(model_name)
elif 'gptq' in model_name.lower():
from predictors.llama_gptq import LLaMaGPTQ
predictor = LLaMaGPTQ(model_name)
elif 'llama' in model_name.lower():
from predictors.llama import LLaMa
predictor = LLaMa(model_name)
elif 'debug' in model_name.lower():
from predictors.debug import Debug
predictor = Debug(model_name)
else:
from predictors.chatglm_predictor import ChatGLM
predictor = ChatGLM(model_name)
top_p = 0.01
max_length = 128
temperature = 0.01
history = []
line = '你是谁?'
last_message = '我是张三丰,我是武当派'
print(line)
for x in predictor.predict_continue(
query=line, latest_message=last_message,
max_length=max_length, top_p=top_p, temperature=temperature,
allow_generate=[True], history=history, last_state=[[], None, None]):
print(x[0][-1][1])
def main():
model_list = [
'THUDM/glm-4-9b-chat-1m',
]
for model_name in model_list:
print(f'Testing {model_name}')
test_model(model_name)
if __name__ == '__main__':
main()