Skip to content

Commit

Permalink
do not use motor
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Dec 17, 2024
1 parent 2941b94 commit 8a2ea3b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 174 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
Expand All @@ -17,23 +17,6 @@
DEFAULT_SESSION_ID_KEY = "SessionId"
DEFAULT_HISTORY_KEY = "History"

try:
from motor.motor_asyncio import (
AsyncIOMotorClient,
AsyncIOMotorCollection,
AsyncIOMotorDatabase,
)

_motor_available = True
except ImportError:
AsyncIOMotorClient = None # type: ignore
_motor_available = False
logger.warning(
"Motor library is not installed. Asynchronous methods will fall back to using "
"`run_in_executor`, which is less efficient. "
"Install motor with `pip install motor` for improved performance."
)


class MongoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in MongoDB.
Expand Down Expand Up @@ -138,15 +121,6 @@ def __init__(
self.db = self.client[database_name]
self.collection = self.db[collection_name]

if _motor_available:
self.async_client: AsyncIOMotorClient = AsyncIOMotorClient(
connection_string
)
self.async_db: AsyncIOMotorDatabase = self.async_client[database_name]
self.async_collection: AsyncIOMotorCollection = self.async_db[
collection_name
]

if create_index:
index_kwargs = index_kwargs or {}
self.collection.create_index(self.session_id_key, **index_kwargs)
Expand Down Expand Up @@ -179,37 +153,6 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

async def aget_messages(self) -> List[BaseMessage]:
"""Async version of getting messages from MongoDB"""
if not _motor_available:
logger.warning(
"Motor library is not installed. "
"Using `run_in_executor` for aget_messages, "
"which may be less efficient."
)
return await super().aget_messages()

if self.history_size is None:
cursor = self.async_collection.find({self.session_id_key: self.session_id})
else:
total_count = await self.async_collection.count_documents(
{self.session_id_key: self.session_id}
)
skip_count = max(
0,
total_count - self.history_size,
)
cursor = self.async_collection.find(
{self.session_id_key: self.session_id}, skip=skip_count
)

items = []
async for document in cursor:
items.append(json.loads(document[self.history_key]))

messages = messages_from_dict(items)
return messages

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in MongoDB"""
try:
Expand All @@ -222,39 +165,9 @@ def add_message(self, message: BaseMessage) -> None:
except errors.WriteError as err:
logger.error(err)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Async add a list of messages to MongoDB"""
if not _motor_available:
logger.warning(
"Motor library is not installed. "
"Using `run_in_executor` for aadd_messages, "
"which may be less efficient."
)
return await super().aadd_messages(messages)

documents = [
{
self.session_id_key: self.session_id,
self.history_key: json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.insert_many(documents)

def clear(self) -> None:
"""Clear session memory from MongoDB"""
try:
self.collection.delete_many({self.session_id_key: self.session_id})
except errors.WriteError as err:
logger.error(err)

async def aclear(self) -> None:
"""Async clear session memory from MongoDB"""
if not _motor_available:
logger.warning(
"Motor library is not installed. "
"Using `run_in_executor` for aclear, which may be less efficient."
)
return await super().aclear()

await self.async_collection.delete_many({self.session_id_key: self.session_id})
127 changes: 43 additions & 84 deletions libs/langchain-mongodb/langchain_mongodb/indexes.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,47 @@
# Based on https://github.com/langchain-ai/langchain/blob/8f5e72de057bc07df19f7d7aefb7673b64fbb1b4/libs/community/langchain_community/indexes/_document_manager.py#L58
from __future__ import annotations

import functools
from typing import Any, Dict, List, Optional, Sequence

from langchain_core.indexing.base import RecordManager


def _get_pymongo_client(mongodb_url: str, **kwargs: Any) -> Any:
"""Get MongoClient for sync operations from the mongodb_url,
otherwise raise error."""
from pymongo import MongoClient

try:
client: MongoClient = MongoClient(mongodb_url, **kwargs)
except ValueError as e:
raise ImportError(
f"MongoClient string provided is not in proper format. " f"Got error: {e} "
) from None
return client


def _get_motor_client(mongodb_url: str, **kwargs: Any) -> Any:
"""Get AsyncIOMotorClient for async operations from the mongodb_url,
otherwise raise error."""
from motor.motor_asyncio import AsyncIOMotorClient

try:
client: AsyncIOMotorClient = AsyncIOMotorClient(mongodb_url, **kwargs)
except ValueError as e:
raise ImportError(
f"AsyncIOMotorClient string provided is not in proper format. "
f"Got error: {e} "
) from None
return client
from langchain_core.runnables.config import run_in_executor
from pymongo import MongoClient
from pymongo.collection import Collection


class MongoDBRecordManager(RecordManager):
"""A MongoDB-based implementation of the record manager."""

def __init__(
self,
*,
connection_string: str,
db_name: str,
collection_name: str,
) -> None:
def __init__(self, collection: Collection) -> None:
"""Initialize the MongoDBRecordManager.
Args:
connection_string: A valid MongoDB connection URI.
db_name: The name of the database to use.
collection_name: The name of the collection to use.
"""
super().__init__(namespace=".".join([db_name, collection_name]))
self.sync_client = _get_pymongo_client(connection_string)
self.sync_db = self.sync_client[db_name]
self.sync_collection = self.sync_db[collection_name]
self.async_client = _get_motor_client(connection_string)
self.async_db = self.async_client[db_name]
self.async_collection = self.async_db[collection_name]
namespace = f"{collection.database.name}.{collection.name}"
super().__init__(namespace=namespace)
self._collection = collection

@classmethod
def from_connection_string(
cls, connection_string: str, namespace: str
) -> MongoDBRecordManager:
"""Construct a RecordManager from a MongoDB connection URI.
Args:
connection_string: A valid MongoDB connection URI.
namespace: A valid MongoDB namespace (in form f"{database}.{collection}")
Returns:
A new MongoDBRecordManager instance.
"""
client: MongoClient = MongoClient(connection_string)
db_name, collection_name = namespace.split(".")
collection = client[db_name][collection_name]
return cls(collection=collection)

def create_schema(self) -> None:
"""Create the database schema for the document manager."""
Expand All @@ -80,7 +66,7 @@ def update(
raise ValueError("Number of keys does not match number of group_ids")

for key, group_id in zip(keys, group_ids):
self.sync_collection.find_one_and_update(
self.collection.find_one_and_update(
{"namespace": self.namespace, "key": key},
{"$set": {"group_id": group_id, "updated_at": self.get_time()}},
upsert=True,
Expand All @@ -94,22 +80,10 @@ async def aupdate(
time_at_least: Optional[float] = None,
) -> None:
"""Asynchronously upsert documents into the MongoDB collection."""
if group_ids is None:
group_ids = [None] * len(keys)

if len(keys) != len(group_ids):
raise ValueError("Number of keys does not match number of group_ids")

update_time = await self.aget_time()
if time_at_least and update_time < time_at_least:
raise ValueError("Server time is behind the expected time_at_least")

for key, group_id in zip(keys, group_ids):
await self.async_collection.find_one_and_update(
{"namespace": self.namespace, "key": key},
{"$set": {"group_id": group_id, "updated_at": update_time}},
upsert=True,
)
func = functools.partial(
self.update, keys, group_ids=group_ids, time_at_least=time_at_least
)
return await run_in_executor(None, func)

def get_time(self) -> float:
"""Get the current server time as a timestamp."""
Expand All @@ -120,9 +94,8 @@ def get_time(self) -> float:

async def aget_time(self) -> float:
"""Asynchronously get the current server time as a timestamp."""
host_info = await self.async_collection.database.command("hostInfo")
local_time = host_info["system"]["currentTime"]
return local_time.timestamp()
func = functools.partial(self.get_time)
return await run_in_executor(None, func)

def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the MongoDB collection."""
Expand All @@ -136,11 +109,8 @@ def exists(self, keys: Sequence[str]) -> List[bool]:

async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Asynchronously check if the given keys exist in the MongoDB collection."""
cursor = self.async_collection.find(
{"namespace": self.namespace, "key": {"$in": keys}}, {"key": 1}
)
existing_keys = {doc["key"] async for doc in cursor}
return [key in existing_keys for key in keys]
func = functools.partial(self.exists, keys)
return await run_in_executor(None, func)

def list_keys(
self,
Expand Down Expand Up @@ -178,20 +148,10 @@ async def alist_keys(
Asynchronously list documents in the MongoDB collection
based on the provided date range.
"""
query: Dict[str, Any] = {"namespace": self.namespace}
if before:
query["updated_at"] = {"$lt": before}
if after:
query["updated_at"] = {"$gt": after}
if group_ids:
query["group_id"] = {"$in": group_ids}

cursor = (
self.async_collection.find(query, {"key": 1}).limit(limit)
if limit
else self.async_collection.find(query, {"key": 1})
func = functools.partial(
self.list_keys, before=before, after=after, group_ids=group_ids, limit=limit
)
return [doc["key"] async for doc in cursor]
return await run_in_executor(None, func)

def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete documents from the MongoDB collection."""
Expand All @@ -201,6 +161,5 @@ def delete_keys(self, keys: Sequence[str]) -> None:

async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Asynchronously delete documents from the MongoDB collection."""
await self.async_collection.delete_many(
{"namespace": self.namespace, "key": {"$in": keys}}
)
func = functools.partial(self.delete_keys, keys)
return await run_in_executor(None, func)
3 changes: 1 addition & 2 deletions libs/langchain-mongodb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pymongo = "^4.6.1"
langchain-core = "^0.3"
langchain = "^0.3"
langchain-text-splitters = "^0.3"
motor = "^3.5"

[[tool.poetry.dependencies.numpy]]
version = "^1"
Expand Down Expand Up @@ -71,7 +70,7 @@ lint.select = [
"B", # flake8-bugbear
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]
lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035"]

[tool.coverage.run]
omit = ["tests/*"]

0 comments on commit 8a2ea3b

Please sign in to comment.