Skip to content

Commit

Permalink
Merge pull request #39 from Azure-Samples/add-support-for-streaming
Browse files Browse the repository at this point in the history
Add streaming
  • Loading branch information
john0isaac authored Jun 21, 2024
2 parents 3f11262 + 2d5d224 commit 8f4ee48
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 60 deletions.
8 changes: 2 additions & 6 deletions frontend/src/pages/chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,9 @@ const Chat = () => {
for await (const event of readNDJSONStream(responseBody)) {
if (event["context"] && event["context"]["data_points"] && event["message"]) {
askResponse = event as ChatAppResponse;
} else if (event["message"]["content"]) {
} else if (event["content"]) {
setIsLoading(false);
await updateState(event["message"]["content"]);
} else if (event["context"]) {
// Update context with new keys from latest event
askResponse.context = { ...askResponse.context, ...event["context"] };
await updateState(event["content"]);
} else if (event["error"]) {
throw Error(event["error"]);
}
Expand Down Expand Up @@ -406,7 +403,6 @@ const Chat = () => {
onRenderLabel={(props: ICheckboxProps | undefined) => (
<HelpCallout labelId={shouldStreamId} fieldId={shouldStreamFieldId} helpText={toolTipText.streamChat} label={props?.label} />
)}
disabled={true}
/>
</Panel>
</div>
Expand Down
53 changes: 50 additions & 3 deletions src/quartapp/app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from collections.abc import AsyncGenerator
from json import dumps
from pathlib import Path
from typing import Any

from quart import Quart, Response, jsonify, request, send_file, send_from_directory
from quart import Quart, Response, jsonify, make_response, request, send_file, send_from_directory

from quartapp.approaches.schemas import RetrievalResponse
from quartapp.approaches.schemas import Message, RetrievalResponse
from quartapp.config import AppConfig

logging.basicConfig(
Expand All @@ -14,6 +16,18 @@
)


async def format_as_ndjson(r: AsyncGenerator[RetrievalResponse | Message, None]) -> AsyncGenerator[str, None]:
"""
Format the response as NDJSON
"""
try:
async for event in r:
yield dumps(event.to_dict(), ensure_ascii=False) + "\n"
except Exception as error:
logging.exception("Exception while generating response stream: %s", error)
yield dumps({"error": str(error)}, ensure_ascii=False) + "\n"


def create_app(app_config: AppConfig, test_config: dict[str, Any] | None = None) -> Quart:
app = Quart(__name__, static_folder="static")

Expand Down Expand Up @@ -56,8 +70,13 @@ async def chat() -> Any:
if not body:
return jsonify({"error": "request body is empty"}), 400

# Get the request message, session_state, context from the request body
# Get the request message
messages: list = body.get("messages", [])

if not messages:
return jsonify({"error": "request must have a message"}), 400

# Get the request session_state, context from the request body
session_state = body.get("session_state", None)
context = body.get("context", {})

Expand Down Expand Up @@ -90,6 +109,34 @@ async def stream_chat() -> Any:
if not body:
return jsonify({"error": "request body is empty"}), 400

# Get the request message
messages: list = body.get("messages", [])

if not messages:
return jsonify({"error": "request must have a message"}), 400

# Get the request session_state, context from the request body
session_state = body.get("session_state", None)
context = body.get("context", {})

# Get the overrides from the context
override = context.get("overrides", {})
retrieval_mode: str = override.get("retrieval_mode", "vector")
temperature: float = override.get("temperature", 0.3)
top: int = override.get("top", 3)
score_threshold: float = override.get("score_threshold", 0.5)

if retrieval_mode == "rag":
result: AsyncGenerator[RetrievalResponse | Message, None] = app_config.run_rag_stream(
session_state=session_state,
messages=messages,
temperature=temperature,
limit=top,
score_threshold=score_threshold,
)
response = await make_response(format_as_ndjson(result))
response.mimetype = "application/x-ndjson"
return response
return jsonify({"error": "Not Implemented!"}), 501

return app
Expand Down
120 changes: 99 additions & 21 deletions src/quartapp/approaches/rag.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,64 @@
from langchain.chains.combine_documents import create_stuff_documents_chain
import json
from collections.abc import AsyncIterator

from langchain.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage

from quartapp.approaches.base import ApproachesBase
from quartapp.approaches.schemas import DataPoint


def get_data_points(documents: list[Document]) -> list[DataPoint]:
data_points: list[DataPoint] = []

for res in documents:
raw_data = json.loads(res.page_content)
json_data_point: DataPoint = DataPoint()
json_data_point.name = raw_data.get("name")
json_data_point.description = raw_data.get("description")
json_data_point.price = raw_data.get("price")
json_data_point.category = raw_data.get("category")
data_points.append(json_data_point)
return data_points


chat_history_prompt = """\
REPHRASE_PROMPT = """\
Given the following conversation and a follow up question, rephrase the follow up \
question to be a standalone question.
Chat History:
{messages}
{chat_history}
Follow Up Input: {question}
Standalone Question:"""

context_prompt = """You are a chatbot that can have a conversation about Food dishes from the context.
Answer the following question based only on the provided context:
CONTEXT_PROMPT = """\
You are a restaurant chatbot, tasked with answering any question about \
food dishes from the contex.
Generate a response of 80 words or less for the \
given question based solely on the provided search results (name, description, price, and category). \
You must only use information from the provided search results. Use an unbiased and \
fun tone. Do not repeat text. Your response must be solely based on the provided context.
If there is nothing in the context relevant to the question at hand, just say "Hmm, \
I'm not sure." Don't try to make up an answer.
Anything between the following `context` html blocks is retrieved from a knowledge \
bank, not part of the conversation with the user.
Context:
{context}
<context>
{context}
<context/>
Question: {input}"""
REMEMBER: If there is no relevant information within the context, just say "Hmm, I'm \
not sure." Don't try to make up an answer. Anything between the preceding 'context' \
html blocks is retrieved from a knowledge bank, not part of the conversation with the \
user.\
User Question: {input}
Chatbot Response:"""


class RAG(ApproachesBase):
Expand All @@ -35,31 +74,70 @@ async def run(
self._chat.temperature = 0.3

# Create a vector context aware chat retriever
history_prompt_template = ChatPromptTemplate.from_template(chat_history_prompt)
history_chain = history_prompt_template | self._chat
rephrase_prompt_template = ChatPromptTemplate.from_template(REPHRASE_PROMPT)
rephrase_chain = rephrase_prompt_template | self._chat

# Rephrase the question
rephrased_question = await history_chain.ainvoke({"messages": messages})
rephrased_question = await rephrase_chain.ainvoke({"chat_history": messages[:-1], "question": messages[-1]})

print(rephrased_question.content)
# Perform vector search
vector_context = await retriever.ainvoke(str(rephrased_question.content))
data_points: list[DataPoint] = get_data_points(vector_context)

# Create a vector context aware chat retriever
context_prompt_template = ChatPromptTemplate.from_template(context_prompt)
document_chain = create_stuff_documents_chain(self._chat, context_prompt_template)

context_prompt_template = ChatPromptTemplate.from_template(CONTEXT_PROMPT)
self._chat.temperature = temperature
context_chain = context_prompt_template | self._chat

if vector_context:
if data_points:
# Perform RAG search
response = await document_chain.ainvoke(
{"context": vector_context, "input": rephrased_question.content}
response = await context_chain.ainvoke(
{"context": [dp.to_dict() for dp in data_points], "input": rephrased_question.content}
)

return vector_context, response
return vector_context, str(response.content)

# Perform RAG search with no context
response = await document_chain.ainvoke({"context": vector_context, "input": rephrased_question.content})
return [], response
response = await context_chain.ainvoke({"context": [], "input": rephrased_question.content})
return [], str(response.content)
return [], ""

async def run_stream(
self, messages: list, temperature: float, limit: int, score_threshold: float
) -> tuple[list[Document], AsyncIterator[BaseMessage]]:
# Create a vector store retriever
retriever = self._vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": limit, "score_threshold": score_threshold}
)

self._chat.temperature = 0.3

# Create a vector context aware chat retriever
rephrase_prompt_template = ChatPromptTemplate.from_template(REPHRASE_PROMPT)
rephrase_chain = rephrase_prompt_template | self._chat

# Rephrase the question
rephrased_question = await rephrase_chain.ainvoke({"chat_history": messages[:-1], "question": messages[-1]})

print(rephrased_question.content)
# Perform vector search
vector_context = await retriever.ainvoke(str(rephrased_question.content))
data_points: list[DataPoint] = get_data_points(vector_context)

# Create a vector context aware chat retriever
context_prompt_template = ChatPromptTemplate.from_template(CONTEXT_PROMPT)
self._chat.temperature = temperature
context_chain = context_prompt_template | self._chat

if data_points:
# Perform RAG search
response = context_chain.astream(
{"context": [dp.to_dict() for dp in data_points], "input": rephrased_question.content}
)

return vector_context, response

# Perform RAG search with no context
response = context_chain.astream({"context": [], "input": rephrased_question.content})
return [], response
25 changes: 25 additions & 0 deletions src/quartapp/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections.abc import AsyncGenerator
from uuid import uuid4

from quartapp.approaches.schemas import Context, DataPoint, Message, RetrievalResponse, Thought
Expand Down Expand Up @@ -109,3 +110,27 @@ async def run_rag(
)

return RetrievalResponse(context, message, new_session_state)

async def run_rag_stream(
self, session_state: str | None, messages: list, temperature: float, limit: int, score_threshold: float
) -> AsyncGenerator[RetrievalResponse | Message, None]:
rag_response, answer = await self.setup.rag.run_stream(messages, temperature, limit, score_threshold)

new_session_state: str = session_state if session_state else str(uuid4())

context: Context = await self.get_context(rag_response)

empty_message: Message = Message(content="", role="assistant")

yield RetrievalResponse(context, empty_message, new_session_state)

async for message_chunk in answer:
message: Message = Message(content=str(message_chunk.content), role="assistant")
yield message

await self.add_to_cosmos(
old_messages=messages,
new_message=message.to_dict(),
session_state=session_state,
new_session_state=new_session_state,
)
49 changes: 22 additions & 27 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch

import langchain_core
import mongomock
import pytest
import pytest_asyncio
Expand Down Expand Up @@ -46,29 +48,15 @@ def approaches_base_mock():

# Mock Vector Store
mock_vector_store = MagicMock()
mock_retriever = MagicMock()
mock_document = Document(
page_content='{"name": "test", "description": "test", "price": "5.0USD", "category": "test"}'
)
mock_retriever.ainvoke = AsyncMock(return_value=[mock_document]) # Assume there is always a response
mock_vector_store.as_retriever.return_value = mock_retriever
mock_vector_store.as_retriever.return_value.ainvoke = AsyncMock(return_value=[mock_document])

# Mock Chat
mock_chat = MagicMock()

mock_content = MagicMock()
mock_content.content = "content"

runnable_return = MagicMock()
runnable_return.to_json = MagicMock(return_value={"kwargs": {"messages": [mock_content]}})

mock_runnable = MagicMock()
mock_runnable.ainvoke = AsyncMock(return_value=runnable_return)

mock_chat.__or__.return_value = mock_runnable

mock_chat.ainvoke = AsyncMock(return_value=mock_content)

# Mock Data Collection
mock_mongo_document = {"textContent": mock_document.page_content, "source": "test"}
mock_data_collection = MagicMock()
mock_data_collection.find = MagicMock()
Expand Down Expand Up @@ -161,17 +149,6 @@ def app_config_mock(setup_mock):
return app_config


@pytest.fixture(autouse=True)
def create_stuff_documents_chain_mock(monkeypatch):
"""Mock quartapp.approaches.rag.create_stuff_documents_chain."""
document_chain_mock = MagicMock()
document_chain_mock.ainvoke = AsyncMock(return_value="content")
_mock = MagicMock()
_mock.return_value = document_chain_mock
monkeypatch.setattr(quartapp.approaches.rag, quartapp.approaches.rag.create_stuff_documents_chain.__name__, _mock)
return _mock


@pytest.fixture(autouse=True)
def setup_data_collection_mock(monkeypatch):
"""Mock quartapp.approaches.setup.setup_data_collection."""
Expand All @@ -180,6 +157,24 @@ def setup_data_collection_mock(monkeypatch):
return _mock


@pytest.fixture(autouse=True)
def mock_runnable_or(monkeypatch):
"""Mock langchain_core.runnables.base.Runnable.__or__."""

@dataclass
class MockContent:
content: str

or_return = MagicMock()
or_return.ainvoke = AsyncMock(return_value=MockContent(content="content"))
or_mock = MagicMock()
or_mock.return_value = or_return
monkeypatch.setattr(
langchain_core.runnables.base.Runnable, langchain_core.runnables.base.Runnable.__or__.__name__, or_mock
)
return or_mock


@pytest_asyncio.fixture
async def app_mock(app_config_mock):
"""Create a test app with the test config mock."""
Expand Down
Loading

0 comments on commit 8f4ee48

Please sign in to comment.