diff --git a/pyproject.toml b/pyproject.toml index d4ee3cb..89f86a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fk-graph" packages = ["src/fk_graph"] -version = "0.0.10" +version = "0.0.11" authors = [ { name="Andrew Curtis", email="fk.graph@fastmail.com" }, { name="John C Thomas" }, diff --git a/src/fk_graph/graph.py b/src/fk_graph/graph.py index f46dd4f..0d4431e 100644 --- a/src/fk_graph/graph.py +++ b/src/fk_graph/graph.py @@ -52,7 +52,13 @@ def __str__(self): return self.str() -def get_graph(engine, table, primary_key, only_tables=None) -> Graph: +def get_graph( + engine, + table, + primary_key, + only_tables=None, + exclude_edge=None, +) -> Graph: """Construct the graph for a specified data-point Args: @@ -61,6 +67,13 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph: primary_key: The primary key for the row. only_tables: (Optional) A list of table names. Any rows for tables not in the list will not be included in the graph. + exclude_edge: (Optional) A callable used to determine whether an edge + will be included in the graph. It should be of the form + `f(input_row, output_row) -> bool`, where the rows are SQL-Alchemy + `Row` instances. If the function returns `True` for a given pair of + nodes, then the corresponding edge is not included in the graph. The + returned graph is always connected, so any nodes that can only be + reached via the edge will also be omitted. Raises: TableDoesNotExist - when the specified table does not exist. @@ -80,7 +93,7 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph: row = _get_row(session, _table, primary_key) row_node = _create_node_from_row(row) graph.add_node(row_node) - _add_related_rows_to_graph(row, row_node, graph, only_tables) + _add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge) return graph @@ -103,13 +116,17 @@ def _get_row(session, table, primary_key): ) -def _add_related_rows_to_graph(row, row_node, graph, only_tables): +def _add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge): related = [] relationships = row.__mapper__.relationships for relationship in relationships: related_rows = _get_related_rows_for_relationship(row, relationship) for related_row in related_rows: - if _row_is_from_an_included_table(related_row, only_tables): + if ( + _row_is_from_an_included_table(related_row, only_tables) + and + _edge_is_not_excluded(exclude_edge, row, related_row) + ): related_node = _create_node_from_row(related_row) related.append((related_row, related_node)) unvisited = [ @@ -119,7 +136,12 @@ def _add_related_rows_to_graph(row, row_node, graph, only_tables): for _, related_node in related: graph.add_edge(row_node, related_node) for unvisited_row, unvisited_node in unvisited: - _add_related_rows_to_graph(unvisited_row, unvisited_node, graph, only_tables) + _add_related_rows_to_graph( + unvisited_row, + unvisited_node, + graph, only_tables, + exclude_edge + ) def _create_node_from_row(row): return Node( @@ -135,6 +157,13 @@ def _row_is_from_an_included_table(row, only_tables): row.__table__.name in only_tables ) +def _edge_is_not_excluded(exclude_edge, row, related_row): + return ( + exclude_edge is None + or + not exclude_edge(row, related_row) + ) + def _get_related_rows_for_relationship(row, relationship): relationship_name = _get_relationship_name(relationship) related_rows = getattr(row, relationship_name) diff --git a/tests/test_graph.py b/tests/test_graph.py index 019153b..e259adf 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -175,6 +175,39 @@ def test_can_create_graph_when_some_rows_have_null_foreign_keys(self): # Check does not add null nodes to graph. self.assertEqual(len(graph.nodes), 1) + def test_excludes_edges(self): + self._create_three_entries_with_linear_foreign_key_relations(self.engine) + + def is_output_table_c(input_row, output_row): + return output_row.__table__.name == "table_c" + + graph = get_graph( + self.engine, + "table_a", + "1", + exclude_edge=is_output_table_c, + ) + + # The last node should be excluded. + self.assertEqual(len(graph.nodes), 2) + + def test_excluding_edges_excludes_all_nodes_only_reachable_via_these_edges(self): + self._create_three_entries_with_linear_foreign_key_relations(self.engine) + + def is_output_table_b(input_row, output_row): + return output_row.__table__.name == "table_b" + + graph = get_graph( + self.engine, + "table_a", + "1", + exclude_edge=is_output_table_b, + ) + + # The excluded edge links the first node to the second, so both the + # second and the third should be excluded. + self.assertEqual(len(graph.nodes), 1) + def _create_single_table_no_relations(self, engine): metadata_object = MetaData() table_a = Table(