Skip to content

Commit

Permalink
Cache calls to Open AI during tests (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsm207 authored Feb 8, 2024
1 parent 0844c9a commit 7798b4f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ optional = true
langchain-openai = ">=0.0.3,<0.1"

[tool.ruff]
select = [
lint.select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
Expand Down
25 changes: 22 additions & 3 deletions tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,28 @@ def weaviate_client(docker_ip, docker_services):
client.close()


@pytest.fixture
@pytest.fixture(scope="session")
def embedding_openai():
yield OpenAIEmbeddings()
class MemoizedOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self):
super().__init__()
object.__setattr__(self, "cache", {})

def embed_query(self, query):
if query not in self.cache:
# Call the base class method if result not cached
self.cache[query] = super().embed_query(query)
return self.cache[query]

def embed_documents(self, documents):
# Use tuple to allow documents list to be hashable
hashable_docs = tuple(documents)
if hashable_docs not in self.cache:
# Call the base class method if result not cached
self.cache[hashable_docs] = super().embed_documents(documents)
return self.cache[hashable_docs]

yield MemoizedOpenAIEmbeddings()


@pytest.fixture
Expand Down Expand Up @@ -622,7 +641,7 @@ def test_embedding_property(weaviate_client, embedding_openai):
embedding=embedding_openai,
)

assert type(docsearch.embeddings) == OpenAIEmbeddings
assert type(docsearch.embeddings) == type(embedding_openai)


def test_documents_with_many_properties(weaviate_client, embedding_openai):
Expand Down

0 comments on commit 7798b4f

Please sign in to comment.