-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor image loading logic to be reusable in
vision_models
(#20)
* Image generator as an LLM --------- Co-authored-by: Jorge <[email protected]>
- Loading branch information
Showing
4 changed files
with
294 additions
and
33 deletions.
There are no files selected for viewing
170 changes: 170 additions & 0 deletions
170
libs/vertexai/langchain_google_vertexai/_image_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import base64 | ||
import os | ||
import re | ||
from typing import Union | ||
from urllib.parse import urlparse | ||
|
||
import requests | ||
from google.cloud import storage # type: ignore[attr-defined] | ||
|
||
|
||
class ImageBytesLoader: | ||
"""Loads image bytes from multiple sources given a string. | ||
Currently supported: | ||
- Google cloud storage URI | ||
- B64 Encoded image string | ||
- Local file path | ||
- URL | ||
""" | ||
|
||
def __init__( | ||
self, | ||
project: Union[str, None] = None, | ||
) -> None: | ||
"""Constructor | ||
Args: | ||
project: Google Cloud project id. Defaults to none. | ||
""" | ||
self._project = project | ||
|
||
def load_bytes(self, image_string: str) -> bytes: | ||
"""Routes to the correct loader based on the image_string. | ||
Args: | ||
image_string: Can be either: | ||
- Google cloud storage URI | ||
- B64 Encoded image string | ||
- Local file path | ||
- URL | ||
Returns: | ||
Image bytes. | ||
""" | ||
|
||
if image_string.startswith("gs://"): | ||
return self._bytes_from_gsc(image_string) | ||
|
||
if image_string.startswith("data:image/"): | ||
return self._bytes_from_b64(image_string) | ||
|
||
if self._is_url(image_string): | ||
return self._bytes_from_url(image_string) | ||
|
||
if os.path.exists(image_string): | ||
return self._bytes_from_file(image_string) | ||
|
||
raise ValueError( | ||
"Image string must be one of: Google Cloud Storage URI, " | ||
"b64 encoded image string (data:image/...), valid image url, " | ||
f"or existing local image file. Instead got '{image_string}'." | ||
) | ||
|
||
def _bytes_from_b64(self, base64_image: str) -> bytes: | ||
"""Gets image bytes from a base64 encoded string. | ||
Args: | ||
base64_image: Encoded image in b64 format. | ||
Returns: | ||
Image bytes | ||
""" | ||
|
||
pattern = r"data:image/\w{2,4};base64,(.*)" | ||
match = re.search(pattern, base64_image) | ||
|
||
if match is not None: | ||
encoded_string = match.group(1) | ||
return base64.b64decode(encoded_string) | ||
|
||
raise ValueError(f"Error in b64 encoded image. Must follow pattern: {pattern}") | ||
|
||
def _bytes_from_file(self, file_path: str) -> bytes: | ||
"""Gets image bytes from a local file path. | ||
Args: | ||
file_path: Existing file path. | ||
Returns: | ||
Image bytes | ||
""" | ||
with open(file_path, "rb") as image_file: | ||
image_bytes = image_file.read() | ||
return image_bytes | ||
|
||
def _bytes_from_url(self, url: str) -> bytes: | ||
"""Gets image bytes from a public url. | ||
Args: | ||
url: Valid url. | ||
Raises: | ||
HTTP Error if there is one. | ||
Returns: | ||
Image bytes | ||
""" | ||
|
||
response = requests.get(url) | ||
|
||
if not response.ok: | ||
response.raise_for_status() | ||
|
||
return response.content | ||
|
||
def _bytes_from_gsc(self, gcs_uri: str) -> bytes: | ||
"""Gets image bytes from a google cloud storage uri. | ||
Args: | ||
gcs_uri: Valid gcs uri. | ||
Raises: | ||
ValueError if there are more than one blob matching the uri. | ||
Returns: | ||
Image bytes | ||
""" | ||
|
||
gcs_client = storage.Client(project=self._project) | ||
|
||
pieces = gcs_uri.split("/") | ||
|
||
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:]))) | ||
|
||
if len(blobs) > 1: | ||
raise ValueError(f"Found more than one candidate for {gcs_uri}!") | ||
|
||
return blobs[0].download_as_bytes() | ||
|
||
def _is_url(self, url_string: str) -> bool: | ||
"""Checks if a url is valid. | ||
Args: | ||
url_string: Url to check. | ||
Returns: | ||
Whether the url is valid. | ||
""" | ||
try: | ||
result = urlparse(url_string) | ||
return all([result.scheme, result.netloc]) | ||
except Exception: | ||
return False | ||
|
||
|
||
def image_bytes_to_b64_string( | ||
image_bytes: bytes, encoding: str = "ascii", image_format: str = "png" | ||
) -> str: | ||
"""Encodes image bytes into a b64 encoded string. | ||
Args: | ||
image_bytes: Bytes of the image. | ||
encoding: Type of encoding in the string. 'ascii' by default. | ||
image_format: Format of the image. 'png' by default. | ||
Returns: | ||
B64 image encoded string. | ||
""" | ||
encoded_bytes = base64.b64encode(image_bytes).decode(encoding) | ||
return f"data:image/{image_format};base64,{encoded_bytes}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from google.cloud import storage # type: ignore[attr-defined] | ||
from google.cloud.exceptions import NotFound | ||
|
||
from langchain_google_vertexai._image_utils import ImageBytesLoader | ||
|
||
|
||
def test_image_utils(): | ||
base64_image = ( | ||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" | ||
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" | ||
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" | ||
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" | ||
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" | ||
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" | ||
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" | ||
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" | ||
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" | ||
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" | ||
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" | ||
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" | ||
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" | ||
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" | ||
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" | ||
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" | ||
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" | ||
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" | ||
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" | ||
) | ||
|
||
loader = ImageBytesLoader() | ||
|
||
image_bytes = loader.load_bytes(base64_image) | ||
|
||
assert isinstance(image_bytes, bytes) | ||
|
||
# Check loads image from blob | ||
|
||
bucket_name = "test_image_utils" | ||
blob_name = "my_image.png" | ||
|
||
client = storage.Client() | ||
bucket = client.bucket(bucket_name=bucket_name) | ||
blob = bucket.blob(blob_name) | ||
|
||
try: | ||
blob.upload_from_string(data=image_bytes) | ||
except NotFound: | ||
client.create_bucket(bucket) | ||
blob.upload_from_string(data=image_bytes) | ||
|
||
gcs_uri = f"gs://{bucket.name}/{blob.name}" | ||
|
||
gcs_image_bytes = loader.load_bytes(gcs_uri) | ||
|
||
assert image_bytes == gcs_image_bytes | ||
|
||
# Checks loads image from url | ||
url = ( | ||
"https://www.google.co.jp/images/branding/" | ||
"googlelogo/1x/googlelogo_color_272x92dp.png" | ||
) | ||
|
||
image_bytes = loader.load_bytes(url) | ||
assert isinstance(image_bytes, bytes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from tempfile import NamedTemporaryFile | ||
|
||
import pytest | ||
|
||
from langchain_google_vertexai._image_utils import ( | ||
ImageBytesLoader, | ||
image_bytes_to_b64_string, | ||
) | ||
|
||
|
||
def test_image_bytes_loader(): | ||
loader = ImageBytesLoader() | ||
|
||
base64_image = ( | ||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" | ||
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" | ||
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" | ||
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" | ||
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" | ||
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" | ||
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" | ||
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" | ||
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" | ||
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" | ||
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" | ||
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" | ||
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" | ||
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" | ||
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" | ||
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" | ||
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" | ||
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" | ||
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" | ||
) | ||
|
||
# Check it loads from b64 | ||
image_bytes = loader.load_bytes(base64_image) | ||
assert isinstance(image_bytes, bytes) | ||
|
||
# Check it loads from local file. | ||
file = NamedTemporaryFile() | ||
file.write(image_bytes) | ||
file.seek(0) | ||
image_bytes_from_file = loader.load_bytes(file.name) | ||
assert image_bytes_from_file == image_bytes | ||
file.close() | ||
|
||
# Check if fails if nosense string | ||
with pytest.raises(ValueError): | ||
loader.load_bytes("No sense string") | ||
|
||
# Checks inverse conversion | ||
recovered_b64 = image_bytes_to_b64_string( | ||
image_bytes, encoding="ascii", image_format="png" | ||
) | ||
|
||
assert recovered_b64 == base64_image |