From 69845943be6a3914f82fc2048e55aae82bead785 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 20 Nov 2023 16:37:05 -0800 Subject: [PATCH] residual aberration fixes across all classes --- ...tive_mixedstate_multislice_ptychography.py | 63 ++++++++++++------- .../iterative_mixedstate_ptychography.py | 55 ++++++++++------ .../iterative_multislice_ptychography.py | 30 ++++----- .../phase/iterative_overlap_tomography.py | 20 ++---- .../iterative_simultaneous_ptychography.py | 11 ++-- .../iterative_singleslice_ptychography.py | 5 +- 6 files changed, 101 insertions(+), 83 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index e873d9199..03483c04e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -2569,9 +2569,6 @@ def _visualize_last_iteration( Pixels to pad by post rotating-cropping object """ - xp = self._xp - asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") @@ -2670,11 +2667,11 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = self.probe_fourier[0] if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array)) - ) + probe_array = self.probe_fourier_residual[0] + else: + probe_array = self.probe_fourier[0] + probe_array = Complex2RGB( probe_array, chroma_boost=chroma_boost, @@ -2774,7 +2771,6 @@ def _visualize_all_iterations( Pixels to pad by post rotating-cropping object """ asnumpy = self._asnumpy - xp = self._xp if not hasattr(self, "object_iterations"): raise ValueError( @@ -2914,17 +2910,11 @@ def _visualize_all_iterations( if plot_fourier_probe: probe_array = asnumpy( self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0] + probes[grid_range[n]][0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, ) ) - if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift( - xp.conjugate(self._known_aberrations_array) - ) - ) - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") @@ -3032,7 +3022,14 @@ def visualize( return self def show_fourier_probe( - self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, ): """ Plot probe in fourier space @@ -3041,6 +3038,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe scalebar: bool, optional if True, adds scalebar to probe pixelunits: str, optional @@ -3051,21 +3050,37 @@ def show_fourier_probe( asnumpy = self._asnumpy if probe is None: - probe = list(self.probe_fourier) + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) else: if isinstance(probe, np.ndarray) and probe.ndim == 2: probe = [probe] - probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] if pixelsize is None: pixelsize = self._reciprocal_sampling[1] if pixelunits is None: pixelunits = r"$\AA^{-1}$" - chroma_boost = kwargs.pop("chroma_boost", 2) + chroma_boost = kwargs.pop("chroma_boost", 1) show_complex( probe if len(probe) > 1 else probe[0], + cbar=cbar, scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, @@ -3077,6 +3092,7 @@ def show_fourier_probe( def show_transmitted_probe( self, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, **kwargs, ): """ @@ -3119,7 +3135,12 @@ def show_transmitted_probe( if plot_fourier_probe: bottom_row = [ - asnumpy(self._return_fourier_probe(probe)) + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) for probe in [ mean_transmitted, min_intensity_transmitted, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 426216116..f3deba614 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1864,9 +1864,6 @@ def _visualize_last_iteration( padding : int, optional Pixels to pad by post rotating-cropping object """ - xp = self._xp - asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") @@ -1962,11 +1959,11 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = self.probe_fourier[0] if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array)) - ) + probe_array = self.probe_fourier_residual[0] + else: + probe_array = self.probe_fourier[0] + probe_array = Complex2RGB( probe_array, chroma_boost=chroma_boost, @@ -2068,7 +2065,6 @@ def _visualize_all_iterations( Pixels to pad by post rotating-cropping object """ asnumpy = self._asnumpy - xp = self._xp if not hasattr(self, "object_iterations"): raise ValueError( @@ -2206,17 +2202,11 @@ def _visualize_all_iterations( if plot_fourier_probe: probe_array = asnumpy( self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0] + probes[grid_range[n]][0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, ) ) - if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift( - xp.conjugate(self._known_aberrations_array) - ) - ) - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") @@ -2325,7 +2315,14 @@ def visualize( return self def show_fourier_probe( - self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, ): """ Plot probe in fourier space @@ -2334,6 +2331,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe scalebar: bool, optional if True, adds scalebar to probe pixelunits: str, optional @@ -2344,21 +2343,37 @@ def show_fourier_probe( asnumpy = self._asnumpy if probe is None: - probe = list(self.probe_fourier) + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) else: if isinstance(probe, np.ndarray) and probe.ndim == 2: probe = [probe] - probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] if pixelsize is None: pixelsize = self._reciprocal_sampling[1] if pixelunits is None: pixelunits = r"$\AA^{-1}$" - chroma_boost = kwargs.pop("chroma_boost", 2) + chroma_boost = kwargs.pop("chroma_boost", 1) show_complex( probe if len(probe) > 1 else probe[0], + cbar=cbar, scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index f2304f72b..269919ddc 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2446,9 +2446,6 @@ def _visualize_last_iteration( Pixels to pad by post rotating-cropping object """ - xp = self._xp - asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") @@ -2547,11 +2544,11 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = self.probe_fourier if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array)) - ) + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( probe_array, chroma_boost=chroma_boost, @@ -2651,7 +2648,6 @@ def _visualize_all_iterations( Pixels to pad by post rotating-cropping object """ asnumpy = self._asnumpy - xp = self._xp if not hasattr(self, "object_iterations"): raise ValueError( @@ -2791,17 +2787,11 @@ def _visualize_all_iterations( if plot_fourier_probe: probe_array = asnumpy( self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, ) ) - if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift( - xp.conjugate(self._known_aberrations_array) - ) - ) - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") @@ -2909,6 +2899,7 @@ def visualize( def show_transmitted_probe( self, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, **kwargs, ): """ @@ -2951,7 +2942,12 @@ def show_transmitted_probe( if plot_fourier_probe: bottom_row = [ - asnumpy(self._return_fourier_probe(probe)) + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) for probe in [ mean_transmitted, min_intensity_transmitted, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 4f4d6eb9e..e8639d469 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -2627,7 +2627,6 @@ def _visualize_last_iteration( y_lims: tuple(float,float) min/max y indices """ - xp = self._xp asnumpy = self._asnumpy figsize = kwargs.pop("figsize", (8, 5)) @@ -2734,11 +2733,11 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = self.probe_fourier if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array)) - ) + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( probe_array, chroma_boost=chroma_boost, @@ -2849,7 +2848,6 @@ def _visualize_all_iterations( min/max y indices """ asnumpy = self._asnumpy - xp = self._xp if not hasattr(self, "object_iterations"): raise ValueError( @@ -3002,17 +3000,11 @@ def _visualize_all_iterations( if plot_fourier_probe: probe_array = asnumpy( self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, ) ) - if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift( - xp.conjugate(self._known_aberrations_array) - ) - ) - probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 3aac72c3b..6e31707c1 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -3096,9 +3096,6 @@ def _visualize_last_iteration( padding : int, optional Pixels to pad by post rotating-cropping object """ - xp = self._xp - asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") @@ -3228,11 +3225,11 @@ def _visualize_last_iteration( # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: - probe_array = self.probe_fourier if remove_initial_probe_aberrations: - probe_array *= asnumpy( - xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array)) - ) + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( probe_array, chroma_boost=chroma_boost, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index fa280a4ae..2c5e506e3 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1774,9 +1774,6 @@ def _visualize_last_iteration( padding : int, optional Pixels to pad by post rotating-cropping object """ - xp = self._xp - asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") @@ -1876,6 +1873,7 @@ def _visualize_last_iteration( probe_array = self.probe_fourier_residual else: probe_array = self.probe_fourier + probe_array = Complex2RGB( probe_array, chroma_boost=chroma_boost, @@ -1977,7 +1975,6 @@ def _visualize_all_iterations( Pixels to pad by post rotating-cropping object """ asnumpy = self._asnumpy - xp = self._xp if not hasattr(self, "object_iterations"): raise ValueError(