diff --git a/py4DSTEM/process/phase/direct_ptychography.py b/py4DSTEM/process/phase/direct_ptychography.py index 9d0e72828..cb33a23c4 100644 --- a/py4DSTEM/process/phase/direct_ptychography.py +++ b/py4DSTEM/process/phase/direct_ptychography.py @@ -25,6 +25,7 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube +from py4DSTEM.process.calibration import get_probe_size from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import ( ComplexProbe, @@ -38,7 +39,6 @@ unwrap_phase_2d_skimage, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar -from py4DSTEM.process.calibration import get_probe_size _aberration_names = { (1, 0): "C1", diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index e6442821f..823319658 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -4,7 +4,7 @@ """ import warnings -from typing import Tuple +from typing import Mapping, Tuple import matplotlib.pyplot as plt import numpy as np @@ -24,13 +24,15 @@ lanczos_interpolate_array, lanczos_kernel_density_estimate, pixel_rolling_kernel_density_estimate, + polar_aberrations_to_cartesian, + polar_aliases, + polar_symbols, ) from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom from py4DSTEM.visualize import return_scaled_histogram_ordering from scipy.linalg import polar from scipy.ndimage import distance_transform_edt -from scipy.optimize import minimize from scipy.special import comb try: @@ -39,20 +41,20 @@ cp = np _aberration_names = { - (1, 0): "C1 ", - (1, 2): "stig ", - (2, 1): "coma ", - (2, 3): "trefoil ", - (3, 0): "C3 ", - (3, 2): "stig2 ", - (3, 4): "quadfoil ", - (4, 1): "coma2 ", - (4, 3): "trefoil2 ", - (4, 5): "pentafoil ", - (5, 0): "C5 ", - (5, 2): "stig3 ", - (5, 4): "quadfoil2 ", - (5, 6): "hexafoil ", + (1, 0): "C1", + (1, 2): "stig", + (2, 1): "coma", + (2, 3): "trefoil", + (3, 0): "C3", + (3, 2): "stig2", + (3, 4): "quadfoil", + (4, 1): "coma2", + (4, 3): "trefoil2", + (4, 5): "pentafoil", + (5, 0): "C5", + (5, 2): "stig3", + (5, 4): "quadfoil2", + (5, 6): "hexafoil", } @@ -81,10 +83,12 @@ def __init__( datacube: DataCube = None, verbose: bool = True, object_padding_px: Tuple[int, int] = (32, 32), + polar_parameters: Mapping[str, float] = None, device: str = "cpu", storage: str = None, clear_fft_cache: bool = True, name: str = "parallax_reconstruction", + **kwargs, ): Custom.__init__(self, name=name) @@ -97,6 +101,17 @@ def __init__( self.set_device(device, clear_fft_cache) self.set_storage(storage) + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) self.set_save_defaults() # Data @@ -137,13 +152,13 @@ def to_h5(self, group): # reconstruction metadata recon_metadata = {"reconstruction_error": float(self._recon_error)} - if hasattr(self, "aberration_C1"): + if hasattr(self, "aberrations_C1"): recon_metadata |= { - "aberration_rotation_QR": self.rotation_Q_to_R_rads, - "aberration_transpose": self.transpose, - "aberration_C1": self.aberration_C1, - "aberration_A1x": self.aberration_A1x, - "aberration_A1y": self.aberration_A1y, + "aberrations_rotation_QR": self.rotation_Q_to_R_rads, + "aberrations_transpose": self.transpose, + "aberrations_C1": self.aberrations_C1, + "aberrations_C12a": self.aberrations_C12a, + "aberrations_C12b": self.aberrations_C12b, } if hasattr(self, "_kde_upsample_factor"): @@ -155,10 +170,15 @@ def to_h5(self, group): data=self._asnumpy(self._recon_BF_subpixel_aligned), ) - if hasattr(self, "aberration_dict_cartesian"): + if hasattr(self, "aberrations_dict_cartesian"): self.metadata = Metadata( - name="aberrations_polar_metadata", - data=self.aberration_dict_polar, + name="aberrations_dict_polar", + data=self.aberrations_dict_polar, + ) + + self.metadata = Metadata( + name="aberrations_dict_cartesian", + data=self.aberrations_dict_cartesian, ) self.metadata = Metadata( @@ -237,12 +257,12 @@ def _populate_instance(self, group): # Data dict_data = Custom._get_emd_attr_data(Custom, group) - if "aberration_C1" in reconstruction_md.keys: - self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] - self.transpose = reconstruction_md["aberration_transpose"] - self.aberration_C1 = reconstruction_md["aberration_C1"] - self.aberration_A1x = reconstruction_md["aberration_A1x"] - self.aberration_A1y = reconstruction_md["aberration_A1y"] + if "aberrations_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberrations_rotation_QR"] + self.transpose = reconstruction_md["aberrations_transpose"] + self.aberrations_C1 = reconstruction_md["aberrations_C1"] + self.aberrations_C12a = reconstruction_md["aberrations_C12a"] + self.aberrations_C12b = reconstruction_md["aberrations_C12b"] if "kde_upsample_factor" in reconstruction_md.keys: self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] @@ -252,6 +272,29 @@ def _populate_instance(self, group): self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) + def _set_polar_parameters(self, parameters: dict): + """ + Set the probe aberrations dictionary. + + Parameters + ---------- + parameters: dict + Mapping from aberration symbols to their corresponding values. + """ + + for symbol, value in parameters.items(): + if symbol in self._polar_parameters.keys(): + self._polar_parameters[symbol] = value + + elif symbol == "defocus": + self._polar_parameters[polar_aliases[symbol]] = -value + + elif symbol in polar_aliases.keys(): + self._polar_parameters[polar_aliases[symbol]] = value + + else: + raise ValueError("{} not a recognized parameter".format(symbol)) + def preprocess( self, edge_blend: float = 16.0, @@ -260,8 +303,8 @@ def preprocess( normalize_images: bool = True, normalize_order=0, descan_correction_fit_function: str = None, - defocus_guess: float = None, - rotation_guess: float = None, + force_rotation_angle_deg: float = 0, + force_transpose: bool = None, aligned_bf_image_guess: np.ndarray = None, plot_average_bf: bool = True, realspace_mask: np.ndarray = None, @@ -288,15 +331,13 @@ def preprocess( If True, bright images normalized to have a mean of 1 normalize_order: integer, optional Polynomial order for normalization. 0 means constant, 1 means linear, etc. - defocus_guess: float, optional - Initial guess of defocus value (defocus dF) in A - If None, first iteration is assumed to be in-focus aligned_bf_image_guess: np.ndarray, optional Guess for the reference BF image to cross-correlate against during the first iteration If None, the incoherent BF image is used instead. - rotation_guess: float, optional - Initial guess of rotation value in degrees - If None, first iteration assumed to be 0 + force_rotation: float, optional + Initial guess of rotation value in degrees. + force_transpose: bool, optional + Whether or not the dataset should be transposed. descan_correction_fit_function: str, optional If not None, descan correction will be performed using fit function. One of "constant", "plane", "parabola", or "bezier_two". @@ -656,33 +697,70 @@ def preprocess( self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) - if defocus_guess is not None: + # aberrations_coefs + aberrations_mn = [] + aberrations_coefs = [] + cartesian_dict = polar_aberrations_to_cartesian(self._polar_parameters) + for key, val in cartesian_dict.items(): + if np.abs(val) > 0: + m = int(key[1]) + n = int(key[2]) + a = 1 if n > 0 and key[3] != "a" else 0 + aberrations_mn.append([m, n, a]) + aberrations_coefs.append(val) + + if len(aberrations_mn) > 0: + # aberrations_basis + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + + # transpose rotation matrix + if force_transpose: + force_rotation_angle_deg *= -1 + + aberrations_basis, aberrations_basis_du, aberrations_basis_dv = ( + calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=np.deg2rad(force_rotation_angle_deg), + xp=xp, + ) + ) + + # shifts + corner_indices = self._xy_inds - xp.array( + self._region_of_interest_shape // 2 + ) + raveled_indices = xp.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.array( + ( + aberrations_basis_du[raveled_indices, :], + aberrations_basis_dv[raveled_indices, :], + ) + ) + + aberrations_coefs = xp.asarray(aberrations_coefs, dtype=xp.float32) + shifts_ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T + + if force_transpose: + shifts_ang = xp.flip(shifts_ang, axis=1) + + shifts_px = shifts_ang / xp.array(self._scan_sampling) + for start, end in generate_batches( self._num_bf_images, max_batch=max_batch_size ): shifted_BFs = self._stack_BF_shifted[start:end] - probe_angles = self._probe_angles[start:end] + xy_shifts = shifts_px[start:end] stack_mask = self._stack_mask[start:end] Gs = xp.fft.fft2(shifted_BFs) - xy_shifts = ( - -probe_angles - * defocus_guess - / xp.array(self._scan_sampling, dtype=xp.float32) - ) - - if rotation_guess is not None: - angle = xp.deg2rad(rotation_guess) - rotation_matrix = xp.array( - [ - [np.cos(angle), np.sin(angle)], - [-np.sin(angle), np.cos(angle)], - ], - dtype=xp.float32, - ) - xy_shifts = xp.dot(xy_shifts, rotation_matrix) - dx = xy_shifts[:, 0] dy = xy_shifts[:, 1] @@ -724,6 +802,7 @@ def preprocess( ) else: + self._recon_BF = ( self._stack_mean * mask_inv + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) @@ -2131,22 +2210,104 @@ def _kernel_density_estimate( max_batch_size=max_batch_size, ) + def _aberration_fit_polar_decomposition( + self, + xy_shifts, + scan_sampling, + probe_angles, + force_transpose: bool = False, + force_rotation_angle_deg: float = None, + ): + """ """ + + xp = self._xp + asnumpy = self._asnumpy + + # Numpy arrays + shifts = asnumpy(xy_shifts) + angles = asnumpy(probe_angles) + sampling = asnumpy(scan_sampling) + + if force_transpose: + shifts = np.flip(shifts, axis=1) + + shifts_Ang = shifts * sampling + + # Solve affine transformation + m = np.linalg.lstsq(angles, shifts_Ang, rcond=None)[0] + + if force_rotation_angle_deg is None: + m_rotation, m_aberration = polar(m, side="right") + + if force_transpose: + m_rotation = m_rotation.T + + # Convert into rotation and aberration coefficients + rotation_rad = -1 * np.arctan2(m_rotation[1, 0], m_rotation[0, 0]) + if 2 * np.abs(np.mod(rotation_rad + np.pi, 2 * np.pi) - np.pi) > np.pi: + rotation_rad = np.mod(rotation_rad, 2 * np.pi) - np.pi + m_aberration *= -1.0 + else: + rotation_rad = np.deg2rad(force_rotation_angle_deg) + c, s = np.cos(rotation_rad), np.sin(rotation_rad) + + m_rotation = np.array([[c, -s], [s, c]]) + if force_transpose: + m_rotation = m_rotation.T + + m_aberration = m_rotation @ m + + aberrations_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2 + + if force_transpose: + aberrations_C12a = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2 + aberrations_C12b = (m_aberration[1, 0] + m_aberration[0, 1]) / 2 + else: + aberrations_C12a = (m_aberration[0, 0] - m_aberration[1, 1]) / 2 + aberrations_C12b = (m_aberration[1, 0] + m_aberration[0, 1]) / 2 + + return ( + xp.asarray(shifts_Ang), + rotation_rad, + aberrations_C1, + aberrations_C12a, + aberrations_C12b, + ) + + def _aberration_fit_deltas_and_increment( + self, + measured_shifts, + fitted_shifts, + gradients, + aberrations_coefs, + indices, + ): + """ """ + xp = self._xp + asnumpy = self._asnumpy + + delta_shifts = (measured_shifts - fitted_shifts).T.ravel() + coefs = xp.linalg.lstsq(gradients, delta_shifts, rcond=None)[0] + deltas = xp.tensordot(gradients, coefs, axes=1).reshape((2, -1)).T + + aberrations_coefs[indices] += asnumpy(coefs) + fitted_shifts += deltas + + return aberrations_coefs, fitted_shifts + def aberration_fit( self, - fit_BF_shifts: bool = False, - fit_CTF_FFT: bool = False, - fit_aberrations_max_radial_order: int = 3, - fit_aberrations_max_angular_order: int = 4, - fit_aberrations_min_radial_order: int = 2, - fit_aberrations_min_angular_order: int = 0, - fit_aberrations_mn: list = None, - fit_max_thon_rings: int = 6, - fit_power_alpha: float = 1.0, - plot_CTF_comparison: bool = None, - plot_BF_shifts_comparison: bool = None, - upsampled: bool = True, + max_radial_order: int = 3, + max_angular_order: int = 4, + min_radial_order: int = 2, + min_angular_order: int = 0, + aberrations_mn: list = None, + initialize_fit_with_polar_decomposition: bool = True, + fit_method="recursive", force_transpose: bool = False, - force_rotation_deg: float = None, + force_rotation_angle_deg: float = None, + plot_CTF_comparison: bool = False, + plot_BF_shifts_comparison: bool = False, **kwargs, ): """ @@ -2188,85 +2349,45 @@ def aberration_fit( xp = self._xp asnumpy = self._asnumpy - ### First pass - - # Convert real space shifts to Angstroms - - if force_transpose is True: - self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( - self._scan_sampling + # Initial estimate + shifts_Ang, rotation_rad, aberrations_C1, aberrations_C12a, aberrations_C12b = ( + self._aberration_fit_polar_decomposition( + self._xy_shifts, + self._scan_sampling, + self._probe_angles, + force_transpose=force_transpose, + force_rotation_angle_deg=force_rotation_angle_deg, ) - else: - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) - - self.transpose = force_transpose - - # Solve affine transformation - m = asnumpy( - xp.linalg.lstsq(self._probe_angles, self._xy_shifts_Ang, rcond=None)[0] ) - if force_rotation_deg is None: - m_rotation, m_aberration = polar(m, side="right") - - if force_transpose: - m_rotation = m_rotation.T - - # Convert into rotation and aberration coefficients - - self.rotation_Q_to_R_rads = -1 * np.arctan2( - m_rotation[1, 0], m_rotation[0, 0] - ) - if np.abs( - np.mod(self.rotation_Q_to_R_rads + np.pi, 2.0 * np.pi) - np.pi - ) > (np.pi * 0.5): - self.rotation_Q_to_R_rads = ( - np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi - ) - m_aberration = -1.0 * m_aberration - else: - self.rotation_Q_to_R_rads = np.deg2rad(force_rotation_deg) - c, s = np.cos(self.rotation_Q_to_R_rads), np.sin(self.rotation_Q_to_R_rads) - - m_rotation = np.array([[c, -s], [s, c]]) - if force_transpose: - m_rotation = m_rotation.T - - m_aberration = m_rotation @ m - - self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 + self.aberrations_C1 = aberrations_C1 + self.aberrations_C12a = aberrations_C12a + self.aberrations_C12b = aberrations_C12b + self.rotation_Q_to_R_rads = rotation_rad + self.transpose = force_transpose - if self.transpose: - self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - else: - self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + # Aberration coefs - ### Second pass + if min_radial_order < 2 or max_radial_order < min_radial_order: + raise ValueError() - # Aberration coefs + if min_angular_order < 0 or max_angular_order < min_angular_order: + raise ValueError() - if fit_aberrations_mn is None: + if aberrations_mn is None: mn = [] - for m in range( - fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order - ): - n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) - for n in range(fit_aberrations_min_angular_order, n_max + 1): + for m in range(min_radial_order - 1, max_radial_order): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(min_angular_order, n_max + 1): if (m + n) % 2: mn.append([m, n, 0]) if n > 0: mn.append([m, n, 1]) else: - mn = fit_aberrations_mn + mn = aberrations_mn self._aberrations_mn = np.array(mn) - self._aberrations_mn = self._aberrations_mn[ - np.argsort(self._aberrations_mn[:, 1]), : - ] - sub = self._aberrations_mn[:, 1] > 0 self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ np.argsort(self._aberrations_mn[sub, 0]), : @@ -2275,387 +2396,373 @@ def aberration_fit( np.argsort(self._aberrations_mn[~sub, 0]), : ] self._aberrations_num = self._aberrations_mn.shape[0] + self._aberrations_coefs = np.zeros(self._aberrations_num) - # Thon Rings Fitting - if fit_CTF_FFT or plot_CTF_comparison: - if upsampled and hasattr(self, "_kde_upsample_factor"): - im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - sx = self._scan_sampling[0] / self._kde_upsample_factor - sy = self._scan_sampling[1] / self._kde_upsample_factor + # Basis functions + sampling = 1 / ( + np.array(self._reciprocal_sampling) * self._region_of_interest_shape + ) + ( + self._aberrations_basis, + self._aberrations_basis_du, + self._aberrations_basis_dv, + ) = calculate_aberration_gradient_basis( + self._aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=rotation_rad, + xp=xp, + ) - reciprocal_extent = [ - -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), - 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), - 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), - -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), - ] + corner_indices = self._xy_inds - xp.asarray(self._region_of_interest_shape // 2) + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) - else: - im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) - sx = self._scan_sampling[0] - sy = self._scan_sampling[1] - upsampled = False + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) + + # Initialization + if initialize_fit_with_polar_decomposition: + aberrations_mn_list = self._aberrations_mn.tolist() + initialization_inds = [] + if [1, 0, 0] in aberrations_mn_list: + ind_C1 = aberrations_mn_list.index([1, 0, 0]) + self._aberrations_coefs[ind_C1] = aberrations_C1 + initialization_inds.append(ind_C1) + + if [1, 2, 0] in aberrations_mn_list: + ind_C12a = aberrations_mn_list.index([1, 2, 0]) + ind_C12b = aberrations_mn_list.index([1, 2, 1]) + self._aberrations_coefs[ind_C12a] = aberrations_C12a + self._aberrations_coefs[ind_C12b] = aberrations_C12b + initialization_inds.append(ind_C12a) + initialization_inds.append(ind_C12b) + + initialization_inds = np.array(initialization_inds) + gradients = xp.array( + ( + self._aberrations_basis_du[ + raveled_indices[:, None], initialization_inds[None, :] + ], + self._aberrations_basis_dv[ + raveled_indices[:, None], initialization_inds[None, :] + ], + ) + ) - reciprocal_extent = [ - -0.5 / self._scan_sampling[1], - 0.5 / self._scan_sampling[1], - 0.5 / self._scan_sampling[0], - -0.5 / self._scan_sampling[0], - ] + aberrations_coefs = xp.asarray( + self._aberrations_coefs[initialization_inds], dtype=xp.float32 + ) + fitted_shifts_Ang = xp.tensordot(gradients, aberrations_coefs, axes=1).T + else: + fitted_shifts_Ang = xp.zeros_like(shifts_Ang) - # FFT coordinates - qx = xp.fft.fftfreq(im_FFT.shape[0], sx) - qy = xp.fft.fftfreq(im_FFT.shape[1], sy) - qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + # Incremental fitting + chunks = np.unique(self._aberrations_mn[:, 0], return_index=True)[1][1:] + split_order = np.split(np.arange(self._aberrations_num), chunks) - alpha_FFT = xp.sqrt(qr2) * self._wavelength - theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) + if fit_method == "recursive-exclusive": + self._aberrations_split_order = split_order + elif fit_method == "recursive": + self._aberrations_split_order = [ + np.concatenate(split_order[:n]) for n in range(1, len(split_order) + 1) + ] + elif fit_method == "global": + self._aberrations_split_order = [np.concatenate(split_order)] + else: + raise ValueError() - # Aberration basis - self._aberrations_basis_FFT = xp.zeros( - (alpha_FFT.size, self._aberrations_num) + for indices in self._aberrations_split_order: + gradients = xp.vstack( + ( + self._aberrations_basis_du[ + raveled_indices[:, None], indices[None, :] + ], + self._aberrations_basis_dv[ + raveled_indices[:, None], indices[None, :] + ], + ) ) - for a0 in range(self._aberrations_num): - m, n, a = self._aberrations_mn[a0] - if n == 0: - # Radially symmetric basis - self._aberrations_basis_FFT[:, a0] = ( - alpha_FFT ** (m + 1) / (m + 1) - ).ravel() - - elif a == 0: - # cos coef - self._aberrations_basis_FFT[:, a0] = ( - alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) - ).ravel() - else: - # sin coef - self._aberrations_basis_FFT[:, a0] = ( - alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) - ).ravel() - - # global scaling - self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength - self._aberrations_surface_shape_FFT = alpha_FFT.shape - plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) - angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 - - # CTF function - def calculate_CTF_FFT(alpha_shape, *coefs): - chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) - for a0 in range(len(coefs)): - chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] - return xp.reshape(chi, alpha_shape) - - # Direct Shifts Fitting - if fit_BF_shifts: - sampling = 1 / ( - np.array(self._reciprocal_sampling) * self._region_of_interest_shape + + self._aberrations_coefs, fitted_shifts_Ang = ( + self._aberration_fit_deltas_and_increment( + shifts_Ang, + fitted_shifts_Ang, + gradients, + self._aberrations_coefs, + indices, + ) ) - ( - self._aberrations_babis, - self._aberrations_basis_du, - self._aberrations_basis_dv, - ) = calculate_aberration_gradient_basis( - self._aberrations_mn, - sampling, - self._region_of_interest_shape, - self._wavelength, - rotation_angle=self.rotation_Q_to_R_rads, - xp=xp, + + if force_transpose: + aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( + self._aberrations_mn[:, 2] == 0 ) + self._aberrations_coefs[aberrations_to_flip] *= -1 - # CTF function - def calculate_CTF(alpha_shape, *coefs): - chi = xp.zeros_like(self._aberrations_basis[:, 0]) - for a0 in range(len(coefs)): - chi += coefs[a0] * self._aberrations_basis[:, a0] - return xp.reshape(chi, alpha_shape) + # format aberrations + dict_cartesian = { + tuple(self._aberrations_mn[a0]): self._aberrations_coefs[a0] + for a0 in range(self._aberrations_num) + } + dict_polar = {} + unique_aberrations = np.unique(self._aberrations_mn[:, :2], axis=0) + for aberration_order in unique_aberrations: + m, n = aberration_order + modulus_name = "C" + str(m) + str(n) - # initial coefficients and plotting intensity range mask - self._aberrations_coefs = np.zeros(self._aberrations_num) + if n != 0: + value_a = dict_cartesian[(m, n, 0)] + value_b = dict_cartesian[(m, n, 1)] + dict_polar[modulus_name] = np.sqrt(value_a**2 + value_b**2) - aberrations_mn_list = self._aberrations_mn.tolist() - if [1, 0, 0] in aberrations_mn_list: - ind_C1 = aberrations_mn_list.index([1, 0, 0]) - self._aberrations_coefs[ind_C1] = self.aberration_C1 - - if [1, 2, 0] in aberrations_mn_list: - ind_A1x = aberrations_mn_list.index([1, 2, 0]) - ind_A1y = aberrations_mn_list.index([1, 2, 1]) - self._aberrations_coefs[ind_A1x] = self.aberration_A1x - self._aberrations_coefs[ind_A1y] = self.aberration_A1y - - # Refinement using CTF fitting / Thon rings - if fit_CTF_FFT: - # scoring function to minimize - mean value of zero crossing regions of FFT - def score_CTF(coefs): - im_CTF = xp.abs( - calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) - ) - mask = xp.logical_and( - im_CTF > 0.5 * np.pi, - im_CTF < (max_num_rings + 0.5) * np.pi, - ) - if np.any(mask): - weights = xp.cos(im_CTF[mask]) ** 4 - return asnumpy( - xp.sum( - weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha - ) - / xp.sum(weights) - ) - else: - return np.inf + argument_name = "phi" + str(m) + str(n) + dict_polar[argument_name] = np.arctan2(value_b, value_a) / n + else: + dict_polar[modulus_name] = dict_cartesian[(m, n, 0)] - for max_num_rings in range(1, fit_max_thon_rings + 1): - # minimization - res = minimize( - score_CTF, - self._aberrations_coefs, - # method = 'Nelder-Mead', - # method = 'CG', - method="BFGS", - tol=1e-8, - ) - self._aberrations_coefs = res.x + dict_cartesian = polar_aberrations_to_cartesian(dict_polar) + self.aberrations_dict_cartesian = dict_cartesian + self.aberrations_dict_polar = dict_polar - # Refinement using CTF fitting / Thon rings - elif fit_BF_shifts: - # Gradient basis - corner_indices = self._xy_inds - xp.asarray( - self._region_of_interest_shape // 2 - ) - raveled_indices = np.ravel_multi_index( - corner_indices.T, self._region_of_interest_shape, mode="wrap" - ) - gradients = xp.vstack( - ( - self._aberrations_basis_du[raveled_indices, :], - self._aberrations_basis_dv[raveled_indices, :], - ) + # Plot the measured/fitted shifts comparison + nrows = np.count_nonzero( + np.array( + [ + plot_BF_shifts_comparison, + plot_CTF_comparison, + ] ) + ) - # (Relative) untransposed fit - raveled_shifts = self._xy_shifts_Ang.T.ravel() - aberrations_coefs, res = xp.linalg.lstsq( - gradients, raveled_shifts, rcond=None - )[:2] + if nrows > 0: + spec = GridSpec(ncols=2, nrows=nrows) - self._aberrations_coefs = asnumpy(aberrations_coefs) + figsize = kwargs.pop("figsize", (8, 4 * nrows)) + fig = plt.figure(figsize=figsize) - if self.transpose: - aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( - self._aberrations_mn[:, 2] == 0 + row_index = 0 + + if plot_CTF_comparison: + if hasattr(self, "_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + else: + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + + reciprocal_extent = [ + -0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[0], + -0.5 / self._scan_sampling[0], + ] + + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + + chi_FFT = xp.zeros(alpha_FFT.shape) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + coeff = self._aberrations_coefs[a0] + if n == 0: + # Radially symmetric basis + chi_FFT += (alpha_FFT ** (m + 1) / (m + 1)) * coeff + + elif a == 0: + # cos coef + chi_FFT += ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) + ) * coeff + else: + # sin coef + chi_FFT += ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) + ) * coeff + + # global scaling + chi_FFT *= 2 * np.pi / self._wavelength + plot_mask = qr2 > np.pi**2 / 4 / np.abs(aberrations_C1) + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # Generate FFT plotting image + im_scale = np.fft.fftshift(asnumpy(im_FFT)) + im_scale, vmin, vmax = return_scaled_histogram_ordering( + im_scale, normalize=True ) - self._aberrations_coefs[aberrations_to_flip] *= -1 + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) - # Plot the measured/fitted shifts comparison - if plot_BF_shifts_comparison: + # Add CTF zero crossings + im_CTF_plot = xp.abs(xp.sin(chi_FFT)) + + chi_FFT[xp.abs(chi_FFT) > 12.5 * np.pi] = np.pi / 2 + chi_FFT = xp.abs(xp.sin(chi_FFT)) < 0.15 + chi_FFT[xp.logical_not(plot_mask)] = 0 + + chi_FFT = np.fft.fftshift(asnumpy(chi_FFT * angular_mask)) + im_plot[:, :, 0] += chi_FFT + im_plot[:, :, 1] -= chi_FFT + im_plot[:, :, 2] -= chi_FFT + im_plot = np.clip(im_plot, 0, 1) - fitted_shifts = ( - xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) - .reshape((2, -1)) - .T + ax1 = fig.add_subplot(spec[row_index, 0]) + ax2 = fig.add_subplot(spec[row_index, 1]) + + ax1.imshow(im_plot, vmin=vmin, vmax=vmax, extent=reciprocal_extent) + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_plot)), + cmap="gray", + extent=reciprocal_extent, ) + for ax in (ax1, ax2): + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + + ax1.set_title("Aligned Bright Field FFT") + ax2.set_title("Fitted CTF ") + row_index += 1 + + if plot_BF_shifts_comparison: + scale_arrows = kwargs.pop("scale_arrows", 1) plot_arrow_freq = kwargs.pop("plot_arrow_freq", 1) - figsize = kwargs.pop("figsize", (4, 4)) - fig, ax = plt.subplots(figsize=figsize) + ax1 = fig.add_subplot(spec[row_index, 0]) + ax2 = fig.add_subplot(spec[row_index, 1]) self.show_shifts( - shifts_ang=self._xy_shifts_Ang, + shifts_ang=shifts_Ang, plot_rotated_shifts=False, plot_arrow_freq=plot_arrow_freq, scale_arrows=scale_arrows, - color=(1, 0, 0, 0.5), - figax=(fig, ax), + color=(1, 0, 0), + figax=(fig, ax1), ) self.show_shifts( - shifts_ang=fitted_shifts, + shifts_ang=fitted_shifts_Ang, plot_rotated_shifts=False, plot_arrow_freq=plot_arrow_freq, scale_arrows=scale_arrows, - color=(0, 0, 1, 0.5), - figax=(fig, ax), + color=(0, 0, 1), + figax=(fig, ax2), ) + ax2.set_title("Fitted BF Shifts") + row_index += 1 - # Plot the CTF comparison between experiment and fit - if plot_CTF_comparison: - # Generate FFT plotting image - im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) - int_vals = np.sort(im_scale.ravel()) - int_range = ( - int_vals[np.round(0.02 * im_scale.size).astype("int")], - int_vals[np.round(0.98 * im_scale.size).astype("int")], - ) - int_range = ( - int_range[0], - (int_range[1] - int_range[0]) * 1.0 + int_range[0], - ) - im_scale = np.clip( - (np.fft.fftshift(im_scale) - int_range[0]) - / (int_range[1] - int_range[0]), - 0, - 1, - ) - im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) - - # Add CTF zero crossings - im_CTF = calculate_CTF_FFT( - self._aberrations_surface_shape_FFT, *self._aberrations_coefs - ) - - im_CTF_plot = xp.abs(xp.sin(im_CTF)) - - im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 - im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 - im_CTF[xp.logical_not(plot_mask)] = 0 - - im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) - im_plot[:, :, 0] += im_CTF - im_plot[:, :, 1] -= im_CTF - im_plot[:, :, 2] -= im_CTF - im_plot = np.clip(im_plot, 0, 1) - - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) - ax1.imshow( - im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent - ) - ax2.imshow( - np.fft.fftshift(asnumpy(im_CTF_plot)), - cmap="gray", - extent=reciprocal_extent, - ) - - for ax in (ax1, ax2): - ax.set_ylabel(r"$k_x$ [$A^{-1}$]") - ax.set_xlabel(r"$k_y$ [$A^{-1}$]") - - ax1.set_title("Aligned Bright Field FFT") - ax2.set_title("Fitted CTF ") - - fig.tight_layout() - - self.aberration_dict_cartesian = { - tuple(self._aberrations_mn[a0]): { - "aberration name": _aberration_names.get( - tuple(self._aberrations_mn[a0, :2]), "-" - ).strip(), - "value [Ang]": self._aberrations_coefs[a0], - } - for a0 in range(self._aberrations_num) - } + spec.tight_layout(fig) # Print results if self._verbose: - if fit_CTF_FFT or fit_BF_shifts: - print("Initial Aberration coefficients") - print("-------------------------------") - print( - ( - "Rotation of Q w.r.t. R = " - f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" - ) - ) - print( - ( - "Astigmatism (A1x,A1y) = (" - f"{self.aberration_A1x:.0f}," - f"{self.aberration_A1y:.0f}) Ang" - ) - ) - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - print(f"Transpose = {self.transpose}") - - if fit_CTF_FFT or fit_BF_shifts: - print() - print("Refined Aberration coefficients") - print("-------------------------------") - print("aberration radial angular dir. coefs") - print("name order order Ang ") - print("---------- ------- ------- ---- -----") - - for a0 in range(self._aberrations_mn.shape[0]): - m, n, a = self._aberrations_mn[a0] - name = _aberration_names.get((m, n), " -- ") - if n == 0: - print( - name - + " " - + str(m + 1) - + " 0 - " - + str(np.round(self._aberrations_coefs[a0]).astype("int")) - ) - elif a == 0: - print( - name - + " " - + str(m + 1) - + " " - + str(n) - + " x " - + str(np.round(self._aberrations_coefs[a0]).astype("int")) - ) - else: - print( - name - + " " - + str(m + 1) - + " " - + str(n) - + " y " - + str(np.round(self._aberrations_coefs[a0]).astype("int")) - ) + heading = "Initial aberration coefficients" + print(f"{heading:^50}") + print("-" * 50) + print(" rotation transpose C1 stig stig angle") + print(" [deg] --- [Ang] [Ang] [deg] ") + print("---------- ------- ------- ----- ---------") + + angle = f"{np.round(np.rad2deg(rotation_rad),decimals=1):^10}" + transpose = f"{str(force_transpose):^7}" + C1 = np.round(aberrations_C1).astype("int") + C1 = f"{C1:^7}" + stig = np.round( + np.sqrt(aberrations_C12a**2 + aberrations_C12b**2) + ).astype("int") + stig = f"{stig:^5}" + stig_angle = np.round( + np.rad2deg(np.arctan2(aberrations_C12b, aberrations_C12a) / 2), + decimals=1, + ) + stig_angle = f"{stig_angle:^9}" + print(" ".join([angle, transpose, C1, stig, stig_angle])) + + print() + heading = "Refined aberration coefficients" + print(f"{heading:^50}") + print("-" * 50) + print("aberration radial angular angle magnitude") + print(" name order order [deg] [Ang] ") + print("---------- ------- ------- ----- ---------") + + for mn in np.unique(self._aberrations_mn[:, :2], axis=0): + m, n = mn + name = _aberration_names.get((m, n), "---") + mag = dict_polar.get(f"C{m}{n}", "---") + angle = dict_polar.get(f"phi{m}{n}", "---") + if angle != "---": + angle = np.round(np.rad2deg(angle), decimals=1) + if mag != "---": + mag = np.round(mag).astype("int") + + name = f"{name:^10}" + radial_order = f"{m+1:^7}" + angular_order = f"{n:^7}" + angle = f"{angle:^5}" + mag = f"{mag:^9}" + print(" ".join([name, radial_order, angular_order, angle, mag])) self.clear_device_mem(self._device, self._clear_fft_cache) return self - def _calculate_CTF(self, alpha_shape, sampling, *coefs): + def _calculate_CTF(self, alpha_shape, sampling, aberrations_mn, coefs): xp = self._xp # FFT coordinates sx, sy = sampling - qx = xp.fft.fftfreq(alpha_shape[0], sx) - qy = xp.fft.fftfreq(alpha_shape[1], sy) + nx, ny = alpha_shape + qx = xp.fft.fftfreq(nx, sx).astype(xp.float32) + qy = xp.fft.fftfreq(ny, sy).astype(xp.float32) qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 alpha = xp.sqrt(qr2) * self._wavelength theta = xp.arctan2(qy[None, :], qx[:, None]) - # Aberration basis - aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) - for a0 in range(self._aberrations_num): - m, n, a = self._aberrations_mn[a0] + chi = xp.zeros(alpha_shape, dtype=xp.float32) + aberrations_num = len(aberrations_mn) + + for a0 in range(aberrations_num): + m, n, a = aberrations_mn[a0] + coef = coefs[a0] if n == 0: # Radially symmetric basis - aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + chi += (alpha ** (m + 1) / (m + 1)) * coef elif a == 0: # cos coef - aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) - ).ravel() + chi += (alpha ** (m + 1) * xp.cos(n * theta) / (m + 1)) * coef else: # sin coef - aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) - ).ravel() + chi += (alpha ** (m + 1) * xp.sin(n * theta) / (m + 1)) * coef # global scaling - aberrations_basis *= 2 * np.pi / self._wavelength - - chi = xp.zeros_like(aberrations_basis[:, 0]) + chi *= 2 * np.pi / self._wavelength - for a0 in range(len(coefs)): - chi += coefs[a0] * aberrations_basis[:, a0] - - return xp.reshape(chi, alpha_shape) + return chi def aberration_correct( self, @@ -2685,7 +2792,7 @@ def aberration_correct( xp = self._xp asnumpy = self._asnumpy - if not hasattr(self, "aberration_C1"): + if not hasattr(self, "aberrations_C1"): raise ValueError( ( "CTF correction is meant to be ran after alignment and aberration fitting. " @@ -2713,22 +2820,23 @@ def aberration_correct( use_CTF_fit = True if use_CTF_fit: + # note m+1 is radial order even_radial_orders = (self._aberrations_mn[:, 0] % 2) == 1 odd_radial_orders = (self._aberrations_mn[:, 0] % 2) == 0 - odd_coefs = self._aberrations_coefs.copy() - odd_coefs[even_radial_orders] = 0 - chi_odd = self._calculate_CTF(im.shape, (sx, sy), *odd_coefs) + odd_mn = self._aberrations_mn[odd_radial_orders] + odd_coefs = self._aberrations_coefs[odd_radial_orders] + chi_odd = self._calculate_CTF(im.shape, (sx, sy), odd_mn, odd_coefs) - even_coefs = self._aberrations_coefs.copy() - even_coefs[odd_radial_orders] = 0 - chi_even = self._calculate_CTF(im.shape, (sx, sy), *even_coefs) + even_mn = self._aberrations_mn[even_radial_orders] + even_coefs = self._aberrations_coefs[even_radial_orders] + chi_even = self._calculate_CTF(im.shape, (sx, sy), even_mn, even_coefs) if not chi_even.any(): # check if all zeros chi_even = xp.ones_like(chi_even) else: - chi_even = (xp.pi * self._wavelength * self.aberration_C1) * kra2 + chi_even = (xp.pi * self._wavelength * self.aberrations_C1) * kra2 chi_odd = xp.zeros_like(chi_even) CTF_corr = xp.sign(xp.sin(chi_even)) * xp.exp(-1j * chi_odd) @@ -2812,7 +2920,7 @@ def depth_section( xp = self._xp asnumpy = self._asnumpy - if not hasattr(self, "aberration_C1"): + if not hasattr(self, "aberrations_C1"): raise ValueError( ( "Depth sectioning is meant to be ran after alignment and aberration fitting. " @@ -2836,7 +2944,7 @@ def depth_section( self._calculate_CTF((nx, ny), (sx, sy), *self._aberrations_coefs) ) else: - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberrations_C1) * kra2) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 @@ -3260,26 +3368,3 @@ def object_cropped(self): ) else: return self._crop_padded_object(self._recon_BF) - - @property - def aberration_dict_polar(self): - """converts cartesian aberration dictionary to the polar convention used in ptycho""" - polar_dict = {} - unique_aberrations = np.unique(self._aberrations_mn[:, :2], axis=0) - aberrations_dict = self.aberration_dict_cartesian - - for aberration_order in unique_aberrations: - m, n = aberration_order - modulus_name = "C" + str(m) + str(n) - - if n != 0: - value_a = aberrations_dict[(m, n, 0)]["value [Ang]"] - value_b = aberrations_dict[(m, n, 1)]["value [Ang]"] - polar_dict[modulus_name] = np.sqrt(value_a**2 + value_b**2) - - argument_name = "phi" + str(m) + str(n) - polar_dict[argument_name] = np.arctan2(value_b, value_a) / n - else: - polar_dict[modulus_name] = aberrations_dict[(m, n, 0)]["value [Ang]"] - - return polar_dict