From 7ccf149197103083a8a4001d54eb9bf110551af7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 19 Dec 2024 15:15:54 -0600 Subject: [PATCH] INTPYTHON-459 Refactor handling of client into a pytest fixture (#44) --- .../tests/integration_tests/conftest.py | 12 ++++++++++++ .../tests/integration_tests/test_cache.py | 3 +-- .../tests/integration_tests/test_chain_example.py | 4 +--- .../integration_tests/test_chat_message_histories.py | 6 +----- .../tests/integration_tests/test_docstore.py | 5 +---- .../tests/integration_tests/test_index.py | 5 +---- .../tests/integration_tests/test_indexes.py | 8 ++------ .../tests/integration_tests/test_mmr.py | 7 +------ .../tests/integration_tests/test_parent_document.py | 7 ++++--- .../tests/integration_tests/test_retrievers.py | 4 +--- .../integration_tests/test_vectorstore_add_delete.py | 7 ++----- .../test_vectorstore_from_documents.py | 6 +----- .../integration_tests/test_vectorstore_from_texts.py | 6 +----- 13 files changed, 29 insertions(+), 51 deletions(-) diff --git a/libs/langchain-mongodb/tests/integration_tests/conftest.py b/libs/langchain-mongodb/tests/integration_tests/conftest.py index 0a980be..ccd12d0 100644 --- a/libs/langchain-mongodb/tests/integration_tests/conftest.py +++ b/libs/langchain-mongodb/tests/integration_tests/conftest.py @@ -1,8 +1,10 @@ +import os from typing import List import pytest from langchain_community.document_loaders import PyPDFLoader from langchain_core.documents import Document +from pymongo import MongoClient @pytest.fixture(scope="session") @@ -11,3 +13,13 @@ def technical_report_pages() -> List[Document]: loader = PyPDFLoader("https://arxiv.org/pdf/2303.08774.pdf") pages = loader.load() return pages + + +@pytest.fixture(scope="session") +def connection_string() -> str: + return os.environ["MONGODB_URI"] + + +@pytest.fixture(scope="session") +def client(connection_string: str) -> MongoClient: + return MongoClient(connection_string) diff --git a/libs/langchain-mongodb/tests/integration_tests/test_cache.py b/libs/langchain-mongodb/tests/integration_tests/test_cache.py index 06b078f..35a24aa 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_cache.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_cache.py @@ -32,9 +32,8 @@ def random_string() -> str: @pytest.fixture(scope="module") -def collection() -> Collection: +def collection(client: MongoClient) -> Collection: """A Collection with both a Vector and a Full-text Search Index""" - client: MongoClient = MongoClient(CONN_STRING) if COLLECTION not in client[DATABASE].list_collection_names(): clxn = client[DATABASE].create_collection(COLLECTION) else: diff --git a/libs/langchain-mongodb/tests/integration_tests/test_chain_example.py b/libs/langchain-mongodb/tests/integration_tests/test_chain_example.py index 8119a74..94b7bad 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_chain_example.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_chain_example.py @@ -16,7 +16,6 @@ from ..utils import PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_chain_example" INDEX_NAME = "langchain-test-chain-example-vector-index" @@ -26,9 +25,8 @@ @pytest.fixture -def collection() -> Collection: +def collection(client: MongoClient) -> Collection: """A Collection with both a Vector and a Full-text Search Index""" - client: MongoClient = MongoClient(CONNECTION_STRING) if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): clxn = client[DB_NAME].create_collection(COLLECTION_NAME) else: diff --git a/libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py b/libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py index 0c6c929..1832dcb 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py @@ -1,5 +1,4 @@ import json -import os from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found] from langchain_core.messages import message_to_dict @@ -9,11 +8,8 @@ DATABASE = "langchain_test_db" COLLECTION = "langchain_test_chat" -# Replace these with your mongodb connection string -connection_string = os.environ.get("MONGODB_URI", "") - -def test_memory_with_message_store() -> None: +def test_memory_with_message_store(connection_string: str) -> None: """Test the memory with a message store.""" # setup MongoDB as a message store message_history = MongoDBChatMessageHistory( diff --git a/libs/langchain-mongodb/tests/integration_tests/test_docstore.py b/libs/langchain-mongodb/tests/integration_tests/test_docstore.py index e4cfd63..d5a12f7 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_docstore.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_docstore.py @@ -1,4 +1,3 @@ -import os from typing import List from langchain_core.documents import Document @@ -6,13 +5,11 @@ from langchain_mongodb.docstores import MongoDBDocStore -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_docstore" -def test_docstore(technical_report_pages: List[Document]) -> None: - client: MongoClient = MongoClient(CONNECTION_STRING) +def test_docstore(client: MongoClient, technical_report_pages: List[Document]) -> None: db = client[DB_NAME] db.drop_collection(COLLECTION_NAME) clxn = db[COLLECTION_NAME] diff --git a/libs/langchain-mongodb/tests/integration_tests/test_index.py b/libs/langchain-mongodb/tests/integration_tests/test_index.py index b43831b..7534df0 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_index.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_index.py @@ -1,4 +1,3 @@ -import os from typing import Generator, List, Optional import pytest @@ -18,10 +17,8 @@ @pytest.fixture -def collection() -> Generator: +def collection(client: MongoClient) -> Generator: """Depending on uri, this could point to any type of cluster.""" - uri = os.environ.get("MONGODB_URI") - client: MongoClient = MongoClient(uri) if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): clxn = client[DB_NAME].create_collection(COLLECTION_NAME) else: diff --git a/libs/langchain-mongodb/tests/integration_tests/test_indexes.py b/libs/langchain-mongodb/tests/integration_tests/test_indexes.py index b12bfd2..cbc2ea4 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_indexes.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_indexes.py @@ -1,4 +1,3 @@ -import os from datetime import datetime from unittest.mock import patch @@ -8,25 +7,22 @@ from langchain_mongodb.indexes import MongoDBRecordManager -CONNECTION_STRING = os.environ["MONGODB_URI"] DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_docstore" NAMESPACE = f"{DB_NAME}.{COLLECTION_NAME}" @pytest.fixture -def manager() -> MongoDBRecordManager: +def manager(client: MongoClient) -> MongoDBRecordManager: """Initialize the test MongoDB and yield the DocumentManager instance.""" - client: MongoClient = MongoClient(CONNECTION_STRING) collection = client[DB_NAME][COLLECTION_NAME] document_manager = MongoDBRecordManager(collection=collection) return document_manager @pytest_asyncio.fixture -async def amanager() -> MongoDBRecordManager: +async def amanager(client: MongoClient) -> MongoDBRecordManager: """Initialize the test MongoDB and yield the DocumentManager instance.""" - client: MongoClient = MongoClient(CONNECTION_STRING) collection = client[DB_NAME][COLLECTION_NAME] document_manager = MongoDBRecordManager(collection=collection) return document_manager diff --git a/libs/langchain-mongodb/tests/integration_tests/test_mmr.py b/libs/langchain-mongodb/tests/integration_tests/test_mmr.py index b0173ec..47c3974 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_mmr.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_mmr.py @@ -2,8 +2,6 @@ from __future__ import annotations -import os - import pytest # type: ignore[import-not-found] from langchain_core.embeddings import Embeddings from pymongo import MongoClient @@ -15,7 +13,6 @@ from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_vectorstores" INDEX_NAME = "langchain-test-index-vectorstores" @@ -23,9 +20,7 @@ @pytest.fixture() -def collection() -> Collection: - client: MongoClient = MongoClient(CONNECTION_STRING) - +def collection(client: MongoClient) -> Collection: if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): clxn = client[DB_NAME].create_collection(COLLECTION_NAME) else: diff --git a/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py b/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py index f61aa6b..683db81 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py @@ -17,7 +17,6 @@ from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_parent_document_combined" VECTOR_INDEX_NAME = "langchain-test-parent-document-vector-index" @@ -41,11 +40,13 @@ def embedding_model() -> Embeddings: def test_1clxn_retriever( - technical_report_pages: List[Document], embedding_model: Embeddings + connection_string: str, + technical_report_pages: List[Document], + embedding_model: Embeddings, ) -> None: # Setup client: MongoClient = MongoClient( - CONNECTION_STRING, + connection_string, driver=DriverInfo(name="langchain", version=version("langchain-mongodb")), ) db = client[DB_NAME] diff --git a/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py b/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py index b144f3a..ad6dae4 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py @@ -20,7 +20,6 @@ from ..utils import PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_retrievers" VECTOR_INDEX_NAME = "vector_index" @@ -57,9 +56,8 @@ def embedding_openai() -> Embeddings: @pytest.fixture(scope="module") -def collection() -> Collection: +def collection(client: MongoClient) -> Collection: """A Collection with both a Vector and a Full-text Search Index""" - client: MongoClient = MongoClient(CONNECTION_STRING) if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): clxn = client[DB_NAME].create_collection(COLLECTION_NAME) else: diff --git a/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_add_delete.py b/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_add_delete.py index 85862a1..94986d8 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_add_delete.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_add_delete.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os from typing import Any, Dict, List import pytest # type: ignore[import-not-found] @@ -17,7 +16,6 @@ from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" INDEX_NAME = "langchain-test-index-vectorstores" COLLECTION_NAME = "langchain_test_vectorstores" @@ -25,9 +23,8 @@ @pytest.fixture(scope="module") -def collection() -> Collection: - test_client: MongoClient = MongoClient(CONNECTION_STRING) - return test_client[DB_NAME][COLLECTION_NAME] +def collection(client: MongoClient) -> Collection: + return client[DB_NAME][COLLECTION_NAME] @pytest.fixture(scope="module") diff --git a/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_documents.py b/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_documents.py index eec5d89..5ec0157 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_documents.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_documents.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os from typing import List import pytest # type: ignore[import-not-found] @@ -17,7 +16,6 @@ from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_from_documents" INDEX_NAME = "langchain-test-index-from-documents" @@ -25,9 +23,7 @@ @pytest.fixture(scope="module") -def collection() -> Collection: - client: MongoClient = MongoClient(CONNECTION_STRING) - +def collection(client: MongoClient) -> Collection: if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): clxn = client[DB_NAME].create_collection(COLLECTION_NAME) else: diff --git a/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_texts.py b/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_texts.py index dcf98ae..41e8118 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_texts.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_texts.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os from typing import Dict, Generator, List import pytest # type: ignore[import-not-found] @@ -17,7 +16,6 @@ from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch -CONNECTION_STRING = os.environ.get("MONGODB_URI") DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_from_texts" INDEX_NAME = "langchain-test-index-from-texts" @@ -25,9 +23,7 @@ @pytest.fixture(scope="module") -def collection() -> Collection: - client: MongoClient = MongoClient(CONNECTION_STRING) - +def collection(client: MongoClient) -> Collection: if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): clxn = client[DB_NAME].create_collection(COLLECTION_NAME) else: