diff --git a/pyproject.toml b/pyproject.toml index 4e7474c..73bc3aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ optional = true langchain-openai = ">=0.0.3,<0.1" [tool.ruff] -select = [ +lint.select = [ "E", # pycodestyle "F", # pyflakes "I", # isort diff --git a/tests/integration_tests/test_vectorstores.py b/tests/integration_tests/test_vectorstores.py index 9350026..1dd9b6c 100644 --- a/tests/integration_tests/test_vectorstores.py +++ b/tests/integration_tests/test_vectorstores.py @@ -59,9 +59,28 @@ def weaviate_client(docker_ip, docker_services): client.close() -@pytest.fixture +@pytest.fixture(scope="session") def embedding_openai(): - yield OpenAIEmbeddings() + class MemoizedOpenAIEmbeddings(OpenAIEmbeddings): + def __init__(self): + super().__init__() + object.__setattr__(self, "cache", {}) + + def embed_query(self, query): + if query not in self.cache: + # Call the base class method if result not cached + self.cache[query] = super().embed_query(query) + return self.cache[query] + + def embed_documents(self, documents): + # Use tuple to allow documents list to be hashable + hashable_docs = tuple(documents) + if hashable_docs not in self.cache: + # Call the base class method if result not cached + self.cache[hashable_docs] = super().embed_documents(documents) + return self.cache[hashable_docs] + + yield MemoizedOpenAIEmbeddings() @pytest.fixture @@ -622,7 +641,7 @@ def test_embedding_property(weaviate_client, embedding_openai): embedding=embedding_openai, ) - assert type(docsearch.embeddings) == OpenAIEmbeddings + assert type(docsearch.embeddings) == type(embedding_openai) def test_documents_with_many_properties(weaviate_client, embedding_openai):