diff --git a/recipes/natural_language_processing/chatbot/chatbot_ui.py b/recipes/natural_language_processing/chatbot/chatbot_ui.py index cc8e696b1..348b1db93 100644 --- a/recipes/natural_language_processing/chatbot/chatbot_ui.py +++ b/recipes/natural_language_processing/chatbot/chatbot_ui.py @@ -1,7 +1,8 @@ from langchain_openai import ChatOpenAI from langchain.chains import LLMChain from langchain_community.callbacks import StreamlitCallbackHandler -from langchain_core.prompts import ChatPromptTemplate +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.memory import ConversationBufferWindowMemory import streamlit as st import requests import time @@ -37,18 +38,28 @@ def checking_model_service(): for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) +@st.cache_resource() +def memory(): + memory = ConversationBufferWindowMemory(return_messages=True,k=10) + return memory + llm = ChatOpenAI(base_url=model_service, - api_key="sk-no-key-required", - streaming=True, - callbacks=[StreamlitCallbackHandler(st.container(), - expand_new_thoughts=True, - collapse_completed_thoughts=True)]) + api_key="sk-no-key-required", + streaming=True, + callbacks=[StreamlitCallbackHandler(st.empty(), + expand_new_thoughts=True, + collapse_completed_thoughts=True)]) + prompt = ChatPromptTemplate.from_messages([ ("system", "You are world class technical advisor."), + MessagesPlaceholder(variable_name="history"), ("user", "{input}") ]) -chain = LLMChain(llm=llm, prompt=prompt) +chain = LLMChain(llm=llm, + prompt=prompt, + verbose=False, + memory=memory()) if prompt := st.chat_input(): st.session_state.messages.append({"role": "user", "content": prompt})