Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph interface property #5

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/extras/use_cases/more/graph/graph_hugegraph_qa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
}
],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/extras/use_cases/more/graph/graph_kuzu_qa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
}
],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/extras/use_cases/more/graph/graph_nebula_qa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
}
],
"source": [
"print(graph.get_schema)"
"print(graph.schema)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/extras/use_cases/more/graph/graph_sparql_qa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
}
],
"source": [
"graph.get_schema"
"graph.schema"
]
},
{
Expand Down Expand Up @@ -300,4 +300,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
4 changes: 2 additions & 2 deletions libs/langchain/langchain/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.neo4j_graph import Neo4jGraph
from langchain.graphs.graph_store import GraphStore
from langchain.pydantic_v1 import Field
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
Expand Down Expand Up @@ -78,7 +78,7 @@ def filter_func(x: str) -> bool:
class GraphCypherQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements."""

graph: Neo4jGraph = Field(exclude=True)
graph: GraphStore = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
graph_schema: str
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/hugegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _call(
question = inputs[self.input_key]

generated_gremlin = self.gremlin_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)

_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/kuzu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _call(
question = inputs[self.input_key]

generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)

_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/nebulagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _call(
question = inputs[self.input_key]

generated_ngql = self.ngql_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)

_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/neptune_cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _call(
intermediate_steps: List = []

generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)

# Extract Cypher code if it is wrapped in backticks
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _call(
_run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose)

generated_sparql = sparql_generation_chain.run(
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
{"prompt": prompt, "schema": self.graph.schema}, callbacks=callbacks
)

_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
Expand Down
79 changes: 72 additions & 7 deletions libs/langchain/langchain/graphs/falkordb_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict, List

from langchain.graphs.graph_document import GraphDocument
from langchain.graphs.graph_store import GraphStore

node_properties_query = """
MATCH (n)
UNWIND labels(n) as l
Expand All @@ -20,7 +23,7 @@
"""


class FalkorDBGraph:
class FalkorDBGraph(GraphStore):
"""FalkorDB wrapper for graph operations."""

def __init__(
Expand All @@ -38,23 +41,40 @@ def __init__(

self._driver = redis.Redis(host=host, port=port)
self._graph = Graph(self._driver, database)
self._schema: str = ""
self._structured_schema: Dict[str, Any] = {}

try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")

@property
def get_schema(self) -> str:
def schema(self) -> str:
"""Returns the schema of the FalkorDB database"""
return self.schema
return self._schema

@property
def structured_schema(self) -> Dict[str, Any]:
"""Returns the structured schema of the Graph database"""
return self._structured_schema

def refresh_schema(self) -> None:
"""Refreshes the schema of the FalkorDB database"""
self.schema = (
f"Node properties: {self.query(node_properties_query)}\n"
f"Relationships properties: {self.query(rel_properties_query)}\n"
f"Relationships: {self.query(rel_query)}\n"
node_properties = self.query(node_properties_query)
rel_properties = self.query(rel_properties_query)
relationships = self.query(rel_query)

self._structured_schema = {
"node_props": node_properties,
"rel_props": rel_properties,
"relationships": relationships,
}

self._schema = (
f"Node properties: {node_properties}\n"
f"Relationships properties: {rel_properties}\n"
f"Relationships: {relationships}\n"
)

def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
Expand All @@ -65,3 +85,48 @@ def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
return data.result_set
except Exception as e:
raise ValueError("Generated Cypher Statement is not valid\n" f"{e}")

def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""
Take GraphDocument as input as uses it to construct a graph.
"""
for document in graph_documents:
# Import nodes
for node in document.nodes:
props = " ".join(
"SET n.{0}='{1}'".format(k, v.replace("'", "\\'"))
if isinstance(v, str)
else "SET n.{0}={1}".format(k, v)
for k, v in node.properties.items()
)

self.query(
(
# f"{include_docs_query if include_source else ''}"
f"MERGE (n:{node.type} {{id:'{node.id}'}}) "
f"{props} "
# f"{'MERGE (d)-[:MENTIONS]->(n) ' if include_source else ''}"
"RETURN distinct 'done' AS result"
)
)

# Import relationships
for rel in document.relationships:
props = " ".join(
"SET r.{0}='{1}'".format(k, v.replace("'", "\\'"))
if isinstance(v, str)
else "SET r.{0}={1}".format(k, v)
for k, v in rel.properties.items()
)

self.query(
(
f"MATCH (a:{rel.source.type} {{id:'{rel.source.id}'}}), "
f"(b:{rel.target.type} {{id:'{rel.target.id}'}}) "
f"MERGE (a)-[r:{(rel.type.replace(' ', '_').upper())}]->(b) "
f"{props} "
"RETURN distinct 'done' AS result"
)
)
37 changes: 37 additions & 0 deletions libs/langchain/langchain/graphs/graph_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import abstractmethod
from typing import Any, Dict, List

from langchain.graphs.graph_document import GraphDocument


class GraphStore:
"""An abstract class wrapper for graph operations."""

@property
@abstractmethod
def schema(self) -> str:
"""Returns the schema of the Graph database"""
pass

@property
@abstractmethod
def structured_schema(self) -> Dict[str, Any]:
"""Returns the structured schema of the Graph database"""
pass

@abstractmethod
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query the graph."""
pass

@abstractmethod
def refresh_schema(self) -> None:
"""Refreshes the graph schema information."""
pass

@abstractmethod
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""Take GraphDocument as input as uses it to construct a graph."""
pass
8 changes: 4 additions & 4 deletions libs/langchain/langchain/graphs/hugegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ def __init__(
self.client = PyHugeGraph(
address, port, user=username, pwd=password, graph=graph
)
self.schema = ""
self._schema = ""
# Set schema
try:
self.refresh_schema()
except Exception as e:
raise ValueError(f"Could not refresh schema. Error: {e}")

@property
def get_schema(self) -> str:
def schema(self) -> str:
"""Returns the schema of the HugeGraph database"""
return self.schema
return self._schema

def refresh_schema(self) -> None:
"""
Expand All @@ -50,7 +50,7 @@ def refresh_schema(self) -> None:
edge_schema = schema.getEdgeLabels()
relationships = schema.getRelations()

self.schema = (
self._schema = (
f"Node properties: {vertex_schema}\n"
f"Edge properties: {edge_schema}\n"
f"Relationships: {relationships}\n"
Expand Down
6 changes: 3 additions & 3 deletions libs/langchain/langchain/graphs/kuzu_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def __init__(self, db: Any, database: str = "kuzu") -> None:
self.refresh_schema()

@property
def get_schema(self) -> str:
def schema(self) -> str:
"""Returns the schema of the Kùzu database"""
return self.schema
return self._schema

def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query Kùzu database"""
Expand Down Expand Up @@ -83,7 +83,7 @@ def refresh_schema(self) -> None:
)
rel_properties.append(current_table_schema)

self.schema = (
self._schema = (
f"Node properties: {node_properties}\n"
f"Relationships properties: {rel_properties}\n"
f"Relationships: {relationships}\n"
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/graphs/memgraph_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def refresh_schema(self) -> None:

db_schema = self.query(SCHEMA_QUERY)[0].get("schema")
assert db_schema is not None
self.schema = db_schema
self._schema = db_schema

db_structured_schema = self.query(RAW_SCHEMA_QUERY)[0].get("schema")
assert db_structured_schema is not None
self.structured_schema = db_structured_schema
self._structured_schema = db_structured_schema
8 changes: 4 additions & 4 deletions libs/langchain/langchain/graphs/nebula_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.session_pool_size = session_pool_size

self.session_pool = self._get_session_pool()
self.schema = ""
self._schema = ""
# Set schema
try:
self.refresh_schema()
Expand Down Expand Up @@ -102,9 +102,9 @@ def __del__(self) -> None:
logger.warning(f"Could not close session pool. Error: {e}")

@property
def get_schema(self) -> str:
def schema(self) -> str:
"""Returns the schema of the NebulaGraph database"""
return self.schema
return self._schema

def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any:
"""Query NebulaGraph database."""
Expand Down Expand Up @@ -187,7 +187,7 @@ def refresh_schema(self) -> None:
if len(r) > 0:
relationships.append(r[0].cast())

self.schema = (
self._schema = (
f"Node properties: {tags_schema}\n"
f"Edge properties: {edge_types_schema}\n"
f"Relationships: {relationships}\n"
Expand Down
Loading