Skip to content

Commit

Permalink
Allow exclusion of specified edges.
Browse files Browse the repository at this point in the history
By use of a function, the user can specify which edges should be
excluded from the returned graph.
  • Loading branch information
MrCurtis committed Jan 29, 2024
1 parent c4e1bcc commit 8e3ee83
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "fk-graph"
packages = ["src/fk_graph"]
version = "0.0.10"
version = "0.0.11"
authors = [
{ name="Andrew Curtis", email="[email protected]" },
{ name="John C Thomas" },
Expand Down
39 changes: 34 additions & 5 deletions src/fk_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def __str__(self):
return self.str()


def get_graph(engine, table, primary_key, only_tables=None) -> Graph:
def get_graph(
engine,
table,
primary_key,
only_tables=None,
exclude_edge=None,
) -> Graph:
"""Construct the graph for a specified data-point
Args:
Expand All @@ -61,6 +67,13 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph:
primary_key: The primary key for the row.
only_tables: (Optional) A list of table names. Any rows for tables not
in the list will not be included in the graph.
exclude_edge: (Optional) A callable used to determine whether an edge
will be included in the graph. It should be of the form
`f(input_row, output_row) -> bool`, where the rows are SQL-Alchemy
`Row` instances. If the function returns `True` for a given pair of
nodes, then the corresponding edge is not included in the graph. The
returned graph is always connected, so any nodes that can only be
reached via the edge will also be omitted.
Raises:
TableDoesNotExist - when the specified table does not exist.
Expand All @@ -80,7 +93,7 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph:
row = _get_row(session, _table, primary_key)
row_node = _create_node_from_row(row)
graph.add_node(row_node)
_add_related_rows_to_graph(row, row_node, graph, only_tables)
_add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge)

return graph

Expand All @@ -103,13 +116,17 @@ def _get_row(session, table, primary_key):
)


def _add_related_rows_to_graph(row, row_node, graph, only_tables):
def _add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge):
related = []
relationships = row.__mapper__.relationships
for relationship in relationships:
related_rows = _get_related_rows_for_relationship(row, relationship)
for related_row in related_rows:
if _row_is_from_an_included_table(related_row, only_tables):
if (
_row_is_from_an_included_table(related_row, only_tables)
and
_edge_is_not_excluded(exclude_edge, row, related_row)
):
related_node = _create_node_from_row(related_row)
related.append((related_row, related_node))
unvisited = [
Expand All @@ -119,7 +136,12 @@ def _add_related_rows_to_graph(row, row_node, graph, only_tables):
for _, related_node in related:
graph.add_edge(row_node, related_node)
for unvisited_row, unvisited_node in unvisited:
_add_related_rows_to_graph(unvisited_row, unvisited_node, graph, only_tables)
_add_related_rows_to_graph(
unvisited_row,
unvisited_node,
graph, only_tables,
exclude_edge
)

def _create_node_from_row(row):
return Node(
Expand All @@ -135,6 +157,13 @@ def _row_is_from_an_included_table(row, only_tables):
row.__table__.name in only_tables
)

def _edge_is_not_excluded(exclude_edge, row, related_row):
return (
exclude_edge is None
or
not exclude_edge(row, related_row)
)

def _get_related_rows_for_relationship(row, relationship):
relationship_name = _get_relationship_name(relationship)
related_rows = getattr(row, relationship_name)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,39 @@ def test_can_create_graph_when_some_rows_have_null_foreign_keys(self):
# Check does not add null nodes to graph.
self.assertEqual(len(graph.nodes), 1)

def test_excludes_edges(self):
self._create_three_entries_with_linear_foreign_key_relations(self.engine)

def is_output_table_c(input_row, output_row):
return output_row.__table__.name == "table_c"

graph = get_graph(
self.engine,
"table_a",
"1",
exclude_edge=is_output_table_c,
)

# The last node should be excluded.
self.assertEqual(len(graph.nodes), 2)

def test_excluding_edges_excludes_all_nodes_only_reachable_via_these_edges(self):
self._create_three_entries_with_linear_foreign_key_relations(self.engine)

def is_output_table_b(input_row, output_row):
return output_row.__table__.name == "table_b"

graph = get_graph(
self.engine,
"table_a",
"1",
exclude_edge=is_output_table_b,
)

# The excluded edge links the first node to the second, so both the
# second and the third should be excluded.
self.assertEqual(len(graph.nodes), 1)

def _create_single_table_no_relations(self, engine):
metadata_object = MetaData()
table_a = Table(
Expand Down

0 comments on commit 8e3ee83

Please sign in to comment.