diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index da509328..3f7f42d8 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -4,7 +4,7 @@ import json import logging from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Optional, Union, cast +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast import proto # type: ignore[import-untyped] from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart @@ -539,6 +539,48 @@ def _stream( ), ) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + if not self._is_gemini_model: + raise NotImplementedError() + params = self._prepare_params(stop=stop, stream=True, **kwargs) + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + message = history_gemini.pop() + chat = self.client.start_chat(history=history_gemini) + raw_tools = params.pop("functions") if "functions" in params else None + tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + safety_settings = params.pop("safety_settings", None) + async for chunk in await chat.send_message_async( + message, + stream=True, + generation_config=params, + safety_settings=safety_settings, + tools=tools, + ): + message = _parse_response_candidate(chunk.candidates[0]) + if run_manager: + await run_manager.on_llm_new_token(message.content) + yield ChatGenerationChunk( + message=AIMessageChunk( + content=message.content, + additional_kwargs=message.additional_kwargs, + ), + generation_info=get_generation_info( + chunk.candidates[0], + self._is_gemini_model, + usage_metadata=chunk.to_dict().get("usage_metadata"), + ), + ) + def _start_chat( self, history: _ChatHistory, **kwargs: Any ) -> Union[ChatSession, CodeChatSession]: diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index 8b6f1062..c236ab3b 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -1,7 +1,7 @@ from __future__ import annotations from concurrent.futures import Executor -from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional, Union import vertexai # type: ignore[import-untyped] from google.api_core.client_options import ClientOptions @@ -95,6 +95,7 @@ async def _acompletion_with_retry( llm: VertexAI, prompt: str, is_gemini: bool = False, + stream: bool = False, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: @@ -105,17 +106,22 @@ async def _acompletion_with_retry( @retry_decorator async def _acompletion_with_retry_inner( - prompt: str, is_gemini: bool = False, **kwargs: Any + prompt: str, is_gemini: bool = False, stream: bool = False, **kwargs: Any ) -> Any: if is_gemini: return await llm.client.generate_content_async( prompt, generation_config=kwargs, + stream=stream, safety_settings=kwargs.pop("safety_settings", None), ) + if stream: + raise ValueError("Async streaming is supported only for Gemini family!") return await llm.client.predict_async(prompt, **kwargs) - return await _acompletion_with_retry_inner(prompt, is_gemini, **kwargs) + return await _acompletion_with_retry_inner( + prompt, is_gemini, stream=stream, **kwargs + ) class _VertexAIBase(BaseModel): @@ -453,6 +459,34 @@ def _stream( verbose=self.verbose, ) + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + params = self._prepare_params(stop=stop, stream=True, **kwargs) + if not self._is_gemini_model: + raise ValueError("Async streaming is supported only for Gemini family!") + async for chunk in await _acompletion_with_retry( + self, + prompt, + stream=True, + is_gemini=self._is_gemini_model, + run_manager=run_manager, + **params, + ): + usage_metadata = chunk.to_dict().get("usage_metadata") + chunk = self._candidate_to_generation( + chunk.candidates[0], stream=True, usage_metadata=usage_metadata + ) + yield chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, chunk=chunk, verbose=self.verbose + ) + class VertexAIModelGarden(_VertexAIBase, BaseLLM): """Large language models served from Vertex AI Model Garden.""" diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index c3721af3..365c7a21 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -86,6 +86,14 @@ def test_vertexai_stream(model_name: str) -> None: assert isinstance(chunk, AIMessageChunk) +async def test_vertexai_astream() -> None: + model = ChatVertexAI(temperature=0, model_name="gemini-pro") + message = HumanMessage(content="Hello") + + async for chunk in model.astream([message]): + assert isinstance(chunk, AIMessageChunk) + + def test_vertexai_single_call_with_context() -> None: model = ChatVertexAI() raw_context = ( diff --git a/libs/vertexai/tests/integration_tests/test_image_utils.py b/libs/vertexai/tests/integration_tests/test_image_utils.py index b297fa0d..87f75a81 100644 --- a/libs/vertexai/tests/integration_tests/test_image_utils.py +++ b/libs/vertexai/tests/integration_tests/test_image_utils.py @@ -1,9 +1,11 @@ +import pytest from google.cloud import storage # type: ignore[attr-defined] from google.cloud.exceptions import NotFound from langchain_google_vertexai._image_utils import ImageBytesLoader +@pytest.mark.skip("CI testing not set up") def test_image_utils(): base64_image = ( "" diff --git a/libs/vertexai/tests/integration_tests/test_llms.py b/libs/vertexai/tests/integration_tests/test_llms.py index 29f29db2..6b252f42 100644 --- a/libs/vertexai/tests/integration_tests/test_llms.py +++ b/libs/vertexai/tests/integration_tests/test_llms.py @@ -113,6 +113,12 @@ async def test_vertex_consistency() -> None: assert output.generations[0][0].text == async_output.generations[0][0].text +async def test_astream() -> None: + llm = VertexAI(temperature=0, model_name="gemini-pro") + async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token, str) + + @pytest.mark.skip("CI testing not set up") @pytest.mark.parametrize( "endpoint_os_variable_name,result_arg",