Skip to content

Commit

Permalink
added retries for ChatVertexAI models (langchain-ai#170)
Browse files Browse the repository at this point in the history
* added retries for ChatVertexAI models
  • Loading branch information
lkuligin authored Apr 20, 2024
1 parent 9c10541 commit 6e2cef0
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 11 deletions.
78 changes: 74 additions & 4 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Expand Down Expand Up @@ -91,6 +92,7 @@
)
from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._utils import (
create_retry_decorator,
get_generation_info,
is_codey_model,
is_gemini_model,
Expand Down Expand Up @@ -420,6 +422,54 @@ def _parse_response_candidate(
)


def _completion_with_retry(
generation_method: Callable,
*,
max_retries: int,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
max_retries=max_retries, run_manager=run_manager
)

@retry_decorator
def _completion_with_retry_inner(generation_method: Callable, **kwargs: Any) -> Any:
response = generation_method(**kwargs)
if len(response.candidates):
return response
else:
raise ValueError("Got 0 candidates from generations.")

return _completion_with_retry_inner(generation_method, **kwargs)


async def _acompletion_with_retry(
generation_method: Callable,
*,
max_retries: int,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
max_retries=max_retries, run_manager=run_manager
)

@retry_decorator
async def _completion_with_retry_inner(
generation_method: Callable, **kwargs: Any
) -> Any:
response = await generation_method(**kwargs)
if len(response.candidates):
return response
else:
raise ValueError("Got 0 candidates from generations.")

return await _completion_with_retry_inner(generation_method, **kwargs)


class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""

Expand Down Expand Up @@ -524,7 +574,12 @@ def _generate(
client, contents = self._gemini_client_and_contents(messages)
params = self._gemini_params(stop=stop, **kwargs)
with telemetry.tool_context_manager(self._user_agent):
response = client.generate_content(contents, **params)
response = _completion_with_retry(
client.generate_content,
max_retries=self.max_retries,
contents=contents,
**params,
)
return self._gemini_response_to_chat_result(response)

def _generate_non_gemini(
Expand All @@ -545,7 +600,12 @@ def _generate_non_gemini(
params["examples"] = _parse_examples(examples)
with telemetry.tool_context_manager(self._user_agent):
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
response = _completion_with_retry(
chat.send_message,
max_retries=self.max_retries,
message=question.content,
**msg_params,
)
generations = [
ChatGeneration(
message=AIMessage(content=candidate.text),
Expand Down Expand Up @@ -590,7 +650,12 @@ async def _agenerate(
client, contents = self._gemini_client_and_contents(messages)
params = self._gemini_params(stop=stop, **kwargs)
with telemetry.tool_context_manager(self._user_agent):
response = await client.generate_content_async(contents, **params)
response = await _acompletion_with_retry(
client.generate_content_async,
max_retries=self.max_retries,
contents=contents,
**params,
)
return self._gemini_response_to_chat_result(response)

async def _agenerate_non_gemini(
Expand All @@ -611,7 +676,12 @@ async def _agenerate_non_gemini(
params["examples"] = _parse_examples(examples)
with telemetry.tool_context_manager(self._user_agent):
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
response = await _acompletion_with_retry(
chat.send_message_async,
message=question.content,
max_retries=self.max_retries,
**msg_params,
)
generations = [
ChatGeneration(
message=AIMessage(content=candidate.text),
Expand Down
10 changes: 5 additions & 5 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def test_langchain_google_vertexai_large_batches() -> None:
model_uscentral1 = VertexAIEmbeddings(
model_name="textembedding-gecko@001", location="us-central1"
)
model_asianortheast1 = VertexAIEmbeddings(
model_name="textembedding-gecko@001", location="asia-northeast1"
)
# model_asianortheast1 = VertexAIEmbeddings(
# model_name="textembedding-gecko@001", location="asia-northeast1"
# )
model_uscentral1.embed_documents(documents)
model_asianortheast1.embed_documents(documents)
# model_asianortheast1.embed_documents(documents)
assert model_uscentral1.instance["batch_size"] >= 250
assert model_asianortheast1.instance["batch_size"] < 50
# assert model_asianortheast1.instance["batch_size"] < 50


@pytest.mark.release
Expand Down
9 changes: 7 additions & 2 deletions libs/vertexai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
response = model([message])

assert response.content == response_text
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
mock_send_message.assert_called_once_with(
message=user_prompt, candidate_count=1
)
expected_stop_sequence = [stop] if stop else None
mock_start_chat.assert_called_once_with(
context=None,
Expand Down Expand Up @@ -491,7 +493,10 @@ def test_default_params_gemini() -> None:
message = HumanMessage(content=user_prompt)
_ = model.invoke([message])
mock_generate_content.assert_called_once()
assert mock_generate_content.call_args.args[0][0].parts[0].text == user_prompt
assert (
mock_generate_content.call_args.kwargs["contents"][0].parts[0].text
== user_prompt
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 6e2cef0

Please sign in to comment.