diff --git a/libs/community/langchain_community/document_loaders/mongodb.py b/libs/community/langchain_community/document_loaders/mongodb.py index b1e062226546f..7eddb2f7eb1a1 100644 --- a/libs/community/langchain_community/document_loaders/mongodb.py +++ b/libs/community/langchain_community/document_loaders/mongodb.py @@ -20,13 +20,37 @@ def __init__( *, filter_criteria: Optional[Dict] = None, field_names: Optional[Sequence[str]] = None, + metadata_names: Optional[Sequence[str]] = None, + include_db_collection_in_metadata: bool = True, ) -> None: + """ + Initializes the MongoDB loader with necessary database connection + details and configurations. + + Args: + connection_string (str): MongoDB connection URI. + db_name (str):Name of the database to connect to. + collection_name (str): Name of the collection to fetch documents from. + filter_criteria (Optional[Dict]): MongoDB filter criteria for querying + documents. + field_names (Optional[Sequence[str]]): List of field names to retrieve + from documents. + metadata_names (Optional[Sequence[str]]): Additional metadata fields to + extract from documents. + include_db_collection_in_metadata (bool): Flag to include database and + collection names in metadata. + + Raises: + ImportError: If the motor library is not installed. + ValueError: If any necessary argument is missing. + """ try: from motor.motor_asyncio import AsyncIOMotorClient except ImportError as e: raise ImportError( "Cannot import from motor, please install with `pip install motor`." ) from e + if not connection_string: raise ValueError("connection_string must be provided.") @@ -39,8 +63,10 @@ def __init__( self.client = AsyncIOMotorClient(connection_string) self.db_name = db_name self.collection_name = collection_name - self.field_names = field_names + self.field_names = field_names or [] self.filter_criteria = filter_criteria or {} + self.metadata_names = metadata_names or [] + self.include_db_collection_in_metadata = include_db_collection_in_metadata self.db = self.client.get_database(db_name) self.collection = self.db.get_collection(collection_name) @@ -60,36 +86,24 @@ def load(self) -> List[Document]: return asyncio.run(self.aload()) async def aload(self) -> List[Document]: - """Load data into Document objects.""" + """Asynchronously loads data into Document objects.""" result = [] total_docs = await self.collection.count_documents(self.filter_criteria) - # Construct the projection dictionary if field_names are specified - projection = ( - {field: 1 for field in self.field_names} if self.field_names else None - ) + projection = self._construct_projection() async for doc in self.collection.find(self.filter_criteria, projection): - metadata = { - "database": self.db_name, - "collection": self.collection_name, - } + metadata = self._extract_fields(doc, self.metadata_names, default="") + + # Optionally add database and collection names to metadata + if self.include_db_collection_in_metadata: + metadata.update( + {"database": self.db_name, "collection": self.collection_name} + ) # Extract text content from filtered fields or use the entire document if self.field_names is not None: - fields = {} - for name in self.field_names: - # Split the field names to handle nested fields - keys = name.split(".") - value = doc - for key in keys: - if key in value: - value = value[key] - else: - value = "" - break - fields[name] = value - + fields = self._extract_fields(doc, self.field_names, default="") texts = [str(value) for value in fields.values()] text = " ".join(texts) else: @@ -104,3 +118,29 @@ async def aload(self) -> List[Document]: ) return result + + def _construct_projection(self) -> Optional[Dict]: + """Constructs the projection dictionary for MongoDB query based + on the specified field names and metadata names.""" + field_names = list(self.field_names) or [] + metadata_names = list(self.metadata_names) or [] + all_fields = field_names + metadata_names + return {field: 1 for field in all_fields} if all_fields else None + + def _extract_fields( + self, + document: Dict, + fields: Sequence[str], + default: str = "", + ) -> Dict: + """Extracts and returns values for specified fields from a document.""" + extracted = {} + for field in fields or []: + value = document + for key in field.split("."): + value = value.get(key, default) + if value == default: + break + new_field_name = field.replace(".", "_") + extracted[new_field_name] = value + return extracted diff --git a/libs/community/tests/unit_tests/document_loaders/test_mongodb.py b/libs/community/tests/unit_tests/document_loaders/test_mongodb.py index 121e8b39fd6be..72ed08905f745 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_mongodb.py +++ b/libs/community/tests/unit_tests/document_loaders/test_mongodb.py @@ -12,6 +12,7 @@ def raw_docs() -> List[Dict]: return [ {"_id": "1", "address": {"building": "1", "room": "1"}}, {"_id": "2", "address": {"building": "2", "room": "2"}}, + {"_id": "3", "address": {"building": "3", "room": "2"}}, ] @@ -19,18 +20,23 @@ def raw_docs() -> List[Dict]: def expected_documents() -> List[Document]: return [ Document( - page_content="{'_id': '1', 'address': {'building': '1', 'room': '1'}}", + page_content="{'_id': '2', 'address': {'building': '2', 'room': '2'}}", metadata={"database": "sample_restaurants", "collection": "restaurants"}, ), Document( - page_content="{'_id': '2', 'address': {'building': '2', 'room': '2'}}", + page_content="{'_id': '3', 'address': {'building': '3', 'room': '2'}}", metadata={"database": "sample_restaurants", "collection": "restaurants"}, ), ] @pytest.mark.requires("motor") -async def test_load_mocked(expected_documents: List[Document]) -> None: +async def test_load_mocked_with_filters(expected_documents: List[Document]) -> None: + filter_criteria = {"address.room": {"$eq": "2"}} + field_names = ["address.building", "address.room"] + metadata_names = ["_id"] + include_db_collection_in_metadata = True + mock_async_load = AsyncMock() mock_async_load.return_value = expected_documents @@ -51,7 +57,13 @@ async def test_load_mocked(expected_documents: List[Document]) -> None: new=mock_async_load, ): loader = MongodbLoader( - "mongodb://localhost:27017", "test_db", "test_collection" + "mongodb://localhost:27017", + "test_db", + "test_collection", + filter_criteria=filter_criteria, + field_names=field_names, + metadata_names=metadata_names, + include_db_collection_in_metadata=include_db_collection_in_metadata, ) loader.collection = mock_collection documents = await loader.aload()