Skip to content

Commit

Permalink
normalize_order = 0 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Dec 16, 2023
1 parent 3b5bd64 commit 6982280
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions py4DSTEM/process/phase/iterative_parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _populate_instance(self, group):

def preprocess(
self,
edge_blend: float = 8.0,
edge_blend: float = 16.0,
threshold_intensity: float = 0.8,
normalize_images: bool = True,
normalize_order=0,
Expand Down Expand Up @@ -462,18 +462,25 @@ def preprocess(
(1, 2, 0),
)

# initalize
# initialize
stack_shape = (
self._num_bf_images,
self._grid_scan_shape[0] + self._object_padding_px[0],
self._grid_scan_shape[1] + self._object_padding_px[1],
)
if normalize_images:
self._normalized_stack = True
self._stack_BF_shifted = xp.ones(stack_shape, dtype=xp.float32)
self._stack_BF_unshifted = xp.ones(stack_shape, xp.float32)

if normalize_order == 0:
all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None]
# all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None]
weights = xp.average(
all_bfs.reshape((self._num_bf_images, -1)),
weights=self._window_edge.ravel(),
axis=1,
)
all_bfs /= weights[:, None, None]

self._stack_BF_shifted[
:,
Expand Down Expand Up @@ -600,6 +607,7 @@ def preprocess(
)

else:
self._normalized_stack = False
all_means = xp.mean(all_bfs, axis=(1, 2))
self._stack_BF_shifted = xp.full(stack_shape, all_means[:, None, None])
self._stack_BF_unshifted = xp.full(stack_shape, all_means[:, None, None])
Expand Down Expand Up @@ -698,10 +706,11 @@ def preprocess(

if plot_average_bf:
figsize = kwargs.pop("figsize", (8, 4))
cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma")

fig, ax = plt.subplots(1, 2, figsize=figsize)

self._visualize_figax(fig, ax[0], **kwargs)
self._visualize_figax(fig, ax[0], cmap=cmap, **kwargs)

ax[0].set_ylabel("x [A]")
ax[0].set_xlabel("y [A]")
Expand Down Expand Up @@ -814,6 +823,7 @@ def tune_angle_and_defocus(
figsize = kwargs.get(
"figsize", (4 * num_defocus_values, 4 * num_angle_values)
)
cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma")

fig = plt.figure(figsize=figsize)

Expand All @@ -837,6 +847,8 @@ def tune_angle_and_defocus(
self._visualize_figax(
fig,
ax=object_ax,
cmap=cmap,
**kwargs,
)

object_ax.set_title(
Expand Down Expand Up @@ -1024,7 +1036,7 @@ def reconstruct(
height_ratios=[1] * nrows + [1 / 4],
)

figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows + 1))
figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows + 1))
else:
spec = GridSpec(
ncols=ncols,
Expand All @@ -1033,9 +1045,9 @@ def reconstruct(
wspace=0.15,
)

figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows))
figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows))

kwargs.pop("figsize", None)
cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma")
fig = plt.figure(figsize=figsize)

xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float")
Expand Down Expand Up @@ -1681,7 +1693,7 @@ def subpixel_alignment(
)

figsize = kwargs.pop("figsize", (4 * ncols, sum(height_ratios)))
cmap = kwargs.pop("cmap", "magma")
cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma")
fig = plt.figure(figsize=figsize)

row_index = 0
Expand Down Expand Up @@ -2664,7 +2676,7 @@ def aberration_correct(
# plotting
if plot_corrected_phase:
figsize = kwargs.pop("figsize", (6, 6))
cmap = kwargs.pop("cmap", "magma")
cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma")

fig, ax = plt.subplots(figsize=figsize)

Expand Down Expand Up @@ -2761,7 +2773,7 @@ def depth_section(
)

figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows))
cmap = kwargs.pop("cmap", "magma")
cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma")

fig = plt.figure(figsize=figsize)

Expand Down

0 comments on commit 6982280

Please sign in to comment.