Skip to content

Commit

Permalink
Add exclude edge check
Browse files Browse the repository at this point in the history
  • Loading branch information
MrCurtis committed Jan 28, 2024
1 parent 9a1a3f7 commit 8a2d652
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions src/fk_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def __str__(self):
return self.str()


def get_graph(engine, table, primary_key, only_tables=None) -> Graph:
def get_graph(
engine,
table,
primary_key,
only_tables=None,
exclude_edge=None,
) -> Graph:
"""Construct the graph for a specified data-point
Args:
Expand Down Expand Up @@ -81,7 +87,7 @@ def get_graph(engine, table, primary_key, only_tables=None) -> Graph:
row = _get_row(session, _table, primary_key)
row_node = _create_node_from_row(row)
graph.add_node(row_node)
_add_related_rows_to_graph(row, row_node, graph, only_tables)
_add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge)

return graph

Expand All @@ -104,13 +110,17 @@ def _get_row(session, table, primary_key):
)


def _add_related_rows_to_graph(row, row_node, graph, only_tables):
def _add_related_rows_to_graph(row, row_node, graph, only_tables, exclude_edge):
related = []
relationships = row.__mapper__.relationships
for relationship in relationships:
related_rows = _get_related_rows_for_relationship(row, relationship)
for related_row in related_rows:
if _row_is_from_an_included_table(related_row, only_tables):
if (
_row_is_from_an_included_table(related_row, only_tables)
and
_edge_is_not_excluded(exclude_edge, row, related_row)
):
related_node = _create_node_from_row(related_row)
related.append((related_row, related_node))
unvisited = [
Expand All @@ -120,7 +130,12 @@ def _add_related_rows_to_graph(row, row_node, graph, only_tables):
for _, related_node in related:
graph.add_edge(row_node, related_node)
for unvisited_row, unvisited_node in unvisited:
_add_related_rows_to_graph(unvisited_row, unvisited_node, graph, only_tables)
_add_related_rows_to_graph(
unvisited_row,
unvisited_node,
graph, only_tables,
exclude_edge
)

def _create_node_from_row(row):
return Node(
Expand All @@ -136,6 +151,13 @@ def _row_is_from_an_included_table(row, only_tables):
row.__table__.name in only_tables
)

def _edge_is_not_excluded(exclude_edge, row, related_row):
return (
exclude_edge is None
or
not exclude_edge(row, related_row)
)

def _get_related_rows_for_relationship(row, relationship):
relationship_name = _get_relationship_name(relationship)
related_rows = getattr(row, relationship_name)
Expand Down

0 comments on commit 8a2d652

Please sign in to comment.