-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
77 lines (62 loc) · 2.11 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os, sys, torch
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(f'{current_path}/../rwkv_pip_package/src')
# set these before import RWKV
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries
from rwkv_numba.model import RWKV # pip install rwkv
model = RWKV(model='model/RWKV-4-World-CHNtuned-3B-v1-20230625-ctx4096', strategy='cuda fp16i8')
from rwkv_numba.utils import PIPELINE, PIPELINE_ARGS
#pipeline = PIPELINE(model, f"{current_path}/20B_tokenizer.json")
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
args = PIPELINE_ARGS(temperature=1.2, top_p=0.2, top_k=0, # top_k = 0 then ignore
alpha_frequency=0.0,
alpha_presence=0.0,
token_ban=[0], # ban the generation of some tokens
token_stop=[], # stop generation whenever you see any token here
chunk_len=256) # split input into chunks to save VRAM (shorter -> slower)
########################################################################################################
import time
#msg = "你好"
def gen(msg):
answer, state = pipeline.generate(msg, token_count=500, args=args)
return answer
msg1 = ["Q: 布洛芬的作用\n\nA:",
"Q: 布洛芬的作用\n\nA:",
"Q: 布洛芬的作用\n\nA:",
"Q: 布洛芬的作用\n\nA:",
]
msg2 = ["Q: 布洛芬的作用\n\nA:"]
#gen(msg)
# start = time.time()
# gen(msg1)
# end = time.time()
# print(end-start)
#
# start = time.time()
# gen(msg1)
# end = time.time()
# print(end-start)
# torch.cuda.synchronize()
# start = time.time()
# gen(msg1)
# torch.cuda.synchronize()
# end = time.time()
# print(end-start)
# start = time.time()
# gen(msg1)
# torch.cuda.synchronize()
# end = time.time()
# print(end-start)
start = time.time()
gen(msg2)
torch.cuda.synchronize()
end = time.time()
print(end-start)
start = time.time()
gen(msg2)
torch.cuda.synchronize()
end = time.time()
print(end-start)