Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
Signed-off-by: Prithvi Kannan <[email protected]>
  • Loading branch information
prithvikannan committed Dec 12, 2024
1 parent f18088f commit e226ae8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
45 changes: 28 additions & 17 deletions integrations/langchain/src/databricks_langchain/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
### langchain/chat_models.py ###

from typing import Iterator, List, Dict, Any, Optional, Union
from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage, AIMessageChunk, BaseMessage, ChatResult, ChatGeneration, ChatGenerationChunk
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, PydanticToolsParser, make_invalid_tool_call, parse_tool_call
AIMessage,
AIMessageChunk,
BaseMessage,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from databricks_langchain.utils import get_deployment_client

from .base_chat_models import BaseChatDatabricks


class ChatDatabricks(BaseChatDatabricks, BaseChatModel):
def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Any) -> ChatResult:
data = self._prepare_inputs([_convert_message_to_dict(msg) for msg in messages], stop, **kwargs)
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None, **kwargs: Any
) -> ChatResult:
data = self._prepare_inputs(
[_convert_message_to_dict(msg) for msg in messages], stop, **kwargs
)
resp = self.client.predict(endpoint=self.endpoint, inputs=data)
return self._convert_response_to_chat_result(resp)

def _stream(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, *, stream_usage: Optional[bool] = None, **kwargs: Any
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if stream_usage is None:
stream_usage = self.stream_usage
data = self._prepare_inputs([_convert_message_to_dict(msg) for msg in messages], stop, **kwargs)
data = self._prepare_inputs(
[_convert_message_to_dict(msg) for msg in messages], stop, **kwargs
)
first_chunk_role = None
for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data):
if chunk["choices"]:
Expand All @@ -40,7 +48,9 @@ def _stream(
first_chunk_role = chunk_delta.get("role")

usage = chunk.get("usage") if stream_usage else None
chunk_message = _convert_dict_to_message_chunk(chunk_delta, first_chunk_role, usage=usage)
chunk_message = _convert_dict_to_message_chunk(
chunk_delta, first_chunk_role, usage=usage
)
generation_info = {
"finish_reason": choice.get("finish_reason", ""),
"logprobs": choice.get("logprobs", {}),
Expand All @@ -57,6 +67,7 @@ def _llm_type(self) -> str:

### Conversion Functions ###


def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
message_dict = {"content": message.content}
if isinstance(message, AIMessage):
Expand Down
11 changes: 8 additions & 3 deletions src/databricks_ai_bridge/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### base_chat_models.py ###

from typing import List, Dict, Any, Optional, Union
from typing import Any, Dict, List, Optional

from databricks_ai_bridge.utils import get_deployment_client


Expand Down Expand Up @@ -38,7 +39,9 @@ def _default_params(self) -> Dict[str, Any]:
"extra_params": self.extra_params,
}

def _prepare_inputs(self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any) -> Dict[str, Any]:
def _prepare_inputs(
self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
data = {
"messages": messages,
"temperature": self.temperature,
Expand All @@ -63,7 +66,9 @@ def _convert_response_to_chat_result(self, response: Dict[str, Any]) -> Dict[str
usage = response.get("usage", {})
return {"generations": generations, "llm_output": usage}

def _stream(self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any):
def _stream(
self, messages: List[Dict[str, Any]], stop: Optional[List[str]] = None, **kwargs: Any
):
data = self._prepare_inputs(messages, stop, **kwargs)
for chunk in self.client.predict_stream(endpoint=self.endpoint, inputs=data):
if chunk["choices"]:
Expand Down
3 changes: 2 additions & 1 deletion src/databricks_ai_bridge/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any
from urllib.parse import urlparse


def get_deployment_client(target_uri: str) -> Any:
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
raise ValueError("Invalid target URI. The target URI must be a valid databricks URI.")
Expand All @@ -14,4 +15,4 @@ def get_deployment_client(target_uri: str) -> Any:
"Failed to create the client. "
"Please run `pip install mlflow` to install "
"required dependencies."
) from e
) from e

0 comments on commit e226ae8

Please sign in to comment.