From 6982280273a46d1d052e856470a7adab59843232 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 16 Dec 2023 11:56:02 -0800 Subject: [PATCH] normalize_order = 0 bug --- py4DSTEM/process/phase/iterative_parallax.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 3d834a0a7..d93a4a27f 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -260,7 +260,7 @@ def _populate_instance(self, group): def preprocess( self, - edge_blend: float = 8.0, + edge_blend: float = 16.0, threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, @@ -462,18 +462,25 @@ def preprocess( (1, 2, 0), ) - # initalize + # initialize stack_shape = ( self._num_bf_images, self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: + self._normalized_stack = True self._stack_BF_shifted = xp.ones(stack_shape, dtype=xp.float32) self._stack_BF_unshifted = xp.ones(stack_shape, xp.float32) if normalize_order == 0: - all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] + # all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] + weights = xp.average( + all_bfs.reshape((self._num_bf_images, -1)), + weights=self._window_edge.ravel(), + axis=1, + ) + all_bfs /= weights[:, None, None] self._stack_BF_shifted[ :, @@ -600,6 +607,7 @@ def preprocess( ) else: + self._normalized_stack = False all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF_shifted = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF_unshifted = xp.full(stack_shape, all_means[:, None, None]) @@ -698,10 +706,11 @@ def preprocess( if plot_average_bf: figsize = kwargs.pop("figsize", (8, 4)) + cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma") fig, ax = plt.subplots(1, 2, figsize=figsize) - self._visualize_figax(fig, ax[0], **kwargs) + self._visualize_figax(fig, ax[0], cmap=cmap, **kwargs) ax[0].set_ylabel("x [A]") ax[0].set_xlabel("y [A]") @@ -814,6 +823,7 @@ def tune_angle_and_defocus( figsize = kwargs.get( "figsize", (4 * num_defocus_values, 4 * num_angle_values) ) + cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma") fig = plt.figure(figsize=figsize) @@ -837,6 +847,8 @@ def tune_angle_and_defocus( self._visualize_figax( fig, ax=object_ax, + cmap=cmap, + **kwargs, ) object_ax.set_title( @@ -1024,7 +1036,7 @@ def reconstruct( height_ratios=[1] * nrows + [1 / 4], ) - figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows + 1)) + figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows + 1)) else: spec = GridSpec( ncols=ncols, @@ -1033,9 +1045,9 @@ def reconstruct( wspace=0.15, ) - figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows)) + figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows)) - kwargs.pop("figsize", None) + cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma") fig = plt.figure(figsize=figsize) xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float") @@ -1681,7 +1693,7 @@ def subpixel_alignment( ) figsize = kwargs.pop("figsize", (4 * ncols, sum(height_ratios))) - cmap = kwargs.pop("cmap", "magma") + cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma") fig = plt.figure(figsize=figsize) row_index = 0 @@ -2664,7 +2676,7 @@ def aberration_correct( # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") + cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma") fig, ax = plt.subplots(figsize=figsize) @@ -2761,7 +2773,7 @@ def depth_section( ) figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows)) - cmap = kwargs.pop("cmap", "magma") + cmap = kwargs.pop("cmap", "RdBu_r" if self._normalized_stack else "magma") fig = plt.figure(figsize=figsize)