-
Notifications
You must be signed in to change notification settings - Fork 4
/
server.py
40 lines (32 loc) · 1.02 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import sys
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import pipeline
app = Flask(__name__)
CORS(app)
def set_up_gen_pipeline(model_path):
global gen_pipeline
gen_pipeline = pipeline('text-generation', model=model_path, framework='pt')
@app.route("/")
def hello():
res = jsonify({
"hello": "world!"
})
return res
@app.route("/autocomplete")
def prompt():
context = request.args.get('context', default = '', type = str)
print(f'context = {context}')
outputs = gen_pipeline(context, max_length=200, num_return_sequences=3, do_sample=True, eos_token_id=2, pad_token_id=0,
skip_special_tokens=True, top_k=50, top_p=0.95)
print(f'outputs = {outputs}')
res = jsonify({
"outputs": outputs
})
return res
if __name__ == '__main__':
if len(sys.argv) < 2:
print("Missing required argument: model path.")
exit(0)
set_up_gen_pipeline(sys.argv[1])
app.run(host="localhost", port=5000, debug=True)