diff --git a/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py b/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py index 8082578..0022fae 100644 --- a/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py +++ b/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import ( @@ -17,23 +17,6 @@ DEFAULT_SESSION_ID_KEY = "SessionId" DEFAULT_HISTORY_KEY = "History" -try: - from motor.motor_asyncio import ( - AsyncIOMotorClient, - AsyncIOMotorCollection, - AsyncIOMotorDatabase, - ) - - _motor_available = True -except ImportError: - AsyncIOMotorClient = None # type: ignore - _motor_available = False - logger.warning( - "Motor library is not installed. Asynchronous methods will fall back to using " - "`run_in_executor`, which is less efficient. " - "Install motor with `pip install motor` for improved performance." - ) - class MongoDBChatMessageHistory(BaseChatMessageHistory): """Chat message history that stores history in MongoDB. @@ -138,15 +121,6 @@ def __init__( self.db = self.client[database_name] self.collection = self.db[collection_name] - if _motor_available: - self.async_client: AsyncIOMotorClient = AsyncIOMotorClient( - connection_string - ) - self.async_db: AsyncIOMotorDatabase = self.async_client[database_name] - self.async_collection: AsyncIOMotorCollection = self.async_db[ - collection_name - ] - if create_index: index_kwargs = index_kwargs or {} self.collection.create_index(self.session_id_key, **index_kwargs) @@ -179,37 +153,6 @@ def messages(self) -> List[BaseMessage]: # type: ignore messages = messages_from_dict(items) return messages - async def aget_messages(self) -> List[BaseMessage]: - """Async version of getting messages from MongoDB""" - if not _motor_available: - logger.warning( - "Motor library is not installed. " - "Using `run_in_executor` for aget_messages, " - "which may be less efficient." - ) - return await super().aget_messages() - - if self.history_size is None: - cursor = self.async_collection.find({self.session_id_key: self.session_id}) - else: - total_count = await self.async_collection.count_documents( - {self.session_id_key: self.session_id} - ) - skip_count = max( - 0, - total_count - self.history_size, - ) - cursor = self.async_collection.find( - {self.session_id_key: self.session_id}, skip=skip_count - ) - - items = [] - async for document in cursor: - items.append(json.loads(document[self.history_key])) - - messages = messages_from_dict(items) - return messages - def add_message(self, message: BaseMessage) -> None: """Append the message to the record in MongoDB""" try: @@ -222,39 +165,9 @@ def add_message(self, message: BaseMessage) -> None: except errors.WriteError as err: logger.error(err) - async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: - """Async add a list of messages to MongoDB""" - if not _motor_available: - logger.warning( - "Motor library is not installed. " - "Using `run_in_executor` for aadd_messages, " - "which may be less efficient." - ) - return await super().aadd_messages(messages) - - documents = [ - { - self.session_id_key: self.session_id, - self.history_key: json.dumps(message_to_dict(message)), - } - for message in messages - ] - await self.async_collection.insert_many(documents) - def clear(self) -> None: """Clear session memory from MongoDB""" try: self.collection.delete_many({self.session_id_key: self.session_id}) except errors.WriteError as err: logger.error(err) - - async def aclear(self) -> None: - """Async clear session memory from MongoDB""" - if not _motor_available: - logger.warning( - "Motor library is not installed. " - "Using `run_in_executor` for aclear, which may be less efficient." - ) - return await super().aclear() - - await self.async_collection.delete_many({self.session_id_key: self.session_id}) diff --git a/libs/langchain-mongodb/langchain_mongodb/indexes.py b/libs/langchain-mongodb/langchain_mongodb/indexes.py index f8bf21b..c97aa0b 100644 --- a/libs/langchain-mongodb/langchain_mongodb/indexes.py +++ b/libs/langchain-mongodb/langchain_mongodb/indexes.py @@ -1,47 +1,19 @@ +# Based on https://github.com/langchain-ai/langchain/blob/8f5e72de057bc07df19f7d7aefb7673b64fbb1b4/libs/community/langchain_community/indexes/_document_manager.py#L58 +from __future__ import annotations + +import functools from typing import Any, Dict, List, Optional, Sequence from langchain_core.indexing.base import RecordManager - - -def _get_pymongo_client(mongodb_url: str, **kwargs: Any) -> Any: - """Get MongoClient for sync operations from the mongodb_url, - otherwise raise error.""" - from pymongo import MongoClient - - try: - client: MongoClient = MongoClient(mongodb_url, **kwargs) - except ValueError as e: - raise ImportError( - f"MongoClient string provided is not in proper format. " f"Got error: {e} " - ) from None - return client - - -def _get_motor_client(mongodb_url: str, **kwargs: Any) -> Any: - """Get AsyncIOMotorClient for async operations from the mongodb_url, - otherwise raise error.""" - from motor.motor_asyncio import AsyncIOMotorClient - - try: - client: AsyncIOMotorClient = AsyncIOMotorClient(mongodb_url, **kwargs) - except ValueError as e: - raise ImportError( - f"AsyncIOMotorClient string provided is not in proper format. " - f"Got error: {e} " - ) from None - return client +from langchain_core.runnables.config import run_in_executor +from pymongo import MongoClient +from pymongo.collection import Collection class MongoDBRecordManager(RecordManager): """A MongoDB-based implementation of the record manager.""" - def __init__( - self, - *, - connection_string: str, - db_name: str, - collection_name: str, - ) -> None: + def __init__(self, collection: Collection) -> None: """Initialize the MongoDBRecordManager. Args: @@ -49,13 +21,27 @@ def __init__( db_name: The name of the database to use. collection_name: The name of the collection to use. """ - super().__init__(namespace=".".join([db_name, collection_name])) - self.sync_client = _get_pymongo_client(connection_string) - self.sync_db = self.sync_client[db_name] - self.sync_collection = self.sync_db[collection_name] - self.async_client = _get_motor_client(connection_string) - self.async_db = self.async_client[db_name] - self.async_collection = self.async_db[collection_name] + namespace = f"{collection.database.name}.{collection.name}" + super().__init__(namespace=namespace) + self._collection = collection + + @classmethod + def from_connection_string( + cls, connection_string: str, namespace: str + ) -> MongoDBRecordManager: + """Construct a RecordManager from a MongoDB connection URI. + + Args: + connection_string: A valid MongoDB connection URI. + namespace: A valid MongoDB namespace (in form f"{database}.{collection}") + + Returns: + A new MongoDBRecordManager instance. + """ + client: MongoClient = MongoClient(connection_string) + db_name, collection_name = namespace.split(".") + collection = client[db_name][collection_name] + return cls(collection=collection) def create_schema(self) -> None: """Create the database schema for the document manager.""" @@ -80,7 +66,7 @@ def update( raise ValueError("Number of keys does not match number of group_ids") for key, group_id in zip(keys, group_ids): - self.sync_collection.find_one_and_update( + self.collection.find_one_and_update( {"namespace": self.namespace, "key": key}, {"$set": {"group_id": group_id, "updated_at": self.get_time()}}, upsert=True, @@ -94,22 +80,10 @@ async def aupdate( time_at_least: Optional[float] = None, ) -> None: """Asynchronously upsert documents into the MongoDB collection.""" - if group_ids is None: - group_ids = [None] * len(keys) - - if len(keys) != len(group_ids): - raise ValueError("Number of keys does not match number of group_ids") - - update_time = await self.aget_time() - if time_at_least and update_time < time_at_least: - raise ValueError("Server time is behind the expected time_at_least") - - for key, group_id in zip(keys, group_ids): - await self.async_collection.find_one_and_update( - {"namespace": self.namespace, "key": key}, - {"$set": {"group_id": group_id, "updated_at": update_time}}, - upsert=True, - ) + func = functools.partial( + self.update, keys, group_ids=group_ids, time_at_least=time_at_least + ) + return await run_in_executor(None, func) def get_time(self) -> float: """Get the current server time as a timestamp.""" @@ -120,9 +94,8 @@ def get_time(self) -> float: async def aget_time(self) -> float: """Asynchronously get the current server time as a timestamp.""" - host_info = await self.async_collection.database.command("hostInfo") - local_time = host_info["system"]["currentTime"] - return local_time.timestamp() + func = functools.partial(self.get_time) + return await run_in_executor(None, func) def exists(self, keys: Sequence[str]) -> List[bool]: """Check if the given keys exist in the MongoDB collection.""" @@ -136,11 +109,8 @@ def exists(self, keys: Sequence[str]) -> List[bool]: async def aexists(self, keys: Sequence[str]) -> List[bool]: """Asynchronously check if the given keys exist in the MongoDB collection.""" - cursor = self.async_collection.find( - {"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1} - ) - existing_keys = {doc["key"] async for doc in cursor} - return [key in existing_keys for key in keys] + func = functools.partial(self.exists, keys) + return await run_in_executor(None, func) def list_keys( self, @@ -178,20 +148,10 @@ async def alist_keys( Asynchronously list documents in the MongoDB collection based on the provided date range. """ - query: Dict[str, Any] = {"namespace": self.namespace} - if before: - query["updated_at"] = {"$lt": before} - if after: - query["updated_at"] = {"$gt": after} - if group_ids: - query["group_id"] = {"$in": group_ids} - - cursor = ( - self.async_collection.find(query, {"key": 1}).limit(limit) - if limit - else self.async_collection.find(query, {"key": 1}) + func = functools.partial( + self.list_keys, before=before, after=after, group_ids=group_ids, limit=limit ) - return [doc["key"] async for doc in cursor] + return await run_in_executor(None, func) def delete_keys(self, keys: Sequence[str]) -> None: """Delete documents from the MongoDB collection.""" @@ -201,6 +161,5 @@ def delete_keys(self, keys: Sequence[str]) -> None: async def adelete_keys(self, keys: Sequence[str]) -> None: """Asynchronously delete documents from the MongoDB collection.""" - await self.async_collection.delete_many( - {"namespace": self.namespace, "key": {"$in": keys}} - ) + func = functools.partial(self.delete_keys, keys) + return await run_in_executor(None, func) diff --git a/libs/langchain-mongodb/pyproject.toml b/libs/langchain-mongodb/pyproject.toml index 91b0be6..7da3f4a 100644 --- a/libs/langchain-mongodb/pyproject.toml +++ b/libs/langchain-mongodb/pyproject.toml @@ -17,7 +17,6 @@ pymongo = "^4.6.1" langchain-core = "^0.3" langchain = "^0.3" langchain-text-splitters = "^0.3" -motor = "^3.5" [[tool.poetry.dependencies.numpy]] version = "^1" @@ -71,7 +70,7 @@ lint.select = [ "B", # flake8-bugbear "I", # isort ] -lint.ignore = ["E501", "B008", "UP007", "UP006"] +lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035"] [tool.coverage.run] omit = ["tests/*"]