Skip to content

Commit

Permalink
added retries for streaming (langchain-ai#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Apr 20, 2024
1 parent 6e2cef0 commit 0486bc2
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def _completion_with_retry(
generation_method: Callable,
*,
max_retries: int,
check_stream_response_for_candidates: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
Expand All @@ -437,6 +438,14 @@ def _completion_with_retry(
@retry_decorator
def _completion_with_retry_inner(generation_method: Callable, **kwargs: Any) -> Any:
response = generation_method(**kwargs)
if kwargs.get("stream") and check_stream_response_for_candidates:
chunks = list(response)
for chunk in chunks:
if not chunk.candidates:
raise ValueError("Got 0 candidates from generations.")
return iter(chunks)
if kwargs.get("stream"):
return response
if len(response.candidates):
return response
else:
Expand All @@ -449,6 +458,7 @@ async def _acompletion_with_retry(
generation_method: Callable,
*,
max_retries: int,
check_stream_response_for_candidates: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
Expand All @@ -462,6 +472,14 @@ async def _completion_with_retry_inner(
generation_method: Callable, **kwargs: Any
) -> Any:
response = await generation_method(**kwargs)
if kwargs.get("stream") and check_stream_response_for_candidates:
chunks = list(response)
for chunk in chunks:
if not chunk.candidates:
raise ValueError("Got 0 candidates from generations.")
return iter(chunks)
if kwargs.get("stream"):
return response
if len(response.candidates):
return response
else:
Expand All @@ -484,6 +502,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""[Deprecated] Since new Gemini models support setting a System Message,
setting this parameter to True is discouraged.
"""
check_stream_response_for_candidates: bool = False
"""Retrieves all chunks from streaming response and check all of them
have candidates. If not, retries.
It makes streaming mode essentially useless."""

@classmethod
def is_lc_serializable(self) -> bool:
Expand Down Expand Up @@ -711,7 +733,13 @@ def _stream(
client, contents = self._gemini_client_and_contents(messages)
params = self._gemini_params(stop=stop, stream=True, **kwargs)
with telemetry.tool_context_manager(self._user_agent):
response_iter = client.generate_content(contents, **params, stream=True)
response_iter = _completion_with_retry(
client.generate_content,
max_retries=self.max_retries,
contents=contents,
check_stream_response_for_candidates=self.check_stream_response_for_candidates,
**params,
)
for response_chunk in response_iter:
chunk = self._gemini_chunk_to_generation_chunk(response_chunk)
if run_manager and isinstance(chunk.message.content, str):
Expand Down Expand Up @@ -758,8 +786,12 @@ async def _astream(
client, contents = self._gemini_client_and_contents(messages)
params = self._gemini_params(stop=stop, stream=True, **kwargs)
with telemetry.tool_context_manager(self._user_agent):
async for response_chunk in await client.generate_content_async(
contents, **params, stream=True
async for response_chunk in await _acompletion_with_retry(
client.generate_content_async,
max_retries=self.max_retries,
contents=contents,
check_stream_response_for_candidates=self.check_stream_response_for_candidates,
**params,
):
chunk = self._gemini_chunk_to_generation_chunk(response_chunk)
if run_manager and isinstance(chunk.message.content, str):
Expand Down

0 comments on commit 0486bc2

Please sign in to comment.