Skip to content

Commit

Permalink
Added example for VectorCypherRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Apr 19, 2024
1 parent de74a0f commit 53aa093
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
67 changes: 67 additions & 0 deletions examples/vector_cypher_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from neo4j import GraphDatabase
from neo4j_genai import VectorCypherRetriever

import random
import string
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index


URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name"
DIMENSION = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random.random() for _ in range(DIMENSION)]


# Generate random strings
def random_str(n: int) -> str:
return "".join([random.choice(string.ascii_letters) for _ in range(n)])


embedder = CustomEmbedder()

# Creating the index
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="euclidean",
)

# Initialize the retriever
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name"
retriever = VectorCypherRetriever(driver, INDEX_NAME, retrieval_query, embedder)

# Upsert the query
vector = [random.random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (doc:Document {id: $id})"
"WITH doc "
"CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)"
"WITH doc "
"MERGE (author:Author {name: $authorName})"
"MERGE (doc)-[:AUTHORED_BY]->(author)"
"RETURN doc, author"
)
parameters = {
"id": random.randint(0, 10000),
"vector": vector,
"authorName": random_str(10),
}
driver.execute_query(insert_query, parameters)

# Perform the search
query_text = "Find me the closest text"
print(retriever.search(query_text=query_text, top_k=1))
2 changes: 1 addition & 1 deletion src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def search(
ValueError: If no embedder is provided.
Returns:
Any: The results of the search query
list[Neo4jRecord]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down

0 comments on commit 53aa093

Please sign in to comment.