Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mistral: read tool calls from AIMessage #20554

Merged
merged 10 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import logging
import uuid
from operator import itemgetter
Expand Down Expand Up @@ -42,8 +43,10 @@
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
Expand Down Expand Up @@ -223,6 +226,34 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)


def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
}
}
if _id := tool_call.get("id"):
result["id"] = _id

return result


def _format_invalid_tool_call_for_mistral(tool_call: InvalidToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral."""
ccurme marked this conversation as resolved.
Show resolved Hide resolved
result: Dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": tool_call["args"],
}
}
if _id := tool_call.get("id"):
result["id"] = _id

return result


def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> Dict:
Expand All @@ -231,8 +262,8 @@ def _convert_message_to_mistral_chat_message(
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
tool_calls = []
if "tool_calls" in message.additional_kwargs:
tool_calls = []
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
Expand All @@ -243,8 +274,15 @@ def _convert_message_to_mistral_chat_message(
if _id := tc.get("id"):
chunk["id"] = _id
tool_calls.append(chunk)
elif message.tool_calls or message.invalid_tool_calls:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be prioritized over additional_kwargs? The latter should be the source of truth if swapping providers etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to either. I actually implemented it that way first, but it introduces small issues with casting to/from Mistral dicts, since the additional_kwargs contain (1) information on order of tool calls (we lose that with the separation into valid and invalid tool call lists), and (2) JSON formatting, since we parse and then dumps the args. Our unit tests maintain that we can go back/forth so I wasn't sure if we wanted to lose that. These are small issues; if we think there's a compensating benefit to prioritizing .tool_calls then happy to change it.

Copy link
Collaborator Author

@ccurme ccurme Apr 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah if we have a provider that (1) formats tool_calls into additional_kwargs, and (2) does not use OpenAI-style formatting then this will break. updated this. see changes to unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@efriis see above - is it OK to lose the behavior maintained in unit tests, in which applying _convert_message_to_mistral_chat_message and _convert_mistral_chat_message_to_message returns the original message identically?

for tool_call in message.tool_calls:
tool_calls.append(_format_tool_call_for_mistral(tool_call))
for invalid_tool_call in message.invalid_tool_calls:
tool_calls.append(
_format_invalid_tool_call_for_mistral(invalid_tool_call)
)
else:
tool_calls = []
pass
return {
"role": "assistant",
"content": message.content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_structured_output() -> None:


def test_streaming_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)

class Person(BaseModel):
name: str
Expand All @@ -156,7 +156,7 @@ class Person(BaseModel):


def test_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)

class Person(BaseModel):
name: str
Expand All @@ -173,7 +173,7 @@ class Person(BaseModel):


def test_streaming_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)

class Person(BaseModel):
name: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ class TestMistralStandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatMistralAI

@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "mistral-large-latest",
"temperature": 0,
}
Loading