Skip to content

Commit

Permalink
mistral: read tool calls from AIMessage (#20554)
Browse files Browse the repository at this point in the history
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
ccurme and eyurtsev authored Apr 17, 2024
1 parent f257909 commit 4a17951
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 11 deletions.
44 changes: 41 additions & 3 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import logging
import uuid
from operator import itemgetter
Expand Down Expand Up @@ -42,8 +43,10 @@
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
InvalidToolCall,
SystemMessage,
SystemMessageChunk,
ToolCall,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
Expand Down Expand Up @@ -223,6 +226,34 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)


def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
"""Format Langchain ToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
}
}
if _id := tool_call.get("id"):
result["id"] = _id

return result


def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) -> dict:
"""Format Langchain InvalidToolCall to dict expected by Mistral."""
result: Dict[str, Any] = {
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
}
}
if _id := invalid_tool_call.get("id"):
result["id"] = _id

return result


def _convert_message_to_mistral_chat_message(
message: BaseMessage,
) -> Dict:
Expand All @@ -231,8 +262,15 @@ def _convert_message_to_mistral_chat_message(
elif isinstance(message, HumanMessage):
return dict(role="user", content=message.content)
elif isinstance(message, AIMessage):
if "tool_calls" in message.additional_kwargs:
tool_calls = []
tool_calls = []
if message.tool_calls or message.invalid_tool_calls:
for tool_call in message.tool_calls:
tool_calls.append(_format_tool_call_for_mistral(tool_call))
for invalid_tool_call in message.invalid_tool_calls:
tool_calls.append(
_format_invalid_tool_call_for_mistral(invalid_tool_call)
)
elif "tool_calls" in message.additional_kwargs:
for tc in message.additional_kwargs["tool_calls"]:
chunk = {
"function": {
Expand All @@ -244,7 +282,7 @@ def _convert_message_to_mistral_chat_message(
chunk["id"] = _id
tool_calls.append(chunk)
else:
tool_calls = []
pass
return {
"role": "assistant",
"content": message.content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_structured_output() -> None:


def test_streaming_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)

class Person(BaseModel):
name: str
Expand All @@ -156,7 +156,7 @@ class Person(BaseModel):


def test_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)

class Person(BaseModel):
name: str
Expand All @@ -173,7 +173,7 @@ class Person(BaseModel):


def test_streaming_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)

class Person(BaseModel):
name: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ class TestMistralStandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatMistralAI

@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "mistral-large-latest",
"temperature": 0,
}
10 changes: 5 additions & 5 deletions libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
}
Expand All @@ -153,16 +153,16 @@ def test__convert_dict_to_message_tool_call() -> None:
# Test malformed tool call
raw_tool_calls = [
{
"id": "abc123",
"id": "def456",
"function": {
"arguments": "oops",
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
},
{
"id": "def456",
"id": "abc123",
"function": {
"arguments": '{"name":"Sally","hair_color":"green"}',
"arguments": "oops",
"name": "GenerateUsername",
},
},
Expand Down

0 comments on commit 4a17951

Please sign in to comment.