diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index 7a4776f9..c1f929fd 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -426,6 +426,7 @@ def _completion_with_retry( generation_method: Callable, *, max_retries: int, + check_stream_response_for_candidates: bool = False, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: @@ -437,6 +438,14 @@ def _completion_with_retry( @retry_decorator def _completion_with_retry_inner(generation_method: Callable, **kwargs: Any) -> Any: response = generation_method(**kwargs) + if kwargs.get("stream") and check_stream_response_for_candidates: + chunks = list(response) + for chunk in chunks: + if not chunk.candidates: + raise ValueError("Got 0 candidates from generations.") + return iter(chunks) + if kwargs.get("stream"): + return response if len(response.candidates): return response else: @@ -449,6 +458,7 @@ async def _acompletion_with_retry( generation_method: Callable, *, max_retries: int, + check_stream_response_for_candidates: bool = False, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: @@ -462,6 +472,14 @@ async def _completion_with_retry_inner( generation_method: Callable, **kwargs: Any ) -> Any: response = await generation_method(**kwargs) + if kwargs.get("stream") and check_stream_response_for_candidates: + chunks = list(response) + for chunk in chunks: + if not chunk.candidates: + raise ValueError("Got 0 candidates from generations.") + return iter(chunks) + if kwargs.get("stream"): + return response if len(response.candidates): return response else: @@ -484,6 +502,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): """[Deprecated] Since new Gemini models support setting a System Message, setting this parameter to True is discouraged. """ + check_stream_response_for_candidates: bool = False + """Retrieves all chunks from streaming response and check all of them + have candidates. If not, retries. + It makes streaming mode essentially useless.""" @classmethod def is_lc_serializable(self) -> bool: @@ -711,7 +733,13 @@ def _stream( client, contents = self._gemini_client_and_contents(messages) params = self._gemini_params(stop=stop, stream=True, **kwargs) with telemetry.tool_context_manager(self._user_agent): - response_iter = client.generate_content(contents, **params, stream=True) + response_iter = _completion_with_retry( + client.generate_content, + max_retries=self.max_retries, + contents=contents, + check_stream_response_for_candidates=self.check_stream_response_for_candidates, + **params, + ) for response_chunk in response_iter: chunk = self._gemini_chunk_to_generation_chunk(response_chunk) if run_manager and isinstance(chunk.message.content, str): @@ -758,8 +786,12 @@ async def _astream( client, contents = self._gemini_client_and_contents(messages) params = self._gemini_params(stop=stop, stream=True, **kwargs) with telemetry.tool_context_manager(self._user_agent): - async for response_chunk in await client.generate_content_async( - contents, **params, stream=True + async for response_chunk in await _acompletion_with_retry( + client.generate_content_async, + max_retries=self.max_retries, + contents=contents, + check_stream_response_for_candidates=self.check_stream_response_for_candidates, + **params, ): chunk = self._gemini_chunk_to_generation_chunk(response_chunk) if run_manager and isinstance(chunk.message.content, str):