Skip to content

Commit

Permalink
Ensure rows added for only selected foreign-keys.
Browse files Browse the repository at this point in the history
Previously, when building the graph, the code would only add rows
realted through a reverse foreign-key relation only if they were
from a table included in the `only_tables` argument. However, rows
related by forward foreign-keys would be added regardless. This was a
mistake, and now the forward case is similar to the reverse case.
  • Loading branch information
MrCurtis committed Nov 10, 2023
1 parent fedca84 commit 9010799
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
8 changes: 4 additions & 4 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph:
data=_get_data(row),
)
graph.add_node(row_node)
_add_related_rows_to_graph(row, row_node, graph)
_add_related_rows_to_graph(row, row_node, graph, only_tables)

return graph

Expand All @@ -107,7 +107,7 @@ def _get_row(session, table, primary_key):
)


def _add_related_rows_to_graph(row, row_node, graph):
def _add_related_rows_to_graph(row, row_node, graph, only_tables):
related = []
relationships = row.__mapper__.relationships
for relationship in relationships:
Expand All @@ -127,7 +127,7 @@ def _add_related_rows_to_graph(row, row_node, graph):
# This path for foreign keys.
related_row = related_rows
# Ignore null foreign-keys.
if related_row is not None:
if related_row is not None and (only_tables is None or related_row.__table__.name in only_tables):
related_node = Node(
table=_get_table_name_from_row(related_row),
primary_key=_get_primary_key_from_row(related_row),
Expand All @@ -141,7 +141,7 @@ def _add_related_rows_to_graph(row, row_node, graph):
for _, related_node in related:
graph.add_edge(row_node, related_node)
for unvisited_row, unvisited_node in unvisited:
_add_related_rows_to_graph(unvisited_row, unvisited_node, graph)
_add_related_rows_to_graph(unvisited_row, unvisited_node, graph, only_tables)


def _get_table_name_from_row(row):
Expand Down
18 changes: 17 additions & 1 deletion test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_raises_primary_key_does_not_exist_when_no_key_in_table(self):
):
get_graph(self.engine, "table_a", "9876")

def test_can_restrict_to_selected_tables(self):
def test_can_restrict_to_selected_tables__reverse_foreign_key_case(self):
self._create_three_entries_with_linear_foreign_key_relations(self.engine)

graph = get_graph(self.engine, "table_a", "1", only_tables=["table_a", "table_b"])
Expand All @@ -151,6 +151,22 @@ def test_can_restrict_to_selected_tables(self):
any([n.table == "table_c" for n in graph.nodes])
)

def test_can_restrict_to_selected_tables__forward_foreign_key_case(self):
self._create_three_entries_with_linear_foreign_key_relations(self.engine)

graph = get_graph(self.engine, "table_c", "1", only_tables=["table_c", "table_b"])

with self.subTest("includes selected tables"):
self.assertTrue(
any([n.table == "table_c" for n in graph.nodes])
and
any([n.table == "table_b" for n in graph.nodes])
)
with self.subTest("excludes non-selected tables"):
self.assertFalse(
any([n.table == "table_a" for n in graph.nodes])
)

def test_can_create_graph_when_some_rows_have_null_foreign_keys(self):
self._create_entries_with_null_foreign_key(self.engine)

Expand Down

0 comments on commit 9010799

Please sign in to comment.