From 6163780b3e69922aaf6456f282fd21101ed66566 Mon Sep 17 00:00:00 2001 From: Prithvi Kannan Date: Tue, 10 Dec 2024 11:41:17 -0800 Subject: [PATCH] format Signed-off-by: Prithvi Kannan --- .../integration_tests/test_chat_models.py | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/integrations/langchain/tests/integration_tests/test_chat_models.py b/integrations/langchain/tests/integration_tests/test_chat_models.py index d2fe498..a81f5dc 100644 --- a/integrations/langchain/tests/integration_tests/test_chat_models.py +++ b/integrations/langchain/tests/integration_tests/test_chat_models.py @@ -25,6 +25,7 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import tool +from langchain_databricks.chat_models import ChatDatabricks from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import START, StateGraph from langgraph.graph.message import add_messages @@ -32,15 +33,11 @@ from pydantic import BaseModel, Field from typing_extensions import TypedDict -from langchain_databricks.chat_models import ChatDatabricks - _TEST_ENDPOINT = "databricks-meta-llama-3-70b-instruct" def test_chat_databricks_invoke(): - chat = ChatDatabricks( - endpoint=_TEST_ENDPOINT, temperature=0, max_tokens=10, stop=["Java"] - ) + chat = ChatDatabricks(endpoint=_TEST_ENDPOINT, temperature=0, max_tokens=10, stop=["Java"]) response = chat.invoke("How to learn Java? Start the response by 'To learn Java,'") assert isinstance(response, AIMessage) @@ -49,13 +46,9 @@ def test_chat_databricks_invoke(): assert response.response_metadata["completion_tokens"] == 3 assert response.response_metadata["total_tokens"] == 27 - response = chat.invoke( - "How to learn Python? Start the response by 'To learn Python,'" - ) + response = chat.invoke("How to learn Python? Start the response by 'To learn Python,'") assert response.content.startswith("To learn Python,") - assert ( - len(response.content.split(" ")) <= 15 - ) # Give some margin for tokenization difference + assert len(response.content.split(" ")) <= 15 # Give some margin for tokenization difference # Call with a system message response = chat.invoke( @@ -156,9 +149,7 @@ async def test_chat_databricks_ainvoke(): max_tokens=10, ) - response = await chat.ainvoke( - "How to learn Python? Start the response by 'To learn Python,'" - ) + response = await chat.ainvoke("How to learn Python? Start the response by 'To learn Python,'") assert isinstance(response, AIMessage) assert response.content.startswith("To learn Python,") @@ -206,9 +197,7 @@ def test_chat_databricks_tool_calls(tool_choice): class GetWeather(BaseModel): """Get the current weather in a given location""" - location: str = Field( - ..., description="The city and state, e.g. San Francisco, CA" - ) + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") llm_with_tools = chat.bind_tools([GetWeather], tool_choice=tool_choice) question = "Which is the current weather in Los Angeles, CA?"