Skip to content

Commit

Permalink
Text2Cypher custom prompt: doc, example and bug fix (#229)
Browse files Browse the repository at this point in the history
* Doc + bug fix

* Do not change the behavior, just document they said

* Use same order for patched functions and check order of mocked object
  • Loading branch information
stellasia authored Dec 12, 2024
1 parent 140a057 commit c33f9c8
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 28 deletions.
9 changes: 9 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,21 @@ RagTemplate

.. autoclass:: neo4j_graphrag.generation.prompts.RagTemplate
:members:
:exclude-members: format

ERExtractionTemplate
--------------------

.. autoclass:: neo4j_graphrag.generation.prompts.ERExtractionTemplate
:members:
:exclude-members: format

Text2CypherTemplate
--------------------

.. autoclass:: neo4j_graphrag.generation.prompts.Text2CypherTemplate
:members:
:exclude-members: format


****
Expand Down
5 changes: 3 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ are listed in [the last section of this file](#customize).

- [Control result format for VectorRetriever](customize/retrievers/result_formatter_vector_retriever.py)
- [Control result format for VectorCypherRetriever](customize/retrievers/result_formatter_vector_cypher_retriever.py)

- [Use pre-filters](customize/retrievers/use_pre_filters.py)
- [Text2Cypher: use a custom prompt](customize/retrievers/text2cypher_custom_prompt.py)

### LLMs

Expand All @@ -74,7 +75,7 @@ are listed in [the last section of this file](#customize).

### Prompts

- [Using a custom prompt](old/graphrag_custom_prompt.py)
- [Using a custom prompt for RAG](customize/answer/custom_prompt.py)


### Embedders
Expand Down
76 changes: 76 additions & 0 deletions examples/customize/retrievers/text2cypher_custom_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""The example shows how to provide a custom prompt to Text2CypherRetriever.
Example using the OpenAILLM, hence the OPENAI_API_KEY needs to be set in the
environment for this example to run.
"""

import neo4j
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.retrievers import Text2CypherRetriever
from neo4j_graphrag.schema import get_schema

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"

# Create LLM object
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

# (Optional) Specify your own Neo4j schema
# (also see get_structured_schema and get_schema functions)
neo4j_schema = """
Node properties:
User {name: STRING}
Person {name: STRING, born: INTEGER}
Movie {tagline: STRING, title: STRING, released: INTEGER}
Relationship properties:
ACTED_IN {roles: LIST}
DIRECTED {}
REVIEWED {summary: STRING, rating: INTEGER}
The relationships:
(:Person)-[:ACTED_IN]->(:Movie)
(:Person)-[:DIRECTED]->(:Movie)
(:User)-[:REVIEWED]->(:Movie)
"""

prompt = """Task: Generate a Cypher statement for querying a Neo4j graph database from a user input.
Do not use any properties or relationships not included in the schema.
Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response.
Always filter movies that have not already been reviewed by the user with name: '{user_name}' using for instance:
(m:Movie)<-[:REVIEWED]-(:User {{name: <the_user_name>}})
Schema:
{schema}
Input:
{query_text}
Cypher query:
"""

with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = Text2CypherRetriever(
driver=driver,
llm=llm,
neo4j_schema=neo4j_schema,
# here we provide a custom prompt
custom_prompt=prompt,
neo4j_database=DATABASE,
)

# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
query_text = "Which movies did Hugo Weaving star in?"
print(
retriever.search(
query_text=query_text,
prompt_params={
# you have to specify all placeholder except the {query_text} one
"schema": get_schema(driver),
"user_name": "the user asking question",
},
)
)
4 changes: 1 addition & 3 deletions examples/retrieve/text2cypher_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
Movie {tagline: STRING, title: STRING, released: INTEGER}
Relationship properties:
ACTED_IN {roles: LIST}
DIRECTED {}
REVIEWED {summary: STRING, rating: INTEGER}
The relationships:
(:Person)-[:ACTED_IN]->(:Movie)
(:Person)-[:DIRECTED]->(:Movie)
(:Person)-[:PRODUCED]->(:Movie)
(:Person)-[:WROTE]->(:Movie)
(:Person)-[:FOLLOWS]->(:Person)
(:Person)-[:REVIEWED]->(:Movie)
"""

Expand Down
38 changes: 18 additions & 20 deletions src/neo4j_graphrag/retrievers/text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Text2CypherRetriever(Retriever):
"""
Allows for the retrieval of records from a Neo4j database using natural language.
Converts a user's natural language query to a Cypher query using an LLM,
then retrieves records from a Neo4j database using the generated Cypher query
then retrieves records from a Neo4j database using the generated Cypher query.
Args:
driver (neo4j.Driver): The Neo4j Python driver.
Expand Down Expand Up @@ -98,23 +98,23 @@ def __init__(
self.examples = validated_data.examples
self.result_formatter = validated_data.result_formatter
self.custom_prompt = validated_data.custom_prompt
try:
if validated_data.custom_prompt:
neo4j_schema = ""
else:
if (
not validated_data.custom_prompt
): # don't need schema for a custom prompt
self.neo4j_schema = (
validated_data.neo4j_schema_model.neo4j_schema
if validated_data.neo4j_schema_model
else get_schema(validated_data.driver_model.driver)
)
validated_data.neo4j_schema_model
and validated_data.neo4j_schema_model.neo4j_schema
):
neo4j_schema = validated_data.neo4j_schema_model.neo4j_schema
else:
self.neo4j_schema = ""

except (Neo4jError, DriverError) as e:
error_message = getattr(e, "message", str(e))
raise SchemaFetchError(
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
) from e
try:
neo4j_schema = get_schema(validated_data.driver_model.driver)
except (Neo4jError, DriverError) as e:
error_message = getattr(e, "message", str(e))
raise SchemaFetchError(
f"Failed to fetch schema for Text2CypherRetriever: {error_message}"
) from e
self.neo4j_schema = neo4j_schema

def get_search_results(
self, query_text: str, prompt_params: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -142,12 +142,10 @@ def get_search_results(

if prompt_params is not None:
# parse the schema and examples inputs
examples_to_use = prompt_params.get("examples") or (
examples_to_use = prompt_params.pop("examples", None) or (
"\n".join(self.examples) if self.examples else ""
)
schema_to_use = prompt_params.get("schema") or self.neo4j_schema
prompt_params.pop("examples", None)
prompt_params.pop("schema", None)
schema_to_use = prompt_params.pop("schema", None) or self.neo4j_schema
else:
examples_to_use = "\n".join(self.examples) if self.examples else ""
schema_to_use = self.neo4j_schema
Expand Down
35 changes: 32 additions & 3 deletions tests/unit/retrievers/test_text2cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from neo4j.exceptions import CypherSyntaxError, Neo4jError
from neo4j_graphrag.exceptions import (
RetrieverInitializationError,
SchemaFetchError,
SearchValidationError,
Text2CypherRetrievalError,
)
Expand All @@ -39,8 +40,8 @@ def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_schema_retrieval(
_verify_version_mock: MagicMock,
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
Expand All @@ -51,13 +52,13 @@ def test_t2c_retriever_schema_retrieval(
@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_schema_retrieval_failure(
_verify_version_mock: MagicMock,
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
) -> None:
get_schema_mock.side_effect = Neo4jError
with pytest.raises(Neo4jError):
with pytest.raises(SchemaFetchError):
Text2CypherRetriever(driver, llm)


Expand Down Expand Up @@ -310,3 +311,31 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
llm.invoke.assert_called_once_with(
"""This is a custom prompt. test ['example A', 'example B']"""
)


@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version")
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
def test_t2c_retriever_with_custom_prompt_and_schema(
get_schema_mock: MagicMock,
_verify_version_mock: MagicMock,
driver: MagicMock,
llm: MagicMock,
neo4j_record: MagicMock,
) -> None:
prompt = "This is a custom prompt. {query_text} {schema}"
query = "test"

driver.execute_query.return_value = (
[neo4j_record],
None,
None,
)

retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt)
retriever.search(
query_text=query,
prompt_params={},
)

get_schema_mock.assert_not_called()
llm.invoke.assert_called_once_with("""This is a custom prompt. test """)

0 comments on commit c33f9c8

Please sign in to comment.