diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index c11a0354..7a4776f9 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -10,6 +10,7 @@ from typing import ( Any, AsyncIterator, + Callable, Dict, Iterator, List, @@ -91,6 +92,7 @@ ) from langchain_google_vertexai._image_utils import ImageBytesLoader from langchain_google_vertexai._utils import ( + create_retry_decorator, get_generation_info, is_codey_model, is_gemini_model, @@ -420,6 +422,54 @@ def _parse_response_candidate( ) +def _completion_with_retry( + generation_method: Callable, + *, + max_retries: int, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = create_retry_decorator( + max_retries=max_retries, run_manager=run_manager + ) + + @retry_decorator + def _completion_with_retry_inner(generation_method: Callable, **kwargs: Any) -> Any: + response = generation_method(**kwargs) + if len(response.candidates): + return response + else: + raise ValueError("Got 0 candidates from generations.") + + return _completion_with_retry_inner(generation_method, **kwargs) + + +async def _acompletion_with_retry( + generation_method: Callable, + *, + max_retries: int, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = create_retry_decorator( + max_retries=max_retries, run_manager=run_manager + ) + + @retry_decorator + async def _completion_with_retry_inner( + generation_method: Callable, **kwargs: Any + ) -> Any: + response = await generation_method(**kwargs) + if len(response.candidates): + return response + else: + raise ValueError("Got 0 candidates from generations.") + + return await _completion_with_retry_inner(generation_method, **kwargs) + + class ChatVertexAI(_VertexAICommon, BaseChatModel): """`Vertex AI` Chat large language models API.""" @@ -524,7 +574,12 @@ def _generate( client, contents = self._gemini_client_and_contents(messages) params = self._gemini_params(stop=stop, **kwargs) with telemetry.tool_context_manager(self._user_agent): - response = client.generate_content(contents, **params) + response = _completion_with_retry( + client.generate_content, + max_retries=self.max_retries, + contents=contents, + **params, + ) return self._gemini_response_to_chat_result(response) def _generate_non_gemini( @@ -545,7 +600,12 @@ def _generate_non_gemini( params["examples"] = _parse_examples(examples) with telemetry.tool_context_manager(self._user_agent): chat = self._start_chat(history, **params) - response = chat.send_message(question.content, **msg_params) + response = _completion_with_retry( + chat.send_message, + max_retries=self.max_retries, + message=question.content, + **msg_params, + ) generations = [ ChatGeneration( message=AIMessage(content=candidate.text), @@ -590,7 +650,12 @@ async def _agenerate( client, contents = self._gemini_client_and_contents(messages) params = self._gemini_params(stop=stop, **kwargs) with telemetry.tool_context_manager(self._user_agent): - response = await client.generate_content_async(contents, **params) + response = await _acompletion_with_retry( + client.generate_content_async, + max_retries=self.max_retries, + contents=contents, + **params, + ) return self._gemini_response_to_chat_result(response) async def _agenerate_non_gemini( @@ -611,7 +676,12 @@ async def _agenerate_non_gemini( params["examples"] = _parse_examples(examples) with telemetry.tool_context_manager(self._user_agent): chat = self._start_chat(history, **params) - response = await chat.send_message_async(question.content, **msg_params) + response = await _acompletion_with_retry( + chat.send_message_async, + message=question.content, + max_retries=self.max_retries, + **msg_params, + ) generations = [ ChatGeneration( message=AIMessage(content=candidate.text), diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index fdc2481f..1a0c1468 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -59,13 +59,13 @@ def test_langchain_google_vertexai_large_batches() -> None: model_uscentral1 = VertexAIEmbeddings( model_name="textembedding-gecko@001", location="us-central1" ) - model_asianortheast1 = VertexAIEmbeddings( - model_name="textembedding-gecko@001", location="asia-northeast1" - ) + # model_asianortheast1 = VertexAIEmbeddings( + # model_name="textembedding-gecko@001", location="asia-northeast1" + # ) model_uscentral1.embed_documents(documents) - model_asianortheast1.embed_documents(documents) + # model_asianortheast1.embed_documents(documents) assert model_uscentral1.instance["batch_size"] >= 250 - assert model_asianortheast1.instance["batch_size"] < 50 + # assert model_asianortheast1.instance["batch_size"] < 50 @pytest.mark.release diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 4b66dabe..c0d18486 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -133,7 +133,9 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None: response = model([message]) assert response.content == response_text - mock_send_message.assert_called_once_with(user_prompt, candidate_count=1) + mock_send_message.assert_called_once_with( + message=user_prompt, candidate_count=1 + ) expected_stop_sequence = [stop] if stop else None mock_start_chat.assert_called_once_with( context=None, @@ -491,7 +493,10 @@ def test_default_params_gemini() -> None: message = HumanMessage(content=user_prompt) _ = model.invoke([message]) mock_generate_content.assert_called_once() - assert mock_generate_content.call_args.args[0][0].parts[0].text == user_prompt + assert ( + mock_generate_content.call_args.kwargs["contents"][0].parts[0].text + == user_prompt + ) @pytest.mark.parametrize(