Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thankfully these phase_contrast changes are as easy as pie #577

Merged
merged 10 commits into from
Nov 22, 2023
59 changes: 49 additions & 10 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def _normalize_diffraction_intensities(

diffraction_intensities = self._asnumpy(diffraction_intensities)
if positions_mask is not None:
number_of_patterns = np.count_nonzero(self._positions_mask.ravel())
number_of_patterns = np.count_nonzero(positions_mask.ravel())
else:
number_of_patterns = np.prod(diffraction_intensities.shape[:2])

Expand Down Expand Up @@ -1217,7 +1217,7 @@ def _normalize_diffraction_intensities(
for rx in range(diffraction_intensities.shape[0]):
for ry in range(diffraction_intensities.shape[1]):
if positions_mask is not None:
if not self._positions_mask[rx, ry]:
if not positions_mask[rx, ry]:
continue
Comment on lines +1220 to 1221
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this skipping posistions where posistions_mask[rx,ry] == False if so is the following correct and more readable"

if posistion_mask[rx,ry] is False:
    continue 

Is it worth splitting this into two for loops so that posistion_mask is Not None, is checked once, and not per probe position?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, this is skipping positions where the mask is false. Not sure I understand the for loop suggestion, this has to be checked per probe position.

Re: style - PEP8 explicitly says comparing boolean values to True/False with == or is is bad form.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. One day I'll learn to not try and refactor numpy Boolean masks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the for loop, this checks if posistion_mask is not None at every probe position. Would it be better to do ~:

 if positions_mask is None:
     intensities = get_shifted_ar(
                    diffraction_intensities[rx, ry],
                    -com_fitted_x[rx, ry],
                    -com_fitted_y[rx, ry],
                    bilinear=True,
                    device="cpu",
                )

                if crop_patterns:
                    intensities = intensities[crop_mask].reshape(
                        region_of_interest_shape
                    )

                mean_intensity += np.sum(intensities)
                amplitudes[counter] = np.sqrt(np.maximum(intensities, 0))
                counter += 1
else:
    for rx in range(diffraction_intensities.shape[0]):
            for ry in range(diffraction_intensities.shape[1]):
                  if not positions_mask[rx, ry]:
                      continue
              intensities = get_shifted_ar(
                  diffraction_intensities[rx, ry],
                  -com_fitted_x[rx, ry],
                  -com_fitted_y[rx, ry],
                  bilinear=True,
                  device="cpu",
              )

              if crop_patterns:
                  intensities = intensities[crop_mask].reshape(
                      region_of_interest_shape
                  )

                mean_intensity += np.sum(intensities)
                amplitudes[counter] = np.sqrt(np.maximum(intensities, 0))
                counter += 1

intensities = get_shifted_ar(
diffraction_intensities[rx, ry],
Expand Down Expand Up @@ -1348,6 +1348,14 @@ def to_h5(self, group):
data=metadata,
)

# saving multiple None positions_mask fix
if self._positions_mask is None:
positions_mask = None
elif self._positions_mask[0] is None:
positions_mask = None
else:
positions_mask = self._positions_mask

# preprocessing metadata
self.metadata = Metadata(
name="preprocess_metadata",
Expand All @@ -1359,7 +1367,7 @@ def to_h5(self, group):
"num_diffraction_patterns": self._num_diffraction_patterns,
"sampling": self.sampling,
"angular_sampling": self.angular_sampling,
"positions_mask": self._positions_mask,
"positions_mask": positions_mask,
},
)

Expand Down Expand Up @@ -2146,6 +2154,7 @@ def plot_position_correction(
def _return_fourier_probe(
self,
probe=None,
remove_initial_probe_aberrations=False,
):
"""
Returns complex fourier probe shifted to center of array from
Expand All @@ -2155,6 +2164,8 @@ def _return_fourier_probe(
----------
probe: complex array, optional
if None is specified, uses self._probe
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe

Returns
-------
Expand All @@ -2168,11 +2179,17 @@ def _return_fourier_probe(
else:
probe = xp.asarray(probe, dtype=xp.complex64)

return xp.fft.fftshift(xp.fft.fft2(probe), axes=(-2, -1))
fourier_probe = xp.fft.fft2(probe)

if remove_initial_probe_aberrations:
fourier_probe *= xp.conjugate(self._known_aberrations_array)

return xp.fft.fftshift(fourier_probe, axes=(-2, -1))

def _return_fourier_probe_from_centered_probe(
self,
probe=None,
remove_initial_probe_aberrations=False,
):
"""
Returns complex fourier probe shifted to center of array from
Expand All @@ -2182,14 +2199,19 @@ def _return_fourier_probe_from_centered_probe(
----------
probe: complex array, optional
if None is specified, uses self._probe
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe

Returns
-------
fourier_probe: np.ndarray
Fourier-transformed and center-shifted probe.
"""
xp = self._xp
return self._return_fourier_probe(xp.fft.ifftshift(probe, axes=(-2, -1)))
return self._return_fourier_probe(
xp.fft.ifftshift(probe, axes=(-2, -1)),
remove_initial_probe_aberrations=remove_initial_probe_aberrations,
)

def _return_centered_probe(
self,
Expand Down Expand Up @@ -2482,6 +2504,7 @@ def show_uncertainty_visualization(
def show_fourier_probe(
self,
probe=None,
remove_initial_probe_aberrations=False,
cbar=True,
scalebar=True,
pixelsize=None,
Comment on lines 2504 to 2510
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of trivial, but it would be nice if these were type hinted

Expand All @@ -2495,6 +2518,8 @@ def show_fourier_probe(
----------
probe: complex array, optional
if None is specified, uses the `probe_fourier` property
remove_initial_probe_aberrations: bool, optional
If True, removes initial probe aberrations from Fourier probe
cbar: bool, optional
if True, adds colorbar
scalebar: bool, optional
Expand All @@ -2506,18 +2531,19 @@ def show_fourier_probe(
"""
asnumpy = self._asnumpy

if probe is None:
probe = self.probe_fourier
else:
probe = asnumpy(self._return_fourier_probe(probe))
probe = asnumpy(
self._return_fourier_probe(
probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations
)
)

if pixelsize is None:
pixelsize = self._reciprocal_sampling[1]
if pixelunits is None:
pixelunits = r"$\AA^{-1}$"

figsize = kwargs.pop("figsize", (6, 6))
chroma_boost = kwargs.pop("chroma_boost", 2)
chroma_boost = kwargs.pop("chroma_boost", 1)
Comment on lines 2545 to +2546
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is cool I didn't know you could pop from a dict


fig, ax = plt.subplots(figsize=figsize)
show_complex(
Expand Down Expand Up @@ -2570,6 +2596,19 @@ def probe_fourier(self):
asnumpy = self._asnumpy
return asnumpy(self._return_fourier_probe(self._probe))

@property
def probe_fourier_residual(self):
"""Current probe estimate in Fourier space"""
if not hasattr(self, "_probe"):
return None

asnumpy = self._asnumpy
return asnumpy(
self._return_fourier_probe(
self._probe, remove_initial_probe_aberrations=True
)
)

@property
def probe_centered(self):
"""Current probe estimate shifted to the center"""
Expand Down
Loading