diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py index 2566387..85ed684 100644 --- a/integrations/langchain/src/databricks_langchain/chat_models.py +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -1,36 +1,44 @@ ### langchain/chat_models.py ### -from typing import Iterator, List, Dict, Any, Optional, Union +from typing import Any, Dict, Iterator, List, Optional + from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( - AIMessage, AIMessageChunk, BaseMessage, ChatResult, ChatGeneration, ChatGenerationChunk -) -from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser -from langchain_core.output_parsers.base import OutputParserLike -from langchain_core.output_parsers.openai_tools import ( - JsonOutputKeyToolsParser, PydanticToolsParser, make_invalid_tool_call, parse_tool_call + AIMessage, + AIMessageChunk, + BaseMessage, + ChatGenerationChunk, + ChatResult, ) -from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough -from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_tool -from langchain_core.utils.pydantic import is_basemodel_subclass -from databricks_langchain.utils import get_deployment_client + from .base_chat_models import BaseChatDatabricks class ChatDatabricks(BaseChatDatabricks, BaseChatModel): - def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Any) -> ChatResult: - data = self._prepare_inputs([_convert_message_to_dict(msg) for msg in messages], stop, **kwargs) + def _generate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Any + ) -> ChatResult: + data = self._prepare_inputs( + [_convert_message_to_dict(msg) for msg in messages], stop, **kwargs + ) resp = self.client.predict(endpoint=self.endpoint, inputs=data) return self._convert_response_to_chat_result(resp) def _stream( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, **kwargs: Any + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, + **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: if stream_usage is None: stream_usage = self.stream_usage - data = self._prepare_inputs([_convert_message_to_dict(msg) for msg in messages], stop, **kwargs) + data = self._prepare_inputs( + [_convert_message_to_dict(msg) for msg in messages], stop, **kwargs + ) first_chunk_role = None for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): if chunk["choices"]: @@ -40,7 +48,9 @@ def _stream( first_chunk_role = chunk_delta.get("role") usage = chunk.get("usage") if stream_usage else None - chunk_message = _convert_dict_to_message_chunk(chunk_delta, first_chunk_role, usage=usage) + chunk_message = _convert_dict_to_message_chunk( + chunk_delta, first_chunk_role, usage=usage + ) generation_info = { "finish_reason": choice.get("finish_reason", ""), "logprobs": choice.get("logprobs", {}), @@ -57,6 +67,7 @@ def _llm_type(self) -> str: ### Conversion Functions ### + def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: message_dict = {"content": message.content} if isinstance(message, AIMessage): diff --git a/src/databricks_ai_bridge/chat_models.py b/src/databricks_ai_bridge/chat_models.py index c2e388f..4ae9348 100644 --- a/src/databricks_ai_bridge/chat_models.py +++ b/src/databricks_ai_bridge/chat_models.py @@ -1,6 +1,7 @@ ### base_chat_models.py ### -from typing import List, Dict, Any, Optional, Union +from typing import Any, Dict, List, Optional + from databricks_ai_bridge.utils import get_deployment_client @@ -38,7 +39,9 @@ def _default_params(self) -> Dict[str, Any]: "extra_params": self.extra_params, } - def _prepare_inputs(self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any) -> Dict[str, Any]: + def _prepare_inputs( + self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: data = { "messages": messages, "temperature": self.temperature, @@ -63,7 +66,9 @@ def _convert_response_to_chat_result(self, response: Dict[str, Any]) -> Dict[str usage = response.get("usage", {}) return {"generations": generations, "llm_output": usage} - def _stream(self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any): + def _stream( + self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any + ): data = self._prepare_inputs(messages, stop, **kwargs) for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data): if chunk["choices"]: diff --git a/src/databricks_ai_bridge/utils.py b/src/databricks_ai_bridge/utils.py index 9df05ea..b465506 100644 --- a/src/databricks_ai_bridge/utils.py +++ b/src/databricks_ai_bridge/utils.py @@ -1,6 +1,7 @@ from typing import Any from urllib.parse import urlparse + def get_deployment_client(target_uri: str) -> Any: if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"): raise ValueError("Invalid target URI. The target URI must be a valid databricks URI.") @@ -14,4 +15,4 @@ def get_deployment_client(target_uri: str) -> Any: "Failed to create the client. " "Please run `pip install mlflow` to install " "required dependencies." - ) from e \ No newline at end of file + ) from e