diff --git a/query_data.py b/query_data.py index c0028317f..9e592b1b0 100644 --- a/query_data.py +++ b/query_data.py @@ -1,7 +1,7 @@ """Create a ChatVectorDBChain for question/answering.""" from langchain.callbacks.base import AsyncCallbackManager from langchain.callbacks.tracers import LangChainTracer -from langchain.chains import ChatVectorDBChain +from langchain.chains import ConversationalRetrievalChain from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT, QA_PROMPT) from langchain.chains.llm import LLMChain @@ -12,9 +12,9 @@ def get_chain( vectorstore: VectorStore, question_handler, stream_handler, tracing: bool = False -) -> ChatVectorDBChain: - """Create a ChatVectorDBChain for question/answering.""" - # Construct a ChatVectorDBChain with a streaming llm for combine docs +) -> ConversationalRetrievalChain: + """Create a ConversationalRetrievalChain for question/answering.""" + # Construct a ConversationalRetrievalChain with a streaming llm for combine docs # and a separate, non-streaming llm for question generation manager = AsyncCallbackManager([]) question_manager = AsyncCallbackManager([question_handler]) @@ -45,8 +45,8 @@ def get_chain( streaming_llm, chain_type="stuff", prompt=QA_PROMPT, callback_manager=manager ) - qa = ChatVectorDBChain( - vectorstore=vectorstore, + qa = ConversationalRetrievalChain( + retriever=vectorstore.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator, callback_manager=manager, diff --git a/requirements.txt b/requirements.txt index 1b7831d96..91c0f0d76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,10 +4,11 @@ black isort websockets pydantic -langchain +langchain==0.0.152 uvicorn jinja2 faiss-cpu bs4 unstructured libmagic +tiktoken