diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 767789df2..d88e09551 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1168,7 +1168,7 @@ def _normalize_diffraction_intensities( diffraction_intensities = self._asnumpy(diffraction_intensities) if positions_mask is not None: - number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) + number_of_patterns = np.count_nonzero(positions_mask.ravel()) else: number_of_patterns = np.prod(diffraction_intensities.shape[:2]) @@ -1217,7 +1217,7 @@ def _normalize_diffraction_intensities( for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): if positions_mask is not None: - if not self._positions_mask[rx, ry]: + if not positions_mask[rx, ry]: continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], @@ -1348,6 +1348,14 @@ def to_h5(self, group): data=metadata, ) + # saving multiple None positions_mask fix + if self._positions_mask is None: + positions_mask = None + elif self._positions_mask[0] is None: + positions_mask = None + else: + positions_mask = self._positions_mask + # preprocessing metadata self.metadata = Metadata( name="preprocess_metadata", @@ -1359,7 +1367,7 @@ def to_h5(self, group): "num_diffraction_patterns": self._num_diffraction_patterns, "sampling": self.sampling, "angular_sampling": self.angular_sampling, - "positions_mask": self._positions_mask, + "positions_mask": positions_mask, }, ) @@ -2146,6 +2154,7 @@ def plot_position_correction( def _return_fourier_probe( self, probe=None, + remove_initial_probe_aberrations=False, ): """ Returns complex fourier probe shifted to center of array from @@ -2155,6 +2164,8 @@ def _return_fourier_probe( ---------- probe: complex array, optional if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe Returns ------- @@ -2168,11 +2179,17 @@ def _return_fourier_probe( else: probe = xp.asarray(probe, dtype=xp.complex64) - return xp.fft.fftshift(xp.fft.fft2(probe), axes=(-2, -1)) + fourier_probe = xp.fft.fft2(probe) + + if remove_initial_probe_aberrations: + fourier_probe *= xp.conjugate(self._known_aberrations_array) + + return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) def _return_fourier_probe_from_centered_probe( self, probe=None, + remove_initial_probe_aberrations=False, ): """ Returns complex fourier probe shifted to center of array from @@ -2182,6 +2199,8 @@ def _return_fourier_probe_from_centered_probe( ---------- probe: complex array, optional if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe Returns ------- @@ -2189,7 +2208,10 @@ def _return_fourier_probe_from_centered_probe( Fourier-transformed and center-shifted probe. """ xp = self._xp - return self._return_fourier_probe(xp.fft.ifftshift(probe, axes=(-2, -1))) + return self._return_fourier_probe( + xp.fft.ifftshift(probe, axes=(-2, -1)), + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) def _return_centered_probe( self, @@ -2482,6 +2504,7 @@ def show_uncertainty_visualization( def show_fourier_probe( self, probe=None, + remove_initial_probe_aberrations=False, cbar=True, scalebar=True, pixelsize=None, @@ -2495,6 +2518,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe cbar: bool, optional if True, adds colorbar scalebar: bool, optional @@ -2506,10 +2531,11 @@ def show_fourier_probe( """ asnumpy = self._asnumpy - if probe is None: - probe = self.probe_fourier - else: - probe = asnumpy(self._return_fourier_probe(probe)) + probe = asnumpy( + self._return_fourier_probe( + probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations + ) + ) if pixelsize is None: pixelsize = self._reciprocal_sampling[1] @@ -2517,7 +2543,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) - chroma_boost = kwargs.pop("chroma_boost", 2) + chroma_boost = kwargs.pop("chroma_boost", 1) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2570,6 +2596,19 @@ def probe_fourier(self): asnumpy = self._asnumpy return asnumpy(self._return_fourier_probe(self._probe)) + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy( + self._return_fourier_probe( + self._probe, remove_initial_probe_aberrations=True + ) + ) + @property def probe_centered(self): """Current probe estimate shifted to the center""" diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index f4c10cb13..03483c04e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -217,13 +217,6 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") - self.set_save_defaults() # Data @@ -440,6 +433,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -2544,6 +2544,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -2562,6 +2563,8 @@ def _visualize_last_iteration( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe padding : int, optional Pixels to pad by post rotating-cropping object @@ -2569,10 +2572,7 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2667,9 +2667,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual[0] + else: + probe_array = self.probe_fourier[0] + probe_array = Complex2RGB( - self.probe_fourier[0], chroma_boost=chroma_boost + probe_array, + chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2735,6 +2742,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -2756,6 +2764,9 @@ def _visualize_all_iterations( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -2798,10 +2809,7 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2900,14 +2908,15 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0] - ) - ), - chroma_boost=chroma_boost, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2953,6 +2962,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2974,6 +2984,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2989,6 +3002,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -3000,6 +3014,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -3007,7 +3022,14 @@ def visualize( return self def show_fourier_probe( - self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, ): """ Plot probe in fourier space @@ -3016,6 +3038,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe scalebar: bool, optional if True, adds scalebar to probe pixelunits: str, optional @@ -3026,21 +3050,37 @@ def show_fourier_probe( asnumpy = self._asnumpy if probe is None: - probe = list(self.probe_fourier) + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) else: if isinstance(probe, np.ndarray) and probe.ndim == 2: probe = [probe] - probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] if pixelsize is None: pixelsize = self._reciprocal_sampling[1] if pixelunits is None: pixelunits = r"$\AA^{-1}$" - chroma_boost = kwargs.pop("chroma_boost", 2) + chroma_boost = kwargs.pop("chroma_boost", 1) show_complex( probe if len(probe) > 1 else probe[0], + cbar=cbar, scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, @@ -3052,6 +3092,7 @@ def show_fourier_probe( def show_transmitted_probe( self, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, **kwargs, ): """ @@ -3094,7 +3135,12 @@ def show_transmitted_probe( if plot_fourier_probe: bottom_row = [ - asnumpy(self._return_fourier_probe(probe)) + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) for probe in [ mean_transmitted, min_intensity_transmitted, @@ -3313,7 +3359,9 @@ def show_depth( gaussian_filter_sigma /= self.sampling[0] rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + plot_im = rotated_object[ + :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) + ] extent = [ 0, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index d68291143..f3deba614 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -164,12 +164,6 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -297,6 +291,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -1841,6 +1842,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -1857,16 +1859,15 @@ def _visualize_last_iteration( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe padding : int, optional Pixels to pad by post rotating-cropping object """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1958,10 +1959,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual[0] + else: + probe_array = self.probe_fourier[0] + probe_array = Complex2RGB( - self.probe_fourier[0], + probe_array, chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2029,6 +2036,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -2050,6 +2058,9 @@ def _visualize_all_iterations( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -2092,10 +2103,7 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2192,14 +2200,15 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0] - ) - ), - chroma_boost=chroma_boost, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2245,6 +2254,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2266,6 +2276,9 @@ def visualize( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2281,6 +2294,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2292,6 +2306,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2300,7 +2315,14 @@ def visualize( return self def show_fourier_probe( - self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, ): """ Plot probe in fourier space @@ -2309,6 +2331,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe scalebar: bool, optional if True, adds scalebar to probe pixelunits: str, optional @@ -2319,21 +2343,37 @@ def show_fourier_probe( asnumpy = self._asnumpy if probe is None: - probe = list(self.probe_fourier) + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) else: if isinstance(probe, np.ndarray) and probe.ndim == 2: probe = [probe] - probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] if pixelsize is None: pixelsize = self._reciprocal_sampling[1] if pixelunits is None: pixelunits = r"$\AA^{-1}$" - chroma_boost = kwargs.pop("chroma_boost", 2) + chroma_boost = kwargs.pop("chroma_boost", 1) show_complex( probe if len(probe) > 1 else probe[0], + cbar=cbar, scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 93e32b079..269919ddc 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -198,12 +198,6 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -420,6 +414,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -2420,6 +2421,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -2438,6 +2440,8 @@ def _visualize_last_iteration( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe padding : int, optional Pixels to pad by post rotating-cropping object @@ -2445,10 +2449,7 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2543,10 +2544,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, + probe_array, chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2612,6 +2619,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -2633,6 +2641,9 @@ def _visualize_all_iterations( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -2675,10 +2686,7 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2777,14 +2785,15 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] - ) - ), - chroma_boost=chroma_boost, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2828,6 +2837,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2849,6 +2859,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2864,6 +2877,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2875,6 +2889,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2884,6 +2899,7 @@ def visualize( def show_transmitted_probe( self, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, **kwargs, ): """ @@ -2926,7 +2942,12 @@ def show_transmitted_probe( if plot_fourier_probe: bottom_row = [ - asnumpy(self._return_fourier_probe(probe)) + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) for probe in [ mean_transmitted, min_intensity_transmitted, @@ -3145,7 +3166,9 @@ def show_depth( gaussian_filter_sigma /= self.sampling[0] rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) - plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + plot_im = rotated_object[ + :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) + ] extent = [ 0, diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index c49a1faac..bdd04b4b4 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -166,13 +166,6 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") - self.set_save_defaults() # Data @@ -512,15 +505,43 @@ def preprocess( ) ) + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask) + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_tilts, 1, 1) + ) + + if self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array."), + UserWarning, + ) + self._positions_mask = self._positions_mask.astype("bool") + else: + self._positions_mask = [None] * self._num_tilts + # Prepopulate various arrays - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) + if self._positions_mask[0] is None: + num_probes_per_tilt = [0] + for dc in self._datacube: + rx, ry = dc.Rshape + num_probes_per_tilt.append(rx * ry) + + num_probes_per_tilt = np.array(num_probes_per_tilt) + else: + num_probes_per_tilt = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) - self._num_diffraction_patterns = sum(num_probes_per_tilt) - self._cum_probes_per_tilt = np.cumsum(np.array(num_probes_per_tilt)) + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) self._mean_diffraction_intensity = [] self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index ddd13ac58..e8639d469 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -175,13 +175,6 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") - self.set_save_defaults() # Data @@ -453,14 +446,43 @@ def preprocess( ) ) + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask) + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_tilts, 1, 1) + ) + + if self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array."), + UserWarning, + ) + self._positions_mask = self._positions_mask.astype("bool") + else: + self._positions_mask = [None] * self._num_tilts + # Prepopulate various arrays - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - self._num_diffraction_patterns = sum(num_probes_per_tilt) - self._cum_probes_per_tilt = np.cumsum(np.array(num_probes_per_tilt)) + if self._positions_mask[0] is None: + num_probes_per_tilt = [0] + for dc in self._datacube: + rx, ry = dc.Rshape + num_probes_per_tilt.append(rx * ry) + + num_probes_per_tilt = np.array(num_probes_per_tilt) + else: + num_probes_per_tilt = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) self._mean_diffraction_intensity = [] self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) @@ -2572,6 +2594,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, projection_angle_deg: float, projection_axes: Tuple[int, int], x_lims: Tuple[int, int], @@ -2593,6 +2616,8 @@ def _visualize_last_iteration( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe projection_angle_deg: float Angle in degrees to rotate 3D array around prior to projection projection_axes: tuple(int,int) @@ -2602,13 +2627,12 @@ def _visualize_last_iteration( y_lims: tuple(float,float) min/max y indices """ + asnumpy = self._asnumpy + figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) asnumpy = self._asnumpy @@ -2709,10 +2733,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, + probe_array, chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2780,6 +2810,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], projection_angle_deg: float, projection_axes: Tuple[int, int], @@ -2802,6 +2833,9 @@ def _visualize_all_iterations( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations projection_angle_deg: float @@ -2855,10 +2889,7 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2967,14 +2998,15 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] - ) - ), - chroma_boost=chroma_boost, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -3020,6 +3052,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, projection_angle_deg: float = None, projection_axes: Tuple[int, int] = (0, 2), @@ -3042,6 +3075,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations projection_angle_deg: float @@ -3065,6 +3101,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, projection_angle_deg=projection_angle_deg, projection_axes=projection_axes, @@ -3079,6 +3116,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, projection_angle_deg=projection_angle_deg, projection_axes=projection_axes, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 716e1d782..c9dbb2fcf 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -587,7 +587,7 @@ def preprocess( self.recon_BF = asnumpy(self._recon_BF) if plot_average_bf: - figsize = kwargs.pop("figsize", (6, 12)) + figsize = kwargs.pop("figsize", (8, 4)) fig, ax = plt.subplots(1, 2, figsize=figsize) @@ -603,7 +603,9 @@ def preprocess( 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), ] - ax[1].imshow(self._dp_mask, extent=reciprocal_extent, cmap="gray") + ax[1].imshow( + self._asnumpy(self._dp_mask), extent=reciprocal_extent, cmap="gray" + ) ax[1].set_title("DP mask") ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") @@ -1333,6 +1335,7 @@ def aberration_fit( plot_BF_shifts_comparison: bool = None, upsampled: bool = True, force_transpose: bool = False, + force_rotation_deg: float = None, ): """ Fit aberrations to the measured image shifts. @@ -1363,7 +1366,9 @@ def aberration_fit( upsampled: bool If True, and upsampled BF is available, uses that for CTF FFT fitting. force_transpose: bool - If True, and fit_BF_shifts is True, flips the measured x and y shifts + If True, flips the measured x and y shifts. + force_rotation_deg: float + If not None, sets the rotation angle to value in degrees. """ xp = self._xp @@ -1379,23 +1384,41 @@ def aberration_fit( ) else: self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + self.transpose = force_transpose # Solve affine transformation m = asnumpy( xp.linalg.lstsq(self._probe_angles, self._xy_shifts_Ang, rcond=None)[0] ) - m_rotation, m_aberration = polar(m, side="right") - # Convert into rotation and aberration coefficients - self.rotation_Q_to_R_rads = -1 * np.arctan2(m_rotation[1, 0], m_rotation[0, 0]) - if np.abs(np.mod(self.rotation_Q_to_R_rads + np.pi, 2.0 * np.pi) - np.pi) > ( - np.pi * 0.5 - ): - self.rotation_Q_to_R_rads = ( - np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi + if force_rotation_deg is None: + m_rotation, m_aberration = polar(m, side="right") + + if force_transpose: + m_rotation = m_rotation.T + + # Convert into rotation and aberration coefficients + + self.rotation_Q_to_R_rads = -1 * np.arctan2( + m_rotation[1, 0], m_rotation[0, 0] ) - m_aberration = -1.0 * m_aberration + if np.abs( + np.mod(self.rotation_Q_to_R_rads + np.pi, 2.0 * np.pi) - np.pi + ) > (np.pi * 0.5): + self.rotation_Q_to_R_rads = ( + np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi + ) + m_aberration = -1.0 * m_aberration + else: + self.rotation_Q_to_R_rads = np.deg2rad(force_rotation_deg) + c, s = np.cos(self.rotation_Q_to_R_rads), np.sin(self.rotation_Q_to_R_rads) + + m_rotation = np.array([[c, -s], [s, c]]) + if force_transpose: + m_rotation = m_rotation.T + + m_aberration = m_rotation @ m self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 233d34e45..6e31707c1 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -153,12 +153,6 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -339,6 +333,27 @@ def preprocess( f"simultaneous_measurements_mode must be either '-+', '-0+', or '0+', not {self._simultaneous_measurements_mode}" ) + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask) + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_sim_measurements, 1, 1) + ) + + if self._positions_mask.dtype != "bool": + warnings.warn( + "`positions_mask` converted to `bool` array.", + UserWarning, + ) + self._positions_mask = self._positions_mask.astype("bool") + else: + self._positions_mask = [None] * self._num_sim_measurements + if force_com_shifts is None: force_com_shifts = [None, None, None] elif len(force_com_shifts) == self._num_sim_measurements: @@ -421,7 +436,7 @@ def preprocess( com_fitted_x_0, com_fitted_y_0, crop_patterns, - self._positions_mask, + self._positions_mask[0], ) # explicitly delete namescapes @@ -506,7 +521,7 @@ def preprocess( com_fitted_x_1, com_fitted_y_1, crop_patterns, - self._positions_mask, + self._positions_mask[1], ) # explicitly delete namescapes @@ -592,7 +607,7 @@ def preprocess( com_fitted_x_2, com_fitted_y_2, crop_patterns, - self._positions_mask, + self._positions_mask[2], ) # explicitly delete namescapes @@ -632,8 +647,8 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions, self._positions_mask - ) + self._scan_positions, self._positions_mask[0] + ) # TO-DO: generaltize to per-dataset probe positions # handle semiangle specified in pixels if self._semiangle_cutoff_pixels: @@ -3056,6 +3071,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -3074,6 +3090,9 @@ def _visualize_last_iteration( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -3101,10 +3120,7 @@ def _visualize_last_iteration( vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) extent = [ 0, @@ -3209,8 +3225,13 @@ def _visualize_last_iteration( # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, + probe_array, chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") @@ -3296,6 +3317,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -3317,6 +3339,9 @@ def _visualize_all_iterations( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -3329,6 +3354,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -3350,6 +3376,9 @@ def visualize( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -3365,6 +3394,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -3376,6 +3406,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 350d0a3cb..2c5e506e3 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -150,13 +150,6 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask is not None and positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converted to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") - self.set_save_defaults() # Data @@ -282,6 +275,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -472,7 +472,6 @@ def preprocess( self._probe_initial = self._probe.copy() self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -1750,6 +1749,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -1768,16 +1768,16 @@ def _visualize_last_iteration( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1869,10 +1869,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, + probe_array, chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -1940,6 +1946,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -1961,6 +1968,9 @@ def _visualize_all_iterations( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -2003,10 +2013,7 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - if plot_fourier_probe: - chroma_boost = kwargs.pop("chroma_boost", 2) - else: - chroma_boost = kwargs.pop("chroma_boost", 1) + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2104,14 +2111,14 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] - ) - ), - chroma_boost=chroma_boost, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") @@ -2158,6 +2165,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2179,6 +2187,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2195,6 +2206,7 @@ def visualize( plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, cbar=cbar, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, padding=padding, **kwargs, ) @@ -2205,6 +2217,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs,