diff --git a/libs/vertexai/langchain_google_vertexai/_image_utils.py b/libs/vertexai/langchain_google_vertexai/_image_utils.py new file mode 100644 index 00000000..4f94e0b1 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/_image_utils.py @@ -0,0 +1,170 @@ +import base64 +import os +import re +from typing import Union +from urllib.parse import urlparse + +import requests +from google.cloud import storage # type: ignore[attr-defined] + + +class ImageBytesLoader: + """Loads image bytes from multiple sources given a string. + + Currently supported: + - Google cloud storage URI + - B64 Encoded image string + - Local file path + - URL + """ + + def __init__( + self, + project: Union[str, None] = None, + ) -> None: + """Constructor + + Args: + project: Google Cloud project id. Defaults to none. + """ + self._project = project + + def load_bytes(self, image_string: str) -> bytes: + """Routes to the correct loader based on the image_string. + + Args: + image_string: Can be either: + - Google cloud storage URI + - B64 Encoded image string + - Local file path + - URL + + Returns: + Image bytes. + """ + + if image_string.startswith("gs://"): + return self._bytes_from_gsc(image_string) + + if image_string.startswith("data:image/"): + return self._bytes_from_b64(image_string) + + if self._is_url(image_string): + return self._bytes_from_url(image_string) + + if os.path.exists(image_string): + return self._bytes_from_file(image_string) + + raise ValueError( + "Image string must be one of: Google Cloud Storage URI, " + "b64 encoded image string (data:image/...), valid image url, " + f"or existing local image file. Instead got '{image_string}'." + ) + + def _bytes_from_b64(self, base64_image: str) -> bytes: + """Gets image bytes from a base64 encoded string. + + Args: + base64_image: Encoded image in b64 format. + + Returns: + Image bytes + """ + + pattern = r"data:image/\w{2,4};base64,(.*)" + match = re.search(pattern, base64_image) + + if match is not None: + encoded_string = match.group(1) + return base64.b64decode(encoded_string) + + raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}") + + def _bytes_from_file(self, file_path: str) -> bytes: + """Gets image bytes from a local file path. + + Args: + file_path: Existing file path. + + Returns: + Image bytes + """ + with open(file_path, "rb") as image_file: + image_bytes = image_file.read() + return image_bytes + + def _bytes_from_url(self, url: str) -> bytes: + """Gets image bytes from a public url. + + Args: + url: Valid url. + + Raises: + HTTP Error if there is one. + + Returns: + Image bytes + """ + + response = requests.get(url) + + if not response.ok: + response.raise_for_status() + + return response.content + + def _bytes_from_gsc(self, gcs_uri: str) -> bytes: + """Gets image bytes from a google cloud storage uri. + + Args: + gcs_uri: Valid gcs uri. + + Raises: + ValueError if there are more than one blob matching the uri. + + Returns: + Image bytes + """ + + gcs_client = storage.Client(project=self._project) + + pieces = gcs_uri.split("/") + + blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:]))) + + if len(blobs) > 1: + raise ValueError(f"Found more than one candidate for {gcs_uri}!") + + return blobs[0].download_as_bytes() + + def _is_url(self, url_string: str) -> bool: + """Checks if a url is valid. + + Args: + url_string: Url to check. + + Returns: + Whether the url is valid. + """ + try: + result = urlparse(url_string) + return all([result.scheme, result.netloc]) + except Exception: + return False + + +def image_bytes_to_b64_string( + image_bytes: bytes, encoding: str = "ascii", image_format: str = "png" +) -> str: + """Encodes image bytes into a b64 encoded string. + + Args: + image_bytes: Bytes of the image. + encoding: Type of encoding in the string. 'ascii' by default. + image_format: Format of the image. 'png' by default. + + Returns: + B64 image encoded string. + """ + encoded_bytes = base64.b64encode(image_bytes).decode(encoding) + return f"data:image/{image_format};base64,{encoded_bytes}" diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index ab620b91..da509328 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -1,16 +1,12 @@ """Wrapper around Google VertexAI chat-based models.""" from __future__ import annotations -import base64 import json import logging -import re from dataclasses import dataclass, field from typing import Any, Dict, Iterator, List, Optional, Union, cast -from urllib.parse import urlparse import proto # type: ignore[import-untyped] -import requests from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall from langchain_core.callbacks import ( @@ -53,11 +49,11 @@ CodeChatModel as PreviewCodeChatModel, ) +from langchain_google_vertexai._image_utils import ImageBytesLoader from langchain_google_vertexai._utils import ( get_generation_info, is_codey_model, is_gemini_model, - load_image_from_gcs, ) from langchain_google_vertexai.functions_utils import ( _format_tools_to_vertex_tool, @@ -108,15 +104,6 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory: return chat_history -def _is_url(s: str) -> bool: - try: - result = urlparse(s) - return all([result.scheme, result.netloc]) - except Exception as e: - logger.debug(f"Unable to parse URL: {e}") - return False - - def _parse_chat_history_gemini( history: List[BaseMessage], project: Optional[str] = None, @@ -134,25 +121,8 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Part: return Part.from_text(part["text"]) elif part["type"] == "image_url": path = part["image_url"]["url"] - if path.startswith("gs://"): - image = load_image_from_gcs(path=path, project=project) - elif path.startswith("data:image/"): - # extract base64 component from image uri - try: - regexp = r"data:image/\w{2,4};base64,(.*)" - encoded = re.search(regexp, path).group(1) # type: ignore - except AttributeError: - raise ValueError( - "Invalid image uri. It should be in the format " - "data:image/;base64,." - ) - image = Image.from_bytes(base64.b64decode(encoded)) - elif _is_url(path): - response = requests.get(path) - response.raise_for_status() - image = Image.from_bytes(response.content) - else: - image = Image.load_from_file(path) + image_bytes = ImageBytesLoader(project=project).load_bytes(path) + image = Image.from_bytes(image_bytes) else: raise ValueError("Only text and image_url types are supported!") return Part.from_image(image) diff --git a/libs/vertexai/tests/integration_tests/test_image_utils.py b/libs/vertexai/tests/integration_tests/test_image_utils.py new file mode 100644 index 00000000..b297fa0d --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_image_utils.py @@ -0,0 +1,64 @@ +from google.cloud import storage # type: ignore[attr-defined] +from google.cloud.exceptions import NotFound + +from langchain_google_vertexai._image_utils import ImageBytesLoader + + +def test_image_utils(): + base64_image = ( + "" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) + + loader = ImageBytesLoader() + + image_bytes = loader.load_bytes(base64_image) + + assert isinstance(image_bytes, bytes) + + # Check loads image from blob + + bucket_name = "test_image_utils" + blob_name = "my_image.png" + + client = storage.Client() + bucket = client.bucket(bucket_name=bucket_name) + blob = bucket.blob(blob_name) + + try: + blob.upload_from_string(data=image_bytes) + except NotFound: + client.create_bucket(bucket) + blob.upload_from_string(data=image_bytes) + + gcs_uri = f"gs://{bucket.name}/{blob.name}" + + gcs_image_bytes = loader.load_bytes(gcs_uri) + + assert image_bytes == gcs_image_bytes + + # Checks loads image from url + url = ( + "https://www.google.co.jp/images/branding/" + "googlelogo/1x/googlelogo_color_272x92dp.png" + ) + + image_bytes = loader.load_bytes(url) + assert isinstance(image_bytes, bytes) diff --git a/libs/vertexai/tests/unit_tests/test_image_utils.py b/libs/vertexai/tests/unit_tests/test_image_utils.py new file mode 100644 index 00000000..5d900d8b --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_image_utils.py @@ -0,0 +1,57 @@ +from tempfile import NamedTemporaryFile + +import pytest + +from langchain_google_vertexai._image_utils import ( + ImageBytesLoader, + image_bytes_to_b64_string, +) + + +def test_image_bytes_loader(): + loader = ImageBytesLoader() + + base64_image = ( + "" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) + + # Check it loads from b64 + image_bytes = loader.load_bytes(base64_image) + assert isinstance(image_bytes, bytes) + + # Check it loads from local file. + file = NamedTemporaryFile() + file.write(image_bytes) + file.seek(0) + image_bytes_from_file = loader.load_bytes(file.name) + assert image_bytes_from_file == image_bytes + file.close() + + # Check if fails if nosense string + with pytest.raises(ValueError): + loader.load_bytes("No sense string") + + # Checks inverse conversion + recovered_b64 = image_bytes_to_b64_string( + image_bytes, encoding="ascii", image_format="png" + ) + + assert recovered_b64 == base64_image