Skip to content

Commit

Permalink
Updated to use memory instead of chat history, fix for Bedrock Anthr…
Browse files Browse the repository at this point in the history
…opic

(cherry picked from commit 7c09863)
  • Loading branch information
3coins committed Oct 23, 2023
1 parent eef6dea commit 9b346e6
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler

PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)


class AskChatHandler(BaseChatHandler):
"""Processes messages prefixed with /ask. This actor will
Expand All @@ -27,9 +37,15 @@ def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
self.llm = provider(**provider_params)
self.chat_history = []
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
self.llm_chain = ConversationalRetrievalChain.from_llm(
self.llm, self._retriever, verbose=True
self.llm,
self._retriever,
memory=memory,
condense_question_prompt=CONDENSE_PROMPT,
verbose=False,
)

async def _process_message(self, message: HumanChatMessage):
Expand All @@ -44,14 +60,8 @@ async def _process_message(self, message: HumanChatMessage):
self.get_llm_chain()

try:
# limit chat history to last 2 exchanges
self.chat_history = self.chat_history[-2:]

result = await self.llm_chain.acall(
{"question": query, "chat_history": self.chat_history}
)
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.chat_history.append((query, response))
self.reply(response, message)
except AssertionError as e:
self.log.error(e)
Expand Down

0 comments on commit 9b346e6

Please sign in to comment.