Skip to content

Commit

Permalink
replacing fourier resampling with zero-padded integer pixel rolling
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 10, 2023
1 parent 96f8ebc commit 167b677
Showing 1 changed file with 20 additions and 102 deletions.
122 changes: 20 additions & 102 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
from py4DSTEM import Calibration, DataCube
from py4DSTEM.preprocess.utils import get_shifted_ar
from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction
from py4DSTEM.process.phase.utils import AffineTransform, bilinearly_interpolate_array, lanczos_interpolate_array, bilinear_kernel_density_estimate, lanczos_kernel_density_estimate, vectorized_fourier_resample, pixel_rolling_kernel_density_estimate
from py4DSTEM.process.phase.utils import (
AffineTransform,
bilinear_kernel_density_estimate,
bilinearly_interpolate_array,
lanczos_interpolate_array,
lanczos_kernel_density_estimate,
pixel_rolling_kernel_density_estimate,
)
from py4DSTEM.process.utils.cross_correlate import align_images_fourier
from py4DSTEM.process.utils.utils import electron_wavelength_angstrom, fourier_resample
from py4DSTEM.process.utils.utils import electron_wavelength_angstrom
from py4DSTEM.visualize import return_scaled_histogram_ordering, show
from scipy.linalg import polar
from scipy.ndimage import distance_transform_edt
Expand Down Expand Up @@ -1191,7 +1198,7 @@ def subpixel_alignment(
kde_sigma_px=0.125,
kde_lowpass_filter=False,
lanczos_interpolation_order=None,
integer_pixel_rolling_alignment = False,
integer_pixel_rolling_alignment=False,
plot_upsampled_BF_comparison: bool = True,
plot_upsampled_FFT_comparison: bool = False,
position_correction_num_iter=None,
Expand Down Expand Up @@ -1339,15 +1346,17 @@ def subpixel_alignment(
lowpass_filter=kde_lowpass_filter,
)
else:
self._kde_upsample_factor = np.round(self._kde_upsample_factor).astype("int")
self._kde_upsample_factor = np.round(self._kde_upsample_factor).astype(
"int"
)

pix_output = pixel_rolling_kernel_density_estimate(
self._stack_BF_unshifted,
xy_shifts,
self._kde_upsample_factor,
kde_sigma_px * self._kde_upsample_factor,
xp = xp,
gaussian_filter = gaussian_filter,
xp=xp,
gaussian_filter=gaussian_filter,
)

# Perform probe position correction if needed
Expand Down Expand Up @@ -1845,7 +1854,7 @@ def _interpolate_array(
xp = self._xp

if lanczos_alpha is not None:
return lanczos_interpolate_array(image, xa, ya, lanczos_alpha,xp=xp)
return lanczos_interpolate_array(image, xa, ya, lanczos_alpha, xp=xp)
else:
return bilinearly_interpolate_array(
image,
Expand Down Expand Up @@ -1878,8 +1887,8 @@ def _kernel_density_estimate(
kde_sigma,
lanczos_alpha,
lowpass_filter=lowpass_filter,
xp = xp,
gaussian_filter = gaussian_filter,
xp=xp,
gaussian_filter=gaussian_filter,
)
else:
return bilinear_kernel_density_estimate(
Expand All @@ -1889,100 +1898,9 @@ def _kernel_density_estimate(
output_shape,
kde_sigma,
lowpass_filter=lowpass_filter,
xp = xp,
gaussian_filter = gaussian_filter,
)

def _vectorized_fourier_resample_stack(
self,
stack,
shifts,
upsampling_factor,
extra_factor,
):
""" """
xp = self._xp

upsampled_stack = vectorized_fourier_resample(
stack,
scale=upsampling_factor * extra_factor,
xp = xp,
)

upsampled_shifts = shifts * upsampling_factor * extra_factor
upsampled_shifts_int = xp.modf(upsampled_shifts)[-1].astype("int")

for BF_index in range(upsampled_stack.shape[0]):
shift = upsampled_shifts_int[BF_index]
upsampled_stack[BF_index] = xp.roll(upsampled_stack[BF_index], shift, axis=(0, 1))

upsampled_stack = vectorized_fourier_resample(
upsampled_stack,
scale=1/extra_factor,
xp = xp,
)

return upsampled_stack

def _serial_fourier_resample_stack(
self,
stack,
shifts,
upsampling_factor,
extra_factor,
):
""" """
xp = self._xp
asnumpy = self._asnumpy

numpy_stack = asnumpy(stack)
numpy_shifts = asnumpy(shifts)

upsampled_shape = np.array(numpy_stack.shape)
upsampled_shape *= (1, upsampling_factor, upsampling_factor)

upsampled_shifts = numpy_shifts * upsampling_factor * extra_factor
upsampled_shifts_int = np.modf(upsampled_shifts)[-1].astype("int")

resampled_stack = xp.empty(upsampled_shape, dtype=xp.float32)

for BF_index in tqdmnd(upsampled_shape[0]):

resampled = fourier_resample(
numpy_stack[BF_index], upsampling_factor * extra_factor
xp=xp,
gaussian_filter=gaussian_filter,
)
shift = upsampled_shifts_int[BF_index]

resampled = np.roll(resampled, shift, axis=(0, 1))
resampled = fourier_resample(resampled, 1 / extra_factor)
resampled_stack[BF_index] = xp.asarray(resampled)

return resampled_stack

def _fourier_resample_stack(
self,
stack,
shifts,
upsampling_factor,
extra_factor,
vectorized = True,
):
""" """

if vectorized:
return self._vectorized_fourier_resample_stack(
stack,
shifts,
upsampling_factor,
extra_factor
)
else:
return self._serial_fourier_resample_stack(
stack,
shifts,
upsampling_factor,
extra_factor
)

def aberration_fit(
self,
Expand Down

0 comments on commit 167b677

Please sign in to comment.