-
Notifications
You must be signed in to change notification settings - Fork 15.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
239 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from typing import Iterator, List, Optional, Sequence, Tuple | ||
|
||
from langchain_core.stores import BaseStore | ||
|
||
from langchain.schema import Document | ||
|
||
|
||
class MongodbStore(BaseStore[str, Document]): | ||
"""BaseStore implementation using MongoDB as the underlying store. | ||
Examples: | ||
Create a MongodbStore instance and perform operations on it: | ||
.. code-block:: python | ||
# Instantiate the MongodbStore with a MongoDB connection | ||
from langchain.storage import MongodbStore | ||
mongo_conn_str = "mongodb://localhost:27017/" | ||
mongodb_store = MongodbStore(mongo_conn_str, db_name="test-db", | ||
collection_name="test-collection") | ||
# Set values for keys | ||
doc1 = Document(...) | ||
doc2 = Document(...) | ||
mongodb_store.mset([("key1", doc1), ("key2", doc2)]) | ||
# Get values for keys | ||
values = mongodb_store.mget(["key1", "key2"]) | ||
# [doc1, doc2] | ||
# Iterate over keys | ||
for key in mongodb_store.yield_keys(): | ||
print(key) | ||
# Delete keys | ||
mongodb_store.mdelete(["key1", "key2"]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
connection_string: str, | ||
db_name: str, | ||
collection_name: str, | ||
client_kwargs: Optional[dict] = None, | ||
) -> None: | ||
"""Initialize the MongodbStore with a MongoDB connection string. | ||
Args: | ||
connection_string (str): MongoDB connection string | ||
db_name (str): name to use | ||
collection_name (str): collection name to use | ||
client_kwargs (dict): Keyword arguments to pass to the Mongo client | ||
""" | ||
try: | ||
from pymongo import MongoClient | ||
except ImportError as e: | ||
raise ImportError( | ||
"The MongodbStore requires the pymongo library to be " | ||
"installed. " | ||
"pip install pymongo" | ||
) from e | ||
|
||
if not connection_string: | ||
raise ValueError("connection_string must be provided.") | ||
if not db_name: | ||
raise ValueError("db_name must be provided.") | ||
if not collection_name: | ||
raise ValueError("collection_name must be provided.") | ||
|
||
self.client = MongoClient(connection_string, **(client_kwargs or {})) | ||
self.collection = self.client[db_name][collection_name] | ||
|
||
def mget(self, keys: Sequence[str]) -> List[Optional[Document]]: | ||
"""Get the list of documents associated with the given keys. | ||
Args: | ||
keys (list[str]): A list of keys representing Document IDs.. | ||
Returns: | ||
list[Document]: A list of Documents corresponding to the provided | ||
keys, where each Document is either retrieved successfully or | ||
represented as None if not found. | ||
""" | ||
result = self.collection.find({"_id": {"$in": keys}}) | ||
result_dict = {doc["_id"]: Document(**doc["value"]) for doc in result} | ||
return [result_dict.get(key) for key in keys] | ||
|
||
def mset(self, key_value_pairs: Sequence[Tuple[str, Document]]) -> None: | ||
"""Set the given key-value pairs. | ||
Args: | ||
key_value_pairs (list[tuple[str, Document]]): A list of id-document | ||
pairs. | ||
Returns: | ||
None | ||
""" | ||
from pymongo import UpdateOne | ||
|
||
updates = [{"_id": k, "value": v.__dict__} for k, v in key_value_pairs] | ||
self.collection.bulk_write( | ||
[UpdateOne({"_id": u["_id"]}, {"$set": u}, upsert=True) for u in updates] | ||
) | ||
|
||
def mdelete(self, keys: Sequence[str]) -> None: | ||
"""Delete the given ids. | ||
Args: | ||
keys (list[str]): A list of keys representing Document IDs.. | ||
""" | ||
self.collection.delete_many({"_id": {"$in": keys}}) | ||
|
||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]: | ||
"""Yield keys in the store. | ||
Args: | ||
prefix (str): prefix of keys to retrieve. | ||
""" | ||
if prefix is None: | ||
for doc in self.collection.find(projection=["_id"]): | ||
yield doc["_id"] | ||
else: | ||
for doc in self.collection.find( | ||
{"_id": {"$regex": f"^{prefix}"}}, projection=["_id"] | ||
): | ||
yield doc["_id"] |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
libs/langchain/tests/integration_tests/storage/test_mongodb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Generator | ||
|
||
import pytest | ||
from langchain_core.documents import Document | ||
|
||
from langchain.storage.mongodb import MongodbStore | ||
|
||
pytest.importorskip("pymongo") | ||
|
||
|
||
@pytest.fixture | ||
def mongo_store() -> Generator: | ||
import mongomock | ||
|
||
# mongomock creates a mock MongoDB instance for testing purposes | ||
with mongomock.patch(servers=(("localhost", 27017),)): | ||
yield MongodbStore("mongodb://localhost:27017/", "test_db", "test_collection") | ||
|
||
|
||
def test_mset_and_mget(mongo_store: MongodbStore) -> None: | ||
doc1 = Document(page_content="doc1") | ||
doc2 = Document(page_content="doc2") | ||
|
||
# Set documents in the store | ||
mongo_store.mset([("key1", doc1), ("key2", doc2)]) | ||
|
||
# Get documents from the store | ||
retrieved_docs = mongo_store.mget(["key1", "key2"]) | ||
|
||
assert retrieved_docs[0] and retrieved_docs[0].page_content == "doc1" | ||
assert retrieved_docs[1] and retrieved_docs[1].page_content == "doc2" | ||
|
||
|
||
def test_yield_keys(mongo_store: MongodbStore) -> None: | ||
mongo_store.mset( | ||
[ | ||
("key1", Document(page_content="doc1")), | ||
("key2", Document(page_content="doc2")), | ||
("another_key", Document(page_content="other")), | ||
] | ||
) | ||
|
||
# Test without prefix | ||
keys = list(mongo_store.yield_keys()) | ||
assert set(keys) == {"key1", "key2", "another_key"} | ||
|
||
# Test with prefix | ||
keys_with_prefix = list(mongo_store.yield_keys(prefix="key")) | ||
assert set(keys_with_prefix) == {"key1", "key2"} | ||
|
||
|
||
def test_mdelete(mongo_store: MongodbStore) -> None: | ||
mongo_store.mset( | ||
[ | ||
("key1", Document(page_content="doc1")), | ||
("key2", Document(page_content="doc2")), | ||
] | ||
) | ||
# Delete single document | ||
mongo_store.mdelete(["key1"]) | ||
remaining_docs = list(mongo_store.yield_keys()) | ||
assert "key1" not in remaining_docs | ||
assert "key2" in remaining_docs | ||
|
||
# Delete multiple documents | ||
mongo_store.mdelete(["key2"]) | ||
remaining_docs = list(mongo_store.yield_keys()) | ||
assert len(remaining_docs) == 0 | ||
|
||
|
||
def test_init_errors() -> None: | ||
with pytest.raises(ValueError): | ||
MongodbStore("", "", "") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""Light weight unit test that attempts to import MongodbStore. | ||
The actual code is tested in integration tests. | ||
This test is intended to catch errors in the import process. | ||
""" | ||
|
||
|
||
def test_import_storage() -> None: | ||
"""Attempt to import storage modules.""" | ||
from langchain.storage.mongodb import MongodbStore # noqa |