diff --git a/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb b/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb index 885174f4a1b52..6e8d84e4e1c90 100644 --- a/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb +++ b/docs/docs/integrations/chat/google_vertex_ai_palm.ipynb @@ -42,18 +42,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install --upgrade --quiet langchain-google-vertexai" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/qiaowang/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_google_vertexai import ChatVertexAI" @@ -261,21 +278,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'is_blocked': False,\n", - " 'safety_ratings': [{'category': 'HARM_CATEGORY_HARASSMENT',\n", + "{'citation_metadata': ,\n", + " 'is_blocked': False,\n", + " 'safety_ratings': [{'blocked': False,\n", + " 'category': 'HARM_CATEGORY_HARASSMENT',\n", " 'probability_label': 'NEGLIGIBLE'},\n", - " {'category': 'HARM_CATEGORY_HATE_SPEECH',\n", + " {'blocked': False,\n", + " 'category': 'HARM_CATEGORY_HATE_SPEECH',\n", " 'probability_label': 'NEGLIGIBLE'},\n", - " {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n", + " {'blocked': False,\n", + " 'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n", " 'probability_label': 'NEGLIGIBLE'},\n", - " {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n", + " {'blocked': False,\n", + " 'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n", " 'probability_label': 'NEGLIGIBLE'}]}\n" ] } @@ -315,18 +337,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'is_blocked': False,\n", - " 'safety_attributes': {'Derogatory': 0.1,\n", - " 'Finance': 0.3,\n", - " 'Insult': 0.1,\n", - " 'Sexual': 0.1}}\n" + "{'errors': (),\n", + " 'grounding_metadata': {'citations': [], 'search_queries': []},\n", + " 'is_blocked': False,\n", + " 'safety_attributes': {'Derogatory': 0.1, 'Insult': 0.1, 'Sexual': 0.2}}\n" ] } ], @@ -482,15 +503,48 @@ " sys.stdout.write(chunk.content)\n", " sys.stdout.flush()" ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'chain' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mfor\u001b[39;00m chunk \u001b[38;5;129;01min\u001b[39;00m chain\u001b[38;5;241m.\u001b[39mastream({}):\n\u001b[1;32m 2\u001b[0m sys\u001b[38;5;241m.\u001b[39mstdout\u001b[38;5;241m.\u001b[39mwrite(chunk\u001b[38;5;241m.\u001b[39mcontent)\n\u001b[1;32m 3\u001b[0m sys\u001b[38;5;241m.\u001b[39mstdout\u001b[38;5;241m.\u001b[39mflush()\n", + "\u001b[0;31mNameError\u001b[0m: name 'chain' is not defined" + ] + } + ], + "source": [ + "async for chunk in chain.astream({}):\n", + " sys.stdout.write(chunk.content)\n", + " sys.stdout.flush()" + ] } ], "metadata": { "kernelspec": { - "display_name": "", - "name": "" + "display_name": "Python 3", + "language": "python", + "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" } }, "nbformat": 4, diff --git a/docs/docs/integrations/llms/google_vertex_ai_palm.ipynb b/docs/docs/integrations/llms/google_vertex_ai_palm.ipynb index 4f53a75c925b4..c9580917544a8 100644 --- a/docs/docs/integrations/llms/google_vertex_ai_palm.ipynb +++ b/docs/docs/integrations/llms/google_vertex_ai_palm.ipynb @@ -80,16 +80,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/qiaowang/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from langchain_google_vertexai import VertexAI" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -98,16 +107,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'**Pros:**\\n\\n* **Easy to learn and use:** Python is known for its simple syntax and readability, making it a great choice for beginners and experienced programmers alike.\\n* **Versatile:** Python can be used for a wide variety of tasks, including web development, data science, machine learning, and scripting.\\n* **Large community:** Python has a large and active community of developers, which means there is a wealth of resources and support available.\\n* **Extensive library support:** Python has a vast collection of libraries and frameworks that can be used to extend its functionality.\\n* **Cross-platform:** Python is available for a'" + "\"Pros:\\n\\n1.Simplicity and Readability: Python has a concise and easy-to-read syntax, which makes it accessible to new programmers and allows for quick development. Its use of significant whitespace contributes to its readability.\\n\\n2. Extensive Library and Module Ecosystem: Python boasts a comprehensive standard library and a vast repository of third-party libraries and modules. These resources facilitate various programming tasks, including data analysis, machine learning, web development, automation, GUI programming, and more.\\n\\n3. Object-Oriented Programming (OOP): Python supports OOP principles, enabling programmers to structure code in a modular and organized fashion. It allows for the creation of classes, objects, and methods, promoting code reusability and maintainability.\\n\\n4. Rapid Development and Prototyping: Python's rapid development cycle and prototyping capabilities make it suitable for quickly iterating through ideas and testing concepts. Its interactive interpreter allows for quick experimentation and rapid feedback.\\n\\n5. Flexibility and Versatility: Python is a versatile language that finds applications in various domains, including web development, data science, scientific computing, GUI programming, game development, automation, scripting, and DevOps. It provides cross-platform compatibility, allowing programs to run on different operating systems.\\n\\nCons:\\n\\n1. Performance and Speed: Python is an interpreted language, which inherently makes it slower than compiled languages like C or Java. This speed disadvantage can be a hindrance in applications requiring intensive computational tasks or real-time processing.\\n\\n2. Limited Type Checking and Static Syntax Checking: Python's dynamic type system lacks comprehensive type checking during development. Errors often arise only at runtime, which can lead to difficulties in identifying and resolving issues early on. Additionally, the lack of static syntax checking can introduce errors that go unnoticed until the program is executed.\\n\\n3. GIL (Global Interpreter Lock): Python employs a Global Interpreter Lock (GIL), which allows only one thread to execute Python bytecode at a time. This can limit parallelism and concurrency, impacting the performance of multi-threaded applications.\\n\\n4. Package Management Complexity: Python's package management system, while extensive, can be challenging to manage, particularly for large projects. Dealing with dependencies and version conflicts can sometimes be complex and time-consuming.\\n\\n5. Memory Management: Python's automatic garbage collection, while convenient, can introduce challenges in optimizing memory usage. Developers need to be aware of memory management considerations when working with large datasets or complex algorithms to prevent memory leaks and performance bottlenecks.\"" ] }, - "execution_count": null, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -161,6 +170,69 @@ " print(chunk, end=\"\", flush=True)" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "**Pros:**\n", + "\n", + "1. **Easy to Learn and Use:**\n", + " - Python is known for its simple syntax, making it easy to learn for beginners.\n", + " - The use of whitespace for indentation adds clarity to the code, making it more readable.\n", + " - Python's extensive documentation and vast online resources contribute to its ease of learning.\n", + "\n", + "2. **Versatile and General-Purpose:**\n", + " - Python can be used for a wide range of tasks, including web development, data science, machine learning, scripting, and automation.\n", + " - Its extensive standard library provides various modules for handling different tasks, reducing the need for external dependencies.\n", + "\n", + "3. **Interpreted Nature:**\n", + " - Python is an interpreted language, meaning the source code is directly executed by the interpreter without being compiled beforehand.\n", + " - This allows for rapid development, as changes can be made and executed quickly without having to recompile the code.\n", + "\n", + "4. **Strong Community Support:**\n", + " - Python has a large and active community of developers who contribute to the language's improvement and offer help and support to others.\n", + " - The continuous development of Python's ecosystem, with new libraries and frameworks being released regularly, contributes to its popularity.\n", + "\n", + "5. **Cross-Platform Availability:**\n", + " - Python is available on various operating systems, including Windows, macOS, and Linux, making it easy to develop and deploy code across different platforms.\n", + " - This portability allows developers to create applications that can run on multiple systems without major modifications.\n", + "\n", + "**Cons:**\n", + "\n", + "1. **Performance:**\n", + " - Compared to compiled languages like C or C++, Python is generally slower in execution speed.\n", + " - However, Python's extensive optimizations and the use of libraries written in compiled languages help mitigate this disadvantage.\n", + "\n", + "2. **Memory Usage:**\n", + " - Python uses a dynamic memory management system, which can lead to higher memory consumption compared to statically typed languages.\n", + " - Optimizing memory usage in Python requires careful consideration of data structures and memory management techniques.\n", + "\n", + "3. **Lack of Strong Typing:**\n", + " - Python is a dynamically typed language, which means data types are not strictly enforced and can change during program execution.\n", + " - This flexibility can lead to errors and make it more difficult to catch type-related issues early on.\n", + " - However, Python does provide mechanisms like type annotations and static type checkers to help with type checking and improving code quality.\n", + "\n", + "4. **Package Management:**\n", + " - Python's package management system, pip, can be complex to use and can sometimes lead to dependency conflicts or version issues.\n", + " - Managing dependencies in Python projects requires careful attention to ensure compatibility and to avoid potential conflicts.\n", + "\n", + "5. **Testing and Debugging:**\n", + " - Python's interpreted nature can make debugging more challenging compared to compiled languages.\n", + " - The dynamic nature of Python allows potential errors to go undetected until runtime, making it more difficult to pinpoint the exact cause of an issue.\n", + " - However, using debugging tools and implementing comprehensive testing can help mitigate these challenges." + ] + } + ], + "source": [ + "async for chunk in model.astream(message):\n", + " print(chunk, end=\"\", flush=True)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -574,7 +646,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py index 1c62d76c75c0d..ecbe9288829e6 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py @@ -6,7 +6,7 @@ import logging import re from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Optional, Union, cast +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Type, Union, cast from urllib.parse import urlparse import requests @@ -18,6 +18,7 @@ ) from langchain_core.language_models.chat_models import ( BaseChatModel, + agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( @@ -29,7 +30,8 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.tools import BaseTool from vertexai.language_models import ( # type: ignore ChatMessage, ChatModel, @@ -45,6 +47,9 @@ Image, Part, ) +from vertexai.preview.generative_models import ( + Tool as VertexTool, +) from langchain_google_vertexai._utils import ( get_generation_info, @@ -361,17 +366,11 @@ def _generate( msg_params["candidate_count"] = params.pop("candidate_count") if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + chat, message, tools = self._start_gemini_chat( + messages=messages, raw_tools=raw_tools + ) response = chat.send_message( message, generation_config=params, @@ -424,8 +423,10 @@ async def _agenerate( ValueError: if the last message in the list is not from human. """ if "stream" in kwargs: - kwargs.pop("stream") - logger.warning("ChatVertexAI does not currently support async streaming.") + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) params = self._prepare_params(stop=stop, **kwargs) safety_settings = kwargs.pop("safety_settings", None) @@ -434,16 +435,11 @@ async def _agenerate( msg_params["candidate_count"] = params.pop("candidate_count") if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + chat, message, tools = self._start_gemini_chat( + messages=messages, raw_tools=raw_tools + ) response = await chat.send_message_async( message, generation_config=params, @@ -483,16 +479,11 @@ def _stream( ) -> Iterator[ChatGenerationChunk]: params = self._prepare_params(stop=stop, stream=True, **kwargs) if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + chat, message, tools = self._start_gemini_chat( + messages=messages, raw_tools=raw_tools + ) responses = chat.send_message( message, stream=True, generation_config=params, tools=tools ) @@ -514,13 +505,80 @@ def _stream( params["examples"] = _parse_examples(examples) chat = self._start_chat(history, **params) responses = chat.send_message_streaming(question.content, **params) - for response in responses: - if run_manager: - run_manager.on_llm_new_token(response.text) - yield ChatGenerationChunk( - message=AIMessageChunk(content=response.text), - generation_info=get_generation_info(response, self._is_gemini_model), + for response in responses: + if run_manager: + run_manager.on_llm_new_token(response.text) + yield ChatGenerationChunk( + message=AIMessageChunk(content=response.text), + generation_info=get_generation_info( + response, self._is_gemini_model + ), + ) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + if "stream" in kwargs: + kwargs.pop("stream") + params = self._prepare_params(stop=stop, stream=True, **kwargs) + if self._is_gemini_model: + # set param to `functions` until core tool/function calling implemented + raw_tools = params.pop("functions") if "functions" in params else None + chat, message, tools = self._start_gemini_chat( + messages=messages, raw_tools=raw_tools ) + async for response in await chat.send_message_async( + message, + generation_config=params, + tools=tools, + stream=True, + ): + message = _parse_response_candidate(response.candidates[0]) + if run_manager: + run_manager.on_llm_new_token(message.content) + yield ChatGenerationChunk( + message=AIMessageChunk( + content=message.content, + additional_kwargs=message.additional_kwargs, + ) + ) + else: + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples", None) + if examples: + params["examples"] = _parse_examples(examples) + chat = self._start_chat(history, **params) + async for response in chat.send_message_streaming_async( + question.content, **params + ): + if run_manager: + run_manager.on_llm_new_token(response.text) + yield ChatGenerationChunk( + message=AIMessageChunk(content=response.text), + generation_info=get_generation_info( + response, self._is_gemini_model + ), + ) + + def _start_gemini_chat( + self, + messages: List[BaseMessage], + raw_tools: List[Union[BaseTool, Type[BaseModel]]], + ) -> tuple[ChatSession, Content, VertexTool]: + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + message = history_gemini.pop() + chat = self.client.start_chat(history=history_gemini) + tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + return chat, message, tools def _start_chat( self, history: _ChatHistory, **kwargs: Any diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py index bd4d346a01910..bfa05737ca5c1 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/llms.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/llms.py @@ -1,7 +1,7 @@ from __future__ import annotations from concurrent.futures import Executor -from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Optional, Union import vertexai # type: ignore[import-untyped] from google.api_core.client_options import ClientOptions @@ -87,6 +87,7 @@ def _completion_with_retry_inner( async def _acompletion_with_retry( llm: VertexAI, prompt: str, + stream: bool = False, is_gemini: bool = False, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, @@ -103,10 +104,14 @@ async def _acompletion_with_retry_inner( if is_gemini: return await llm.client.generate_content_async( prompt, + stream=stream, generation_config=kwargs, safety_settings=kwargs.pop("safety_settings", None), ) - return await llm.client.predict_async(prompt, **kwargs) + else: + if stream: + return llm.client.predict_streaming_async(prompt, **kwargs) + return await llm.client.predict_async(prompt, **kwargs) return await _acompletion_with_retry_inner(prompt, is_gemini, **kwargs) @@ -403,21 +408,46 @@ def _stream( run_manager=run_manager, **params, ): - # Gemini models return GenerationResponse even when streaming, which has a - # candidates field. - stream_resp = ( - stream_resp - if isinstance(stream_resp, TextGenerationResponse) - else stream_resp.candidates[0] + for chunk in self._get_stream_chunk(stream_resp, run_manager): + yield chunk + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + params = self._prepare_params(stop=stop, stream=True, **kwargs) + async for stream_resp in await _acompletion_with_retry( + self, + prompt, + stream=True, + is_gemini=self._is_gemini_model, + run_manager=run_manager, + **params, + ): + for chunk in self._get_stream_chunk(stream_resp, run_manager): + yield chunk + + def _get_stream_chunk( + self, stream_resp: Any, run_manager: Optional[CallbackManagerForLLMRun] = None + ): + # Gemini models return GenerationResponse even when streaming, which has a + # candidates field. + stream_resp = ( + stream_resp + if isinstance(stream_resp, TextGenerationResponse) + else stream_resp.candidates[0] + ) + chunk = self._response_to_generation(stream_resp, stream=True) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + chunk=chunk, + verbose=self.verbose, ) - chunk = self._response_to_generation(stream_resp, stream=True) - yield chunk - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - chunk=chunk, - verbose=self.verbose, - ) class VertexAIModelGarden(_VertexAIBase, BaseLLM): diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py index f2981a0801812..82308de15d93b 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_chat_models.py @@ -75,6 +75,28 @@ async def test_vertexai_agenerate(model_name: str) -> None: # assert sync_generation == async_generation +@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"]) +async def test_vertexai_agenerate_stream_enabled(model_name: str) -> None: + model = ChatVertexAI(temperature=0, model_name=model_name) + message = HumanMessage(content="Hello") + response = await model.agenerate([[message]], stream=True) + assert isinstance(response, LLMResult) + assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore + + sync_response = model.generate([[message]], stream=True) + sync_generation = cast(ChatGeneration, sync_response.generations[0][0]) + async_generation = cast(ChatGeneration, response.generations[0][0]) + + # assert some properties to make debugging easier + + # xfail: this is not equivalent with temp=0 right now + # assert sync_generation.message.content == async_generation.message.content + assert sync_generation.generation_info == async_generation.generation_info + + # xfail: content is not same right now + # assert sync_generation == async_generation + + @pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"]) def test_vertexai_stream(model_name: str) -> None: model = ChatVertexAI(temperature=0, model_name=model_name) @@ -85,6 +107,16 @@ def test_vertexai_stream(model_name: str) -> None: assert isinstance(chunk, AIMessageChunk) +@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"]) +async def test_vertexai_astream(model_name: str) -> None: + model = ChatVertexAI(temperature=0, model_name=model_name) + message = HumanMessage(content="Hello") + + async_response = model.astream([message]) + async for chunk in async_response: + assert isinstance(chunk, AIMessageChunk) + + def test_vertexai_single_call_with_context() -> None: model = ChatVertexAI() raw_context = ( diff --git a/libs/partners/google-vertexai/tests/integration_tests/test_llms.py b/libs/partners/google-vertexai/tests/integration_tests/test_llms.py index ae10d9377cbf2..246af3bf8dafc 100644 --- a/libs/partners/google-vertexai/tests/integration_tests/test_llms.py +++ b/libs/partners/google-vertexai/tests/integration_tests/test_llms.py @@ -95,6 +95,20 @@ def test_stream(model_name: str) -> None: assert isinstance(token, str) +@pytest.mark.parametrize( + "model_name", + model_names_to_test_with_default, +) +async def test_astream(model_name: str) -> None: + llm = ( + VertexAI(temperature=0, model_name=model_name) + if model_name + else VertexAI(temperature=0) + ) + async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token, str) + + async def test_vertex_consistency() -> None: llm = VertexAI(temperature=0) output = llm.generate(["Please say foo:"])