diff --git a/graph.py b/graph.py
index 35efdcd..b37ef37 100644
--- a/graph.py
+++ b/graph.py
@@ -14,13 +14,24 @@ class Node:
table:str
primary_key:typing.Any
+ data:dict[str, typing.Any] = None
+
def str(self):
+ """table.primary_key"""
return f"{self.table}.{str(self.primary_key)}"
+ def str_data(self, max_length=25):
+ """Convert addtional data to string for plotly, using
for newlines."""
+ s = '
'.join([f"{k}:{str(v)[:max_length]}"
+ for k, v in self.data.items()])
+ return s
+
+
def __repr__(self):
return self.str()
-#Node = namedtuple("Node", ["table", "primary_key"])
+ def __str__(self):
+ return self.str()
def get_graph(engine, table, primary_key):
diff --git a/plot_graph.py b/plot_graph.py
index 7f5fe7e..3c18e73 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -1,7 +1,7 @@
-from plotly_functions import process_graph, plot_v1
+from plotly_functions import plot_v2
import networkx as nx
import plotly as ply
def plot(graph:nx.Graph) -> ply.graph_objs.Figure:
- fig = plot_v1(*process_graph(graph))
+ fig = plot_v2(graph)
return fig
\ No newline at end of file
diff --git a/plotly_functions.py b/plotly_functions.py
index 55ee20b..c20173a 100644
--- a/plotly_functions.py
+++ b/plotly_functions.py
@@ -1,39 +1,54 @@
+import dataclasses
+
import networkx as nx
import pandas as pd
import numpy as np
-from typing import Tuple, Any, Union, Collection, Mapping
+from typing import Tuple, Any, Union, Collection, Mapping, NamedTuple, NewType
from plotly import graph_objects as go
from graph import Node
-def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame):
+# @dataclasses.dataclass
+# class NodeCollection:
+# nodes:list[Node]
+#
+
+class XYValues(NamedTuple):
+ x: list[float | None]
+ y: list[float | None]
+
+NodeLayout = NewType('NodeLayout', dict[Node, tuple[float, float]])
+
+def get_edge_xy(
+ graph:nx.Graph,
+ layout:NodeLayout
+) -> XYValues:
edge_x = []
edge_y = []
for edge in graph.edges():
- x0, y0 = node_xy.loc[edge[0], ['X', 'Y']]
- x1, y1 = node_xy.loc[edge[1], ['X', 'Y']]
+ x0, y0 = layout[edge[0]]
+ x1, y1 = layout[edge[1]]
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(None)
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(None)
- df = pd.DataFrame(dict(X=edge_x, Y=edge_y))
- return df
+ return XYValues(x=edge_x, y=edge_y)
-def get_nodes_df(graph:nx.Graph) -> pd.DataFrame:
- df = pd.DataFrame([
- {'Node':node, 'X':x, 'Y':y}
- for (node, (x, y)) in nx.spring_layout(graph).items()
- ]).set_index('Node')
- #df.index = [str(i) for i in df.index]
+def get_nodes_xy(layout:NodeLayout) -> XYValues:
+ xs = []
+ ys = []
+ for (node, (x, y)) in layout.items():
+ xs.append(x)
+ ys.append(y)
- return df
+ return XYValues(x=xs, y=ys)
def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]:
@@ -41,35 +56,27 @@ def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]:
return nodes_df.drop(['X', 'Y', 'AnnotationText'], axis=1, errors='ignore').to_dict('index')
-def textify_additional_data(
- data: dict[str, dict[str, Any]],
- max_length:int=25,
-) -> dict[str, str]:
- """Take data in form of {nodeID:{dataKey:dataValue}} and return text string for annotation."""
-
- out = {}
-
- for nodeid, values in data.items():
- #todo truncate long strings
- s = '\n'.join([f"{k}:{v}" for k, v in values.items()])
- out[nodeid] = s
-
- return out
+# def process_graph(graph:nx.Graph) -> Tuple[NodeLayout, XYValues, XYValues]:
+# """Turn graph obj into DF used by plotting function"""
+#
+#
+# return layout, node_xy, edge_xy
+def plot_v2(
+ graph
+):
+ """With graph object of table/row relationships,
+ plot those as a network graph"""
-def process_graph(graph:nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]:
- """Turn graph obj into DF used by plotting function"""
- nodes = get_nodes_df(graph)
- edges = get_edges_df(graph, nodes)
-
- return nodes, edges
+ #todo show hide additional data
+ #todo colour by table
+ #todo hide tables
-def plot_v1(
- nodes:pd.DataFrame,
- edges:pd.DataFrame
-):
- """Plot some simple data."""
+ layout = NodeLayout(nx.spring_layout(graph))
+ nodes = list(layout.keys())
+ node_xy = get_nodes_xy(layout)
+ edge_xy = get_edge_xy(graph, layout)
node_fmt = dict(
size=2,
@@ -81,16 +88,16 @@ def plot_v1(
)
nodes_go = go.Scatter(
- x=nodes.X,
- y=nodes.Y,
- ids=nodes.index.map(str),
+ x=node_xy.x,
+ y=node_xy.y,
+
mode='markers',
marker=node_fmt,
)
edges_go = go.Scatter(
- x=edges.X,
- y=edges.Y,
+ x=edge_xy.x,
+ y=edge_xy.y,
mode='lines'
)
@@ -108,28 +115,95 @@ def plot_v1(
showticklabels=False,
showgrid=False,
zeroline=False
- )
+ ),
+ showlegend=False,
)
- for k, row in nodes.iterrows():
- annotation = f"{k}"
+ for node in nodes:
+ annotation = f"{node.str()}"
- if 'AnnotationText' in row.index:
- more = row.AnnotationText
- annotation = f"
{more}"
+ if node.data is not None:
+ annotation += '
'+node.str_data()
fig.add_annotation(
text=annotation,
yanchor='bottom',
bgcolor='lightgrey',
- x=row.X,
- y=row.Y,
- ax=row.X,
- ay=row.Y
+ x=layout[node][0],
+ y=layout[node][1],
+ ax=layout[node][0],
+ ay=layout[node][1],
)
return fig
+# def plot_v1(
+# nodes:pd.DataFrame,
+# edges:pd.DataFrame
+# ):
+# """Plot some simple data."""
+#
+# node_fmt = dict(
+# size=2,
+# color='white',
+# line=dict(
+# color='lightslategrey',
+# width=2,
+# )
+# )
+#
+# print(nodes.x, nodes.y)
+#
+# nodes_go = go.Scatter(
+# x=nodes.X,
+# y=nodes.Y,
+# ids=nodes.index.map(str),
+# mode='markers',
+# marker=node_fmt,
+# )
+#
+# edges_go = go.Scatter(
+# x=edges.X,
+# y=edges.Y,
+# mode='lines'
+# )
+#
+# fig = go.Figure()
+# fig.add_trace(edges_go)
+# fig.add_trace(nodes_go)
+#
+# fig.update_layout(
+# xaxis=dict(
+# showticklabels=False,
+# showgrid=False,
+# zeroline=False
+# ),
+# yaxis=dict(
+# showticklabels=False,
+# showgrid=False,
+# zeroline=False
+# )
+# )
+#
+# for k, row in nodes.iterrows():
+# annotation = f"{k}"
+#
+# if 'AnnotationText' in row.index:
+# more = row.AnnotationText
+# annotation = f"
{more}"
+#
+# fig.add_annotation(
+# text=annotation,
+# yanchor='bottom',
+# bgcolor='lightgrey',
+# x=row.X,
+# y=row.Y,
+# ax=row.X,
+# ay=row.Y
+# )
+#
+# return fig
+
def basic_graph(data=(('A', 'B'), ('B', 'C'), ('C', 'A'))) -> nx.Graph:
"""Get a nx.Graph with some edges and nodes."""
@@ -140,14 +214,8 @@ def basic_graph(data=(('A', 'B'), ('B', 'C'), ('C', 'A'))) -> nx.Graph:
def basic_test():
G = basic_graph()
- nodes_df = get_nodes_df(G)
- nodes_df.loc[:, 'Table'] = ['TableA', 'TableB', 'TableB']
- nodes_df.loc[:, 'OtherInfo'] = ['Something', 'Some longish text oadijwoidjaodjw ad', np.nan]
- nodes_df.loc[:, 'AnnotationText'] = textify_additional_data(get_info_dicts(nodes_df))
-
- edges_xy = get_edges_df(G, nodes_df)
- f = plot_v1(nodes_df, edges_xy)
+ f = plot_v2(G)
f.show()