Skip to content

Commit

Permalink
INTPYTHON-449 Fix langchain local atlas usage
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Dec 13, 2024
1 parent 5ede100 commit 5820d15
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 14 deletions.
38 changes: 36 additions & 2 deletions libs/mongodb/tests/integration_tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, Generation, LLMResult

from pymongo.collection import Collection
from pymongo import MongoClient
from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
from langchain_mongodb.index import (
create_vector_search_index,
)

from ..utils import ConsistentFakeEmbeddings, FakeChatModel, FakeLLM

Expand All @@ -18,21 +23,48 @@
DATABASE = "langchain_test_db"
COLLECTION = "langchain_test_cache"

DIMENSIONS = 1536 # Meets OpenAI model
TIMEOUT = 60.0

def random_string() -> str:
return str(uuid.uuid4())


@pytest.fixture(scope="module")
def collection() -> Collection:
"""A Collection with both a Vector and a Full-text Search Index"""
client: MongoClient = MongoClient(CONN_STRING)
if COLLECTION not in client[DATABASE].list_collection_names():
clxn = client[DATABASE].create_collection(COLLECTION)
else:
clxn = client[DATABASE][COLLECTION]

clxn.delete_many({})

if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_vector_search_index(
collection=clxn,
index_name=INDEX_NAME,
dimensions=DIMENSIONS,
path="embedding",
filters=['llm_string'],
similarity="cosine",
wait_until_complete=TIMEOUT,
)

return clxn


def llm_cache(cls: Any) -> BaseCache:
set_llm_cache(
cls(
embedding=ConsistentFakeEmbeddings(dimensionality=1536),
embedding=ConsistentFakeEmbeddings(dimensionality=DIMENSIONS),
connection_string=CONN_STRING,
collection_name=COLLECTION,
database_name=DATABASE,
index_name=INDEX_NAME,
score_threshold=0.5,
wait_until_ready=15.0,
wait_until_ready=TIMEOUT,
)
)
assert get_llm_cache()
Expand Down Expand Up @@ -101,6 +133,7 @@ def test_mongodb_cache(
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
collection: Collection,
) -> None:
llm_cache(cacher)
if remove_score:
Expand Down Expand Up @@ -136,6 +169,7 @@ def test_mongodb_cache(
def test_mongodb_atlas_cache_matrix(
prompts: List[str],
generations: List[List[str]],
collection: Collection,
) -> None:
llm_cache(MongoDBAtlasSemanticCache)
llm = FakeLLM()
Expand Down
3 changes: 2 additions & 1 deletion libs/mongodb/tests/integration_tests/test_chain_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def collection() -> Collection:


@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is not None, reason="Requires OpenAI for chat responses."
os.environ.get("OPENAI_API_KEY") is not None,
reason="Requires OpenAI for chat responses.",
)
def test_chain(
collection: Collection,
Expand Down
27 changes: 24 additions & 3 deletions libs/mongodb/tests/integration_tests/test_mmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
from pymongo.collection import Collection

from langchain_mongodb.index import (
create_vector_search_index,
)
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")

Check failure on line 16 in libs/mongodb/tests/integration_tests/test_mmr.py

View workflow job for this annotation

GitHub Actions / cd libs/mongodb / make lint #3.9

Ruff (I001)

tests/integration_tests/test_mmr.py:3:1: I001 Import block is un-sorted or un-formatted

Check failure on line 16 in libs/mongodb/tests/integration_tests/test_mmr.py

View workflow job for this annotation

GitHub Actions / cd libs/mongodb / make lint #3.12

Ruff (I001)

tests/integration_tests/test_mmr.py:3:1: I001 Import block is un-sorted or un-formatted
Expand All @@ -20,8 +22,27 @@

@pytest.fixture()
def collection() -> Collection:
test_client: MongoClient = MongoClient(CONNECTION_STRING)
return test_client[DB_NAME][COLLECTION_NAME]
client: MongoClient = MongoClient(CONNECTION_STRING)

if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]

clxn.delete_many({})

if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_vector_search_index(
collection=clxn,
index_name=INDEX_NAME,
dimensions=5,
path="embedding",
filters=["c"],
similarity="cosine",
wait_until_complete=60,
)

return clxn


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from langchain_core.embeddings import Embeddings
from pymongo import MongoClient
from pymongo.collection import Collection

from langchain_mongodb.index import (
create_vector_search_index,
)
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")

Check failure on line 18 in libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py

View workflow job for this annotation

GitHub Actions / cd libs/mongodb / make lint #3.9

Ruff (I001)

tests/integration_tests/test_vectorstore_from_documents.py:3:1: I001 Import block is un-sorted or un-formatted

Check failure on line 18 in libs/mongodb/tests/integration_tests/test_vectorstore_from_documents.py

View workflow job for this annotation

GitHub Actions / cd libs/mongodb / make lint #3.12

Ruff (I001)

tests/integration_tests/test_vectorstore_from_documents.py:3:1: I001 Import block is un-sorted or un-formatted
Expand All @@ -21,12 +23,28 @@


@pytest.fixture(scope="module")
def collection() -> Generator[Collection, None, None]:
test_client: MongoClient = MongoClient(CONNECTION_STRING)
clxn = test_client[DB_NAME][COLLECTION_NAME]
yield clxn
def collection() -> Collection:
client: MongoClient = MongoClient(CONNECTION_STRING)

if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]

clxn.delete_many({})

if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_vector_search_index(
collection=clxn,
index_name=INDEX_NAME,
dimensions=DIMENSIONS,
path="embedding",
similarity="cosine",
wait_until_complete=60,
)

return clxn


@pytest.fixture(scope="module")
def example_documents() -> List[Document]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from pymongo.collection import Collection

from langchain_mongodb import MongoDBAtlasVectorSearch

from langchain_mongodb.index import (
create_vector_search_index,
)
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch

CONNECTION_STRING = os.environ.get("MONGODB_URI")

Check failure on line 19 in libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py

View workflow job for this annotation

GitHub Actions / cd libs/mongodb / make lint #3.9

Ruff (I001)

tests/integration_tests/test_vectorstore_from_texts.py:3:1: I001 Import block is un-sorted or un-formatted

Check failure on line 19 in libs/mongodb/tests/integration_tests/test_vectorstore_from_texts.py

View workflow job for this annotation

GitHub Actions / cd libs/mongodb / make lint #3.12

Ruff (I001)

tests/integration_tests/test_vectorstore_from_texts.py:3:1: I001 Import block is un-sorted or un-formatted
Expand All @@ -23,8 +25,27 @@

@pytest.fixture(scope="module")
def collection() -> Collection:
test_client: MongoClient = MongoClient(CONNECTION_STRING)
return test_client[DB_NAME][COLLECTION_NAME]
client: MongoClient = MongoClient(CONNECTION_STRING)

if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
else:
clxn = client[DB_NAME][COLLECTION_NAME]

clxn.delete_many({})

if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
create_vector_search_index(
collection=clxn,
index_name=INDEX_NAME,
dimensions=DIMENSIONS,
path="embedding",
filters=['c'],
similarity="cosine",
wait_until_complete=60,
)

return clxn


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 5820d15

Please sign in to comment.