diff --git a/docs/extras/use_cases/more/graph/graph_hugegraph_qa.ipynb b/docs/extras/use_cases/more/graph/graph_hugegraph_qa.ipynb index dfd64125ae89d..7bcbe8e067802 100644 --- a/docs/extras/use_cases/more/graph/graph_hugegraph_qa.ipynb +++ b/docs/extras/use_cases/more/graph/graph_hugegraph_qa.ipynb @@ -216,7 +216,7 @@ } ], "source": [ - "print(graph.get_schema)" + "print(graph.schema)" ] }, { diff --git a/docs/extras/use_cases/more/graph/graph_kuzu_qa.ipynb b/docs/extras/use_cases/more/graph/graph_kuzu_qa.ipynb index 2604d0b5f2c9f..62668218e7eb0 100644 --- a/docs/extras/use_cases/more/graph/graph_kuzu_qa.ipynb +++ b/docs/extras/use_cases/more/graph/graph_kuzu_qa.ipynb @@ -189,7 +189,7 @@ } ], "source": [ - "print(graph.get_schema)" + "print(graph.schema)" ] }, { diff --git a/docs/extras/use_cases/more/graph/graph_nebula_qa.ipynb b/docs/extras/use_cases/more/graph/graph_nebula_qa.ipynb index 738fe5c9b0e10..3e3a26eab30f4 100644 --- a/docs/extras/use_cases/more/graph/graph_nebula_qa.ipynb +++ b/docs/extras/use_cases/more/graph/graph_nebula_qa.ipynb @@ -182,7 +182,7 @@ } ], "source": [ - "print(graph.get_schema)" + "print(graph.schema)" ] }, { diff --git a/docs/extras/use_cases/more/graph/graph_sparql_qa.ipynb b/docs/extras/use_cases/more/graph/graph_sparql_qa.ipynb index 288dc874ecd9b..4c593c29d6b2e 100644 --- a/docs/extras/use_cases/more/graph/graph_sparql_qa.ipynb +++ b/docs/extras/use_cases/more/graph/graph_sparql_qa.ipynb @@ -106,7 +106,7 @@ } ], "source": [ - "graph.get_schema" + "graph.schema" ] }, { @@ -300,4 +300,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/libs/langchain/langchain/chains/graph_qa/cypher.py b/libs/langchain/langchain/chains/graph_qa/cypher.py index bb2dcb0d9ec9f..f06d2b33a3aa1 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher.py @@ -8,7 +8,7 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT from langchain.chains.llm import LLMChain -from langchain.graphs.neo4j_graph import Neo4jGraph +from langchain.graphs.graph_store import GraphStore from langchain.pydantic_v1 import Field from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel @@ -78,7 +78,7 @@ def filter_func(x: str) -> bool: class GraphCypherQAChain(Chain): """Chain for question-answering against a graph by generating Cypher statements.""" - graph: Neo4jGraph = Field(exclude=True) + graph: GraphStore = Field(exclude=True) cypher_generation_chain: LLMChain qa_chain: LLMChain graph_schema: str diff --git a/libs/langchain/langchain/chains/graph_qa/hugegraph.py b/libs/langchain/langchain/chains/graph_qa/hugegraph.py index 9c11016cd1b3e..266a149c98aa2 100644 --- a/libs/langchain/langchain/chains/graph_qa/hugegraph.py +++ b/libs/langchain/langchain/chains/graph_qa/hugegraph.py @@ -72,7 +72,7 @@ def _call( question = inputs[self.input_key] generated_gremlin = self.gremlin_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + {"question": question, "schema": self.graph.schema}, callbacks=callbacks ) _run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose) diff --git a/libs/langchain/langchain/chains/graph_qa/kuzu.py b/libs/langchain/langchain/chains/graph_qa/kuzu.py index 8246e92588be5..4e04f3dfcce66 100644 --- a/libs/langchain/langchain/chains/graph_qa/kuzu.py +++ b/libs/langchain/langchain/chains/graph_qa/kuzu.py @@ -71,7 +71,7 @@ def _call( question = inputs[self.input_key] generated_cypher = self.cypher_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + {"question": question, "schema": self.graph.schema}, callbacks=callbacks ) _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) diff --git a/libs/langchain/langchain/chains/graph_qa/nebulagraph.py b/libs/langchain/langchain/chains/graph_qa/nebulagraph.py index 09dd52671bac5..7013c66b07c0f 100644 --- a/libs/langchain/langchain/chains/graph_qa/nebulagraph.py +++ b/libs/langchain/langchain/chains/graph_qa/nebulagraph.py @@ -69,7 +69,7 @@ def _call( question = inputs[self.input_key] generated_ngql = self.ngql_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + {"question": question, "schema": self.graph.schema}, callbacks=callbacks ) _run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose) diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py index 1e3cb6464406a..a0328cf7cb4d2 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py @@ -156,7 +156,7 @@ def _call( intermediate_steps: List = [] generated_cypher = self.cypher_generation_chain.run( - {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + {"question": question, "schema": self.graph.schema}, callbacks=callbacks ) # Extract Cypher code if it is wrapped in backticks diff --git a/libs/langchain/langchain/chains/graph_qa/sparql.py b/libs/langchain/langchain/chains/graph_qa/sparql.py index 2e1c017748bb4..60d3b78f0f3d2 100644 --- a/libs/langchain/langchain/chains/graph_qa/sparql.py +++ b/libs/langchain/langchain/chains/graph_qa/sparql.py @@ -100,7 +100,7 @@ def _call( _run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose) generated_sparql = sparql_generation_chain.run( - {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks + {"prompt": prompt, "schema": self.graph.schema}, callbacks=callbacks ) _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) diff --git a/libs/langchain/langchain/graphs/falkordb_graph.py b/libs/langchain/langchain/graphs/falkordb_graph.py index f74281b3c18a7..43dae091d7bde 100644 --- a/libs/langchain/langchain/graphs/falkordb_graph.py +++ b/libs/langchain/langchain/graphs/falkordb_graph.py @@ -1,5 +1,8 @@ from typing import Any, Dict, List +from langchain.graphs.graph_document import GraphDocument +from langchain.graphs.graph_store import GraphStore + node_properties_query = """ MATCH (n) UNWIND labels(n) as l @@ -20,7 +23,7 @@ """ -class FalkorDBGraph: +class FalkorDBGraph(GraphStore): """FalkorDB wrapper for graph operations.""" def __init__( @@ -38,6 +41,8 @@ def __init__( self._driver = redis.Redis(host=host, port=port) self._graph = Graph(self._driver, database) + self._schema: str = "" + self._structured_schema: Dict[str, Any] = {} try: self.refresh_schema() @@ -45,16 +50,31 @@ def __init__( raise ValueError(f"Could not refresh schema. Error: {e}") @property - def get_schema(self) -> str: + def schema(self) -> str: """Returns the schema of the FalkorDB database""" - return self.schema + return self._schema + + @property + def structured_schema(self) -> Dict[str, Any]: + """Returns the structured schema of the Graph database""" + return self._structured_schema def refresh_schema(self) -> None: """Refreshes the schema of the FalkorDB database""" - self.schema = ( - f"Node properties: {self.query(node_properties_query)}\n" - f"Relationships properties: {self.query(rel_properties_query)}\n" - f"Relationships: {self.query(rel_query)}\n" + node_properties = self.query(node_properties_query) + rel_properties = self.query(rel_properties_query) + relationships = self.query(rel_query) + + self._structured_schema = { + "node_props": node_properties, + "rel_props": rel_properties, + "relationships": relationships, + } + + self._schema = ( + f"Node properties: {node_properties}\n" + f"Relationships properties: {rel_properties}\n" + f"Relationships: {relationships}\n" ) def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: @@ -65,3 +85,48 @@ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: return data.result_set except Exception as e: raise ValueError("Generated Cypher Statement is not valid\n" f"{e}") + + def add_graph_documents( + self, graph_documents: List[GraphDocument], include_source: bool = False + ) -> None: + """ + Take GraphDocument as input as uses it to construct a graph. + """ + for document in graph_documents: + # Import nodes + for node in document.nodes: + props = " ".join( + "SET n.{0}='{1}'".format(k, v.replace("'", "\\'")) + if isinstance(v, str) + else "SET n.{0}={1}".format(k, v) + for k, v in node.properties.items() + ) + + self.query( + ( + # f"{include_docs_query if include_source else ''}" + f"MERGE (n:{node.type} {{id:'{node.id}'}}) " + f"{props} " + # f"{'MERGE (d)-[:MENTIONS]->(n) ' if include_source else ''}" + "RETURN distinct 'done' AS result" + ) + ) + + # Import relationships + for rel in document.relationships: + props = " ".join( + "SET r.{0}='{1}'".format(k, v.replace("'", "\\'")) + if isinstance(v, str) + else "SET r.{0}={1}".format(k, v) + for k, v in rel.properties.items() + ) + + self.query( + ( + f"MATCH (a:{rel.source.type} {{id:'{rel.source.id}'}}), " + f"(b:{rel.target.type} {{id:'{rel.target.id}'}}) " + f"MERGE (a)-[r:{(rel.type.replace(' ', '_').upper())}]->(b) " + f"{props} " + "RETURN distinct 'done' AS result" + ) + ) diff --git a/libs/langchain/langchain/graphs/graph_store.py b/libs/langchain/langchain/graphs/graph_store.py new file mode 100644 index 0000000000000..434f71a5ced1f --- /dev/null +++ b/libs/langchain/langchain/graphs/graph_store.py @@ -0,0 +1,37 @@ +from abc import abstractmethod +from typing import Any, Dict, List + +from langchain.graphs.graph_document import GraphDocument + + +class GraphStore: + """An abstract class wrapper for graph operations.""" + + @property + @abstractmethod + def schema(self) -> str: + """Returns the schema of the Graph database""" + pass + + @property + @abstractmethod + def structured_schema(self) -> Dict[str, Any]: + """Returns the structured schema of the Graph database""" + pass + + @abstractmethod + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: + """Query the graph.""" + pass + + @abstractmethod + def refresh_schema(self) -> None: + """Refreshes the graph schema information.""" + pass + + @abstractmethod + def add_graph_documents( + self, graph_documents: List[GraphDocument], include_source: bool = False + ) -> None: + """Take GraphDocument as input as uses it to construct a graph.""" + pass diff --git a/libs/langchain/langchain/graphs/hugegraph.py b/libs/langchain/langchain/graphs/hugegraph.py index f6afc99af9d54..b0a8639df2335 100644 --- a/libs/langchain/langchain/graphs/hugegraph.py +++ b/libs/langchain/langchain/graphs/hugegraph.py @@ -29,7 +29,7 @@ def __init__( self.client = PyHugeGraph( address, port, user=username, pwd=password, graph=graph ) - self.schema = "" + self._schema = "" # Set schema try: self.refresh_schema() @@ -37,9 +37,9 @@ def __init__( raise ValueError(f"Could not refresh schema. Error: {e}") @property - def get_schema(self) -> str: + def schema(self) -> str: """Returns the schema of the HugeGraph database""" - return self.schema + return self._schema def refresh_schema(self) -> None: """ @@ -50,7 +50,7 @@ def refresh_schema(self) -> None: edge_schema = schema.getEdgeLabels() relationships = schema.getRelations() - self.schema = ( + self._schema = ( f"Node properties: {vertex_schema}\n" f"Edge properties: {edge_schema}\n" f"Relationships: {relationships}\n" diff --git a/libs/langchain/langchain/graphs/kuzu_graph.py b/libs/langchain/langchain/graphs/kuzu_graph.py index 85841165d2375..31c5a0ab71ae7 100644 --- a/libs/langchain/langchain/graphs/kuzu_graph.py +++ b/libs/langchain/langchain/graphs/kuzu_graph.py @@ -18,9 +18,9 @@ def __init__(self, db: Any, database: str = "kuzu") -> None: self.refresh_schema() @property - def get_schema(self) -> str: + def schema(self) -> str: """Returns the schema of the Kùzu database""" - return self.schema + return self._schema def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Kùzu database""" @@ -83,7 +83,7 @@ def refresh_schema(self) -> None: ) rel_properties.append(current_table_schema) - self.schema = ( + self._schema = ( f"Node properties: {node_properties}\n" f"Relationships properties: {rel_properties}\n" f"Relationships: {relationships}\n" diff --git a/libs/langchain/langchain/graphs/memgraph_graph.py b/libs/langchain/langchain/graphs/memgraph_graph.py index 79ca390401ec4..a00dbe2671ef8 100644 --- a/libs/langchain/langchain/graphs/memgraph_graph.py +++ b/libs/langchain/langchain/graphs/memgraph_graph.py @@ -29,8 +29,8 @@ def refresh_schema(self) -> None: db_schema = self.query(SCHEMA_QUERY)[0].get("schema") assert db_schema is not None - self.schema = db_schema + self._schema = db_schema db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema") assert db_structured_schema is not None - self.structured_schema = db_structured_schema + self._structured_schema = db_structured_schema diff --git a/libs/langchain/langchain/graphs/nebula_graph.py b/libs/langchain/langchain/graphs/nebula_graph.py index 8a031e372becb..2206efaaa722b 100644 --- a/libs/langchain/langchain/graphs/nebula_graph.py +++ b/libs/langchain/langchain/graphs/nebula_graph.py @@ -48,7 +48,7 @@ def __init__( self.session_pool_size = session_pool_size self.session_pool = self._get_session_pool() - self.schema = "" + self._schema = "" # Set schema try: self.refresh_schema() @@ -102,9 +102,9 @@ def __del__(self) -> None: logger.warning(f"Could not close session pool. Error: {e}") @property - def get_schema(self) -> str: + def schema(self) -> str: """Returns the schema of the NebulaGraph database""" - return self.schema + return self._schema def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any: """Query NebulaGraph database.""" @@ -187,7 +187,7 @@ def refresh_schema(self) -> None: if len(r) > 0: relationships.append(r[0].cast()) - self.schema = ( + self._schema = ( f"Node properties: {tags_schema}\n" f"Edge properties: {edge_types_schema}\n" f"Relationships: {relationships}\n" diff --git a/libs/langchain/langchain/graphs/neo4j_graph.py b/libs/langchain/langchain/graphs/neo4j_graph.py index ec6c156018a76..03f07512cbd38 100644 --- a/libs/langchain/langchain/graphs/neo4j_graph.py +++ b/libs/langchain/langchain/graphs/neo4j_graph.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List from langchain.graphs.graph_document import GraphDocument +from langchain.graphs.graph_store import GraphStore node_properties_query = """ CALL apoc.meta.data() @@ -28,7 +29,7 @@ """ -class Neo4jGraph: +class Neo4jGraph(GraphStore): """Neo4j wrapper for graph operations.""" def __init__( @@ -45,8 +46,8 @@ def __init__( self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) self._database = database - self.schema: str = "" - self.structured_schema: Dict[str, Any] = {} + self._schema: str = "" + self._structured_schema: Dict[str, Any] = {} # Verify connection try: self._driver.verify_connectivity() @@ -70,6 +71,16 @@ def __init__( "'apoc.meta.data()' is allowed in Neo4j configuration " ) + @property + def schema(self) -> str: + """Returns the schema of the Graph""" + return self._schema + + @property + def structured_schema(self) -> Dict[str, Any]: + """Returns the structured schema of the Graph database""" + return self._structured_schema + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Neo4j database.""" from neo4j.exceptions import CypherSyntaxError @@ -89,12 +100,12 @@ def refresh_schema(self) -> None: rel_properties = [el["output"] for el in self.query(rel_properties_query)] relationships = [el["output"] for el in self.query(rel_query)] - self.structured_schema = { + self._structured_schema = { "node_props": {el["labels"]: el["properties"] for el in node_properties}, "rel_props": {el["type"]: el["properties"] for el in rel_properties}, "relationships": relationships, } - self.schema = f""" + self._schema = f""" Node properties are the following: {node_properties} Relationship properties are the following: diff --git a/libs/langchain/langchain/graphs/neptune_graph.py b/libs/langchain/langchain/graphs/neptune_graph.py index ac6e98eb253df..816cf5e4d491e 100644 --- a/libs/langchain/langchain/graphs/neptune_graph.py +++ b/libs/langchain/langchain/graphs/neptune_graph.py @@ -103,9 +103,9 @@ def __init__( ) @property - def get_schema(self) -> str: + def schema(self) -> str: """Returns the schema of the Neptune database""" - return self.schema + return self._schema def query(self, query: str, params: dict = {}) -> Dict[str, Any]: """Query Neptune database.""" @@ -229,7 +229,7 @@ def _refresh_schema(self) -> None: node_properties = self._get_node_properties(n_labels, types) edge_properties = self._get_edge_properties(e_labels, types) - self.schema = f""" + self._schema = f""" Node properties are the following: {node_properties} Relationship properties are the following: diff --git a/libs/langchain/langchain/graphs/rdf_graph.py b/libs/langchain/langchain/graphs/rdf_graph.py index 342d4dfc80d4e..ada8a3016a944 100644 --- a/libs/langchain/langchain/graphs/rdf_graph.py +++ b/libs/langchain/langchain/graphs/rdf_graph.py @@ -173,15 +173,15 @@ def __init__( raise AssertionError("The graph is empty.") # Set schema - self.schema = "" + self._schema = "" self.load_schema() @property - def get_schema(self) -> str: + def schema(self) -> str: """ Returns the schema of the graph database. """ - return self.schema + return self._schema def query( self, @@ -261,16 +261,16 @@ def _rdf_s_schema( if self.standard == "rdf": clss = self.query(cls_query_rdf) rels = self.query(rel_query_rdf) - self.schema = _rdf_s_schema(clss, rels) + self._schema = _rdf_s_schema(clss, rels) elif self.standard == "rdfs": clss = self.query(cls_query_rdfs) rels = self.query(rel_query_rdfs) - self.schema = _rdf_s_schema(clss, rels) + self._schema = _rdf_s_schema(clss, rels) elif self.standard == "owl": clss = self.query(cls_query_owl) ops = self.query(op_query_owl) dps = self.query(dp_query_owl) - self.schema = ( + self._schema = ( f"In the following, each IRI is followed by the local name and " f"optionally its description in parentheses. \n" f"The OWL graph supports the following node types:\n" diff --git a/libs/langchain/tests/integration_tests/graphs/test_falkordb.py b/libs/langchain/tests/integration_tests/graphs/test_falkordb.py index de6c77a49317b..7f2f24b4a2dbb 100644 --- a/libs/langchain/tests/integration_tests/graphs/test_falkordb.py +++ b/libs/langchain/tests/integration_tests/graphs/test_falkordb.py @@ -31,4 +31,4 @@ def test_refresh_schema(self, mock_client: Any) -> None: graph = FalkorDBGraph(database=self.graph, host=self.host, port=self.port) graph.refresh_schema() - self.assertNotEqual(graph.get_schema, "") + self.assertNotEqual(graph.schema, "") diff --git a/libs/langchain/tests/integration_tests/graphs/test_hugegraph.py b/libs/langchain/tests/integration_tests/graphs/test_hugegraph.py index 23a3893c2a6d0..3c8f386696159 100644 --- a/libs/langchain/tests/integration_tests/graphs/test_hugegraph.py +++ b/libs/langchain/tests/integration_tests/graphs/test_hugegraph.py @@ -43,4 +43,4 @@ def test_refresh_schema(self, mock_client: Any) -> None: self.username, self.password, self.address, self.port, self.graph ) huge_graph.refresh_schema() - self.assertNotEqual(huge_graph.get_schema, "") + self.assertNotEqual(huge_graph.schema, "") diff --git a/libs/langchain/tests/integration_tests/test_kuzu.py b/libs/langchain/tests/integration_tests/test_kuzu.py index c6cd36e34a5a3..94b192fece242 100644 --- a/libs/langchain/tests/integration_tests/test_kuzu.py +++ b/libs/langchain/tests/integration_tests/test_kuzu.py @@ -52,5 +52,5 @@ def test_refresh_schema(self) -> None: ) self.conn.execute("CREATE REL TABLE ActedIn (FROM Person TO Movie)") self.kuzu_graph.refresh_schema() - schema = self.kuzu_graph.get_schema + schema = self.kuzu_graph.schema self.assertEqual(schema, EXPECTED_SCHEMA) diff --git a/libs/langchain/tests/integration_tests/test_nebulagraph.py b/libs/langchain/tests/integration_tests/test_nebulagraph.py index bf10f9097bb63..720b43ad58b8f 100644 --- a/libs/langchain/tests/integration_tests/test_nebulagraph.py +++ b/libs/langchain/tests/integration_tests/test_nebulagraph.py @@ -87,4 +87,4 @@ def test_refresh_schema(self, mock_session_pool: Any) -> None: self.session_pool_size, ) nebula_graph.refresh_schema() - self.assertNotEqual(nebula_graph.get_schema, "") + self.assertNotEqual(nebula_graph.schema, "")