Skip to content

Commit

Permalink
reasonable option names and variable consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 5, 2023
1 parent b4d606a commit df5fe2c
Showing 1 changed file with 75 additions and 41 deletions.
116 changes: 75 additions & 41 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,12 +1191,12 @@ def subpixel_alignment(
kde_lowpass_filter=False,
plot_upsampled_BF_comparison: bool = True,
plot_upsampled_FFT_comparison: bool = False,
position_corr_num_iter=None,
position_corr_step_start=1.0,
position_corr_step_min=0.1,
position_corr_step_reduce=0.75,
position_corr_sigma_reg=(0.25, 0.25),
plot_position_corr_convergence: bool = True,
position_correction_num_iter=None,
position_correction_initial_step_size=1.0,
position_correction_min_step_size=0.1,
position_correction_step_size_factor=0.75,
position_correction_regularization_sigma=(0.25, 0.25),
plot_position_correction_convergence: bool = True,
progress_bar: bool = True,
**kwargs,
):
Expand All @@ -1216,6 +1216,20 @@ def subpixel_alignment(
If True, the pre/post alignment BF images are plotted for comparison
plot_upsampled_FFT_comparison: bool, optional
If True, the pre/post alignment BF FFTs are plotted for comparison
position_correction_num_iter: int, optional
If not None, parallax positions are corrected iteratively for this many iterations
position_correction_initial_step_size: float, optional
Initial position correction step-size in pixels
position_correction_min_step_size: float, optional
Minimum position correction step-size in pixels
position_correction_step_size_factor: float, optional
Factor to multiply step-size by between iterations
position_correction_regularization_sigma, tuple(float, float), optional
Bandwidth to regularize corrected positions in pixels
plot_position_correction_convergence: bool, optional
If True, position correction convergence is plotted
progress_bar: bool, optional
If True, a progress bar is printed with position correction progress
"""
xp = self._xp
Expand Down Expand Up @@ -1291,6 +1305,7 @@ def subpixel_alignment(
# kernel density output the upsampled BF image
xa = (xa_init + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor
ya = (ya_init + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor

pix_output = self._kernel_density_estimate(
xa,
ya,
Expand All @@ -1301,17 +1316,19 @@ def subpixel_alignment(
)

# Perform probe position correction if needed
if position_corr_num_iter is not None:
if position_correction_num_iter is not None:
# init position shift array
self.probe_dx = np.zeros_like(xa_init)
self.probe_dy = np.zeros_like(xa_init)
self._probe_dx = xp.zeros_like(xa_init)
self._probe_dy = xp.zeros_like(xa_init)

# step size of initial search, cost function
step = np.ones_like(xa_init) * position_corr_step_start
step = xp.ones_like(xa_init) * position_correction_initial_step_size

# init scores and stats
scores = np.mean(
np.abs(
position_correction_stats = np.zeros(position_correction_num_iter + 1)

scores = xp.mean(
xp.abs(
self._bilinear_sample_array(
pix_output,
xa,
Expand All @@ -1321,10 +1338,14 @@ def subpixel_alignment(
),
axis=0,
)
position_corr_stats = np.zeros(position_corr_num_iter + 1)
position_corr_stats[0] = np.mean(scores)

self._scores = scores.copy()

position_correction_stats[0] = scores.mean()

# gradient search directions

# 4 isotropic directions
dxy = np.array(
[
[-1.0, 0.0],
Expand All @@ -1333,6 +1354,8 @@ def subpixel_alignment(
[0.0, 1.0],
]
)

# 8 isotropic directions
# dxy = np.array([
# [-1.0, 0.0],
# [ 1.0, 0.0],
Expand All @@ -1343,6 +1366,7 @@ def subpixel_alignment(
# [-0.71, 0.71],
# [ 0.71, 0.71],
# ])

scores_test = xp.zeros(
(
dxy.shape[0],
Expand All @@ -1351,30 +1375,34 @@ def subpixel_alignment(
)
)

self._scores_test = scores_test.copy()

# main loop for position correction
# for a0 in range(position_corr_num_iter):
for a0 in tqdmnd(
position_corr_num_iter,
position_correction_num_iter,
desc="Correcting positions: ",
unit=" iteration",
disable=not progress_bar,
):
# Evaluate scores for step directions and magnitudes

for a1 in range(dxy.shape[0]):
xa = (
xa_init
+ self.probe_dx
+ self._probe_dx
+ dxy[a1, 0] * step
+ xy_shifts[:, 0, None, None]
) * self._kde_upsample_factor
ya = (
ya_init
+ self.probe_dy
+ self._probe_dy
+ dxy[a1, 1] * step
+ xy_shifts[:, 1, None, None]
) * self._kde_upsample_factor
scores_test[a1] = np.mean(
np.abs(

scores_test[a1] = xp.mean(
xp.abs(
self._bilinear_sample_array(
pix_output,
xa,
Expand All @@ -1385,48 +1413,54 @@ def subpixel_alignment(
axis=0,
)

self._scores_test[a1] = scores_test[a1].copy()

# Check where cost function has improved
update = np.min(scores_test, axis=0) < scores
scores_ind = np.argmin(scores_test, axis=0)

update = xp.min(scores_test, axis=0) < scores
scores_ind = xp.argmin(scores_test, axis=0)

# update the scores and probe shifts
for a1 in range(dxy.shape[0]):
sub = np.logical_and(update, scores_ind == a1)
self.probe_dx[sub] += dxy[a1, 0] * step[sub]
self.probe_dy[sub] += dxy[a1, 1] * step[sub]
sub = xp.logical_and(update, scores_ind == a1)
self._probe_dx[sub] += dxy[a1, 0] * step[sub]
self._probe_dy[sub] += dxy[a1, 1] * step[sub]
# scores[sub] = scores_test[a1][sub]

# reduce gradient step for sites which did not improve
step[np.logical_not(update)] *= position_corr_step_reduce
step[xp.logical_not(update)] *= position_correction_step_size_factor

# enforce minimum step size
step = np.maximum(step, position_corr_step_min)
step = xp.maximum(step, position_correction_min_step_size)

# apply regularization if needed
if position_corr_sigma_reg is not None:
self.probe_dx = gaussian_filter(
self.probe_dx,
position_corr_sigma_reg,
if position_correction_regularization_sigma is not None:
self._probe_dx = gaussian_filter(
self._probe_dx,
position_correction_regularization_sigma[0],
mode="nearest",
)
self.probe_dy = gaussian_filter(
self.probe_dy,
position_corr_sigma_reg,
self._probe_dy = gaussian_filter(
self._probe_dy,
position_correction_regularization_sigma[1],
mode="nearest",
)

# kernel density output the upsampled BF image
xa = (
xa_init
+ self.probe_dx
+ self._probe_dx
+ dxy[a1, 0] * step
+ xy_shifts[:, 0, None, None]
) * self._kde_upsample_factor

ya = (
ya_init
+ self.probe_dy
+ self._probe_dy
+ dxy[a1, 1] * step
+ xy_shifts[:, 1, None, None]
) * self._kde_upsample_factor

pix_output = self._kernel_density_estimate(
xa,
ya,
Expand All @@ -1437,8 +1471,8 @@ def subpixel_alignment(
)

# update cost function and stats
scores = np.mean(
np.abs(
scores = xp.mean(
xp.abs(
self._bilinear_sample_array(
pix_output,
xa,
Expand All @@ -1448,13 +1482,13 @@ def subpixel_alignment(
),
axis=0,
)
position_corr_stats[a0 + 1] = np.mean(scores)
position_correction_stats[a0 + 1] = scores.mean()

if plot_position_corr_convergence:
if plot_position_correction_convergence:
fig, ax = plt.subplots(figsize=(8, 2))
ax.plot(
np.arange(position_corr_num_iter + 1),
position_corr_stats,
np.arange(position_correction_num_iter + 1),
position_correction_stats,
color=(1, 0, 0),
)
ax.set_xlabel("iterations")
Expand Down

0 comments on commit df5fe2c

Please sign in to comment.