diff --git a/src/fk_graph/graph.py b/src/fk_graph/graph.py index a7bf87f..571572d 100644 --- a/src/fk_graph/graph.py +++ b/src/fk_graph/graph.py @@ -52,7 +52,13 @@ def __str__(self): return self.str() -def get_graph(engine, table, primary_key, only_tables=None) -> Graph: +def get_graph( + engine, + table, + primary_key, + only_tables=None, + exclude_edge=None, +) -> Graph: """Construct the graph for a specified data-point Args: @@ -81,7 +87,7 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph: row = _get_row(session, _table, primary_key) row_node = _create_node_from_row(row) graph.add_node(row_node) - _add_related_rows_to_graph(row, row_node, graph, only_tables) + _add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge) return graph @@ -104,13 +110,17 @@ def _get_row(session, table, primary_key): ) -def _add_related_rows_to_graph(row, row_node, graph, only_tables): +def _add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge): related = [] relationships = row.__mapper__.relationships for relationship in relationships: related_rows = _get_related_rows_for_relationship(row, relationship) for related_row in related_rows: - if _row_is_from_an_included_table(related_row, only_tables): + if ( + _row_is_from_an_included_table(related_row, only_tables) + and + _edge_is_not_excluded(exclude_edge, row, related_row) + ): related_node = _create_node_from_row(related_row) related.append((related_row, related_node)) unvisited = [ @@ -120,7 +130,12 @@ def _add_related_rows_to_graph(row, row_node, graph, only_tables): for _, related_node in related: graph.add_edge(row_node, related_node) for unvisited_row, unvisited_node in unvisited: - _add_related_rows_to_graph(unvisited_row, unvisited_node, graph, only_tables) + _add_related_rows_to_graph( + unvisited_row, + unvisited_node, + graph, only_tables, + exclude_edge + ) def _create_node_from_row(row): return Node( @@ -136,6 +151,13 @@ def _row_is_from_an_included_table(row, only_tables): row.__table__.name in only_tables ) +def _edge_is_not_excluded(exclude_edge, row, related_row): + return ( + exclude_edge is None + or + not exclude_edge(row, related_row) + ) + def _get_related_rows_for_relationship(row, relationship): relationship_name = _get_relationship_name(relationship) related_rows = getattr(row, relationship_name)