Skip to content

Commit

Permalink
google-vertexai[patch]: fix _parse_response_candidate issue (#16647)
Browse files Browse the repository at this point in the history
**Description:** enable _parse_response_candidate to support complex
structure format.
  **Issue:** 
currently, if Gemini response complex args format, people will get
"TypeError: Object of type RepeatedComposite is not JSON serializable"
error from _parse_response_candidate.
  
 response candidate example
```
content {
  role: "model"
  parts {
    function_call {
      name: "Information"
      args {
        fields {
          key: "people"
          value {
            list_value {
              values {
                string_value: "Joe is 30, his mom is Martha"
              }
            }
          }
        }
      }
    }
  }
}
finish_reason: STOP
safety_ratings {
  category: HARM_CATEGORY_HARASSMENT
  probability: NEGLIGIBLE
}
safety_ratings {
  category: HARM_CATEGORY_HATE_SPEECH
  probability: NEGLIGIBLE
}
safety_ratings {
  category: HARM_CATEGORY_SEXUALLY_EXPLICIT
  probability: NEGLIGIBLE
}
safety_ratings {
  category: HARM_CATEGORY_DANGEROUS_CONTENT
  probability: NEGLIGIBLE
}
```
 
error msg:
```
Traceback (most recent call last):
  File "/home/jupyter/user/abehsu/gemini_langchain_tools/example2.py", line 36, in <module>
    print(tagging_chain.invoke({"input": "Joe is 30, his mom is Martha"}))
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2053, in invoke
    input = step.invoke(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3887, in invoke
    return self.bound.invoke(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 165, in invoke
    self.generate_prompt(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 543, in generate_prompt
    return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 407, in generate
    raise e
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 397, in generate
    self._generate_with_cache(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 576, in _generate_with_cache
    return self._generate(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 406, in _generate
    generations = [
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 408, in <listcomp>
    message=_parse_response_candidate(c),
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 280, in _parse_response_candidate
    function_call["arguments"] = json.dumps(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/__init__.py", line 231, in dumps
    return _default_encoder.encode(obj)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type RepeatedComposite is not JSON serializable
```
  

  **Twitter handle:**  @abehsu1992626
  • Loading branch information
hsuyuming authored Feb 8, 2024
1 parent d77bb7b commit e22c4d4
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse

import proto # type: ignore[import-untyped]
import requests
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
Expand Down Expand Up @@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = {"name": first_part.function_call.name}

# dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
"args"
]
function_call["arguments"] = json.dumps(
{k: first_part.function_call.args[k] for k in first_part.function_call.args}
{k: function_call_args_dict[k] for k in function_call_args_dict}
)
additional_kwargs["function_call"] = function_call
return AIMessage(content=content, additional_kwargs=additional_kwargs)
Expand Down
114 changes: 114 additions & 0 deletions libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
"""Test chat model integration."""

import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch

import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
FunctionCall,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
)

from langchain_google_vertexai.chat_models import (
ChatVertexAI,
_parse_chat_history,
_parse_chat_history_gemini,
_parse_examples,
_parse_response_candidate,
)


Expand Down Expand Up @@ -202,3 +215,104 @@ def test_default_params_gemini() -> None:
message = HumanMessage(content=user_prompt)
_ = model([message])
mock_start_chat.assert_called_once_with(history=[])


@pytest.mark.parametrize(
"raw_candidate, expected",
[
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
),
)
],
)
),
{
"name": "Information",
"arguments": {"name": "Ben"},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": ["A", "B", "C"]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": ["A", "B", "C"]},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
),
)
],
)
),
{
"name": "Information",
"arguments": {
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": [[1, 2, 3], [4, 5, 6]]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
},
),
],
)
def test_parse_response_candidate(raw_candidate, expected) -> None:
response_candidate = Candidate._from_gapic(raw_candidate)
result = _parse_response_candidate(response_candidate)
result_arguments = json.loads(
result.additional_kwargs["function_call"]["arguments"]
)

assert result_arguments == expected["arguments"]

0 comments on commit e22c4d4

Please sign in to comment.