Skip to content

Commit

Permalink
residual aberration fixes across all classes
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 21, 2023
1 parent ed830e0 commit 6984594
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2569,9 +2569,6 @@ def _visualize_last_iteration(
Pixels to pad by post rotating-cropping object
"""
xp = self._xp
asnumpy = self._asnumpy

figsize = kwargs.pop("figsize", (8, 5))
cmap = kwargs.pop("cmap", "magma")

Expand Down Expand Up @@ -2670,11 +2667,11 @@ def _visualize_last_iteration(

ax = fig.add_subplot(spec[0, 1])
if plot_fourier_probe:
probe_array = self.probe_fourier[0]
if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array))
)
probe_array = self.probe_fourier_residual[0]
else:
probe_array = self.probe_fourier[0]

probe_array = Complex2RGB(
probe_array,
chroma_boost=chroma_boost,
Expand Down Expand Up @@ -2774,7 +2771,6 @@ def _visualize_all_iterations(
Pixels to pad by post rotating-cropping object
"""
asnumpy = self._asnumpy
xp = self._xp

if not hasattr(self, "object_iterations"):
raise ValueError(
Expand Down Expand Up @@ -2914,17 +2910,11 @@ def _visualize_all_iterations(
if plot_fourier_probe:
probe_array = asnumpy(
self._return_fourier_probe_from_centered_probe(
probes[grid_range[n]][0]
probes[grid_range[n]][0],
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)

if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(
xp.conjugate(self._known_aberrations_array)
)
)

probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost)

ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]")
Expand Down Expand Up @@ -3032,7 +3022,14 @@ def visualize(
return self

def show_fourier_probe(
self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs
self,
probe=None,
remove_initial_probe_aberrations=False,
cbar=True,
scalebar=True,
pixelsize=None,
pixelunits=None,
**kwargs,
):
"""
Plot probe in fourier space
Expand All @@ -3041,6 +3038,8 @@ def show_fourier_probe(
----------
probe: complex array, optional
if None is specified, uses the `probe_fourier` property
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe
scalebar: bool, optional
if True, adds scalebar to probe
pixelunits: str, optional
Expand All @@ -3051,21 +3050,37 @@ def show_fourier_probe(
asnumpy = self._asnumpy

if probe is None:
probe = list(self.probe_fourier)
probe = list(
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
)
else:
if isinstance(probe, np.ndarray) and probe.ndim == 2:
probe = [probe]
probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe]
probe = [
asnumpy(
self._return_fourier_probe(
pr,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for pr in probe
]

if pixelsize is None:
pixelsize = self._reciprocal_sampling[1]
if pixelunits is None:
pixelunits = r"$\AA^{-1}$"

chroma_boost = kwargs.pop("chroma_boost", 2)
chroma_boost = kwargs.pop("chroma_boost", 1)

show_complex(
probe if len(probe) > 1 else probe[0],
cbar=cbar,
scalebar=scalebar,
pixelsize=pixelsize,
pixelunits=pixelunits,
Expand All @@ -3077,6 +3092,7 @@ def show_fourier_probe(
def show_transmitted_probe(
self,
plot_fourier_probe: bool = False,
remove_initial_probe_aberrations=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -3119,7 +3135,12 @@ def show_transmitted_probe(

if plot_fourier_probe:
bottom_row = [
asnumpy(self._return_fourier_probe(probe))
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for probe in [
mean_transmitted,
min_intensity_transmitted,
Expand Down
55 changes: 35 additions & 20 deletions py4DSTEM/process/phase/iterative_mixedstate_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,9 +1864,6 @@ def _visualize_last_iteration(
padding : int, optional
Pixels to pad by post rotating-cropping object
"""
xp = self._xp
asnumpy = self._asnumpy

figsize = kwargs.pop("figsize", (8, 5))
cmap = kwargs.pop("cmap", "magma")

Expand Down Expand Up @@ -1962,11 +1959,11 @@ def _visualize_last_iteration(
ax = fig.add_subplot(spec[0, 1])

if plot_fourier_probe:
probe_array = self.probe_fourier[0]
if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array))
)
probe_array = self.probe_fourier_residual[0]
else:
probe_array = self.probe_fourier[0]

probe_array = Complex2RGB(
probe_array,
chroma_boost=chroma_boost,
Expand Down Expand Up @@ -2068,7 +2065,6 @@ def _visualize_all_iterations(
Pixels to pad by post rotating-cropping object
"""
asnumpy = self._asnumpy
xp = self._xp

if not hasattr(self, "object_iterations"):
raise ValueError(
Expand Down Expand Up @@ -2206,17 +2202,11 @@ def _visualize_all_iterations(
if plot_fourier_probe:
probe_array = asnumpy(
self._return_fourier_probe_from_centered_probe(
probes[grid_range[n]][0]
probes[grid_range[n]][0],
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)

if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(
xp.conjugate(self._known_aberrations_array)
)
)

probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost)

ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]")
Expand Down Expand Up @@ -2325,7 +2315,14 @@ def visualize(
return self

def show_fourier_probe(
self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs
self,
probe=None,
remove_initial_probe_aberrations=False,
cbar=True,
scalebar=True,
pixelsize=None,
pixelunits=None,
**kwargs,
):
"""
Plot probe in fourier space
Expand All @@ -2334,6 +2331,8 @@ def show_fourier_probe(
----------
probe: complex array, optional
if None is specified, uses the `probe_fourier` property
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe
scalebar: bool, optional
if True, adds scalebar to probe
pixelunits: str, optional
Expand All @@ -2344,21 +2343,37 @@ def show_fourier_probe(
asnumpy = self._asnumpy

if probe is None:
probe = list(self.probe_fourier)
probe = list(
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
)
else:
if isinstance(probe, np.ndarray) and probe.ndim == 2:
probe = [probe]
probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe]
probe = [
asnumpy(
self._return_fourier_probe(
pr,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for pr in probe
]

if pixelsize is None:
pixelsize = self._reciprocal_sampling[1]
if pixelunits is None:
pixelunits = r"$\AA^{-1}$"

chroma_boost = kwargs.pop("chroma_boost", 2)
chroma_boost = kwargs.pop("chroma_boost", 1)

show_complex(
probe if len(probe) > 1 else probe[0],
cbar=cbar,
scalebar=scalebar,
pixelsize=pixelsize,
pixelunits=pixelunits,
Expand Down
30 changes: 13 additions & 17 deletions py4DSTEM/process/phase/iterative_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2446,9 +2446,6 @@ def _visualize_last_iteration(
Pixels to pad by post rotating-cropping object
"""
xp = self._xp
asnumpy = self._asnumpy

figsize = kwargs.pop("figsize", (8, 5))
cmap = kwargs.pop("cmap", "magma")

Expand Down Expand Up @@ -2547,11 +2544,11 @@ def _visualize_last_iteration(

ax = fig.add_subplot(spec[0, 1])
if plot_fourier_probe:
probe_array = self.probe_fourier
if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array))
)
probe_array = self.probe_fourier_residual
else:
probe_array = self.probe_fourier

probe_array = Complex2RGB(
probe_array,
chroma_boost=chroma_boost,
Expand Down Expand Up @@ -2651,7 +2648,6 @@ def _visualize_all_iterations(
Pixels to pad by post rotating-cropping object
"""
asnumpy = self._asnumpy
xp = self._xp

if not hasattr(self, "object_iterations"):
raise ValueError(
Expand Down Expand Up @@ -2791,17 +2787,11 @@ def _visualize_all_iterations(
if plot_fourier_probe:
probe_array = asnumpy(
self._return_fourier_probe_from_centered_probe(
probes[grid_range[n]]
probes[grid_range[n]],
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)

if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(
xp.conjugate(self._known_aberrations_array)
)
)

probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost)

ax.set_title(f"Iter: {grid_range[n]} Fourier probe")
Expand Down Expand Up @@ -2909,6 +2899,7 @@ def visualize(
def show_transmitted_probe(
self,
plot_fourier_probe: bool = False,
remove_initial_probe_aberrations=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -2951,7 +2942,12 @@ def show_transmitted_probe(

if plot_fourier_probe:
bottom_row = [
asnumpy(self._return_fourier_probe(probe))
asnumpy(
self._return_fourier_probe(
probe,
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)
for probe in [
mean_transmitted,
min_intensity_transmitted,
Expand Down
20 changes: 6 additions & 14 deletions py4DSTEM/process/phase/iterative_overlap_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,7 +2627,6 @@ def _visualize_last_iteration(
y_lims: tuple(float,float)
min/max y indices
"""
xp = self._xp
asnumpy = self._asnumpy

figsize = kwargs.pop("figsize", (8, 5))
Expand Down Expand Up @@ -2734,11 +2733,11 @@ def _visualize_last_iteration(

ax = fig.add_subplot(spec[0, 1])
if plot_fourier_probe:
probe_array = self.probe_fourier
if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(xp.conjugate(self._known_aberrations_array))
)
probe_array = self.probe_fourier_residual
else:
probe_array = self.probe_fourier

probe_array = Complex2RGB(
probe_array,
chroma_boost=chroma_boost,
Expand Down Expand Up @@ -2849,7 +2848,6 @@ def _visualize_all_iterations(
min/max y indices
"""
asnumpy = self._asnumpy
xp = self._xp

if not hasattr(self, "object_iterations"):
raise ValueError(
Expand Down Expand Up @@ -3002,17 +3000,11 @@ def _visualize_all_iterations(
if plot_fourier_probe:
probe_array = asnumpy(
self._return_fourier_probe_from_centered_probe(
probes[grid_range[n]]
probes[grid_range[n]],
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)
)

if remove_initial_probe_aberrations:
probe_array *= asnumpy(
xp.fft.ifftshift(
xp.conjugate(self._known_aberrations_array)
)
)

probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost)

ax.set_title(f"Iter: {grid_range[n]} Fourier probe")
Expand Down
Loading

0 comments on commit 6984594

Please sign in to comment.