Skip to content

Commit

Permalink
add mongodb_store
Browse files Browse the repository at this point in the history
  • Loading branch information
xieqihui committed Nov 24, 2023
1 parent 751226e commit 9441b52
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 2 deletions.
126 changes: 126 additions & 0 deletions libs/langchain/langchain/storage/mongodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Iterator, List, Optional, Sequence, Tuple

from langchain_core.stores import BaseStore

from langchain.schema import Document


class MongodbStore(BaseStore[str, Document]):
"""BaseStore implementation using MongoDB as the underlying store.
Examples:
Create a MongodbStore instance and perform operations on it:
.. code-block:: python
# Instantiate the MongodbStore with a MongoDB connection
from langchain.storage import MongodbStore
mongo_conn_str = "mongodb://localhost:27017/"
mongodb_store = MongodbStore(mongo_conn_str, db_name="test-db",
collection_name="test-collection")
# Set values for keys
doc1 = Document(...)
doc2 = Document(...)
mongodb_store.mset([("key1", doc1), ("key2", doc2)])
# Get values for keys
values = mongodb_store.mget(["key1", "key2"])
# [doc1, doc2]
# Iterate over keys
for key in mongodb_store.yield_keys():
print(key)
# Delete keys
mongodb_store.mdelete(["key1", "key2"])
"""

def __init__(
self,
connection_string: str,
db_name: str,
collection_name: str,
client_kwargs: Optional[dict] = None,
) -> None:
"""Initialize the MongodbStore with a MongoDB connection string.
Args:
connection_string (str): MongoDB connection string
db_name (str): name to use
collection_name (str): collection name to use
client_kwargs (dict): Keyword arguments to pass to the Mongo client
"""
try:
from pymongo import MongoClient
except ImportError as e:
raise ImportError(
"The MongodbStore requires the pymongo library to be "
"installed. "
"pip install pymongo"
) from e

if not connection_string:
raise ValueError("connection_string must be provided.")
if not db_name:
raise ValueError("db_name must be provided.")
if not collection_name:
raise ValueError("collection_name must be provided.")

self.client = MongoClient(connection_string, **(client_kwargs or {}))
self.collection = self.client[db_name][collection_name]

def mget(self, keys: Sequence[str]) -> List[Optional[Document]]:
"""Get the list of documents associated with the given keys.
Args:
keys (list[str]): A list of keys representing Document IDs..
Returns:
list[Document]: A list of Documents corresponding to the provided
keys, where each Document is either retrieved successfully or
represented as None if not found.
"""
result = self.collection.find({"_id": {"$in": keys}})
result_dict = {doc["_id"]: Document(**doc["value"]) for doc in result}
return [result_dict.get(key) for key in keys]

def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None:
"""Set the given key-value pairs.
Args:
key_value_pairs (list[tuple[str, Document]]): A list of id-document
pairs.
Returns:
None
"""
from pymongo import UpdateOne

updates = [{"_id": k, "value": v.__dict__} for k, v in key_value_pairs]
self.collection.bulk_write(
[UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates]
)

def mdelete(self, keys: Sequence[str]) -> None:
"""Delete the given ids.
Args:
keys (list[str]): A list of keys representing Document IDs..
"""
self.collection.delete_many({"_id": {"$in": keys}})

def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
"""Yield keys in the store.
Args:
prefix (str): prefix of keys to retrieve.
"""
if prefix is None:
for doc in self.collection.find(projection=["_id"]):
yield doc["_id"]
else:
for doc in self.collection.find(
{"_id": {"$regex": f"^{prefix}"}}, projection=["_id"]
):
yield doc["_id"]
29 changes: 27 additions & 2 deletions libs/langchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ azure-ai-textanalytics = {version = "^5.3.0", optional = true}
google-cloud-documentai = {version = "^2.20.1", optional = true}
fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"}
javelin-sdk = {version = "^0.1.8", optional = true}
mongomock = {version = "^4.1.2", optional = true}


[tool.poetry.group.test.dependencies]
Expand Down Expand Up @@ -379,6 +380,7 @@ extended_testing = [
"rspace_client",
"fireworks-ai",
"javelin-sdk",
"mongomock",
]

[tool.ruff]
Expand Down
73 changes: 73 additions & 0 deletions libs/langchain/tests/integration_tests/storage/test_mongodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Generator

import pytest
from langchain_core.documents import Document

from langchain.storage.mongodb import MongodbStore

pytest.importorskip("pymongo")


@pytest.fixture
def mongo_store() -> Generator:
import mongomock

# mongomock creates a mock MongoDB instance for testing purposes
with mongomock.patch(servers=(("localhost", 27017),)):
yield MongodbStore("mongodb://localhost:27017/", "test_db", "test_collection")


def test_mset_and_mget(mongo_store: MongodbStore) -> None:
doc1 = Document(page_content="doc1")
doc2 = Document(page_content="doc2")

# Set documents in the store
mongo_store.mset([("key1", doc1), ("key2", doc2)])

# Get documents from the store
retrieved_docs = mongo_store.mget(["key1", "key2"])

assert retrieved_docs[0] and retrieved_docs[0].page_content == "doc1"
assert retrieved_docs[1] and retrieved_docs[1].page_content == "doc2"


def test_yield_keys(mongo_store: MongodbStore) -> None:
mongo_store.mset(
[
("key1", Document(page_content="doc1")),
("key2", Document(page_content="doc2")),
("another_key", Document(page_content="other")),
]
)

# Test without prefix
keys = list(mongo_store.yield_keys())
assert set(keys) == {"key1", "key2", "another_key"}

# Test with prefix
keys_with_prefix = list(mongo_store.yield_keys(prefix="key"))
assert set(keys_with_prefix) == {"key1", "key2"}


def test_mdelete(mongo_store: MongodbStore) -> None:
mongo_store.mset(
[
("key1", Document(page_content="doc1")),
("key2", Document(page_content="doc2")),
]
)
# Delete single document
mongo_store.mdelete(["key1"])
remaining_docs = list(mongo_store.yield_keys())
assert "key1" not in remaining_docs
assert "key2" in remaining_docs

# Delete multiple documents
mongo_store.mdelete(["key2"])
remaining_docs = list(mongo_store.yield_keys())
assert len(remaining_docs) == 0


def test_init_errors() -> None:
with pytest.raises(ValueError):
MongodbStore("", "", "")
11 changes: 11 additions & 0 deletions libs/langchain/tests/unit_tests/storage/test_mongodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Light weight unit test that attempts to import MongodbStore.
The actual code is tested in integration tests.
This test is intended to catch errors in the import process.
"""


def test_import_storage() -> None:
"""Attempt to import storage modules."""
from langchain.storage.mongodb import MongodbStore # noqa

0 comments on commit 9441b52

Please sign in to comment.