Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrates to langgraph #165

Merged
merged 13 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

DEFAULTS = {
"INIT_API_FN": "django_ai_assistant.api.views.init_api",
"USE_LANGGRAPH": False,
"CAN_CREATE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_UPDATE_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
Expand Down
132 changes: 128 additions & 4 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import inspect
import re
from typing import Any, ClassVar, Sequence, cast
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
Expand All @@ -12,11 +12,20 @@
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain.tools import StructuredTool
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand All @@ -37,12 +46,15 @@
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode

from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
)
from django_ai_assistant.langchain.tools import Tool
from django_ai_assistant.langchain.tools import tool as tool_decorator


Expand Down Expand Up @@ -437,6 +449,115 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
prompt | llm | StrOutputParser() | retriever,
)

@with_cast_id
def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
"""Create the Langchain graph for the assistant.\n
This graph is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_graph` uses many other methods to create the graph for the assistant.
Prefer to override the other methods to customize the graph for the assistant.
Only override this method if you need to customize the graph at a lower level.

Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.

Returns:
the compiled graph
"""
llm = self.get_llm()
tools = self.get_tools()
llm_with_tools = llm.bind_tools(tools) if tools else llm
message_history = self.get_message_history(thread_id)

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right)

if thread_id:
messages_to_store = [
m
for m in result
if isinstance(m, HumanMessage | ChatMessage)
or (isinstance(m, AIMessage) and not m.tool_calls)
]
message_history.add_messages(messages_to_store)

return result

class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], custom_add_messages]
input: str # noqa: A003
context: str
output: str

def setup(state: AgentState):
return {"messages": [SystemMessage(content=self.get_instructions())]}

def retriever(state: AgentState):
if not self.has_rag:
return

retriever = self.get_history_aware_retriever()
fjsj marked this conversation as resolved.
Show resolved Hide resolved
docs = retriever.invoke({"input": state["input"], "history": state["messages"]})

document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()

formatted_docs = document_separator.join(
format_document(doc, document_prompt) for doc in docs
)

return {
"messages": SystemMessage(
content=f"---START OF CONTEXT---\n{formatted_docs}---END OF CONTEXT---\n"
)
}

def history(state: AgentState):
history = message_history.messages if thread_id else []
return {"messages": [*history, HumanMessage(content=state["input"])]}

def agent(state: AgentState):
response = llm_with_tools.invoke(state["messages"])

return {"messages": [response]}

def tool_selector(state: AgentState):
last_message = state["messages"][-1]

if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "call_tool"

return "continue"

def record_response(state: AgentState):
return {"output": state["messages"][-1].content}

workflow = StateGraph(AgentState)

workflow.add_node("setup", setup)
workflow.add_node("retriever", retriever)
workflow.add_node("history", history)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("respond", record_response)

workflow.set_entry_point("setup")
workflow.add_edge("setup", "retriever")
workflow.add_edge("retriever", "history")
workflow.add_edge("history", "agent")
workflow.add_conditional_edges(
"agent",
tool_selector,
{
"call_tool": "tools",
"continue": "respond",
},
)
workflow.add_edge("tools", "agent")
workflow.add_edge("respond", END)

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
Expand Down Expand Up @@ -539,7 +660,10 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
dict: The output of the assistant chain,
structured like `{"output": "assistant response", "history": ...}`.
"""
chain = self.as_chain(thread_id)
if app_settings.USE_LANGGRAPH:
chain = self.as_graph(thread_id)
else:
chain = self.as_chain(thread_id)
return chain.invoke(*args, **kwargs)

@with_cast_id
Expand Down Expand Up @@ -577,7 +701,7 @@ def as_tool(self, description: str) -> BaseTool:
Returns:
BaseTool: A tool that runs the assistant. The tool name is this assistant's id.
"""
return Tool.from_function(
return StructuredTool.from_function(
fjsj marked this conversation as resolved.
Show resolved Hide resolved
func=self._run_as_tool,
name=self.id,
description=description,
Expand Down
22 changes: 17 additions & 5 deletions django_ai_assistant/langchain/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from django.db import transaction

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.models import Message
Expand Down Expand Up @@ -73,13 +77,17 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
messages: A list of BaseMessage objects to store.
"""
with transaction.atomic():
existing_message_ids = [m.id for m in self.get_messages()]

messages_to_create = [m for m in messages if m.id not in existing_message_ids]
fjsj marked this conversation as resolved.
Show resolved Hide resolved

created_messages = Message.objects.bulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
[Message(thread_id=self._thread_id, message=dict()) for _ in messages_to_create]
)

# Update langchain message IDs with Django message IDs
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id = messages_to_create[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

Expand All @@ -91,15 +99,19 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
Args:
messages: A list of BaseMessage objects to store.
"""
existing_message_ids = [m.id for m in await self.aget_messages()]

messages_to_create = [m for m in messages if m.id not in existing_message_ids]

# NOTE: This method does not use transactions because it do not yet work in async mode.
# Source: https://docs.djangoproject.com/en/5.0/topics/async/#queries-the-orm
created_messages = await Message.objects.abulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
[Message(thread_id=self._thread_id, message=dict()) for _ in messages_to_create]
)

# Update langchain message IDs with Django message IDs
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id = messages_to_create[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

Expand Down
1 change: 1 addition & 0 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@

# django-ai-assistant

AI_ASSISTANT_USE_LANGGRAPH = True
AI_ASSISTANT_INIT_API_FN = "django_ai_assistant.api.views.init_api"
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
Expand Down
34 changes: 18 additions & 16 deletions example/tour_guide/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,17 @@ class TourGuideAIAssistant(AIAssistant):
name = "Tour Guide Assistant"
instructions = (
"You are a tour guide assistant that offers information about nearby attractions. "
"The application will capture the user coordinates, and should provide a list of nearby attractions. "
"Use the available tools to suggest nearby attractions to the user. "
"You don't need to include all the found items, only include attractions that are relevant for a tourist. "
"Select the top 10 best attractions for a tourist, if there are less then 10 relevant items only return these. "
"Order items by the most relevant to the least relevant. "
"If there are no relevant attractions nearby, just keep the list empty. "
"Your response will be integrated with a frontend web application therefore it's critical that "
"it only contains a valid JSON. DON'T include '```json' in your response. "
"You will receive the user coordinates and should use available tools to find nearby attractions. "
"Only call the find_nearby_attractions tool once. "
"Your response should only contain valid JSON data. DON'T include '```json' in your response. "
fjsj marked this conversation as resolved.
Show resolved Hide resolved
"The JSON should be formatted according to the following structure: \n"
f"\n\n{_tour_guide_example_json()}\n\n\n"
"In the 'attraction_name' field provide the name of the attraction in english. "
"In the 'attraction_description' field generate an overview about the attraction with the most important information, "
"curiosities and interesting facts. "
"Only include a value for the 'attraction_url' field if you find a real value in the provided data otherwise keep it empty. "
)
model = "gpt-4o"
model = "gpt-4o-2024-08-06"
fjsj marked this conversation as resolved.
Show resolved Hide resolved

def get_instructions(self):
# Warning: this will use the server's timezone
Expand All @@ -60,11 +55,18 @@ def get_instructions(self):
return f"Today is: {current_date_str}. {self.instructions}"

@method_tool
def get_nearby_attractions_from_api(self, latitude: float, longitude: float) -> dict:
"""Find nearby attractions based on user's current location."""
return fetch_points_of_interest(
latitude=latitude,
longitude=longitude,
tags=["tourism", "leisure", "place", "building"],
radius=500,
def find_nearby_attractions(self, latitude: float, longitude: float) -> str:
"""
Find nearby attractions based on user's current location.
Returns a JSON with the list of all types of points of interest,
which may or may not include attractions.
Calls to this tool are idempotent.
"""
return json.dumps(
fetch_points_of_interest(
latitude=latitude,
longitude=longitude,
tags=["tourism", "leisure", "place", "building"],
radius=500,
)
)
42 changes: 38 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading