Skip to content

Commit

Permalink
feat: add support for multimodal model and add embed_image
Browse files Browse the repository at this point in the history
  • Loading branch information
svidiella committed Feb 28, 2024
1 parent 6af9e3b commit 8956c57
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 26 deletions.
55 changes: 54 additions & 1 deletion libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
TextEmbeddingInput,
TextEmbeddingModel,
)
from vertexai.vision_models import ( # type: ignore
Image,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
)

from langchain_google_vertexai._base import _VertexAICommon

Expand Down Expand Up @@ -46,7 +51,11 @@ def validate_environment(cls, values: Dict) -> Dict:
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
if cls._is_multimodal_model(values["model_name"]):
values["client"] = MultiModalEmbeddingModel.from_pretrained(
values["model_name"])
else:
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

def __init__(
Expand Down Expand Up @@ -82,6 +91,15 @@ def __init__(
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
retry_errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
self.instance["retry_decorator"] = create_base_retry_decorator(
error_types=retry_errors, max_retries=self.max_retries
)

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
Expand Down Expand Up @@ -321,6 +339,8 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
if self._is_multimodal_model(self.model_name):
raise NotImplementedError("Not supported for multimodal models")
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")

def embed_query(self, text: str) -> List[float]:
Expand All @@ -332,5 +352,38 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
if self._is_multimodal_model(self.model_name):
raise NotImplementedError("Not supported for multimodal models")
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]

def embed_image(self, image_path: str) -> List[float]:
"""Embed an image.
Args:
image_path: Path to image (local or Google Cloud Storage) to generate
embeddings for.
Returns:
Embedding for the image.
"""
if not self._is_multimodal_model(self.model_name):
raise NotImplementedError("Only supported for multimodal models")

embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings)
image = Image.load_from_file(image_path)
result: MultiModalEmbeddingResponse = embed_with_retry(image=image)
return result.image_embedding

@staticmethod
def _is_multimodal_model(model_name: str) -> bool:
"""
Check if the embeddings model is multimodal or not.
Args:
model_name: The embeddings model name.
Returns:
A boolean, True if the model is multimodal.
"""
return "multimodalembedding" in model_name
39 changes: 39 additions & 0 deletions libs/vertexai/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import base64

import pytest
from _pytest.tmpdir import TempPathFactory
from vertexai.vision_models import Image


@pytest.fixture
def base64_image() -> str:
return (
""
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)


@pytest.fixture
def tmp_image(tmp_path_factory: TempPathFactory, base64_image) -> str:
img_data = base64.b64decode(base64_image.split(',')[1])
image = Image(image_bytes=img_data)
fn = tmp_path_factory.mktemp("data") / "img.png"
image.save(str(fn))
return str(fn)
6 changes: 6 additions & 0 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,9 @@ def test_warning(caplog: pytest.LogCaptureFixture) -> None:
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message


def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None:
model = VertexAIEmbeddings(model_name="multimodalembedding")
output = model.embed_image(tmp_image)
assert len(output) == 1408
25 changes: 0 additions & 25 deletions libs/vertexai/tests/integration_tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,28 +127,3 @@ def test_vertex_ai_image_generation_and_edition():

response = editor.invoke(messages)
assert isinstance(response, AIMessage)


@pytest.fixture
def base64_image() -> str:
return (
""
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)
39 changes: 39 additions & 0 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from vertexai.language_models import TextEmbeddingModel
from vertexai.vision_models import MultiModalEmbeddingModel

from langchain_google_vertexai import VertexAIEmbeddings


def test_langchain_google_vertexai_text_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001")
assert isinstance(embeddings_model.client, TextEmbeddingModel)
assert not embeddings_model._is_multimodal_model(embeddings_model.model_name)


def test_langchain_google_vertexai_multimodal_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001")
assert isinstance(embeddings_model.client, MultiModalEmbeddingModel)
assert embeddings_model._is_multimodal_model(embeddings_model.model_name)


def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001")
with pytest.raises(NotImplementedError) as e:
embeddings_model.embed_image("test")
assert e.value == "Only supported for multimodal models"


def test_langchain_google_vertexai_embed_documents_text_only() -> None:
embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001")
with pytest.raises(NotImplementedError) as e:
embeddings_model.embed_documents(["test"])
assert e.value == "Not supported for multimodal models"


def test_langchain_google_vertexai_embed_query_text_only() -> None:
embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001")
with pytest.raises(NotImplementedError) as e:
embeddings_model.embed_query("test")
assert e.value == "Not supported for multimodal models"

0 comments on commit 8956c57

Please sign in to comment.