diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 04cfd6a60..772f6b133 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -8,7 +8,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import show, show_complex +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import rotate try: @@ -23,7 +23,11 @@ from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( PtychographicConstraints, ) -from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases +from py4DSTEM.process.phase.utils import ( + AffineTransform, + generate_batches, + polar_aliases, +) from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -1132,6 +1136,7 @@ def _normalize_diffraction_intensities( com_fitted_x, com_fitted_y, crop_patterns, + positions_mask, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1147,6 +1152,8 @@ def _normalize_diffraction_intensities( crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1160,6 +1167,11 @@ def _normalize_diffraction_intensities( mean_intensity = 0 diffraction_intensities = self._asnumpy(diffraction_intensities) + if positions_mask is not None: + number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) + else: + number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + if crop_patterns: crop_x = int( np.minimum( @@ -1178,8 +1190,7 @@ def _normalize_diffraction_intensities( region_of_interest_shape = (crop_w * 2, crop_w * 2) amplitudes = np.zeros( ( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + number_of_patterns, crop_w * 2, crop_w * 2, ), @@ -1195,13 +1206,19 @@ def _normalize_diffraction_intensities( else: region_of_interest_shape = diffraction_intensities.shape[-2:] - amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) + amplitudes = np.zeros( + (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + ) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) + counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): + if positions_mask is not None: + if not self._positions_mask[rx, ry]: + continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], -com_fitted_x[rx, ry], @@ -1216,9 +1233,9 @@ def _normalize_diffraction_intensities( ) mean_intensity += np.sum(intensities) - amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] @@ -1257,7 +1274,7 @@ def show_complex_CoM( if pixelsize is None: pixelsize = self._scan_sampling[0] if pixelunits is None: - pixelunits = r"$\AA$" + pixelunits = self._scan_units[0] figsize = kwargs.pop("figsize", (6, 6)) fig, ax = plt.subplots(figsize=figsize) @@ -1535,7 +1552,9 @@ def _set_polar_parameters(self, parameters: dict): else: raise ValueError("{} not a recognized parameter".format(symbol)) - def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): + def _calculate_scan_positions_in_pixels( + self, positions: np.ndarray, positions_mask + ): """ Method to compute the initial guess of scan positions in pixels. @@ -1544,6 +1563,8 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions: (J,2) np.ndarray or None Input probe positions in Å. If None, a raster scan using experimental parameters is constructed. + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1592,6 +1613,9 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions = np.array([x.ravel(), y.ravel()]).T positions -= np.min(positions, axis=0) + if positions_mask is not None: + positions = positions[positions_mask.ravel()] + if self._object_padding_px is None: float_padding = self._region_of_interest_shape / 2 self._object_padding_px = (float_padding, float_padding) @@ -2217,6 +2241,243 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(asnumpy(obj)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped) + else: + projected_cropped_potential = self.object_cropped + + return projected_cropped_potential + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(xp.asarray(errors), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") + pix_output /= pix_count + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5.25) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized Squared Error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], + 0, + ] + + ax.imshow( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + def show_fourier_probe( self, probe=None, @@ -2286,22 +2547,16 @@ def show_object_fft(self, obj=None, **kwargs): figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -2366,6 +2621,6 @@ def positions(self): @property def object_cropped(self): - """cropped and rotated object""" + """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index af3cbbb45..b390ce46d 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -799,6 +799,7 @@ def reconstruct( anti_gridding=anti_gridding, ) + self.error_iterations.append(self.error.item()) if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -807,7 +808,6 @@ def reconstruct( ].copy() ) ) - self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: if self._verbose: diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 3eeb07814..f4c10cb13 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -82,9 +82,17 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -114,7 +122,11 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -162,6 +174,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: @@ -186,6 +217,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -201,6 +239,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -210,6 +249,8 @@ def __init__( self._num_probes = num_probes self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -217,6 +258,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -232,6 +275,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) Returns ------- @@ -251,6 +298,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -258,6 +309,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -445,7 +502,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -454,7 +515,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -597,6 +658,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -3060,6 +3123,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -3075,12 +3139,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -3098,8 +3170,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) @@ -3509,3 +3594,61 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 2e9fbd076..d68291143 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -74,6 +74,8 @@ class MixedstatePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -102,6 +104,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "mixed-state_ptychographic_reconstruction", @@ -161,6 +164,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -178,6 +187,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -349,7 +359,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -358,7 +372,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -2327,3 +2341,50 @@ def show_fourier_probe( chroma_boost=chroma_boost, **kwargs, ) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 4515590fe..93e32b079 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -81,12 +81,16 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): Probe positions in Å for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in degrees) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -117,7 +121,9 @@ def __init__( initial_scan_positions: np.ndarray = None, theta_x: float = 0, theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -150,6 +156,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: @@ -173,6 +198,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -189,6 +220,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -224,9 +256,9 @@ def _precompute_propagator_arrays( slice_thicknesses: Sequence[float] Array of slice thicknesses in A theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in degrees) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in degrees) Returns ------- @@ -450,7 +482,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -459,7 +495,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -2919,6 +2955,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -2934,12 +2971,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -2957,8 +3002,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) @@ -3075,7 +3133,7 @@ def show_depth( rotated_object = np.roll( rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - int(x1_0), + -int(x1_0), axis=1, ) @@ -3368,3 +3426,14 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 32b0f6fd4..c49a1faac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -93,6 +93,8 @@ class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -115,6 +117,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -163,6 +166,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -179,6 +189,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -595,7 +606,11 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, com_fitted_x, com_fitted_y, crop_patterns + intensities, + com_fitted_x, + com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -615,7 +630,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels @@ -3288,22 +3303,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -3327,3 +3336,29 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 66cf46487..ddd13ac58 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -88,6 +88,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction name: str, optional Class name kwargs: @@ -111,6 +113,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -172,6 +175,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -188,6 +198,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -535,7 +546,11 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, com_fitted_x, com_fitted_y, crop_patterns + intensities, + com_fitted_x, + com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -555,7 +570,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels @@ -3168,22 +3183,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -3207,3 +3216,29 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 74688fa0b..716e1d782 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -136,7 +136,7 @@ def to_h5(self, group): if hasattr(self, "aberration_C1"): recon_metadata |= { "aberration_rotation_QR": self.rotation_Q_to_R_rads, - "aberration_transpose": self.transpose_detected, + "aberration_transpose": self.transpose, "aberration_C1": self.aberration_C1, "aberration_A1x": self.aberration_A1x, "aberration_A1y": self.aberration_A1y, @@ -236,7 +236,7 @@ def _populate_instance(self, group): if "aberration_C1" in reconstruction_md.keys: self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] - self.transpose_detected = reconstruction_md["aberration_transpose"] + self.transpose = reconstruction_md["aberration_transpose"] self.aberration_C1 = reconstruction_md["aberration_C1"] self.aberration_A1x = reconstruction_md["aberration_A1x"] self.aberration_A1y = reconstruction_md["aberration_A1y"] @@ -587,16 +587,27 @@ def preprocess( self.recon_BF = asnumpy(self._recon_BF) if plot_average_bf: - figsize = kwargs.pop("figsize", (6, 6)) + figsize = kwargs.pop("figsize", (6, 12)) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(1, 2, figsize=figsize) - self._visualize_figax(fig, ax, **kwargs) + self._visualize_figax(fig, ax[0], **kwargs) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Average Bright Field Image") + ax[0].set_ylabel("x [A]") + ax[0].set_xlabel("y [A]") + ax[0].set_title("Average Bright Field Image") + reciprocal_extent = [ + -0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + ] + ax[1].imshow(self._dp_mask, extent=reciprocal_extent, cmap="gray") + ax[1].set_title("DP mask") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + plt.tight_layout() self._preprocessed = True if self._device == "gpu": @@ -1098,26 +1109,46 @@ def subpixel_alignment( BF_size = np.array(self._stack_BF_no_window.shape[-2:]) self._DF_upsample_limit = np.max( - self._region_of_interest_shape / self._scan_shape + 2 * self._region_of_interest_shape / self._scan_shape ) self._BF_upsample_limit = ( - 2 * self._kr.max() / self._reciprocal_sampling[0] + 4 * self._kr.max() / self._reciprocal_sampling[0] ) / self._scan_shape.max() if self._device == "gpu": self._BF_upsample_limit = self._BF_upsample_limit.item() if kde_upsample_factor is None: - kde_upsample_factor = np.minimum( - self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit - ) + if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit: + kde_upsample_factor = self._DF_upsample_limit - warnings.warn( - ( - f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " - f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." - ), - UserWarning, - ) + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (the " + f"dark-field upsampling limit)." + ), + UserWarning, + ) + + elif self._BF_upsample_limit * 3 / 2 > 1: + kde_upsample_factor = self._BF_upsample_limit * 3 / 2 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + else: + kde_upsample_factor = self._DF_upsample_limit * 2 / 3 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (2/3 times the " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f})." + ), + UserWarning, + ) if kde_upsample_factor < 1: raise ValueError("kde_upsample_factor must be larger than 1") @@ -1187,7 +1218,7 @@ def subpixel_alignment( # kernel density estimate sigma = kde_sigma * self._kde_upsample_factor pix_count = gaussian_filter(pix_count, sigma) - pix_count[pix_output == 0.0] = np.inf + pix_count[pix_count == 0.0] = np.inf pix_output = gaussian_filter(pix_output, sigma) pix_output /= pix_count @@ -1301,7 +1332,7 @@ def aberration_fit( plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, upsampled: bool = True, - force_transpose: bool = None, + force_transpose: bool = False, ): """ Fit aberrations to the measured image shifts. @@ -1342,17 +1373,13 @@ def aberration_fit( # Convert real space shifts to Angstroms - if force_transpose is None: - self.transpose_detected = False - else: - self.transpose_detected = force_transpose - if force_transpose is True: self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( self._scan_sampling ) else: self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + self.transpose = force_transpose # Solve affine transformation m = asnumpy( @@ -1369,9 +1396,15 @@ def aberration_fit( np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi ) m_aberration = -1.0 * m_aberration + self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + + if self.transpose: + self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + else: + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 ### Second pass @@ -1417,12 +1450,26 @@ def aberration_fit( sx = self._scan_sampling[0] / self._kde_upsample_factor sy = self._scan_sampling[1] / self._kde_upsample_factor + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + else: im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) sx = self._scan_sampling[0] sy = self._scan_sampling[1] upsampled = False + reciprocal_extent = [ + -0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[0], + -0.5 / self._scan_sampling[0], + ] + # FFT coordinates qx = xp.fft.fftfreq(im_FFT.shape[0], sx) qy = xp.fft.fftfreq(im_FFT.shape[1], sy) @@ -1474,12 +1521,19 @@ def calculate_CTF_FFT(alpha_shape, *coefs): sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) - qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + qx, qy = np.meshgrid(qx, qy, indexing="ij") - u = qx[:, None] * self._wavelength - v = qy[None, :] * self._wavelength + # passive rotation basis by -theta + rotation_angle = -self.rotation_Q_to_R_rads + qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( + rotation_angle + ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) + + qr2 = qx**2 + qy**2 + u = qx * self._wavelength + v = qy * self._wavelength alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None, :], qx[:, None]) + theta = xp.arctan2(qy, qx) # Aberration basis self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) @@ -1541,10 +1595,17 @@ def calculate_CTF(alpha_shape, *coefs): # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) - ind = np.argmin( - np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] - ) - self._aberrations_coefs[ind] = self.aberration_C1 + + aberrations_mn_list = self._aberrations_mn.tolist() + if [1, 0, 0] in aberrations_mn_list: + ind_C1 = aberrations_mn_list.index([1, 0, 0]) + self._aberrations_coefs[ind_C1] = self.aberration_C1 + + if [1, 2, 0] in aberrations_mn_list: + ind_A1x = aberrations_mn_list.index([1, 2, 0]) + ind_A1y = aberrations_mn_list.index([1, 2, 1]) + self._aberrations_coefs[ind_A1x] = self.aberration_A1x + self._aberrations_coefs[ind_A1y] = self.aberration_A1y # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: @@ -1597,54 +1658,84 @@ def score_CTF(coefs): ) # (Relative) untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + raveled_shifts = self._xy_shifts_Ang.T.ravel() aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None + gradients, raveled_shifts, rcond=None )[:2] - if force_transpose is None: - # (Relative) transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) - m_T = asnumpy( - xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ - 0 - ] + self._aberrations_coefs = asnumpy(aberrations_coefs) + + if self.transpose: + aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( + self._aberrations_mn[:, 2] == 0 ) - m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2( - m_rotation_T[1, 0], m_rotation_T[0, 0] + self._aberrations_coefs[aberrations_to_flip] *= -1 + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 ) - if np.abs( - np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi - ) > (np.pi * 0.5): - rotation_Q_to_R_rads_T = ( - np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] + + fitted_shifts = ( + xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) + .reshape((2, -1)) + .T + ) + + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] - tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq( - gradients, rotated_shifts_T, rcond=None - )[:2] - - # Compare fits - if res_T.sum() < res.sum(): - self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = not self.transpose_detected - self._aberrations_coefs = asnumpy(aberrations_coefs_T) - self._rotated_shifts = rotated_shifts_T - - warnings.warn( - ( - "Data transpose detected. " - f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" - ), - UserWarning, + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] + + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] ) - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts + ) + + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], + [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Fitted Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Horizontal Shifts", + ], + ) # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: @@ -1682,79 +1773,24 @@ def score_CTF(coefs): im_plot[:, :, 2] -= im_CTF im_plot = np.clip(im_plot, 0, 1) - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) - ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) - - ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") - - fig.tight_layout() - - # Plot the measured/fitted shifts comparison - if plot_BF_shifts_comparison: - if not fit_BF_shifts: - raise ValueError() - - measured_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) + ax1.imshow( + im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._rotated_shifts[: self._xy_inds.shape[0]] - - measured_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap="gray", + extent=reciprocal_extent, ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._rotated_shifts[self._xy_inds.shape[0] :] - fitted_shifts = xp.tensordot( - gradients, xp.array(self._aberrations_coefs), axes=1 - ) + for ax in (ax1, ax2): + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") - fitted_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ - : self._xy_inds.shape[0] - ] - - fitted_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ - self._xy_inds.shape[0] : - ] - - max_shift = xp.max( - xp.array( - [ - xp.abs(measured_shifts_sx).max(), - xp.abs(measured_shifts_sy).max(), - xp.abs(fitted_shifts_sx).max(), - xp.abs(fitted_shifts_sy).max(), - ] - ) - ) + ax1.set_title("Aligned Bright Field FFT") + ax2.set_title("Fitted CTF Zero-Crossings") - show( - [ - [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], - [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], - ], - cmap="PiYG", - vmin=-max_shift, - vmax=max_shift, - intensity_range="absolute", - axsize=(4, 4), - ticks=False, - title=[ - "Measured Vertical Shifts", - "Measured Horizontal Shifts", - "Fitted Vertical Shifts", - "Fitted Horizontal Shifts", - ], - ) + fig.tight_layout() self.aberration_dict = { tuple(self._aberrations_mn[a0]): { @@ -1786,6 +1822,7 @@ def score_CTF(coefs): ) print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Transpose = {self.transpose}") if fit_CTF_FFT or fit_BF_shifts: print() @@ -2268,6 +2305,7 @@ def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, + plot_rotated_shifts=True, **kwargs, ): """ @@ -2284,10 +2322,22 @@ def show_shifts( xp = self._xp asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (6, 6)) color = kwargs.pop("color", (1, 0, 0, 1)) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + scaling_factor = ( + xp.array(self._reciprocal_sampling) + / xp.array(self._scan_sampling) + * scale_arrows + ) + rotated_shifts = self._xy_shifts_Ang * scaling_factor - fig, ax = plt.subplots(figsize=figsize) + else: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + + shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -2297,29 +2347,68 @@ def show_shifts( masked_ind = xp.logical_and(freq_mask, self._dp_mask) plot_ind = masked_ind[dp_mask_ind] - ax.quiver( - asnumpy(self._kxy[plot_ind, 1]), - asnumpy(self._kxy[plot_ind, 0]), - asnumpy( - self._xy_shifts[plot_ind, 1] - * scale_arrows - * self._reciprocal_sampling[0] - ), - asnumpy( - self._xy_shifts[plot_ind, 0] - * scale_arrows - * self._reciprocal_sampling[1] - ), - color=color, - angles="xy", - scale_units="xy", - scale=1, - **kwargs, - ) - kr_max = xp.max(self._kr) - ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) - ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + ax[0].quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[0].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_title("Measured Bright Field Shifts") + ax[0].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[0].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[0].set_aspect("equal") + + # passive coordinate rotation + tf_T = AffineTransform(angle=-self.rotation_Q_to_R_rads) + rotated_kxy = tf_T(self._kxy[plot_ind], xp=xp) + ax[1].quiver( + asnumpy(rotated_kxy[:, 1]), + asnumpy(rotated_kxy[:, 0]), + asnumpy(rotated_shifts[plot_ind, 1]), + asnumpy(rotated_shifts[plot_ind, 0]), + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[1].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_title("Rotated Bright Field Shifts") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[1].set_aspect("equal") + else: + ax.quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_title("Measured BF Shifts") + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.set_aspect("equal") + + fig.tight_layout() def visualize( self, @@ -2345,3 +2434,21 @@ def visualize( ax.set_title("Reconstructed Bright Field Image") return self + + @property + def object_cropped(self): + """cropped object""" + if hasattr(self, "_recon_phase_corrected"): + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_phase_corrected, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_phase_corrected) + else: + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_BF) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 3eebdb068..d29aa1747 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -433,8 +433,8 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - # probe_intensity = xp.abs(current_probe) ** 2 - # current_probe_sum = xp.sum(probe_intensity) + probe_intensity = xp.abs(current_probe) ** 2 + current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -444,10 +444,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_fourier_amplitude_constraint( self, @@ -476,7 +476,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -489,10 +489,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aperture_constraint( self, @@ -514,16 +514,16 @@ def _probe_aperture_constraint( """ xp = self._xp - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aberration_fitting_constraint( self, @@ -566,7 +566,7 @@ def _probe_aberration_fitting_constraint( xp=xp, ) - fourier_probe = fourier_probe_abs * xp.exp(1.0j * fitted_angle) + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) current_probe = xp.fft.ifft2(fourier_probe) return current_probe diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 37438852f..233d34e45 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -66,6 +66,8 @@ class SimultaneousPtychographicReconstruction(PtychographicReconstruction): object_padding_px: Tuple[int,int], optional Pixel dimensions to pad objects with If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction initial_object_guess: np.ndarray, optional Initial guess for complex-valued object of dimensions (Px,Py) If None, initialized to 1.0j @@ -102,6 +104,7 @@ def __init__( vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, @@ -150,6 +153,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -167,6 +176,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -341,6 +351,9 @@ def preprocess( ) ) + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + # 1st measurement sets rotation angle and transposition ( measurement_0, @@ -404,7 +417,11 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns + intensities_0, + com_fitted_x_0, + com_fitted_y_0, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -485,7 +502,11 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns + intensities_1, + com_fitted_x_1, + com_fitted_y_1, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -567,7 +588,11 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns + intensities_2, + com_fitted_x_2, + com_fitted_y_2, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -607,7 +632,7 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -3357,3 +3382,98 @@ def visualize( ) return self + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[0][start:end] + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped[0]) + else: + projected_cropped_potential = self.object_cropped[0] + + return projected_cropped_potential + + @property + def object_cropped(self): + """Cropped and rotated object""" + + obj_e, obj_m = self._object + obj_e = self._crop_rotate_object_fov(obj_e) + obj_m = self._crop_rotate_object_fov(obj_m) + return (obj_e, obj_m) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 5dd19d7bd..350d0a3cb 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -79,6 +79,8 @@ class SingleslicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -102,6 +104,7 @@ def __init__( initial_scan_positions: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "ptychographic_reconstruction", @@ -147,6 +150,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -163,6 +173,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -333,7 +344,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -342,7 +357,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index d29765d04..a1eb54c80 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1620,7 +1620,7 @@ def fit_aberration_surface( ): """ """ probe_amp = xp.abs(complex_probe) - probe_angle = xp.angle(complex_probe) + probe_angle = -xp.angle(complex_probe) if xp is np: probe_angle = probe_angle.astype(np.float64) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 4e99c0de5..b6077c412 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -366,7 +366,9 @@ def show( from py4DSTEM.visualize import show if show_fft: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) for a0 in range(num_images): im = show( ar[a0], diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index d1efbd023..acacb6184 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -829,7 +829,7 @@ def show_complex( for ax_flat in ax.flatten(): divider = make_axes_locatable(ax_flat) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") @@ -839,3 +839,59 @@ def show_complex( if returnfig: return fig, ax + + +def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + """ + Utility function for calculating min and max values for plotting array + based on distribution of pixel values + + Parameters + ---------- + array: np.array + array to be plotted + vmin: float + lower fraction cut off of pixel values + vmax: float + upper fraction cut off of pixel values + normalize: bool + if True, rescales from 0 to 1 + + Returns + ---------- + scaled_array: np.array + array clipped outside vmin and vmax + vmin: float + lower value to be plotted + vmax: float + upper value to be plotted + """ + + if vmin is None: + vmin = 0.02 + if vmax is None: + vmax = 0.98 + + vals = np.sort(array.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + + scaled_array = array.copy() + scaled_array = np.where(scaled_array < vmin, vmin, scaled_array) + scaled_array = np.where(scaled_array > vmax, vmax, scaled_array) + + if normalize: + scaled_array -= scaled_array.min() + scaled_array /= scaled_array.max() + vmin = 0 + vmax = 1 + + return scaled_array, vmin, vmax