diff --git a/.gitignore b/.gitignore index f8261eca..745df54f 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ nosetests.xml coverage.xml *.cover .hypothesis/ +test-reports/ # Translations *.mo diff --git a/asap/em_montage_qc/detect_montage_defects.py b/asap/em_montage_qc/detect_montage_defects.py index 0bab75e6..82a1dc98 100644 --- a/asap/em_montage_qc/detect_montage_defects.py +++ b/asap/em_montage_qc/detect_montage_defects.py @@ -1,13 +1,15 @@ from functools import partial import time +import igraph import networkx as nx import numpy as np import renderapi +import renderapi.utils import requests -from rtree import index as rindex -from six import viewkeys from scipy.spatial import cKDTree +import shapely +import shapely.strtree from asap.residuals import compute_residuals as cr from asap.em_montage_qc.schemas import ( @@ -17,10 +19,10 @@ from asap.em_montage_qc.plots import plot_section_maps from asap.em_montage_qc.distorted_montages import ( - do_get_z_scales_nopm, - get_z_scales_nopm, + get_scales_from_tilespecs, + get_rts_fallthrough, get_scale_statistics_mad - ) +) example = { "render": { @@ -41,157 +43,277 @@ } -def detect_seams( - render, stack, match_collection, match_owner, z, - residual_threshold=8, distance=60, min_cluster_size=15, tspecs=None): - # seams will always be computed for montages using montage point matches - # but the input stack can be either montage, rough, or fine - # Compute residuals and other stats for this z - stats, allmatches = cr.compute_residuals_within_group( - render, stack, match_owner, match_collection, z, tilespecs=tspecs) +# TODO methods for tile borders/boundary points should go in render-python +def determine_numX_numY_rectangle(width, height, meshcellsize=32): + numX = max([2, np.around(width / meshcellsize)]) + numY = max([2, np.around(height / meshcellsize)]) + return int(numX), int(numY) + + +def determine_numX_numY_triangle(width, height, meshcellsize=32): + # TODO do we always want width to define the geometry? + numX = max([2, np.around(width / meshcellsize)]) + numY = max([ + 2, + np.around( + height / + ( + 2 * np.sqrt(3. / 4. * (width / (numX - 1)) ** 2) + ) + 1) + ]) + return int(numX), int(numY) + + +def generate_border_mesh_pts( + width, height, meshcellsize=64, mesh_type="square", **kwargs): + numfunc = { + "square": determine_numX_numY_rectangle, + "triangle": determine_numX_numY_triangle + }[mesh_type] + numX, numY = numfunc(width, height, meshcellsize) + + xs = np.linspace(0, width - 1, numX).reshape(-1, 1) + ys = np.linspace(0, height - 1, numY).reshape(-1, 1) + + perim = np.vstack([ + np.hstack([xs, np.zeros(xs.shape)]), + np.hstack([np.ones(ys.shape) * float(height - 1), ys])[1:-1], + np.hstack([xs, np.ones(xs.shape) * float(width - 1)])[::-1], + np.hstack([np.zeros(ys.shape), ys])[1:-1][::-1], + + ]) + return perim + + +def polygon_from_ts(ts, ref_tforms, **kwargs): + tsarr = generate_border_mesh_pts(ts.width, ts.height, **kwargs) + return shapely.geometry.Polygon( + renderapi.transform.estimate_dstpts( + ts.tforms, src=tsarr, reference_tforms=ref_tforms)) + + +def polygons_from_rts(rts, **kwargs): + return [ + polygon_from_ts(ts, rts.transforms, **kwargs) + for ts in rts.tilespecs + ] + + +def strtree_query_geometries(tree, q): + res = tree.query(q) + return tree.geometries[res] + + +def pair_clusters_networkx(pairs, min_cluster_size=25): + G = nx.Graph() + G.add_edges_from(pairs) + + # get the connected subraphs from G + Gc = nx.connected_components(G) + # get the list of nodes in each component + fnodes = sorted((list(n) for n in Gc if len(n) > min_cluster_size), + key=len, reverse=True) + return fnodes + + +def pair_clusters_igraph(pairs, min_cluster_size=25): + G_ig = igraph.Graph(edges=pairs, directed=False) + + # get the connected subraphs from G + cc_ig = G_ig.connected_components(mode='strong') + # filter nodes list with min_cluster_size + fnodes = sorted((c for c in cc_ig if len(c) > min_cluster_size), + key=len, reverse=True) + return fnodes + + +def detect_seams(tilespecs, matches, residual_threshold=10, + distance=80, min_cluster_size=25, cluster_method="igraph"): + stats, allmatches = cr.compute_residuals(tilespecs, matches) # get mean positions of the point matches as numpy array pt_match_positions = np.concatenate( - list(stats['pt_match_positions'].values()), - 0) + list(stats['pt_match_positions'].values()), 0 + ) # get the tile residuals - tile_residuals = np.concatenate(list(stats['tile_residuals'].values())) + tile_residuals = np.concatenate( + list(stats['tile_residuals'].values()) + ) # threshold the points based on residuals new_pts = pt_match_positions[ - np.where(tile_residuals >= residual_threshold), :][0] - - if len(new_pts) > 0: - # construct a KD Tree using these points - tree = cKDTree(new_pts) - # construct a networkx graph - G = nx.Graph() - # find the pairs of points within a distance to each other - pairs = tree.query_pairs(r=distance) - G.add_edges_from(pairs) - # get the connected subraphs from G - Gc = nx.connected_components(G) - # get the list of nodes in each component - nodes = sorted(Gc, key=len, reverse=True) - # filter nodes list with min_cluster_size - fnodes = [list(nn) for nn in nodes if len(nn) > min_cluster_size] - # get pts list for each filtered node list - points_list = [new_pts[mm, :] for mm in fnodes] - centroids = [[np.sum(pt[:, 0])/len(pt), np.sum(pt[:, 1])/len(pt)] - for pt in points_list] - else: - centroids = [] + np.where(tile_residuals >= residual_threshold), : + ][0] + + # construct a KD Tree using these points + tree = cKDTree(new_pts) + # construct a networkx graph + + # find the pairs of points within a distance to each other + pairs = tree.query_pairs(r=distance) + + fnodes = { + "igraph": pair_clusters_igraph, + "networkx": pair_clusters_networkx + }[cluster_method](pairs, min_cluster_size=min_cluster_size) + + # get pts list for each filtered node list + points_list = [new_pts[mm, :] for mm in fnodes] + centroids = np.array([(np.sum(pt, axis=0) / len(pt)).tolist() + for pt in points_list]) return centroids, allmatches, stats -def detect_disconnected_tiles(render, prestitched_stack, poststitched_stack, - z, pre_tilespecs=None, post_tilespecs=None): +def detect_seams_from_collections( + render, stack, match_collection, match_owner, z, + residual_threshold=8, distance=60, min_cluster_size=15, tspecs=None): + session = requests.session() + + groupId = render.run( + renderapi.stack.get_sectionId_for_z, stack, z, session=session) + allmatches = render.run( + renderapi.pointmatch.get_matches_within_group, + match_collection, + groupId, + owner=match_owner, + session=session) + if tspecs is None: + tspecs = render.run( + renderapi.tilespec.get_tile_specs_from_z, + stack, z, session=session) + session.close() + + return detect_seams( + tspecs, allmatches, + residual_threshold=residual_threshold, + distance=distance, + min_cluster_size=min_cluster_size) + + +def detect_disconnected_tiles(pre_tilespecs, post_tilespecs): + pre_tileIds = {ts.tileId for ts in pre_tilespecs} + post_tileIds = {ts.tileId for ts in post_tilespecs} + missing_tileIds = list(pre_tileIds - post_tileIds) + return missing_tileIds + + +def detect_disconnected_tiles_from_collections( + render, prestitched_stack, poststitched_stack, + z, pre_tilespecs=None, post_tilespecs=None): session = requests.session() # get the tilespecs for both prestitched_stack and poststitched_stack if pre_tilespecs is None: pre_tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - prestitched_stack, - z, - session=session) + renderapi.tilespec.get_tile_specs_from_z, + prestitched_stack, + z, + session=session) if post_tilespecs is None: post_tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - poststitched_stack, - z, - session=session) - # pre tile_ids - pre_tileIds = [] - pre_tileIds = [ts.tileId for ts in pre_tilespecs] - # post tile_ids - post_tileIds = [] - post_tileIds = [ts.tileId for ts in post_tilespecs] - missing_tileIds = list(set(pre_tileIds) - set(post_tileIds)) + renderapi.tilespec.get_tile_specs_from_z, + poststitched_stack, + z, + session=session) session.close() - return missing_tileIds + return detect_disconnected_tiles(pre_tilespecs, post_tilespecs) + + +def detect_stitching_gaps(pre_rts, post_rts, polygon_kwargs={}, use_bbox=False): + tId_to_pre_polys = { + ts.tileId: ( + polygon_from_ts( + ts, pre_rts.transforms, **polygon_kwargs) + if not use_bbox else shapely.geometry.box(*ts.bbox) + ) + for ts in pre_rts.tilespecs + } + + tId_to_post_polys = { + ts.tileId: ( + polygon_from_ts( + ts, post_rts.transforms, **polygon_kwargs) + if not use_bbox else shapely.geometry.box(*ts.bbox) + ) + for ts in post_rts.tilespecs + } + + poly_id_to_tId = { + id(p): tId for tId, p in + (i for l in ( + tId_to_pre_polys.items(), tId_to_post_polys.items()) + for i in l) + } + + pre_polys = [*tId_to_pre_polys.values()] + post_polys = [*tId_to_post_polys.values()] + + pre_tree = shapely.strtree.STRtree(pre_polys) + post_tree = shapely.strtree.STRtree(post_polys) + + pre_graph = nx.Graph({ + poly_id_to_tId[id(p)]: [ + poly_id_to_tId[id(r)] + for r in strtree_query_geometries(pre_tree, p) + ] + for p in pre_polys}) + post_graph = nx.Graph({ + poly_id_to_tId[id(p)]: [ + poly_id_to_tId[id(r)] + for r in strtree_query_geometries(post_tree, p) + ] + for p in post_polys}) + + diff_g = nx.Graph(pre_graph.edges - post_graph.edges) + gap_tiles = [n for n in diff_g.nodes() if diff_g.degree(n) > 0] + return gap_tiles -def detect_stitching_gaps(render, prestitched_stack, poststitched_stack, - z, pre_tilespecs=None, tilespecs=None): +def detect_stitching_gaps_legacy(render, prestitched_stack, poststitched_stack, + z, pre_tilespecs=None, tilespecs=None): session = requests.session() - # setup an rtree to find overlapping tiles - pre_ridx = rindex.Index() - # setup a graph to store overlapping tiles - G1 = nx.Graph() - # get the tilespecs for both prestitched_stack and poststitched_stack + if pre_tilespecs is None: pre_tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - prestitched_stack, - z, - session=session) + renderapi.tilespec.get_tile_specs_from_z, + prestitched_stack, + z, + session=session) if tilespecs is None: tilespecs = render.run( - renderapi.tilespec.get_tile_specs_from_z, - poststitched_stack, - z, - session=session) - # insert the prestitched_tilespecs into rtree - # with their bounding boxes to find overlaps - for i, ts in enumerate(pre_tilespecs): - pre_ridx.insert(i, ts.bbox) - - pre_tileIds = {} - for i, ts in enumerate(pre_tilespecs): - pre_tileIds[ts.tileId] = i - nodes = list(pre_ridx.intersection(ts.bbox)) - nodes.remove(i) - [G1.add_edge(i, node) for node in nodes] - # G1 contains the prestitched_stack tile)s and the degree - # of each node representing the number of tiles that overlap. - # This overlap count has to match in the poststitched_stack - G2 = nx.Graph() - post_ridx = rindex.Index() - tileId_to_ts = {ts.tileId: ts for ts in tilespecs} - shared_tileIds = viewkeys(tileId_to_ts) & viewkeys(pre_tileIds) - [post_ridx.insert(pre_tileIds[tId], tileId_to_ts[tId].bbox) - for tId in shared_tileIds] - for ts in tilespecs: - try: - i = pre_tileIds[ts.tileId] - except KeyError: - continue - nodes = list(post_ridx.intersection(ts.bbox)) - nodes.remove(i) - [G2.add_edge(i, node) for node in nodes] - # Now G1 and G2 have the same index for the same tileId - # comparing the degree of each node pre and post - # stitching should reveal stitching gaps - gap_tiles = [] - for n in G2.nodes(): - if G1.degree(n) > G2.degree(n): - tileId = list(pre_tileIds.keys())[ - list(pre_tileIds.values()).index(n)] - gap_tiles.append(tileId) + renderapi.tilespec.get_tile_specs_from_z, + poststitched_stack, + z, + session=session) + + gap_tiles = detect_stitching_gaps( + renderapi.resolvedtiles.ResolvedTiles(tilespecs=pre_tilespecs), + renderapi.resolvedtiles.ResolvedTiles(tilespecs=tilespecs), + use_bbox=True) session.close() return gap_tiles -def detect_distortion(render, poststitched_stack, zvalue, threshold_cutoff=[0.005, 0.005], pool_size=20): - #z_to_scales = {zvalue: do_get_z_scales_nopm(zvalue, [poststitched_stack], render)} - z_to_scales = {} - # check if any scale is None - #zs = [z for z, scales in z_to_scales.items() if scales is None] - #for z in zs: - # z_to_scales[z] = get_z_scales_nopm(z, [poststitched_stack], render) +def detect_distortion_tilespecs(tilespecs, zvalue, threshold_cutoff=[0.005, 0.005]): + scales = get_scales_from_tilespecs(tilespecs) + mad_stats = get_scale_statistics_mad(scales) + badzs_cutoff = ( + [zvalue] if ( + mad_stats[0] > threshold_cutoff[0] or + mad_stats[1] > threshold_cutoff[1] + ) + else []) + return badzs_cutoff - try: - z_to_scales[zvalue] = get_z_scales_nopm(zvalue, [poststitched_stack], render) - except Exception: - z_to_scales[zvalue] = None - # get the mad statistics - z_to_scalestats = {z: get_scale_statistics_mad(scales) for z, scales in z_to_scales.items() if scales is not None} +def detect_distortion( + render, poststitched_stack, zvalue, + threshold_cutoff=[0.005, 0.005], pool_size=20, tilespecs=None): + if tilespecs is None: + rts = get_rts_fallthrough([poststitched_stack], zvalue, render=render) + tilespecs = rts.tilespecs - # find zs that fall outside cutoff - badzs_cutoff = [z for z, s in z_to_scalestats.items() if s[0] > threshold_cutoff[0] or s[1] > threshold_cutoff[1]] - return badzs_cutoff + return detect_distortion_tilespecs(tilespecs, zvalue, threshold_cutoff) def get_pre_post_tspecs(render, prestitched_stack, poststitched_stack, z): @@ -216,18 +338,19 @@ def run_analysis( min_cluster_size, threshold_cutoff, z): pre_tspecs, post_tspecs = get_pre_post_tspecs( render, prestitched_stack, poststitched_stack, z) - disconnected_tiles = detect_disconnected_tiles( + disconnected_tiles = detect_disconnected_tiles_from_collections( render, prestitched_stack, poststitched_stack, z, pre_tspecs, post_tspecs) - gap_tiles = detect_stitching_gaps( + gap_tiles = detect_stitching_gaps_legacy( render, prestitched_stack, poststitched_stack, z, pre_tspecs, post_tspecs) - seam_centroids, matches, stats = detect_seams( - render, poststitched_stack, match_collection, match_collection_owner, + seam_centroids, matches, stats = detect_seams_from_collections( + render, poststitched_stack, match_collection, match_collection_owner, z, residual_threshold=residual_threshold, distance=neighbor_distance, min_cluster_size=min_cluster_size, tspecs=post_tspecs) distorted_zs = detect_distortion( - render, poststitched_stack, z, threshold_cutoff=threshold_cutoff) + render, poststitched_stack, z, threshold_cutoff=threshold_cutoff, + tilespecs=post_tspecs) return (disconnected_tiles, gap_tiles, seam_centroids, distorted_zs, post_tspecs, matches, stats) @@ -346,8 +469,8 @@ def run(self): 'gap_sections': gaps, 'seam_sections': seams, 'distorted_sections': distorted_zs, - 'seam_centroids': np.array(centroids, dtype=object)}) - print(self.output) + 'seam_centroids': np.array(centroids, dtype=object)}, + cls=renderapi.utils.RenderEncoder) # delete the stacks that were cloned if status1 == 'LOADING': self.render.run(renderapi.stack.delete_stack, new_prestitched) diff --git a/asap/em_montage_qc/distorted_montages.py b/asap/em_montage_qc/distorted_montages.py index 6346f325..a422f142 100644 --- a/asap/em_montage_qc/distorted_montages.py +++ b/asap/em_montage_qc/distorted_montages.py @@ -90,6 +90,11 @@ def groupId_from_tilespec(ts): def sections_from_resolvedtiles(rts): return list({groupId_from_tilespec(ts) for ts in rts.tilespecs}) + +def get_scales_from_tilespecs(tilespecs): + return np.array([ts.tforms[-1].scale for ts in tilespecs]) + + def get_z_scales_nopm(z, input_stacks, render): rts = get_rts_fallthrough(input_stacks, z, render=render) scales = np.array([ts.tforms[-1].scale for ts in rts.tilespecs]) diff --git a/asap/em_montage_qc/plots.py b/asap/em_montage_qc/plots.py index e80dfc2c..146adfbb 100644 --- a/asap/em_montage_qc/plots.py +++ b/asap/em_montage_qc/plots.py @@ -7,13 +7,13 @@ import renderapi from bokeh.palettes import Plasma256, Viridis256 -from bokeh.plotting import figure, output_file, save +from bokeh.plotting import figure, save from bokeh.layouts import row -from bokeh.models.widgets import Tabs, Panel from bokeh.models import (HoverTool, ColumnDataSource, CustomJS, CategoricalColorMapper, LinearColorMapper, - TapTool, OpenURL, Div, ColorBar) + TapTool, OpenURL, Div, ColorBar, + Tabs, TabPanel) from asap.residuals import compute_residuals as cr @@ -33,13 +33,29 @@ xrange = range +def bbox_tup_to_xs_ys(bbox_tup): + min_x, min_y, max_x, max_y = bbox_tup + return ( + [min_x, max_x, max_x, min_x, min_x], + [min_y, min_y, max_y, max_y, min_y] + ) + + +def tilespecs_to_xs_ys(tilespecs): + return zip(*(bbox_tup_to_xs_ys(ts.bbox) for ts in tilespecs)) + + def point_match_plot(tilespecsA, matches, tilespecsB=None): if tilespecsB is None: tilespecsB = tilespecsA if len(matches) > 0: - x1, y1, id1 = cr.get_tile_centers(tilespecsA) - x2, y2, id2 = cr.get_tile_centers(tilespecsB) + tId_to_ctr_a = { + idx: (x, y) + for x, y, idx in zip(*cr.get_tile_centers(tilespecsA))} + tId_to_ctr_b = { + idx: (x, y) + for x, y, idx in zip(*cr.get_tile_centers(tilespecsB))} xs = [] ys = [] @@ -54,20 +70,25 @@ def point_match_plot(tilespecsA, matches, tilespecsB=None): xs.append([0.1, 0]) ys.append([0.1, 0.1]) - if (set(x1) != set(x2)): + if (set(tId_to_ctr_a.keys()) != set(tId_to_ctr_b.keys())): clist.append(500) else: clist.append(200) - for k in np.arange(len(matches)): - t1 = np.argwhere(id1 == matches[k]['qId']).flatten() - t2 = np.argwhere(id2 == matches[k]['pId']).flatten() - if (t1.size != 0) & (t2.size != 0): - t1 = t1[0] - t2 = t2[0] - xs.append([x1[t1], x2[t2]]) - ys.append([y1[t1], y2[t2]]) - clist.append(len(matches[k]['matches']['q'][0])) + for m in matches: + match_qId = m['qId'] + match_pId = m['pId'] + num_pts = len(m['matches']['q'][0]) + + try: + a_ctr = tId_to_ctr_a[match_qId] + b_ctr = tId_to_ctr_b[match_pId] + except KeyError: + continue + + xs.append([a_ctr[0], b_ctr[0]]) + ys.append([a_ctr[1], b_ctr[1]]) + clist.append(num_pts) mapper = LinearColorMapper( palette=Plasma256, low=min(clist), high=max(clist)) @@ -77,9 +98,24 @@ def point_match_plot(tilespecsA, matches, tilespecsB=None): TOOLS = "pan,box_zoom,reset,hover,save" + w = np.ptp(xs) + h = np.ptp(ys) + base_dim = 1000 + if w > h: + h = int(np.round(base_dim * (h / w))) + w = base_dim + else: + h = base_dim + w = int(np.round(base_dim * w / h)) + plot = figure( - plot_width=800, plot_height=700, - background_fill_color='gray', tools=TOOLS) + # plot_width=800, plot_height=700, + width=w, height=h, + background_fill_color='gray', tools=TOOLS, + # sizing_mode="stretch_both", + match_aspect=True + ) + plot.tools[1].match_aspect = True plot.multi_line( xs="xs", ys="ys", source=source, color={'field': 'colors', 'transform': mapper}, line_width=2) @@ -88,14 +124,24 @@ def point_match_plot(tilespecsA, matches, tilespecsB=None): plot.ygrid.visible = False else: - plot = figure(plot_width=800, plot_height=700, - background_fill_color='gray') + plot = figure( + # width=800, height=700, + background_fill_color='gray') return plot def plot_residual(xs, ys, residual): - p = figure(width=1000, height=1000) + w = np.ptp(xs) + h = np.ptp(ys) + base_dim = 1000 + if w > h: + h = int(np.round(base_dim * (h / w))) + w = base_dim + else: + h = base_dim + w = int(np.round(base_dim * w / h)) + p = figure(width=w, height=h) color_mapper = LinearColorMapper( palette=Viridis256, low=min(residual), high=max(residual)) @@ -107,7 +153,7 @@ def plot_residual(xs, ys, residual): fill_alpha=1.0, line_color="black", line_width=0.05) color_bar = ColorBar(color_mapper=color_mapper, label_standoff=12, - border_line_color=None, location=(0,0)) + border_line_color=None, location=(0, 0)) p.add_layout(color_bar, 'right') @@ -117,68 +163,37 @@ def plot_residual(xs, ys, residual): return p -def plot_defects(render, stack, out_html_dir, args): - tspecs = args[0] - matches = args[1] - dis_tiles = args[2] - gap_tiles = args[3] - seam_centroids = np.array(args[4]) - stats = args[5] - z = args[6] - - # Tile residual mean +def plot_residual_tilespecs( + tspecs, tileId_to_point_residuals, default_residual=50): tile_residual_mean = cr.compute_mean_tile_residuals( - stats['tile_residuals']) - - tile_positions = [] - tile_ids = [] - residual = [] - for ts in tspecs: - tile_ids.append(ts.tileId) - pts = [] - pts.append([ts.minX, ts.minY]) - pts.append([ts.maxX, ts.minY]) - pts.append([ts.maxX, ts.maxY]) - pts.append([ts.minX, ts.maxY]) - pts.append([ts.minX, ts.minY]) - tile_positions.append(pts) - - try: - residual.append(tile_residual_mean[ts.tileId]) - except KeyError: - residual.append(50) # a high value for residual for that tile + tileId_to_point_residuals) + xs, ys = tilespecs_to_xs_ys(tspecs) + residual = [ + tile_residual_mean.get(ts.tileId, default_residual) + for ts in tspecs + ] - out_html = os.path.join( - out_html_dir, - "%s_%d_%s.html" % ( - stack, - z, - datetime.datetime.now().strftime('%Y%m%d%H%S%M%f'))) - - output_file(out_html) - xs = [] - ys = [] - alphas = [] - for tp in tile_positions: - sp = np.array(tp) - x = list(sp[:, 0]) - y = list(sp[:, 1]) - xs.append(x) - ys.append(y) - alphas.append(0.5) + return plot_residual(xs, ys, residual) + + +def montage_defect_plot( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=None): + xs, ys = tilespecs_to_xs_ys(tspecs) + alphas = [0.5] * len(xs) fill_color = [] label = [] - for t in tile_ids: - if t in gap_tiles: - label.append("Gap tiles") - fill_color.append("red") - elif t in dis_tiles: - label.append("Disconnected tiles") - fill_color.append("yellow") - else: - label.append("Stitched tiles") - fill_color.append("blue") + gap_tiles = set(gap_tiles) + disconnected_tiles = set(disconnected_tiles) + + tile_ids = [ts.tileId for ts in tspecs] + + label, fill_color = zip(*( + (("Gap tiles", "red") if tileId in gap_tiles + else ("Disconnected tiles", "yellow") if tileId in disconnected_tiles + else ("Stitched tiles", "blue")) + for tileId in tile_ids)) color_mapper = CategoricalColorMapper( factors=['Gap tiles', 'Disconnected tiles', 'Stitched tiles'], @@ -194,13 +209,28 @@ def plot_defects(render, stack, out_html_dir, args): TOOLS = "pan,box_zoom,reset,hover,tap,save" - p = figure(title=str(z), width=1000, height=1000, - tools=TOOLS, match_aspect=True) + w = np.ptp(xs) + h = np.ptp(ys) + base_dim = 1000 + if w > h: + h = int(np.round(base_dim * (h / w))) + w = base_dim + else: + h = base_dim + w = int(np.round(base_dim * w / h)) + + p = figure(title=str(z), + # width=1000, height=1000, + width=w, height=h, + tools=TOOLS, + match_aspect=True) + p.tools[1].match_aspect = True pp = p.patches( 'x', 'y', source=source, alpha='alpha', line_width=2, - color={'field': 'labels', 'transform': color_mapper}, legend='labels') - cp = p.circle('x', 'y', source=seam_source, legend='lbl', size=11) + color={'field': 'labels', 'transform': color_mapper}, + legend_group='labels') + cp = p.scatter('x', 'y', source=seam_source, legend_group='lbl', size=11) jscode = """ var inds = cb_obj.selected['1d'].indices; @@ -211,38 +241,74 @@ def plot_defects(render, stack, out_html_dir, args): if ( lines.length > 35 ) { lines.shift(); } div.text = lines.join("\\n"); """ - div = Div(width=1000) + div = Div(width=w) layout = row(p, div) - urls = "%s:%d/render-ws/v1/owner/%s/project/%s/stack/%s/tile/@names/withNeighbors/jpeg-image?scale=0.1" % (render.DEFAULT_HOST, render.DEFAULT_PORT, render.DEFAULT_OWNER, render.DEFAULT_PROJECT, stack) - - taptool = p.select(type=TapTool) - taptool.renderers = [pp] - taptool.callback = OpenURL(url=urls) + if tile_url_format: + taptool = p.select(type=TapTool) + taptool.renderers = [pp] + taptool.callback = OpenURL(url=tile_url_format) hover = p.select(dict(type=HoverTool)) hover.renderers = [pp] hover.point_policy = "follow_mouse" hover.tooltips = [("tileId", "@names"), ("x", "$x{int}"), ("y", "$y{int}")] - source.callback = CustomJS(args=dict(div=div), code=jscode % ('names')) + source.js_event_callbacks['identify'] = [ + CustomJS(args=dict(div=div), code=jscode % ('names')) + ] + return layout + + +def create_montage_qc_plots( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=None): + # montage qc + layout = montage_defect_plot( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=None) # add point match plot in another tab plot = point_match_plot(tspecs, matches) # montage statistics plots in other tabs - - stat_layout = plot_residual(xs, ys, residual) + stat_layout = plot_residual_tilespecs(tspecs, stats["tile_residuals"]) tabs = [] - tabs.append(Panel(child=layout, title="Defects")) - tabs.append(Panel(child=plot, title="Point match plot")) - tabs.append(Panel(child=stat_layout, title="Mean tile residual")) + tabs.append(TabPanel(child=layout, title="Defects")) + tabs.append(TabPanel(child=plot, title="Point match plot")) + tabs.append(TabPanel(child=stat_layout, title="Mean tile residual")) plot_tabs = Tabs(tabs=tabs) - save(plot_tabs) + return plot_tabs + + +def write_montage_qc_plots(out_fn, plt): + return save(plt, out_fn) + +def run_montage_qc_plots_legacy(render, stack, out_html_dir, args): + tspecs = args[0] + matches = args[1] + disconnected_tiles = args[2] + gap_tiles = args[3] + seam_centroids = np.array(args[4]) + stats = args[5] + z = args[6] + + out_html = os.path.join( + out_html_dir, + "%s_%d_%s.html" % ( + stack, + z, + datetime.datetime.now().strftime('%Y%m%d%H%S%M%f'))) + tile_url_format = f"{render.DEFAULT_HOST}:{render.DEFAULT_PORT}/render-ws/v1/owner/{render.DEFAULT_OWNER}/project/{render.DEFAULT_PROJECT}/stack/{stack}/tile/@names/withNeighbors/jpeg-image?scale=0.1" + + qc_plot = create_montage_qc_plots( + tspecs, matches, disconnected_tiles, gap_tiles, + seam_centroids, stats, z, tile_url_format=tile_url_format) + write_montage_qc_plots(out_html, qc_plot) return out_html @@ -253,7 +319,9 @@ def plot_section_maps( if out_html_dir is None: out_html_dir = tempfile.mkdtemp() - mypartial = partial(plot_defects, render, stack, out_html_dir) + mypartial = partial( + run_montage_qc_plots_legacy, render, stack, out_html_dir + ) args = zip(post_tspecs, matches, disconnected_tiles, gap_tiles, seam_centroids, stats, zvalues) diff --git a/asap/residuals/compute_residuals.py b/asap/residuals/compute_residuals.py index 5201eb33..f5ce9d1d 100644 --- a/asap/residuals/compute_residuals.py +++ b/asap/residuals/compute_residuals.py @@ -3,48 +3,33 @@ import renderapi -def compute_residuals_within_group(render, stack, matchCollectionOwner, - matchCollection, z, min_points=1, - tilespecs=None): - session = requests.session() +def compute_residuals(tilespecs, matches, min_points=1, extra_statistics=None): + """from compute_residuals_in_group for in-memory""" + extra_statistics = extra_statistics or {} - # get the sectionID which is the group ID in point match collection - groupId = render.run( - renderapi.stack.get_sectionId_for_z, stack, z, session=session) + tId_to_tforms = {ts.tileId: ts.tforms for ts in tilespecs} - # get matches within the group for this section - allmatches = render.run( - renderapi.pointmatch.get_matches_within_group, - matchCollection, - groupId, - owner=matchCollectionOwner, - session=session) - - # get the tilespecs to extract the transformations - if tilespecs is None: - tilespecs = render.run(renderapi.tilespec.get_tile_specs_from_z, - stack, z, session=session) - tforms = {ts.tileId: ts.tforms for ts in tilespecs} - - tile_residuals = {key: np.empty((0, 1)) for key in tforms.keys()} - tile_rmse = {key: np.empty((0, 1)) for key in tforms.keys()} - pt_match_positions = {key: np.empty((0, 2)) for key in tforms.keys()} + tile_residuals = {key: np.empty((0, 1)) for key in tId_to_tforms.keys()} + tile_rmse = {key: np.empty((0, 1)) for key in tId_to_tforms.keys()} + pt_match_positions = {key: np.empty((0, 2)) for key in tId_to_tforms.keys()} statistics = {} - for i, match in enumerate(allmatches): + + for match in matches: pts_p = np.array(match['matches']['p']) pts_q = np.array(match['matches']['q']) - + if pts_p.shape[1] < min_points: continue + try: - t_p = tforms[match['pId']][-1].tform(pts_p.T) - t_q = tforms[match['qId']][-1].tform(pts_q.T) + t_p = tId_to_tforms[match['pId']][-1].tform(pts_p.T) + t_q = tId_to_tforms[match['qId']][-1].tform(pts_q.T) except KeyError: continue positions = (t_p + t_q) / 2. - + res = np.linalg.norm(t_p - t_q, axis=1) rmse = np.true_divide(res, res.shape[0]) @@ -55,38 +40,66 @@ def compute_residuals_within_group(render, stack, matchCollectionOwner, pt_match_positions[match['pId']] = np.append( pt_match_positions[match['pId']], positions, axis=0) - # remove empty entries from these dicts empty_keys = [k for k in tile_residuals if tile_residuals[k].size == 0] + for k in empty_keys: tile_residuals.pop(k) tile_rmse.pop(k) pt_match_positions.pop(k) statistics['tile_rmse'] = tile_rmse - statistics['z'] = z statistics['tile_residuals'] = tile_residuals statistics['pt_match_positions'] = pt_match_positions - session.close() + statistics = {**extra_statistics, **statistics} + + return statistics, matches - return statistics, allmatches +def compute_residuals_within_group(render, stack, matchCollectionOwner, + matchCollection, z, min_points=1, + tilespecs=None): + session = requests.session() -def compute_mean_tile_residuals(residuals): - tile_mean = {} + # get the sectionID which is the group ID in point match collection + groupId = render.run( + renderapi.stack.get_sectionId_for_z, stack, z, session=session) + + # get matches within the group for this section + allmatches = render.run( + renderapi.pointmatch.get_matches_within_group, + matchCollection, + groupId, + owner=matchCollectionOwner, + session=session) - # loop over each tile and compute the mean residual for each tile - # iteritems is specific to py2.7 - maxes = [np.nanmean(v) for v in residuals.values() if len(v) > 0] - maximum = np.max(maxes) + # get the tilespecs to extract the transformations + if tilespecs is None: + tilespecs = render.run(renderapi.tilespec.get_tile_specs_from_z, + stack, z, session=session) - for key in residuals: - if len(residuals[key]) == 0: - tile_mean[key] = maximum - else: - tile_mean[key] = np.nanmean(residuals[key]) + statistics, allmatches = compute_residuals( + tilespecs, allmatches, min_points=min_points, + extra_statistics={"z": z}) - return tile_mean + session.close() + + return statistics, allmatches + + +def compute_mean_tile_residuals(residuals): + tile_mean = { + tileId: (np.nanmean(tile_residuals) if tile_residuals.size else np.nan) + for tileId, tile_residuals in residuals.items() + } + tile_residual_max = np.nanmax(tile_mean.values()) + + return { + tileId: ( + tile_residual if not np.isnan(tile_residual) + else tile_residual_max) + for tileId, tile_residual in tile_mean.items() + } def get_tile_centers(tilespecs): diff --git a/integration_tests/test_em_montage_qc.py b/integration_tests/test_em_montage_qc.py index c4d9732c..3faca2a5 100644 --- a/integration_tests/test_em_montage_qc.py +++ b/integration_tests/test_em_montage_qc.py @@ -19,6 +19,7 @@ from asap.em_montage_qc.distorted_montages import DetectDistortedMontagesModule except: IMPORTS_ERRORED = True + raise @pytest.fixture(scope='module') @@ -152,9 +153,10 @@ def test_detect_montage_defects(render, # read the output json with open(ex['output_json'], 'r') as f: data = json.load(f) - f.close() assert(len(data['output_html']) > 0) + assert all([os.path.isfile(html_file) for html_file in data['output_html']]) + assert(len(data['seam_sections']) > 0) assert(len(data['hole_sections']) == 1) @@ -162,6 +164,7 @@ def test_detect_montage_defects(render, assert(len(data['distorted_sections']) == 0) assert(len(data['qc_passed_sections']) == 0) + for s in data['seam_centroids']: assert(len(s) > 0) diff --git a/pyproject.toml b/pyproject.toml index 92fbbc77..e4c5eedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "mpld3", "descartes", "lxml", - "pymongo==3.11.1" + "pymongo==3.11.1", + "igraph" ] [project.optional-dependencies] @@ -79,6 +80,7 @@ scipy = "*" imageio = "*" matplotlib = "*" petsc4py = "*" +python-igraph = "*" # version-specific python features [tool.pixi.feature.py310.dependencies] diff --git a/requirements.txt b/requirements.txt index 26ec0b91..ad076f88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ pathlib2 scipy rtree networkx -bokeh<=1.4.0 +bokeh bigfeta opencv-contrib-python em_stitch @@ -24,3 +24,4 @@ seaborn mpld3 descartes lxml +igraph