diff --git a/graph.py b/graph.py index fd98ec9..35efdcd 100644 --- a/graph.py +++ b/graph.py @@ -17,6 +17,9 @@ class Node: def str(self): return f"{self.table}.{str(self.primary_key)}" + def __repr__(self): + return self.str() + #Node = namedtuple("Node", ["table", "primary_key"]) diff --git a/plot_graph.py b/plot_graph.py index 7962a7c..7f5fe7e 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -1,2 +1,7 @@ -def plot(graph): - raise NotImplementedError \ No newline at end of file +from plotly_functions import process_graph, plot_v1 +import networkx as nx +import plotly as ply + +def plot(graph:nx.Graph) -> ply.graph_objs.Figure: + fig = plot_v1(*process_graph(graph)) + return fig \ No newline at end of file diff --git a/plotly_functions.py b/plotly_functions.py index 5192135..55ee20b 100644 --- a/plotly_functions.py +++ b/plotly_functions.py @@ -1,5 +1,4 @@ import networkx as nx -import dash import pandas as pd import numpy as np @@ -8,6 +7,8 @@ from plotly import graph_objects as go +from graph import Node + def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame): edge_x = [] edge_y = [] @@ -25,11 +26,15 @@ def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame): def get_nodes_df(graph:nx.Graph) -> pd.DataFrame: - return pd.DataFrame([ + df = pd.DataFrame([ {'Node':node, 'X':x, 'Y':y} for (node, (x, y)) in nx.spring_layout(graph).items() ]).set_index('Node') + #df.index = [str(i) for i in df.index] + + return df + def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]: """Dicts, keyed by index (node) then column.""" @@ -52,6 +57,14 @@ def textify_additional_data( return out + +def process_graph(graph:nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Turn graph obj into DF used by plotting function""" + nodes = get_nodes_df(graph) + edges = get_edges_df(graph, nodes) + + return nodes, edges + def plot_v1( nodes:pd.DataFrame, edges:pd.DataFrame @@ -70,7 +83,7 @@ def plot_v1( nodes_go = go.Scatter( x=nodes.X, y=nodes.Y, - ids=nodes.index, + ids=nodes.index.map(str), mode='markers', marker=node_fmt, ) @@ -99,8 +112,22 @@ def plot_v1( ) for k, row in nodes.iterrows(): - fig.add_annotation(text=row.AnnotationText, yanchor='bottom', bgcolor='lightgrey', x=row.X, y=row.Y, ax=row.X, - ay=row.Y) + annotation = f"{k}" + + if 'AnnotationText' in row.index: + more = row.AnnotationText + annotation = f"
{more}" + + fig.add_annotation( + text=annotation, + yanchor='bottom', + bgcolor='lightgrey', + x=row.X, + y=row.Y, + ax=row.X, + ay=row.Y + ) + return fig @@ -124,5 +151,7 @@ def basic_test(): f.show() + + if __name__ == '__main__': basic_test() \ No newline at end of file