Skip to content

Commit

Permalink
fix embeddings with max_batch_size (langchain-ai#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Jun 28, 2024
1 parent a22a608 commit e7b4572
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def _prepare_and_validate_batches(
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
batches = batches[1:]
break
except InvalidArgument:
had_failure = True
Expand All @@ -347,6 +346,8 @@ def _prepare_and_validate_batches(
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
batches = batches[1:]
else:
# Still figuring out max batch size.
batches = batches[1:]
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:
mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001")
default_batch_size = mock_embeddings.instance["batch_size"]
texts = ["text_{i}" for i in range(default_batch_size * 2)]
texts = ["text {i}" for i in range(default_batch_size * 2)]
# It should only return one batch (out of two) still to process
_, batches = mock_embeddings._prepare_and_validate_batches(texts=texts)
assert len(batches) == 1
Expand Down

0 comments on commit e7b4572

Please sign in to comment.