Skip to content

Commit

Permalink
Moved index_name definition to constructor level of VectorRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 5, 2024
1 parent 6e4b3fe commit 8e1e641
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
5 changes: 3 additions & 2 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ class VectorRetriever:
def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
self._verify_version()
self.index_name = index_name
self.embedder = embedder

def _verify_version(self) -> None:
Expand Down Expand Up @@ -48,7 +50,6 @@ def _verify_version(self) -> None:

def search(
self,
name: str,
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
Expand All @@ -74,7 +75,7 @@ def search(
"""
try:
validated_data = SimilaritySearchModel(
index_name=name,
index_name=self.index_name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def driver():
@pytest.fixture
@patch("neo4j_genai.VectorRetriever._verify_version")
def retriever(_verify_version_mock, driver):
return VectorRetriever(driver)
return VectorRetriever(driver, "my-index")
31 changes: 14 additions & 17 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@
def test_vector_retriever_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None]

VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")


def test_vector_retriever_no_supported_aura_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None]

with pytest.raises(ValueError) as excinfo:
VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")

assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)


def test_vector_retriever_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None]

VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")


def test_vector_retriever_no_supported_version(driver):
driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None]

with pytest.raises(ValueError) as excinfo:
VectorRetriever(driver=driver)
VectorRetriever(driver=driver, index_name="my-index")

assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo)

Expand All @@ -38,13 +38,13 @@ def test_vector_retriever_no_supported_version(driver):
def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
custom_embeddings = MagicMock()

retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
Expand All @@ -55,7 +55,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = retriever.search(name=index_name, query_vector=query_vector, top_k=top_k)
records = retriever.search(query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

Expand All @@ -77,12 +77,12 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)

driver.execute_query.return_value = [
[{"node": "dummy-node", "score": 1.0}],
None,
Expand All @@ -94,7 +94,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
YIELD node, score
"""

records = retriever.search(name=index_name, query_text=query_text, top_k=top_k)
records = retriever.search(query_text=query_text, top_k=top_k)

custom_embeddings.embed_query.assert_called_once_with(query_text)

Expand All @@ -111,16 +111,14 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):


def test_similarity_search_missing_embedder_for_text(retriever):
index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

with pytest.raises(ValueError, match="Embedding method required for text query"):
retriever.search(name=index_name, query_text=query_text, top_k=top_k)
retriever.search(query_text=query_text, top_k=top_k)


def test_similarity_search_both_text_and_vector(retriever):
index_name = "my-index"
query_text = "may thy knife chip and shatter"
query_vector = [1.1, 2.2, 3.3]
top_k = 5
Expand All @@ -129,7 +127,6 @@ def test_similarity_search_both_text_and_vector(retriever):
ValueError, match="You must provide exactly one of query_vector or query_text."
):
retriever.search(
name=index_name,
query_text=query_text,
query_vector=query_vector,
top_k=top_k,
Expand All @@ -140,13 +137,13 @@ def test_similarity_search_both_text_and_vector(retriever):
def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
custom_embeddings = MagicMock()

retriever = VectorRetriever(driver, custom_embeddings)

index_name = "my-index"
dimensions = 1536
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

retriever = VectorRetriever(driver, index_name, custom_embeddings)

retriever.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": "adsa"}],
None,
Expand All @@ -158,7 +155,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
"""

with pytest.raises(ValueError):
retriever.search(name=index_name, query_vector=query_vector, top_k=top_k)
retriever.search(query_vector=query_vector, top_k=top_k)

custom_embeddings.embed_query.assert_not_called()

Expand Down

0 comments on commit 8e1e641

Please sign in to comment.