From f573826dbfd613c81d648be79162200b2431e01b Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Sat, 3 Feb 2024 18:14:15 +0530 Subject: [PATCH] Move to Langchain LCEL --- .vscode/settings.json | 22 ++++++- chain.py | 133 ++++++++++++++++++++++++------------------ main.py | 43 +++++++------- requirements.txt | 7 +-- template.py | 20 ++++--- utils/snowchat_ui.py | 13 ++--- 6 files changed, 140 insertions(+), 98 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ba77eac..48d8756 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,25 @@ "[python]": { "editor.defaultFormatter": "ms-python.python" }, - "python.formatting.provider": "none" + "python.formatting.provider": "none", + "workbench.colorCustomizations": { + "activityBar.activeBackground": "#7c185f", + "activityBar.background": "#7c185f", + "activityBar.foreground": "#e7e7e7", + "activityBar.inactiveForeground": "#e7e7e799", + "activityBarBadge.background": "#000000", + "activityBarBadge.foreground": "#e7e7e7", + "commandCenter.border": "#e7e7e799", + "sash.hoverBorder": "#7c185f", + "statusBar.background": "#51103e", + "statusBar.foreground": "#e7e7e7", + "statusBarItem.hoverBackground": "#7c185f", + "statusBarItem.remoteBackground": "#51103e", + "statusBarItem.remoteForeground": "#e7e7e7", + "titleBar.activeBackground": "#51103e", + "titleBar.activeForeground": "#e7e7e7", + "titleBar.inactiveBackground": "#51103e99", + "titleBar.inactiveForeground": "#e7e7e799" + }, + "peacock.color": "#51103e" } \ No newline at end of file diff --git a/chain.py b/chain.py index 5a55bcb..b5b64df 100644 --- a/chain.py +++ b/chain.py @@ -2,16 +2,25 @@ import boto3 import streamlit as st -from langchain.chains import ConversationalRetrievalChain, LLMChain -from langchain.chains.question_answering import load_qa_chain -from langchain.chat_models import ChatOpenAI, BedrockChat +from langchain.chat_models import BedrockChat, ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms import OpenAI from langchain.vectorstores import SupabaseVectorStore from pydantic import BaseModel, validator from supabase.client import Client, create_client -from template import CONDENSE_QUESTION_PROMPT, LLAMA_PROMPT, QA_PROMPT +from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT + +from operator import itemgetter + +from langchain.prompts.prompt import PromptTemplate +from langchain.schema import format_document +from langchain_core.messages import get_buffer_string +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableParallel, RunnablePassthrough +from langchain_openai import ChatOpenAI, OpenAIEmbeddings + +DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") supabase_url = st.secrets["SUPABASE_URL"] supabase_key = st.secrets["SUPABASE_SERVICE_KEY"] @@ -25,7 +34,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "claude", "mixtral"]: + if v not in ["gpt", "codellama", "mixtral"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -44,23 +53,15 @@ def __init__(self, config: ModelConfig): def setup(self): if self.model_type == "gpt": self.setup_gpt() - elif self.model_type == "claude": - self.setup_claude() + elif self.model_type == "codellama": + self.setup_codellama() elif self.model_type == "mixtral": self.setup_mixtral() def setup_gpt(self): - self.q_llm = OpenAI( - temperature=0.1, - api_key=self.secrets["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo-16k", - max_tokens=500, - base_url=self.gateway_url, - ) - self.llm = ChatOpenAI( - model_name="gpt-3.5-turbo-16k", - temperature=0.5, + model_name="gpt-3.5-turbo-0125", + temperature=0.2, api_key=self.secrets["OPENAI_API_KEY"], max_tokens=500, callbacks=[self.callback_handler], @@ -69,17 +70,9 @@ def setup_gpt(self): ) def setup_mixtral(self): - self.q_llm = OpenAI( - temperature=0.1, - api_key=self.secrets["MIXTRAL_API_KEY"], - model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", - max_tokens=500, - base_url="https://api.together.xyz/v1", - ) - self.llm = ChatOpenAI( model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", - temperature=0.5, + temperature=0.2, api_key=self.secrets["MIXTRAL_API_KEY"], max_tokens=500, callbacks=[self.callback_handler], @@ -87,42 +80,66 @@ def setup_mixtral(self): base_url="https://api.together.xyz/v1", ) - def setup_claude(self): - bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"], - region_name="us-east-1", - ) - parameters = { - "max_tokens_to_sample": 1000, - "stop_sequences": [], - "temperature": 0, - "top_p": 0.9, - } - self.q_llm = BedrockChat( - model_id="anthropic.claude-instant-v1", client=bedrock_runtime - ) - - self.llm = BedrockChat( - model_id="anthropic.claude-instant-v1", - client=bedrock_runtime, + def setup_codellama(self): + self.llm = ChatOpenAI( + model_name="codellama/codellama-70b-instruct", + temperature=0.2, + api_key=self.secrets["OPENROUTER_API_KEY"], + max_tokens=500, callbacks=[self.callback_handler], streaming=True, - model_kwargs=parameters, + base_url="https://openrouter.ai/api/v1", ) + # def setup_claude(self): + # bedrock_runtime = boto3.client( + # service_name="bedrock-runtime", + # aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"], + # aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"], + # region_name="us-east-1", + # ) + # parameters = { + # "max_tokens_to_sample": 1000, + # "stop_sequences": [], + # "temperature": 0, + # "top_p": 0.9, + # } + # self.q_llm = BedrockChat( + # model_id="anthropic.claude-instant-v1", client=bedrock_runtime + # ) + + # self.llm = BedrockChat( + # model_id="anthropic.claude-instant-v1", + # client=bedrock_runtime, + # callbacks=[self.callback_handler], + # streaming=True, + # model_kwargs=parameters, + # ) + def get_chain(self, vectorstore): - if not self.q_llm or not self.llm: - raise ValueError("Models have not been properly initialized.") - question_generator = LLMChain(llm=self.q_llm, prompt=CONDENSE_QUESTION_PROMPT) - doc_chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=QA_PROMPT) - conv_chain = ConversationalRetrievalChain( - retriever=vectorstore.as_retriever(), - combine_docs_chain=doc_chain, - question_generator=question_generator, + def _combine_documents( + docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" + ): + doc_strings = [format_document(doc, document_prompt) for doc in docs] + return document_separator.join(doc_strings) + + _inputs = RunnableParallel( + standalone_question=RunnablePassthrough.assign( + chat_history=lambda x: get_buffer_string(x["chat_history"]) + ) + | CONDENSE_QUESTION_PROMPT + | OpenAI() + | StrOutputParser(), ) - return conv_chain + _context = { + "context": itemgetter("standalone_question") + | vectorstore.as_retriever() + | _combine_documents, + "question": lambda x: x["standalone_question"], + } + conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm + + return conversational_qa_chain def load_chain(model_name="GPT-3.5", callback_handler=None): @@ -136,8 +153,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): query_name="v_match_documents", ) - if "claude" in model_name.lower(): - model_type = "claude" + if "codellama" in model_name.lower(): + model_type = "codellama" elif "GPT-3.5" in model_name: model_type = "gpt" elif "mixtral" in model_name.lower(): diff --git a/main.py b/main.py index 6e96d9e..c446141 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,8 @@ from snowflake.snowpark.exceptions import SnowparkSQLException from chain import load_chain -from utils.snow_connect import SnowflakeConnection + +# from utils.snow_connect import SnowflakeConnection from utils.snowchat_ui import StreamlitUICallbackHandler, message_func from utils.snowddl import Snowddl @@ -17,11 +18,10 @@ st.caption("Talk your way through data") model = st.radio( "", - options=["✨ GPT-3.5", "♾️ Claude", "⛰️ Mixtral"], + options=["✨ GPT-3.5", "♾️ codellama", "⛰️ Mixtral"], index=0, horizontal=True, ) - st.session_state["model"] = model INITIAL_MESSAGE = [ @@ -97,15 +97,10 @@ def get_sql(text): return sql_match.group(1) if sql_match else None -def append_message(content, role="assistant", display=False): - message = {"role": role, "content": content} - st.session_state.messages.append(message) - if role != "data": - append_chat_history(st.session_state.messages[-2]["content"], content) - - if callback_handler.has_streaming_ended: - callback_handler.has_streaming_ended = False - return +def append_message(content, role="assistant"): + """Appends a message to the session state messages.""" + if content.strip(): + st.session_state.messages.append({"role": role, "content": content}) def handle_sql_exception(query, conn, e, retries=2): @@ -135,14 +130,22 @@ def execute_sql(query, conn, retries=2): return handle_sql_exception(query, conn, e, retries) -if st.session_state.messages[-1]["role"] != "assistant": - content = st.session_state.messages[-1]["content"] - if isinstance(content, str): - result = chain( - {"question": content, "chat_history": st.session_state["history"]} - )["answer"] - print(result) - append_message(result) +if ( + "messages" in st.session_state + and st.session_state["messages"][-1]["role"] != "assistant" +): + user_input_content = st.session_state["messages"][-1]["content"] + # print(f"User input content is: {user_input_content}") + + if isinstance(user_input_content, str): + result = chain.invoke( + { + "question": user_input_content, + "chat_history": [h for h in st.session_state["history"]], + } + ) + append_message(result.content) + # if get_sql(result): # conn = SnowflakeConnection().get_session() # df = execute_sql(get_sql(result), conn) diff --git a/requirements.txt b/requirements.txt index 4aa7a89..51db6d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,12 @@ -langchain==0.0.350 +langchain==0.1.5 pandas==1.5.0 pydantic==1.10.8 snowflake_snowpark_python==1.5.0 snowflake-snowpark-python[pandas] -streamlit==1.27.1 +streamlit==1.31.0 supabase==1.0.3 unstructured==0.7.12 tiktoken==0.4.0 -openai==0.27.8 +openai==1.11.0 black==23.3.0 -replicate==0.8.4 boto3==1.28.57 \ No newline at end of file diff --git a/template.py b/template.py index 031af36..58e222f 100644 --- a/template.py +++ b/template.py @@ -1,4 +1,5 @@ from langchain.prompts.prompt import PromptTemplate +from langchain_core.prompts import ChatPromptTemplate template = """You are an AI chatbot having a conversation with a human. @@ -27,11 +28,13 @@ Write your response in markdown format. -Human: ```{question}``` +User: {question} {context} Assistant: """ + + B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" @@ -54,11 +57,14 @@ """ -LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST +# LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST + +CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_template(template) + +# QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) +# LLAMA_PROMPT = PromptTemplate( +# template=LLAMA_TEMPLATE, input_variables=["question", "context"] +# ) -CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(template) -QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"]) -LLAMA_PROMPT = PromptTemplate( - template=LLAMA_TEMPLATE, input_variables=["question", "context"] -) +QA_PROMPT = ChatPromptTemplate.from_template(TEMPLATE) diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index b13dc62..d8d4507 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -86,7 +86,7 @@ class StreamlitUICallbackHandler(BaseCallbackHandler): def __init__(self): # Buffer to accumulate tokens self.token_buffer = [] - self.placeholder = None + self.placeholder = st.empty() self.has_streaming_ended = False def _get_bot_message_container(self, text): @@ -111,13 +111,10 @@ def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs): """ self.token_buffer.append(token) complete_message = "".join(self.token_buffer) - if self.placeholder is None: - container_content = self._get_bot_message_container(complete_message) - self.placeholder = st.markdown(container_content, unsafe_allow_html=True) - else: - # Update the placeholder content - container_content = self._get_bot_message_container(complete_message) - self.placeholder.markdown(container_content, unsafe_allow_html=True) + + # Update the placeholder content with the complete message + container_content = self._get_bot_message_container(complete_message) + self.placeholder.markdown(container_content, unsafe_allow_html=True) def display_dataframe(self, df): """