diff --git a/heracles/plot.py b/heracles/plot.py index 012651b..3afb14b 100644 --- a/heracles/plot.py +++ b/heracles/plot.py @@ -18,7 +18,7 @@ # License along with Heracles. If not, see . """utility functions for plotting""" -from collections import defaultdict +from collections import Counter, defaultdict from collections.abc import Mapping from itertools import chain, count, cycle @@ -65,6 +65,7 @@ def postage_stamps( hatch_empty=False, linscale=0.01, cycler=None, + group=True, ): """create a postage stamp plot for cls""" @@ -89,26 +90,30 @@ def postage_stamps( else: trkeys = {} - stamps = sorted( - set(key[-2:] for key in keys) | set(key[-2:][::-1] for key in trkeys), - ) - - sx = list(set(i for i, _ in stamps)) - sy = list(set(j for _, j in stamps)) + # either group the plots by last two indices or not + # uses Counter as a set that remembers order + if group: + si = list(Counter(key[-2] for key in keys)) + sj = list(Counter(key[-1] for key in keys)) + ti = list(Counter(key[-2] for key in trkeys)) + tj = list(Counter(key[-1] for key in trkeys)) + else: + si = sj = list(Counter(key[k::2] for key in keys for k in (0, 1))) + ti = tj = list(Counter(key[k::2] for key in trkeys for k in (0, 1))) - nx = len(sx) - ny = len(sy) + nx = max(len(si), len(tj)) + ny = max(len(sj), len(ti)) if trkeys: nx += trxshift ny += tryshift - figw = (ny + (ny - 1) * space) * stampsize - figh = (nx + (nx - 1) * space) * stampsize + figw = (nx + (nx - 1) * space) * stampsize + figh = (ny + (ny - 1) * space) * stampsize fig, axes = plt.subplots( - nx, ny, + nx, figsize=(figw, figh), squeeze=False, sharex=False, @@ -129,11 +134,19 @@ def postage_stamps( ki, kj, i, j = key if n < len(keys): - idx = (sx.index(j) + trxshift, sy.index(i)) + if group: + idx_y, idx_x = sj.index(j), si.index(i) + else: + idx_x, idx_y = sorted([si.index((ki, i)), sj.index((kj, j))]) + idx = (idx_y + tryshift, idx_x) cls = (x.get(key) for x in plot) axidx.add(idx) else: - idx = (sx.index(i), sy.index(j) + tryshift) + if group: + idx_y, idx_x = ti.index(i), tj.index(j) + else: + idx_y, idx_x = sorted([ti.index((ki, i)), tj.index((kj, j))]) + idx = (idx_y, idx_x + trxshift) cls = (x.get(key) for x in transpose) traxidx.add(idx) @@ -185,9 +198,9 @@ def postage_stamps( label = None xmin, xmax = _pad_xlim(xmin, xmax) - ymin, ymax = _pad_ylim(ymin, ymax) - ylin = 10 ** np.ceil(np.log10(max(abs(ymin), abs(ymax)) * linscale)) - + if keys: + ymin, ymax = _pad_ylim(ymin, ymax) + ylin = 10 ** np.ceil(np.log10(max(abs(ymin), abs(ymax)) * linscale)) if trkeys: trymin, trymax = _pad_ylim(trymin, trymax) trylin = 10 ** np.ceil(np.log10(max(abs(trymin), abs(trymax)) * linscale)) @@ -210,9 +223,9 @@ def postage_stamps( left=True, right=True, labeltop=(idx[0] == 0), - labelbottom=(idx[0] == nx - 1), + labelbottom=(idx[0] == ny - 1), labelleft=(idx[1] == 0), - labelright=(idx[1] == ny - 1), + labelright=(idx[1] == nx - 1), ) ax.set_axisbelow(False) diff --git a/tests/test_plot.py b/tests/test_plot.py index 7d8f8b7..43b57f5 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -21,14 +21,14 @@ def test_postage_stamps(): ("P", "P", 1, 1): cl, } - fig = postage_stamps(plot, transpose, trxshift=3, tryshift=2, hatch_empty=True) + fig = postage_stamps(plot, transpose, trxshift=2, tryshift=3, hatch_empty=True) assert len(fig.axes) == 5 * 4 axes = np.reshape(fig.axes, (5, 4)) - for i in range(5): # rows: 2 + trxshift - for j in range(4): # columns: 2 + tryshift + for i in range(5): # rows: 2 + tryshift + for j in range(4): # columns: 2 + trxshift lines = axes[i, j].get_lines() if i - j > 2: assert len(lines) == 2 # E, B in lower