From 9441b52a4ef2e50607b795f9f81837f1217ba517 Mon Sep 17 00:00:00 2001 From: xieqihui Date: Fri, 24 Nov 2023 14:38:08 +0800 Subject: [PATCH] add mongodb_store --- libs/langchain/langchain/storage/mongodb.py | 126 ++++++++++++++++++ libs/langchain/poetry.lock | 29 +++- libs/langchain/pyproject.toml | 2 + .../integration_tests/storage/test_mongodb.py | 73 ++++++++++ .../tests/unit_tests/storage/test_mongodb.py | 11 ++ 5 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 libs/langchain/langchain/storage/mongodb.py create mode 100644 libs/langchain/tests/integration_tests/storage/test_mongodb.py create mode 100644 libs/langchain/tests/unit_tests/storage/test_mongodb.py diff --git a/libs/langchain/langchain/storage/mongodb.py b/libs/langchain/langchain/storage/mongodb.py new file mode 100644 index 0000000000000..9f48fcf6a6e0f --- /dev/null +++ b/libs/langchain/langchain/storage/mongodb.py @@ -0,0 +1,126 @@ +from typing import Iterator, List, Optional, Sequence, Tuple + +from langchain_core.stores import BaseStore + +from langchain.schema import Document + + +class MongodbStore(BaseStore[str, Document]): + """BaseStore implementation using MongoDB as the underlying store. + + Examples: + Create a MongodbStore instance and perform operations on it: + + .. code-block:: python + + # Instantiate the MongodbStore with a MongoDB connection + from langchain.storage import MongodbStore + + mongo_conn_str = "mongodb://localhost:27017/" + mongodb_store = MongodbStore(mongo_conn_str, db_name="test-db", + collection_name="test-collection") + + # Set values for keys + doc1 = Document(...) + doc2 = Document(...) + mongodb_store.mset([("key1", doc1), ("key2", doc2)]) + + # Get values for keys + values = mongodb_store.mget(["key1", "key2"]) + # [doc1, doc2] + + # Iterate over keys + for key in mongodb_store.yield_keys(): + print(key) + + # Delete keys + mongodb_store.mdelete(["key1", "key2"]) + """ + + def __init__( + self, + connection_string: str, + db_name: str, + collection_name: str, + client_kwargs: Optional[dict] = None, + ) -> None: + """Initialize the MongodbStore with a MongoDB connection string. + + Args: + connection_string (str): MongoDB connection string + db_name (str): name to use + collection_name (str): collection name to use + client_kwargs (dict): Keyword arguments to pass to the Mongo client + """ + try: + from pymongo import MongoClient + except ImportError as e: + raise ImportError( + "The MongodbStore requires the pymongo library to be " + "installed. " + "pip install pymongo" + ) from e + + if not connection_string: + raise ValueError("connection_string must be provided.") + if not db_name: + raise ValueError("db_name must be provided.") + if not collection_name: + raise ValueError("collection_name must be provided.") + + self.client = MongoClient(connection_string, **(client_kwargs or {})) + self.collection = self.client[db_name][collection_name] + + def mget(self, keys: Sequence[str]) -> List[Optional[Document]]: + """Get the list of documents associated with the given keys. + + Args: + keys (list[str]): A list of keys representing Document IDs.. + + Returns: + list[Document]: A list of Documents corresponding to the provided + keys, where each Document is either retrieved successfully or + represented as None if not found. + """ + result = self.collection.find({"_id": {"$in": keys}}) + result_dict = {doc["_id"]: Document(**doc["value"]) for doc in result} + return [result_dict.get(key) for key in keys] + + def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None: + """Set the given key-value pairs. + + Args: + key_value_pairs (list[tuple[str, Document]]): A list of id-document + pairs. + Returns: + None + """ + from pymongo import UpdateOne + + updates = [{"_id": k, "value": v.__dict__} for k, v in key_value_pairs] + self.collection.bulk_write( + [UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates] + ) + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given ids. + + Args: + keys (list[str]): A list of keys representing Document IDs.. + """ + self.collection.delete_many({"_id": {"$in": keys}}) + + def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: + """Yield keys in the store. + + Args: + prefix (str): prefix of keys to retrieve. + """ + if prefix is None: + for doc in self.collection.find(projection=["_id"]): + yield doc["_id"] + else: + for doc in self.collection.find( + {"_id": {"$regex": f"^{prefix}"}}, projection=["_id"] + ): + yield doc["_id"] diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 7f3e6c5c49383..c76dc64474fce 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -4794,6 +4794,21 @@ files = [ grpcio = "*" protobuf = ">=3,<5" +[[package]] +name = "mongomock" +version = "4.1.2" +description = "Fake pymongo stub for testing simple MongoDB-dependent code" +optional = true +python-versions = "*" +files = [ + {file = "mongomock-4.1.2-py2.py3-none-any.whl", hash = "sha256:08a24938a05c80c69b6b8b19a09888d38d8c6e7328547f94d46cadb7f47209f2"}, + {file = "mongomock-4.1.2.tar.gz", hash = "sha256:f06cd62afb8ae3ef63ba31349abd220a657ef0dd4f0243a29587c5213f931b7d"}, +] + +[package.dependencies] +packaging = "*" +sentinels = "*" + [[package]] name = "more-itertools" version = "10.1.0" @@ -8910,6 +8925,16 @@ files = [ {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, ] +[[package]] +name = "sentinels" +version = "1.0.0" +description = "Various objects to denote special meanings in python" +optional = true +python-versions = "*" +files = [ + {file = "sentinels-1.0.0.tar.gz", hash = "sha256:7be0704d7fe1925e397e92d18669ace2f619c92b5d4eb21a89f31e026f9ff4b1"}, +] + [[package]] name = "setuptools" version = "67.8.0" @@ -11075,7 +11100,7 @@ cli = ["typer"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "mongomock", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -11085,4 +11110,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2e6a3d2370b12302f338bee317b467bff45e8a92954530cbaac164f01635076a" +content-hash = "2a1c9f13a4baac4b233a8151660c3bfde06b69318f789de92b7eb203157d9109" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index aed189232e78b..a11d7df99b855 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -143,6 +143,7 @@ azure-ai-textanalytics = {version = "^5.3.0", optional = true} google-cloud-documentai = {version = "^2.20.1", optional = true} fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"} javelin-sdk = {version = "^0.1.8", optional = true} +mongomock = {version = "^4.1.2", optional = true} [tool.poetry.group.test.dependencies] @@ -379,6 +380,7 @@ extended_testing = [ "rspace_client", "fireworks-ai", "javelin-sdk", + "mongomock", ] [tool.ruff] diff --git a/libs/langchain/tests/integration_tests/storage/test_mongodb.py b/libs/langchain/tests/integration_tests/storage/test_mongodb.py new file mode 100644 index 0000000000000..87502c23fb91f --- /dev/null +++ b/libs/langchain/tests/integration_tests/storage/test_mongodb.py @@ -0,0 +1,73 @@ +from typing import Generator + +import pytest +from langchain_core.documents import Document + +from langchain.storage.mongodb import MongodbStore + +pytest.importorskip("pymongo") + + +@pytest.fixture +def mongo_store() -> Generator: + import mongomock + + # mongomock creates a mock MongoDB instance for testing purposes + with mongomock.patch(servers=(("localhost", 27017),)): + yield MongodbStore("mongodb://localhost:27017/", "test_db", "test_collection") + + +def test_mset_and_mget(mongo_store: MongodbStore) -> None: + doc1 = Document(page_content="doc1") + doc2 = Document(page_content="doc2") + + # Set documents in the store + mongo_store.mset([("key1", doc1), ("key2", doc2)]) + + # Get documents from the store + retrieved_docs = mongo_store.mget(["key1", "key2"]) + + assert retrieved_docs[0] and retrieved_docs[0].page_content == "doc1" + assert retrieved_docs[1] and retrieved_docs[1].page_content == "doc2" + + +def test_yield_keys(mongo_store: MongodbStore) -> None: + mongo_store.mset( + [ + ("key1", Document(page_content="doc1")), + ("key2", Document(page_content="doc2")), + ("another_key", Document(page_content="other")), + ] + ) + + # Test without prefix + keys = list(mongo_store.yield_keys()) + assert set(keys) == {"key1", "key2", "another_key"} + + # Test with prefix + keys_with_prefix = list(mongo_store.yield_keys(prefix="key")) + assert set(keys_with_prefix) == {"key1", "key2"} + + +def test_mdelete(mongo_store: MongodbStore) -> None: + mongo_store.mset( + [ + ("key1", Document(page_content="doc1")), + ("key2", Document(page_content="doc2")), + ] + ) + # Delete single document + mongo_store.mdelete(["key1"]) + remaining_docs = list(mongo_store.yield_keys()) + assert "key1" not in remaining_docs + assert "key2" in remaining_docs + + # Delete multiple documents + mongo_store.mdelete(["key2"]) + remaining_docs = list(mongo_store.yield_keys()) + assert len(remaining_docs) == 0 + + +def test_init_errors() -> None: + with pytest.raises(ValueError): + MongodbStore("", "", "") diff --git a/libs/langchain/tests/unit_tests/storage/test_mongodb.py b/libs/langchain/tests/unit_tests/storage/test_mongodb.py new file mode 100644 index 0000000000000..06ec3575a857d --- /dev/null +++ b/libs/langchain/tests/unit_tests/storage/test_mongodb.py @@ -0,0 +1,11 @@ +"""Light weight unit test that attempts to import MongodbStore. + +The actual code is tested in integration tests. + +This test is intended to catch errors in the import process. +""" + + +def test_import_storage() -> None: + """Attempt to import storage modules.""" + from langchain.storage.mongodb import MongodbStore # noqa