Skip to content

Commit

Permalink
Schema Update
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Sep 13, 2024
1 parent 5a82f8a commit 23408cc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/react_agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from datetime import datetime, timezone
from typing import Dict, List, Literal, cast

from langchain.chat_models import init_chat_model
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
Expand All @@ -16,6 +15,7 @@
from react_agent.configuration import Configuration
from react_agent.state import InputState, State
from react_agent.tools import TOOLS
from react_agent.utils import load_chat_model

# Define the function that calls the model

Expand All @@ -42,7 +42,7 @@ async def call_model(
)

# Initialize the model with tool binding. Change the model or add more tools here.
model = init_chat_model(configuration.model_name).bind_tools(TOOLS)
model = load_chat_model(configuration.model_name).bind_tools(TOOLS)

# Prepare the input for the model, including the current system time
message_value = await prompt.ainvoke(
Expand Down
12 changes: 12 additions & 0 deletions src/react_agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility & helper functions."""

from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage


Expand All @@ -13,3 +15,13 @@ def get_message_text(msg: BaseMessage) -> str:
else:
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
return "".join(txts).strip()


def load_chat_model(fully_specified_name: str) -> BaseChatModel:
"""Load a chat model from a fully specified name.
Args:
fully_specified_name (str): String in the format 'provider/model'.
"""
provider, model = fully_specified_name.split("/", maxsplit=1)
return init_chat_model(model, model_provider=provider)

0 comments on commit 23408cc

Please sign in to comment.