diff --git a/src/fk_graph/graph.py b/src/fk_graph/graph.py index 7ec57c2..f46dd4f 100644 --- a/src/fk_graph/graph.py +++ b/src/fk_graph/graph.py @@ -78,11 +78,7 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph: graph = Graph() with Session(engine) as session: row = _get_row(session, _table, primary_key) - row_node = Node( - table=_get_table_name_from_row(row), - primary_key=_get_primary_key_from_row(row), - data=_get_data(row), - ) + row_node = _create_node_from_row(row) graph.add_node(row_node) _add_related_rows_to_graph(row, row_node, graph, only_tables) @@ -111,28 +107,10 @@ def _add_related_rows_to_graph(row, row_node, graph, only_tables): related = [] relationships = row.__mapper__.relationships for relationship in relationships: - # This is a bit hacky - but they don't call it a hackathon for nothing. - relationship_name = str(relationship).split(".")[-1] - related_rows = getattr(row, relationship_name) - try: - # This path for reverse foreign keys - for related_row in related_rows: - related_node = Node( - table=_get_table_name_from_row(related_row), - primary_key=_get_primary_key_from_row(related_row), - data=_get_data(related_row), - ) - related.append((related_row, related_node)) - except TypeError: - # This path for foreign keys. - related_row = related_rows - # Ignore null foreign-keys. - if related_row is not None and (only_tables is None or related_row.__table__.name in only_tables): - related_node = Node( - table=_get_table_name_from_row(related_row), - primary_key=_get_primary_key_from_row(related_row), - data=_get_data(related_row), - ) + related_rows = _get_related_rows_for_relationship(row, relationship) + for related_row in related_rows: + if _row_is_from_an_included_table(related_row, only_tables): + related_node = _create_node_from_row(related_row) related.append((related_row, related_node)) unvisited = [ (row, node) for (row, node) in related @@ -143,6 +121,37 @@ def _add_related_rows_to_graph(row, row_node, graph, only_tables): for unvisited_row, unvisited_node in unvisited: _add_related_rows_to_graph(unvisited_row, unvisited_node, graph, only_tables) +def _create_node_from_row(row): + return Node( + table=_get_table_name_from_row(row), + primary_key=_get_primary_key_from_row(row), + data=_get_data(row), + ) + +def _row_is_from_an_included_table(row, only_tables): + return ( + only_tables is None + or + row.__table__.name in only_tables + ) + +def _get_related_rows_for_relationship(row, relationship): + relationship_name = _get_relationship_name(relationship) + related_rows = getattr(row, relationship_name) + # We always return a list of rows, to ensure that subsequent code is simpler. + try: + for _ in related_rows: + pass + except TypeError: + if related_rows is None: + return [] + return [related_rows] + else: + return related_rows + +def _get_relationship_name(relationship): + # This is a bit hacky - but they don't call it a hackathon for nothing. + return str(relationship).split(".")[-1] def _get_table_name_from_row(row): return row.__table__.name