Skip to content

Commit

Permalink
Fixing multi-part response parsing (langchain-ai#165)
Browse files Browse the repository at this point in the history
* fixed LC->Content conversion

* fixed candidate parsing
  • Loading branch information
alx13 authored Apr 18, 2024
1 parent cf77485 commit fc8e92d
Show file tree
Hide file tree
Showing 4 changed files with 513 additions and 117 deletions.
141 changes: 81 additions & 60 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
elif isinstance(message, AIMessage):
raw_function_call = message.additional_kwargs.get("function_call")
role = "model"

parts = []
if message.content:
parts = _convert_to_parts(message)
if raw_function_call:
function_call = FunctionCall(
{
Expand All @@ -207,9 +211,7 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
}
)
gapic_part = GapicPart(function_call=function_call)
parts = [Part._from_gapic(gapic_part)]
else:
parts = _convert_to_parts(message)
parts.append(Part._from_gapic(gapic_part))
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message)
Expand Down Expand Up @@ -316,70 +318,89 @@ def _get_client_with_sys_instruction(
def _parse_response_candidate(
response_candidate: "Candidate", streaming: bool = False
) -> AIMessage:
try:
content = response_candidate.text
except AttributeError:
content = ""

content: Union[None, str, List[str]] = None
additional_kwargs = {}
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = {"name": first_part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
"args"
]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
if streaming:
tool_call_chunks = [
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
index=function_call.get("index"),
)
]
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
tool_calls = []
invalid_tool_calls = []
tool_call_chunks = []

for part in response_candidate.content.parts:
text = None
try:
text = part.text
except AttributeError:
text = None

if text is not None:
if content is None:
content = text
elif isinstance(content, str):
content = [content, text]
elif isinstance(content, list):
content.append(text)
else:
raise Exception("Unexpected content type")

if part.function_call:
# TODO: support multiple function calls
if "function_call" in additional_kwargs:
raise Exception("Multiple function calls are not currently supported")
function_call = {"name": part.function_call.name}
# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
function_call["arguments"] = json.dumps(
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
else:
tool_calls = []
invalid_tool_calls = []
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id", str(uuid.uuid4())),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
additional_kwargs["function_call"] = function_call

if streaming:
tool_call_chunks.append(
ToolCallChunk(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
index=function_call.get("index"),
)
)
else:
try:
tool_calls_dicts = parse_tool_calls(
[{"function": function_call}],
return_id=False,
)
]
tool_calls = [
ToolCall(
name=tool_call["name"],
args=tool_call["args"],
id=tool_call.get("id", str(uuid.uuid4())),
)
for tool_call in tool_calls_dicts
]
except Exception as e:
invalid_tool_calls = [
InvalidToolCall(
name=function_call.get("name"),
args=function_call.get("arguments"),
id=function_call.get("id", str(uuid.uuid4())),
error=str(e),
)
]
if content is None:
content = ""

return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
return AIMessage(content=content, additional_kwargs=additional_kwargs)
if streaming:
return AIMessageChunk(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
)

return AIMessage(
content=cast(Union[str, List[Union[str, Dict[Any, Any]]]], content),
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)


class ChatVertexAI(_VertexAICommon, BaseChatModel):
Expand Down
19 changes: 10 additions & 9 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 55 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.pydantic_v1 import BaseModel
Expand Down Expand Up @@ -433,3 +434,57 @@ class MyModel(BaseModel):
"name": "Erick",
"age": 27,
}


# Can be flaky
@pytest.mark.release
def test_chat_vertexai_gemini_function_calling_with_multiple_parts() -> None:
@tool
def search(
question: str,
):
"""
Useful for when you need to answer questions or visit websites.
You should ask targeted questions.
"""
return "brown"

tools = [search]

safety = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH
}
llm = ChatVertexAI(model_name="gemini-1.5-pro-preview-0409", safety_settings=safety)
llm_with_search = llm.bind(
functions=tools,
)
llm_with_search_force = llm_with_search.bind(
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["search"],
}
},
)
request = HumanMessage(
content="Please tell the primary color of following birds: sparrow, hawk, crow",
)
response = llm_with_search_force.invoke([request])

assert isinstance(response, AIMessage)
assert len(response.tool_calls) > 0
tool_call = response.tool_calls[0]
assert tool_call["name"] == "search"

tool_response = search("sparrow")
tool_message = ToolMessage(
name="search",
content=json.dumps(tool_response),
tool_call_id="0",
)

result = llm_with_search.invoke([request, response, tool_message])

assert isinstance(result, AIMessage)
assert "brown" in result.content
assert len(result.tool_calls) > 0
Loading

0 comments on commit fc8e92d

Please sign in to comment.