Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <[email protected]>
  • Loading branch information
prithvikannan committed Dec 10, 2024
1 parent 2c18892 commit 6163780
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions integrations/langchain/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,19 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain_databricks.chat_models import ChatDatabricks
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, create_react_agent, tools_condition
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langchain_databricks.chat_models import ChatDatabricks

_TEST_ENDPOINT = "databricks-meta-llama-3-70b-instruct"


def test_chat_databricks_invoke():
chat = ChatDatabricks(
endpoint=_TEST_ENDPOINT, temperature=0, max_tokens=10, stop=["Java"]
)
chat = ChatDatabricks(endpoint=_TEST_ENDPOINT, temperature=0, max_tokens=10, stop=["Java"])

response = chat.invoke("How to learn Java? Start the response by 'To learn Java,'")
assert isinstance(response, AIMessage)
Expand All @@ -49,13 +46,9 @@ def test_chat_databricks_invoke():
assert response.response_metadata["completion_tokens"] == 3
assert response.response_metadata["total_tokens"] == 27

response = chat.invoke(
"How to learn Python? Start the response by 'To learn Python,'"
)
response = chat.invoke("How to learn Python? Start the response by 'To learn Python,'")
assert response.content.startswith("To learn Python,")
assert (
len(response.content.split(" ")) <= 15
) # Give some margin for tokenization difference
assert len(response.content.split(" ")) <= 15 # Give some margin for tokenization difference

# Call with a system message
response = chat.invoke(
Expand Down Expand Up @@ -156,9 +149,7 @@ async def test_chat_databricks_ainvoke():
max_tokens=10,
)

response = await chat.ainvoke(
"How to learn Python? Start the response by 'To learn Python,'"
)
response = await chat.ainvoke("How to learn Python? Start the response by 'To learn Python,'")
assert isinstance(response, AIMessage)
assert response.content.startswith("To learn Python,")

Expand Down Expand Up @@ -206,9 +197,7 @@ def test_chat_databricks_tool_calls(tool_choice):
class GetWeather(BaseModel):
"""Get the current weather in a given location"""

location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")

llm_with_tools = chat.bind_tools([GetWeather], tool_choice=tool_choice)
question = "Which is the current weather in Los Angeles, CA?"
Expand Down

0 comments on commit 6163780

Please sign in to comment.