diff --git a/libs/mongodb/tests/integration_tests/test_cache.py b/libs/mongodb/tests/integration_tests/test_cache.py index 62d4b00..719aa3c 100644 --- a/libs/mongodb/tests/integration_tests/test_cache.py +++ b/libs/mongodb/tests/integration_tests/test_cache.py @@ -9,7 +9,12 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, Generation, LLMResult +from pymongo.collection import Collection +from pymongo import MongoClient from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache +from langchain_mongodb.index import ( + create_vector_search_index, +) from ..utils import ConsistentFakeEmbeddings, FakeChatModel, FakeLLM @@ -18,21 +23,48 @@ DATABASE = "langchain_test_db" COLLECTION = "langchain_test_cache" +DIMENSIONS = 1536 # Meets OpenAI model +TIMEOUT = 60.0 def random_string() -> str: return str(uuid.uuid4()) +@pytest.fixture(scope="module") +def collection() -> Collection: + """A Collection with both a Vector and a Full-text Search Index""" + client: MongoClient = MongoClient(CONN_STRING) + if COLLECTION not in client[DATABASE].list_collection_names(): + clxn = client[DATABASE].create_collection(COLLECTION) + else: + clxn = client[DATABASE][COLLECTION] + + clxn.delete_many({}) + + if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=INDEX_NAME, + dimensions=DIMENSIONS, + path="embedding", + filters=['llm_string'], + similarity="cosine", + wait_until_complete=TIMEOUT, + ) + + return clxn + + def llm_cache(cls: Any) -> BaseCache: set_llm_cache( cls( - embedding=ConsistentFakeEmbeddings(dimensionality=1536), + embedding=ConsistentFakeEmbeddings(dimensionality=DIMENSIONS), connection_string=CONN_STRING, collection_name=COLLECTION, database_name=DATABASE, index_name=INDEX_NAME, score_threshold=0.5, - wait_until_ready=15.0, + wait_until_ready=TIMEOUT, ) ) assert get_llm_cache() @@ -101,6 +133,7 @@ def test_mongodb_cache( prompt: Union[str, List[BaseMessage]], llm: Union[str, FakeLLM, FakeChatModel], response: List[Generation], + collection: Collection, ) -> None: llm_cache(cacher) if remove_score: @@ -136,6 +169,7 @@ def test_mongodb_cache( def test_mongodb_atlas_cache_matrix( prompts: List[str], generations: List[List[str]], + collection: Collection, ) -> None: llm_cache(MongoDBAtlasSemanticCache) llm = FakeLLM() diff --git a/libs/mongodb/tests/integration_tests/test_chain_example.py b/libs/mongodb/tests/integration_tests/test_chain_example.py index 1e61698..c4c2195 100644 --- a/libs/mongodb/tests/integration_tests/test_chain_example.py +++ b/libs/mongodb/tests/integration_tests/test_chain_example.py @@ -51,7 +51,8 @@ def collection() -> Collection: @pytest.mark.skipif( - os.environ.get("OPENAI_API_KEY") is not None, reason="Requires OpenAI for chat responses." + os.environ.get("OPENAI_API_KEY") is not None, + reason="Requires OpenAI for chat responses.", ) def test_chain( collection: Collection, diff --git a/libs/mongodb/tests/integration_tests/test_mmr.py b/libs/mongodb/tests/integration_tests/test_mmr.py index 22934e5..9c0d767 100644 --- a/libs/mongodb/tests/integration_tests/test_mmr.py +++ b/libs/mongodb/tests/integration_tests/test_mmr.py @@ -8,7 +8,9 @@ from langchain_core.embeddings import Embeddings from pymongo import MongoClient from pymongo.collection import Collection - +from langchain_mongodb.index import ( + create_vector_search_index, +) from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch CONNECTION_STRING = os.environ.get("MONGODB_URI") @@ -20,8 +22,27 @@ @pytest.fixture() def collection() -> Collection: - test_client: MongoClient = MongoClient(CONNECTION_STRING) - return test_client[DB_NAME][COLLECTION_NAME] + client: MongoClient = MongoClient(CONNECTION_STRING) + + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) + else: + clxn = client[DB_NAME][COLLECTION_NAME] + + clxn.delete_many({}) + + if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=INDEX_NAME, + dimensions=5, + path="embedding", + filters=["c"], + similarity="cosine", + wait_until_complete=60, + ) + + return clxn @pytest.fixture diff --git a/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py b/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py index 0336f1f..451dc92 100644 --- a/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py +++ b/libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py @@ -10,7 +10,9 @@ from langchain_core.embeddings import Embeddings from pymongo import MongoClient from pymongo.collection import Collection - +from langchain_mongodb.index import ( + create_vector_search_index, +) from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch CONNECTION_STRING = os.environ.get("MONGODB_URI") @@ -21,12 +23,28 @@ @pytest.fixture(scope="module") -def collection() -> Generator[Collection, None, None]: - test_client: MongoClient = MongoClient(CONNECTION_STRING) - clxn = test_client[DB_NAME][COLLECTION_NAME] - yield clxn +def collection() -> Collection: + client: MongoClient = MongoClient(CONNECTION_STRING) + + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) + else: + clxn = client[DB_NAME][COLLECTION_NAME] + clxn.delete_many({}) + if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=INDEX_NAME, + dimensions=DIMENSIONS, + path="embedding", + similarity="cosine", + wait_until_complete=60, + ) + + return clxn + @pytest.fixture(scope="module") def example_documents() -> List[Document]: diff --git a/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py b/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py index 4dc2e9d..dad75e6 100644 --- a/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py +++ b/libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py @@ -11,7 +11,9 @@ from pymongo.collection import Collection from langchain_mongodb import MongoDBAtlasVectorSearch - +from langchain_mongodb.index import ( + create_vector_search_index, +) from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch CONNECTION_STRING = os.environ.get("MONGODB_URI") @@ -23,8 +25,27 @@ @pytest.fixture(scope="module") def collection() -> Collection: - test_client: MongoClient = MongoClient(CONNECTION_STRING) - return test_client[DB_NAME][COLLECTION_NAME] + client: MongoClient = MongoClient(CONNECTION_STRING) + + if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME) + else: + clxn = client[DB_NAME][COLLECTION_NAME] + + clxn.delete_many({}) + + if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=INDEX_NAME, + dimensions=DIMENSIONS, + path="embedding", + filters=['c'], + similarity="cosine", + wait_until_complete=60, + ) + + return clxn @pytest.fixture(scope="module")