Skip to content

Commit

Permalink
Refactored Retriever object
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 22, 2024
1 parent 40a6d69 commit 2a21e0d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 38 deletions.
4 changes: 2 additions & 2 deletions examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from random import random
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index
from neo4j_genai.indexes import create_vector_index, drop_index

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
Expand Down Expand Up @@ -52,4 +52,4 @@ def embed_query(self, text: str) -> list[float]:

# Perform the similarity search for a text query
query_text = "hello world"
print(retriever.search(query_text=query_text, top_k=5))
# print(retriever.search(query_text=query_text, top_k=5))
82 changes: 51 additions & 31 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod, ABC
from typing import Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from neo4j import Driver, Record
from .embedder import Embedder
from .types import SimilaritySearchModel, Neo4jRecord, VectorCypherSearchModel
from .types import (
SimilaritySearchModel,
VectorSearchRecord,
VectorCypherSearchModel,
)


class VectorRetriever:
class Retriever(ABC):
"""
Provides retrieval methods using vector search over embeddings
Abstract class for Neo4j retrievers
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
def __init__(self, driver: Driver):
self.driver = driver
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder

def _verify_version(self) -> None:
"""
Expand Down Expand Up @@ -65,12 +59,40 @@ def _verify_version(self) -> None:
"This package only supports Neo4j version 5.18.1 or greater"
)

@abstractmethod
def search(self, *args, **kwargs) -> Any:
pass


class VectorRetriever(Retriever):
"""
Provides retrieval method using vector search over embeddings
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties

if embedder and not isinstance(embedder, Embedder):
raise TypeError(
"Provided 'embedder' must be an instance of Embedder with an 'embed_query' method."
)
self.embedder = embedder

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[Neo4jRecord]:
) -> list[VectorSearchRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -87,7 +109,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
Expand Down Expand Up @@ -126,7 +148,7 @@ def search(

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
VectorSearchRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
Expand All @@ -136,16 +158,9 @@ def search(
)


class VectorCypherRetriever(VectorRetriever):
class VectorCypherRetriever(Retriever):
"""
Provides retrieval method using vector similarity and custom Cypher query.
When providing the custom query, note that the existing variable `node` can be used.
The query prefix:
```
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
```
Provides retrieval method using vector similarity and custom Cypher query
"""

def __init__(
Expand All @@ -155,10 +170,15 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.retrieval_query = retrieval_query

if embedder and not isinstance(embedder, Embedder):
raise TypeError(
"Provided 'embedder' must be an instance of Embedder with an 'embed_query' method."
)
self.embedder = embedder

def search(
Expand All @@ -167,7 +187,7 @@ def search(
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Neo4jRecord]:
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -185,7 +205,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The results of the search query
list[Record]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from neo4j import Driver


class Neo4jRecord(BaseModel):
class VectorSearchRecord(BaseModel):
node: Any
score: float

Expand Down
8 changes: 4 additions & 4 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import VectorCypherRetriever
from neo4j_genai.types import Neo4jRecord
from neo4j_genai.types import VectorSearchRecord


def test_vector_retriever_supported_aura_version(driver):
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):
Expand Down

0 comments on commit 2a21e0d

Please sign in to comment.