Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added example for VectorCypherRetriever #12

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/vector_cypher_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from neo4j import GraphDatabase
from neo4j_genai import VectorCypherRetriever

import random
import string
from neo4j_genai.embedder import Embedder
from neo4j_genai.indexes import create_vector_index


URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")

INDEX_NAME = "embedding-name"
DIMENSION = 1536

# Connect to Neo4j database
driver = GraphDatabase.driver(URI, auth=AUTH)


# Create Embedder object
class CustomEmbedder(Embedder):
def embed_query(self, text: str) -> list[float]:
return [random.random() for _ in range(DIMENSION)]


# Generate random strings
def random_str(n: int) -> str:
return "".join([random.choice(string.ascii_letters) for _ in range(n)])
willtai marked this conversation as resolved.
Show resolved Hide resolved


embedder = CustomEmbedder()

# Creating the index
create_vector_index(
driver,
INDEX_NAME,
label="Document",
property="propertyKey",
dimensions=DIMENSION,
similarity_fn="euclidean",
)

# Initialize the retriever
retrieval_query = "MATCH (node)-[:AUTHORED_BY]->(author:Author)" "RETURN author.name"
willtai marked this conversation as resolved.
Show resolved Hide resolved
retriever = VectorCypherRetriever(driver, INDEX_NAME, retrieval_query, embedder)

# Upsert the query
vector = [random.random() for _ in range(DIMENSION)]
insert_query = (
"MERGE (doc:Document {id: $id})"
"WITH doc "
"CALL db.create.setNodeVectorProperty(doc, 'propertyKey', $vector)"
"WITH doc "
"MERGE (author:Author {name: $authorName})"
"MERGE (doc)-[:AUTHORED_BY]->(author)"
"RETURN doc, author"
)
parameters = {
"id": random.randint(0, 10000),
"vector": vector,
"authorName": random_str(10),
}
driver.execute_query(insert_query, parameters)

# Perform the search
query_text = "Find me the closest text"
print(retriever.search(query_text=query_text, top_k=1))
11 changes: 9 additions & 2 deletions src/neo4j_genai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,14 @@ def search(

class VectorCypherRetriever(VectorRetriever):
"""
Provides retrieval method using vector similarity and custom Cypher query
Provides retrieval method using vector similarity and custom Cypher query.
When providing the custom query, note that the existing variable `node` can be used.
The query prefix:
```
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
```

"""

def __init__(
Expand Down Expand Up @@ -167,7 +174,7 @@ def search(
ValueError: If no embedder is provided.

Returns:
Any: The results of the search query
list[Neo4jRecord]: The results of the search query
"""
try:
validated_data = VectorCypherSearchModel(
Expand Down