Skip to content

Commit

Permalink
Merge branch 'main' into willtai/add-cla-check
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai authored Apr 23, 2024
2 parents 4880034 + 944c4e9 commit dad8e8a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
# Create CustomEmbedder object with the required Embedder type
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]
Expand Down
72 changes: 42 additions & 30 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod, ABC
from typing import Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from neo4j import Driver, Record
from .embedder import Embedder
from .types import SimilaritySearchModel, Neo4jRecord, VectorCypherSearchModel
from .types import (
SimilaritySearchModel,
VectorSearchRecord,
VectorCypherSearchModel,
)


class VectorRetriever:
class Retriever(ABC):
"""
Provides retrieval methods using vector search over embeddings
Abstract class for Neo4j retrievers
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
def __init__(self, driver: Driver):
self.driver = driver
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder

def _verify_version(self) -> None:
"""
Expand Down Expand Up @@ -65,12 +59,36 @@ def _verify_version(self) -> None:
"This package only supports Neo4j version 5.18.1 or greater"
)

@abstractmethod
def search(self, *args, **kwargs) -> Any:
pass


class VectorRetriever(Retriever):
"""
Provides retrieval method using vector search over embeddings.
If an embedder is provided, it needs to have the required Embedder type.
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[Neo4jRecord]:
) -> list[VectorSearchRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -87,7 +105,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
Expand Down Expand Up @@ -126,7 +144,7 @@ def search(

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
VectorSearchRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
Expand All @@ -136,16 +154,10 @@ def search(
)


class VectorCypherRetriever(VectorRetriever):
class VectorCypherRetriever(Retriever):
"""
Provides retrieval method using vector similarity and custom Cypher query.
When providing the custom query, note that the existing variable `node` can be used.
The query prefix:
```
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
```
If an embedder is provided, it needs to have the required Embedder type.
"""

def __init__(
Expand All @@ -155,7 +167,7 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.retrieval_query = retrieval_query
Expand All @@ -167,7 +179,7 @@ def search(
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Neo4jRecord]:
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -185,7 +197,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The results of the search query
list[Record]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from neo4j import Driver


class Neo4jRecord(BaseModel):
class VectorSearchRecord(BaseModel):
node: Any
score: float

Expand Down
8 changes: 4 additions & 4 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import VectorCypherRetriever
from neo4j_genai.types import Neo4jRecord
from neo4j_genai.types import VectorSearchRecord


def test_vector_retriever_supported_aura_version(driver):
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):
Expand Down

0 comments on commit dad8e8a

Please sign in to comment.