From 1f6637b19c03149d44b8b5dbe2548f68d3ebdd56 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Mon, 26 Feb 2024 14:44:47 -0600 Subject: [PATCH] Change tool_context_manager context location --- .../langchain_google_vertexai/_base.py | 7 + .../langchain_google_vertexai/_utils.py | 4 +- .../langchain_google_vertexai/chat_models.py | 403 +++++++++--------- .../langchain_google_vertexai/embeddings.py | 4 +- .../langchain_google_vertexai/llms.py | 47 +- .../vision_models.py | 95 +++-- 6 files changed, 294 insertions(+), 266 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 1b2a9846..4efea07b 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -27,6 +27,7 @@ from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory from langchain_google_vertexai._utils import ( get_client_info, + get_user_agent, is_codey_model, is_gemini_model, ) @@ -142,6 +143,12 @@ def _default_params(self) -> Dict[str, Any]: ) return updated_params + @property + def _user_agent(self) -> str: + """Gets the User Agent.""" + _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}") + return user_agent + @classmethod def _init_vertexai(cls, values: Dict) -> None: vertexai.init( diff --git a/libs/vertexai/langchain_google_vertexai/_utils.py b/libs/vertexai/langchain_google_vertexai/_utils.py index adbcdb81..8e4e858b 100644 --- a/libs/vertexai/langchain_google_vertexai/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_utils.py @@ -66,7 +66,7 @@ def get_user_agent(module: Optional[str] = None) -> Tuple[str, str]: module (Optional[str]): Optional. The module for a custom user agent header. Returns: - google.api_core.gapic_v1.client_info.ClientInfo + Tuple[str, str] """ try: langchain_version = metadata.version("langchain") @@ -95,7 +95,7 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo": def load_image_from_gcs(path: str, project: Optional[str] = None) -> Image: - """Loads im Image from GCS.""" + """Loads an Image from GCS.""" gcs_client = storage.Client(project=project) pieces = path.split("/") blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:]))) diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index cdd07e53..5ad8d0e2 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -1,6 +1,6 @@ """Wrapper around Google VertexAI chat-based models.""" -from __future__ import annotations +from __future__ import annotations # noqa import json import logging @@ -295,25 +295,24 @@ def validate_environment(cls, values: Dict) -> Dict: raise ValueError("Safety settings are only supported for Gemini models") cls._init_vertexai(values) - with tool_context_manager(get_user_agent("vertex-ai-llm")): - if is_gemini: - values["client"] = GenerativeModel( - model_name=values["model_name"], safety_settings=safety_settings - ) - values["client_preview"] = GenerativeModel( - model_name=values["model_name"], safety_settings=safety_settings - ) + if is_gemini: + values["client"] = GenerativeModel( + model_name=values["model_name"], safety_settings=safety_settings + ) + values["client_preview"] = GenerativeModel( + model_name=values["model_name"], safety_settings=safety_settings + ) + else: + if is_codey_model(values["model_name"]): + model_cls = CodeChatModel + model_cls_preview = PreviewCodeChatModel else: - if is_codey_model(values["model_name"]): - model_cls = CodeChatModel - model_cls_preview = PreviewCodeChatModel - else: - model_cls = ChatModel - model_cls_preview = PreviewChatModel - values["client"] = model_cls.from_pretrained(values["model_name"]) - values["client_preview"] = model_cls_preview.from_pretrained( - values["model_name"] - ) + model_cls = ChatModel + model_cls_preview = PreviewChatModel + values["client"] = model_cls.from_pretrained(values["model_name"]) + values["client_preview"] = model_cls_preview.from_pretrained( + values["model_name"] + ) return values def _generate( @@ -339,67 +338,68 @@ def _generate( Raises: ValueError: if the last message in the list is not from human. """ - should_stream = stream if stream is not None else self.streaming - safety_settings = kwargs.pop("safety_settings", None) - if should_stream: - stream_iter = self._stream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return generate_from_stream(stream_iter) - - params = self._prepare_params(stop=stop, stream=False, **kwargs) - msg_params = {} - if "candidate_count" in params: - msg_params["candidate_count"] = params.pop("candidate_count") - - if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - - # set param to `functions` until core tool/function calling implemented - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - response = chat.send_message( - message, - generation_config=params, - tools=tools, - safety_settings=safety_settings, - ) - generations = [ - ChatGeneration( - message=_parse_response_candidate(candidate), - generation_info=get_generation_info( - candidate, - self._is_gemini_model, - usage_metadata=response.to_dict().get("usage_metadata"), - ), + with tool_context_manager(self._user_agent): + should_stream = stream if stream is not None else self.streaming + safety_settings = kwargs.pop("safety_settings", None) + if should_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs ) - for candidate in response.candidates - ] - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples") or self.examples - if examples: - params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, **params) - response = chat.send_message(question.content, **msg_params) - generations = [ - ChatGeneration( - message=AIMessage(content=candidate.text), - generation_info=get_generation_info( - candidate, - self._is_gemini_model, - usage_metadata=response.raw_prediction_response.metadata, - ), + return generate_from_stream(stream_iter) + + params = self._prepare_params(stop=stop, stream=False, **kwargs) + msg_params = {} + if "candidate_count" in params: + msg_params["candidate_count"] = params.pop("candidate_count") + + if self._is_gemini_model: + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, ) - for candidate in response.candidates - ] + message = history_gemini.pop() + chat = self.client.start_chat(history=history_gemini) + + # set param to `functions` until core tool/function calling implemented + raw_tools = params.pop("functions") if "functions" in params else None + tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + response = chat.send_message( + message, + generation_config=params, + tools=tools, + safety_settings=safety_settings, + ) + generations = [ + ChatGeneration( + message=_parse_response_candidate(candidate), + generation_info=get_generation_info( + candidate, + self._is_gemini_model, + usage_metadata=response.to_dict().get("usage_metadata"), + ), + ) + for candidate in response.candidates + ] + else: + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples") or self.examples + if examples: + params["examples"] = _parse_examples(examples) + chat = self._start_chat(history, **params) + response = chat.send_message(question.content, **msg_params) + generations = [ + ChatGeneration( + message=AIMessage(content=candidate.text), + generation_info=get_generation_info( + candidate, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), + ) + for candidate in response.candidates + ] return ChatResult(generations=generations) async def _agenerate( @@ -427,59 +427,60 @@ async def _agenerate( kwargs.pop("stream") logger.warning("ChatVertexAI does not currently support async streaming.") - params = self._prepare_params(stop=stop, **kwargs) - safety_settings = kwargs.pop("safety_settings", None) - msg_params = {} - if "candidate_count" in params: - msg_params["candidate_count"] = params.pop("candidate_count") - - if self._is_gemini_model: - history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - # set param to `functions` until core tool/function calling implemented - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - response = await chat.send_message_async( - message, - generation_config=params, - tools=tools, - safety_settings=safety_settings, - ) - generations = [ - ChatGeneration( - message=_parse_response_candidate(c), - generation_info=get_generation_info( - c, - self._is_gemini_model, - usage_metadata=response.to_dict().get("usage_metadata"), - ), + with tool_context_manager(self._user_agent): + params = self._prepare_params(stop=stop, **kwargs) + safety_settings = kwargs.pop("safety_settings", None) + msg_params = {} + if "candidate_count" in params: + msg_params["candidate_count"] = params.pop("candidate_count") + + if self._is_gemini_model: + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, ) - for c in response.candidates - ] - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples", None) or self.examples - if examples: - params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, **params) - response = await chat.send_message_async(question.content, **msg_params) - generations = [ - ChatGeneration( - message=AIMessage(content=r.text), - generation_info=get_generation_info( - r, - self._is_gemini_model, - usage_metadata=response.raw_prediction_response.metadata, - ), + message = history_gemini.pop() + chat = self.client.start_chat(history=history_gemini) + # set param to `functions` until core tool/function calling implemented + raw_tools = params.pop("functions") if "functions" in params else None + tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + response = await chat.send_message_async( + message, + generation_config=params, + tools=tools, + safety_settings=safety_settings, ) - for r in response.candidates - ] + generations = [ + ChatGeneration( + message=_parse_response_candidate(c), + generation_info=get_generation_info( + c, + self._is_gemini_model, + usage_metadata=response.to_dict().get("usage_metadata"), + ), + ) + for c in response.candidates + ] + else: + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples", None) or self.examples + if examples: + params["examples"] = _parse_examples(examples) + chat = self._start_chat(history, **params) + response = await chat.send_message_async(question.content, **msg_params) + generations = [ + ChatGeneration( + message=AIMessage(content=r.text), + generation_info=get_generation_info( + r, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), + ) + for r in response.candidates + ] return ChatResult(generations=generations) def _stream( @@ -489,8 +490,73 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - params = self._prepare_params(stop=stop, stream=True, **kwargs) - if self._is_gemini_model: + with tool_context_manager(self._user_agent): + params = self._prepare_params(stop=stop, stream=True, **kwargs) + if self._is_gemini_model: + history_gemini = _parse_chat_history_gemini( + messages, + project=self.project, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + message = history_gemini.pop() + chat = self.client.start_chat(history=history_gemini) + # set param to `functions` until core tool/function calling implemented + raw_tools = params.pop("functions") if "functions" in params else None + tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None + safety_settings = params.pop("safety_settings", None) + responses = chat.send_message( + message, + stream=True, + generation_config=params, + safety_settings=safety_settings, + tools=tools, + ) + for response in responses: + message = _parse_response_candidate(response.candidates[0]) + if run_manager: + run_manager.on_llm_new_token(message.content) + yield ChatGenerationChunk( + message=AIMessageChunk( + content=message.content, + additional_kwargs=message.additional_kwargs, + ), + generation_info=get_generation_info( + response.candidates[0], + self._is_gemini_model, + usage_metadata=response.to_dict().get("usage_metadata"), + ), + ) + else: + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + examples = kwargs.get("examples", None) + if examples: + params["examples"] = _parse_examples(examples) + chat = self._start_chat(history, **params) + responses = chat.send_message_streaming(question.content, **params) + for response in responses: + if run_manager: + run_manager.on_llm_new_token(response.text) + yield ChatGenerationChunk( + message=AIMessageChunk(content=response.text), + generation_info=get_generation_info( + response, + self._is_gemini_model, + usage_metadata=response.raw_prediction_response.metadata, + ), + ) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + if not self._is_gemini_model: + raise NotImplementedError() + with tool_context_manager(self._user_agent): + params = self._prepare_params(stop=stop, stream=True, **kwargs) history_gemini = _parse_chat_history_gemini( messages, project=self.project, @@ -498,94 +564,31 @@ def _stream( ) message = history_gemini.pop() chat = self.client.start_chat(history=history_gemini) - # set param to `functions` until core tool/function calling implemented raw_tools = params.pop("functions") if "functions" in params else None tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None safety_settings = params.pop("safety_settings", None) - responses = chat.send_message( + async for chunk in await chat.send_message_async( message, stream=True, generation_config=params, safety_settings=safety_settings, tools=tools, - ) - for response in responses: - message = _parse_response_candidate(response.candidates[0]) + ): + message = _parse_response_candidate(chunk.candidates[0]) if run_manager: - run_manager.on_llm_new_token(message.content) + await run_manager.on_llm_new_token(message.content) yield ChatGenerationChunk( message=AIMessageChunk( content=message.content, additional_kwargs=message.additional_kwargs, ), generation_info=get_generation_info( - response.candidates[0], - self._is_gemini_model, - usage_metadata=response.to_dict().get("usage_metadata"), - ), - ) - else: - question = _get_question(messages) - history = _parse_chat_history(messages[:-1]) - examples = kwargs.get("examples", None) - if examples: - params["examples"] = _parse_examples(examples) - chat = self._start_chat(history, **params) - responses = chat.send_message_streaming(question.content, **params) - for response in responses: - if run_manager: - run_manager.on_llm_new_token(response.text) - yield ChatGenerationChunk( - message=AIMessageChunk(content=response.text), - generation_info=get_generation_info( - response, + chunk.candidates[0], self._is_gemini_model, - usage_metadata=response.raw_prediction_response.metadata, + usage_metadata=chunk.to_dict().get("usage_metadata"), ), ) - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - if not self._is_gemini_model: - raise NotImplementedError() - params = self._prepare_params(stop=stop, stream=True, **kwargs) - history_gemini = _parse_chat_history_gemini( - messages, - project=self.project, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - message = history_gemini.pop() - chat = self.client.start_chat(history=history_gemini) - raw_tools = params.pop("functions") if "functions" in params else None - tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None - safety_settings = params.pop("safety_settings", None) - async for chunk in await chat.send_message_async( - message, - stream=True, - generation_config=params, - safety_settings=safety_settings, - tools=tools, - ): - message = _parse_response_candidate(chunk.candidates[0]) - if run_manager: - await run_manager.on_llm_new_token(message.content) - yield ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - ), - generation_info=get_generation_info( - chunk.candidates[0], - self._is_gemini_model, - usage_metadata=chunk.to_dict().get("usage_metadata"), - ), - ) - def _start_chat( self, history: _ChatHistory, **kwargs: Any ) -> Union[ChatSession, CodeChatSession]: diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index ac813adc..eeeb1037 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -48,7 +48,6 @@ def validate_environment(cls, values: Dict) -> Dict: "textembedding-gecko@001" ) values["model_name"] = "textembedding-gecko@001" - with tool_context_manager(get_user_agent("vertex-ai-embeddings")): values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) return values @@ -173,7 +172,8 @@ def _completion_with_retry(texts_to_process: List[str]) -> Any: embeddings = self.client.get_embeddings(requests) return [embs.values for embs in embeddings] - return _completion_with_retry(texts) + with tool_context_manager(self._user_agent): + return _completion_with_retry(texts) def _prepare_and_validate_batches( self, texts: List[str], embeddings_type: Optional[str] = None diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index cf077ca4..73361177 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -91,7 +91,8 @@ def _completion_with_retry_inner( return llm.client.predict_streaming(prompt[0], **kwargs) return llm.client.predict(prompt[0], **kwargs) - return _completion_with_retry_inner(prompt, is_gemini, **kwargs) + with tool_context_manager(llm._user_agent): + return _completion_with_retry_inner(prompt, is_gemini, **kwargs) async def _acompletion_with_retry( @@ -122,9 +123,10 @@ async def _acompletion_with_retry_inner( raise ValueError("Async streaming is supported only for Gemini family!") return await llm.client.predict_async(prompt, **kwargs) - return await _acompletion_with_retry_inner( - prompt, is_gemini, stream=stream, **kwargs - ) + with tool_context_manager(llm._user_agent): + return await _acompletion_with_retry_inner( + prompt, is_gemini, stream=stream, **kwargs + ) class _VertexAIBase(BaseModel): @@ -232,6 +234,12 @@ def _default_params(self) -> Dict[str, Any]: ) return updated_params + @property + def _user_agent(self) -> str: + """Gets the User Agent.""" + _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}") + return user_agent + @classmethod def _init_vertexai(cls, values: Dict) -> None: vertexai.init( @@ -316,25 +324,22 @@ def validate_environment(cls, values: Dict) -> Dict: model_cls = TextGenerationModel preview_model_cls = PreviewTextGenerationModel - with tool_context_manager(get_user_agent("vertex-ai-llm")): - if tuned_model_name: - values["client"] = model_cls.get_tuned_model(tuned_model_name) - values["client_preview"] = preview_model_cls.get_tuned_model( - tuned_model_name + if tuned_model_name: + values["client"] = model_cls.get_tuned_model(tuned_model_name) + values["client_preview"] = preview_model_cls.get_tuned_model( + tuned_model_name + ) + else: + if is_gemini: + values["client"] = model_cls( + model_name=model_name, safety_settings=safety_settings + ) + values["client_preview"] = preview_model_cls( + model_name=model_name, safety_settings=safety_settings ) else: - if is_gemini: - values["client"] = model_cls( - model_name=model_name, safety_settings=safety_settings - ) - values["client_preview"] = preview_model_cls( - model_name=model_name, safety_settings=safety_settings - ) - else: - values["client"] = model_cls.from_pretrained(model_name) - values["client_preview"] = preview_model_cls.from_pretrained( - model_name - ) + values["client"] = model_cls.from_pretrained(model_name) + values["client_preview"] = preview_model_cls.from_pretrained(model_name) if values["streaming"] and values["n"] > 1: raise ValueError("Only one candidate can be generated with streaming!") diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index a005fdbd..212c22f8 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -40,8 +40,7 @@ class _BaseImageTextModel(BaseModel): def _create_model(self) -> ImageTextModel: """Builds the model object from the class attributes.""" - with tool_context_manager(get_user_agent("vertex-ai-imagen")): - return ImageTextModel.from_pretrained(model_name=self.model_name) + return ImageTextModel.from_pretrained(model_name=self.model_name) def _get_image_from_message_part(self, message_part: str | Dict) -> Image | None: """Given a message part obtain a image if the part represents it. @@ -78,6 +77,12 @@ def _llm_type(self) -> str: """Returns the type of LLM""" return "vertexai-vision" + @property + def _user_agent(self) -> str: + """Gets the User Agent.""" + _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}") + return user_agent + class _BaseVertexAIImageCaptioning(_BaseImageTextModel): """Base class for Image Captioning models.""" @@ -91,13 +96,14 @@ def _get_captions(self, image: Image) -> List[str]: Returns: List of captions obtained from the image. """ - model = self._create_model() - captions = model.get_captions( - image=image, - number_of_results=self.number_of_results, - language=self.language, - ) - return captions + with tool_context_manager(self._user_agent): + model = self._create_model() + captions = model.get_captions( + image=image, + number_of_results=self.number_of_results, + language=self.language, + ) + return captions class VertexAIImageCaptioning(_BaseVertexAIImageCaptioning, BaseLLM): @@ -269,11 +275,12 @@ def _ask_questions(self, image: Image, query: str) -> List[str]: Returns: List of responses to the query. """ - model = self._create_model() - answers = model.ask_question( - image=image, question=query, number_of_results=self.number_of_results - ) - return answers + with tool_context_manager(self._user_agent): + model = self._create_model() + answers = model.ask_question( + image=image, question=query, number_of_results=self.number_of_results + ) + return answers class _BaseVertexAIImageGenerator(BaseModel): @@ -306,17 +313,17 @@ def _generate_images(self, prompt: str) -> List[str]: Returns: List of b64 encoded strings. """ - - model = ImageGenerationModel.from_pretrained(self.model_name) - - generation_result = model.generate_images( - prompt=prompt, - negative_prompt=self.negative_prompt, - number_of_images=self.number_of_images, - language=self.language, - guidance_scale=self.guidance_scale, - seed=self.seed, - ) + with tool_context_manager(self._user_agent): + model = ImageGenerationModel.from_pretrained(self.model_name) + + generation_result = model.generate_images( + prompt=prompt, + negative_prompt=self.negative_prompt, + number_of_images=self.number_of_images, + language=self.language, + guidance_scale=self.guidance_scale, + seed=self.seed, + ) image_str_list = [ self._to_b64_string(image) for image in generation_result.images @@ -334,22 +341,22 @@ def _edit_images(self, image_str: str, prompt: str) -> List[str]: Returns: List of b64 encoded strings. """ - - model = ImageGenerationModel.from_pretrained(self.model_name) - - image_loader = ImageBytesLoader(project=self.project) - image_bytes = image_loader.load_bytes(image_str) - image = Image(image_bytes=image_bytes) - - generation_result = model.edit_image( - prompt=prompt, - base_image=image, - negative_prompt=self.negative_prompt, - number_of_images=self.number_of_images, - language=self.language, - guidance_scale=self.guidance_scale, - seed=self.seed, - ) + with tool_context_manager(self._user_agent): + model = ImageGenerationModel.from_pretrained(self.model_name) + + image_loader = ImageBytesLoader(project=self.project) + image_bytes = image_loader.load_bytes(image_str) + image = Image(image_bytes=image_bytes) + + generation_result = model.edit_image( + prompt=prompt, + base_image=image, + negative_prompt=self.negative_prompt, + number_of_images=self.number_of_images, + language=self.language, + guidance_scale=self.guidance_scale, + seed=self.seed, + ) image_str_list = [ self._to_b64_string(image) for image in generation_result.images @@ -386,6 +393,12 @@ def _llm_type(self) -> str: """Returns the type of LLM""" return "vertexai-vision" + @property + def _user_agent(self) -> str: + """Gets the User Agent.""" + _, user_agent = get_user_agent(f"{type(self).__name__}_{self.model_name}") + return user_agent + class VertexAIImageGeneratorChat(_BaseVertexAIImageGenerator, BaseChatModel): """Generates an image from a prompt."""