Skip to content

Commit

Permalink
ENH(plot): option to not group plots in postage stamps (#45)
Browse files Browse the repository at this point in the history
Add a new `group=` parameter to `postage_stamps()` that controls whether
or not plots are grouped by their bin indices. For `group=False`, all
combinations are separated into a triangular plot.

Also fixes a number of bugs and inconsistencies in the plotting routine.

Closes: #44
  • Loading branch information
ntessore authored Oct 9, 2023
1 parent bbb741c commit d1678cd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 22 deletions.
51 changes: 32 additions & 19 deletions heracles/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# License along with Heracles. If not, see <https://www.gnu.org/licenses/>.
"""utility functions for plotting"""

from collections import defaultdict
from collections import Counter, defaultdict
from collections.abc import Mapping
from itertools import chain, count, cycle

Expand Down Expand Up @@ -65,6 +65,7 @@ def postage_stamps(
hatch_empty=False,
linscale=0.01,
cycler=None,
group=True,
):
"""create a postage stamp plot for cls"""

Expand All @@ -89,26 +90,30 @@ def postage_stamps(
else:
trkeys = {}

stamps = sorted(
set(key[-2:] for key in keys) | set(key[-2:][::-1] for key in trkeys),
)

sx = list(set(i for i, _ in stamps))
sy = list(set(j for _, j in stamps))
# either group the plots by last two indices or not
# uses Counter as a set that remembers order
if group:
si = list(Counter(key[-2] for key in keys))
sj = list(Counter(key[-1] for key in keys))
ti = list(Counter(key[-2] for key in trkeys))
tj = list(Counter(key[-1] for key in trkeys))
else:
si = sj = list(Counter(key[k::2] for key in keys for k in (0, 1)))
ti = tj = list(Counter(key[k::2] for key in trkeys for k in (0, 1)))

nx = len(sx)
ny = len(sy)
nx = max(len(si), len(tj))
ny = max(len(sj), len(ti))

if trkeys:
nx += trxshift
ny += tryshift

figw = (ny + (ny - 1) * space) * stampsize
figh = (nx + (nx - 1) * space) * stampsize
figw = (nx + (nx - 1) * space) * stampsize
figh = (ny + (ny - 1) * space) * stampsize

fig, axes = plt.subplots(
nx,
ny,
nx,
figsize=(figw, figh),
squeeze=False,
sharex=False,
Expand All @@ -129,11 +134,19 @@ def postage_stamps(
ki, kj, i, j = key

if n < len(keys):
idx = (sx.index(j) + trxshift, sy.index(i))
if group:
idx_y, idx_x = sj.index(j), si.index(i)
else:
idx_x, idx_y = sorted([si.index((ki, i)), sj.index((kj, j))])
idx = (idx_y + tryshift, idx_x)
cls = (x.get(key) for x in plot)
axidx.add(idx)
else:
idx = (sx.index(i), sy.index(j) + tryshift)
if group:
idx_y, idx_x = ti.index(i), tj.index(j)
else:
idx_y, idx_x = sorted([ti.index((ki, i)), tj.index((kj, j))])
idx = (idx_y, idx_x + trxshift)
cls = (x.get(key) for x in transpose)
traxidx.add(idx)

Expand Down Expand Up @@ -185,9 +198,9 @@ def postage_stamps(
label = None

xmin, xmax = _pad_xlim(xmin, xmax)
ymin, ymax = _pad_ylim(ymin, ymax)
ylin = 10 ** np.ceil(np.log10(max(abs(ymin), abs(ymax)) * linscale))

if keys:
ymin, ymax = _pad_ylim(ymin, ymax)
ylin = 10 ** np.ceil(np.log10(max(abs(ymin), abs(ymax)) * linscale))
if trkeys:
trymin, trymax = _pad_ylim(trymin, trymax)
trylin = 10 ** np.ceil(np.log10(max(abs(trymin), abs(trymax)) * linscale))
Expand All @@ -210,9 +223,9 @@ def postage_stamps(
left=True,
right=True,
labeltop=(idx[0] == 0),
labelbottom=(idx[0] == nx - 1),
labelbottom=(idx[0] == ny - 1),
labelleft=(idx[1] == 0),
labelright=(idx[1] == ny - 1),
labelright=(idx[1] == nx - 1),
)
ax.set_axisbelow(False)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def test_postage_stamps():
("P", "P", 1, 1): cl,
}

fig = postage_stamps(plot, transpose, trxshift=3, tryshift=2, hatch_empty=True)
fig = postage_stamps(plot, transpose, trxshift=2, tryshift=3, hatch_empty=True)

assert len(fig.axes) == 5 * 4

axes = np.reshape(fig.axes, (5, 4))

for i in range(5): # rows: 2 + trxshift
for j in range(4): # columns: 2 + tryshift
for i in range(5): # rows: 2 + tryshift
for j in range(4): # columns: 2 + trxshift
lines = axes[i, j].get_lines()
if i - j > 2:
assert len(lines) == 2 # E, B in lower
Expand Down

0 comments on commit d1678cd

Please sign in to comment.