Skip to content

Commit

Permalink
Merge pull request #11 from MrCurtis/connectingGraphPlot
Browse files Browse the repository at this point in the history
Process and plot nx.Graph with plot_graph.plot(graph)
  • Loading branch information
johncthomas authored Sep 12, 2023
2 parents e4c015b + e70a748 commit a154309
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
3 changes: 3 additions & 0 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class Node:
def str(self):
return f"{self.table}.{str(self.primary_key)}"

def __repr__(self):
return self.str()

#Node = namedtuple("Node", ["table", "primary_key"])


Expand Down
9 changes: 7 additions & 2 deletions plot_graph.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
def plot(graph):
raise NotImplementedError
from plotly_functions import process_graph, plot_v1
import networkx as nx
import plotly as ply

def plot(graph:nx.Graph) -> ply.graph_objs.Figure:
fig = plot_v1(*process_graph(graph))
return fig
39 changes: 34 additions & 5 deletions plotly_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import networkx as nx
import dash

import pandas as pd
import numpy as np
Expand All @@ -8,6 +7,8 @@

from plotly import graph_objects as go

from graph import Node

def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame):
edge_x = []
edge_y = []
Expand All @@ -25,11 +26,15 @@ def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame):


def get_nodes_df(graph:nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
df = pd.DataFrame([
{'Node':node, 'X':x, 'Y':y}
for (node, (x, y)) in nx.spring_layout(graph).items()
]).set_index('Node')

#df.index = [str(i) for i in df.index]

return df


def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]:
"""Dicts, keyed by index (node) then column."""
Expand All @@ -52,6 +57,14 @@ def textify_additional_data(
return out



def process_graph(graph:nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Turn graph obj into DF used by plotting function"""
nodes = get_nodes_df(graph)
edges = get_edges_df(graph, nodes)

return nodes, edges

def plot_v1(
nodes:pd.DataFrame,
edges:pd.DataFrame
Expand All @@ -70,7 +83,7 @@ def plot_v1(
nodes_go = go.Scatter(
x=nodes.X,
y=nodes.Y,
ids=nodes.index,
ids=nodes.index.map(str),
mode='markers',
marker=node_fmt,
)
Expand Down Expand Up @@ -99,8 +112,22 @@ def plot_v1(
)

for k, row in nodes.iterrows():
fig.add_annotation(text=row.AnnotationText, yanchor='bottom', bgcolor='lightgrey', x=row.X, y=row.Y, ax=row.X,
ay=row.Y)
annotation = f"<b>{k}</b>"

if 'AnnotationText' in row.index:
more = row.AnnotationText
annotation = f"<br>{more}"

fig.add_annotation(
text=annotation,
yanchor='bottom',
bgcolor='lightgrey',
x=row.X,
y=row.Y,
ax=row.X,
ay=row.Y
)

return fig


Expand All @@ -124,5 +151,7 @@ def basic_test():
f.show()




if __name__ == '__main__':
basic_test()

0 comments on commit a154309

Please sign in to comment.