Skip to content

Commit

Permalink
slow fourier resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 9, 2023
1 parent 2efda6b commit d905cb3
Showing 1 changed file with 148 additions and 26 deletions.
174 changes: 148 additions & 26 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction
from py4DSTEM.process.phase.utils import AffineTransform
from py4DSTEM.process.utils.cross_correlate import align_images_fourier
from py4DSTEM.process.utils.utils import electron_wavelength_angstrom
from py4DSTEM.process.utils.utils import electron_wavelength_angstrom, fourier_resample
from py4DSTEM.visualize import return_scaled_histogram_ordering, show
from scipy.linalg import polar
from scipy.ndimage import distance_transform_edt
Expand Down Expand Up @@ -1190,14 +1190,18 @@ def subpixel_alignment(
kde_upsample_factor=None,
kde_sigma_px=0.125,
kde_lowpass_filter=False,
additional_fourier_resampling_factor=None,
plot_upsampled_BF_comparison: bool = True,
plot_upsampled_FFT_comparison: bool = False,
position_correction_num_iter=None,
position_correction_initial_step_size=1.0,
position_correction_min_step_size=0.1,
position_correction_step_size_factor=0.75,
position_correction_checkerboard_steps=False,
position_correction_regularization_sigma=None,
position_correction_gaussian_filter_sigma=None,
position_correction_butterworth_q_lowpass=None,
position_correction_butterworth_q_highpass=None,
position_correction_butterworth_order=(2, 2),
plot_position_correction_convergence: bool = True,
progress_bar: bool = True,
**kwargs,
Expand Down Expand Up @@ -1228,8 +1232,14 @@ def subpixel_alignment(
Factor to multiply step-size by between iterations
position_correction_checkerboard_steps: bool, optional
If True, uses steepest-descent checkerboarding steps, as opposed to gradient direction
position_correction_regularization_sigma, tuple(float, float), optional
Bandwidth to regularize corrected positions in pixels
position_correction_gaussian_filter_sigma: tuple(float, float), optional
Standard deviation of gaussian kernel in A
position_correction_butterworth_q_lowpass: tuple(float, float), optional
Cut-off frequency in A^-1 for low-pass butterworth filter
position_correction_butterworth_q_highpass: tuple(float, float), optional
Cut-off frequency in A^-1 for high-pass butterworth filter
position_correction_butterworth_order: tuple(int,int), optional
Butterworth filter order. Smaller gives a smoother filter
plot_position_correction_convergence: bool, optional
If True, position correction convergence is plotted
progress_bar: bool, optional
Expand Down Expand Up @@ -1301,23 +1311,34 @@ def subpixel_alignment(
self._kde_upsample_factor = kde_upsample_factor
pixel_output_shape = np.round(BF_size * self._kde_upsample_factor).astype("int")

# shifted coordinates
x = xp.arange(BF_size[0], dtype=xp.float32)
y = xp.arange(BF_size[1], dtype=xp.float32)
xa_init, ya_init = xp.meshgrid(x, y, indexing="ij")

# kernel density output the upsampled BF image
xa = (xa_init + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor
ya = (ya_init + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor

pix_output = self._kernel_density_estimate(
xa,
ya,
self._stack_BF_unshifted,
pixel_output_shape,
kde_sigma_px * self._kde_upsample_factor,
lowpass_filter=kde_lowpass_filter,
)
if (
additional_fourier_resampling_factor is None
or position_correction_num_iter is not None
):
# shifted coordinates
x = xp.arange(BF_size[0], dtype=xp.float32)
y = xp.arange(BF_size[1], dtype=xp.float32)
xa_init, ya_init = xp.meshgrid(x, y, indexing="ij")

# kernel density output the upsampled BF image
xa = (xa_init + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor
ya = (ya_init + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor

pix_output = self._kernel_density_estimate(
xa,
ya,
self._stack_BF_unshifted,
pixel_output_shape,
kde_sigma_px * self._kde_upsample_factor,
lowpass_filter=kde_lowpass_filter,
)
else:
pix_output = self._fourier_resample_stack(
self._stack_BF_unshifted,
xy_shifts,
self._kde_upsample_factor,
additional_fourier_resampling_factor,
).mean(0)

# Perform probe position correction if needed
if position_correction_num_iter is not None:
Expand Down Expand Up @@ -1478,18 +1499,75 @@ def subpixel_alignment(
step = xp.maximum(step, position_correction_min_step_size)

# apply regularization if needed
if position_correction_regularization_sigma is not None:
if position_correction_gaussian_filter_sigma is not None:
self._probe_dx = gaussian_filter(
self._probe_dx,
position_correction_regularization_sigma[0],
mode="nearest",
position_correction_gaussian_filter_sigma[0]
/ self._scan_sampling[0],
# mode="nearest",
)
self._probe_dy = gaussian_filter(
self._probe_dy,
position_correction_regularization_sigma[1],
mode="nearest",
position_correction_gaussian_filter_sigma[1]
/ self._scan_sampling[1],
# mode="nearest",
)

if (
position_correction_butterworth_q_lowpass is not None
or position_correction_butterworth_q_highpass is not None
):
qx = xp.fft.fftfreq(BF_size[0], self._scan_sampling[0])
qy = xp.fft.fftfreq(BF_size[1], self._scan_sampling[1])

qya, qxa = xp.meshgrid(qy, qx)
qra = xp.sqrt(qxa**2 + qya**2)

if position_correction_butterworth_q_lowpass:
(
q_lowpass_x,
q_lowpass_y,
) = position_correction_butterworth_q_lowpass
else:
q_lowpass_x, q_lowpass_y = (None, None)
if position_correction_butterworth_q_highpass:
(
q_highpass_x,
q_highpass_y,
) = position_correction_butterworth_q_highpass
else:
q_highpass_x, q_highpass_y = (None, None)

order_x, order_y = position_correction_butterworth_order

# dx
env = xp.ones_like(qra)
if q_highpass_x:
env *= 1 - 1 / (1 + (qra / q_highpass_x) ** (2 * order_x))
if q_lowpass_x:
env *= 1 / (1 + (qra / q_lowpass_x) ** (2 * order_x))

probe_dx_mean = xp.mean(self._probe_dx)
self._probe_dx -= probe_dx_mean
self._probe_dx = xp.real(
xp.fft.ifft2(xp.fft.fft2(self._probe_dx) * env)
)
self._probe_dx += probe_dx_mean

# dy
env = xp.ones_like(qra)
if q_highpass_y:
env *= 1 - 1 / (1 + (qra / q_highpass_y) ** (2 * order_y))
if q_lowpass_y:
env *= 1 / (1 + (qra / q_lowpass_y) ** (2 * order_y))

probe_dy_mean = xp.mean(self._probe_dy)
self._probe_dy -= probe_dy_mean
self._probe_dy = xp.real(
xp.fft.ifft2(xp.fft.fft2(self._probe_dy) * env)
)
self._probe_dy += probe_dy_mean

# kernel density output the upsampled BF image
xa = (
xa_init + self._probe_dx + xy_shifts[:, 0, None, None]
Expand Down Expand Up @@ -1525,12 +1603,24 @@ def subpixel_alignment(
)

position_correction_stats[a0 + 1] = scores.mean()

if additional_fourier_resampling_factor is not None:
pix_output = self._fourier_resample_stack(
self._stack_BF_unshifted,
xy_shifts,
self._kde_upsample_factor,
additional_fourier_resampling_factor,
).mean(0)
else:
plot_position_correction_convergence = False

self._recon_BF_subpixel_aligned = pix_output
self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned)

if self._device == "gpu":
xp._default_memory_pool.free_all_blocks()
xp.clear_memo()

# plotting
nrows = np.count_nonzero(
np.array(
Expand Down Expand Up @@ -1948,6 +2038,38 @@ def _kernel_density_estimate(

return pix_output

def _fourier_resample_stack(
self,
stack,
shifts,
upsampling_factor,
extra_factor,
):
""" """
xp = self._xp
asnumpy = self._asnumpy

stack = asnumpy(stack)
upsampled_shape = np.array(stack.shape)
upsampled_shape *= (1, upsampling_factor, upsampling_factor)

upsampled_shifts = shifts * upsampling_factor * extra_factor
upsampled_shifts_int = asnumpy(xp.modf(upsampled_shifts)[-1].astype("int"))

resampled_stack = xp.empty(upsampled_shape, dtype=xp.float32)

for BF_index in tqdmnd(upsampled_shape[0]):
resampled = fourier_resample(
stack[BF_index], upsampling_factor * extra_factor
)
shift = upsampled_shifts_int[BF_index]

resampled = np.roll(resampled, shift, axis=(0, 1))
resampled = fourier_resample(resampled, 1 / extra_factor)
resampled_stack[BF_index] = xp.asarray(resampled)

return resampled_stack

def aberration_fit(
self,
fit_BF_shifts: bool = False,
Expand Down

0 comments on commit d905cb3

Please sign in to comment.