diff --git a/graph.py b/graph.py
index fd98ec9..35efdcd 100644
--- a/graph.py
+++ b/graph.py
@@ -17,6 +17,9 @@ class Node:
def str(self):
return f"{self.table}.{str(self.primary_key)}"
+ def __repr__(self):
+ return self.str()
+
#Node = namedtuple("Node", ["table", "primary_key"])
diff --git a/plot_graph.py b/plot_graph.py
index 7962a7c..7f5fe7e 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -1,2 +1,7 @@
-def plot(graph):
- raise NotImplementedError
\ No newline at end of file
+from plotly_functions import process_graph, plot_v1
+import networkx as nx
+import plotly as ply
+
+def plot(graph:nx.Graph) -> ply.graph_objs.Figure:
+ fig = plot_v1(*process_graph(graph))
+ return fig
\ No newline at end of file
diff --git a/plotly_functions.py b/plotly_functions.py
index 5192135..55ee20b 100644
--- a/plotly_functions.py
+++ b/plotly_functions.py
@@ -1,5 +1,4 @@
import networkx as nx
-import dash
import pandas as pd
import numpy as np
@@ -8,6 +7,8 @@
from plotly import graph_objects as go
+from graph import Node
+
def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame):
edge_x = []
edge_y = []
@@ -25,11 +26,15 @@ def get_edges_df(graph:nx.Graph, node_xy:pd.DataFrame):
def get_nodes_df(graph:nx.Graph) -> pd.DataFrame:
- return pd.DataFrame([
+ df = pd.DataFrame([
{'Node':node, 'X':x, 'Y':y}
for (node, (x, y)) in nx.spring_layout(graph).items()
]).set_index('Node')
+ #df.index = [str(i) for i in df.index]
+
+ return df
+
def get_info_dicts(nodes_df: pd.DataFrame) -> dict[str, dict[str, Any]]:
"""Dicts, keyed by index (node) then column."""
@@ -52,6 +57,14 @@ def textify_additional_data(
return out
+
+def process_graph(graph:nx.Graph) -> Tuple[pd.DataFrame, pd.DataFrame]:
+ """Turn graph obj into DF used by plotting function"""
+ nodes = get_nodes_df(graph)
+ edges = get_edges_df(graph, nodes)
+
+ return nodes, edges
+
def plot_v1(
nodes:pd.DataFrame,
edges:pd.DataFrame
@@ -70,7 +83,7 @@ def plot_v1(
nodes_go = go.Scatter(
x=nodes.X,
y=nodes.Y,
- ids=nodes.index,
+ ids=nodes.index.map(str),
mode='markers',
marker=node_fmt,
)
@@ -99,8 +112,22 @@ def plot_v1(
)
for k, row in nodes.iterrows():
- fig.add_annotation(text=row.AnnotationText, yanchor='bottom', bgcolor='lightgrey', x=row.X, y=row.Y, ax=row.X,
- ay=row.Y)
+ annotation = f"{k}"
+
+ if 'AnnotationText' in row.index:
+ more = row.AnnotationText
+ annotation = f"
{more}"
+
+ fig.add_annotation(
+ text=annotation,
+ yanchor='bottom',
+ bgcolor='lightgrey',
+ x=row.X,
+ y=row.Y,
+ ax=row.X,
+ ay=row.Y
+ )
+
return fig
@@ -124,5 +151,7 @@ def basic_test():
f.show()
+
+
if __name__ == '__main__':
basic_test()
\ No newline at end of file