Skip to content

Commit

Permalink
plotting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 8, 2023
1 parent f5e31dd commit 0abd641
Showing 1 changed file with 149 additions and 60 deletions.
209 changes: 149 additions & 60 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import PercentFormatter
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from py4DSTEM import Calibration, DataCube
from py4DSTEM.preprocess.utils import get_shifted_ar
Expand Down Expand Up @@ -1320,6 +1321,8 @@ def subpixel_alignment(

# Perform probe position correction if needed
if position_correction_num_iter is not None:
recon_BF_subpixel_aligned_reference = pix_output.copy()

# init position shift array
self._probe_dx = xp.zeros_like(xa_init)
self._probe_dy = xp.zeros_like(xa_init)
Expand Down Expand Up @@ -1523,66 +1526,113 @@ def subpixel_alignment(
)

position_correction_stats[a0 + 1] = scores.mean()

if plot_position_correction_convergence:
fig, ax = plt.subplots(figsize=(8, 2))
ax.plot(
np.arange(position_correction_num_iter + 1),
position_correction_stats,
color=(1, 0, 0),
)
ax.set_xlabel("iterations")
ax.set_ylabel("position error")
else:
plot_position_correction_convergence = False

self._recon_BF_subpixel_aligned = pix_output
self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned)

# plotting
if plot_upsampled_BF_comparison:
if plot_upsampled_FFT_comparison:
figsize = kwargs.pop("figsize", (8, 8))
fig, axs = plt.subplots(2, 2, figsize=figsize)
else:
figsize = kwargs.pop("figsize", (8, 4))
fig, axs = plt.subplots(1, 2, figsize=figsize)
nrows = np.count_nonzero(
np.array(
[
plot_upsampled_BF_comparison,
plot_upsampled_FFT_comparison,
plot_position_correction_convergence,
]
)
)
if nrows > 0:
ncols = 3 if position_correction_num_iter is not None else 2
height_ratios = (
[4, 4, 2][-nrows:]
if plot_position_correction_convergence
else [4, 4, 2][:nrows]
)
spec = GridSpec(
ncols=ncols, nrows=nrows, height_ratios=height_ratios, hspace=0.15
)

axs = axs.flat
figsize = kwargs.pop("figsize", (4 * ncols, sum(height_ratios)))
cmap = kwargs.pop("cmap", "magma")
fig = plt.figure(figsize=figsize)

cropped_object = self._crop_padded_object(self._recon_BF)
cropped_object_aligned = self._crop_padded_object(
self._recon_BF_subpixel_aligned, upsampled=True
)
row_index = 0

extent = [
0,
self._scan_sampling[1] * cropped_object.shape[1],
self._scan_sampling[0] * cropped_object.shape[0],
0,
]
if plot_upsampled_BF_comparison:
ax1 = fig.add_subplot(spec[row_index, 0])
ax2 = fig.add_subplot(spec[row_index, 1])

axs[0].imshow(
cropped_object,
extent=extent,
cmap=cmap,
**kwargs,
)
axs[0].set_title("Aligned Bright Field")
cropped_object = self._crop_padded_object(self._recon_BF)

axs[1].imshow(
cropped_object_aligned,
extent=extent,
cmap=cmap,
**kwargs,
)
axs[1].set_title("Upsampled Bright Field")
if ncols == 3:
ax3 = fig.add_subplot(spec[row_index, 2])

for ax in axs[:2]:
ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")
cropped_object_reference_aligned = self._crop_padded_object(
recon_BF_subpixel_aligned_reference, upsampled=True
)
cropped_object_aligned = self._crop_padded_object(
self._recon_BF_subpixel_aligned, upsampled=True
)
axs = [ax1, ax2, ax3]

else:
cropped_object_reference_aligned = self._crop_padded_object(
self._recon_BF_subpixel_aligned, upsampled=True
)
axs = [ax1, ax2]

extent = [
0,
self._scan_sampling[1] * cropped_object.shape[1],
self._scan_sampling[0] * cropped_object.shape[0],
0,
]

axs[0].imshow(
cropped_object,
extent=extent,
cmap=cmap,
**kwargs,
)
axs[0].set_title("Aligned Bright Field")

axs[1].imshow(
cropped_object_reference_aligned,
extent=extent,
cmap=cmap,
**kwargs,
)
axs[1].set_title("Upsampled Bright Field")

if ncols == 3:
axs[2].imshow(
cropped_object_aligned,
extent=extent,
cmap=cmap,
**kwargs,
)
axs[2].set_title("Probe-Corrected Bright Field")

for ax in axs:
ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")

row_index += 1

if plot_upsampled_FFT_comparison:
ax1 = fig.add_subplot(spec[row_index, 0])
ax2 = fig.add_subplot(spec[row_index, 1])

reciprocal_extent = [
-0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
-0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
]

recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF)))

pad_x = np.round(
BF_size[0] * (self._kde_upsample_factor - 1) / 2
).astype("int")
Expand All @@ -1593,43 +1643,82 @@ def subpixel_alignment(
xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y)))
)

upsampled_fft = asnumpy(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
if ncols == 3:
ax3 = fig.add_subplot(spec[row_index, 2])
upsampled_fft_reference = asnumpy(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(recon_BF_subpixel_aligned_reference))
)
)
)

reciprocal_extent = [
-0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
-0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
]
upsampled_fft = asnumpy(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
)
)
axs = [ax1, ax2, ax3]
else:
upsampled_fft_reference = asnumpy(
xp.fft.fftshift(
xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
)
)
axs = [ax1, ax2]

show(
pad_recon_fft,
figax=(fig, axs[2]),
figax=(fig, axs[0]),
extent=reciprocal_extent,
cmap="gray",
title="Aligned Bright Field FFT",
**kwargs,
)

show(
upsampled_fft,
figax=(fig, axs[3]),
upsampled_fft_reference,
figax=(fig, axs[1]),
extent=reciprocal_extent,
cmap="gray",
title="Upsampled Bright Field FFT",
**kwargs,
)

for ax in axs[2:]:
if ncols == 3:
show(
upsampled_fft,
figax=(fig, axs[2]),
extent=reciprocal_extent,
cmap="gray",
title="Probe-Corrected Bright Field FFT",
**kwargs,
)

for ax in axs:
ax.set_ylabel(r"$k_x$ [$A^{-1}$]")
ax.set_xlabel(r"$k_y$ [$A^{-1}$]")
ax.xaxis.set_ticks_position("bottom")

fig.tight_layout()
row_index += 1

if plot_position_correction_convergence:
axs = fig.add_subplot(spec[row_index, :])

kwargs.pop("vmin", None)
kwargs.pop("vmax", None)
color = kwargs.pop("color", (1, 0, 0))

axs.semilogy(
np.arange(position_correction_num_iter + 1),
position_correction_stats / position_correction_stats[0],
color=color,
**kwargs,
)
axs.set_xlabel("Iteration number")
axs.set_ylabel("NMSE")
axs.yaxis.set_major_formatter(PercentFormatter(1.0, decimals=0))
axs.yaxis.set_minor_formatter(PercentFormatter(1.0, decimals=0))

spec.tight_layout(fig)

def _bilinearly_sample_array(
self,
Expand Down

0 comments on commit 0abd641

Please sign in to comment.