diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 1c62d76c..096f1800 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -9,6 +9,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union, cast from urllib.parse import urlparse +import proto # type: ignore[import-untyped] import requests from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall @@ -45,6 +46,12 @@ Image, Part, ) +from vertexai.preview.language_models import ( # type: ignore + ChatModel as PreviewChatModel, +) +from vertexai.preview.language_models import ( + CodeChatModel as PreviewCodeChatModel, +) from langchain_google_vertexai._utils import ( get_generation_info, @@ -272,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage: first_part = response_candidate.content.parts[0] if first_part.function_call: function_call = {"name": first_part.function_call.name} - # dump to match other function calling llm for now + function_call_args_dict = proto.Message.to_dict(first_part.function_call)[ + "args" + ] function_call["arguments"] = json.dumps( - {k: first_part.function_call.args[k] for k in first_part.function_call.args} + {k: function_call_args_dict[k] for k in function_call_args_dict} ) additional_kwargs["function_call"] = function_call return AIMessage(content=content, additional_kwargs=additional_kwargs) @@ -316,12 +325,20 @@ def validate_environment(cls, values: Dict) -> Dict: values["client"] = GenerativeModel( model_name=values["model_name"], safety_settings=safety_settings ) + values["client_preview"] = GenerativeModel( + model_name=values["model_name"], safety_settings=safety_settings + ) else: if is_codey_model(values["model_name"]): model_cls = CodeChatModel + model_cls_preview = PreviewCodeChatModel else: model_cls = ChatModel + model_cls_preview = PreviewChatModel values["client"] = model_cls.from_pretrained(values["model_name"]) + values["client_preview"] = model_cls_preview.from_pretrained( + values["model_name"] + ) return values def _generate( @@ -493,8 +510,13 @@ def _stream( # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + safety_settings = params.pop("safety_settings", None) responses = chat.send_message( - message, stream=True, generation_config=params, tools=tools + message, + stream=True, + generation_config=params, + safety_settings=safety_settings, + tools=tools, ) for response in responses: message = _parse_response_candidate(response.candidates[0]) diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index bd4d346a..9b3099b4 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -31,6 +31,12 @@ Image, ) from vertexai.preview.language_models import ( # type: ignore[import-untyped] + ChatModel as PreviewChatModel, +) +from vertexai.preview.language_models import ( + CodeChatModel as PreviewCodeChatModel, +) +from vertexai.preview.language_models import ( CodeGenerationModel as PreviewCodeGenerationModel, ) from vertexai.preview.language_models import ( @@ -239,6 +245,27 @@ def _prepare_params( params.pop("candidate_count") return params + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + is_palm_chat_model = isinstance( + self.client_preview, PreviewChatModel + ) or isinstance(self.client_preview, PreviewCodeChatModel) + if is_palm_chat_model: + result = self.client_preview.start_chat().count_tokens(text) + else: + result = self.client_preview.count_tokens([text]) + + return result.total_tokens + class VertexAI(_VertexAICommon, BaseLLM): """Google Vertex AI large language models.""" @@ -300,20 +327,6 @@ def validate_environment(cls, values: Dict) -> Dict: raise ValueError("Only one candidate can be generated with streaming!") return values - def get_num_tokens(self, text: str) -> int: - """Get the number of tokens present in the text. - - Useful for checking if an input will fit in a model's context window. - - Args: - text: The string input to tokenize. - - Returns: - The integer number of tokens in the text. - """ - result = self.client_preview.count_tokens([text]) - return result.total_tokens - def _response_to_generation( self, response: TextGenerationResponse, *, stream: bool = False ) -> GenerationChunk: diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index 94c7ea6a..33853310 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -225,6 +225,18 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None: assert isinstance(response.content, str) +@pytest.mark.parametrize("model_name", model_names_to_test) +def test_get_num_tokens_from_messages(model_name: str) -> None: + if model_name: + model = ChatVertexAI(model_name=model_name, temperature=0.0) + else: + model = ChatVertexAI(temperature=0.0) + message = HumanMessage(content="Hello") + token = model.get_num_tokens_from_messages(messages=[message]) + assert isinstance(token, int) + assert token == 3 + + def test_chat_vertexai_gemini_function_calling() -> None: class MyModel(BaseModel): name: str diff --git a/libs/vertexai/tests/integration_tests/test_tools.py b/libs/vertexai/tests/integration_tests/test_tools.py index 58bc7ecc..5da72240 100644 --- a/libs/vertexai/tests/integration_tests/test_tools.py +++ b/libs/vertexai/tests/integration_tests/test_tools.py @@ -81,7 +81,6 @@ def test_tools() -> None: agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"}) - print(response) assert isinstance(response, dict) assert response["input"] == "What is 6 raised to the 0.43 power?" @@ -106,7 +105,6 @@ def test_stream() -> None: ] response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools)) assert len(response) == 1 - # for chunk in response: assert isinstance(response[0], AIMessageChunk) assert "function_call" in response[0].additional_kwargs diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 6bcbb6e5..f24f418b 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -1,22 +1,35 @@ """Test chat model integration.""" +import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch import pytest +from google.cloud.aiplatform_v1beta1.types import ( + Content, + FunctionCall, + Part, +) +from google.cloud.aiplatform_v1beta1.types import ( + content as gapic_content_types, +) from langchain_core.messages import ( AIMessage, HumanMessage, SystemMessage, ) from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore +from vertexai.preview.generative_models import ( # type: ignore + Candidate, +) from langchain_google_vertexai.chat_models import ( ChatVertexAI, _parse_chat_history, _parse_chat_history_gemini, _parse_examples, + _parse_response_candidate, ) @@ -202,3 +215,104 @@ def test_default_params_gemini() -> None: message = HumanMessage(content=user_prompt) _ = model([message]) mock_start_chat.assert_called_once_with(history=[]) + + +@pytest.mark.parametrize( + "raw_candidate, expected", + [ + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"name": "Ben"}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"name": "Ben"}, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"info": ["A", "B", "C"]}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"info": ["A", "B", "C"]}, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={ + "people": [ + {"name": "Joe", "age": 30}, + {"name": "Martha"}, + ] + }, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": { + "people": [ + {"name": "Joe", "age": 30}, + {"name": "Martha"}, + ] + }, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"info": [[1, 2, 3], [4, 5, 6]]}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"info": [[1, 2, 3], [4, 5, 6]]}, + }, + ), + ], +) +def test_parse_response_candidate(raw_candidate, expected) -> None: + response_candidate = Candidate._from_gapic(raw_candidate) + result = _parse_response_candidate(response_candidate) + result_arguments = json.loads( + result.additional_kwargs["function_call"]["arguments"] + ) + + assert result_arguments == expected["arguments"]