From 978a8a34ee631784aa15cbf53223411cd3e3e90a Mon Sep 17 00:00:00 2001 From: willtai Date: Tue, 23 Apr 2024 09:30:21 +0100 Subject: [PATCH] Retriever base type (#16) * Refactored Retriever object * Added comment about Embedder type in docstring --- .github/workflows/cla-check.yaml | 2 +- README.md | 8 ++- examples/similarity_search_for_text.py | 2 +- src/neo4j_genai/retrievers.py | 72 +++++++++++++++----------- src/neo4j_genai/types.py | 2 +- tests/test_retrievers.py | 8 +-- 6 files changed, 52 insertions(+), 42 deletions(-) diff --git a/.github/workflows/cla-check.yaml b/.github/workflows/cla-check.yaml index d9812a56e..f1189c25b 100644 --- a/.github/workflows/cla-check.yaml +++ b/.github/workflows/cla-check.yaml @@ -27,6 +27,6 @@ jobs: owner=$(echo "$GITHUB_REPOSITORY" | cut -d/ -f1) repository=$(echo "$GITHUB_REPOSITORY" | cut -d/ -f2) - ./bin/examine-pull-request "$owner" "$repository" "${{ secrets.NEO4J_TEAM_GRAPHQL_PERSONAL_ACCESS_TOKEN }}" "$PULL_REQUEST_NUMBER" cla-database.csv + ./bin/examine-pull-request "$owner" "$repository" "${{ secrets.NEO4J_TEAM_GENAI_PERSONAL_ACCESS_TOKEN }}" "$PULL_REQUEST_NUMBER" cla-database.csv env: PULL_REQUEST_NUMBER: ${{ github.event.number }} diff --git a/README.md b/README.md index b67caeeae..cc231d853 100644 --- a/README.md +++ b/README.md @@ -94,9 +94,9 @@ and/or [Discord](https://discord.gg/neo4j). ### Make changes -1. Fork the respository. -2. Install Node.js and Yarn. For more information, see [the development guide](./docs/contributing/DEVELOPING.md). -3. Create a working branch from `dev` and start with your changes! +1. Fork the repository. +2. Install Python and Poetry. For more information, see [the development guide](./docs/contributing/DEVELOPING.md). +3. Create a working branch from `main` and start with your changes! ### Pull request @@ -104,8 +104,6 @@ When you're finished with your changes, create a pull request, also known as a P * Ensure that you have [signed the CLA](https://neo4j.com/developer/contributing-code/#sign-cla). * Ensure that the base of your PR is set to `main`. -* Fill out the template so that we can easily review your PR. The template helps -reviewers understand your changes as well as the purpose of the pull request. * Don't forget to [link your PR to an issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) if you are solving one. * Enable the checkbox to [allow maintainer edits](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/allowing-changes-to-a-pull-request-branch-created-from-a-fork) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index d28aee115..e282b7d57 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -15,7 +15,7 @@ driver = GraphDatabase.driver(URI, auth=AUTH) -# Create Embedder object +# Create CustomEmbedder object with the required Embedder type class CustomEmbedder(Embedder): def embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 2203704df..9ce9d6dc9 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -12,31 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from abc import abstractmethod, ABC from typing import Optional, Any from pydantic import ValidationError -from neo4j import Driver +from neo4j import Driver, Record from .embedder import Embedder -from .types import SimilaritySearchModel, Neo4jRecord, VectorCypherSearchModel +from .types import ( + SimilaritySearchModel, + VectorSearchRecord, + VectorCypherSearchModel, +) -class VectorRetriever: +class Retriever(ABC): """ - Provides retrieval methods using vector search over embeddings + Abstract class for Neo4j retrievers """ - def __init__( - self, - driver: Driver, - index_name: str, - embedder: Optional[Embedder] = None, - return_properties: Optional[list[str]] = None, - ) -> None: + def __init__(self, driver: Driver): self.driver = driver - self._verify_version() - self.index_name = index_name - self.return_properties = return_properties - self.embedder = embedder def _verify_version(self) -> None: """ @@ -65,12 +59,36 @@ def _verify_version(self) -> None: "This package only supports Neo4j version 5.18.1 or greater" ) + @abstractmethod + def search(self, *args, **kwargs) -> Any: + pass + + +class VectorRetriever(Retriever): + """ + Provides retrieval method using vector search over embeddings. + If an embedder is provided, it needs to have the required Embedder type. + """ + + def __init__( + self, + driver: Driver, + index_name: str, + embedder: Optional[Embedder] = None, + return_properties: Optional[list[str]] = None, + ) -> None: + super().__init__(driver) + self._verify_version() + self.index_name = index_name + self.return_properties = return_properties + self.embedder = embedder + def search( self, query_vector: Optional[list[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - ) -> list[Neo4jRecord]: + ) -> list[VectorSearchRecord]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -87,7 +105,7 @@ def search( ValueError: If no embedder is provided. Returns: - list[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. + list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. """ try: validated_data = SimilaritySearchModel( @@ -126,7 +144,7 @@ def search( try: return [ - Neo4jRecord(node=record["node"], score=record["score"]) + VectorSearchRecord(node=record["node"], score=record["score"]) for record in records ] except ValidationError as e: @@ -136,16 +154,10 @@ def search( ) -class VectorCypherRetriever(VectorRetriever): +class VectorCypherRetriever(Retriever): """ Provides retrieval method using vector similarity and custom Cypher query. - When providing the custom query, note that the existing variable `node` can be used. - The query prefix: - ``` - CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) - YIELD node, score - ``` - + If an embedder is provided, it needs to have the required Embedder type. """ def __init__( @@ -155,7 +167,7 @@ def __init__( retrieval_query: str, embedder: Optional[Embedder] = None, ) -> None: - self.driver = driver + super().__init__(driver) self._verify_version() self.index_name = index_name self.retrieval_query = retrieval_query @@ -167,7 +179,7 @@ def search( query_text: Optional[str] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, - ) -> list[Neo4jRecord]: + ) -> list[Record]: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -185,7 +197,7 @@ def search( ValueError: If no embedder is provided. Returns: - list[Neo4jRecord]: The results of the search query + list[Record]: The results of the search query """ try: validated_data = VectorCypherSearchModel( diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index dc65a481d..ebe8cd166 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -18,7 +18,7 @@ from neo4j import Driver -class Neo4jRecord(BaseModel): +class VectorSearchRecord(BaseModel): node: Any score: float diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index d6a6acb32..62900cb01 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -20,7 +20,7 @@ from neo4j_genai import VectorRetriever from neo4j_genai.retrievers import VectorCypherRetriever -from neo4j_genai.types import Neo4jRecord +from neo4j_genai.types import VectorSearchRecord def test_vector_retriever_supported_aura_version(driver): @@ -87,7 +87,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] @patch("neo4j_genai.VectorRetriever._verify_version") @@ -126,7 +126,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] @patch("neo4j_genai.VectorRetriever._verify_version") @@ -169,7 +169,7 @@ def test_similarity_search_text_return_properties(_verify_version_mock, driver): }, ) - assert records == [Neo4jRecord(node="dummy-node", score=1.0)] + assert records == [VectorSearchRecord(node="dummy-node", score=1.0)] def test_vector_retriever_search_missing_embedder_for_text(vector_retriever):