From eec55c25503355dcf19aa30539318215545084a3 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 4 Dec 2024 14:00:36 -0500 Subject: [PATCH] chroma[patch]: add `get_by_ids` and fix bug (#28516) - Run standard integration tests in Chroma - Add `get_by_ids` method - Fix bug in `add_texts`: if a list of `ids` is passed but any of them are None, Chroma will raise an exception. Here we assign a uuid. --- .../chroma/langchain_chroma/vectorstores.py | 38 +++++++++++++++++++ libs/partners/chroma/poetry.lock | 23 ++++++++++- libs/partners/chroma/pyproject.toml | 4 ++ .../tests/integration_tests/test_standard.py | 37 ++++++++++++++++++ 4 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 libs/partners/chroma/tests/integration_tests/test_standard.py diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 2d0537b57700a..d8d8f5de5e425 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -16,6 +16,7 @@ Iterable, List, Optional, + Sequence, Tuple, Type, Union, @@ -517,6 +518,11 @@ def add_texts( """ if ids is None: ids = [str(uuid.uuid4()) for _ in texts] + else: + # Assign strings to any null IDs + for idx, _id in enumerate(ids): + if _id is None: + ids[idx] = str(uuid.uuid4()) embeddings = None texts = list(texts) if self._embedding_function is not None: @@ -1028,6 +1034,38 @@ def get( return self._collection.get(**kwargs) # type: ignore + def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: + """Get documents by their IDs. + + The returned documents are expected to have the ID field set to the ID of the + document in the vector store. + + Fewer documents may be returned than requested if some IDs are not found or + if there are duplicated IDs. + + Users should not assume that the order of the returned documents matches + the order of the input IDs. Instead, users should rely on the ID field of the + returned documents. + + This method should **NOT** raise exceptions if no documents are found for + some IDs. + + Args: + ids: List of ids to retrieve. + + Returns: + List of Documents. + + .. versionadded:: 0.2.1 + """ + results = self.get(ids=list(ids)) + return [ + Document(page_content=doc, metadata=meta, id=doc_id) + for doc, meta, doc_id in zip( + results["documents"], results["metadatas"], results["ids"] + ) + ] + def update_document(self, document_id: str, document: Document) -> None: """Update a document in the collection. diff --git a/libs/partners/chroma/poetry.lock b/libs/partners/chroma/poetry.lock index 812eefb494a0e..f4b4192f448a0 100644 --- a/libs/partners/chroma/poetry.lock +++ b/libs/partners/chroma/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -955,6 +955,25 @@ typing-extensions = ">=4.7" type = "directory" url = "../../core" +[[package]] +name = "langchain-tests" +version = "0.3.4" +description = "Standard tests for LangChain implementations" +optional = false +python-versions = ">=3.9,<4.0" +files = [] +develop = true + +[package.dependencies] +httpx = "^0.27.0" +langchain-core = "^0.3.19" +pytest = ">=7,<9" +syrupy = "^4" + +[package.source] +type = "directory" +url = "../../standard-tests" + [[package]] name = "langsmith" version = "0.1.139" @@ -2805,4 +2824,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "4e3e3152fdc954723a33ffc5cc5c42b763e5aee74f39df8b54f16cdb753b2d13" +content-hash = "2d6bc4b9a18a322c326c3f7d5786c4b196a997458e6d2ca4043cb6b7a4a123b3" diff --git a/libs/partners/chroma/pyproject.toml b/libs/partners/chroma/pyproject.toml index ede3a9fad5ab1..8a7e18bdf1bf3 100644 --- a/libs/partners/chroma/pyproject.toml +++ b/libs/partners/chroma/pyproject.toml @@ -90,6 +90,10 @@ python = ">=3.9" version = ">=0.1.40,<0.3" python = "<3.9" +[[tool.poetry.group.test.dependencies.langchain-tests]] +path = "../../standard-tests" +develop = true + [tool.poetry.group.codespell.dependencies] codespell = "^2.2.0" diff --git a/libs/partners/chroma/tests/integration_tests/test_standard.py b/libs/partners/chroma/tests/integration_tests/test_standard.py new file mode 100644 index 0000000000000..9211f77c4eed6 --- /dev/null +++ b/libs/partners/chroma/tests/integration_tests/test_standard.py @@ -0,0 +1,37 @@ +from typing import AsyncGenerator, Generator + +import pytest +from langchain_core.embeddings.fake import DeterministicFakeEmbedding +from langchain_core.vectorstores import VectorStore +from langchain_tests.integration_tests.vectorstores import ( + AsyncReadWriteTestSuite, + ReadWriteTestSuite, +) + +from langchain_chroma import Chroma + + +class TestSync(ReadWriteTestSuite): + @pytest.fixture() + def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore + """Get an empty vectorstore for unit tests.""" + embeddings = DeterministicFakeEmbedding(size=10) + store = Chroma(embedding_function=embeddings) + try: + yield store + finally: + store.delete_collection() + pass + + +class TestAsync(AsyncReadWriteTestSuite): + @pytest.fixture() + async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore + """Get an empty vectorstore for unit tests.""" + embeddings = DeterministicFakeEmbedding(size=10) + store = Chroma(embedding_function=embeddings) + try: + yield store + finally: + store.delete_collection() + pass