Skip to content

Commit

Permalink
Refactor image loading logic to be reusable in vision_models (#20)
Browse files Browse the repository at this point in the history
* Image generator as an LLM

---------

Co-authored-by: Jorge <[email protected]>
  • Loading branch information
jzaldi and Jorge authored Feb 21, 2024
1 parent 2689bd3 commit aa6f9a9
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 33 deletions.
170 changes: 170 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import base64
import os
import re
from typing import Union
from urllib.parse import urlparse

import requests
from google.cloud import storage # type: ignore[attr-defined]


class ImageBytesLoader:
"""Loads image bytes from multiple sources given a string.
Currently supported:
- Google cloud storage URI
- B64 Encoded image string
- Local file path
- URL
"""

def __init__(
self,
project: Union[str, None] = None,
) -> None:
"""Constructor
Args:
project: Google Cloud project id. Defaults to none.
"""
self._project = project

def load_bytes(self, image_string: str) -> bytes:
"""Routes to the correct loader based on the image_string.
Args:
image_string: Can be either:
- Google cloud storage URI
- B64 Encoded image string
- Local file path
- URL
Returns:
Image bytes.
"""

if image_string.startswith("gs://"):
return self._bytes_from_gsc(image_string)

if image_string.startswith("data:image/"):
return self._bytes_from_b64(image_string)

if self._is_url(image_string):
return self._bytes_from_url(image_string)

if os.path.exists(image_string):
return self._bytes_from_file(image_string)

raise ValueError(
"Image string must be one of: Google Cloud Storage URI, "
"b64 encoded image string (data:image/...), valid image url, "
f"or existing local image file. Instead got '{image_string}'."
)

def _bytes_from_b64(self, base64_image: str) -> bytes:
"""Gets image bytes from a base64 encoded string.
Args:
base64_image: Encoded image in b64 format.
Returns:
Image bytes
"""

pattern = r"data:image/\w{2,4};base64,(.*)"
match = re.search(pattern, base64_image)

if match is not None:
encoded_string = match.group(1)
return base64.b64decode(encoded_string)

raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}")

def _bytes_from_file(self, file_path: str) -> bytes:
"""Gets image bytes from a local file path.
Args:
file_path: Existing file path.
Returns:
Image bytes
"""
with open(file_path, "rb") as image_file:
image_bytes = image_file.read()
return image_bytes

def _bytes_from_url(self, url: str) -> bytes:
"""Gets image bytes from a public url.
Args:
url: Valid url.
Raises:
HTTP Error if there is one.
Returns:
Image bytes
"""

response = requests.get(url)

if not response.ok:
response.raise_for_status()

return response.content

def _bytes_from_gsc(self, gcs_uri: str) -> bytes:
"""Gets image bytes from a google cloud storage uri.
Args:
gcs_uri: Valid gcs uri.
Raises:
ValueError if there are more than one blob matching the uri.
Returns:
Image bytes
"""

gcs_client = storage.Client(project=self._project)

pieces = gcs_uri.split("/")

blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))

if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {gcs_uri}!")

return blobs[0].download_as_bytes()

def _is_url(self, url_string: str) -> bool:
"""Checks if a url is valid.
Args:
url_string: Url to check.
Returns:
Whether the url is valid.
"""
try:
result = urlparse(url_string)
return all([result.scheme, result.netloc])
except Exception:
return False


def image_bytes_to_b64_string(
image_bytes: bytes, encoding: str = "ascii", image_format: str = "png"
) -> str:
"""Encodes image bytes into a b64 encoded string.
Args:
image_bytes: Bytes of the image.
encoding: Type of encoding in the string. 'ascii' by default.
image_format: Format of the image. 'png' by default.
Returns:
B64 image encoded string.
"""
encoded_bytes = base64.b64encode(image_bytes).decode(encoding)
return f"data:image/{image_format};base64,{encoded_bytes}"
36 changes: 3 additions & 33 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations

import base64
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse

import proto # type: ignore[import-untyped]
import requests
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
from langchain_core.callbacks import (
Expand Down Expand Up @@ -53,11 +49,11 @@
CodeChatModel as PreviewCodeChatModel,
)

from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._utils import (
get_generation_info,
is_codey_model,
is_gemini_model,
load_image_from_gcs,
)
from langchain_google_vertexai.functions_utils import (
_format_tools_to_vertex_tool,
Expand Down Expand Up @@ -108,15 +104,6 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
return chat_history


def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False


def _parse_chat_history_gemini(
history: List[BaseMessage],
project: Optional[str] = None,
Expand All @@ -134,25 +121,8 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Part:
return Part.from_text(part["text"])
elif part["type"] == "image_url":
path = part["image_url"]["url"]
if path.startswith("gs://"):
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
regexp = r"data:image/\w{2,4};base64,(.*)"
encoded = re.search(regexp, path).group(1) # type: ignore
except AttributeError:
raise ValueError(
"Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>."
)
image = Image.from_bytes(base64.b64decode(encoded))
elif _is_url(path):
response = requests.get(path)
response.raise_for_status()
image = Image.from_bytes(response.content)
else:
image = Image.load_from_file(path)
image_bytes = ImageBytesLoader(project=project).load_bytes(path)
image = Image.from_bytes(image_bytes)
else:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)
Expand Down
64 changes: 64 additions & 0 deletions libs/vertexai/tests/integration_tests/test_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from google.cloud import storage # type: ignore[attr-defined]
from google.cloud.exceptions import NotFound

from langchain_google_vertexai._image_utils import ImageBytesLoader


def test_image_utils():
base64_image = (
""
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)

loader = ImageBytesLoader()

image_bytes = loader.load_bytes(base64_image)

assert isinstance(image_bytes, bytes)

# Check loads image from blob

bucket_name = "test_image_utils"
blob_name = "my_image.png"

client = storage.Client()
bucket = client.bucket(bucket_name=bucket_name)
blob = bucket.blob(blob_name)

try:
blob.upload_from_string(data=image_bytes)
except NotFound:
client.create_bucket(bucket)
blob.upload_from_string(data=image_bytes)

gcs_uri = f"gs://{bucket.name}/{blob.name}"

gcs_image_bytes = loader.load_bytes(gcs_uri)

assert image_bytes == gcs_image_bytes

# Checks loads image from url
url = (
"https://www.google.co.jp/images/branding/"
"googlelogo/1x/googlelogo_color_272x92dp.png"
)

image_bytes = loader.load_bytes(url)
assert isinstance(image_bytes, bytes)
57 changes: 57 additions & 0 deletions libs/vertexai/tests/unit_tests/test_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from tempfile import NamedTemporaryFile

import pytest

from langchain_google_vertexai._image_utils import (
ImageBytesLoader,
image_bytes_to_b64_string,
)


def test_image_bytes_loader():
loader = ImageBytesLoader()

base64_image = (
""
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)

# Check it loads from b64
image_bytes = loader.load_bytes(base64_image)
assert isinstance(image_bytes, bytes)

# Check it loads from local file.
file = NamedTemporaryFile()
file.write(image_bytes)
file.seek(0)
image_bytes_from_file = loader.load_bytes(file.name)
assert image_bytes_from_file == image_bytes
file.close()

# Check if fails if nosense string
with pytest.raises(ValueError):
loader.load_bytes("No sense string")

# Checks inverse conversion
recovered_b64 = image_bytes_to_b64_string(
image_bytes, encoding="ascii", image_format="png"
)

assert recovered_b64 == base64_image

0 comments on commit aa6f9a9

Please sign in to comment.