Skip to content

Commit

Permalink
next iter typo, window mask, gradient direction steps
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 8, 2023
1 parent df5fe2c commit f5e31dd
Showing 1 changed file with 110 additions and 71 deletions.
181 changes: 110 additions & 71 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,8 @@ def subpixel_alignment(
position_correction_initial_step_size=1.0,
position_correction_min_step_size=0.1,
position_correction_step_size_factor=0.75,
position_correction_regularization_sigma=(0.25, 0.25),
position_correction_checkerboard_steps=False,
position_correction_regularization_sigma=None,
plot_position_correction_convergence: bool = True,
progress_bar: bool = True,
**kwargs,
Expand Down Expand Up @@ -1224,6 +1225,8 @@ def subpixel_alignment(
Minimum position correction step-size in pixels
position_correction_step_size_factor: float, optional
Factor to multiply step-size by between iterations
position_correction_checkerboard_steps: bool, optional
If True, uses steepest-descent checkerboarding steps, as opposed to gradient direction
position_correction_regularization_sigma, tuple(float, float), optional
Bandwidth to regularize corrected positions in pixels
plot_position_correction_convergence: bool, optional
Expand Down Expand Up @@ -1311,7 +1314,7 @@ def subpixel_alignment(
ya,
self._stack_BF_unshifted,
pixel_output_shape,
kde_sigma_px,
kde_sigma_px * self._kde_upsample_factor,
lowpass_filter=kde_lowpass_filter,
)

Expand All @@ -1327,45 +1330,46 @@ def subpixel_alignment(
# init scores and stats
position_correction_stats = np.zeros(position_correction_num_iter + 1)

scores = xp.mean(
xp.abs(
self._bilinear_sample_array(
pix_output,
xa,
ya,
)
- self._stack_BF_unshifted
),
axis=0,
scores = (
xp.mean(
xp.abs(
self._bilinearly_sample_array(
pix_output,
xa,
ya,
)
- self._stack_BF_unshifted
),
axis=0,
)
* self._window_pad
)

self._scores = scores.copy()

position_correction_stats[0] = scores.mean()

# gradient search directions

# 4 isotropic directions
dxy = np.array(
[
[-1.0, 0.0],
[1.0, 0.0],
[0.0, -1.0],
[0.0, 1.0],
]
)
if position_correction_checkerboard_steps:
# checkerboard steps
dxy = np.array(
[
[-1.0, 0.0],
[1.0, 0.0],
[0.0, -1.0],
[0.0, 1.0],
]
)

# 8 isotropic directions
# dxy = np.array([
# [-1.0, 0.0],
# [ 1.0, 0.0],
# [ 0.0, -1.0],
# [ 0.0, 1.0],
# [-0.71, -0.71],
# [ 0.71, -0.71],
# [-0.71, 0.71],
# [ 0.71, 0.71],
# ])
else:
# centered finite-difference directions
dxy = np.array(
[
[-0.5, 0.0],
[0.5, 0.0],
[0.0, -0.5],
[0.0, 0.5],
]
)

scores_test = xp.zeros(
(
Expand All @@ -1375,10 +1379,7 @@ def subpixel_alignment(
)
)

self._scores_test = scores_test.copy()

# main loop for position correction
# for a0 in range(position_corr_num_iter):
for a0 in tqdmnd(
position_correction_num_iter,
desc="Correcting positions: ",
Expand All @@ -1394,6 +1395,7 @@ def subpixel_alignment(
+ dxy[a1, 0] * step
+ xy_shifts[:, 0, None, None]
) * self._kde_upsample_factor

ya = (
ya_init
+ self._probe_dy
Expand All @@ -1403,7 +1405,7 @@ def subpixel_alignment(

scores_test[a1] = xp.mean(
xp.abs(
self._bilinear_sample_array(
self._bilinearly_sample_array(
pix_output,
xa,
ya,
Expand All @@ -1413,19 +1415,59 @@ def subpixel_alignment(
axis=0,
)

self._scores_test[a1] = scores_test[a1].copy()
if position_correction_checkerboard_steps:
# Check where cost function has improved

# Check where cost function has improved
scores_test *= self._window_pad[None]
update = np.min(scores_test, axis=0) < scores
scores_ind = np.argmin(scores_test, axis=0)

update = xp.min(scores_test, axis=0) < scores
scores_ind = xp.argmin(scores_test, axis=0)
for a1 in range(dxy.shape[0]):
sub = np.logical_and(update, scores_ind == a1)
self._probe_dx[sub] += (
dxy[a1, 0] * step[sub] * self._window_pad[sub]
)
self._probe_dy[sub] += (
dxy[a1, 1] * step[sub] * self._window_pad[sub]
)

# update the scores and probe shifts
for a1 in range(dxy.shape[0]):
sub = xp.logical_and(update, scores_ind == a1)
self._probe_dx[sub] += dxy[a1, 0] * step[sub]
self._probe_dy[sub] += dxy[a1, 1] * step[sub]
# scores[sub] = scores_test[a1][sub]
else:
# Check where cost function has improved
dx = scores_test[0] - scores_test[1]
dy = scores_test[2] - scores_test[3]

dr = xp.sqrt(dx**2 + dy**2) / step
dx *= self._window_pad / dr
dy *= self._window_pad / dr

# Fixed-size step
xa = (
xa_init + self._probe_dx + dx + xy_shifts[:, 0, None, None]
) * self._kde_upsample_factor

ya = (
ya_init + self._probe_dy + dy + xy_shifts[:, 1, None, None]
) * self._kde_upsample_factor

fixed_step_scores = (
xp.mean(
xp.abs(
self._bilinearly_sample_array(
pix_output,
xa,
ya,
)
- self._stack_BF_unshifted
),
axis=0,
)
* self._window_pad
)

update = fixed_step_scores < scores

self._probe_dx[update] += dx[update]
self._probe_dy[update] += dy[update]

# reduce gradient step for sites which did not improve
step[xp.logical_not(update)] *= position_correction_step_size_factor
Expand All @@ -1448,40 +1490,38 @@ def subpixel_alignment(

# kernel density output the upsampled BF image
xa = (
xa_init
+ self._probe_dx
+ dxy[a1, 0] * step
+ xy_shifts[:, 0, None, None]
xa_init + self._probe_dx + xy_shifts[:, 0, None, None]
) * self._kde_upsample_factor

ya = (
ya_init
+ self._probe_dy
+ dxy[a1, 1] * step
+ xy_shifts[:, 1, None, None]
ya_init + self._probe_dy + xy_shifts[:, 1, None, None]
) * self._kde_upsample_factor

pix_output = self._kernel_density_estimate(
xa,
ya,
self._stack_BF_unshifted,
pixel_output_shape,
kde_sigma_px,
kde_sigma_px * self._kde_upsample_factor,
lowpass_filter=kde_lowpass_filter,
)

# update cost function and stats
scores = xp.mean(
xp.abs(
self._bilinear_sample_array(
pix_output,
xa,
ya,
)
- self._stack_BF_unshifted
),
axis=0,
scores = (
xp.mean(
xp.abs(
self._bilinearly_sample_array(
pix_output,
xa,
ya,
)
- self._stack_BF_unshifted
),
axis=0,
)
* self._window_pad
)

position_correction_stats[a0 + 1] = scores.mean()

if plot_position_correction_convergence:
Expand Down Expand Up @@ -1591,7 +1631,7 @@ def subpixel_alignment(

fig.tight_layout()

def _bilinear_sample_array(
def _bilinearly_sample_array(
self,
image,
xa,
Expand Down Expand Up @@ -1789,9 +1829,8 @@ def _kernel_density_estimate(
)

# kernel density estimate
sigma = kde_sigma * self._kde_upsample_factor
pix_count = gaussian_filter(pix_count, sigma)
pix_output = gaussian_filter(pix_output, sigma)
pix_count = gaussian_filter(pix_count, kde_sigma)
pix_output = gaussian_filter(pix_output, kde_sigma)
sub = pix_count > 1e-3
pix_output[sub] /= pix_count[sub]
pix_output[np.logical_not(sub)] = 1
Expand Down

0 comments on commit f5e31dd

Please sign in to comment.