diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 19543f374..9255000e4 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -2860,7 +2860,7 @@ def aberration_correct( # if needed, add low pass filter output image if q_lowpass is not None: - im_fft_corr /= 1 + (xp.sqrt(qra) / q_lowpass) ** (2 * butterworth_order) + im_fft_corr /= 1 + (xp.sqrt(kra2) / q_lowpass) ** (2 * butterworth_order) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) diff --git a/py4DSTEM/process/phase/ptychographic_tomography.py b/py4DSTEM/process/phase/ptychographic_tomography.py index e6fadedc7..4a22957ea 100644 --- a/py4DSTEM/process/phase/ptychographic_tomography.py +++ b/py4DSTEM/process/phase/ptychographic_tomography.py @@ -46,6 +46,7 @@ polar_aliases, polar_symbols, ) +from py4DSTEM.process.utils import electron_wavelength_angstrom class PtychographicTomography( @@ -219,6 +220,167 @@ def __init__( self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_measurements = num_tilts + def slim_preprocess( + self, + list_of_amplitudes, + list_of_probe_arrays, + object_array, + list_of_positions_px, + main_tilt_axis: str = "vertical", + reciprocal_sampling=None, + angular_sampling=None, + store_initial_arrays=True, + ): + """ + Alternative function for ptychographic preprocessing. + This accepts the necessary arrays, and simply sets the appropriate attributes. + + Parameters + ---------- + list_of_amplitudes: np.ndarray + List of corner-centered diffraction amplitudes each with dimension (N,Qx,Qy) + list_of_probe_arrays: np.ndarray + List of corner-centered guess for complex-valued probe of dimensions (...,Qx,Qy) + object_array: np.ndarray + Initial guess for object of dimensions (...,Px,Py) + list_of_positions_px: np.ndarray + List of initial guess for probe positions in pixels of dimensions (N,2) + reciprocal_sampling: (float,float), optional + (dk_x, dk_y) reciprocal space sampling in inverse Angstroms + angular_sampling: (float,float), optional + (dalpha_x, dalpha_y) angluar sampling in mrad + store_initial_arrays: bool + If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct. + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + xp_storage = self._xp_storage + + if len(list_of_amplitudes) != self._num_measurements: + raise ValueError() + if len(list_of_probe_arrays) != self._num_measurements: + raise ValueError() + if len(list_of_positions_px) != self._num_measurements: + raise ValueError() + + # attach arrays + num_probes_per_measurement = [0] + [amp.shape[0] for amp in list_of_amplitudes] + num_probes_per_measurement = np.array(num_probes_per_measurement) + + self._mean_diffraction_intensity = [] + self._probes_all = [] + self._num_diffraction_patterns = num_probes_per_measurement.sum() + self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement) + self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) + + self._amplitudes_shape = np.array(list_of_amplitudes[0][0].shape) + self._amplitudes = xp_storage.empty( + (self._num_diffraction_patterns,) + self._amplitudes_shape + ) + self._object = xp.asarray(object_array, dtype=xp.float32) + + if store_initial_arrays: + self._probes_all_initial = [] + self._probes_all_initial_aperture = [] + else: + self._probes_all_initial_aperture = [None] * self._num_measurements + + for index in range(self._num_measurements): + amps = xp_storage.asarray( + list_of_amplitudes[index], dtype=xp_storage.float32 + ) + probe = xp.asarray(list_of_probe_arrays[index], dtype=xp.complex64) + pos = xp_storage.asarray( + list_of_positions_px[index], dtype=xp_storage.float32 + ) + idx_start = self._cum_probes_per_measurement[index] + idx_end = self._cum_probes_per_measurement[index + 1] + self._amplitudes[idx_start:idx_end] = amps + self._positions_px_all[idx_start:idx_end] = pos + self._mean_diffraction_intensity.append((amps**2).sum((-1, -2)).mean(0)) + + self._probes_all.append(probe) + if store_initial_arrays: + self._probe_initial = probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(probe)) + + # specify sampling + if angular_sampling is None and reciprocal_sampling is None: + raise ValueError( + "One of angular or reciprocal calibration has to be specified." + ) + + wavelength = electron_wavelength_angstrom(self._energy) + if angular_sampling is not None: + if reciprocal_sampling is not None: + raise ValueError( + "Only one of angular or reciprocal calibration can be specified." + ) + self._angular_sampling = tuple(angular_sampling) + self._reciprocal_sampling = tuple( + d_alpha / wavelength / 1e3 for d_alpha in self._angular_sampling + ) + else: + self._reciprocal_sampling = tuple(reciprocal_sampling) + self._angular_sampling = tuple( + d_k * wavelength * 1e3 for d_k in self._reciprocal_sampling + ) + + # necessary probe attributes + self._resample_exit_waves = False + self._region_of_interest_shape = self._amplitudes_shape + + # necessary object attributes + self._object_shape = self._object.shape[-2:] + self._num_voxels = self._object.shape[0] + self._object_fov_mask_inverse = np.full(self._object_shape, False) + + # Precomputed propagator arrays + if main_tilt_axis == "vertical": + thickness = self._object_shape[1] * self.sampling[1] + elif main_tilt_axis == "horizontal": + thickness = self._object_shape[0] * self.sampling[0] + else: + thickness_h = self._object_shape[1] * self.sampling[1] + thickness_v = self._object_shape[0] * self.sampling[0] + thickness = max(thickness_h, thickness_v) + + self._slice_thicknesses = np.tile( + thickness / self._num_slices, self._num_slices - 1 + ) + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # necessary general attributes + self._rotation_best_transpose = False + self._rotation_best_rad = 0 + self._preprocessed = True + + # necessary restarting attributes + if store_initial_arrays: + self._object_initial = self._object.copy() + self._object_initial_type = self._object_type + + self._positions_px_initial_all = self._positions_px_all.copy() + self._positions_initial_all = self._positions_px_initial_all.copy() + self._positions_initial_all[:, 0] *= self.sampling[0] + self._positions_initial_all[:, 1] *= self.sampling[1] + + self._positions_initial = self._return_average_positions() + if self._positions_initial is not None: + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + return self + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None,