-
Notifications
You must be signed in to change notification settings - Fork 516
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AI Agents Masterclass #7 - LangGraph Guide
- Loading branch information
Showing
5 changed files
with
442 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Rename this file to .env once you have filled in the below environment variables! | ||
|
||
# Get your Open AI API Key by following these instructions - | ||
# https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key | ||
# You only need this environment variable set if you set LLM_MODEL to a GPT model | ||
OPENAI_API_KEY= | ||
|
||
# Get your Anthropic API Key in your account settings - | ||
# https://console.anthropic.com/settings/keys | ||
# You only need this environment variable set if you set LLM_MODEL to a Claude model | ||
ANTHROPIC_API_KEY= | ||
|
||
# See all Open AI models you can use here - | ||
# https://platform.openai.com/docs/models | ||
# And all Anthropic models you can use here - | ||
# https://docs.anthropic.com/en/docs/about-claude/models | ||
# A good default to go with here is gpt-4o or claude-3-5-sonnet-20240620 | ||
LLM_MODEL=gpt-4o | ||
|
||
# Get your personal Asana access token through the developer console in Asana. | ||
# Feel free to follow these instructions - | ||
# https://developers.asana.com/docs/personal-access-token | ||
ASANA_ACCESS_TOKEN= | ||
|
||
# The Asana workspace ID is in the URL when you visit your Asana Admin Console (when logged in). | ||
# Go to the URL "https://app.asana.com/admin" and then your workspace ID | ||
# will appear in the URL as a slew of digits once the site loads. | ||
# If your URL is https://app.asana.com/admin/987654321/insights, then your | ||
# Asana workspace ID is 987654321 | ||
ASANA_WORKPLACE_ID= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from datetime import datetime | ||
import streamlit as st | ||
import asyncio | ||
import json | ||
import uuid | ||
import os | ||
|
||
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage, ToolMessage | ||
|
||
from runnable import get_runnable | ||
|
||
@st.cache_resource | ||
def create_chatbot_instance(): | ||
return get_runnable() | ||
|
||
chatbot = create_chatbot_instance() | ||
|
||
@st.cache_resource | ||
def get_thread_id(): | ||
return str(uuid.uuid4()) | ||
|
||
thread_id = get_thread_id() | ||
|
||
system_message = f""" | ||
You are a personal assistant who helps manage tasks in Asana. | ||
You never give IDs to the user since those are just for you to keep track of. | ||
When a user asks to create a task and you don't know the project to add it to for sure, clarify with the user. | ||
The current date is: {datetime.now().date()} | ||
""" | ||
|
||
async def prompt_ai(messages): | ||
config = { | ||
"configurable": { | ||
"thread_id": thread_id | ||
} | ||
} | ||
|
||
async for event in chatbot.astream_events( | ||
{"messages": messages}, config, version="v2" | ||
): | ||
if event["event"] == "on_chat_model_stream": | ||
yield event["data"]["chunk"].content | ||
|
||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# ~~~~~~~~~~~~~~~~~~ Main Function with UI Creation ~~~~~~~~~~~~~~~~~~~~ | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
async def main(): | ||
st.title("Asana Chatbot with LangGraph") | ||
|
||
# Initialize chat history | ||
if "messages" not in st.session_state: | ||
st.session_state.messages = [ | ||
SystemMessage(content=system_message) | ||
] | ||
|
||
# Display chat messages from history on app rerun | ||
for message in st.session_state.messages: | ||
message_json = json.loads(message.json()) | ||
message_type = message_json["type"] | ||
if message_type in ["human", "ai", "system"]: | ||
with st.chat_message(message_type): | ||
st.markdown(message_json["content"]) | ||
|
||
# React to user input | ||
if prompt := st.chat_input("What would you like to do today?"): | ||
# Display user message in chat message container | ||
st.chat_message("user").markdown(prompt) | ||
# Add user message to chat history | ||
st.session_state.messages.append(HumanMessage(content=prompt)) | ||
|
||
# Display assistant response in chat message container | ||
response_content = "" | ||
with st.chat_message("assistant"): | ||
message_placeholder = st.empty() # Placeholder for updating the message | ||
# Run the async generator to fetch responses | ||
async for chunk in prompt_ai(st.session_state.messages): | ||
response_content += chunk | ||
# Update the placeholder with the current response content | ||
message_placeholder.markdown(response_content) | ||
|
||
st.session_state.messages.append(AIMessage(content=response_content)) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
asana==5.0.7 | ||
python-dotenv==0.13.0 | ||
langchain==0.2.12 | ||
langchain-anthropic==0.1.22 | ||
langchain-community==0.2.11 | ||
langchain-core==0.2.28 | ||
langchain-openai==0.1.20 | ||
streamlit==1.36.0 | ||
langgraph==0.1.19 | ||
aiosqlite==0.20.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from langgraph.graph.message import AnyMessage, add_messages | ||
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver | ||
from langchain_core.runnables import RunnableConfig | ||
from langgraph.graph import END, StateGraph | ||
from typing_extensions import TypedDict | ||
from typing import Annotated, Literal, Dict | ||
from dotenv import load_dotenv | ||
import os | ||
|
||
from langchain_openai import ChatOpenAI | ||
from langchain_anthropic import ChatAnthropic | ||
from langchain_core.messages import ToolMessage | ||
|
||
from tools import available_functions | ||
|
||
load_dotenv() | ||
model = os.getenv('LLM_MODEL', 'gpt-4o') | ||
|
||
tools = [tool for _, tool in available_functions.items()] | ||
chatbot = ChatOpenAI(model=model, streaming=True) if "gpt" in model.lower() else ChatAnthropic(model=model, streaming=True) | ||
chatbot_with_tools = chatbot.bind_tools(tools) | ||
|
||
### State | ||
class GraphState(TypedDict): | ||
""" | ||
Represents the state of our graph. | ||
Attributes: | ||
messages: List of chat messages. | ||
""" | ||
messages: Annotated[list[AnyMessage], add_messages] | ||
|
||
async def call_model(state: GraphState, config: RunnableConfig) -> Dict[str, AnyMessage]: | ||
""" | ||
Function that calls the model to generate a response. | ||
Args: | ||
state (GraphState): The current graph state | ||
Returns: | ||
dict: The updated state with a new AI message | ||
""" | ||
print("---CALL MODEL---") | ||
messages = state["messages"] | ||
|
||
# Invoke the chatbot with the binded tools | ||
response = await chatbot_with_tools.ainvoke(messages, config) | ||
print("Response from model:", response) | ||
|
||
# We return an object because this will get added to the existing list | ||
return {"messages": response} | ||
|
||
def tool_node(state: GraphState) -> Dict[str, AnyMessage]: | ||
""" | ||
Function that handles all tool calls. | ||
Args: | ||
state (GraphState): The current graph state | ||
Returns: | ||
dict: The updated state with tool messages | ||
""" | ||
print("---TOOL NODE---") | ||
messages = state["messages"] | ||
last_message = messages[-1] if messages else None | ||
|
||
outputs = [] | ||
|
||
if last_message and last_message.tool_calls: | ||
for call in last_message.tool_calls: | ||
tool = available_functions.get(call['name'], None) | ||
|
||
if tool is None: | ||
raise Exception(f"Tool '{call['name']}' not found.") | ||
|
||
output = tool.invoke(call['args']) | ||
outputs.append(ToolMessage( | ||
output if isinstance(output, str) else json.dumps(output), | ||
tool_call_id=call['id'] | ||
)) | ||
|
||
return {'messages': outputs} | ||
|
||
def should_continue(state: GraphState) -> Literal["__end__", "tools"]: | ||
""" | ||
Determine whether to continue or end the workflow based on if there are tool calls to make. | ||
Args: | ||
state (GraphState): The current graph state | ||
Returns: | ||
str: The next node to execute or END | ||
""" | ||
print("---SHOULD CONTINUE---") | ||
messages = state["messages"] | ||
last_message = messages[-1] if messages else None | ||
|
||
# If there is no function call, then we finish | ||
if not last_message or not last_message.tool_calls: | ||
return END | ||
else: | ||
return "tools" | ||
|
||
def get_runnable(): | ||
workflow = StateGraph(GraphState) | ||
|
||
# Define the nodes and how they connect | ||
workflow.add_node("agent", call_model) | ||
workflow.add_node("tools", tool_node) | ||
|
||
workflow.set_entry_point("agent") | ||
|
||
workflow.add_conditional_edges( | ||
"agent", | ||
should_continue | ||
) | ||
workflow.add_edge("tools", "agent") | ||
|
||
# Compile the LangGraph graph into a runnable | ||
memory = AsyncSqliteSaver.from_conn_string(":memory:") | ||
app = workflow.compile(checkpointer=memory) | ||
|
||
return app |
Oops, something went wrong.