diff --git a/docs/source/api.rst b/docs/source/api.rst index 8e52f9a4..ca0647ee 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -338,12 +338,21 @@ RagTemplate .. autoclass:: neo4j_graphrag.generation.prompts.RagTemplate :members: + :exclude-members: format ERExtractionTemplate -------------------- .. autoclass:: neo4j_graphrag.generation.prompts.ERExtractionTemplate :members: + :exclude-members: format + +Text2CypherTemplate +-------------------- + +.. autoclass:: neo4j_graphrag.generation.prompts.Text2CypherTemplate + :members: + :exclude-members: format **** diff --git a/examples/README.md b/examples/README.md index 2faed5f8..a7308660 100644 --- a/examples/README.md +++ b/examples/README.md @@ -58,7 +58,8 @@ are listed in [the last section of this file](#customize). - [Control result format for VectorRetriever](customize/retrievers/result_formatter_vector_retriever.py) - [Control result format for VectorCypherRetriever](customize/retrievers/result_formatter_vector_cypher_retriever.py) - +- [Use pre-filters](customize/retrievers/use_pre_filters.py) +- [Text2Cypher: use a custom prompt](customize/retrievers/text2cypher_custom_prompt.py) ### LLMs @@ -74,7 +75,7 @@ are listed in [the last section of this file](#customize). ### Prompts -- [Using a custom prompt](old/graphrag_custom_prompt.py) +- [Using a custom prompt for RAG](customize/answer/custom_prompt.py) ### Embedders diff --git a/examples/customize/retrievers/text2cypher_custom_prompt.py b/examples/customize/retrievers/text2cypher_custom_prompt.py new file mode 100644 index 00000000..64ab5290 --- /dev/null +++ b/examples/customize/retrievers/text2cypher_custom_prompt.py @@ -0,0 +1,76 @@ +"""The example shows how to provide a custom prompt to Text2CypherRetriever. + +Example using the OpenAILLM, hence the OPENAI_API_KEY needs to be set in the +environment for this example to run. +""" + +import neo4j +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.retrievers import Text2CypherRetriever +from neo4j_graphrag.schema import get_schema + +# Define database credentials +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" + +# Create LLM object +llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}) + +# (Optional) Specify your own Neo4j schema +# (also see get_structured_schema and get_schema functions) +neo4j_schema = """ +Node properties: +User {name: STRING} +Person {name: STRING, born: INTEGER} +Movie {tagline: STRING, title: STRING, released: INTEGER} +Relationship properties: +ACTED_IN {roles: LIST} +DIRECTED {} +REVIEWED {summary: STRING, rating: INTEGER} +The relationships: +(:Person)-[:ACTED_IN]->(:Movie) +(:Person)-[:DIRECTED]->(:Movie) +(:User)-[:REVIEWED]->(:Movie) +""" + +prompt = """Task: Generate a Cypher statement for querying a Neo4j graph database from a user input. + +Do not use any properties or relationships not included in the schema. +Do not include triple backticks ``` or any additional text except the generated Cypher statement in your response. + +Always filter movies that have not already been reviewed by the user with name: '{user_name}' using for instance: +(m:Movie)<-[:REVIEWED]-(:User {{name: }}) + +Schema: +{schema} + +Input: +{query_text} + +Cypher query: +""" + +with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver: + # Initialize the retriever + retriever = Text2CypherRetriever( + driver=driver, + llm=llm, + neo4j_schema=neo4j_schema, + # here we provide a custom prompt + custom_prompt=prompt, + neo4j_database=DATABASE, + ) + + # Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results + query_text = "Which movies did Hugo Weaving star in?" + print( + retriever.search( + query_text=query_text, + prompt_params={ + # you have to specify all placeholder except the {query_text} one + "schema": get_schema(driver), + "user_name": "the user asking question", + }, + ) + ) diff --git a/examples/retrieve/text2cypher_search.py b/examples/retrieve/text2cypher_search.py index e17cb65d..deb2f592 100644 --- a/examples/retrieve/text2cypher_search.py +++ b/examples/retrieve/text2cypher_search.py @@ -22,13 +22,11 @@ Movie {tagline: STRING, title: STRING, released: INTEGER} Relationship properties: ACTED_IN {roles: LIST} +DIRECTED {} REVIEWED {summary: STRING, rating: INTEGER} The relationships: (:Person)-[:ACTED_IN]->(:Movie) (:Person)-[:DIRECTED]->(:Movie) -(:Person)-[:PRODUCED]->(:Movie) -(:Person)-[:WROTE]->(:Movie) -(:Person)-[:FOLLOWS]->(:Person) (:Person)-[:REVIEWED]->(:Movie) """ diff --git a/src/neo4j_graphrag/retrievers/text2cypher.py b/src/neo4j_graphrag/retrievers/text2cypher.py index 039f42f0..fb0a3521 100644 --- a/src/neo4j_graphrag/retrievers/text2cypher.py +++ b/src/neo4j_graphrag/retrievers/text2cypher.py @@ -48,7 +48,7 @@ class Text2CypherRetriever(Retriever): """ Allows for the retrieval of records from a Neo4j database using natural language. Converts a user's natural language query to a Cypher query using an LLM, - then retrieves records from a Neo4j database using the generated Cypher query + then retrieves records from a Neo4j database using the generated Cypher query. Args: driver (neo4j.Driver): The Neo4j Python driver. @@ -98,23 +98,23 @@ def __init__( self.examples = validated_data.examples self.result_formatter = validated_data.result_formatter self.custom_prompt = validated_data.custom_prompt - try: + if validated_data.custom_prompt: + neo4j_schema = "" + else: if ( - not validated_data.custom_prompt - ): # don't need schema for a custom prompt - self.neo4j_schema = ( - validated_data.neo4j_schema_model.neo4j_schema - if validated_data.neo4j_schema_model - else get_schema(validated_data.driver_model.driver) - ) + validated_data.neo4j_schema_model + and validated_data.neo4j_schema_model.neo4j_schema + ): + neo4j_schema = validated_data.neo4j_schema_model.neo4j_schema else: - self.neo4j_schema = "" - - except (Neo4jError, DriverError) as e: - error_message = getattr(e, "message", str(e)) - raise SchemaFetchError( - f"Failed to fetch schema for Text2CypherRetriever: {error_message}" - ) from e + try: + neo4j_schema = get_schema(validated_data.driver_model.driver) + except (Neo4jError, DriverError) as e: + error_message = getattr(e, "message", str(e)) + raise SchemaFetchError( + f"Failed to fetch schema for Text2CypherRetriever: {error_message}" + ) from e + self.neo4j_schema = neo4j_schema def get_search_results( self, query_text: str, prompt_params: Optional[Dict[str, Any]] = None @@ -142,12 +142,10 @@ def get_search_results( if prompt_params is not None: # parse the schema and examples inputs - examples_to_use = prompt_params.get("examples") or ( + examples_to_use = prompt_params.pop("examples", None) or ( "\n".join(self.examples) if self.examples else "" ) - schema_to_use = prompt_params.get("schema") or self.neo4j_schema - prompt_params.pop("examples", None) - prompt_params.pop("schema", None) + schema_to_use = prompt_params.pop("schema", None) or self.neo4j_schema else: examples_to_use = "\n".join(self.examples) if self.examples else "" schema_to_use = self.neo4j_schema diff --git a/tests/unit/retrievers/test_text2cypher.py b/tests/unit/retrievers/test_text2cypher.py index 4d110c8e..05b1e545 100644 --- a/tests/unit/retrievers/test_text2cypher.py +++ b/tests/unit/retrievers/test_text2cypher.py @@ -19,6 +19,7 @@ from neo4j.exceptions import CypherSyntaxError, Neo4jError from neo4j_graphrag.exceptions import ( RetrieverInitializationError, + SchemaFetchError, SearchValidationError, Text2CypherRetrievalError, ) @@ -39,8 +40,8 @@ def test_t2c_retriever_initialization(driver: MagicMock, llm: MagicMock) -> None @patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") @patch("neo4j_graphrag.retrievers.text2cypher.get_schema") def test_t2c_retriever_schema_retrieval( - _verify_version_mock: MagicMock, get_schema_mock: MagicMock, + _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock, ) -> None: @@ -51,13 +52,13 @@ def test_t2c_retriever_schema_retrieval( @patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") @patch("neo4j_graphrag.retrievers.text2cypher.get_schema") def test_t2c_retriever_schema_retrieval_failure( - _verify_version_mock: MagicMock, get_schema_mock: MagicMock, + _verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock, ) -> None: get_schema_mock.side_effect = Neo4jError - with pytest.raises(Neo4jError): + with pytest.raises(SchemaFetchError): Text2CypherRetriever(driver, llm) @@ -310,3 +311,31 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params( llm.invoke.assert_called_once_with( """This is a custom prompt. test ['example A', 'example B']""" ) + + +@patch("neo4j_graphrag.retrievers.base.Retriever._verify_version") +@patch("neo4j_graphrag.retrievers.text2cypher.get_schema") +def test_t2c_retriever_with_custom_prompt_and_schema( + get_schema_mock: MagicMock, + _verify_version_mock: MagicMock, + driver: MagicMock, + llm: MagicMock, + neo4j_record: MagicMock, +) -> None: + prompt = "This is a custom prompt. {query_text} {schema}" + query = "test" + + driver.execute_query.return_value = ( + [neo4j_record], + None, + None, + ) + + retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt) + retriever.search( + query_text=query, + prompt_params={}, + ) + + get_schema_mock.assert_not_called() + llm.invoke.assert_called_once_with("""This is a custom prompt. test """)