From 7fc903464a753ac10e9b671906c2e9889e4d598e Mon Sep 17 00:00:00 2001 From: 2jimoo <107998986+2jimoo@users.noreply.github.com> Date: Sat, 24 Feb 2024 11:32:52 +0900 Subject: [PATCH] community: Add document manager and mongo document manager (#17320) - **Description:** - Add DocumentManager class, which is a nosql record manager. - In order to use index and aindex in libs/langchain/langchain/indexes/_api.py, DocumentManager inherits RecordManager. - Also I added the MongoDB implementation of Document Manager too. - **Dependencies:** pymongo, motor --------- Co-authored-by: Eugene Yurtsev --- docker/docker-compose.yml | 13 + .../indexes/_document_manager.py | 231 +++++++++++++++ .../integration_tests/indexes/__init__.py | 0 .../indexes/test_document_manager.py | 268 ++++++++++++++++++ 4 files changed, 512 insertions(+) create mode 100644 libs/community/langchain_community/indexes/_document_manager.py create mode 100644 libs/community/tests/integration_tests/indexes/__init__.py create mode 100644 libs/community/tests/integration_tests/indexes/test_document_manager.py diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 968e32469a34a..33c873e60e0af 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,5 +1,10 @@ # docker-compose to make it easier to spin up integration tests. # Services should use NON standard ports to avoid collision with +# any existing services that might be used for development. +# ATTENTION: When adding a service below use a non-standard port +# increment by one from the preceding port. +# For credentials always use `langchain` and `langchain` for the +# username and password. version: "3" name: langchain-tests @@ -19,3 +24,11 @@ services: image: graphdb ports: - "6021:7200" + mongo: + image: mongo:latest + container_name: mongo_container + ports: + - "6022:27017" + environment: + MONGO_INITDB_ROOT_USERNAME: langchain + MONGO_INITDB_ROOT_PASSWORD: langchain diff --git a/libs/community/langchain_community/indexes/_document_manager.py b/libs/community/langchain_community/indexes/_document_manager.py new file mode 100644 index 0000000000000..64fbe1b94616e --- /dev/null +++ b/libs/community/langchain_community/indexes/_document_manager.py @@ -0,0 +1,231 @@ +from typing import Any, Dict, List, Optional, Sequence + +from langchain_community.indexes.base import RecordManager + +IMPORT_PYMONGO_ERROR = ( + "Could not import MongoClient. Please install it with `pip install pymongo`." +) +IMPORT_MOTOR_ASYNCIO_ERROR = ( + "Could not import AsyncIOMotorClient. Please install it with `pip install motor`." +) + + +def _import_pymongo() -> Any: + """Import PyMongo if available, otherwise raise error.""" + try: + from pymongo import MongoClient + except ImportError: + raise ImportError(IMPORT_PYMONGO_ERROR) + return MongoClient + + +def _get_pymongo_client(mongodb_url: str, **kwargs: Any) -> Any: + """Get MongoClient for sync operations from the mongodb_url, + otherwise raise error.""" + try: + pymongo = _import_pymongo() + client = pymongo(mongodb_url, **kwargs) + except ValueError as e: + raise ImportError( + f"MongoClient string provided is not in proper format. " f"Got error: {e} " + ) + return client + + +def _import_motor_asyncio() -> Any: + """Import Motor if available, otherwise raise error.""" + try: + from motor.motor_asyncio import AsyncIOMotorClient + except ImportError: + raise ImportError(IMPORT_MOTOR_ASYNCIO_ERROR) + return AsyncIOMotorClient + + +def _get_motor_client(mongodb_url: str, **kwargs: Any) -> Any: + """Get AsyncIOMotorClient for async operations from the mongodb_url, + otherwise raise error.""" + try: + motor = _import_motor_asyncio() + client = motor(mongodb_url, **kwargs) + except ValueError as e: + raise ImportError( + f"AsyncIOMotorClient string provided is not in proper format. " + f"Got error: {e} " + ) + return client + + +class MongoDocumentManager(RecordManager): + """A MongoDB based implementation of the document manager.""" + + def __init__( + self, + namespace: str, + *, + mongodb_url: str, + db_name: str, + collection_name: str = "documentMetadata", + ) -> None: + """Initialize the MongoDocumentManager. + + Args: + namespace: The namespace associated with this document manager. + db_name: The name of the database to use. + collection_name: The name of the collection to use. + Default is 'documentMetadata'. + """ + super().__init__(namespace=namespace) + self.sync_client = _get_pymongo_client(mongodb_url) + self.sync_db = self.sync_client[db_name] + self.sync_collection = self.sync_db[collection_name] + self.async_client = _get_motor_client(mongodb_url) + self.async_db = self.async_client[db_name] + self.async_collection = self.async_db[collection_name] + + def create_schema(self) -> None: + """Create the database schema for the document manager.""" + pass + + async def acreate_schema(self) -> None: + """Create the database schema for the document manager.""" + pass + + def update( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> None: + """Upsert documents into the MongoDB collection.""" + if group_ids is None: + group_ids = [None] * len(keys) + + if len(keys) != len(group_ids): + raise ValueError("Number of keys does not match number of group_ids") + + for key, group_id in zip(keys, group_ids): + self.sync_collection.find_one_and_update( + {"namespace": self.namespace, "key": key}, + {"$set": {"group_id": group_id, "updated_at": self.get_time()}}, + upsert=True, + ) + + async def aupdate( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> None: + """Asynchronously upsert documents into the MongoDB collection.""" + if group_ids is None: + group_ids = [None] * len(keys) + + if len(keys) != len(group_ids): + raise ValueError("Number of keys does not match number of group_ids") + + update_time = await self.aget_time() + if time_at_least and update_time < time_at_least: + raise ValueError("Server time is behind the expected time_at_least") + + for key, group_id in zip(keys, group_ids): + await self.async_collection.find_one_and_update( + {"namespace": self.namespace, "key": key}, + {"$set": {"group_id": group_id, "updated_at": update_time}}, + upsert=True, + ) + + def get_time(self) -> float: + """Get the current server time as a timestamp.""" + server_info = self.sync_db.command("hostInfo") + local_time = server_info["system"]["currentTime"] + timestamp = local_time.timestamp() + return timestamp + + async def aget_time(self) -> float: + """Asynchronously get the current server time as a timestamp.""" + host_info = await self.async_collection.database.command("hostInfo") + local_time = host_info["system"]["currentTime"] + return local_time.timestamp() + + def exists(self, keys: Sequence[str]) -> List[bool]: + """Check if the given keys exist in the MongoDB collection.""" + existing_keys = { + doc["key"] + for doc in self.sync_collection.find( + {"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1} + ) + } + return [key in existing_keys for key in keys] + + async def aexists(self, keys: Sequence[str]) -> List[bool]: + """Asynchronously check if the given keys exist in the MongoDB collection.""" + cursor = self.async_collection.find( + {"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1} + ) + existing_keys = {doc["key"] async for doc in cursor} + return [key in existing_keys for key in keys] + + def list_keys( + self, + *, + before: Optional[float] = None, + after: Optional[float] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + """List documents in the MongoDB collection based on the provided date range.""" + query: Dict[str, Any] = {"namespace": self.namespace} + if before: + query["updated_at"] = {"$lt": before} + if after: + query["updated_at"] = {"$gt": after} + if group_ids: + query["group_id"] = {"$in": group_ids} + + cursor = ( + self.sync_collection.find(query, {"key": 1}).limit(limit) + if limit + else self.sync_collection.find(query, {"key": 1}) + ) + return [doc["key"] for doc in cursor] + + async def alist_keys( + self, + *, + before: Optional[float] = None, + after: Optional[float] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + """ + Asynchronously list documents in the MongoDB collection + based on the provided date range. + """ + query: Dict[str, Any] = {"namespace": self.namespace} + if before: + query["updated_at"] = {"$lt": before} + if after: + query["updated_at"] = {"$gt": after} + if group_ids: + query["group_id"] = {"$in": group_ids} + + cursor = ( + self.async_collection.find(query, {"key": 1}).limit(limit) + if limit + else self.async_collection.find(query, {"key": 1}) + ) + return [doc["key"] async for doc in cursor] + + def delete_keys(self, keys: Sequence[str]) -> None: + """Delete documents from the MongoDB collection.""" + self.sync_collection.delete_many( + {"namespace": self.namespace, "key": {"$in": keys}} + ) + + async def adelete_keys(self, keys: Sequence[str]) -> None: + """Asynchronously delete documents from the MongoDB collection.""" + await self.async_collection.delete_many( + {"namespace": self.namespace, "key": {"$in": keys}} + ) diff --git a/libs/community/tests/integration_tests/indexes/__init__.py b/libs/community/tests/integration_tests/indexes/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/community/tests/integration_tests/indexes/test_document_manager.py b/libs/community/tests/integration_tests/indexes/test_document_manager.py new file mode 100644 index 0000000000000..f9d88d8398ae7 --- /dev/null +++ b/libs/community/tests/integration_tests/indexes/test_document_manager.py @@ -0,0 +1,268 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest +import pytest_asyncio + +from langchain_community.indexes._document_manager import MongoDocumentManager + + +@pytest.fixture +@pytest.mark.requires("pymongo") +def manager() -> MongoDocumentManager: + """Initialize the test MongoDB and yield the DocumentManager instance.""" + document_manager = MongoDocumentManager( + namespace="kittens", + mongodb_url="mongodb://langchain:langchain@localhost:6022/", + db_name="test_db", + collection_name="test_collection", + ) + return document_manager + + +@pytest_asyncio.fixture +@pytest.mark.requires("motor") +async def amanager() -> MongoDocumentManager: + """Initialize the test MongoDB and yield the DocumentManager instance.""" + document_manager = MongoDocumentManager( + namespace="kittens", + mongodb_url="mongodb://langchain:langchain@localhost:6022/", + db_name="test_db", + collection_name="test_collection", + ) + return document_manager + + +@pytest.mark.requires("pymongo") +def test_update(manager: MongoDocumentManager) -> None: + """Test updating records in the MongoDB.""" + read_keys = manager.list_keys() + updated_keys = ["update_key1", "update_key2", "update_key3"] + manager.update(updated_keys) + all_keys = manager.list_keys() + assert sorted(all_keys) == sorted(read_keys + updated_keys) + + +@pytest.mark.asyncio +@pytest.mark.requires("motor") +async def test_aupdate(amanager: MongoDocumentManager) -> None: + """Test updating records in the MongoDB.""" + read_keys = await amanager.alist_keys() + aupdated_keys = ["aupdate_key1", "aupdate_key2", "aupdate_key3"] + await amanager.aupdate(aupdated_keys) + all_keys = await amanager.alist_keys() + assert sorted(all_keys) == sorted(read_keys + aupdated_keys) + + +@pytest.mark.requires("pymongo") +def test_update_timestamp(manager: MongoDocumentManager) -> None: + """Test updating records with timestamps in MongoDB.""" + with patch.object( + manager, "get_time", return_value=datetime(2024, 2, 23).timestamp() + ): + manager.update(["key1"]) + records = list( + manager.sync_collection.find({"namespace": manager.namespace, "key": "key1"}) + ) + + assert [ + { + "key": record["key"], + "namespace": record["namespace"], + "updated_at": record["updated_at"], + "group_id": record.get("group_id"), + } + for record in records + ] == [ + { + "group_id": None, + "key": "key1", + "namespace": "kittens", + "updated_at": datetime(2024, 2, 23).timestamp(), + } + ] + + +@pytest.mark.requires("motor") +async def test_aupdate_timestamp(amanager: MongoDocumentManager) -> None: + """Test asynchronously updating records with timestamps in MongoDB.""" + with patch.object( + amanager, "aget_time", return_value=datetime(2024, 2, 23).timestamp() + ): + await amanager.aupdate(["key1"]) + + records = [ + doc + async for doc in amanager.async_collection.find( + {"namespace": amanager.namespace, "key": "key1"} + ) + ] + + assert [ + { + "key": record["key"], + "namespace": record["namespace"], + "updated_at": record["updated_at"], + "group_id": record.get("group_id"), + } + for record in records + ] == [ + { + "group_id": None, + "key": "key1", + "namespace": "kittens", + "updated_at": datetime(2024, 2, 23).timestamp(), + } + ] + + +@pytest.mark.requires("pymongo") +def test_exists(manager: MongoDocumentManager) -> None: + """Test checking if keys exist in MongoDB.""" + keys = ["key1", "key2", "key3"] + manager.update(keys) + exists = manager.exists(keys) + assert len(exists) == len(keys) + assert all(exists) + + exists = manager.exists(["key1", "key4"]) + assert len(exists) == 2 + assert exists == [True, False] + + +@pytest.mark.requires("motor") +async def test_aexists(amanager: MongoDocumentManager) -> None: + """Test asynchronously checking if keys exist in MongoDB.""" + keys = ["key1", "key2", "key3"] + await amanager.aupdate(keys) + exists = await amanager.aexists(keys) + assert len(exists) == len(keys) + assert all(exists) + + exists = await amanager.aexists(["key1", "key4"]) + assert len(exists) == 2 + assert exists == [True, False] + + +@pytest.mark.requires("pymongo") +def test_list_keys(manager: MongoDocumentManager) -> None: + """Test listing keys in MongoDB.""" + manager.delete_keys(manager.list_keys()) + with patch.object( + manager, "get_time", return_value=datetime(2021, 1, 1).timestamp() + ): + manager.update(["key1"]) + with patch.object( + manager, "get_time", return_value=datetime(2022, 1, 1).timestamp() + ): + manager.update(["key2"]) + with patch.object( + manager, "get_time", return_value=datetime(2023, 1, 1).timestamp() + ): + manager.update(["key3"]) + with patch.object( + manager, "get_time", return_value=datetime(2024, 1, 1).timestamp() + ): + manager.update(["key4"], group_ids=["group1"]) + assert sorted(manager.list_keys()) == sorted(["key1", "key2", "key3", "key4"]) + assert sorted(manager.list_keys(after=datetime(2022, 2, 1).timestamp())) == sorted( + ["key3", "key4"] + ) + assert sorted(manager.list_keys(group_ids=["group1", "group2"])) == sorted(["key4"]) + + +@pytest.mark.requires("motor") +async def test_alist_keys(amanager: MongoDocumentManager) -> None: + """Test asynchronously listing keys in MongoDB.""" + await amanager.adelete_keys(await amanager.alist_keys()) + with patch.object( + amanager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + ): + await amanager.aupdate(["key1"]) + with patch.object( + amanager, "aget_time", return_value=datetime(2022, 1, 1).timestamp() + ): + await amanager.aupdate(["key2"]) + with patch.object( + amanager, "aget_time", return_value=datetime(2023, 1, 1).timestamp() + ): + await amanager.aupdate(["key3"]) + with patch.object( + amanager, "aget_time", return_value=datetime(2024, 1, 1).timestamp() + ): + await amanager.aupdate(["key4"], group_ids=["group1"]) + assert sorted(await amanager.alist_keys()) == sorted( + ["key1", "key2", "key3", "key4"] + ) + assert sorted( + await amanager.alist_keys(after=datetime(2022, 2, 1).timestamp()) + ) == sorted(["key3", "key4"]) + assert sorted(await amanager.alist_keys(group_ids=["group1", "group2"])) == sorted( + ["key4"] + ) + + +@pytest.mark.requires("pymongo") +def test_namespace_is_used(manager: MongoDocumentManager) -> None: + """Verify that namespace is taken into account for all operations in MongoDB.""" + manager.delete_keys(manager.list_keys()) + manager.update(["key1", "key2"], group_ids=["group1", "group2"]) + manager.sync_collection.insert_many( + [ + {"key": "key1", "namespace": "puppies", "group_id": None}, + {"key": "key3", "namespace": "puppies", "group_id": None}, + ] + ) + assert sorted(manager.list_keys()) == sorted(["key1", "key2"]) + manager.delete_keys(["key1"]) + assert sorted(manager.list_keys()) == sorted(["key2"]) + manager.update(["key3"], group_ids=["group3"]) + assert ( + manager.sync_collection.find_one({"key": "key3", "namespace": "kittens"})[ + "group_id" + ] + == "group3" + ) + + +@pytest.mark.requires("motor") +async def test_anamespace_is_used(amanager: MongoDocumentManager) -> None: + """ + Verify that namespace is taken into account for all operations + in MongoDB asynchronously. + """ + await amanager.adelete_keys(await amanager.alist_keys()) + await amanager.aupdate(["key1", "key2"], group_ids=["group1", "group2"]) + await amanager.async_collection.insert_many( + [ + {"key": "key1", "namespace": "puppies", "group_id": None}, + {"key": "key3", "namespace": "puppies", "group_id": None}, + ] + ) + assert sorted(await amanager.alist_keys()) == sorted(["key1", "key2"]) + await amanager.adelete_keys(["key1"]) + assert sorted(await amanager.alist_keys()) == sorted(["key2"]) + await amanager.aupdate(["key3"], group_ids=["group3"]) + assert ( + await amanager.async_collection.find_one( + {"key": "key3", "namespace": "kittens"} + ) + )["group_id"] == "group3" + + +@pytest.mark.requires("pymongo") +def test_delete_keys(manager: MongoDocumentManager) -> None: + """Test deleting keys from MongoDB.""" + manager.update(["key1", "key2", "key3"]) + manager.delete_keys(["key1", "key2"]) + remaining_keys = manager.list_keys() + assert sorted(remaining_keys) == sorted(["key3"]) + + +@pytest.mark.requires("motor") +async def test_adelete_keys(amanager: MongoDocumentManager) -> None: + """Test asynchronously deleting keys from MongoDB.""" + await amanager.aupdate(["key1", "key2", "key3"]) + await amanager.adelete_keys(["key1", "key2"]) + remaining_keys = await amanager.alist_keys() + assert sorted(remaining_keys) == sorted(["key3"])