diff --git a/graph.py b/graph.py index 35efdcd..b37ef37 100644 --- a/graph.py +++ b/graph.py @@ -14,13 +14,24 @@ class Node: table:str primary_key:typing.Any + data:dict[str, typing.Any] = None + def str(self): + """table.primary_key""" return f"{self.table}.{str(self.primary_key)}" + def str_data(self, max_length=25): + """Convert addtional data to string for plotly, using
for newlines.""" + s = '
'.join([f"{k}:{str(v)[:max_length]}" + for k, v in self.data.items()]) + return s + + def __repr__(self): return self.str() -#Node = namedtuple("Node", ["table", "primary_key"]) + def __str__(self): + return self.str() def get_graph(engine, table, primary_key): diff --git a/plot_graph.py b/plot_graph.py index 7f5fe7e..3c18e73 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -1,7 +1,7 @@ -from plotly_functions import process_graph, plot_v1 +from plotly_functions import plot_v2 import networkx as nx import plotly as ply def plot(graph:nx.Graph) -> ply.graph_objs.Figure: - fig = plot_v1(*process_graph(graph)) + fig = plot_v2(graph) return fig \ No newline at end of file diff --git a/plotly_functions.py b/plotly_functions.py index 55ee20b..c20173a 100644 --- a/plotly_functions.py +++ b/plotly_functions.py @@ -1,39 +1,54 @@ +import dataclasses + import networkx as nx import pandas as pd import numpy as np -from typing import Tuple, Any, Union, Collection, Mapping +from typing import Tuple, Any, Union, Collection, Mapping, NamedTuple, NewType from plotly import graph_objects as go from graph import Node -def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame): +# @dataclasses.dataclass +# class NodeCollection: +# nodes:list[Node] +# + +class XYValues(NamedTuple): + x: list[float | None] + y: list[float | None] + +NodeLayout = NewType('NodeLayout', dict[Node, tuple[float, float]]) + +def get_edge_xy( + graph:nx.Graph, + layout:NodeLayout +) -> XYValues: edge_x = [] edge_y = [] for edge in graph.edges(): - x0, y0 = node_xy.loc[edge[0], ['X', 'Y']] - x1, y1 = node_xy.loc[edge[1], ['X', 'Y']] + x0, y0 = layout[edge[0]] + x1, y1 = layout[edge[1]] edge_x.append(x0) edge_x.append(x1) edge_x.append(None) edge_y.append(y0) edge_y.append(y1) edge_y.append(None) - df = pd.DataFrame(dict(X=edge_x, Y=edge_y)) - return df + return XYValues(x=edge_x, y=edge_y) -def get_nodes_df(graph:nx.Graph) -> pd.DataFrame: - df = pd.DataFrame([ - {'Node':node, 'X':x, 'Y':y} - for (node, (x, y)) in nx.spring_layout(graph).items() - ]).set_index('Node') - #df.index = [str(i) for i in df.index] +def get_nodes_xy(layout:NodeLayout) -> XYValues: + xs = [] + ys = [] + for (node, (x, y)) in layout.items(): + xs.append(x) + ys.append(y) - return df + return XYValues(x=xs, y=ys) def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]: @@ -41,35 +56,27 @@ def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]: return nodes_df.drop(['X', 'Y', 'AnnotationText'], axis=1, errors='ignore').to_dict('index') -def textify_additional_data( - data: dict[str, dict[str, Any]], - max_length:int=25, -) -> dict[str, str]: - """Take data in form of {nodeID:{dataKey:dataValue}} and return text string for annotation.""" - - out = {} - - for nodeid, values in data.items(): - #todo truncate long strings - s = '\n'.join([f"{k}:{v}" for k, v in values.items()]) - out[nodeid] = s - - return out +# def process_graph(graph:nx.Graph) -> Tuple[NodeLayout, XYValues, XYValues]: +# """Turn graph obj into DF used by plotting function""" +# +# +# return layout, node_xy, edge_xy +def plot_v2( + graph +): + """With graph object of table/row relationships, + plot those as a network graph""" -def process_graph(graph:nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Turn graph obj into DF used by plotting function""" - nodes = get_nodes_df(graph) - edges = get_edges_df(graph, nodes) - - return nodes, edges + #todo show hide additional data + #todo colour by table + #todo hide tables -def plot_v1( - nodes:pd.DataFrame, - edges:pd.DataFrame -): - """Plot some simple data.""" + layout = NodeLayout(nx.spring_layout(graph)) + nodes = list(layout.keys()) + node_xy = get_nodes_xy(layout) + edge_xy = get_edge_xy(graph, layout) node_fmt = dict( size=2, @@ -81,16 +88,16 @@ def plot_v1( ) nodes_go = go.Scatter( - x=nodes.X, - y=nodes.Y, - ids=nodes.index.map(str), + x=node_xy.x, + y=node_xy.y, + mode='markers', marker=node_fmt, ) edges_go = go.Scatter( - x=edges.X, - y=edges.Y, + x=edge_xy.x, + y=edge_xy.y, mode='lines' ) @@ -108,28 +115,95 @@ def plot_v1( showticklabels=False, showgrid=False, zeroline=False - ) + ), + showlegend=False, ) - for k, row in nodes.iterrows(): - annotation = f"{k}" + for node in nodes: + annotation = f"{node.str()}" - if 'AnnotationText' in row.index: - more = row.AnnotationText - annotation = f"
{more}" + if node.data is not None: + annotation += '
'+node.str_data() fig.add_annotation( text=annotation, yanchor='bottom', bgcolor='lightgrey', - x=row.X, - y=row.Y, - ax=row.X, - ay=row.Y + x=layout[node][0], + y=layout[node][1], + ax=layout[node][0], + ay=layout[node][1], ) return fig +# def plot_v1( +# nodes:pd.DataFrame, +# edges:pd.DataFrame +# ): +# """Plot some simple data.""" +# +# node_fmt = dict( +# size=2, +# color='white', +# line=dict( +# color='lightslategrey', +# width=2, +# ) +# ) +# +# print(nodes.x, nodes.y) +# +# nodes_go = go.Scatter( +# x=nodes.X, +# y=nodes.Y, +# ids=nodes.index.map(str), +# mode='markers', +# marker=node_fmt, +# ) +# +# edges_go = go.Scatter( +# x=edges.X, +# y=edges.Y, +# mode='lines' +# ) +# +# fig = go.Figure() +# fig.add_trace(edges_go) +# fig.add_trace(nodes_go) +# +# fig.update_layout( +# xaxis=dict( +# showticklabels=False, +# showgrid=False, +# zeroline=False +# ), +# yaxis=dict( +# showticklabels=False, +# showgrid=False, +# zeroline=False +# ) +# ) +# +# for k, row in nodes.iterrows(): +# annotation = f"{k}" +# +# if 'AnnotationText' in row.index: +# more = row.AnnotationText +# annotation = f"
{more}" +# +# fig.add_annotation( +# text=annotation, +# yanchor='bottom', +# bgcolor='lightgrey', +# x=row.X, +# y=row.Y, +# ax=row.X, +# ay=row.Y +# ) +# +# return fig + def basic_graph(data=(('A', 'B'), ('B', 'C'), ('C', 'A'))) -> nx.Graph: """Get a nx.Graph with some edges and nodes.""" @@ -140,14 +214,8 @@ def basic_graph(data=(('A', 'B'), ('B', 'C'), ('C', 'A'))) -> nx.Graph: def basic_test(): G = basic_graph() - nodes_df = get_nodes_df(G) - nodes_df.loc[:, 'Table'] = ['TableA', 'TableB', 'TableB'] - nodes_df.loc[:, 'OtherInfo'] = ['Something', 'Some longish text oadijwoidjaodjw ad', np.nan] - nodes_df.loc[:, 'AnnotationText'] = textify_additional_data(get_info_dicts(nodes_df)) - - edges_xy = get_edges_df(G, nodes_df) - f = plot_v1(nodes_df, edges_xy) + f = plot_v2(G) f.show()