Skip to content

Commit

Permalink
ENH(plot): shift transpose in x and y direction (#23)
Browse files Browse the repository at this point in the history
Adds separate `trxshift=` and `tryshift=` parameters to
`postage_stamps()` to shift the transpose along the x-axis and/or
y-axis.
  • Loading branch information
ntessore authored Sep 6, 2023
1 parent 624dcbe commit 6231aa7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
6 changes: 3 additions & 3 deletions examples/example.ipynb

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions heracles/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def postage_stamps(
transpose=None,
*,
scale=None,
shift_transpose=0,
trxshift=0,
tryshift=0,
stampsize=1.0,
hatch_empty=False,
linscale=0.01,
Expand Down Expand Up @@ -92,7 +93,8 @@ def postage_stamps(
ny = len(sy)

if trkeys:
ny += shift_transpose
nx += trxshift
ny += tryshift

fig, axes = plt.subplots(
nx,
Expand All @@ -117,11 +119,11 @@ def postage_stamps(
ki, kj, i, j = key

if n < len(keys):
idx = (sx.index(j), sy.index(i))
idx = (sx.index(j) + trxshift, sy.index(i))
cls = (x.get(key) for x in plot)
axidx.add(idx)
else:
idx = (sx.index(i), sy.index(j) + shift_transpose)
idx = (sx.index(i), sy.index(j) + tryshift)
cls = (x.get(key) for x in transpose)
traxidx.add(idx)

Expand Down
18 changes: 9 additions & 9 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ def test_postage_stamps():
("P", "P", 1, 1): cl,
}

fig = postage_stamps(plot, transpose, shift_transpose=2, hatch_empty=True)
fig = postage_stamps(plot, transpose, trxshift=3, tryshift=2, hatch_empty=True)

assert len(fig.axes) == 2 * 4
assert len(fig.axes) == 5 * 4

axes = np.reshape(fig.axes, (2, 4))
axes = np.reshape(fig.axes, (5, 4))

for i in range(2): # rows
for j in range(4): # columns: 2 + shift
for i in range(5): # rows: 2 + trxshift
for j in range(4): # columns: 2 + tryshift
lines = axes[i, j].get_lines()
if i >= j:
if i - j > 2:
assert len(lines) == 3 # E, B and axhline in lower
elif i + 1 == j:
assert len(lines) == 0 # empty diagonal
else:
elif i - j < -1:
assert len(lines) == 2 # P and axhline in upper
else:
assert len(lines) == 0 # empty diagonal

plt.close()

0 comments on commit 6231aa7

Please sign in to comment.