Skip to content

Commit

Permalink
feat: add support for multimodal text embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
svidiella committed Feb 29, 2024
1 parent b72e14c commit a41532b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 75 deletions.
72 changes: 41 additions & 31 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def __init__(
Aborted,
DeadlineExceeded,
]
self.instance["retry_decorator"] = create_base_retry_decorator(
retry_decorator = create_base_retry_decorator(
error_types=retry_errors, max_retries=self.max_retries
)
self.instance["get_embeddings_with_retry"] = retry_decorator(
self.client.get_embeddings
)

@property
def model_type(self) -> str:
Expand Down Expand Up @@ -188,30 +191,41 @@ def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""

errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
retry_decorator = create_base_retry_decorator(
error_types=errors, max_retries=self.max_retries
if self.model_type == GoogleEmbeddingModelType.MULTIMODAL:
return self._get_multimodal_embeddings_with_retry(texts)
return self._get_text_embeddings_with_retry(
texts, embeddings_type=embeddings_type
)

@retry_decorator
def _completion_with_retry(texts_to_process: List[str]) -> Any:
if embeddings_type and self.instance["embeddings_task_type_supported"]:
requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type)
for t in texts_to_process
]
else:
requests = texts_to_process
embeddings = self.client.get_embeddings(requests)
return [embs.values for embs in embeddings]
def _get_multimodal_embeddings_with_retry(
self, texts: List[str]
) -> List[List[float]]:
tasks = []
for text in texts:
tasks.append(
self.instance["task_executor"].submit(
self.instance["get_embeddings_with_retry"],
contextual_text=text,
)
)
if len(tasks) > 0:
wait(tasks)
embeddings = [task.result().text_embedding for task in tasks]
return embeddings

return _completion_with_retry(texts)
def _get_text_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""

if embeddings_type and self.instance["embeddings_task_type_supported"]:
requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type) for t in texts
]
else:
requests = texts
embeddings = self.instance["get_embeddings_with_retry"](requests)
return [embedding.values for embedding in embeddings]

def _prepare_and_validate_batches(
self, texts: List[str], embeddings_type: Optional[str] = None
Expand Down Expand Up @@ -240,7 +254,7 @@ def _prepare_and_validate_batches(
return [], VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# Figure out largest possible batch size by trying to push
# Figure out the largest possible batch size by trying to push
# batches and lowering their size in half after every failure.
first_batch = batches[0]
first_result = []
Expand Down Expand Up @@ -362,8 +376,6 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")

def embed_query(self, text: str) -> List[float]:
Expand All @@ -375,10 +387,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]
return self.embed([text], 1, "RETRIEVAL_QUERY")[0]

def embed_image(self, image_path: str) -> List[float]:
"""Embed an image.
Expand All @@ -393,7 +402,8 @@ def embed_image(self, image_path: str) -> List[float]:
if self.model_type != GoogleEmbeddingModelType.MULTIMODAL:
raise NotImplementedError("Only supported for multimodal models")

embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings)
image = Image.load_from_file(image_path)
result: MultiModalEmbeddingResponse = embed_with_retry(image=image)
result: MultiModalEmbeddingResponse = self.instance[
"get_embeddings_with_retry"
](image=image)
return result.image_embedding
52 changes: 24 additions & 28 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,37 @@ def test_initialization() -> None:


@pytest.mark.release
def test_langchain_google_vertexai_embedding_documents() -> None:
documents = ["foo bar"]
model = VertexAIEmbeddings()
@pytest.mark.parametrize(
"number_of_docs",
[1, 8],
)
@pytest.mark.parametrize(
"model_name, embeddings_dim",
[("textembedding-gecko@001", 768), ("multimodalembedding@001", 1408)],
)
def test_langchain_google_vertexai_embedding_documents(
number_of_docs: int, model_name: str, embeddings_dim: int
) -> None:
documents = ["foo bar"] * number_of_docs
model = VertexAIEmbeddings(model_name)
output = model.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
assert len(output) == number_of_docs
for embedding in output:
assert len(embedding) == embeddings_dim
assert model.model_name == model.client._model_id
assert model.model_name == "textembedding-gecko@001"
assert model.model_name == model_name


@pytest.mark.release
def test_langchain_google_vertexai_embedding_query() -> None:
@pytest.mark.parametrize(
"model_name, embeddings_dim",
[("textembedding-gecko@001", 768), ("multimodalembedding@001", 1408)],
)
def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -> None:
document = "foo bar"
model = VertexAIEmbeddings()
model = VertexAIEmbeddings(model_name)
output = model.embed_query(document)
assert len(output) == 768
assert len(output) == embeddings_dim


@pytest.mark.release
Expand All @@ -49,25 +64,6 @@ def test_langchain_google_vertexai_large_batches() -> None:
assert model_asianortheast1.instance["batch_size"] < 50


@pytest.mark.release
def test_langchain_google_vertexai_paginated_texts() -> None:
documents = [
"foo bar",
"foo baz",
"bar foo",
"baz foo",
"bar bar",
"foo foo",
"baz baz",
"baz bar",
]
model = VertexAIEmbeddings()
output = model.embed_documents(documents)
assert len(output) == 8
assert len(output[0]) == 768
assert model.model_name == model.client._model_id


@pytest.mark.release
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
_ = VertexAIEmbeddings()
Expand Down
16 changes: 0 additions & 16 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,6 @@ def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
assert e.value == "Only supported for multimodal models"


def test_langchain_google_vertexai_embed_documents_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_documents(["test"])
assert e.value == "Not supported for multimodal models"


def test_langchain_google_vertexai_embed_query_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_query("test")
assert e.value == "Not supported for multimodal models"


class MockVertexAIEmbeddings(VertexAIEmbeddings):
"""
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client
Expand Down

0 comments on commit a41532b

Please sign in to comment.