diff --git a/libs/langchain/langchain/chains/graph_qa/cypher.py b/libs/langchain/langchain/chains/graph_qa/cypher.py index f8507408195f1..50ee3771d6a63 100644 --- a/libs/langchain/langchain/chains/graph_qa/cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/cypher.py @@ -37,7 +37,7 @@ def extract_cypher(text: str) -> str: def construct_schema( - structured_schema: Dict[str, Dict[str, Any]], + structured_schema: Dict[str, Any], include_types: List[str], exclude_types: List[str], ) -> str: @@ -46,17 +46,12 @@ def construct_schema( def filter_func(x: str) -> bool: return x in include_types if include_types else x not in exclude_types + node_props: Dict[str, Any] = structured_schema.get("node_props", {}) + rel_props: Dict[str, Any] = structured_schema.get("rel_props", {}) + filtered_schema = { - "node_props": { - k: v - for k, v in structured_schema.get("node_props", {}).items() - if filter_func(k) - }, - "rel_props": { - k: v - for k, v in structured_schema.get("rel_props", {}).items() - if filter_func(k) - }, + "node_props": {k: v for k, v in node_props.items() if filter_func(k)}, + "rel_props": {k: v for k, v in rel_props.items() if filter_func(k)}, "relationships": [ r for r in structured_schema.get("relationships", [])