-
Notifications
You must be signed in to change notification settings - Fork 0
/
rag_inf_api.py
128 lines (107 loc) · 4.51 KB
/
rag_inf_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import List, Tuple, AsyncGenerator
import torch
from transformers import AutoTokenizer
from search import create_text_searcher
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
DEFAULT_SYSTEM_PROMPT = 'あなたは誠実で優秀な日本人のアシスタントです。'
DEFAULT_QA_PROMPT = """
## Instruction
参考情報を元に、質問に回答してください。
回答は参考情報だけを元に作成し、推測や一般的な知識を含めないでください。参考情報に答えが見つからなかった場合は、その旨を述べてください。
## 参考情報
{contexts}
## 質問
{question}
""".strip()
LLM_MODEL_ID = 'elyza/ELYZA-japanese-Llama-2-13b-instruct'
class InferenceEngine:
def __init__(self) -> None:
if not torch.cuda.is_available():
raise EnvironmentError('need CUDA env.')
self.llm_engine = self.init_llm_engine()
self.searcher = create_text_searcher()
self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
def init_llm_engine(self) -> AsyncLLMEngine:
engine_args = AsyncEngineArgs(model=LLM_MODEL_ID, dtype='bfloat16',
tensor_parallel_size=4,
disable_log_requests=True,
disable_log_stats=True,
gpu_memory_utilization=0.6)
return AsyncLLMEngine.from_engine_args(engine_args)
def get_prompt(self, question: str, contexts: List[str], system_prompt: str) -> str:
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
message = DEFAULT_QA_PROMPT.format(contexts=self.searcher.to_contexts(contexts), question=question)
texts.append(f'{message} [/INST]')
return ''.join(texts)
async def run(
self,
question: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
max_new_tokens: int = 1024,
temperature: float = 0.8,
top_p: float = 0.95,
top_k: int = 50,
do_sample: bool = False,
repetition_penalty: float = 1.2,
stream: bool = True,
) -> AsyncGenerator | str:
contexts, scores = await self.search_contexts(question)
prompt = self.get_prompt(question, contexts, system_prompt)
if not do_sample:
temperature = 0
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
request_id = random_uuid()
results_generator = self.llm_engine.generate(prompt, sampling_params, request_id)
async def stream_results() -> AsyncGenerator:
async for request_output in results_generator:
yield ''.join([output.text for output in request_output.outputs])
if stream:
return stream_results()
else:
async for request_output in results_generator:
pass
return ''.join([output.text for output in request_output.outputs])
async def search_contexts(self, question: str) -> Tuple[List[str], List[float]]:
search_results = self.searcher.search(
question,
top_k=5,
)
scores, contexts = zip(*search_results)
return contexts, scores
class QuestionRequest(BaseModel):
question: str
max_new_tokens: int = 1024
temperature: float = 0.8
top_p: float = 0.95
top_k: int = 50
do_sample: bool = False
repetition_penalty: float = 1.2
app = FastAPI()
inferenceEngine = InferenceEngine()
@app.post("/question")
async def instruct(body: QuestionRequest):
async def generate():
async for item in await inferenceEngine.run(
question=body.question,
max_new_tokens=body.max_new_tokens,
temperature=body.temperature,
top_p=body.top_p,
top_k=body.top_k,
do_sample=body.do_sample,
repetition_penalty=body.repetition_penalty,
stream=True,
):
yield item
return StreamingResponse(generate(), media_type="text/event-stream")