Skip to content

Commit

Permalink
Retriever base type (#16)
Browse files Browse the repository at this point in the history
* Refactored Retriever object

* Added comment about Embedder type in docstring
  • Loading branch information
willtai committed Apr 23, 2024
1 parent 4880034 commit 978a8a3
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cla-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
owner=$(echo "$GITHUB_REPOSITORY" | cut -d/ -f1)
repository=$(echo "$GITHUB_REPOSITORY" | cut -d/ -f2)
./bin/examine-pull-request "$owner" "$repository" "${{ secrets.NEO4J_TEAM_GRAPHQL_PERSONAL_ACCESS_TOKEN }}" "$PULL_REQUEST_NUMBER" cla-database.csv
./bin/examine-pull-request "$owner" "$repository" "${{ secrets.NEO4J_TEAM_GENAI_PERSONAL_ACCESS_TOKEN }}" "$PULL_REQUEST_NUMBER" cla-database.csv
env:
PULL_REQUEST_NUMBER: ${{ github.event.number }}
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,16 @@ and/or [Discord](https://discord.gg/neo4j).

### Make changes

1. Fork the respository.
2. Install Node.js and Yarn. For more information, see [the development guide](./docs/contributing/DEVELOPING.md).
3. Create a working branch from `dev` and start with your changes!
1. Fork the repository.
2. Install Python and Poetry. For more information, see [the development guide](./docs/contributing/DEVELOPING.md).
3. Create a working branch from `main` and start with your changes!

### Pull request

When you're finished with your changes, create a pull request, also known as a PR.

* Ensure that you have [signed the CLA](https://neo4j.com/developer/contributing-code/#sign-cla).
* Ensure that the base of your PR is set to `main`.
* Fill out the template so that we can easily review your PR. The template helps
reviewers understand your changes as well as the purpose of the pull request.
* Don't forget to [link your PR to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue)
if you are solving one.
* Enable the checkbox to [allow maintainer edits](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/allowing-changes-to-a-pull-request-branch-created-from-a-fork)
Expand Down
2 changes: 1 addition & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
# Create CustomEmbedder object with the required Embedder type
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random() for _ in range(DIMENSION)]
Expand Down
72 changes: 42 additions & 30 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod, ABC
from typing import Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from neo4j import Driver, Record
from .embedder import Embedder
from .types import SimilaritySearchModel, Neo4jRecord, VectorCypherSearchModel
from .types import (
SimilaritySearchModel,
VectorSearchRecord,
VectorCypherSearchModel,
)


class VectorRetriever:
class Retriever(ABC):
"""
Provides retrieval methods using vector search over embeddings
Abstract class for Neo4j retrievers
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
def __init__(self, driver: Driver):
self.driver = driver
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder

def _verify_version(self) -> None:
"""
Expand Down Expand Up @@ -65,12 +59,36 @@ def _verify_version(self) -> None:
"This package only supports Neo4j version 5.18.1 or greater"
)

@abstractmethod
def search(self, *args, **kwargs) -> Any:
pass


class VectorRetriever(Retriever):
"""
Provides retrieval method using vector search over embeddings.
If an embedder is provided, it needs to have the required Embedder type.
"""

def __init__(
self,
driver: Driver,
index_name: str,
embedder: Optional[Embedder] = None,
return_properties: Optional[list[str]] = None,
) -> None:
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.return_properties = return_properties
self.embedder = embedder

def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[Neo4jRecord]:
) -> list[VectorSearchRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -87,7 +105,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
Expand Down Expand Up @@ -126,7 +144,7 @@ def search(

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
VectorSearchRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
Expand All @@ -136,16 +154,10 @@ def search(
)


class VectorCypherRetriever(VectorRetriever):
class VectorCypherRetriever(Retriever):
"""
Provides retrieval method using vector similarity and custom Cypher query.
When providing the custom query, note that the existing variable `node` can be used.
The query prefix:
```
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
```
If an embedder is provided, it needs to have the required Embedder type.
"""

def __init__(
Expand All @@ -155,7 +167,7 @@ def __init__(
retrieval_query: str,
embedder: Optional[Embedder] = None,
) -> None:
self.driver = driver
super().__init__(driver)
self._verify_version()
self.index_name = index_name
self.retrieval_query = retrieval_query
Expand All @@ -167,7 +179,7 @@ def search(
query_text: Optional[str] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
) -> list[Neo4jRecord]:
) -> list[Record]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -185,7 +197,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
list[Neo4jRecord]: The results of the search query
list[Record]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from neo4j import Driver


class Neo4jRecord(BaseModel):
class VectorSearchRecord(BaseModel):
node: Any
score: float

Expand Down
8 changes: 4 additions & 4 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from neo4j_genai import VectorRetriever
from neo4j_genai.retrievers import VectorCypherRetriever
from neo4j_genai.types import Neo4jRecord
from neo4j_genai.types import VectorSearchRecord


def test_vector_retriever_supported_aura_version(driver):
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


@patch("neo4j_genai.VectorRetriever._verify_version")
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [VectorSearchRecord(node="dummy-node", score=1.0)]


def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):
Expand Down

0 comments on commit 978a8a3

Please sign in to comment.