Skip to content

Commit

Permalink
ptycho-tomo slim_preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 14, 2024
1 parent f58a763 commit 91d3bb5
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 1 deletion.
2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2860,7 +2860,7 @@ def aberration_correct(

# if needed, add low pass filter output image
if q_lowpass is not None:
im_fft_corr /= 1 + (xp.sqrt(qra) / q_lowpass) ** (2 * butterworth_order)
im_fft_corr /= 1 + (xp.sqrt(kra2) / q_lowpass) ** (2 * butterworth_order)

# Output phase image
self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr))
Expand Down
162 changes: 162 additions & 0 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
polar_aliases,
polar_symbols,
)
from py4DSTEM.process.utils import electron_wavelength_angstrom


class PtychographicTomography(
Expand Down Expand Up @@ -219,6 +220,167 @@ def __init__(
self._tilt_orientation_matrices = tuple(tilt_orientation_matrices)
self._num_measurements = num_tilts

def slim_preprocess(
self,
list_of_amplitudes,
list_of_probe_arrays,
object_array,
list_of_positions_px,
main_tilt_axis: str = "vertical",
reciprocal_sampling=None,
angular_sampling=None,
store_initial_arrays=True,
):
"""
Alternative function for ptychographic preprocessing.
This accepts the necessary arrays, and simply sets the appropriate attributes.
Parameters
----------
list_of_amplitudes: np.ndarray
List of corner-centered diffraction amplitudes each with dimension (N,Qx,Qy)
list_of_probe_arrays: np.ndarray
List of corner-centered guess for complex-valued probe of dimensions (...,Qx,Qy)
object_array: np.ndarray
Initial guess for object of dimensions (...,Px,Py)
list_of_positions_px: np.ndarray
List of initial guess for probe positions in pixels of dimensions (N,2)
reciprocal_sampling: (float,float), optional
(dk_x, dk_y) reciprocal space sampling in inverse Angstroms
angular_sampling: (float,float), optional
(dalpha_x, dalpha_y) angluar sampling in mrad
store_initial_arrays: bool
If True, preprocesed object and probe arrays are stored allowing reset=True in reconstruct.
Returns
--------
self: PtychographicReconstruction
Self to accommodate chaining
"""
xp = self._xp
xp_storage = self._xp_storage

if len(list_of_amplitudes) != self._num_measurements:
raise ValueError()
if len(list_of_probe_arrays) != self._num_measurements:
raise ValueError()
if len(list_of_positions_px) != self._num_measurements:
raise ValueError()

# attach arrays
num_probes_per_measurement = [0] + [amp.shape[0] for amp in list_of_amplitudes]
num_probes_per_measurement = np.array(num_probes_per_measurement)

self._mean_diffraction_intensity = []
self._probes_all = []
self._num_diffraction_patterns = num_probes_per_measurement.sum()
self._cum_probes_per_measurement = np.cumsum(num_probes_per_measurement)
self._positions_px_all = np.empty((self._num_diffraction_patterns, 2))

self._amplitudes_shape = np.array(list_of_amplitudes[0][0].shape)
self._amplitudes = xp_storage.empty(
(self._num_diffraction_patterns,) + self._amplitudes_shape
)
self._object = xp.asarray(object_array, dtype=xp.float32)

if store_initial_arrays:
self._probes_all_initial = []
self._probes_all_initial_aperture = []
else:
self._probes_all_initial_aperture = [None] * self._num_measurements

for index in range(self._num_measurements):
amps = xp_storage.asarray(
list_of_amplitudes[index], dtype=xp_storage.float32
)
probe = xp.asarray(list_of_probe_arrays[index], dtype=xp.complex64)
pos = xp_storage.asarray(
list_of_positions_px[index], dtype=xp_storage.float32
)
idx_start = self._cum_probes_per_measurement[index]
idx_end = self._cum_probes_per_measurement[index + 1]
self._amplitudes[idx_start:idx_end] = amps
self._positions_px_all[idx_start:idx_end] = pos
self._mean_diffraction_intensity.append((amps**2).sum((-1, -2)).mean(0))

self._probes_all.append(probe)
if store_initial_arrays:
self._probe_initial = probe.copy()
self._probe_initial_aperture = xp.abs(xp.fft.fft2(probe))

# specify sampling
if angular_sampling is None and reciprocal_sampling is None:
raise ValueError(
"One of angular or reciprocal calibration has to be specified."
)

wavelength = electron_wavelength_angstrom(self._energy)
if angular_sampling is not None:
if reciprocal_sampling is not None:
raise ValueError(
"Only one of angular or reciprocal calibration can be specified."
)
self._angular_sampling = tuple(angular_sampling)
self._reciprocal_sampling = tuple(
d_alpha / wavelength / 1e3 for d_alpha in self._angular_sampling
)
else:
self._reciprocal_sampling = tuple(reciprocal_sampling)
self._angular_sampling = tuple(
d_k * wavelength * 1e3 for d_k in self._reciprocal_sampling
)

# necessary probe attributes
self._resample_exit_waves = False
self._region_of_interest_shape = self._amplitudes_shape

# necessary object attributes
self._object_shape = self._object.shape[-2:]
self._num_voxels = self._object.shape[0]
self._object_fov_mask_inverse = np.full(self._object_shape, False)

# Precomputed propagator arrays
if main_tilt_axis == "vertical":
thickness = self._object_shape[1] * self.sampling[1]
elif main_tilt_axis == "horizontal":
thickness = self._object_shape[0] * self.sampling[0]
else:
thickness_h = self._object_shape[1] * self.sampling[1]
thickness_v = self._object_shape[0] * self.sampling[0]
thickness = max(thickness_h, thickness_v)

self._slice_thicknesses = np.tile(
thickness / self._num_slices, self._num_slices - 1
)
self._propagator_arrays = self._precompute_propagator_arrays(
self._region_of_interest_shape,
self.sampling,
self._energy,
self._slice_thicknesses,
)

# necessary general attributes
self._rotation_best_transpose = False
self._rotation_best_rad = 0
self._preprocessed = True

# necessary restarting attributes
if store_initial_arrays:
self._object_initial = self._object.copy()
self._object_initial_type = self._object_type

self._positions_px_initial_all = self._positions_px_all.copy()
self._positions_initial_all = self._positions_px_initial_all.copy()
self._positions_initial_all[:, 0] *= self.sampling[0]
self._positions_initial_all[:, 1] *= self.sampling[1]

self._positions_initial = self._return_average_positions()
if self._positions_initial is not None:
self._positions_initial[:, 0] *= self.sampling[0]
self._positions_initial[:, 1] *= self.sampling[1]

return self

def preprocess(
self,
diffraction_intensities_shape: Tuple[int, int] = None,
Expand Down

0 comments on commit 91d3bb5

Please sign in to comment.