diff --git a/bigraph_viz/__init__.py b/bigraph_viz/__init__.py index 80d6991..12991ad 100644 --- a/bigraph_viz/__init__.py +++ b/bigraph_viz/__init__.py @@ -1,3 +1,17 @@ +import pprint from bigraph_viz.plot import plot_bigraph, plot_flow, plot_multitimestep from bigraph_viz.dict_utils import pp, pf, schema_state_to_dict from bigraph_viz.convert import convert_vivarium_composite + + +pretty = pprint.PrettyPrinter(indent=2) + + +def pp(x): + """Print ``x`` in a pretty format.""" + pretty.pprint(x) + + +def pf(x): + """Format ``x`` for display.""" + return pretty.pformat(x) \ No newline at end of file diff --git a/bigraph_viz/diagram.py b/bigraph_viz/diagram.py new file mode 100644 index 0000000..fede0a8 --- /dev/null +++ b/bigraph_viz/diagram.py @@ -0,0 +1,423 @@ +""" +Bigraph diagram +""" +import os +from bigraph_schema import TypeSystem, Edge +from bigraph_viz.plot import absolute_path, make_label, check_if_path_in_removed_nodes +import graphviz + + +PROCESS_SCHEMA_KEYS = ['config', 'address', 'interval', 'inputs', 'outputs'] + + +step_type = { + '_type': 'step', + '_inherit': 'edge', + 'address': 'string', + 'config': 'schema'} + + +process_type = { + '_type': 'process', + '_inherit': 'step', + 'interval': 'float'} + + +def get_graph_wires(schema, wires, graph_dict, schema_key, edge_path, port): + + if isinstance(schema, dict) and schema: + for port, subschema in schema.items(): + subwire = wires.get(port) + if subwire: + graph_dict = get_graph_wires( + subschema, subwire, graph_dict, schema_key, edge_path, port) + else: # this is a disconnected port + graph_dict['disconnected_hyper_edges'].append({ + 'edge_path': edge_path, + 'port': port, + 'type': schema_key}) + elif isinstance(wires, dict): + for port, subwire in wires.items(): + subschema = schema.get(port, schema) + graph_dict = get_graph_wires( + subschema, subwire, graph_dict, schema_key, edge_path, port) + elif isinstance(wires, (list, tuple)): + target_path = absolute_path(edge_path[:-1], tuple(wires)) # TODO -- make sure this resolves ".." + graph_dict['hyper_edges'].append({ + 'edge_path': edge_path, + 'target_path': target_path, + 'port': port, + 'type': schema_key}) + else: + raise ValueError(f"Unexpected wire type: {wires}") + + return graph_dict + + +def get_graph_dict( + schema, + state, + core, + graph_dict=None, + path=None, + top_state=None, + retain_type_keys=False, + retain_process_keys=False, + remove_nodes=None, +): + path = path or () + top_state = top_state or state + remove_nodes = remove_nodes or [] + + # initialize bigraph + graph_dict = graph_dict or { + 'state_nodes': [], + 'process_nodes': [], + 'place_edges': [], + 'hyper_edges': [], + 'disconnected_hyper_edges': [], + 'bridges': [], + } + + for key, value in state.items(): + if key.startswith('_') and not retain_type_keys: + continue + + subpath = path + (key,) + if check_if_path_in_removed_nodes(subpath, remove_nodes): + # skip node if path in removed_nodes + continue + + node_spec = { + 'name': key, + 'path': subpath, + 'value': None, + 'type': None + } + + is_edge = core.check('edge', value) + if is_edge: # this is a process/edge node + if key in PROCESS_SCHEMA_KEYS and not retain_process_keys: + continue + + graph_dict['process_nodes'].append(node_spec) + + # this is an edge, get its inputs and outputs + input_wires = value.get('inputs', {}) + output_wires = value.get('outputs', {}) + input_schema = value.get('_inputs', {}) + output_schema = value.get('_outputs', {}) + + # get the input and output wires + graph_dict = get_graph_wires( + input_schema, input_wires, graph_dict, schema_key='inputs', edge_path=subpath, port=()) + graph_dict = get_graph_wires( + output_schema, output_wires, graph_dict, schema_key='outputs', edge_path=subpath, port=()) + + else: # this is a state node + if not isinstance(value, dict): # this is a leaf node + node_spec['value'] = value + node_spec['type'] = schema.get(key, {}).get('_type') + else: + # node_spec['value'] = str(value) + node_spec['type'] = schema.get(key, {}).get('_type') + graph_dict['state_nodes'].append(node_spec) + + if isinstance(value, dict): # get subgraph + if is_edge: + removed_process_schema_keys = [subpath + (schema_key,) for schema_key in PROCESS_SCHEMA_KEYS] + remove_nodes.extend(removed_process_schema_keys) + + graph_dict = get_graph_dict( + schema=schema.get(key, schema), + state=value, + core=core, + graph_dict=graph_dict, + path=subpath, + top_state=top_state, + remove_nodes=remove_nodes + ) + + # get the place edge + for node in value.keys(): + if node.startswith('_') and not retain_type_keys: + continue + + child_path = subpath + (node,) + if check_if_path_in_removed_nodes(child_path, remove_nodes): + continue + graph_dict['place_edges'].append({ + 'parent': subpath, + 'child': child_path}) + + return graph_dict + + +def get_graphviz_fig( + graph_dict, + label_margin='0.05', + node_label_size='12pt', + size='16,10', + rankdir='TB', + dpi='70', + show_values=False, + show_types=False, + port_labels=True, + port_label_size='10pt', +): + """make a graphviz figure from a graph_dict""" + node_names = [] + + # node specs + state_node_spec = { + 'shape': 'circle', 'penwidth': '2', 'margin': label_margin, 'fontsize': node_label_size} + process_node_spec = { + 'shape': 'box', 'penwidth': '2', 'constraint': 'false', 'margin': label_margin, 'fontsize': node_label_size} + hyper_edge_spec = { + 'style': 'dashed', 'penwidth': '1', 'arrowhead': 'dot', 'arrowsize': '0.5'} + input_edge_spec = { + 'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0'} + output_edge_spec = { + 'style': 'dashed', 'penwidth': '1', 'arrowhead': 'normal', 'arrowsize': '1.0', 'dir': 'back'} + + # initialize graph + graph = graphviz.Digraph(name='bigraph', engine='dot') + graph.attr(size=size, overlap='false', rankdir=rankdir, dpi=dpi) + + # state nodes + graph.attr('node', **state_node_spec) + for node in graph_dict['state_nodes']: + node_path = node['path'] + node_name = str(node_path) + node_names.append(node_name) + + # make the label + label = node_path[-1] + schema_label = None + if show_values: + if node.get('value'): + if not schema_label: + schema_label = '' + schema_label += f": {node['value']}" + if show_types: + if node.get('type'): + if not schema_label: + schema_label = '
' + schema_label += f"[{node['type']}]" + if schema_label: + label += schema_label + label = make_label(label) + graph.node(node_name, label=label) + + # process nodes + process_paths = [] + graph.attr('node', **process_node_spec) + for node in graph_dict['process_nodes']: + node_path = node['path'] + process_paths.append(node_path) + node_name = str(node_path) + node_names.append(node_name) + label = make_label(node_path[-1]) + graph.node(node_name, label=label) + + # place edges + graph.attr('edge', arrowhead='none', penwidth='2') + for edge in graph_dict['place_edges']: + graph.attr('edge', style='filled') + parent_node = str(edge['parent']) + child_node = str(edge['child']) + graph.edge(parent_node, child_node) + + # hyper edges + for edge in graph_dict['hyper_edges']: + process_path = edge['edge_path'] + process_name = str(process_path) + target_path = edge['target_path'] + port = edge['port'] + edge_type = edge['type'] # input or output + target_name = str(target_path) + + # place it in the graph + if target_name not in graph.body: # is the source node already in the graph? + label = make_label(target_path[-1]) + graph.node(target_name, label=label, **state_node_spec) + + if edge_type == 'inputs': + graph.attr('edge', **input_edge_spec) + elif edge_type == 'outputs': + graph.attr('edge', **output_edge_spec) + else: + graph.attr('edge', **hyper_edge_spec) + with graph.subgraph(name=process_name) as c: + if port_labels: + label = make_label(port) + c.edge(target_name, process_name, label=label, labelloc="t", fontsize=port_label_size) + else: + c.edge(target_name, process_name) + + # disconnected hyper edges + graph.attr('edge', **hyper_edge_spec) + for edge in graph_dict['disconnected_hyper_edges']: + process_path = edge['edge_path'] + process_name = str(process_path) + port = edge['port'] + edge_type = edge['type'] # input or output + + # add invisible node for port + node_name2 = str(absolute_path(process_path, port)) + graph.node(node_name2, label='', style='invis', width='0') + + # add the edge + if edge_type == 'inputs': + graph.attr('edge', **input_edge_spec) + elif edge_type == 'outputs': + graph.attr('edge', **output_edge_spec) + else: + graph.attr('edge', **hyper_edge_spec) + with graph.subgraph(name=process_name) as c: + if port_labels: + label = make_label(port) + c.edge(node_name2, process_name, label=label, labelloc="t", fontsize=port_label_size + ) + else: + c.edge(node_name2, process_name) + + return graph + + +def plot_bigraph( + state, + schema=None, + core=None, + out_dir=None, + filename=None, + file_format='png', + size='16,10', + node_label_size='12pt', + show_values=False, + show_types=False, + port_labels=True, + port_label_size='10pt', + rankdir='TB', + print_source=False, + dpi='70', + label_margin='0.05', + # show_process_schema=False, + # collapse_processes=False, + # node_border_colors=None, + # node_fill_colors=None, + # node_groups=False, + remove_nodes=None, + # invisible_edges=False, + # mark_top=False, + # remove_process_place_edges=False, +): + # get kwargs dict and remove plotting-specific kwargs + kwargs = locals() + state = kwargs.pop('state') + schema = kwargs.pop('schema') + core = kwargs.pop('core') + file_format = kwargs.pop('file_format') + out_dir = kwargs.pop('out_dir') + filename = kwargs.pop('filename') + print_source = kwargs.pop('print_source') + remove_nodes = kwargs.pop('remove_nodes') + # show_process_schema = kwargs.pop('show_process_schema') + + # set defaults if none provided + core = core or TypeSystem() + schema = schema or {} + + if not core.exists('step'): + core.register('step', step_type) + if not core.exists('process'): + core.register('process', process_type) + + schema, state = core.complete(schema, state) + + # parse out the network + graph_dict = get_graph_dict( + schema=schema, + state=state, + core=core, + remove_nodes=remove_nodes, + ) + + # make a figure + graph = get_graphviz_fig(graph_dict, **kwargs) + + # display or save results + if print_source: + print(graph.source) + if filename is not None: + out_dir = out_dir or 'out' + os.makedirs(out_dir, exist_ok=True) + fig_path = os.path.join(out_dir, filename) + print(f"Writing {fig_path}") + graph.render(filename=fig_path, format=file_format) + return graph + + +def test_diagram_plot(): + cell = { + 'config': { + '_type': 'map[float]', + 'a': 11.0, #{'_type': 'float', '_value': 11.0}, + 'b': 3333.33}, + 'cell': { + '_type': 'process', # TODO -- this should also accept process, step, but how in bigraph-schema? + 'config': {}, + 'address': 'local:cell', # TODO -- this is where the ports/inputs/outputs come from + 'internal': 1.0, + '_inputs': { + 'nutrients': 'float', + }, + '_outputs': { + 'secretions': 'float', + 'biomass': 'float', + }, + 'inputs': { + 'nutrients': ['down', 'nutrients_store'], + }, + 'outputs': { + # 'secretions': ['secretions_store'], + 'biomass': ['biomass_store'], + } + } + } + plot_bigraph(cell, filename='bigraph_cell', + show_values=True, + show_types=True, + # port_labels=False, + # rankdir='BT', + # remove_nodes=[ + # ('cell', 'address',), + # ('cell', 'config'), + # ('cell', 'interval'), + # ] + ) + +def test_bio_schema(): + b = { + 'environment': { + 'cells': {}, + 'fields': {}, + 'barriers': {}, + 'diffusion': { + '_type': 'process', + # '_inputs': { + # 'fields': 'array' + # }, + 'inputs': { + 'fields': ['fields',] + } + } + }} + + plot_bigraph(b, filename='bioschema') + + + +if __name__ == '__main__': + # test_diagram_plot() + test_bio_schema() diff --git a/bigraph_viz/plot.py b/bigraph_viz/plot.py index 7ed55be..5c8b086 100644 --- a/bigraph_viz/plot.py +++ b/bigraph_viz/plot.py @@ -67,7 +67,7 @@ def get_bigraph_network( for key, child in bigraph_dict.items(): # skip process schema - if not show_process_schema and key in process_schema_keys: + if not show_process_schema and key in process_schema_keys: continue if key in schema_keys: continue diff --git a/bigraph_viz/test.py b/bigraph_viz/test.py index bd23168..088bf99 100644 --- a/bigraph_viz/test.py +++ b/bigraph_viz/test.py @@ -47,7 +47,7 @@ def test_composite_spec(): '_type': 'int' }, 'process1': { - '_ports': { + '_inputs': { 'port1': {'_type': 'type'}, 'port2': {'_type': 'type'}, },