From 8956c577733b86c8f9207f0d44a4294b27723a3d Mon Sep 17 00:00:00 2001 From: Sergio Vidiella Pinto Date: Wed, 28 Feb 2024 16:23:01 +0100 Subject: [PATCH] feat: add support for multimodal model and add embed_image --- .../langchain_google_vertexai/embeddings.py | 55 ++++++++++++++++++- .../tests/integration_tests/conftest.py | 39 +++++++++++++ .../integration_tests/test_embeddings.py | 6 ++ .../integration_tests/test_vision_models.py | 25 --------- .../tests/unit_tests/test_embeddings.py | 39 +++++++++++++ 5 files changed, 138 insertions(+), 26 deletions(-) create mode 100644 libs/vertexai/tests/integration_tests/conftest.py create mode 100644 libs/vertexai/tests/unit_tests/test_embeddings.py diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index f7a9b97f2..60517c54e 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -19,6 +19,11 @@ TextEmbeddingInput, TextEmbeddingModel, ) +from vertexai.vision_models import ( # type: ignore + Image, + MultiModalEmbeddingModel, + MultiModalEmbeddingResponse, +) from langchain_google_vertexai._base import _VertexAICommon @@ -46,7 +51,11 @@ def validate_environment(cls, values: Dict) -> Dict: "textembedding-gecko@001" ) values["model_name"] = "textembedding-gecko@001" - values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) + if cls._is_multimodal_model(values["model_name"]): + values["client"] = MultiModalEmbeddingModel.from_pretrained( + values["model_name"]) + else: + values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) return values def __init__( @@ -82,6 +91,15 @@ def __init__( self.instance[ "embeddings_task_type_supported" ] = not self.client._endpoint_name.endswith("/textembedding-gecko@001") + retry_errors: List[Type[BaseException]] = [ + ResourceExhausted, + ServiceUnavailable, + Aborted, + DeadlineExceeded, + ] + self.instance["retry_decorator"] = create_base_retry_decorator( + error_types=retry_errors, max_retries=self.max_retries + ) @staticmethod def _split_by_punctuation(text: str) -> List[str]: @@ -321,6 +339,8 @@ def embed_documents( Returns: List of embeddings, one for each text. """ + if self._is_multimodal_model(self.model_name): + raise NotImplementedError("Not supported for multimodal models") return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT") def embed_query(self, text: str) -> List[float]: @@ -332,5 +352,38 @@ def embed_query(self, text: str) -> List[float]: Returns: Embedding for the text. """ + if self._is_multimodal_model(self.model_name): + raise NotImplementedError("Not supported for multimodal models") embeddings = self.embed([text], 1, "RETRIEVAL_QUERY") return embeddings[0] + + def embed_image(self, image_path: str) -> List[float]: + """Embed an image. + + Args: + image_path: Path to image (local or Google Cloud Storage) to generate + embeddings for. + + Returns: + Embedding for the image. + """ + if not self._is_multimodal_model(self.model_name): + raise NotImplementedError("Only supported for multimodal models") + + embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings) + image = Image.load_from_file(image_path) + result: MultiModalEmbeddingResponse = embed_with_retry(image=image) + return result.image_embedding + + @staticmethod + def _is_multimodal_model(model_name: str) -> bool: + """ + Check if the embeddings model is multimodal or not. + + Args: + model_name: The embeddings model name. + + Returns: + A boolean, True if the model is multimodal. + """ + return "multimodalembedding" in model_name diff --git a/libs/vertexai/tests/integration_tests/conftest.py b/libs/vertexai/tests/integration_tests/conftest.py new file mode 100644 index 000000000..e3664fc8a --- /dev/null +++ b/libs/vertexai/tests/integration_tests/conftest.py @@ -0,0 +1,39 @@ +import base64 + +import pytest +from _pytest.tmpdir import TempPathFactory +from vertexai.vision_models import Image + + +@pytest.fixture +def base64_image() -> str: + return ( + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) + + +@pytest.fixture +def tmp_image(tmp_path_factory: TempPathFactory, base64_image) -> str: + img_data = base64.b64decode(base64_image.split(',')[1]) + image = Image(image_bytes=img_data) + fn = tmp_path_factory.mktemp("data") / "img.png" + image.save(str(fn)) + return str(fn) diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index 42495629b..89ca0d1bf 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -68,3 +68,9 @@ def test_warning(caplog: pytest.LogCaptureFixture) -> None: "Feb-01-2024. Currently the default is set to textembedding-gecko@001" ) assert record.message == expected_message + + +def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None: + model = VertexAIEmbeddings(model_name="multimodalembedding") + output = model.embed_image(tmp_image) + assert len(output) == 1408 diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py index a3edaa047..8b735d96d 100644 --- a/libs/vertexai/tests/integration_tests/test_vision_models.py +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -127,28 +127,3 @@ def test_vertex_ai_image_generation_and_edition(): response = editor.invoke(messages) assert isinstance(response, AIMessage) - - -@pytest.fixture -def base64_image() -> str: - return ( - "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" - "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" - "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" - "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" - "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" - "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" - "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" - "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" - "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" - "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" - "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" - "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" - "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" - "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" - "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" - "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" - "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" - "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" - "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" - ) diff --git a/libs/vertexai/tests/unit_tests/test_embeddings.py b/libs/vertexai/tests/unit_tests/test_embeddings.py new file mode 100644 index 000000000..deb5b541d --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_embeddings.py @@ -0,0 +1,39 @@ +import pytest +from vertexai.language_models import TextEmbeddingModel +from vertexai.vision_models import MultiModalEmbeddingModel + +from langchain_google_vertexai import VertexAIEmbeddings + + +def test_langchain_google_vertexai_text_model() -> None: + embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001") + assert isinstance(embeddings_model.client, TextEmbeddingModel) + assert not embeddings_model._is_multimodal_model(embeddings_model.model_name) + + +def test_langchain_google_vertexai_multimodal_model() -> None: + embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001") + assert isinstance(embeddings_model.client, MultiModalEmbeddingModel) + assert embeddings_model._is_multimodal_model(embeddings_model.model_name) + + +def test_langchain_google_vertexai_embed_image_multimodal_only() -> None: + embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001") + with pytest.raises(NotImplementedError) as e: + embeddings_model.embed_image("test") + assert e.value == "Only supported for multimodal models" + + +def test_langchain_google_vertexai_embed_documents_text_only() -> None: + embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001") + with pytest.raises(NotImplementedError) as e: + embeddings_model.embed_documents(["test"]) + assert e.value == "Not supported for multimodal models" + + +def test_langchain_google_vertexai_embed_query_text_only() -> None: + embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001") + with pytest.raises(NotImplementedError) as e: + embeddings_model.embed_query("test") + assert e.value == "Not supported for multimodal models" +