Skip to content

Commit

Permalink
plot graph version 2, lots of internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
johncthomas committed Sep 12, 2023
1 parent 4542892 commit 0d0ea6b
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 63 deletions.
13 changes: 12 additions & 1 deletion graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@ class Node:
table:str
primary_key:typing.Any

data:dict[str, typing.Any] = None

def str(self):
"""table.primary_key"""
return f"{self.table}.{str(self.primary_key)}"

def str_data(self, max_length=25):
"""Convert addtional data to string for plotly, using <br> for newlines."""
s = '<br>'.join([f"{k}:{str(v)[:max_length]}"
for k, v in self.data.items()])
return s


def __repr__(self):
return self.str()

#Node = namedtuple("Node", ["table", "primary_key"])
def __str__(self):
return self.str()


def get_graph(engine, table, primary_key):
Expand Down
4 changes: 2 additions & 2 deletions plot_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from plotly_functions import process_graph, plot_v1
from plotly_functions import plot_v2
import networkx as nx
import plotly as ply

def plot(graph:nx.Graph) -> ply.graph_objs.Figure:
fig = plot_v1(*process_graph(graph))
fig = plot_v2(graph)
return fig
188 changes: 128 additions & 60 deletions plotly_functions.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,82 @@
import dataclasses

import networkx as nx

import pandas as pd
import numpy as np

from typing import Tuple, Any, Union, Collection, Mapping
from typing import Tuple, Any, Union, Collection, Mapping, NamedTuple, NewType

from plotly import graph_objects as go

from graph import Node

def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame):
# @dataclasses.dataclass
# class NodeCollection:
# nodes:list[Node]
#

class XYValues(NamedTuple):
x: list[float | None]
y: list[float | None]

NodeLayout = NewType('NodeLayout', dict[Node, tuple[float, float]])

def get_edge_xy(
graph:nx.Graph,
layout:NodeLayout
) -> XYValues:
edge_x = []
edge_y = []
for edge in graph.edges():
x0, y0 = node_xy.loc[edge[0], ['X', 'Y']]
x1, y1 = node_xy.loc[edge[1], ['X', 'Y']]
x0, y0 = layout[edge[0]]
x1, y1 = layout[edge[1]]
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(None)
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(None)
df = pd.DataFrame(dict(X=edge_x, Y=edge_y))
return df

return XYValues(x=edge_x, y=edge_y)

def get_nodes_df(graph:nx.Graph) -> pd.DataFrame:
df = pd.DataFrame([
{'Node':node, 'X':x, 'Y':y}
for (node, (x, y)) in nx.spring_layout(graph).items()
]).set_index('Node')

#df.index = [str(i) for i in df.index]
def get_nodes_xy(layout:NodeLayout) -> XYValues:
xs = []
ys = []
for (node, (x, y)) in layout.items():
xs.append(x)
ys.append(y)

return df
return XYValues(x=xs, y=ys)


def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]:
"""Dicts, keyed by index (node) then column."""
return nodes_df.drop(['X', 'Y', 'AnnotationText'], axis=1, errors='ignore').to_dict('index')


def textify_additional_data(
data: dict[str, dict[str, Any]],
max_length:int=25,
) -> dict[str, str]:
"""Take data in form of {nodeID:{dataKey:dataValue}} and return text string for annotation."""

out = {}

for nodeid, values in data.items():
#todo truncate long strings
s = '\n'.join([f"{k}:{v}" for k, v in values.items()])
out[nodeid] = s

return out
# def process_graph(graph:nx.Graph) -> Tuple[NodeLayout, XYValues, XYValues]:
# """Turn graph obj into DF used by plotting function"""
#
#
# return layout, node_xy, edge_xy


def plot_v2(
graph
):
"""With graph object of table/row relationships,
plot those as a network graph"""

def process_graph(graph:nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Turn graph obj into DF used by plotting function"""
nodes = get_nodes_df(graph)
edges = get_edges_df(graph, nodes)

return nodes, edges
#todo show hide additional data
#todo colour by table
#todo hide tables

def plot_v1(
nodes:pd.DataFrame,
edges:pd.DataFrame
):
"""Plot some simple data."""
layout = NodeLayout(nx.spring_layout(graph))
nodes = list(layout.keys())
node_xy = get_nodes_xy(layout)
edge_xy = get_edge_xy(graph, layout)

node_fmt = dict(
size=2,
Expand All @@ -81,16 +88,16 @@ def plot_v1(
)

nodes_go = go.Scatter(
x=nodes.X,
y=nodes.Y,
ids=nodes.index.map(str),
x=node_xy.x,
y=node_xy.y,

mode='markers',
marker=node_fmt,
)

edges_go = go.Scatter(
x=edges.X,
y=edges.Y,
x=edge_xy.x,
y=edge_xy.y,
mode='lines'
)

Expand All @@ -108,28 +115,95 @@ def plot_v1(
showticklabels=False,
showgrid=False,
zeroline=False
)
),
showlegend=False,
)

for k, row in nodes.iterrows():
annotation = f"<b>{k}</b>"
for node in nodes:
annotation = f"<b>{node.str()}</b>"

if 'AnnotationText' in row.index:
more = row.AnnotationText
annotation = f"<br>{more}"
if node.data is not None:
annotation += '<br>'+node.str_data()

fig.add_annotation(
text=annotation,
yanchor='bottom',
bgcolor='lightgrey',
x=row.X,
y=row.Y,
ax=row.X,
ay=row.Y
x=layout[node][0],
y=layout[node][1],
ax=layout[node][0],
ay=layout[node][1],
)

return fig

# def plot_v1(
# nodes:pd.DataFrame,
# edges:pd.DataFrame
# ):
# """Plot some simple data."""
#
# node_fmt = dict(
# size=2,
# color='white',
# line=dict(
# color='lightslategrey',
# width=2,
# )
# )
#
# print(nodes.x, nodes.y)
#
# nodes_go = go.Scatter(
# x=nodes.X,
# y=nodes.Y,
# ids=nodes.index.map(str),
# mode='markers',
# marker=node_fmt,
# )
#
# edges_go = go.Scatter(
# x=edges.X,
# y=edges.Y,
# mode='lines'
# )
#
# fig = go.Figure()
# fig.add_trace(edges_go)
# fig.add_trace(nodes_go)
#
# fig.update_layout(
# xaxis=dict(
# showticklabels=False,
# showgrid=False,
# zeroline=False
# ),
# yaxis=dict(
# showticklabels=False,
# showgrid=False,
# zeroline=False
# )
# )
#
# for k, row in nodes.iterrows():
# annotation = f"<b>{k}</b>"
#
# if 'AnnotationText' in row.index:
# more = row.AnnotationText
# annotation = f"<br>{more}"
#
# fig.add_annotation(
# text=annotation,
# yanchor='bottom',
# bgcolor='lightgrey',
# x=row.X,
# y=row.Y,
# ax=row.X,
# ay=row.Y
# )
#
# return fig


def basic_graph(data=(('A', 'B'), ('B', 'C'), ('C', 'A'))) -> nx.Graph:
"""Get a nx.Graph with some edges and nodes."""
Expand All @@ -140,14 +214,8 @@ def basic_graph(data=(('A', 'B'), ('B', 'C'), ('C', 'A'))) -> nx.Graph:

def basic_test():
G = basic_graph()
nodes_df = get_nodes_df(G)
nodes_df.loc[:, 'Table'] = ['TableA', 'TableB', 'TableB']
nodes_df.loc[:, 'OtherInfo'] = ['Something', 'Some longish text oadijwoidjaodjw ad', np.nan]
nodes_df.loc[:, 'AnnotationText'] = textify_additional_data(get_info_dicts(nodes_df))

edges_xy = get_edges_df(G, nodes_df)

f = plot_v1(nodes_df, edges_xy)
f = plot_v2(G)
f.show()


Expand Down

0 comments on commit 0d0ea6b

Please sign in to comment.