Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor image loading logic to be reusable in vision_models #20

Merged
merged 11 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import base64
import os
import re
from typing import Union
from urllib.parse import urlparse

import requests
from google.cloud import storage # type: ignore[attr-defined]


class ImageBytesLoader:
"""Loads image bytes from multiple sources given a string.

Currently supported:
- Google cloud storage URI
- B64 Encoded image string
- Local file path
- URL
"""

def __init__(
self,
project: Union[str, None] = None,
) -> None:
"""Constructor

Args:
project: Google Cloud project id. Defaults to none.
"""
self._project = project

def load_bytes(self, image_string: str) -> bytes:
"""Routes to the correct loader based on the image_string.

Args:
image_string: Can be either:
- Google cloud storage URI
- B64 Encoded image string
- Local file path
- URL

Returns:
Image bytes.
"""

if image_string.startswith("gs://"):
return self._bytes_from_gsc(image_string)

if image_string.startswith("data:image/"):
return self._bytes_from_b64(image_string)

if self._is_url(image_string):
return self._bytes_from_url(image_string)

if os.path.exists(image_string):
return self._bytes_from_file(image_string)

raise ValueError(
"Image string must be one of: Google Cloud Storage URI, "
"b64 encoded image string (data:image/...), valid image url, "
f"or existing local image file. Instead got '{image_string}'."
)

def _bytes_from_b64(self, base64_image: str) -> bytes:
"""Gets image bytes from a base64 encoded string.

Args:
base64_image: Encoded image in b64 format.

Returns:
Image bytes
"""

pattern = r"data:image/\w{2,4};base64,(.*)"
match = re.search(pattern, base64_image)

if match is not None:
encoded_string = match.group(1)
return base64.b64decode(encoded_string)

raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}")

def _bytes_from_file(self, file_path: str) -> bytes:
"""Gets image bytes from a local file path.

Args:
file_path: Existing file path.

Returns:
Image bytes
"""
with open(file_path, "rb") as image_file:
image_bytes = image_file.read()
return image_bytes

def _bytes_from_url(self, url: str) -> bytes:
"""Gets image bytes from a public url.

Args:
url: Valid url.

Raises:
HTTP Error if there is one.

Returns:
Image bytes
"""

response = requests.get(url)

if not response.ok:
response.raise_for_status()

return response.content

def _bytes_from_gsc(self, gcs_uri: str) -> bytes:
"""Gets image bytes from a google cloud storage uri.

Args:
gcs_uri: Valid gcs uri.

Raises:
ValueError if there are more than one blob matching the uri.

Returns:
Image bytes
"""

gcs_client = storage.Client(project=self._project)

pieces = gcs_uri.split("/")

blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))

if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {gcs_uri}!")

return blobs[0].download_as_bytes()

def _is_url(self, url_string: str) -> bool:
"""Checks if a url is valid.

Args:
url_string: Url to check.

Returns:
Whether the url is valid.
"""
try:
result = urlparse(url_string)
return all([result.scheme, result.netloc])
except Exception:
return False


def image_bytes_to_b64_string(
image_bytes: bytes, encoding: str = "ascii", image_format: str = "png"
) -> str:
"""Encodes image bytes into a b64 encoded string.

Args:
image_bytes: Bytes of the image.
encoding: Type of encoding in the string. 'ascii' by default.
image_format: Format of the image. 'png' by default.

Returns:
B64 image encoded string.
"""
encoded_bytes = base64.b64encode(image_bytes).decode(encoding)
return f"data:image/{image_format};base64,{encoded_bytes}"
36 changes: 3 additions & 33 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations

import base64
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse

import proto # type: ignore[import-untyped]
import requests
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
from langchain_core.callbacks import (
Expand Down Expand Up @@ -53,11 +49,11 @@
CodeChatModel as PreviewCodeChatModel,
)

from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._utils import (
get_generation_info,
is_codey_model,
is_gemini_model,
load_image_from_gcs,
)
from langchain_google_vertexai.functions_utils import (
_format_tools_to_vertex_tool,
Expand Down Expand Up @@ -108,15 +104,6 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
return chat_history


def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False


def _parse_chat_history_gemini(
history: List[BaseMessage],
project: Optional[str] = None,
Expand All @@ -134,25 +121,8 @@ def _convert_to_prompt(part: Union[str, Dict]) -> Part:
return Part.from_text(part["text"])
elif part["type"] == "image_url":
path = part["image_url"]["url"]
if path.startswith("gs://"):
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
regexp = r"data:image/\w{2,4};base64,(.*)"
encoded = re.search(regexp, path).group(1) # type: ignore
except AttributeError:
raise ValueError(
"Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>."
)
image = Image.from_bytes(base64.b64decode(encoded))
elif _is_url(path):
response = requests.get(path)
response.raise_for_status()
image = Image.from_bytes(response.content)
else:
image = Image.load_from_file(path)
image_bytes = ImageBytesLoader(project=project).load_bytes(path)
image = Image.from_bytes(image_bytes)
else:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)
Expand Down
64 changes: 64 additions & 0 deletions libs/vertexai/tests/integration_tests/test_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from google.cloud import storage # type: ignore[attr-defined]
from google.cloud.exceptions import NotFound

from langchain_google_vertexai._image_utils import ImageBytesLoader


def test_image_utils():
base64_image = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA"
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)

loader = ImageBytesLoader()

image_bytes = loader.load_bytes(base64_image)

assert isinstance(image_bytes, bytes)

# Check loads image from blob

bucket_name = "test_image_utils"
blob_name = "my_image.png"

client = storage.Client()
bucket = client.bucket(bucket_name=bucket_name)
blob = bucket.blob(blob_name)

try:
blob.upload_from_string(data=image_bytes)
except NotFound:
client.create_bucket(bucket)
blob.upload_from_string(data=image_bytes)

gcs_uri = f"gs://{bucket.name}/{blob.name}"

gcs_image_bytes = loader.load_bytes(gcs_uri)

assert image_bytes == gcs_image_bytes

# Checks loads image from url
url = (
"https://www.google.co.jp/images/branding/"
"googlelogo/1x/googlelogo_color_272x92dp.png"
)

image_bytes = loader.load_bytes(url)
assert isinstance(image_bytes, bytes)
57 changes: 57 additions & 0 deletions libs/vertexai/tests/unit_tests/test_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from tempfile import NamedTemporaryFile

import pytest

from langchain_google_vertexai._image_utils import (
ImageBytesLoader,
image_bytes_to_b64_string,
)


def test_image_bytes_loader():
loader = ImageBytesLoader()

base64_image = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA"
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)

# Check it loads from b64
image_bytes = loader.load_bytes(base64_image)
assert isinstance(image_bytes, bytes)

# Check it loads from local file.
file = NamedTemporaryFile()
file.write(image_bytes)
file.seek(0)
image_bytes_from_file = loader.load_bytes(file.name)
assert image_bytes_from_file == image_bytes
file.close()

# Check if fails if nosense string
with pytest.raises(ValueError):
loader.load_bytes("No sense string")

# Checks inverse conversion
recovered_b64 = image_bytes_to_b64_string(
image_bytes, encoding="ascii", image_format="png"
)

assert recovered_b64 == base64_image
Loading