From 3e2cb4e8a4e5818df220daad7a299734c6e21dcf Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Sat, 21 Sep 2024 01:58:45 +0200 Subject: [PATCH] openai: embeddings: supported chunk_size when check_embedding_ctx_length is disabled (#23767) Chunking of the input array controlled by `self.chunk_size` is being ignored when `self.check_embedding_ctx_length` is disabled. Effectively, the chunk size is assumed to be equal 1 in such a case. This is suprising. The PR takes into account `self.chunk_size` passed by the user. --------- Co-authored-by: Erick Friis --- .../langchain_openai/embeddings/base.py | 22 +++++++++++-------- .../integration_tests/embeddings/test_base.py | 9 ++++++++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 3e14da153cd6c..252996dc8b32f 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -254,14 +254,14 @@ class OpenAIEmbeddings(BaseModel, Embeddings): retry_max_seconds: int = 20 """Max number of seconds to wait between retries""" http_client: Union[Any, None] = None - """Optional httpx.Client. Only used for sync invocations. Must specify + """Optional httpx.Client. Only used for sync invocations. Must specify http_async_client as well if you'd like a custom client for async invocations. """ http_async_client: Union[Any, None] = None - """Optional httpx.AsyncClient. Only used for async invocations. Must specify + """Optional httpx.AsyncClient. Only used for async invocations. Must specify http_client as well if you'd like a custom client for sync invocations.""" check_embedding_ctx_length: bool = True - """Whether to check the token length of inputs and automatically split inputs + """Whether to check the token length of inputs and automatically split inputs longer than embedding_ctx_length.""" model_config = ConfigDict( @@ -558,7 +558,7 @@ async def empty_embedding() -> List[float]: return [e if e is not None else await empty_embedding() for e in embeddings] def embed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 + self, texts: List[str], chunk_size: int | None = None ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. @@ -570,10 +570,13 @@ def embed_documents( Returns: List of embeddings, one for each text. """ + chunk_size_ = chunk_size or self.chunk_size if not self.check_embedding_ctx_length: embeddings: List[List[float]] = [] - for text in texts: - response = self.client.create(input=text, **self._invocation_params) + for i in range(0, len(texts), self.chunk_size): + response = self.client.create( + input=texts[i : i + chunk_size_], **self._invocation_params + ) if not isinstance(response, dict): response = response.dict() embeddings.extend(r["embedding"] for r in response["data"]) @@ -585,7 +588,7 @@ def embed_documents( return self._get_len_safe_embeddings(texts, engine=engine) async def aembed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 + self, texts: List[str], chunk_size: int | None = None ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint async for embedding search docs. @@ -597,11 +600,12 @@ async def aembed_documents( Returns: List of embeddings, one for each text. """ + chunk_size_ = chunk_size or self.chunk_size if not self.check_embedding_ctx_length: embeddings: List[List[float]] = [] - for text in texts: + for i in range(0, len(texts), chunk_size_): response = await self.async_client.create( - input=text, **self._invocation_params + input=texts[i : i + chunk_size_], **self._invocation_params ) if not isinstance(response, dict): response = response.dict() diff --git a/libs/partners/openai/tests/integration_tests/embeddings/test_base.py b/libs/partners/openai/tests/integration_tests/embeddings/test_base.py index b7a4aa7c8c70c..ef16dd2f48a67 100644 --- a/libs/partners/openai/tests/integration_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/integration_tests/embeddings/test_base.py @@ -61,3 +61,12 @@ async def test_langchain_openai_embeddings_equivalent_to_raw_async() -> None: .embedding ) assert np.isclose(lc_output, direct_output).all() + + +def test_langchain_openai_embeddings_dimensions_large_num() -> None: + """Test openai embeddings.""" + documents = [f"foo bar {i}" for i in range(2000)] + embedding = OpenAIEmbeddings(model="text-embedding-3-small", dimensions=128) + output = embedding.embed_documents(documents) + assert len(output) == 2000 + assert len(output[0]) == 128