From 359ffd87ccc9d13fdb6086befd1bae5b195da732 Mon Sep 17 00:00:00 2001 From: MrCurtis <4184070+MrCurtis@users.noreply.github.com> Date: Tue, 12 Sep 2023 13:58:13 +0000 Subject: [PATCH] Extend graph creation to include more complex cases. We can now deal with nodes that are more than one step from the input row. --- graph.py | 34 ++++++++++++++++++----- test_graph.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 6 deletions(-) diff --git a/graph.py b/graph.py index 35efdcd..289b74f 100644 --- a/graph.py +++ b/graph.py @@ -47,20 +47,42 @@ def get_graph(engine, table, primary_key): primary_key=_get_primary_key_from_row(row), ) graph.add_node(row_node) - relationships = row.__mapper__.relationships - for relationship in relationships: - # This is a bit hacky - but they don't call it a hackathon for nothing. - relationship_name = str(relationship).split(".")[-1] - related_rows = getattr(row, relationship_name) + _add_related_rows_to_graph(row, row_node, graph) + + return graph + + +def _add_related_rows_to_graph(row, row_node, graph): + relationships = row.__mapper__.relationships + for relationship in relationships: + # This is a bit hacky - but they don't call it a hackathon for nothing. + relationship_name = str(relationship).split(".")[-1] + related_rows = getattr(row, relationship_name) + try: + # This path for reverse foreign keys for related_row in related_rows: related_row_node = Node( table=_get_table_name_from_row(related_row), primary_key=_get_primary_key_from_row(related_row), ) + if related_row_node in graph.nodes(): + continue graph.add_node(related_row_node) graph.add_edge(row_node, related_row_node) + _add_related_rows_to_graph(related_row, related_row_node, graph) + except TypeError: + # This path for foreign keys. + related_row = related_rows + related_row_node = Node( + table=_get_table_name_from_row(related_row), + primary_key=_get_primary_key_from_row(related_row), + ) + if related_row_node in graph.nodes(): + continue + graph.add_node(related_row_node) + graph.add_edge(row_node, related_row_node) + _add_related_rows_to_graph(related_row, related_row_node, graph) - return graph def _get_table_name_from_row(row): diff --git a/test_graph.py b/test_graph.py index 6a50b92..42aba70 100644 --- a/test_graph.py +++ b/test_graph.py @@ -44,6 +44,40 @@ def test_can_build_from_reverse_foreign_key_relations(self): with self.subTest(): self.assertTrue(is_isomorphic(graph, expected_graph)) + def test_can_build_from_triple_row_linear_foreign_key_relations(self): + self._create_three_entires_with_linear_foreign_key_relations(self.engine) + node_1 = Node(table="table_c", primary_key=1) + node_2 = Node(table="table_b", primary_key=1) + node_3 = Node(table="table_a", primary_key=1) + expected_graph = Graph() + expected_graph.add_edge(node_1, node_2) + expected_graph.add_edge(node_2, node_3) + + graph = get_graph(self.engine, "table_c", 1) + + with self.subTest(): + self.assertCountEqual(graph.nodes, expected_graph.nodes) + + with self.subTest(): + self.assertTrue(is_isomorphic(graph, expected_graph)) + + def test_can_build_from_triple_row_linear_reverse_foreign_key_relations(self): + self._create_three_entires_with_linear_foreign_key_relations(self.engine) + node_1 = Node(table="table_a", primary_key=1) + node_2 = Node(table="table_b", primary_key=1) + node_3 = Node(table="table_c", primary_key=1) + expected_graph = Graph() + expected_graph.add_edge(node_1, node_2) + expected_graph.add_edge(node_2, node_3) + + graph = get_graph(self.engine, "table_a", 1) + + with self.subTest(): + self.assertCountEqual(graph.nodes, expected_graph.nodes) + + with self.subTest(): + self.assertTrue(is_isomorphic(graph, expected_graph)) + def _create_single_table_no_relations(self, engine): metadata_object = MetaData() table_a = Table( @@ -90,3 +124,44 @@ def _create_db_with_reverse_foreign_key_relations(self, engine): ] ) conn.commit() + + def _create_three_entires_with_linear_foreign_key_relations(self, engine): + metadata_object = MetaData() + table_a = Table( + "table_a", + metadata_object, + Column("id", Integer, primary_key=True), + ) + table_b = Table( + "table_b", + metadata_object, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey("table_a.id"), nullable=False), + ) + table_c = Table( + "table_c", + metadata_object, + Column("id", Integer, primary_key=True), + Column("b_id", ForeignKey("table_b.id"), nullable=False), + ) + metadata_object.create_all(engine) + with engine.connect() as conn: + conn.execute( + insert(table_a), + [ + {"id": 1}, + ] + ) + conn.execute( + insert(table_b), + [ + {"id": 1, "a_id": 1}, + ] + ) + conn.execute( + insert(table_c), + [ + {"id": 1, "b_id": 1}, + ] + ) + conn.commit()