diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index d3f75670..e282b7d5 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -3,7 +3,7 @@ from random import random from neo4j_genai.embedder import Embedder -from neo4j_genai.indexes import create_vector_index, drop_index +from neo4j_genai.indexes import create_vector_index URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") @@ -15,7 +15,7 @@ driver = GraphDatabase.driver(URI, auth=AUTH) -# Create Embedder object +# Create CustomEmbedder object with the required Embedder type class CustomEmbedder(Embedder): def embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] @@ -52,4 +52,4 @@ def embed_query(self, text: str) -> list[float]: # Perform the similarity search for a text query query_text = "hello world" -# print(retriever.search(query_text=query_text, top_k=5)) +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index d3b008b4..9ce9d6dc 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -66,7 +66,8 @@ def search(self, *args, **kwargs) -> Any: class VectorRetriever(Retriever): """ - Provides retrieval method using vector search over embeddings + Provides retrieval method using vector search over embeddings. + If an embedder is provided, it needs to have the required Embedder type. """ def __init__( @@ -80,11 +81,6 @@ def __init__( self._verify_version() self.index_name = index_name self.return_properties = return_properties - - if embedder and not isinstance(embedder, Embedder): - raise TypeError( - "Provided 'embedder' must be an instance of Embedder with an 'embed_query' method." - ) self.embedder = embedder def search( @@ -160,7 +156,8 @@ def search( class VectorCypherRetriever(Retriever): """ - Provides retrieval method using vector similarity and custom Cypher query + Provides retrieval method using vector similarity and custom Cypher query. + If an embedder is provided, it needs to have the required Embedder type. """ def __init__( @@ -174,11 +171,6 @@ def __init__( self._verify_version() self.index_name = index_name self.retrieval_query = retrieval_query - - if embedder and not isinstance(embedder, Embedder): - raise TypeError( - "Provided 'embedder' must be an instance of Embedder with an 'embed_query' method." - ) self.embedder = embedder def search(