forked from saxenaakansha30/documentor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rag.py
70 lines (58 loc) · 2.06 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from chunk_vector_store import ChunkVectorStore as cvs
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama, ChatOpenAI
from config import Config
class Rag:
vector_store = None
retriever = None
chain = None
def __init__(self) -> None:
self.csv_obj = cvs()
self.prompt = PromptTemplate.from_template(
"""
<s> [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
maximum and keep the answer concise. [/INST] </s>
[INST] Question: {question}
Context: {context}
Answer: [/INST]
"""
)
llm_config = Config.get_llm_config()
if Config.LLM_PROVIDER == "ollama":
self.model = ChatOllama(**llm_config)
elif Config.LLM_PROVIDER == "openai":
self.model = ChatOpenAI(**llm_config)
else:
raise ValueError(f"Unsupported LLM provider: {Config.LLM_PROVIDER}")
def set_retriever(self):
self.retriever = self.vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 3,
"score_threshold": 0.5,
},
)
# Augment the context to original prompt.
def augment(self):
self.chain = ({"context": self.retriever, "question": RunnablePassthrough()}
| self.prompt
| self.model
| StrOutputParser())
# Generate the response.
def ask(self, query: str):
if not self.chain:
return "Please upload a PDF file for context"
return self.chain.invoke(query)
# Stores the file into vector database.
def feed(self, file_path: str):
chunks = self.csv_obj.split_into_chunks(file_path)
self.vector_store = self.csv_obj.store_to_vector_database(chunks)
self.set_retriever()
self.augment()
def clear(self):
self.vector_store = None
self.chain = None
self.retriever = None