From 20bb94f51191d58d9a1f2d084ec000e9fc9f688b Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Fri, 14 Jul 2023 01:58:04 -0400 Subject: [PATCH 001/176] bugfix --- py4DSTEM/preprocess/preprocess.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index 755d07ae4..f1052d72d 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -283,12 +283,6 @@ def bin_data_diffraction( # set calibration pixel size datacube.calibration.set_Q_pixel_size(Qpixsize) - # remake Cartesian coordinate system - datacube.qyy,datacube.qxx = np.meshgrid( - np.arange(0,datacube.Q_Ny), - np.arange(0,datacube.Q_Nx) - ) - # return return datacube From 7b9d330440f02bacbc75673b5a4f63a92c91b687 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Fri, 14 Jul 2023 11:36:39 -0400 Subject: [PATCH 002/176] bugfix --- py4DSTEM/datacube/virtualimage.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 6731c0fd0..d91b23906 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -325,7 +325,7 @@ def position_detector( shift_center = None, scan_position = None, invert = False, - color = 'r', + color = 'c', alpha = 0.7, **kwargs ): @@ -382,10 +382,10 @@ def position_detector( # data if data is None: keys = ['dp_mean','dp_max','dp_median'] + image = None for k in keys: - image = None try: - image = data.tree(k) + image = self.tree(k) break except: pass @@ -393,6 +393,7 @@ def position_detector( image = self[0,0] elif isinstance(data, np.ndarray): assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" + image = data elif isinstance(data,tuple): rx,ry = data[:2] image = self[rx,ry] From 6db566e0f83bd1a75e112c95ba75232f28a7be30 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 3 Aug 2023 17:32:51 -0700 Subject: [PATCH 003/176] start for depth profile --- .../iterative_multislice_ptychography.py | 68 ++++++++++++++++++- py4DSTEM/process/phase/utils.py | 14 ++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 92f8c0bf3..cce65f6c3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -29,6 +29,7 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -974,7 +975,7 @@ def _gradient_descent_adjoint( ) # back-transmit - exit_waves *= xp.conj(obj) #/ xp.abs(obj) ** 2 + exit_waves *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -1076,7 +1077,7 @@ def _projection_sets_adjoint( ) # back-transmit - exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 + exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -2841,6 +2842,67 @@ def show_slices( spec.tight_layout(fig) + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + **kwargs, + ): + """ + doc strings go here + """ + ms_obj = self.object_cropped + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + fig, ax = plt.subplots() + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + def tune_num_slices_and_thicknesses( self, num_slices_guess=None, @@ -3067,4 +3129,4 @@ def _return_object_fft( obj = np.angle(obj) obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) \ No newline at end of file + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c2e1d3b77..118e9990a 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1610,3 +1610,17 @@ def fit_aberration_surface( fitted_angle = xp.tensordot(coeff, basis, axes=1) return fitted_angle, coeff + + +def rotate_point(origin, point, angle): + """ + Rotate a point counterclockwise by a given angle around a given origin. + + The angle should be given in radians. + """ + ox, oy = origin + px, py = point + + qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) + qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) + return qx, qy From 4eb461de2ffe9d245e64f8bcf96fb1361d625c12 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 7 Aug 2023 18:06:55 -0700 Subject: [PATCH 004/176] adding real-space kde upsampling --- py4DSTEM/process/phase/iterative_parallax.py | 224 ++++++++++++++++++- 1 file changed, 219 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 80cdd8cd8..c2bfc8739 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -14,6 +14,7 @@ from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.visualize import show from scipy.linalg import polar from scipy.special import comb @@ -246,6 +247,7 @@ def preprocess( ) if normalize_images: self._stack_BF = xp.ones(stack_shape) + self._stack_BF_no_window = xp.ones(stack_shape) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -259,6 +261,14 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + elif normalize_order == 1: x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) @@ -285,9 +295,18 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) + self._stack_BF_no_window[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + else: all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -299,6 +318,14 @@ def preprocess( + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) @@ -533,9 +560,9 @@ def tune_angle_and_defocus( divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) + fig.colorbar(im, cax=cax) - plt.tight_layout() + fig.tight_layout() if return_values: convergence = np.array(convergence).reshape( @@ -548,7 +575,7 @@ def reconstruct( max_alignment_bin: int = None, min_alignment_bin: int = 1, max_iter_at_min_bin: int = 2, - upsample_factor: int = 8, + cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, running_average: bool = True, @@ -570,7 +597,7 @@ def reconstruct( Minimum bin size for bright field alignment max_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size - upsample_factor: int, optional + cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional Bernstein basis degree used for regularizing shifts @@ -730,7 +757,7 @@ def reconstruct( xy_shift = align_images_fourier( G_ref, G, - upsample_factor=upsample_factor, + upsample_factor=cross_correlation_upsample_factor, device=self._device, ) @@ -837,6 +864,193 @@ def reconstruct( return self + def subpixel_alignment( + self, + kde_upsample_factor=4, + kde_sigma=0.125, + plot_upsampled_BF_comparison: bool = True, + plot_upsampled_FFT_comparison: bool = False, + **kwargs, + ): + """ + Upsample and subpixel-align BFs using the measured image shifts. + Uses kernel density estimation (KDE) to align upsampled BFs. + + Parameters + ---------- + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma: float, optional + KDE gaussian kernel bandwidth + plot_upsampled_BF_comparison: bool, optional + If True, the pre/post alignment BF images are plotted for comparison + plot_upsampled_FFT_comparison: bool, optional + If True, the pre/post alignment BF FFTs are plotted for comparison + + """ + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + xy_shifts = self._xy_shifts + BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + + pixel_output = BF_size * kde_upsample_factor + pixel_size = pixel_output.prod() + + # shifted coordinates + x = xp.arange(BF_size[0]) + y = xp.arange(BF_size[1]) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + xa = ((xa + xy_shifts[:, 0, None, None]) * kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * kde_upsample_factor).ravel() + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + sigma = kde_sigma * kde_upsample_factor + pix_count = gaussian_filter(pix_count, sigma) + pix_count[pix_output == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, sigma) + pix_output /= pix_count + + self._recon_BF_subpixel_aligned = pix_output + self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + + # plotting + if plot_upsampled_BF_comparison: + if plot_upsampled_FFT_comparison: + figsize = kwargs.pop("figsize", (8, 8)) + fig, axs = plt.subplots(2, 2, figsize=figsize) + else: + figsize = kwargs.pop("figsize", (8, 4)) + fig, axs = plt.subplots(1, 2, figsize=figsize) + + axs = axs.flat + cmap = kwargs.pop("cmap", "magma") + + cropped_object = self._crop_padded_object(self._recon_BF) + upsampled_pad_x = self._object_padding_px[0] * kde_upsample_factor // 2 + upsampled_pad_y = self._object_padding_px[1] * kde_upsample_factor // 2 + cropped_object_aligned = self.recon_BF_subpixel_aligned[ + upsampled_pad_x:-upsampled_pad_x, + upsampled_pad_y:-upsampled_pad_y, + ] + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + for ax in axs[:2]: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if plot_upsampled_FFT_comparison: + recon_fft = xp.fft.fft2(self._recon_BF) + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = BF_size[0] * (kde_upsample_factor - 1) // 2 + pad_y = BF_size[1] * (kde_upsample_factor - 1) // 2 + pad_recon_fft = asnumpy( + xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + + reciprocal_extent = [ + 0, + self._reciprocal_sampling[1] * cropped_object_aligned.shape[1], + self._reciprocal_sampling[0] * cropped_object_aligned.shape[0], + 0, + ] + + show( + pad_recon_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Aligned Bright Field FFT", + **kwargs, + ) + + show( + upsampled_fft, + figax=(fig, axs[3]), + extent=reciprocal_extent, + cmap="gray", + title="Upsampled Bright Field FFT", + **kwargs, + ) + + for ax in axs[2:]: + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.xaxis.set_ticks_position("bottom") + + fig.tight_layout() + def aberration_fit( self, plot_CTF_compare: bool = False, From 8c61a5b6591d410ea27377ce264eba839313c2e7 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 8 Aug 2023 11:58:24 -0700 Subject: [PATCH 005/176] more depth profile --- .../iterative_multislice_ptychography.py | 78 +++++++++++++++---- py4DSTEM/process/phase/utils.py | 16 +++- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index cce65f6c3..111f012ad 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2852,12 +2852,31 @@ def show_depth( ms_object=None, cbar: bool = False, aspect: float = None, + plot_line_profile: bool = False, **kwargs, ): """ - doc strings go here + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats + line profile for dpeth seciton runs from (x1,y1) to (x2,y2) + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken """ - ms_obj = self.object_cropped + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped angle = np.arctan((x2 - x1) / (y2 - y1)) x0 = ms_obj.shape[1] / 2 @@ -2879,6 +2898,7 @@ def show_depth( if gaussian_filter_sigma is not None: from scipy.ndimage import gaussian_filter + gaussian_filter_sigma /= self.sampling[0] rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] @@ -2890,18 +2910,48 @@ def show_depth( 0, ] - fig, ax = plt.subplots() - im = ax.imshow(plot_im, cmap="magma", extent=extent) - if aspect is not None: - ax.set_aspect(aspect) - ax.set_xlabel("y [A]") - ax.set_ylabel("x [A]") - ax.set_title("Multislice depth profile") - if cbar: - divider = make_axes_locatable(ax) - ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - fig.add_axes(ax_cb) - fig.colorbar(im, cax=ax_cb) + if plot_line_profile == False: + fig, ax = plt.subplots() + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("y [A]") + ax.set_ylabel("x [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[0] * ms_obj.shape[1], + self.sampling[1] * ms_obj.shape[2], + 0, + ] + fig, ax = plt.subplots(2, 1) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 / self.sampling[0], y2 / self.sampling[1]], + [x1 / self.sampling[0], x2 / self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("y [A]") + ax[1].set_ylabel("x [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) def tune_num_slices_and_thicknesses( self, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 118e9990a..a8a702f89 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1614,9 +1614,21 @@ def fit_aberration_surface( def rotate_point(origin, point, angle): """ - Rotate a point counterclockwise by a given angle around a given origin. + Rotate a point (x1, y1) counterclockwise by a given angle around + a given origin (x0, y0). + + Parameters + -------- + origin: 2-tuple of floats + (x0, y0) + point: 2-tuple of floats + (x1, y1) + angle: float (radians) + + Returns + -------- + rotated points (2-tuple) - The angle should be given in radians. """ ox, oy = origin px, py = point From 22d1bb489320cbcfcd0a089ea7ca2e384d542536 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 8 Aug 2023 12:01:38 -0700 Subject: [PATCH 006/176] saving error --- .../process/phase/iterative_multislice_ptychography.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 111f012ad..e000910c3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2860,8 +2860,8 @@ def show_depth( Parameters -------- - x1, x2, y1, y2: floats - line profile for dpeth seciton runs from (x1,y1) to (x2,y2) + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) gaussian_filter_sigma: float (optional) Standard deviation of gaussian kernel in A ms_object: np.array @@ -2933,8 +2933,8 @@ def show_depth( fig, ax = plt.subplots(2, 1) ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) ax[0].plot( - [y1 / self.sampling[0], y2 / self.sampling[1]], - [x1 / self.sampling[0], x2 / self.sampling[1]], + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], color="red", ) ax[0].set_xlabel("y [A]") @@ -2952,6 +2952,7 @@ def show_depth( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) + plt.tight_layout() def tune_num_slices_and_thicknesses( self, From 5f919bdfb26196e30409a2852b633d413ebd2e13 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 8 Aug 2023 12:04:18 -0700 Subject: [PATCH 007/176] small name changes --- .../process/phase/iterative_multislice_ptychography.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index e000910c3..438c9d1fb 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2915,8 +2915,8 @@ def show_depth( im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax.set_aspect(aspect) - ax.set_xlabel("y [A]") - ax.set_ylabel("x [A]") + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") ax.set_title("Multislice depth profile") if cbar: divider = make_axes_locatable(ax) @@ -2944,8 +2944,8 @@ def show_depth( im = ax[1].imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax[1].set_aspect(aspect) - ax[1].set_xlabel("y [A]") - ax[1].set_ylabel("x [A]") + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") ax[1].set_title("Multislice depth profile") if cbar: divider = make_axes_locatable(ax[1]) From fa4736ffe9ea2535ed50c9cfe6f4250299f7f050 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 16:30:39 -0700 Subject: [PATCH 008/176] updates to polardata --- py4DSTEM/braggvectors/braggvector_methods.py | 243 ++----- py4DSTEM/braggvectors/braggvectors.py | 29 +- py4DSTEM/data/calibration.py | 10 +- py4DSTEM/datacube/datacube.py | 35 +- py4DSTEM/datacube/virtualimage.py | 17 +- py4DSTEM/io/filereaders/__init__.py | 3 +- py4DSTEM/io/filereaders/read_abTEM.py | 81 +++ py4DSTEM/io/filereaders/read_arina.py | 115 ++++ py4DSTEM/io/google_drive_downloader.py | 57 ++ py4DSTEM/io/importfile.py | 37 +- .../legacy/legacy13/v13_emd_classes/array.py | 6 +- py4DSTEM/io/parsefiletype.py | 84 ++- py4DSTEM/io/read.py | 124 ++-- py4DSTEM/preprocess/preprocess.py | 8 +- py4DSTEM/process/calibration/origin.py | 9 +- py4DSTEM/process/diffraction/crystal.py | 150 ++++- py4DSTEM/process/diffraction/crystal_ACOM.py | 48 +- .../process/diffraction/crystal_calibrate.py | 35 +- py4DSTEM/process/diffraction/crystal_viz.py | 80 ++- py4DSTEM/process/diffraction/flowlines.py | 24 +- py4DSTEM/process/fit/fit.py | 4 - py4DSTEM/process/latticevectors/fit.py | 15 +- py4DSTEM/process/latticevectors/index.py | 57 +- py4DSTEM/process/latticevectors/strain.py | 33 +- .../iterative_multislice_ptychography.py | 6 +- py4DSTEM/process/polar/__init__.py | 2 +- py4DSTEM/process/polar/polar_analysis.py | 46 ++ py4DSTEM/process/polar/polar_datacube.py | 16 +- py4DSTEM/process/polar/polar_fits.py | 51 +- py4DSTEM/process/polar/polar_peaks.py | 114 +++- py4DSTEM/process/rdf/amorph.py | 2 +- py4DSTEM/process/strain.py | 617 +++++++++++++++--- py4DSTEM/process/utils/utils.py | 4 +- py4DSTEM/version.py | 2 +- py4DSTEM/visualize/overlay.py | 22 +- py4DSTEM/visualize/show.py | 8 +- py4DSTEM/visualize/vis_special.py | 96 --- setup.py | 1 + test/gettestdata.py | 21 +- test/test_nonnative_io/test_arina.py | 19 + test/test_strain.py | 2 + 41 files changed, 1613 insertions(+), 720 deletions(-) create mode 100644 py4DSTEM/io/filereaders/read_abTEM.py create mode 100644 py4DSTEM/io/filereaders/read_arina.py create mode 100644 test/test_nonnative_io/test_arina.py diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 53f800ed1..be766ad49 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -186,7 +186,11 @@ def get_virtual_image( mode = None, geometry = None, name = 'bragg_virtual_image', - returncalc = True + returncalc = True, + center = True, + ellipse = True, + pixel = True, + rotate = True, ): ''' Calculates a virtual image based on the values of the Braggvectors @@ -204,13 +208,22 @@ def get_virtual_image( - 'circle', 'circular': nested 2-tuple, ((qx,qy),radius) - 'annular' or 'annulus': nested 2-tuple, ((qx,qy),(radius_i,radius_o)) - All values are in pixels. Note that (qx,qy) can be skipped, which - assumes peaks centered at (0,0) + Values can be in pixels or calibrated units. Note that (qx,qy) + can be skipped, which assumes peaks centered at (0,0). + center: bool + Apply calibration - center coordinate. + ellipse: bool + Apply calibration - elliptical correction. + pixel: bool + Apply calibration - pixel size. + rotate: bool + Apply calibration - QR rotation. Returns ------- virtual_im : VirtualImage ''' + # parse inputs circle_modes = ['circular','circle'] annulus_modes = ['annular','annulus'] @@ -220,13 +233,13 @@ def get_virtual_image( # set geometry if mode is None: if geometry is None: - center = None + qxy_center = None radial_range = np.array((0,np.inf)) else: if len(geometry[0]) == 0: - center = None + qxy_center = None else: - center = np.array(geometry[0]) + qxy_center = np.array(geometry[0]) if isinstance(geometry[1], int) or isinstance(geometry[1], float): radial_range = np.array((0,geometry[1])) elif len(geometry[1]) == 0: @@ -236,30 +249,44 @@ def get_virtual_image( elif mode == 'circular' or mode == 'circle': radial_range = np.array((0,geometry[1])) if len(geometry[0]) == 0: - center = None + qxy_center = None else: - center = np.array(geometry[0]) + qxy_center = np.array(geometry[0]) elif mode == 'annular' or mode == 'annulus': radial_range = np.array(geometry[1]) if len(geometry[0]) == 0: - center = None + qxy_center = None else: - center = np.array(geometry[0]) + qxy_center = np.array(geometry[0]) # allocate space im_virtual = np.zeros(self.shape) # generate image - for rx,ry in tqdmnd(self.shape[0],self.shape[1]): - p = self.raw[rx,ry] + for rx,ry in tqdmnd( + self.shape[0], + self.shape[1], + ): + # Get user-specified Bragg vectors + p = self.get_vectors( + rx, + ry, + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate, + ) + if p.data.shape[0] > 0: if radial_range is None: im_virtual[rx,ry] = np.sum(p.I) else: - if center is None: + if qxy_center is None: qr = np.hypot(p.qx,p.qy) else: - qr = np.hypot(p.qx - center[0],p.qy - center[1]) + qr = np.hypot( + p.qx - qxy_center[0], + p.qy - qxy_center[1]) sub = np.logical_and( qr >= radial_range[0], qr < radial_range[1]) @@ -284,7 +311,7 @@ def get_virtual_image( } ) # attach to the tree - self.attach( ans) + self.attach(ans) # return if returncalc: @@ -634,192 +661,6 @@ def fit_p_ellipse( if returncalc: return p_ellipse - - # Deprecated?? - - def index_bragg_directions( - self, - x0 = None, - y0 = None, - plot = True, - bvm_vis_params = {}, - returncalc = False, - ): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - Plot (bool): plot results - """ - - if x0 is None: - x0 = self.Qshape[0]/2 - if y0 is None: - y0 = self.Qshape[0]/2 - - from py4DSTEM.process.latticevectors import index_bragg_directions - _, _, braggdirections = index_bragg_directions( - x0, - y0, - self.g['x'], - self.g['y'], - self.g1, - self.g2 - ) - - self.braggdirections = braggdirections - - if plot: - from py4DSTEM.visualize import show_bragg_indexing - show_bragg_indexing( - self.bvm_centered, - **bvm_vis_params, - braggdirections = braggdirections, - points = True - ) - - if returncalc: - return braggdirections - - - - def add_indices_to_braggpeaks( - self, - maxPeakSpacing, - mask = None, - returncalc = False, - ): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - """ - from py4DSTEM.process.latticevectors import add_indices_to_braggpeaks - - bragg_peaks_indexed = add_indices_to_braggpeaks( - self.vectors, - self.braggdirections, - maxPeakSpacing = maxPeakSpacing, - qx_shift = self.Qshape[0]/2, - qy_shift = self.Qshape[1]/2, - ) - - self.bragg_peaks_indexed = bragg_peaks_indexed - - if returncalc: - return bragg_peaks_indexed - - - def fit_lattice_vectors_all_DPs(self, returncalc = False): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - - """ - - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_peaks_indexed) - self.g1g2_map = g1g2_map - if returncalc: - return g1g2_map - - def get_strain_from_reference_region(self, mask, returncalc = False): - """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Args: - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - """ - from py4DSTEM.process.latticevectors import get_strain_from_reference_region - - strainmap_median_g1g2 = get_strain_from_reference_region( - self.g1g2_map, - mask = mask, - ) - - self.strainmap_median_g1g2 = strainmap_median_g1g2 - - if returncalc: - return strainmap_median_g1g2 - - - def get_strain_from_reference_g1g2(self, mask, returncalc = False): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - - Args: - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - """ - from py4DSTEM.process.latticevectors import get_reference_g1g2 - g1_ref,g2_ref = get_reference_g1g2(self.g1g2_map, mask) - - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - strainmap_reference_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - - self.strainmap_reference_g1g2 = strainmap_reference_g1g2 - - if returncalc: - return strainmap_reference_g1g2 - - def get_rotated_strain_map(self, mode, g_reference = None, returncalc = True, flip_theta = False): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - g_referencce (tupe): reference coordinate system for xaxis_x and xaxis_y - """ - - assert mode in ("median","reference") - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) - - from py4DSTEM.process.latticevectors import get_rotated_strain_map - - if mode == "median": - strainmap_raw = self.strainmap_median_g1g2 - elif mode == "reference": - strainmap_raw = self.strainmap_reference_g1g2 - - strainmap = get_rotated_strain_map( - strainmap_raw, - xaxis_x = g_reference[0], - xaxis_y = g_reference[1], - flip_theta = flip_theta - ) - - if returncalc: - return strainmap - - def mask_in_Q( self, mask, diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 14b89fd98..f1ff406d0 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -65,7 +65,7 @@ def __init__( Rshape, Qshape, name = 'braggvectors', - verbose = True, + verbose = False, calibration = None ): Custom.__init__(self,name=name) @@ -236,9 +236,16 @@ def setcal( "rotate" : rotate, } if self.verbose: - print('current calstate: ', self.calstate) + print('current calibration state: ', self.calstate) pass + def calibrate(self): + """ + Autoupdate the calstate when relevant calibrations are set + """ + self.setcal() + + # vector getter method @@ -250,7 +257,7 @@ def get_vectors( ellipse, pixel, rotate - ): + ): """ Returns the bragg vectors at the specified scan position with the specified calibration state. @@ -268,6 +275,7 @@ def get_vectors( ------- vectors : BVects """ + ans = self._v_uncal[scan_x,scan_y].data ans = self.cal._transform( data = ans, @@ -282,17 +290,16 @@ def get_vectors( # copy - def copy(self, name=None): name = name if name is not None else self.name+"_copy" braggvector_copy = BraggVectors( - self.Rshape, - self.Qshape, - name=name, + self.Rshape, + self.Qshape, + name=name, calibration = self.calibration.copy() ) - - braggvector_copy._v_uncal = self._v_uncal.copy() + + braggvector_copy.set_raw_vectors( self._v_uncal.copy() ) for k in self.metadata.keys(): braggvector_copy.metadata = self.metadata[k].copy() return braggvector_copy @@ -526,6 +533,4 @@ def _transform( # return - return ans - - + return ans \ No newline at end of file diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py index 6077864f3..50ec8f6f9 100644 --- a/py4DSTEM/data/calibration.py +++ b/py4DSTEM/data/calibration.py @@ -64,8 +64,8 @@ class Calibration(Metadata): theta, * p_ellipse, * ellipse, * - QR_rotation_degrees, - QR_flip, + QR_rotation_degrees, * + QR_flip, * QR_rotflip, * probe_semiangle, probe_param, @@ -598,11 +598,13 @@ def ellipse(self,x): # Q/R-space rotation and flip + @call_calibrate def set_QR_rotation_degrees(self,x): self._params['QR_rotation_degrees'] = x def get_QR_rotation_degrees(self): return self._get_value('QR_rotation_degrees') + @call_calibrate def set_QR_flip(self,x): self._params['QR_flip'] = x def get_QR_flip(self): @@ -617,8 +619,8 @@ def set_QR_rotflip(self, rot_flip): flip (bool): True indicates a Q/R axes flip """ rot,flip = rot_flip - self.set_QR_rotation_degrees(rot) - self.set_QR_flip(flip) + self._params['QR_rotation_degrees'] = rot + self._params['QR_flip'] = flip def get_QR_rotflip(self): rot = self.get_QR_rotation_degrees() flip = self.get_QR_flip() diff --git a/py4DSTEM/datacube/datacube.py b/py4DSTEM/datacube/datacube.py index 81db0ed9b..ae3a82a36 100644 --- a/py4DSTEM/datacube/datacube.py +++ b/py4DSTEM/datacube/datacube.py @@ -2,8 +2,8 @@ import numpy as np from scipy.interpolate import interp1d -from scipy.ndimage import (binary_opening, binary_dilation,distance_transform_edt, - binary_fill_holes, gaussian_filter1d,gaussian_filter) +from scipy.ndimage import (binary_opening, binary_dilation, + distance_transform_edt, binary_fill_holes, gaussian_filter1d, gaussian_filter) from typing import Optional,Union from emdfile import Array, Metadata, Node, Root, tqdmnd @@ -125,6 +125,9 @@ def calibrate(self): self._qxx,self._qyy = np.meshgrid( dim_qx,dim_qy ) self._rxx,self._ryy = np.meshgrid( dim_rx,dim_ry ) + self._qyy_raw,self._qxx_raw = np.meshgrid( np.arange(self.Q_Ny),np.arange(self.Q_Nx) ) + self._ryy_raw,self._rxx_raw = np.meshgrid( np.arange(self.R_Ny),np.arange(self.R_Nx) ) + # coordinate meshgrids @@ -140,6 +143,18 @@ def qxx(self): @property def qyy(self): return self._qyy + @property + def rxx_raw(self): + return self._rxx_raw + @property + def ryy_raw(self): + return self._ryy_raw + @property + def qxx_raw(self): + return self._qxx_raw + @property + def qyy_raw(self): + return self._qyy_raw # coordinate meshgrids with shifted origin def qxxs(self,rx,ry): @@ -1061,11 +1076,12 @@ def get_beamstop_mask( # im = self.tree["dp_max"].data.astype('float') if not "dp_max" in self._branch.keys(): self.get_dp_max(); - im = self.tree("dp_max").data.astype('float') + im = self.tree("dp_max").data.copy().astype('float') else: if not "dp_mean" in self._branch.keys(): self.get_dp_mean(); - im = self.tree("dp_mean").data + im = self.tree("dp_mean").data.copy() + # if not "dp_mean" in self.tree.keys(): # self.get_dp_mean(); # im = self.tree["dp_mean"].data.astype('float') @@ -1119,7 +1135,7 @@ def get_beamstop_mask( ) # Add to tree - self.attach( x ) + self.tree(x) # return if returncalc: @@ -1170,7 +1186,7 @@ def get_radial_bkgrnd( # define the 2D cartesian coordinate system origin = self.calibration.get_origin() origin = origin[0][rx,ry],origin[1][rx,ry] - qxx,qyy = self.qxx-origin[0], self.qyy-origin[1] + qxx,qyy = self.qxx_raw-origin[0], self.qyy_raw-origin[1] # get distance qr in polar-elliptical coords ellipse = self.calibration.get_ellipse() @@ -1455,9 +1471,6 @@ def get_braggmask( vects = braggvectors.raw[rx,ry] # loop for idx in range(len(vects.data)): - qr = np.hypot(self.qxx-vects.qx[idx], self.qyy-vects.qy[idx]) + qr = np.hypot(self.qxx_raw-vects.qx[idx], self.qyy_raw-vects.qy[idx]) mask = np.logical_and(mask, qr>radius) - return mask - - - + return mask \ No newline at end of file diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index d91b23906..5e2681eb6 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -10,7 +10,7 @@ import inspect from emdfile import tqdmnd,Metadata -from py4DSTEM.data import Calibration, RealSlice, Data +from py4DSTEM.data import Calibration, RealSlice, Data, DiffractionSlice from py4DSTEM.visualize.show import show @@ -322,7 +322,7 @@ def position_detector( data = None, centered = None, calibrated = None, - shift_center = None, + shift_center = False, scan_position = None, invert = False, color = 'c', @@ -381,6 +381,7 @@ def position_detector( # data if data is None: + image = None keys = ['dp_mean','dp_max','dp_median'] image = None for k in keys: @@ -389,11 +390,14 @@ def position_detector( break except: pass - if image is None: - image = self[0,0] + if image is None: + image = self[0,0] elif isinstance(data, np.ndarray): assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" image = data + elif isinstance(data, DiffractionSlice): + assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" + image = data.data elif isinstance(data,tuple): rx,ry = data[:2] image = self[rx,ry] @@ -402,10 +406,7 @@ def position_detector( # shift center if shift_center is None: - if isinstance(data,tuple): - shift_center = True - else: - shift_center = False + shift_center = False elif shift_center == True: assert(isinstance(data,tuple)), "If shift_center is set to True, `data` should be a 2-tuple (rx,ry). To shift the detector mask while using some other input for `data`, set `shift_center` to a 2-tuple (rx,ry)" elif isinstance(shift_center,tuple): diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index d256334a8..b6f4eb0a2 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -2,4 +2,5 @@ from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib - +from py4DSTEM.io.filereaders.read_arina import read_arina +from py4DSTEM.io.filereaders.read_abTEM import read_abTEM diff --git a/py4DSTEM/io/filereaders/read_abTEM.py b/py4DSTEM/io/filereaders/read_abTEM.py new file mode 100644 index 000000000..1fec9e73e --- /dev/null +++ b/py4DSTEM/io/filereaders/read_abTEM.py @@ -0,0 +1,81 @@ +import h5py +from py4DSTEM.data import DiffractionSlice, RealSlice +from py4DSTEM.datacube import DataCube + +def read_abTEM( + filename, + mem="RAM", + binfactor: int = 1, +): + """ + File reader for abTEM datasets + Args: + filename: str with path to file + mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is + loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP" + leaves the data in storage and creates a memory map which points to + the diffraction patterns, allowing them to be retrieved individually + from storage. + binfactor (int): Diffraction space binning factor for bin-on-load. + + Returns: + DataCube + """ + assert mem == "RAM", "read_abTEM does not support memory mapping" + assert binfactor == 1, "abTEM files can only be read at full resolution" + + with h5py.File(filename, "r") as f: + datasets = {} + for key in f.keys(): + datasets[key] = f.get(key)[()] + + data = datasets["array"] + + sampling = datasets["sampling"] + units = datasets["units"] + + assert len(data.shape) in (2, 4), "abtem reader supports only 4D and 2D data" + + if len(data.shape) == 4: + + datacube = DataCube(data=data) + + datacube.calibration.set_R_pixel_size(sampling[0]) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration" + ) + datacube.calibration.set_Q_pixel_size(sampling[2]) + if sampling[2] != sampling[3]: + print( + "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with qx calibration" + ) + + if units[0] == b"\xc3\x85": + datacube.calibration.set_R_pixel_units("A") + else: + datacube.calibration.set_R_pixel_units(units[0].decode("utf-8")) + + datacube.calibration.set_Q_pixel_units(units[2].decode("utf-8")) + + return datacube + + else: + if units[0] == b"mrad": + diffraction = DiffractionSlice(data=data) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with x calibration" + ) + diffraction.calibration.set_Q_pixel_units(units[0].decode("utf-8")) + diffraction.calibration.set_Q_pixel_size(sampling[0]) + return diffraction + else: + image = RealSlice(data=data) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration" + ) + image.calibration.set_Q_pixel_units("A") + image.calibration.set_Q_pixel_size(sampling[0]) + return image diff --git a/py4DSTEM/io/filereaders/read_arina.py b/py4DSTEM/io/filereaders/read_arina.py new file mode 100644 index 000000000..323b5643f --- /dev/null +++ b/py4DSTEM/io/filereaders/read_arina.py @@ -0,0 +1,115 @@ +import h5py +import hdf5plugin +import numpy as np +from py4DSTEM.datacube import DataCube +from py4DSTEM.preprocess.utils import bin2D + + +def read_arina( + filename, + scan_width=1, + mem="RAM", + binfactor: int = 1, + dtype_bin: float = None, + flatfield: np.ndarray = None, +): + + """ + File reader for arina 4D-STEM datasets + Args: + filename: str with path to master file + scan_width: x dimension of scan + mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is + loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP" + leaves the data in storage and creates a memory map which points to + the diffraction patterns, allowing them to be retrieved individually + from storage. + binfactor (int): Diffraction space binning factor for bin-on-load. + dtype_bin(float): specify datatype for bin on load if need something + other than uint16 + flatfield (np.ndarray): + flatfield forcorrection factors + + Returns: + DataCube + """ + assert mem == "RAM", "read_arina does not support memory mapping" + + f = h5py.File(filename, "r") + nimages = 0 + + # Count the number of images in all datasets + for dset in f["entry"]["data"]: + nimages = nimages + f["entry"]["data"][dset].shape[0] + height = f["entry"]["data"][dset].shape[1] + width = f["entry"]["data"][dset].shape[2] + dtype = f["entry"]["data"][dset].dtype + + width = width // binfactor + height = height // binfactor + + assert ( + nimages % scan_width < 1e-6 + ), "scan_width must be integer multiple of x*y size" + + if dtype.type is np.uint32: + print("Dataset is uint32 but will be converted to uint16") + dtype = np.dtype(np.uint16) + + if dtype_bin: + array_3D = np.empty((nimages, width, height), dtype=dtype_bin) + else: + array_3D = np.empty((nimages, width, height), dtype=dtype) + + image_index = 0 + + if flatfield is None: + correction_factors = 1 + else: + # Avoid div by 0 errors -> pixel with value 0 will be set to meadian + flatfield[flatfield == 0] = 1 + correction_factors = np.median(flatfield) / flatfield + + for dset in f["entry"]["data"]: + image_index = _processDataSet( + f["entry"]["data"][dset], + image_index, + array_3D, + binfactor, + correction_factors, + ) + + if f.__bool__(): + f.close() + + scan_height = int(nimages / scan_width) + + datacube = DataCube( + np.flip( + array_3D.reshape( + scan_width, scan_height, array_3D.data.shape[1], array_3D.data.shape[2] + ), + 0, + ) + ) + + return datacube + + +def _processDataSet(dset, start_index, array_3D, binfactor, correction_factors): + image_index = start_index + nimages_dset = dset.shape[0] + + for i in range(nimages_dset): + if binfactor == 1: + array_3D[image_index] = np.multiply( + dset[i].astype(array_3D.dtype), correction_factors + ) + else: + array_3D[image_index] = bin2D( + np.multiply(dset[i].astype(array_3D.dtype), correction_factors), + binfactor, + ) + + image_index = image_index + 1 + return image_index diff --git a/py4DSTEM/io/google_drive_downloader.py b/py4DSTEM/io/google_drive_downloader.py index 51e3a70d7..86ad1a9f4 100644 --- a/py4DSTEM/io/google_drive_downloader.py +++ b/py4DSTEM/io/google_drive_downloader.py @@ -83,6 +83,50 @@ 'test_realslice_io.h5', '1siH80-eRJwG5R6AnU4vkoqGWByrrEz1y' ), + 'test_arina_master' : ( + 'STO_STEM_bench_20us_master.h5', + '1q_4IjFuWRkw5VM84NhxrNTdIq4563BOC' + ), + 'test_arina_01' : ( + 'STO_STEM_bench_20us_data_000001.h5', + '1_3Dbm22-hV58iffwK9x-3vqJUsEXZBFQ' + ), + 'test_arina_02' : ( + 'STO_STEM_bench_20us_data_000002.h5', + '1x29RzHLnCzP0qthLhA1kdlUQ09ENViR8' + ), + 'test_arina_03' : ( + 'STO_STEM_bench_20us_data_000003.h5', + '1qsbzdEVD8gt4DYKnpwjfoS_Mg4ggObAA' + ), + 'test_arina_04' : ( + 'STO_STEM_bench_20us_data_000004.h5', + '1Lcswld0Y9fNBk4-__C9iJbc854BuHq-h' + ), + 'test_arina_05' : ( + 'STO_STEM_bench_20us_data_000005.h5', + '13YTO2ABsTK5nObEr7RjOZYCV3sEk3gt9' + ), + 'test_arina_06' : ( + 'STO_STEM_bench_20us_data_000006.h5', + '1RywPXt6HRbCvjgjSuYFf60QHWlOPYXwy' + ), + 'test_arina_07' : ( + 'STO_STEM_bench_20us_data_000007.h5', + '1GRoBecCvAUeSIujzsPywv1vXKSIsNyoT' + ), + 'test_arina_08' : ( + 'STO_STEM_bench_20us_data_000008.h5', + '1sTFuuvgKbTjZz1lVUfkZbbTDTQmwqhuU' + ), + 'test_arina_09' : ( + 'STO_STEM_bench_20us_data_000009.h5', + '1JmBiMg16iMVfZ5wz8z_QqcNPVRym1Ezh' + ), + 'test_arina_10' : ( + 'STO_STEM_bench_20us_data_000010.h5', + '1_90xAfclNVwMWwQ-YKxNNwBbfR1nfHoB' + ), 'test_strain' : ( 'downsample_Si_SiGe_analysis_braggdisks_cal.h5', '1bYgDdAlnWHyFmY-SwN3KVpMutWBI5MhP' @@ -112,6 +156,19 @@ 'legacy_v0.14', 'test_realslice_io', ), + 'test_arina' : ( + 'test_arina_master', + 'test_arina_01', + 'test_arina_02', + 'test_arina_03', + 'test_arina_04', + 'test_arina_05', + 'test_arina_06', + 'test_arina_07', + 'test_arina_08', + 'test_arina_09', + 'test_arina_10', + ), 'test_braggvectors' : ( 'Au_sim', ), diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index 17b052601..20a3759a2 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -1,18 +1,18 @@ # Reader functions for non-native file types import pathlib -from os.path import exists, splitext -from typing import Union, Optional +from os.path import exists +from typing import Optional, Union -from py4DSTEM.io.parsefiletype import _parse_filetype from py4DSTEM.io.filereaders import ( - read_empad, + load_mib, + read_abTEM, + read_arina, read_dm, + read_empad, read_gatan_K2_bin, - load_mib ) - - +from py4DSTEM.io.parsefiletype import _parse_filetype def import_file( @@ -37,6 +37,7 @@ def import_file( from storage. binfactor (int): Diffraction space binning factor for bin-on-load. filetype (str): Used to override automatic filetype detection. + options include "dm", "empad", "gatan_K2_bin", "mib", "arina", "abTEM" **kwargs: any additional kwargs are passed to the downstream reader - refer to the individual filetype reader function call signatures and docstrings for more details. @@ -55,9 +56,7 @@ def import_file( "RAM", "MEMMAP", ], 'Error: argument mem must be either "RAM" or "MEMMAP"' - assert isinstance( - binfactor, int - ), "Error: argument binfactor must be an integer" + assert isinstance(binfactor, int), "Error: argument binfactor must be an integer" assert binfactor >= 1, "Error: binfactor must be >= 1" if binfactor > 1: assert ( @@ -66,13 +65,17 @@ def import_file( filetype = _parse_filetype(filepath) if filetype is None else filetype - if filetype == 'EMD': - raise Exception("EMD file detected - use py4DSTEM.read, not py4DSTEM.import_file!") + if filetype in ("emd", "legacy"): + raise Exception( + "EMD file or py4DSTEM detected - use py4DSTEM.read, not py4DSTEM.import_file!" + ) assert filetype in [ "dm", "empad", "gatan_K2_bin", - "mib" + "mib", + "arina", + "abTEM" # "kitware_counted", ], "Error: filetype not recognized" @@ -85,10 +88,12 @@ def import_file( # elif filetype == "kitware_counted": # data = read_kitware_counted(filepath, mem, binfactor, metadata=metadata, **kwargs) elif filetype == "mib": - data = load_mib(filepath, mem=mem, binfactor=binfactor,**kwargs) + data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs) + elif filetype == "arina": + data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs) + elif filetype == "abTEM": + data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs) else: raise Exception("Bad filetype!") return data - - diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py index 4b385a694..8b20779f8 100644 --- a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py +++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py @@ -322,8 +322,10 @@ def set_dim( values for the n'th dim vector. Accepts: n (int): specifies which dim vector - dim (list or array): length must be either 2, or equal to the - length of the n'th axis of the data array + dim (list or array): length must be either 1 or 2, or equal to the + length of the n'th axis of the data array. If length is 1 specifies step + size of dim vector and starts at 0. If length is 2, specifies start + and step of dim vector. units (Optional, str): name: (Optional, str): """ diff --git a/py4DSTEM/io/parsefiletype.py b/py4DSTEM/io/parsefiletype.py index 5903ce814..1838f89b6 100644 --- a/py4DSTEM/io/parsefiletype.py +++ b/py4DSTEM/io/parsefiletype.py @@ -1,9 +1,18 @@ # File parser utility from os.path import splitext +import py4DSTEM.io.legacy as legacy +import emdfile as emd +import h5py + +import emdfile as emd +import h5py +import py4DSTEM.io.legacy as legacy + def _parse_filetype(fp): - """ Accepts a path to a data file, and returns the file type as a string. + """ + Accepts a path to a data file, and returns the file type as a string. """ _, fext = splitext(fp) fext = fext.lower() @@ -13,7 +22,20 @@ def _parse_filetype(fp): ".py4dstem", ".emd", ]: - return "H5" + if emd._is_EMD_file(fp): + return "emd" + + elif legacy.is_py4DSTEM_file(fp): + return "legacy" + + elif _is_arina(fp): + return "arina" + + elif _is_abTEM(fp): + return "abTEM" + else: + raise Exception("not supported `h5` data type") + elif fext in [ ".dm", ".dm3", @@ -21,17 +43,67 @@ def _parse_filetype(fp): ]: return "dm" elif fext in [".raw"]: - return "empad" + return "empad" elif fext in [".mrc"]: - return "mrc_relativity" + return "mrc_relativity" elif fext in [".gtg", ".bin"]: - return "gatan_K2_bin" + return "gatan_K2_bin" elif fext in [".kitware_counted"]: - return "kitware_counted" + return "kitware_counted" elif fext in [".mib", ".MIB"]: return "mib" else: raise Exception(f"Unrecognized file extension {fext}.") +def _is_arina(filepath): + """ + Check if an h5 file is an Arina file. + """ + with h5py.File(filepath,'r') as f: + try: + assert("entry" in f.keys()) + except AssertionError: + return False + try: + assert("NX_class" in f["entry"].attrs.keys()) + except AssertionError: + return False + return True + +def _is_abTEM(filepath): + """ + Check if an h5 file is an abTEM file. + """ + with h5py.File(filepath,'r') as f: + try: + assert("array" in f.keys()) + except AssertionError: + return False + return True + +def _is_arina(filepath): + """ + Check if an h5 file is an Arina file. + """ + with h5py.File(filepath, "r") as f: + try: + assert "entry" in f.keys() + except AssertionError: + return False + try: + assert "NX_class" in f["entry"].attrs.keys() + except AssertionError: + return False + return True +def _is_abTEM(filepath): + """ + Check if an h5 file is an abTEM file. + """ + with h5py.File(filepath, "r") as f: + try: + assert "array" in f.keys() + except AssertionError: + return False + return True diff --git a/py4DSTEM/io/read.py b/py4DSTEM/io/read.py index 79291fe86..bab555eaf 100644 --- a/py4DSTEM/io/read.py +++ b/py4DSTEM/io/read.py @@ -1,25 +1,23 @@ # Reader for native files -from pathlib import Path -from os.path import exists -from typing import Optional,Union import warnings +from os.path import exists +from pathlib import Path +from typing import Optional, Union -import py4DSTEM import emdfile as emd -from py4DSTEM.io.parsefiletype import _parse_filetype import py4DSTEM.io.legacy as legacy - - +from py4DSTEM.data import Data +from py4DSTEM.io.parsefiletype import _parse_filetype def read( - filepath: Union[str,Path], + filepath: Union[str, Path], datapath: Optional[str] = None, - tree: Optional[Union[bool,str]] = True, + tree: Optional[Union[bool, str]] = True, verbose: Optional[bool] = False, **kwargs, - ): +): """ A file reader for native py4DSTEM / EMD files. To read non-native formats, use `py4DSTEM.import_file`. @@ -66,53 +64,61 @@ def read( # parse filetype er1 = f"filepath must be a string or Path, not {type(filepath)}" er2 = f"specified filepath '{filepath}' does not exist" - assert(isinstance(filepath, (str,Path) )), er1 - assert(exists(filepath)), er2 + assert isinstance(filepath, (str, Path)), er1 + assert exists(filepath), er2 filetype = _parse_filetype(filepath) - assert filetype == "H5", f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file" + assert filetype in ( + "emd", + "legacy", + ), f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file" # support older `root` input if datapath is None: - if 'root' in kwargs: - datapath = kwargs['root'] + if "root" in kwargs: + datapath = kwargs["root"] # EMD 1.0 formatted files (py4DSTEM v0.14+) - if emd._is_EMD_file(filepath): + if filetype == "emd": + + # check version version = emd._get_EMD_version(filepath) - if verbose: print(f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading...") - assert emd._version_is_geq(version,(1,0,0)), f"EMD version {version} detected. Expected version >= 1.0.0" + if verbose: + print( + f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading..." + ) + assert emd._version_is_geq( + version, (1, 0, 0) + ), f"EMD version {version} detected. Expected version >= 1.0.0" # read - data = emd.read( - filepath, - emdpath = datapath, - tree = tree - ) + data = emd.read(filepath, emdpath=datapath, tree=tree) + if verbose: + print("Data was read from file. Adding calibration links...") # add calibration links - if isinstance(data,py4DSTEM.Data): + if isinstance(data, Data): with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") cal = data.calibration - elif isinstance(data,py4DSTEM.Root): + elif isinstance(data, emd.Root): try: - cal = data.metadata['calibration'] + cal = data.metadata["calibration"] except KeyError: cal = None else: cal = None if cal is not None: try: - root_treepath = cal['_root_treepath'] - target_paths = cal['_target_paths'] - del(cal._params['_target_paths']) + root_treepath = cal["_root_treepath"] + target_paths = cal["_target_paths"] + del cal._params["_target_paths"] for p in target_paths: try: - p = p.replace(root_treepath,'') + p = p.replace(root_treepath, "") d = data.root.tree(p) - cal.register_target( d ) - if hasattr(d,'setcal'): + cal.register_target(d) + if hasattr(d, "setcal"): d.setcal() except AssertionError: pass @@ -121,68 +127,70 @@ def read( cal.calibrate() # return - if verbose: print("Done.") + if verbose: + print("Done.") return data - # legacy py4DSTEM files (v <= 0.13) else: - assert legacy.is_py4DSTEM_file(filepath), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file." - + assert ( + filetype == "legacy" + ), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file." # read v13 if legacy.is_py4DSTEM_version13(filepath): # load the data - if verbose: print(f"Legacy py4DSTEM version 13 file detected. Reading...") - kwargs['root'] = datapath - kwargs['tree'] = tree + if verbose: + print("Legacy py4DSTEM version 13 file detected. Reading...") + kwargs["root"] = datapath + kwargs["tree"] = tree data = legacy.read_legacy13( filepath=filepath, **kwargs, ) - if verbose: print("Done.") + if verbose: + print("Done.") return data - # read <= v12 else: # parse the root/data_id from the datapath arg if datapath is not None: - datapath = datapath.split('/') + datapath = datapath.split("/") try: - datapath.remove('') + datapath.remove("") except ValueError: pass rootgroup = datapath[0] - if len(datapath)>1: - datapath = '/'.join(rootgroup[1:]) + if len(datapath) > 1: + datapath = "/".join(rootgroup[1:]) else: datapath = None else: rootgroups = legacy.get_py4DSTEM_topgroups(filepath) - if len(rootgroups)>1: - print('multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`') + if len(rootgroups) > 1: + print( + "multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`" + ) return rootgroups - elif len(rootgroups)==0: - raise Exception('No rootgroups found') + elif len(rootgroups) == 0: + raise Exception("No rootgroups found") else: rootgroup = rootgroups[0] datapath = None - # load the data - if verbose: print(f"Legacy py4DSTEM version <= 12 file detected. Reading...") - kwargs['topgroup'] = rootgroup + if verbose: + print("Legacy py4DSTEM version <= 12 file detected. Reading...") + kwargs["topgroup"] = rootgroup if datapath is not None: - kwargs['data_id'] = datapath + kwargs["data_id"] = datapath data = legacy.read_legacy12( filepath=filepath, **kwargs, ) - if verbose: print("Done.") + if verbose: + print("Done.") return data - - - diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index f1052d72d..4001f80cb 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -283,6 +283,7 @@ def bin_data_diffraction( # set calibration pixel size datacube.calibration.set_Q_pixel_size(Qpixsize) + # return return datacube @@ -647,14 +648,9 @@ def resample_data_diffraction( datacube.data = fourier_resample( datacube.data, scale=resampling_factor, output_size=output_size ) - - if not resampling_factor: - resampling_factor = old_size[2] / output_size[0] - if datacube.calibration.get_Q_pixel_size() is not None: - datacube.calibration.set_Q_pixel_size(datacube.calibration.get_Q_pixel_size() / resampling_factor) if not resampling_factor: - resampling_factor = old_size[2] / output_size[0] + resampling_factor = output_size[0] / old_size[2] if datacube.calibration.get_Q_pixel_size() is not None: datacube.calibration.set_Q_pixel_size(datacube.calibration.get_Q_pixel_size() / resampling_factor) diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py index 8821b43d0..19d2f0c55 100644 --- a/py4DSTEM/process/calibration/origin.py +++ b/py4DSTEM/process/calibration/origin.py @@ -134,14 +134,14 @@ def fit_origin( # Fit data if mask is None: - popt_x, pcov_x, qx0_fit = fit_2D( + popt_x, pcov_x, qx0_fit, _ = fit_2D( f, qx0_meas, robust=robust, robust_steps=robust_steps, robust_thresh=robust_thresh, ) - popt_y, pcov_y, qy0_fit = fit_2D( + popt_y, pcov_y, qy0_fit, _ = fit_2D( f, qy0_meas, robust=robust, @@ -150,7 +150,7 @@ def fit_origin( ) else: - popt_x, pcov_x, qx0_fit = fit_2D( + popt_x, pcov_x, qx0_fit, _ = fit_2D( f, qx0_meas, robust=robust, @@ -158,7 +158,7 @@ def fit_origin( robust_thresh=robust_thresh, data_mask=mask == True, ) - popt_y, pcov_y, qy0_fit = fit_2D( + popt_y, pcov_y, qy0_fit, _ = fit_2D( f, qy0_meas, robust=robust, @@ -359,4 +359,3 @@ def get_origin_beamstop(datacube: DataCube, mask: np.ndarray, **kwargs): return qx0, qy0 - diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 228602692..4d4d4a248 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt from fractions import Fraction from typing import Union, Optional -from copy import deepcopy from scipy.optimize import curve_fit import sys @@ -78,6 +77,7 @@ def __init__( 1 number: the lattice parameter for a cubic cell 3 numbers: the three lattice parameters for an orthorhombic cell 6 numbers: the a,b,c lattice parameters and ɑ,β,ɣ angles for any cell + 3x3 array: row vectors containing the (u,v,w) lattice vectors. """ # Initialize Crystal @@ -92,7 +92,10 @@ def __init__( else: raise Exception("Number of positions and atomic numbers do not match") - # unit cell, as either [a a a 90 90 90], [a b c 90 90 90], or [a b c alpha beta gamma] + # unit cell, as one of: + # [a a a 90 90 90] + # [a b c 90 90 90] + # [a b c alpha beta gamma] cell = np.asarray(cell, dtype="float_") if np.size(cell) == 1: self.cell = np.hstack([cell, cell, cell, 90, 90, 90]) @@ -100,34 +103,48 @@ def __init__( self.cell = np.hstack([cell, 90, 90, 90]) elif np.size(cell) == 6: self.cell = cell + elif np.shape(cell)[0] == 3 and np.shape(cell)[1] == 3: + self.lat_real = np.array(cell) + a = np.linalg.norm(self.lat_real[0,:]) + b = np.linalg.norm(self.lat_real[1,:]) + c = np.linalg.norm(self.lat_real[2,:]) + alpha = np.rad2deg(np.arccos(np.clip(np.sum( + self.lat_real[1,:]*self.lat_real[2,:])/b/c,-1,1))) + beta = np.rad2deg(np.arccos(np.clip(np.sum( + self.lat_real[0,:]*self.lat_real[2,:])/a/c,-1,1))) + gamma = np.rad2deg(np.arccos(np.clip(np.sum( + self.lat_real[0,:]*self.lat_real[1,:])/a/b,-1,1))) + self.cell = (a,b,c,alpha,beta,gamma) else: - raise Exception("Cell cannot contain " + np.size(cell) + " elements") + raise Exception("Cell cannot contain " + np.size(cell) + " entries") # pymatgen flag self.pymatgen_available = False # Calculate lattice parameters self.calculate_lattice() - + def calculate_lattice(self): - # calculate unit cell lattice vectors - a = self.cell[0] - b = self.cell[1] - c = self.cell[2] - alpha = np.deg2rad(self.cell[3]) - beta = np.deg2rad(self.cell[4]) - gamma = np.deg2rad(self.cell[5]) - f = np.cos(beta) * np.cos(gamma) - np.cos(alpha) - vol = a*b*c*np.sqrt(1 \ - + 2*np.cos(alpha)*np.cos(beta)*np.cos(gamma) \ - - np.cos(alpha)**2 - np.cos(beta)**2 - np.cos(gamma)**2) - self.lat_real = np.array( - [ - [a, 0, 0], - [b*np.cos(gamma), b*np.sin(gamma), 0], - [c*np.cos(beta), -c*f/np.sin(gamma), vol/(a*b*np.sin(gamma))], - ] - ) + + if not hasattr(self, 'lat_real'): + # calculate unit cell lattice vectors + a = self.cell[0] + b = self.cell[1] + c = self.cell[2] + alpha = np.deg2rad(self.cell[3]) + beta = np.deg2rad(self.cell[4]) + gamma = np.deg2rad(self.cell[5]) + f = np.cos(beta) * np.cos(gamma) - np.cos(alpha) + vol = a*b*c*np.sqrt(1 \ + + 2*np.cos(alpha)*np.cos(beta)*np.cos(gamma) \ + - np.cos(alpha)**2 - np.cos(beta)**2 - np.cos(gamma)**2) + self.lat_real = np.array( + [ + [a, 0, 0], + [b*np.cos(gamma), b*np.sin(gamma), 0], + [c*np.cos(beta), -c*f/np.sin(gamma), vol/(a*b*np.sin(gamma))], + ] + ) # Inverse lattice, metric tensors self.metric_real = self.lat_real @ self.lat_real.T @@ -139,6 +156,49 @@ def calculate_lattice(self): self.pymatgen_available = True else: self.pymatgen_available = False + + def get_strained_crystal( + self, + exx = 0.0, + eyy = 0.0, + ezz = 0.0, + exy = 0.0, + exz = 0.0, + eyz = 0.0, + deformation_matrix = None, + return_deformation_matrix = False, + ): + """ + This method returns new Crystal class with strain applied. The directions of (x,y,z) + are with respect to the default Crystal orientation, which can be checked with + print(Crystal.lat_real) applied to the original Crystal. + + Strains are given in fractional values, so exx = 0.01 is 1% strain along the x direction. + """ + + # deformation matrix + if deformation_matrix is None: + deformation_matrix = np.array([ + [1.0+exx, 1.0*exy, 1.0*exz], + [1.0*exy, 1.0+eyy, 1.0*eyz], + [1.0*exz, 1.0*eyz, 1.0+ezz], + ]) + + # new unit cell + lat_new = self.lat_real @ deformation_matrix + + # make new crystal class + from py4DSTEM.process.diffraction import Crystal + crystal_strained = Crystal( + positions = self.positions.copy(), + numbers = self.numbers.copy(), + cell = lat_new, + ) + + if return_deformation_matrix: + return crystal_strained, deformation_matrix + else: + return crystal_strained def from_CIF(CIF, conventional_standard_structure=True): @@ -386,13 +446,28 @@ def calculate_structure_factors( k_max: float = 2.0, tol_structure_factor: float = 1e-4, return_intensities: bool = False, - ): + ): + + """ Calculate structure factors for all hkl indices up to max scattering vector k_max - Args: - k_max (numpy float): max scattering vector to include (1/Angstroms) - tol_structure_factor (numpy float): tolerance for removing low-valued structure factors + Parameters + -------- + + k_max: float + max scattering vector to include (1/Angstroms) + tol_structure_factor: float + tolerance for removing low-valued structure factors + return_intensities: bool + return the intensities and positions of all structure factor peaks. + + Returns + -------- + (q_SF, I_SF) + Tuple of the q vectors and intensities of each structure factor. + + """ # Store k_max @@ -425,7 +500,7 @@ def calculate_structure_factors( hkl = np.vstack([xa.ravel(), ya.ravel(), za.ravel()]) # g_vec_all = self.lat_inv @ hkl g_vec_all = (hkl.T @ self.lat_inv).T - + # Delete lattice vectors outside of k_max keep = np.linalg.norm(g_vec_all, axis=0) <= self.k_max self.hkl = hkl[:, keep] @@ -898,12 +973,25 @@ def calculate_bragg_peak_histogram( k = np.arange(k_min, k_max + k_step, k_step) k_num = k.shape[0] - # experimental data histogram + # set rotate and ellipse based on their availability + rotate = bragg_peaks.calibration.get_QR_rotation_degrees() + ellipse = bragg_peaks.calibration.get_ellipse() + rotate = False if rotate is None else True + ellipse = False if ellipse is None else True + + # concatenate all peaks bigpl = np.concatenate( [ - bragg_peaks.cal[i, j].data - for i in range(bragg_peaks.shape[0]) - for j in range(bragg_peaks.shape[1]) + bragg_peaks.get_vectors( + rx, + ry, + center = True, + ellipse = ellipse, + pixel = True, + rotate = rotate, + ).data + for rx in range(bragg_peaks.shape[0]) + for ry in range(bragg_peaks.shape[1]) ] ) qr = np.sqrt(bigpl["qx"] ** 2 + bigpl["qy"] ** 2) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 4d02dcb0b..bffa5b620 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -8,6 +8,8 @@ from py4DSTEM.process.diffraction.utils import Orientation, OrientationMap, axisEqual3D from py4DSTEM.process.utils import electron_wavelength_angstrom +from warnings import warn + from numpy.linalg import lstsq try: import cupy as cp @@ -767,6 +769,18 @@ def match_orientations( num_x=bragg_peaks_array.shape[0], num_y=bragg_peaks_array.shape[1], num_matches=num_matches_return) + + #check cal state + if bragg_peaks_array.calstate['ellipse'] == False: + ellipse = False + warn('Warning: bragg peaks not elliptically calibrated') + else: + ellipse = True + if bragg_peaks_array.calstate['rotate'] == False: + rotate = False + warn('bragg peaks not rotationally calibrated') + else: + rotate = True for rx, ry in tqdmnd( *bragg_peaks_array.shape, @@ -774,9 +788,17 @@ def match_orientations( unit=" PointList", disable=not progress_bar, ): + vectors = bragg_peaks_array.get_vectors( + scan_x=rx, + scan_y=ry, + center=True, + ellipse=ellipse, + pixel=True, + rotate=rotate + ) orientation = self.match_single_pattern( - bragg_peaks_array.cal[rx, ry], + bragg_peaks=vectors, num_matches_return=num_matches_return, min_number_peaks=min_number_peaks, inversion_symmetry=inversion_symmetry, @@ -1639,6 +1661,18 @@ def calculate_strain( corr_kernel_size = self.orientation_kernel_size radius_max_2 = corr_kernel_size**2 + #check cal state + if bragg_peaks_array.calstate['ellipse'] == False: + ellipse = False + warn('bragg peaks not elliptically calibrated') + else: + ellipse = True + if bragg_peaks_array.calstate['rotate'] == False: + rotate = False + warn('bragg peaks not rotationally calibrated') + else: + rotate = True + # Loop over all probe positions for rx, ry in tqdmnd( *bragg_peaks_array.shape, @@ -1647,7 +1681,14 @@ def calculate_strain( disable=not progress_bar, ): # Get bragg peaks from experiment and reference - p = bragg_peaks_array.cal[rx,ry] + p = bragg_peaks_array.get_vectors( + scan_x=rx, + scan_y=ry, + center=True, + ellipse=ellipse, + pixel=True, + rotate=rotate + ) if p.data.shape[0] >= min_num_peaks: p_ref = self.generate_diffraction_pattern( @@ -2070,5 +2111,4 @@ def symmetry_reduce_directions( } # "-3m": ["fiber", [0, 0, 1], [90.0, 60.0]], - # "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]], - + # "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]], \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py index 2d08cd03c..1b65480f5 100644 --- a/py4DSTEM/process/diffraction/crystal_calibrate.py +++ b/py4DSTEM/process/diffraction/crystal_calibrate.py @@ -24,7 +24,7 @@ def calibrate_pixel_size( k_step = 0.002, k_broadening = 0.002, fit_all_intensities = True, - set_calibration = True, + set_calibration_in_place = False, verbose = True, plot_result = False, figsize: Union[list, tuple, np.ndarray] = (12, 6), @@ -60,8 +60,13 @@ def calibrate_pixel_size( figsize (list, tuple, np.ndarray): Figure size of the plot. returnfig (bool): Return handles figure and axis - Returns: - fig, ax (handles): Optional figure and axis handles, if returnfig=True. + Returns + _______ + + + + fig, ax: handles, optional + Figure and axis handles, if returnfig=True. """ @@ -112,17 +117,21 @@ def fit_profile(k, *coefs): # Get the answer pix_size_prev = bragg_peaks.calibration.get_Q_pixel_size() - ans = pix_size_prev / scale_pixel_size + pixel_size_new = pix_size_prev / scale_pixel_size - # if requested, apply calibrations - if set_calibration: - bragg_peaks.calibration.set_Q_pixel_size( ans ) + # if requested, apply calibrations in place + if set_calibration_in_place: + bragg_peaks.calibration.set_Q_pixel_size( pixel_size_new ) bragg_peaks.calibration.set_Q_pixel_units('A^-1') - bragg_peaks.setcal() - # Output + # Output calibrated Bragg peaks + bragg_peaks_cali = bragg_peaks.copy() + bragg_peaks_cali.calibration.set_Q_pixel_size( pixel_size_new ) + bragg_peaks_cali.calibration.set_Q_pixel_units('A^-1') + + # Output pixel size if verbose: - print(f"Calibrated pixel size = {np.round(ans, decimals=8)} A^-1") + print(f"Calibrated pixel size = {np.round(pixel_size_new, decimals=8)} A^-1") # Plotting if plot_result: @@ -163,9 +172,9 @@ def fit_profile(k, *coefs): # return if returnfig and plot_result: - return ans, (fig,ax) + return bragg_peaks_cali, (fig,ax) else: - return ans + return bragg_peaks_cali @@ -463,4 +472,4 @@ def fitfun(self, k, *coefs_fit): "432": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic "-43m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic "m-3m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic - } + } \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 9c1f5b667..a9420fee4 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -296,30 +296,48 @@ def plot_scattering_intensity( bragg_k_power=0.0, bragg_intensity_power=1.0, bragg_k_broadening=0.005, - figsize: Union[list, tuple, np.ndarray] = (12, 6), + figsize: Union[list, tuple, np.ndarray] = (10, 4), returnfig: bool = False, ): """ 1D plot of the structure factors - Args: - k_min (float): min k value for profile range. - k_max (float): max k value for profile range. - k_step (float): step size of k in profile range. - k_broadening (float): Broadening of simulated pattern. - k_power_scale (float): Scale SF intensities by k**k_power_scale. - int_power_scale (float): Scale SF intensities**int_power_scale. - int_scale (float): Scale output profile by this value. - remove_origin (bool): Remove origin from plot. - bragg_peaks (BraggVectors): Passed in bragg_peaks for comparison with simulated pattern. - bragg_k_power (float): bragg_peaks scaled by k**bragg_k_power. - bragg_intensity_power (float): bragg_peaks scaled by intensities**bragg_intensity_power. - bragg_k_broadening float): Broadening applied to bragg_peaks. - figsize (list, tuple, np.ndarray): Figure size for plot. - returnfig (bool): Return figure and axes handles if this is True. - - Returns: - fig, ax (optional) figure and axes handles + Parameters + -------- + + k_min: float + min k value for profile range. + k_max: float + max k value for profile range. + k_step: float + Step size of k in profile range. + k_broadening: float + Broadening of simulated pattern. + k_power_scale: float + Scale SF intensities by k**k_power_scale. + int_power_scale: float + Scale SF intensities**int_power_scale. + int_scale: float + Scale output profile by this value. + remove_origin: bool + Remove origin from plot. + bragg_peaks: BraggVectors + Passed in bragg_peaks for comparison with simulated pattern. + bragg_k_power: float + bragg_peaks scaled by k**bragg_k_power. + bragg_intensity_power: float + bragg_peaks scaled by intensities**bragg_intensity_power. + bragg_k_broadening: float + Broadening applied to bragg_peaks. + figsize: list, tuple, np.ndarray + Figure size for plot. + returnfig (bool): + Return figure and axes handles if this is True. + + Returns + -------- + fig, ax (optional) + figure and axes handles """ # k coordinates @@ -342,12 +360,25 @@ def plot_scattering_intensity( # If Bragg peaks are passed in, compute 1D integral if bragg_peaks is not None: + # set rotate and ellipse based on their availability + rotate = bragg_peaks.calibration.get_QR_rotation_degrees() + ellipse = bragg_peaks.calibration.get_ellipse() + rotate = False if rotate is None else True + ellipse = False if ellipse is None else True + # concatenate all peaks bigpl = np.concatenate( [ - bragg_peaks.cal[i, j].data - for i in range(bragg_peaks.shape[0]) - for j in range(bragg_peaks.shape[1]) + bragg_peaks.get_vectors( + rx, + ry, + center = True, + ellipse = ellipse, + pixel = True, + rotate = rotate, + ).data + for rx in range(bragg_peaks.shape[0]) + for ry in range(bragg_peaks.shape[1]) ] ) @@ -903,6 +934,9 @@ def plot_diffraction_pattern( ax.set_ylabel("$q_x$ [Å$^{-1}$]") if plot_range_kx_ky is not None: + plot_range_kx_ky = np.array(plot_range_kx_ky) + if plot_range_kx_ky.ndim == 0: + plot_range_kx_ky = np.array((plot_range_kx_ky,plot_range_kx_ky)) ax.set_xlim((-plot_range_kx_ky[0], plot_range_kx_ky[0])) ax.set_ylim((-plot_range_kx_ky[1], plot_range_kx_ky[1])) else: @@ -1846,4 +1880,4 @@ def plot_ring_pattern( plt.show() if returnfig: - return fig, ax + return fig, ax \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/flowlines.py b/py4DSTEM/process/diffraction/flowlines.py index 27d4f9381..cf84f69f5 100644 --- a/py4DSTEM/process/diffraction/flowlines.py +++ b/py4DSTEM/process/diffraction/flowlines.py @@ -519,6 +519,7 @@ def make_flowline_rainbow_image( power_scaling = 1.0, sum_radial_bins = False, plot_images = True, + figsize = None, ): """ Generate RGB output images from the flowline arrays. @@ -535,6 +536,7 @@ def make_flowline_rainbow_image( power_scaling (float): Power law scaling for flowline intensity output. sum_radial_bins (bool): Sum all radial bins (alternative is to output separate images). plot_images (bool): Plot the outputs for quick visualization. + figsize (2-tuple): Size of output figure. Returns: im_flowline (array): 3D or 4D array containing flowline images @@ -613,7 +615,14 @@ def make_flowline_rainbow_image( im_flowline = np.min(im_flowline,axis=0)[None,:,:,:] if plot_images is True: - fig,ax = plt.subplots(im_flowline.shape[0],1,figsize=(10,im_flowline.shape[0]*10)) + if figsize is None: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=(10,im_flowline.shape[0]*10)) + else: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=figsize) if im_flowline.shape[0] > 1: for a0 in range(im_flowline.shape[0]): @@ -729,6 +738,7 @@ def make_flowline_combined_image( power_scaling = 1.0, sum_radial_bins = True, plot_images = True, + figsize = None, ): """ Generate RGB output images from the flowline arrays. @@ -742,6 +752,7 @@ def make_flowline_combined_image( power_scaling (float): Power law scaling for flowline intensities. sum_radial_bins (bool): Sum outputs over radial bins. plot_images (bool): Plot the output images for quick visualization. + figsize (2-tuple): Size of output figure. Returns: im_flowline (array): flowline images @@ -787,7 +798,14 @@ def make_flowline_combined_image( if plot_images is True: - fig,ax = plt.subplots(im_flowline.shape[0],1,figsize=(10,im_flowline.shape[0]*10)) + if figsize is None: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=(10,im_flowline.shape[0]*10)) + else: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=figsize) if im_flowline.shape[0] > 1: for a0 in range(im_flowline.shape[0]): @@ -1143,4 +1161,4 @@ def set_intensity(orient,xy_t_int): mode=['clip','clip','wrap']) orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:,3]*( dx)*( dy)*( dt) - return orient + return orient \ No newline at end of file diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 9abb713f7..32809ddb1 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -269,7 +269,3 @@ def fit_2D_polar_gaussian( robust_steps = robust_steps, robust_thresh = robust_thresh ) - - - - diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py index 822ffdea5..fef72aca3 100644 --- a/py4DSTEM/process/latticevectors/fit.py +++ b/py4DSTEM/process/latticevectors/fit.py @@ -104,13 +104,22 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): # Make RealSlice to contain outputs slicelabels = ('x0','y0','g1x','g1y','g2x','g2y','error','mask') - g1g2_map = RealSlice(data=np.zeros((braggpeaks.shape[0],braggpeaks.shape[1],8)), - slicelabels=slicelabels, name='g1g2_map') + g1g2_map = RealSlice( + data=np.zeros( + (8, braggpeaks.shape[0],braggpeaks.shape[1]) + ), + slicelabels=slicelabels, name='g1g2_map' + ) # Fit lattice vectors for (Rx, Ry) in tqdmnd(braggpeaks.shape[0],braggpeaks.shape[1]): braggpeaks_curr = braggpeaks.get_pointlist(Rx,Ry) - qx0,qy0,g1x,g1y,g2x,g2y,error = fit_lattice_vectors(braggpeaks_curr, x0, y0, minNumPeaks) + qx0,qy0,g1x,g1y,g2x,g2y,error = fit_lattice_vectors( + braggpeaks_curr, + x0, + y0, + minNumPeaks + ) # Store data if g1x is not None: g1g2_map.get_slice('x0').data[Rx,Ry] = qx0 diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py index cdf6b00fd..189e7f10f 100644 --- a/py4DSTEM/process/latticevectors/index.py +++ b/py4DSTEM/process/latticevectors/index.py @@ -80,6 +80,9 @@ def index_bragg_directions(x0, y0, gx, gy, g1, g2): temp_array = np.zeros([], dtype = coords) bragg_directions = PointList(data = temp_array) bragg_directions.add_data_by_field((gx,gy,h,k)) + mask = np.zeros(bragg_directions['qx'].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) return h,k, bragg_directions @@ -152,8 +155,14 @@ def generate_lattice(ux,uy,vx,vy,x0,y0,Q_Nx,Q_Ny,h_max=None,k_max=None): return ideal_lattice -def add_indices_to_braggpeaks(braggpeaks, lattice, maxPeakSpacing, qx_shift=0, - qy_shift=0, mask=None): +def add_indices_to_braggvectors( + braggpeaks, + lattice, + maxPeakSpacing, + qx_shift=0, + qy_shift=0, + mask=None + ): """ Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, identify the indices for each peak in the PointListArray braggpeaks. @@ -181,43 +190,41 @@ def add_indices_to_braggpeaks(braggpeaks, lattice, maxPeakSpacing, qx_shift=0, 'h', 'k', containing the indices of each indexable peak. """ - assert isinstance(braggpeaks,PointListArray) - assert np.all([name in braggpeaks.dtype.names for name in ('qx','qy','intensity')]) - assert isinstance(lattice, PointList) - assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) if mask is None: - mask = np.ones(braggpeaks.shape,dtype=bool) + mask = np.ones(braggpeaks.Rshape,dtype=bool) - assert mask.shape == braggpeaks.shape, 'mask must have same shape as pointlistarray' + assert mask.shape == braggpeaks.Rshape, 'mask must have same shape as pointlistarray' assert mask.dtype == bool, 'mask must be boolean' - indexed_braggpeaks = braggpeaks.copy() - # add the coordinates if they don't exist - if not ('h' in braggpeaks.dtype.names): - indexed_braggpeaks = indexed_braggpeaks.add_fields([('h',int)]) - if not ('k' in braggpeaks.dtype.names): - indexed_braggpeaks = indexed_braggpeaks.add_fields([('k',int)]) + coords = [('qx',float),('qy',float),('intensity',float),('h',int),('k',int)] + + indexed_braggpeaks = PointListArray( + dtype = coords, + shape = braggpeaks.Rshape, + ) # loop over all the scan positions for Rx, Ry in tqdmnd(mask.shape[0],mask.shape[1]): - if mask[Rx,Ry]: - pl = indexed_braggpeaks.get_pointlist(Rx,Ry) - rm_peak_mask = np.zeros(pl.length,dtype=bool) - - for i in range(pl.length): + if mask[Rx,Ry]: + pl = braggpeaks.cal[Rx,Ry] + for i in range(pl.data.shape[0]): r2 = (pl.data['qx'][i]-lattice.data['qx'] + qx_shift)**2 + \ (pl.data['qy'][i]-lattice.data['qy'] + qy_shift)**2 ind = np.argmin(r2) if r2[ind] <= maxPeakSpacing**2: - pl.data['h'][i] = lattice.data['h'][ind] - pl.data['k'][i] = lattice.data['k'][ind] - else: - rm_peak_mask[i] = True - pl.remove(rm_peak_mask) + indexed_braggpeaks[Rx,Ry].add_data_by_field(( + pl.data['qx'][i], + pl.data['qy'][i], + pl.data['intensity'][i], + lattice.data['h'][ind], + lattice.data['k'][ind] + )) - indexed_braggpeaks.name = braggpeaks.name + "_indexed" return indexed_braggpeaks diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py index 50b9bddc9..7a586bd69 100644 --- a/py4DSTEM/process/latticevectors/strain.py +++ b/py4DSTEM/process/latticevectors/strain.py @@ -71,9 +71,11 @@ def get_strain_from_reference_g1g2(g1g2_map, g1, g2): # Get RealSlice for output storage R_Nx,R_Ny = g1g2_map.get_slice('g1x').shape - strain_map = RealSlice(data=np.zeros((R_Nx,R_Ny,5)), - slicelabels=('e_xx','e_yy','e_xy','theta','mask'), - name='strain_map') + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=('e_xx','e_yy','e_xy','theta','mask'), + name='strain_map' + ) # Get reference lattice matrix g1x,g1y = g1 @@ -130,7 +132,8 @@ def get_strain_from_reference_region(g1g2_map, mask): Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical """ assert isinstance(g1g2_map, RealSlice) - assert np.all([name in g1g2_map.slicelabels for name in ('g1x','g1y','g2x','g2y','mask')]) + assert np.all( + [name in g1g2_map.slicelabels for name in ('g1x','g1y','g2x','g2y','mask')]) assert mask.dtype == bool g1,g2 = get_reference_g1g2(g1g2_map,mask) @@ -169,18 +172,20 @@ def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): sint2 = sint**2 Rx,Ry = unrotated_strain_map.get_slice('e_xx').data.shape - rotated_strain_map = RealSlice(data=np.zeros((Rx,Ry,5)), - slicelabels=['e_xx','e_xy','e_yy','theta','mask'], - name=unrotated_strain_map.name+"_rotated".format(np.degrees(theta))) - - rotated_strain_map.data[:,:,0] = cost2*unrotated_strain_map.get_slice('e_xx').data - 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + sint2*unrotated_strain_map.get_slice('e_yy').data - rotated_strain_map.data[:,:,1] = cost*sint*(unrotated_strain_map.get_slice('e_xx').data-unrotated_strain_map.get_slice('e_yy').data) + (cost2-sint2)*unrotated_strain_map.get_slice('e_xy').data - rotated_strain_map.data[:,:,2] = sint2*unrotated_strain_map.get_slice('e_xx').data + 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + cost2*unrotated_strain_map.get_slice('e_yy').data + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx,Ry)), + slicelabels=['e_xx','e_xy','e_yy','theta','mask'], + name=unrotated_strain_map.name+"_rotated".format(np.degrees(theta)) + ) + + rotated_strain_map.data[0,:,:] = cost2*unrotated_strain_map.get_slice('e_xx').data - 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + sint2*unrotated_strain_map.get_slice('e_yy').data + rotated_strain_map.data[1,:,:] = cost*sint*(unrotated_strain_map.get_slice('e_xx').data-unrotated_strain_map.get_slice('e_yy').data) + (cost2-sint2)*unrotated_strain_map.get_slice('e_xy').data + rotated_strain_map.data[2,:,:] = sint2*unrotated_strain_map.get_slice('e_xx').data + 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + cost2*unrotated_strain_map.get_slice('e_yy').data if flip_theta == True: - rotated_strain_map.data[:,:,3] = -unrotated_strain_map.get_slice('theta').data + rotated_strain_map.data[3,:,:] = -unrotated_strain_map.get_slice('theta').data else: - rotated_strain_map.data[:,:,3] = unrotated_strain_map.get_slice('theta').data - rotated_strain_map.data[:,:,4] = unrotated_strain_map.get_slice('mask').data + rotated_strain_map.data[3,:,:] = unrotated_strain_map.get_slice('theta').data + rotated_strain_map.data[4,:,:] = unrotated_strain_map.get_slice('mask').data return rotated_strain_map diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 1e9bd2cbb..92f8c0bf3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -974,7 +974,7 @@ def _gradient_descent_adjoint( ) # back-transmit - exit_waves *= xp.conj(obj) / xp.abs(obj) ** 2 + exit_waves *= xp.conj(obj) #/ xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -1076,7 +1076,7 @@ def _projection_sets_adjoint( ) # back-transmit - exit_waves_copy *= xp.conj(obj) / xp.abs(obj) ** 2 + exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -3067,4 +3067,4 @@ def _return_object_fft( obj = np.angle(obj) obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) \ No newline at end of file diff --git a/py4DSTEM/process/polar/__init__.py b/py4DSTEM/process/polar/__init__.py index ddf0a9e50..79e13a054 100644 --- a/py4DSTEM/process/polar/__init__.py +++ b/py4DSTEM/process/polar/__init__.py @@ -1,3 +1,3 @@ from py4DSTEM.process.polar.polar_datacube import PolarDatacube from py4DSTEM.process.polar.polar_fits import fit_amorphous_ring, plot_amorphous_ring -from py4DSTEM.process.polar.polar_peaks import find_peaks_single_pattern, find_peaks, refine_peaks, plot_radial_peaks, plot_radial_background, make_orientation_histogram +from py4DSTEM.process.polar.polar_peaks import find_peaks_single_pattern, find_peaks, refine_peaks, plot_radial_peaks, plot_radial_background, make_orientation_histogram \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 0a6089f4e..fa6a40a4f 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -6,10 +6,17 @@ from emdfile import tqdmnd +<<<<<<< Updated upstream def calculate_FEM_global( self, use_median_local = False, use_median_global = False, +======= +def calculate_radial_statistics( + self, + median_local = False, + median_global = False, +>>>>>>> Stashed changes plot_results = False, figsize = (8,4), returnval = False, @@ -42,22 +49,48 @@ def calculate_FEM_global( self.scattering_vector = self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() self.scattering_vector_units = self.calibration.get_Q_pixel_units() +<<<<<<< Updated upstream # init radial data array +======= + # init radial data arrays +>>>>>>> Stashed changes self.radial_all = np.zeros(( self._datacube.shape[0], self._datacube.shape[1], self.polar_shape[1], )) +<<<<<<< Updated upstream +======= + self.radial_all_std = np.zeros(( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + )) + +>>>>>>> Stashed changes # Compute the radial mean for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], +<<<<<<< Updated upstream desc="Global FEM", unit=" probe positions", disable=not progress_bar): self.radial_all[rx,ry] = np.mean(self.data[rx,ry],axis=0) +======= + desc="Radial statistics", + unit=" probe positions", + disable=not progress_bar): + + self.radial_all[rx,ry] = np.mean( + self.data[rx,ry], + axis=0) + self.radial_all_std[rx,ry] = np.sqrt(np.mean( + (self.data[rx,ry] - self.radial_all[rx,ry][None])**2, + axis=0)) +>>>>>>> Stashed changes self.radial_avg = np.mean(self.radial_all, axis=(0,1)) self.radial_var = np.mean( @@ -138,5 +171,18 @@ def calculate_FEM_local( """ +<<<<<<< Updated upstream 1+1 +======= + pass + + +# def radial_average( +# self, +# figsize = (8,6), +# returnfig = False, +# ): + + +>>>>>>> Stashed changes diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index 3f3db0eca..b6ef4ee66 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -19,7 +19,7 @@ def __init__( n_annular = 180, qscale = None, mask = None, - mask_thresh = 0.25, + mask_thresh = 0.1, ellipse = True, two_fold_symmetry = False, ): @@ -95,7 +95,12 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( +<<<<<<< Updated upstream calculate_FEM_global, +======= + # calculate_FEM_global, + calculate_radial_statistics, +>>>>>>> Stashed changes plot_FEM_global, calculate_FEM_local, ) @@ -127,9 +132,9 @@ def set_radial_bins( self._qmax, self._qstep ) - self.qscale = self._qscale self._radial_step = self._datacube.calibration.get_Q_pixel_size() * self._qstep self.set_polar_shape() + self.qscale = self._qscale @property def qmin(self): @@ -241,7 +246,7 @@ def qscale(self): def qscale(self,x): self._qscale = x if x is not None: - self._qscale_ar = np.arange(self.polar_shape[1])**x + self._qscale_ar = (self.qq / self.qq[-1])**x # expose raw data @@ -453,7 +458,7 @@ def _transform( ) # scale the normalization array by the bin density - norm_array = ans_norm*self._polarcube._annular_bin_step[np.newaxis] + norm_array = ans_norm * self._polarcube._annular_bin_step[np.newaxis] mask_bool = norm_array < mask_thresh # apply normalization @@ -588,5 +593,4 @@ def __repr__(self): space = ' '*len(self.__class__.__name__)+' ' string = f"{self.__class__.__name__}( " string += "Retrieves the diffraction pattern at scan position (x,y) in polar coordinates when sliced with [x,y]." - return string - + return string \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index 82085d6fb..e231dda07 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -7,11 +7,12 @@ def fit_amorphous_ring( im, - center, - radial_range, + center = None, + radial_range = None, coefs = None, mask_dp = None, show_fit_mask = False, + maxfev = None, verbose = False, plot_result = True, plot_log_scale = False, @@ -28,15 +29,19 @@ def fit_amorphous_ring( im: np.array 2D image array to perform fitting on center: np.array - (x,y) center coordinates for fitting mask + (x,y) center coordinates for fitting mask. If not specified + by the user, we will assume the center coordinate is (im.shape-1)/2. radial_range: np.array - (radius_inner, radius_outer) radial range to perform fitting over + (radius_inner, radius_outer) radial range to perform fitting over. + If not specified by the user, we will assume (im.shape[0]/4,im.shape[0]/2). coefs: np.array (optional) Array containing initial fitting coefficients for the amorphous fit. mask_dp: np.array Dark field mask for fitting, in addition to the radial range specified above. show_fit_mask: bool Set to true to preview the fitting mask and initial guess for the ellipse params + maxfev: int + Max number of fitting evaluations for curve_fit. verbose: bool Print fit results plot_result: bool @@ -58,6 +63,14 @@ def fit_amorphous_ring( 11 parameter elliptic fit coefficients """ + # Default values + if center is None: + center = np.array(( + (im.shape[0]-1)/2, + (im.shape[1]-1)/2)) + if radial_range is None: + radial_range = (im.shape[0]/4, im.shape[0]/2) + # coordinates xa,ya = np.meshgrid( np.arange(im.shape[0]), @@ -149,14 +162,26 @@ def fit_amorphous_ring( else: # Perform elliptic fitting int_mean = np.mean(vals) - coefs = curve_fit( - amorphous_model, - basis, - vals / int_mean, - p0=coefs, - xtol = 1e-12, - bounds = (lb,ub), - )[0] + + if maxfev is None: + coefs = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + )[0] + else: + coefs = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + maxfev = maxfev, + )[0] coefs[4] = np.mod(coefs[4],2*np.pi) coefs[5:8] *= int_mean # bounds=bounds @@ -356,4 +381,4 @@ def amorphous_model(basis, *coefs): sub = np.logical_not(sub) int_model[sub] += int12*np.exp(dr2[sub]/(-2*sigma2**2)) - return int_model + return int_model \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index 3f367b398..6a6e0860a 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -5,7 +5,7 @@ from scipy.ndimage import gaussian_filter, gaussian_filter1d from scipy.signal import peak_prominences from skimage.feature import peak_local_max -from scipy.optimize import curve_fit +from scipy.optimize import curve_fit, leastsq import warnings # from emdfile import tqdmnd, PointList, PointListArray @@ -34,7 +34,8 @@ def find_peaks_single_pattern( return_background = False, plot_result = True, plot_power_scale = 1.0, - plot_scale_size = 100.0, + plot_scale_size = 10.0, + figsize = (12,6), returnfig = False, ): """ @@ -62,10 +63,41 @@ def find_peaks_single_pattern( radial_background_thresh: float Relative order of sorted values to use as background estimate. Setting to 0.5 is equivalent to median, 0.0 is min value. - + num_peaks_max = 100 + Max number of peaks to return. + threshold_abs: float + Absolute image intensity threshold for peaks. + threshold_prom_annular: float + Threshold for prominance, along annular direction. + threshold_prom_radial: float + Threshold for prominance, along radial direction. + remove_masked_peaks: bool + Delete peaks that are in the region masked by "mask" + scale_sigma_annular: float + Scaling of the estimated annular standard deviation. + scale_sigma_radial: float + Scaling of the estimated radial standard deviation. + return_background: bool + Return the background signal. + plot_result: + Plot the detector peaks + plot_power_scale: float + Image intensity power law scaling. + plot_scale_size: float + Marker scaling in the plot. + figsize: 2-tuple + Size of the result plotting figure. + returnfig: bool + Return the figure and axes handles. + Returns -------- + peaks_polar : pointlist + The detected peaks + fig, ax : (optional) + Figure and axes handles + """ # if needed, generate mask from Bragg peaks @@ -151,7 +183,7 @@ def find_peaks_single_pattern( trace_annular, annular_ind_center, ) - sigma_annular = scale_sigma_annular * np.maximum( + sigma_annular = scale_sigma_annular * np.minimum( annular_ind_center - p_annular[1], p_annular[2] - annular_ind_center) @@ -161,7 +193,7 @@ def find_peaks_single_pattern( trace_radial, np.atleast_1d(peaks[a0,1]), ) - sigma_radial = scale_sigma_radial * np.maximum( + sigma_radial = scale_sigma_radial * np.minimum( peaks[a0,1] - p_radial[1], p_radial[2] - peaks[a0,1]) @@ -266,7 +298,7 @@ def find_peaks_single_pattern( st = np.sin(t) - fig,ax = plt.subplots(figsize=(12,6)) + fig,ax = plt.subplots(figsize=figsize) ax.imshow( im_plot, @@ -685,6 +717,7 @@ def model_radial_background( ring_int = None, refine_model = True, plot_result = True, + figsize = (8,4), ): """ User provided radial background model, of the form: @@ -751,6 +784,8 @@ def model_radial_background( self.background_coefs[3*a0+3] = ring_int[a0] self.background_coefs[3*a0+4] = ring_sigma[a0] self.background_coefs[3*a0+5] = ring_position[a0] + lb = np.zeros_like(self.background_coefs) + ub = np.ones_like(self.background_coefs) * np.inf # Create background model def background_model(q, *coefs): @@ -776,7 +811,7 @@ def background_model(q, *coefs): self.background_radial_mean[self.background_mask], p0 = self.background_coefs, xtol = 1e-12, - # bounds = (lb,ub), + bounds = (lb,ub), )[0] # plotting @@ -784,6 +819,7 @@ def background_model(q, *coefs): self.plot_radial_background( q_pixel_units = False, plot_background_model = True, + figsize = figsize, ) @@ -794,6 +830,7 @@ def refine_peaks( # reset_fits_to_init_positions = False, scale_sigma_estimate = 0.5, min_num_pixels_fit = 10, + maxfev = None, progress_bar = True, ): """ @@ -816,6 +853,8 @@ def refine_peaks( Factor to reduce sigma of peaks by, to prevent fit from running away. min_num_pixels_fit: int Minimum number of pixels to perform fitting + maxfev: int + Maximum number of iterations in fit. Set to a low number for a fast fit. progress_bar: bool Enable progress bar @@ -896,6 +935,11 @@ def refine_peaks( s_radial * scale_sigma_estimate, )) + # bounds + lb = np.zeros_like(coefs_all) + ub = np.ones_like(coefs_all) * np.inf + + # Construct fitting model def fit_image(basis, *coefs): coefs = np.squeeze(np.array(coefs)) @@ -928,14 +972,25 @@ def fit_image(basis, *coefs): try: with warnings.catch_warnings(): warnings.simplefilter('ignore') - coefs_all = curve_fit( - fit_image, - basis[mask_bool.ravel(),:], - im_polar[mask_bool], - p0 = coefs_all, - xtol = 1e-12, - # bounds = (lb,ub), - )[0] + if maxfev is None: + coefs_all = curve_fit( + fit_image, + basis[mask_bool.ravel(),:], + im_polar[mask_bool], + p0 = coefs_all, + xtol = 1e-12, + bounds = (lb,ub), + )[0] + else: + coefs_all = curve_fit( + fit_image, + basis[mask_bool.ravel(),:], + im_polar[mask_bool], + p0 = coefs_all, + xtol = 1e-12, + maxfev = maxfev, + bounds = (lb,ub), + )[0] # Output refined peak parameters coefs_peaks = np.reshape( @@ -951,9 +1006,24 @@ def fit_image(basis, *coefs): ]), name = 'peaks_polar') except: - # if fitting has failed, we will output the mean background signal, - # but none of the peaks. - pass + # if fitting has failed, we will still output the last iteration + # TODO - add a flag for unconverged fits + coefs_peaks = np.reshape( + coefs_all[(3*num_rings+3):], + (5,num_peaks)).T + self.peaks_refine[rx,ry] = PointList( + coefs_peaks.ravel().view([ + ('qt', float), + ('qr', float), + ('intensity', float), + ('sigma_annular', float), + ('sigma_radial', float), + ]), + name = 'peaks_polar') + + # mean background signal, + # # but none of the peaks. + # pass # Output refined parameters for background coefs_bg = coefs_all[:(3*num_rings+3)] @@ -1154,6 +1224,9 @@ def make_orientation_histogram( v_sigma = np.linspace(-2,2,2*peak_sigma_samples+1) w_sigma = np.exp(-v_sigma**2/2) + if use_refined_peaks is False: + warnings.warn("Orientation histogram is using non-refined peak positions") + # Loop over all probe positions for a0 in range(num_radii): t = "Generating histogram " + str(a0) @@ -1199,7 +1272,10 @@ def make_orientation_histogram( # If needed, expand signal using peak sigma to write into multiple bins if use_peak_sigma: - theta_std = self.peaks_refine[rx,ry]['sigma_annular'][sub] / dtheta + if use_refined_peaks: + theta_std = self.peaks_refine[rx,ry]['sigma_annular'][sub] / dtheta + else: + theta_std = self.peaks[rx,ry]['sigma_annular'][sub] / dtheta t = (t[:,None] + theta_std[:,None]*v_sigma[None,:]).ravel() intensity = (intensity[:,None] * w_sigma[None,:]).ravel() diff --git a/py4DSTEM/process/rdf/amorph.py b/py4DSTEM/process/rdf/amorph.py index 9c80a2807..a537896b9 100644 --- a/py4DSTEM/process/rdf/amorph.py +++ b/py4DSTEM/process/rdf/amorph.py @@ -111,7 +111,7 @@ def plot_strains(strains, cmap="RdBu_r", vmin=None, vmax=None, mask=None): cmap, vmin, vmax: imshow parameters mask: real space mask of values not to show (black) """ - cmap = matplotlib.cm.get_cmap(cmap) + cmap = plt.get_cmap(cmap) if vmin is None: vmin = np.min(strains) if vmax is None: diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain.py index e999c02d1..db252f75b 100644 --- a/py4DSTEM/process/strain.py +++ b/py4DSTEM/process/strain.py @@ -1,13 +1,17 @@ # Defines the Strain class -import numpy as np from typing import Optional -from py4DSTEM.data import RealSlice, Data -from py4DSTEM.braggvectors import BraggVectors +import matplotlib.pyplot as plt +import numpy as np +from py4DSTEM import PointList +from py4DSTEM.braggvectors import BraggVectors +from py4DSTEM.data import Data, RealSlice +from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show -class StrainMap(RealSlice,Data): +class StrainMap(RealSlice, Data): """ Stores strain map. @@ -15,64 +19,80 @@ class StrainMap(RealSlice,Data): """ - def __init__( - self, - braggvectors: BraggVectors, - name: Optional[str] = 'strainmap' - ): + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): """ TODO """ - assert(isinstance(braggvectors,BraggVectors)), f"braggvectors myst be BraggVectors, not type {type(braggvectors)}" + assert isinstance( + braggvectors, BraggVectors + ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" # initialize as a RealSlice RealSlice.__init__( self, - name = name, - data = np.empty(( - 6, - braggvectors.Rshape[0], - braggvectors.Rshape[1], - )), - slicelabels = [ - 'exx', - 'eyy', - 'exy', - 'theta', - 'mask', - 'error' - ] + name=name, + data=np.empty( + ( + 6, + braggvectors.Rshape[0], + braggvectors.Rshape[1], + ) + ), + slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], ) # set up braggvectors + # this assigns the bvs, ensures the origin is calibrated, + # and adds the strainmap to the bvs' tree self.braggvectors = braggvectors - # TODO - how to handle changes to braggvectors - # option: register with calibrations and add a .calibrate method - # which {{does something}} when origin changes - # TODO - include ellipse cal or no? - - assert(self.root is not None) # initialize as Data - Data.__init__( - self, - calibration = self.braggvectors.calibration - ) - + Data.__init__(self) + + # set calstate + # this property is used only to check to make sure that + # the braggvectors being used throughout a workflow are + # the same. The state of calibration of the vectors is noted + # here, and then checked each time the vectors are used - + # if they differ, an error message and instructions for + # re-calibration are issued + self.calstate = self.braggvectors.calstate + assert self.calstate["center"], "braggvectors must be centered" + # get the BVM + # a new BVM using the current calstate is computed + self.bvm = self.braggvectors.histogram(mode="cal") # braggvector properties @property def braggvectors(self): return self._braggvectors + @braggvectors.setter - def braggvectors(self,x): - assert(isinstance(x,BraggVectors)), f".braggvectors must be BraggVectors, not type {type(x)}" - assert(x.calibration.origin is not None), f"braggvectors must have a calibrated origin" + def braggvectors(self, x): + assert isinstance( + x, BraggVectors + ), f".braggvectors must be BraggVectors, not type {type(x)}" + assert ( + x.calibration.origin is not None + ), f"braggvectors must have a calibrated origin" self._braggvectors = x - self._braggvectors.tree(self,force=True) - + self._braggvectors.tree(self, force=True) + def reset_calstate(self): + """ + Resets the calibration state. This recomputes the BVM, and removes any computations + this StrainMap instance has stored, which will need to be recomputed. + """ + for attr in ( + "g0", + "g1", + "g2", + ): + if hasattr(self, attr): + delattr(self, attr) + self.calstate = self.braggvectors.calstate + pass # Class methods @@ -81,10 +101,8 @@ def choose_lattice_vectors( index_g0, index_g1, index_g2, - mode = 'centered', - plot = True, - subpixel = 'multicorr', - upsample_factor = 16, + subpixel="multicorr", + upsample_factor=16, sigma=0, minAbsoluteIntensity=0, minRelativeIntensity=0, @@ -92,95 +110,492 @@ def choose_lattice_vectors( minSpacing=0, edgeBoundary=1, maxNumPeaks=10, - bvm_vis_params = {}, - returncalc = False, - ): + figsize=(12, 6), + c_indices="lightblue", + c0="g", + c1="r", + c2="r", + c_vectors="r", + c_vectorlabels="w", + size_indices=20, + width_vectors=1, + size_vectorlabels=20, + vis_params={}, + returncalc=False, + returnfig=False, + ): """ Choose which lattice vectors to use for strain mapping. - Args: - index_g0 (int): origin - index_g1 (int): second point of vector 1 - index_g2 (int): second point of vector 2 - mode (str): centered or raw bragg map - plot (bool): plot bragg vector maps and vectors - subpixel (str): specifies the subpixel resolution algorithm to use. - must be in ('pixel','poly','multicorr'), which correspond - to pixel resolution, subpixel resolution by fitting a - parabola, and subpixel resultion by Fourier upsampling. - upsample_factor: the upsampling factor for the 'multicorr' - algorithm - sigma: if >0, applies a gaussian filter - maxNumPeaks: the maximum number of maxima to return - minAbsoluteIntensity, minRelativeIntensity, relativeToPeak, - minSpacing, edgeBoundary, maxNumPeaks: filtering applied - after maximum detection and before subpixel refinement + Overlays the bvm with the points detected via local 2D + maxima detection, plus an index for each point. User selects + 3 points using the overlaid indices, which are identified as + the origin and the termini of the lattice vectors g1 and g2. + + Parameters + ---------- + index_g0 : int + selected index for the origin + index_g1 : int + selected index for g1 + index_g2 :int + selected index for g2 + subpixel : str in ('pixel','poly','multicorr') + See the docstring for py4DSTEM.preprocess.get_maxima_2D + upsample_factor : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + sigma : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + figsize : 2-tuple + the size of the figure + c_indices : color + color of the maxima + c0 : color + color of the origin + c1 : color + color of g1 point + c2 : color + color of g2 point + c_vectors : color + color of the g1/g2 vectors + c_vectorlabels : color + color of the vector labels + size_indices : number + size of the indices + width_vectors : number + width of the vectors + size_vectorlabels : number + size of the vector labels + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + toggles returning the answer + returnfig : bool + toggles returning the figure + + Returns + ------- + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter """ - from py4DSTEM.process.utils import get_maxima_2D - - if mode == "centered": - bvm = self.bvm_centered - else: - bvm = self.bvm_raw - + # validate inputs + for i in (index_g0, index_g1, index_g2): + assert isinstance(i, (int, np.integer)), "indices must be integers!" + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # find the maxima g = get_maxima_2D( - bvm, - subpixel = subpixel, - upsample_factor = upsample_factor, - sigma = sigma, - minAbsoluteIntensity = minAbsoluteIntensity, - minRelativeIntensity = minRelativeIntensity, - relativeToPeak = relativeToPeak, - minSpacing = minSpacing, - edgeBoundary = edgeBoundary, - maxNumPeaks = maxNumPeaks, + self.bvm.data, + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, ) + # get the lattice vectors + gx, gy = g["x"], g["y"] + g0 = gx[index_g0], gy[index_g0] + g1x = gx[index_g1] - g0[0] + g1y = gy[index_g1] - g0[1] + g2x = gx[index_g2] - g0[0] + g2y = gy[index_g2] - g0[1] + g1, g2 = (g1x, g1y), (g2x, g2y) + + # make the figure + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + show(self.bvm.data, figax=(fig, ax1), **vis_params) + show(self.bvm.data, figax=(fig, ax2), **vis_params) + + # Add indices to left panel + d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} + d0 = { + "x": gx[index_g0], + "y": gy[index_g0], + "size": size_indices, + "color": c0, + "fontweight": "bold", + "labels": [str(index_g0)], + } + d1 = { + "x": gx[index_g1], + "y": gy[index_g1], + "size": size_indices, + "color": c1, + "fontweight": "bold", + "labels": [str(index_g1)], + } + d2 = { + "x": gx[index_g2], + "y": gy[index_g2], + "size": size_indices, + "color": c2, + "fontweight": "bold", + "labels": [str(index_g2)], + } + add_pointlabels(ax1, d) + add_pointlabels(ax1, d0) + add_pointlabels(ax1, d1) + add_pointlabels(ax1, d2) + + # Add vectors to right panel + dg1 = { + "x0": gx[index_g0], + "y0": gy[index_g0], + "vx": g1[0], + "vy": g1[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_1$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + dg2 = { + "x0": gx[index_g0], + "y0": gy[index_g0], + "vx": g2[0], + "vy": g2[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_2$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + add_vector(ax2, dg1) + add_vector(ax2, dg2) + + # store vectors self.g = g + self.g0 = g0 + self.g1 = g1 + self.g2 = g2 + + # return + if returncalc and returnfig: + return (g0, g1, g2), (fig, (ax1, ax2)) + elif returncalc: + return (g0, g1, g2) + elif returnfig: + return (fig, (ax1, ax2)) + else: + return + + def fit_lattice_vectors( + self, + x0=None, + y0=None, + max_peak_spacing=2, + mask=None, + plot=True, + vis_params={}, + returncalc=False, + ): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + Args: + x0 : floagt + x-coord of origin + y0 : float + y-coord of origin + max_peak_spacing: float + Maximum distance from the ideal lattice points + to include a peak for indexing + mask: bool + Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + plot:bool + plot results if tru + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - from py4DSTEM.visualize import select_lattice_vectors - g1,g2 = select_lattice_vectors( - bvm, - gx = g['x'], - gy = g['y'], - i0 = index_g0, - i1 = index_g1, - i2 = index_g2, - **bvm_vis_params, + if x0 is None: + x0 = self.braggvectors.Qshape[0] / 2 + if y0 is None: + y0 = self.braggvectors.Qshape[0] / 2 + + # index braggvectors + from py4DSTEM.process.latticevectors import index_bragg_directions + + _, _, braggdirections = index_bragg_directions( + x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 ) - self.g1 = g1 - self.g2 = g2 + self.braggdirections = braggdirections + + if plot: + self.show_bragg_indexing( + self.bvm, + bragg_directions=braggdirections, + points=True, + **vis_params, + ) + + # add indicies to braggvectors + from py4DSTEM.process.latticevectors import add_indices_to_braggvectors + + bragg_vectors_indexed = add_indices_to_braggvectors( + self.braggvectors, + self.braggdirections, + maxPeakSpacing=max_peak_spacing, + qx_shift=self.braggvectors.Qshape[0] / 2, + qy_shift=self.braggvectors.Qshape[1] / 2, + mask=mask, + ) + + self.bragg_vectors_indexed = bragg_vectors_indexed + + # fit bragg vectors + from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs + + g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) + self.g1g2_map = g1g2_map if returncalc: - return g1, g2 + braggdirections, bragg_vectors_indexed, g1g2_map + + def get_strain( + self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs + ): + """ + mask: nd.array (bool) + Use lattice vectors from g1g2_map scan positions + wherever mask==True. If mask is None gets median strain + map from entire field of view. If mask is not None, gets + reference g1 and g2 from region and then calculates strain. + g_reference: nd.array of form [x,y] + G_reference (tupe): reference coordinate system for + xaxis_x and xaxis_y + flip_theta: bool + If True, flips rotation coordinate system + returncal: bool + It True, returns rotated map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + if mask is None: + mask = np.ones(self.g1g2_map.shape, dtype="bool") + from py4DSTEM.process.latticevectors import get_strain_from_reference_region + strainmap_g1g2 = get_strain_from_reference_region( + self.g1g2_map, + mask=mask, + ) + else: + from py4DSTEM.process.latticevectors import get_reference_g1g2 + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 + strainmap_g1g2 = get_strain_from_reference_g1g2( + self.g1g2_map, g1_ref, g2_ref + ) + self.strainmap_g1g2 = strainmap_g1g2 + if g_reference is None: + g_reference = np.subtract(self.g1, self.g2) + from py4DSTEM.process.latticevectors import get_rotated_strain_map + strainmap_rotated = get_rotated_strain_map( + self.strainmap_g1g2, + xaxis_x=g_reference[0], + xaxis_y=g_reference[1], + flip_theta=flip_theta, + ) - # IO methods + self.strainmap_rotated = strainmap_rotated + + from py4DSTEM.visualize import show_strain + + figsize = kwargs.pop("figsize", (14, 4)) + vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) + vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) + ticknumber = kwargs.pop("ticknumber", 3) + bkgrd = kwargs.pop("bkgrd", False) + axes_plots = kwargs.pop("axes_plots", ()) + + fig, ax = show_strain( + self.strainmap_rotated, + vrange_exx=vrange_exx, + vrange_theta=vrange_theta, + ticknumber=ticknumber, + axes_plots=axes_plots, + bkgrd=bkgrd, + figsize=figsize, + **kwargs, + returnfig=True, + ) + + if not np.all(mask == True): + ax[0][0].imshow(mask, alpha=0.2, cmap="binary") + ax[0][1].imshow(mask, alpha=0.2, cmap="binary") + ax[1][0].imshow(mask, alpha=0.2, cmap="binary") + ax[1][1].imshow(mask, alpha=0.2, cmap="binary") + + if returncalc: + return self.strainmap_rotated + + def show_lattice_vectors( + ar, + x0, + y0, + g1, + g2, + color="r", + width=1, + labelsize=20, + labelcolor="w", + returnfig=False, + **kwargs, + ): + """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" + fig, ax = show(ar, returnfig=True, **kwargs) + + # Add vectors + dg1 = { + "x0": x0, + "y0": y0, + "vx": g1[0], + "vy": g1[1], + "width": width, + "color": color, + "label": r"$g_1$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + dg2 = { + "x0": x0, + "y0": y0, + "vx": g2[0], + "vy": g2[1], + "width": width, + "color": color, + "label": r"$g_2$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + add_vector(ax, dg1) + add_vector(ax, dg2) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def show_bragg_indexing( + self, + ar, + bragg_directions, + voffset=5, + hoffset=0, + color="w", + size=20, + points=True, + pointcolor="r", + pointsize=50, + returnfig=False, + **kwargs, + ): + """ + Shows an array with an overlay describing the Bragg directions + + Accepts: + ar (arrray) the image + bragg_directions (PointList) the bragg scattering directions; must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. + """ + assert isinstance(bragg_directions, PointList) + for k in ("qx", "qy", "h", "k"): + assert k in bragg_directions.data.dtype.fields + + fig, ax = show(ar, returnfig=True, **kwargs) + d = { + "bragg_directions": bragg_directions, + "voffset": voffset, + "hoffset": hoffset, + "color": color, + "size": size, + "points": points, + "pointsize": pointsize, + "pointcolor": pointcolor, + } + add_bragg_index_labels(ax, d) - # TODO - copy method + if returnfig: + return fig, ax + else: + plt.show() + return + + def copy(self, name=None): + name = name if name is not None else self.name + "_copy" + strainmap_copy = StrainMap(self.braggvectors) + for attr in ( + "g", + "g0", + "g1", + "g2", + "calstate", + "bragg_directions", + "bragg_vectors_indexed", + "g1g2_map", + "strainmap_g1g2", + "strainmap_rotated", + ): + if hasattr(self, attr): + setattr(strainmap_copy, attr, getattr(self, attr)) + + for k in self.metadata.keys(): + strainmap_copy.metadata = self.metadata[k].copy() + return strainmap_copy + + # IO methods # read @classmethod - def _get_constructor_args(cls,group): + def _get_constructor_args(cls, group): """ Returns a dictionary of args/values to pass to the class constructor """ ar_constr_args = RealSlice._get_constructor_args(group) args = { - 'data' : ar_constr_args['data'], - 'name' : ar_constr_args['name'], + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], } return args - - - diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 1df4e78c5..86257b4dc 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -59,7 +59,7 @@ def radial_reduction( def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, figsize=(10, 10), scale=None): fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax) + im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) @@ -636,7 +636,7 @@ def fourier_resample( #def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, # figsize=(10, 10), scale=None): # fig, ax = plt.subplots(figsize=figsize) -# im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax) +# im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.05) # plt.colorbar(im, cax=cax) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 4009e43c9..9df5075b8 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1,2 +1,2 @@ -__version__='0.14.2' +__version__='0.14.3' diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index e0c87a427..7e7147a15 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -437,7 +437,7 @@ def add_bragg_index_labels(ax,d): Adds labels for indexed bragg directions to a plot, using the parameters in dict d. The dictionary d has required and optional parameters as follows: - braggdirections (req'd) (PointList) the Bragg directions. This PointList must have + bragg_directions (req'd) (PointList) the Bragg directions. This PointList must have the fields 'qx','qy','h', and 'k', and may optionally have 'l' voffset (number) vertical offset for the labels hoffset (number) horizontal offset for the labels @@ -450,12 +450,12 @@ def add_bragg_index_labels(ax,d): # handle inputs assert isinstance(ax,Axes) # bragg directions - assert('braggdirections' in d.keys()) - braggdirections = d['braggdirections'] - assert isinstance(braggdirections,PointList) + assert('bragg_directions' in d.keys()) + bragg_directions = d['bragg_directions'] + assert isinstance(bragg_directions,PointList) for k in ('qx','qy','h','k'): - assert k in braggdirections.data.dtype.fields - include_l = True if 'l' in braggdirections.data.dtype.fields else False + assert k in bragg_directions.data.dtype.fields + include_l = True if 'l' in bragg_directions.data.dtype.fields else False # offsets hoffset = d['hoffset'] if 'hoffset' in d.keys() else 0 voffset = d['voffset'] if 'voffset' in d.keys() else 5 @@ -474,20 +474,20 @@ def add_bragg_index_labels(ax,d): # add the points if points: - ax.scatter(braggdirections.data['qy'],braggdirections.data['qx'], + ax.scatter(bragg_directions.data['qy'],bragg_directions.data['qx'], color=pointcolor,s=pointsize) # add index labels - for i in range(braggdirections.length): - x,y = braggdirections.data['qx'][i],braggdirections.data['qy'][i] + for i in range(bragg_directions.length): + x,y = bragg_directions.data['qx'][i],bragg_directions.data['qy'][i] x -= voffset y += hoffset - h,k = braggdirections.data['h'][i],braggdirections.data['k'][i] + h,k = bragg_directions.data['h'][i],bragg_directions.data['k'][i] h = str(h) if h>=0 else r'$\overline{{{}}}$'.format(np.abs(h)) k = str(k) if k>=0 else r'$\overline{{{}}}$'.format(np.abs(k)) s = h+','+k if include_l: - l = braggdirections.data['l'][i] + l = bragg_directions.data['l'][i] l = str(l) if l>=0 else r'$\overline{{{}}}$'.format(np.abs(l)) s += l ax.text(y,x,s,color=color,size=size,ha='center',va='bottom') diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index f63bda993..3b9d99e43 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -567,12 +567,8 @@ def show( ax.matshow(mask_display,cmap=cmap,alpha=mask_alpha,vmin=vmin,vmax=vmax) # ...or, plot its histogram else: - # hist,bin_edges = np.histogram( - # _ar, - # bins=np.linspace(np.min(_ar),np.max(_ar),num=n_bins)) - hist,bin_edges = np.histogram( - _ar, - bins=np.linspace(vmin,vmax,num=n_bins)) + hist,bin_edges = np.histogram(_ar,bins=np.linspace(np.min(_ar), + np.max(_ar),num=n_bins)) w = bin_edges[1]-bin_edges[0] x = bin_edges[:-1]+w/2. ax.bar(x,hist,width=w) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index b487048e2..43cf7fff8 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -590,102 +590,6 @@ def select_point(ar,x,y,i,color='lightblue',color_selected='r',size=20,returnfig return -def select_lattice_vectors(ar,gx,gy,i0,i1,i2, - c_indices='lightblue',c0='g',c1='r',c2='r',c_vectors='r',c_vectorlabels='w', - size_indices=20,width_vectors=1,size_vectorlabels=20, - figsize=(12,6),returnfig=False,**kwargs): - """ - This function accepts a set of reciprocal lattice points (gx,gy) and three indices - (i0,i1,i2). Using those indices as, respectively, the origin, the endpoint of g1, and - the endpoint of g2, this function computes the basis lattice vectors g1,g2, visualizes - them, and returns them. To compute these vectors without visualizing, use - latticevectors.get_selected_lattice_vectors(). - - Returns: - if returnfig==False: g1,g2 - if returnfig==True g1,g2,fig,ax - """ - from py4DSTEM.process.latticevectors import get_selected_lattice_vectors - - # Make the figure - fig,(ax1,ax2) = plt.subplots(1,2,figsize=figsize) - show(ar,figax=(fig,ax1),**kwargs) - show(ar,figax=(fig,ax2),**kwargs) - - # Add indices to left panel - d = {'x':gx,'y':gy,'size':size_indices,'color':c_indices} - d0 = {'x':gx[i0],'y':gy[i0],'size':size_indices,'color':c0,'fontweight':'bold','labels':[str(i0)]} - d1 = {'x':gx[i1],'y':gy[i1],'size':size_indices,'color':c1,'fontweight':'bold','labels':[str(i1)]} - d2 = {'x':gx[i2],'y':gy[i2],'size':size_indices,'color':c2,'fontweight':'bold','labels':[str(i2)]} - add_pointlabels(ax1,d) - add_pointlabels(ax1,d0) - add_pointlabels(ax1,d1) - add_pointlabels(ax1,d2) - - # Compute vectors - g1,g2 = get_selected_lattice_vectors(gx,gy,i0,i1,i2) - - # Add vectors to right panel - dg1 = {'x0':gx[i0],'y0':gy[i0],'vx':g1[0],'vy':g1[1],'width':width_vectors, - 'color':c_vectors,'label':r'$g_1$','labelsize':size_vectorlabels,'labelcolor':c_vectorlabels} - dg2 = {'x0':gx[i0],'y0':gy[i0],'vx':g2[0],'vy':g2[1],'width':width_vectors, - 'color':c_vectors,'label':r'$g_2$','labelsize':size_vectorlabels,'labelcolor':c_vectorlabels} - add_vector(ax2,dg1) - add_vector(ax2,dg2) - - if returnfig: - return g1,g2,fig,(ax1,ax2) - else: - plt.show() - return g1,g2 - - -def show_lattice_vectors(ar,x0,y0,g1,g2,color='r',width=1,labelsize=20,labelcolor='w',returnfig=False,**kwargs): - """ Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy). - """ - fig,ax = show(ar,returnfig=True,**kwargs) - - # Add vectors - dg1 = {'x0':x0,'y0':y0,'vx':g1[0],'vy':g1[1],'width':width, - 'color':color,'label':r'$g_1$','labelsize':labelsize,'labelcolor':labelcolor} - dg2 = {'x0':x0,'y0':y0,'vx':g2[0],'vy':g2[1],'width':width, - 'color':color,'label':r'$g_2$','labelsize':labelsize,'labelcolor':labelcolor} - add_vector(ax,dg1) - add_vector(ax,dg2) - - if returnfig: - return fig,ax - else: - plt.show() - return - - -def show_bragg_indexing(ar,braggdirections,voffset=5,hoffset=0,color='w',size=20, - points=True,pointcolor='r',pointsize=50,returnfig=False,**kwargs): - """ - Shows an array with an overlay describing the Bragg directions - - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. - """ - assert isinstance(braggdirections,PointList) - for k in ('qx','qy','h','k'): - assert k in braggdirections.data.dtype.fields - - fig,ax = show(ar,returnfig=True,**kwargs) - d = {'braggdirections':braggdirections,'voffset':voffset,'hoffset':hoffset,'color':color, - 'size':size,'points':points,'pointsize':pointsize,'pointcolor':pointcolor} - add_bragg_index_labels(ax,d) - - if returnfig: - return fig,ax - else: - plt.show() - return - - def show_max_peak_spacing(ar,spacing,braggdirections,color='g',lw=2,returnfig=False,**kwargs): """ Show a circle of radius `spacing` about each Bragg direction """ diff --git a/setup.py b/setup.py index cb9da8169..b0c7fa081 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ 'numpy >= 1.19', 'scipy >= 1.5.2', 'h5py >= 3.2.0', + 'hdf5plugin >= 4.1.3', 'ncempy >= 1.8.1', 'matplotlib >= 3.2.2', 'scikit-image >= 0.17.2', diff --git a/test/gettestdata.py b/test/gettestdata.py index b3d8a0a40..a84e5b9b3 100644 --- a/test/gettestdata.py +++ b/test/gettestdata.py @@ -53,23 +53,24 @@ # Set data collection key if args.data == 'tutorials': - data = 'tutorials' + data = ['tutorials'] elif args.data == 'io': - data = 'test_io' + data = ['test_io','test_arina'] elif args.data == 'basic': - data = 'small_datacube' + data = ['small_datacube'] elif args.data == 'strain': - data = 'strain' + data = ['strain'] else: raise Exception(f"invalid data choice, {parser.data}") # Download data -download( - data, - destination = testpath, - overwrite = args.overwrite, - verbose = args.verbose -) +for d in data: + download( + d, + destination = testpath, + overwrite = args.overwrite, + verbose = args.verbose + ) # Always download the basic datacube if args.data != 'basic': diff --git a/test/test_nonnative_io/test_arina.py b/test/test_nonnative_io/test_arina.py new file mode 100644 index 000000000..c27cb8ef5 --- /dev/null +++ b/test/test_nonnative_io/test_arina.py @@ -0,0 +1,19 @@ +import py4DSTEM +import emdfile +from os.path import join + + +# Set filepaths +filepath = join(py4DSTEM._TESTPATH, "test_arina/STO_STEM_bench_20us_master.h5") + + +def test_read_arina(): + + # read + data = py4DSTEM.import_file( filepath ) + + # check imported data + assert isinstance(data, emdfile.Array) + assert isinstance(data, py4DSTEM.DataCube) + + diff --git a/test/test_strain.py b/test/test_strain.py index 5bfa0efd3..bc9b8b58c 100644 --- a/test/test_strain.py +++ b/test/test_strain.py @@ -27,5 +27,7 @@ def test_strainmap_instantiation(self): ) assert(isinstance(strainmap, StrainMap)) + assert(strainmap.calibration is not None) + assert(strainmap.calibration is strainmap.braggvectors.calibration) From 4e1be96c7f58d267abf7c85ffd83d82687f40fb6 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 16:31:53 -0700 Subject: [PATCH 009/176] Fixing merge conflicts --- py4DSTEM/process/polar/polar_analysis.py | 30 +----------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index fa6a40a4f..447baaf1d 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -6,17 +6,10 @@ from emdfile import tqdmnd -<<<<<<< Updated upstream -def calculate_FEM_global( - self, - use_median_local = False, - use_median_global = False, -======= def calculate_radial_statistics( self, median_local = False, median_global = False, ->>>>>>> Stashed changes plot_results = False, figsize = (8,4), returnval = False, @@ -49,37 +42,22 @@ def calculate_radial_statistics( self.scattering_vector = self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() self.scattering_vector_units = self.calibration.get_Q_pixel_units() -<<<<<<< Updated upstream - # init radial data array -======= # init radial data arrays ->>>>>>> Stashed changes self.radial_all = np.zeros(( self._datacube.shape[0], self._datacube.shape[1], self.polar_shape[1], )) -<<<<<<< Updated upstream -======= self.radial_all_std = np.zeros(( self._datacube.shape[0], self._datacube.shape[1], self.polar_shape[1], )) ->>>>>>> Stashed changes - # Compute the radial mean for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], -<<<<<<< Updated upstream - desc="Global FEM", - unit=" probe positions", - disable=not progress_bar): - - self.radial_all[rx,ry] = np.mean(self.data[rx,ry],axis=0) -======= desc="Radial statistics", unit=" probe positions", disable=not progress_bar): @@ -90,7 +68,6 @@ def calculate_radial_statistics( self.radial_all_std[rx,ry] = np.sqrt(np.mean( (self.data[rx,ry] - self.radial_all[rx,ry][None])**2, axis=0)) ->>>>>>> Stashed changes self.radial_avg = np.mean(self.radial_all, axis=(0,1)) self.radial_var = np.mean( @@ -171,9 +148,7 @@ def calculate_FEM_local( """ -<<<<<<< Updated upstream - 1+1 -======= + pass @@ -183,6 +158,3 @@ def calculate_FEM_local( # returnfig = False, # ): - ->>>>>>> Stashed changes - From 46effbb395e253e573e35d0b3cfa3165dac492a9 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 9 Aug 2023 16:32:31 -0700 Subject: [PATCH 010/176] Merge conflict fixing --- py4DSTEM/process/polar/polar_datacube.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index b6ef4ee66..c0b8871f9 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -95,12 +95,8 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( -<<<<<<< Updated upstream - calculate_FEM_global, -======= # calculate_FEM_global, calculate_radial_statistics, ->>>>>>> Stashed changes plot_FEM_global, calculate_FEM_local, ) From 68b31a7c8dd0700c611eff6c56232558fe2e9514 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 16 Aug 2023 17:40:21 -0700 Subject: [PATCH 011/176] single slice tv denoise --- .../iterative_ptychographic_constraints.py | 81 ++++++++++++++++--- .../iterative_singleslice_ptychography.py | 12 +++ setup.py | 1 + 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 9af22ba92..6ae9e176d 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,4 +1,5 @@ import numpy as np +import pylops from py4DSTEM.process.phase.utils import ( array_slice, estimate_global_transformation_ransac, @@ -7,6 +8,7 @@ regularize_probe_amplitude, ) from py4DSTEM.process.utils import get_CoM +import warnings class PtychographicConstraints: @@ -183,6 +185,59 @@ def _object_butterworth_constraint( return current_object + def _object_denoise_tv_pylops(self, current_object, weight): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float, optional + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("tv_denoise currently for potential objects only"), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = 40 + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + def _object_denoise_tv_chambolle( self, current_object, @@ -363,8 +418,8 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - probe_intensity = xp.abs(current_probe) ** 2 - #current_probe_sum = xp.sum(probe_intensity) + # probe_intensity = xp.abs(current_probe) ** 2 + # current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -374,10 +429,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe #* normalization + return updated_probe # * normalization def _probe_fourier_amplitude_constraint( self, @@ -406,7 +461,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - #current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -419,10 +474,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe #* normalization + return updated_probe # * normalization def _probe_aperture_constraint( self, @@ -444,16 +499,16 @@ def _probe_aperture_constraint( """ xp = self._xp - #current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe #* normalization + return updated_probe # * normalization def _probe_aberration_fitting_constraint( self, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0480bae8a..0c9af9649 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1023,6 +1023,8 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, object_positivity, shrinkage_rad, object_mask, @@ -1108,6 +1110,12 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weight, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1198,6 +1206,8 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1618,6 +1628,8 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse diff --git a/setup.py b/setup.py index b0c7fa081..d8baff354 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ 'dask >= 2.3.0', 'distributed >= 2.3.0', 'emdfile >= 0.0.10', + 'pylops >= 2.1.0' ], extras_require={ 'ipyparallel': ['ipyparallel >= 6.2.4', 'dill >= 0.3.3'], From 5acdd5809e51910c6f1a4242248eb4f12c4be5a7 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 17 Aug 2023 14:56:10 -0700 Subject: [PATCH 012/176] multislice tv denoise... will test more before adding to other classes... --- .../iterative_multislice_ptychography.py | 119 +++++++++++-- .../iterative_ptychographic_constraints.py | 163 ++++++++++-------- .../iterative_singleslice_ptychography.py | 8 + 3 files changed, 197 insertions(+), 93 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 438c9d1fb..5966ca07a 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex @@ -1450,6 +1451,70 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_denoise_tv_pylops(self, current_object, weights): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("tv_denoise currently for potential objects only"), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = 40 + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + # remove padding + + return current_object_tv[1:-1] + def _constraints( self, current_object, @@ -1482,9 +1547,11 @@ def _constraints( shrinkage_rad, object_mask, pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, tv_denoise, - tv_denoise_weight, - tv_denoise_pad, + tv_denoise_weights, ): """ Ptychographic constraints operator. @@ -1549,12 +1616,17 @@ def _constraints( If not None, used to calculate additional shrinkage using masked-mean of object pure_phase_object: bool If True, object amplitude is set to unity - tv_denoise: bool + tv_denoise_chambolle: bool If True, performs TV denoising along z - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. Returns -------- @@ -1586,13 +1658,16 @@ def _constraints( current_object, kz_regularization_gamma ) elif tv_denoise: - if self._object_type == "complex": - raise NotImplementedError() + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + ) + elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, - tv_denoise_weight, + tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad, + pad_object=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1691,9 +1766,11 @@ def reconstruct( shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, - tv_denoise_weight=None, - tv_denoise_pad=True, + tv_denoise_weights=None, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1786,12 +1863,17 @@ def reconstruct( If true, the potential mean outside the FOV is forced to zero at each iteration pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - tv_denoise_iter: bool + tv_denoise_iter_chambolle: bool Number of iterations with TV denoisining - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -2134,9 +2216,12 @@ def reconstruct( else None, pure_phase_object=a0 < pure_phase_object_iter and self._object_type == "complex", - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_pad=tv_denoise_pad, + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 6ae9e176d..e300e1154 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -193,7 +193,7 @@ def _object_denoise_tv_pylops(self, current_object, weight): ---------- current_object: np.ndarray Current object estimate - weight : float, optional + weight : float Denoising weight. The greater `weight`, the more denoising (at the expense of fidelity to `input`). Returns @@ -284,90 +284,101 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) + if xp.iscomplexobj(current_object): + updated_object = current_object + warnings.warn( + ("tv_denoise currently for potential objects only"), + UserWarning, + ) else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" + + current_object_sum = xp.sum(current_object) + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if pad_object: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (1, 1) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() - p = xp.zeros( - (current_object.ndim,) + current_object.shape, dtype=current_object.dtype - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ slice(None), ] * (current_object.ndim + 1) for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E E_previous = E - i += 1 + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] + if pad_object: + for ax in range(len(ndim)): + slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + updated_object = updated_object[slices] + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) - return updated_object / xp.sum(updated_object) * current_object_sum + return updated_object def _probe_center_of_mass_constraint(self, current_probe): """ diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0c9af9649..2480974f3 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1080,6 +1080,10 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1294,6 +1298,10 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float From 1fbfd3bb10dccf812d938fffa089caa556710bea Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 11:30:15 +1000 Subject: [PATCH 013/176] RDF working --- py4DSTEM/process/polar/polar_analysis.py | 346 ++++++++++++++++++++++- py4DSTEM/process/polar/polar_datacube.py | 4 +- 2 files changed, 338 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 447baaf1d..38068e578 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -2,6 +2,8 @@ import numpy as np import matplotlib.pyplot as plt +from scipy.optimize import curve_fit + from emdfile import tqdmnd @@ -69,11 +71,14 @@ def calculate_radial_statistics( (self.data[rx,ry] - self.radial_all[rx,ry][None])**2, axis=0)) - self.radial_avg = np.mean(self.radial_all, axis=(0,1)) + self.radial_mean = np.mean(self.radial_all, axis=(0,1)) self.radial_var = np.mean( - (self.radial_all - self.radial_avg[None,None])**2, + (self.radial_all - self.radial_mean[None,None])**2, axis=(0,1)) - self.radial_var_norm = self.radial_var / self.radial_avg**2 + + self.radial_var_norm = self.radial_var + sub = self.radial_mean > 0.0 + self.radial_var_norm[sub] /= self.radial_mean[sub]**2 # plot results if plot_results: @@ -92,16 +97,49 @@ def calculate_radial_statistics( # Return values if returnval: if returnfig: - return self.radial_avg, self.radial_var, fig, ax + return self.radial_mean, self.radial_var, fig, ax else: - return self.radial_avg, self.radial_var + return self.radial_mean, self.radial_var else: if returnfig: return fig, ax else: pass -def plot_FEM_global( + +def plot_radial_mean( + self, + log_x = False, + log_y = False, + figsize = (8,4), + returnfig = False, + ): + """ + Plot radial mean + """ + fig,ax = plt.subplots(figsize=figsize) + ax.plot( + self.scattering_vector, + self.radial_mean, + ) + + if log_x: + ax.set_xscale('log') + if log_y: + ax.set_yscale('log') + + ax.set_xlabel('Scattering Vector (' + self.scattering_vector_units + ')') + ax.set_ylabel('Radial Mean') + if log_x and self.scattering_vector[0] == 0.0: + ax.set_xlim((self.scattering_vector[1],self.scattering_vector[-1])) + else: + ax.set_xlim((self.scattering_vector[0],self.scattering_vector[-1])) + + if returnfig: + return fig, ax + + +def plot_radial_var_norm( self, figsize = (8,4), returnfig = False, @@ -123,6 +161,279 @@ def plot_FEM_global( return fig, ax +def calculate_pair_dist_function( + self, + k_min = 0.05, + k_max = None, + k_width = 0.25, + # k_pad_max = 10.0, + r_min = 0.0, + r_max = 20.0, + r_step = 0.02, + # iterative_pdf_refine = True, + num_iter = 10, + plot_fits = False, + plot_sf_estimate = False, + plot_pdf = True, + figsize = (8,4), + maxfev = None, + ): + """ + Calculate the pair distribution function (PDF). + + """ + + # init + k = self.scattering_vector + dk = k[1] - k[0] + k2 = k**2 + Ik = self.radial_mean + int_mean = np.mean(Ik) + sub_fit = k >= k_min + + # initial coefs + const_bg = np.min(self.radial_mean) + int0 = np.median(self.radial_mean) - const_bg + sigma0 = np.mean(k) + coefs = [const_bg, int0, sigma0, int0, sigma0] + lb = [0,0,0,0,0] + ub = [np.inf, np.inf, np.inf, np.inf, np.inf] + # noise_est = 1/k + # noise_est = np.divide(1.0, k, out=np.zeros_like(k), where=k!=0) + noise_est = k[-1] - k + dk + + # Estimate the mean atomic form factor + background + if maxfev is None: + coefs = curve_fit( + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma = noise_est[sub_fit], + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + )[0] + else: + coefs = curve_fit( + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma = noise_est[sub_fit], + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + maxfev = maxfev, + )[0] + coefs[0] *= int_mean + coefs[1] *= int_mean + coefs[3] *= int_mean + + # Calculate the mean atomic form factor wthout any background + coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) + fk = scattering_model(k2, coefs_fk) + bg = scattering_model(k2, coefs) + + # mask for structure factor estimate + if k_max is None: + k_max = np.max(k) + mask = np.clip(np.minimum( + (k - k_min) / k_width, + (k_max - k) / k_width, + ),0,1) + mask = np.sin(mask*(np.pi/2))**2 + + # Estimate the reduced structure factor S(k) + Sk = (Ik - bg) * k / fk + mask_sum = np.sum(mask) + Sk = (Sk - np.sum(Sk*mask)/mask_sum) * mask + + # # pad or crop S(k) to 0 and k_pad_max + # k_pad = np.arange(0, k_pad_max, dk) + # Sk_pad = np.zeros_like(k_pad) + # ind_0 = np.argmin(np.abs(k_pad-k[0])) + # ind_max = ind_0 + k.size + # if ind_max > k_pad.size: + # Sk_pad[ind_0:] = Sk[ind_0:k_pad.size] + # else: + # Sk_pad[ind_0:ind_max] = Sk + + # Calculate the real space PDF + # dr = 1/(2*k_pad[-1]) + r = np.arange(r_min, r_max, r_step) + ra,ka = np.meshgrid(r,k) + pdf = (2/np.pi)*np.pi*dk*np.sum( + np.sin( + 2*np.pi*ra*ka + ) * Sk[:,None], + axis=0, + ) + + # invert + + ind_max = np.argmax(pdf) + r_ind_max = r[ind_max] + r_mask = np.minimum(r / r_ind_max, 1.0) + r_mask = np.sin(r_mask*np.pi/2)**2 + + Sk_back_proj = (2*r_step)*np.sum( + np.sin( + 2*np.pi*ra*ka + ) * pdf[None,:] * r_mask[None,:], + axis=1, + ) + + fig,ax = plt.subplots(figsize=figsize) + ax.plot( + k, + Sk, + color = 'k', + ) + ax.plot( + k, + Sk_back_proj, + color = 'r', + ) + + + + # # iterative refinement of the PDF + # if iterative_pdf_refine: + # # pdf = np.maximum(pdf + (r/r[-1]), 0.0) + + # ind_max = np.argmax(pdf) + # r_ind_max = r[ind_max] + # r_mask = np.minimum(r / r_ind_max, 1.0) + # r_mask = np.sin(r_mask*np.pi/2)**2 + + # pdf = np.maximum(pdf * r_mask + (r/r[-1]), 0.0) + # r_weight = r_mask * (1 - r / r[-1])**2 + + + + # # basis = np.vstack((np.ones_like(r),r)).T + # # coefs_lin = np.linalg.lstsq(basis, pdf, rcond=None)[0] + # # pdf_lin = basis * coefs_lin + # # print(coefs_lin) + + + # for a0 in range(10): + # Sk_back_proj = (1*np.pi/r.size)*np.sum( + # np.sin( + # 2*np.pi*ra*ka + # ) * pdf[None,:], + # axis=1, + # ) + + # Sk_diff = Sk - Sk_back_proj + # Sk_diff = (Sk_diff - np.mean(Sk_diff*mask)/mask_sum) * mask + + # pdf_update = 4*np.pi*dk*np.sum( + # np.sin( + # 8*np.pi*ra*ka + # ) * Sk_diff[:,None], + # axis=0, + # ) * r_weight + + # pdf = np.maximum(pdf + 0.5*pdf_update, 0.0) + + # fig,ax = plt.subplots(figsize=figsize) + # ax.plot( + # r, + # pdf, + # color = 'k', + # ) + # # ax.plot( + # # r, + # # pdf_lin, + # # color = 'r', + # # ) + + # # ax.plot( + # # r, + # # pdf + pdf_update, + # # color = 'r', + # # ) + + # # ax.plot( + # # k, + # # Sk, + # # color = 'k', + # # ) + # # ax.plot( + # # k, + # # Sk_back_proj, + # # color = 'r', + # # ) + # # ax.plot( + # # Sk_diff, + # # color = 'r', + # # ) + + + # Plots + if plot_fits: + fig,ax = plt.subplots(figsize=figsize) + ax.plot( + self.scattering_vector, + self.radial_mean, + color = 'k', + ) + ax.plot( + k, + np.ones(k.size)*coefs[0], + color = 'r', + ) + ax.plot( + k, + fk + coefs[0], + color = 'r', + ) + ax.set_xlabel('Scattering Vector (' + self.scattering_vector_units + ')') + ax.set_ylabel('Radial Mean') + ax.set_xlim((self.scattering_vector[0],self.scattering_vector[-1])) + ax.set_ylim((0,2e-5)) + ax.set_xlabel('Scattering Vector [A^-1]') + ax.set_ylabel('I(k) and Fit Estimates') + + + if plot_sf_estimate: + fig,ax = plt.subplots(figsize=figsize) + ax.plot( + k, + Sk, + color = 'r', + ) + yr = (np.min(Sk),np.max(Sk)) + ax.set_ylim(( + yr[0]-0.05*(yr[1]-yr[0]), + yr[1]+0.05*(yr[1]-yr[0]), + )) + ax.set_xlabel('Scattering Vector [A^-1]') + ax.set_ylabel('Structure Factor') + + if plot_pdf: + fig,ax = plt.subplots(figsize=figsize) + ax.plot( + r, + pdf, + color = 'r', + ) + ax.set_xlabel('Radius [A]') + ax.set_ylabel('Pair Distribution Function') + # r = (np.min(Sk),np.max(Sk)) + # ax.set_ylim(( + # r[0]-0.05*(r[1]-r[0]), + # r[1]+0.05*(r[1]-r[0]), + # )) + + + # ax.set_yscale('log') + + + + + + def calculate_FEM_local( self, figsize = (8,6), @@ -152,9 +463,22 @@ def calculate_FEM_local( pass -# def radial_average( -# self, -# figsize = (8,6), -# returnfig = False, -# ): + + +def scattering_model(k2, *coefs): + coefs = np.squeeze(np.array(coefs)) + + const_bg = coefs[0] + int0 = coefs[1] + sigma0 = coefs[2] + int1 = coefs[3] + sigma1 = coefs[4] + + + + int_model = const_bg + \ + int0*np.exp(k2/(-2*sigma0**2)) + \ + (int1*sigma1)**2/(k2 + sigma1**2) + + return int_model diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index c0b8871f9..cc3f534c6 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -97,7 +97,9 @@ def __init__( from py4DSTEM.process.polar.polar_analysis import ( # calculate_FEM_global, calculate_radial_statistics, - plot_FEM_global, + plot_radial_mean, + plot_radial_var_norm, + calculate_pair_dist_function, calculate_FEM_local, ) from py4DSTEM.process.polar.polar_peaks import ( From aa47b7078eb2d8b0814c04268290f9f35a0a5cc4 Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 13:49:26 +1000 Subject: [PATCH 014/176] Adding more options --- py4DSTEM/process/polar/polar_analysis.py | 143 +++++++++++++++++------ 1 file changed, 105 insertions(+), 38 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 38068e578..fc26408a4 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -3,7 +3,8 @@ import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit - +from scipy.special import comb +from scipy.ndimage import gaussian_filter from emdfile import tqdmnd @@ -166,12 +167,16 @@ def calculate_pair_dist_function( k_min = 0.05, k_max = None, k_width = 0.25, + k_lowpass = None, + k_highpass = None, # k_pad_max = 10.0, r_min = 0.0, r_max = 20.0, r_step = 0.02, + damp_origin_fluctuations = False, + # poly_background_order = 2, # iterative_pdf_refine = True, - num_iter = 10, + # num_iter = 10, plot_fits = False, plot_sf_estimate = False, plot_pdf = True, @@ -192,8 +197,8 @@ def calculate_pair_dist_function( sub_fit = k >= k_min # initial coefs - const_bg = np.min(self.radial_mean) - int0 = np.median(self.radial_mean) - const_bg + const_bg = np.min(self.radial_mean) / int_mean + int0 = np.median(self.radial_mean) / int_mean - const_bg sigma0 = np.mean(k) coefs = [const_bg, int0, sigma0, int0, sigma0] lb = [0,0,0,0,0] @@ -202,6 +207,11 @@ def calculate_pair_dist_function( # noise_est = np.divide(1.0, k, out=np.zeros_like(k), where=k!=0) noise_est = k[-1] - k + dk + # print( + # np.round(coefs[0],3), + # np.round(coefs[1],3), + # np.round(coefs[3],3)) + # Estimate the mean atomic form factor + background if maxfev is None: coefs = curve_fit( @@ -224,10 +234,16 @@ def calculate_pair_dist_function( bounds = (lb,ub), maxfev = maxfev, )[0] + + # print( + # np.round(coefs[0],3), + # np.round(coefs[1],3), + # np.round(coefs[3],3)) coefs[0] *= int_mean coefs[1] *= int_mean coefs[3] *= int_mean + # Calculate the mean atomic form factor wthout any background coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) fk = scattering_model(k2, coefs_fk) @@ -244,18 +260,23 @@ def calculate_pair_dist_function( # Estimate the reduced structure factor S(k) Sk = (Ik - bg) * k / fk + + # Masking edges of S(k) mask_sum = np.sum(mask) Sk = (Sk - np.sum(Sk*mask)/mask_sum) * mask - # # pad or crop S(k) to 0 and k_pad_max - # k_pad = np.arange(0, k_pad_max, dk) - # Sk_pad = np.zeros_like(k_pad) - # ind_0 = np.argmin(np.abs(k_pad-k[0])) - # ind_max = ind_0 + k.size - # if ind_max > k_pad.size: - # Sk_pad[ind_0:] = Sk[ind_0:k_pad.size] - # else: - # Sk_pad[ind_0:ind_max] = Sk + if k_lowpass is not None and k_lowpass > 0.0: + Sk = gaussian_filter( + Sk, + sigma=k_lowpass / dk, + mode = 'nearest') + if k_highpass is not None: + Sk_lowpass = gaussian_filter( + Sk, + sigma=k_highpass / dk, + mode = 'nearest') + Sk -= Sk_lowpass + # Calculate the real space PDF # dr = 1/(2*k_pad[-1]) @@ -268,11 +289,17 @@ def calculate_pair_dist_function( axis=0, ) - # invert + if damp_origin_fluctuations: + ind_max = np.argmax(pdf) + r_ind_max = r[ind_max] + r_mask = np.minimum(r / r_ind_max, 1.0) + r_mask = np.sin(r_mask*np.pi/2)**2 + pdf *= r_mask - ind_max = np.argmax(pdf) - r_ind_max = r[ind_max] - r_mask = np.minimum(r / r_ind_max, 1.0) + # invert + ind_max = np.argmax(pdf * np.sqrt(r)) + r_ind_max = r[ind_max-1] + r_mask = np.minimum(r / (r_ind_max), 1.0) r_mask = np.sin(r_mask*np.pi/2)**2 Sk_back_proj = (2*r_step)*np.sum( @@ -282,19 +309,49 @@ def calculate_pair_dist_function( axis=1, ) - fig,ax = plt.subplots(figsize=figsize) - ax.plot( - k, - Sk, - color = 'k', - ) - ax.plot( - k, - Sk_back_proj, - color = 'r', - ) + # fig,ax = plt.subplots(figsize=figsize) + # ax.plot( + # r, + # pdf*np.sqrt(r), + # color = 'r', + # ) + # ax.plot( + # k, + # Sk, + # color = 'k', + # ) + # ax.plot( + # k, + # Sk_back_proj, + # color = 'r', + # ) + + + # # polynomial high pass filtering + # if poly_background_order is not None: + # u = np.linspace(0,1,k.size) + # basis = np.zeros((k.size,poly_background_order+1)) + # for ii in range(poly_background_order+1): + # basis[:,ii] = comb(poly_background_order,ii) * \ + # ((1-u)**(poly_background_order-ii)) * (u**ii) + # coefs = np.linalg.lstsq( + # basis[sub_fit,:], + # Sk[sub_fit], + # rcond=None)[0] + # bg_poly = basis @ coefs + # Sk -= bg_poly + # # pad or crop S(k) to 0 and k_pad_max + # k_pad = np.arange(0, k_pad_max, dk) + # Sk_pad = np.zeros_like(k_pad) + # ind_0 = np.argmin(np.abs(k_pad-k[0])) + # ind_max = ind_0 + k.size + # if ind_max > k_pad.size: + # Sk_pad[ind_0:] = Sk[ind_0:k_pad.size] + # else: + # Sk_pad[ind_0:ind_max] = Sk + # # iterative refinement of the PDF # if iterative_pdf_refine: @@ -378,23 +435,30 @@ def calculate_pair_dist_function( self.radial_mean, color = 'k', ) + # ax.plot( + # k, + # np.ones(k.size)*coefs[0], + # color = 'r', + # ) ax.plot( k, - np.ones(k.size)*coefs[0], - color = 'r', - ) - ax.plot( - k, - fk + coefs[0], + bg, color = 'r', ) ax.set_xlabel('Scattering Vector (' + self.scattering_vector_units + ')') ax.set_ylabel('Radial Mean') ax.set_xlim((self.scattering_vector[0],self.scattering_vector[-1])) - ax.set_ylim((0,2e-5)) + # ax.set_ylim((0,2e-5)) ax.set_xlabel('Scattering Vector [A^-1]') ax.set_ylabel('I(k) and Fit Estimates') + ax.set_ylim((np.min(self.radial_mean[self.radial_mean>0])*0.8, + np.max(self.radial_mean*mask)*1.25)) + ax.set_yscale('log') + # print(np.min(self.radial_mean)*0.8) + # print(np.min(self.radial_mean)*0.8) + + if plot_sf_estimate: fig,ax = plt.subplots(figsize=figsize) @@ -474,11 +538,14 @@ def scattering_model(k2, *coefs): int1 = coefs[3] sigma1 = coefs[4] - - int_model = const_bg + \ int0*np.exp(k2/(-2*sigma0**2)) + \ - (int1*sigma1)**2/(k2 + sigma1**2) + int1*np.exp(k2**2/(-2*sigma1**4)) + + # (int1*sigma1)/(k2 + sigma1**2) + # int1*np.exp(k2/(-2*sigma1**2)) + # int1*np.exp(k2/(-2*sigma1**2)) + return int_model From 8adcbde720016abf88dc9cb7e7951d20a9636391 Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 15:32:42 +1000 Subject: [PATCH 015/176] Added full pdf --- py4DSTEM/process/polar/polar_analysis.py | 118 ++++++++++++++--------- 1 file changed, 71 insertions(+), 47 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index fc26408a4..3f6df4d60 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -3,7 +3,7 @@ import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit -from scipy.special import comb +from scipy.special import comb, erf from scipy.ndimage import gaussian_filter from emdfile import tqdmnd @@ -177,9 +177,11 @@ def calculate_pair_dist_function( # poly_background_order = 2, # iterative_pdf_refine = True, # num_iter = 10, + dens = None, plot_fits = False, plot_sf_estimate = False, - plot_pdf = True, + plot_reduced_pdf = True, + plot_pdf = False, figsize = (8,4), maxfev = None, ): @@ -203,15 +205,9 @@ def calculate_pair_dist_function( coefs = [const_bg, int0, sigma0, int0, sigma0] lb = [0,0,0,0,0] ub = [np.inf, np.inf, np.inf, np.inf, np.inf] - # noise_est = 1/k - # noise_est = np.divide(1.0, k, out=np.zeros_like(k), where=k!=0) + # Weight the fit towards high k values noise_est = k[-1] - k + dk - # print( - # np.round(coefs[0],3), - # np.round(coefs[1],3), - # np.round(coefs[3],3)) - # Estimate the mean atomic form factor + background if maxfev is None: coefs = curve_fit( @@ -235,10 +231,6 @@ def calculate_pair_dist_function( maxfev = maxfev, )[0] - # print( - # np.round(coefs[0],3), - # np.round(coefs[1],3), - # np.round(coefs[3],3)) coefs[0] *= int_mean coefs[1] *= int_mean coefs[3] *= int_mean @@ -252,11 +244,15 @@ def calculate_pair_dist_function( # mask for structure factor estimate if k_max is None: k_max = np.max(k) + # mask = np.clip(np.minimum( + # (k - k_min) / k_width, + # (k_max - k) / k_width, + # ),0,1) mask = np.clip(np.minimum( - (k - k_min) / k_width, + (k - 0.0) / k_width, (k_max - k) / k_width, ),0,1) - mask = np.sin(mask*(np.pi/2))**2 + mask = np.sin(mask*(np.pi/2)) # Estimate the reduced structure factor S(k) Sk = (Ik - bg) * k / fk @@ -265,6 +261,7 @@ def calculate_pair_dist_function( mask_sum = np.sum(mask) Sk = (Sk - np.sum(Sk*mask)/mask_sum) * mask + # Filtering of S(k) if k_lowpass is not None and k_lowpass > 0.0: Sk = gaussian_filter( Sk, @@ -277,44 +274,68 @@ def calculate_pair_dist_function( mode = 'nearest') Sk -= Sk_lowpass - # Calculate the real space PDF - # dr = 1/(2*k_pad[-1]) r = np.arange(r_min, r_max, r_step) ra,ka = np.meshgrid(r,k) - pdf = (2/np.pi)*np.pi*dk*np.sum( + pdf_reduced = (2/np.pi)*dk*np.sum( np.sin( 2*np.pi*ra*ka ) * Sk[:,None], axis=0, ) + # Damp the unphysical fluctuations at the PDF origin if damp_origin_fluctuations: - ind_max = np.argmax(pdf) + ind_max = np.argmax(pdf_reduced) r_ind_max = r[ind_max] r_mask = np.minimum(r / r_ind_max, 1.0) r_mask = np.sin(r_mask*np.pi/2)**2 - pdf *= r_mask + pdf_reduced *= r_mask - # invert - ind_max = np.argmax(pdf * np.sqrt(r)) - r_ind_max = r[ind_max-1] - r_mask = np.minimum(r / (r_ind_max), 1.0) - r_mask = np.sin(r_mask*np.pi/2)**2 + # Store results + self.pdf_r = r + self.pdf_reduced = pdf_reduced - Sk_back_proj = (2*r_step)*np.sum( - np.sin( - 2*np.pi*ra*ka - ) * pdf[None,:] * r_mask[None,:], - axis=1, - ) + # if density is provided, we can estimate the full PDF + if dens is not None: + pdf = pdf_reduced.copy() + pdf[1:] /= (4*np.pi*dens*r[1:]*(r[1]-r[0])) + pdf += 1 + + if damp_origin_fluctuations: + pdf *= r_mask + pdf = np.maximum(pdf, 0.0) # fig,ax = plt.subplots(figsize=figsize) # ax.plot( - # r, - # pdf*np.sqrt(r), + # k, + # mask, # color = 'r', # ) + # # invert + # ind_max = np.argmax(pdf_reduced* np.sqrt(r)) + # r_ind_max = r[ind_max-1] + # r_mask = np.minimum(r / (r_ind_max), 1.0) + # r_mask = np.sin(r_mask*np.pi/2)**2 + + # pdf_corr = np.maximum(pdf*6 + erf((r - 1.5)/0.5)*0.5 + 0.5, 0.0) + # fig,ax = plt.subplots(figsize=figsize) + # ax.plot( + # r, + # pdf_corr, + # color = 'k', + # ) + + + # Sk_back_proj = (0.5*r_step)*np.sum( + # np.sin( + # 2*np.pi*ra*ka + # ) * pdf_corr[None,:],# * r_mask[None,:], + # # ) * pdf_corr[None,:],# * r_mask[None,:], + # axis=1, + # ) + + # fig,ax = plt.subplots(figsize=figsize) # ax.plot( # k, # Sk, @@ -355,14 +376,14 @@ def calculate_pair_dist_function( # # iterative refinement of the PDF # if iterative_pdf_refine: - # # pdf = np.maximum(pdf + (r/r[-1]), 0.0) + # # pdf_reduced= np.maximum(pdf_reduced+ (r/r[-1]), 0.0) # ind_max = np.argmax(pdf) # r_ind_max = r[ind_max] # r_mask = np.minimum(r / r_ind_max, 1.0) # r_mask = np.sin(r_mask*np.pi/2)**2 - # pdf = np.maximum(pdf * r_mask + (r/r[-1]), 0.0) + # pdf_reduced= np.maximum(pdf_reduced* r_mask + (r/r[-1]), 0.0) # r_weight = r_mask * (1 - r / r[-1])**2 @@ -391,7 +412,7 @@ def calculate_pair_dist_function( # axis=0, # ) * r_weight - # pdf = np.maximum(pdf + 0.5*pdf_update, 0.0) + # pdf_reduced= np.maximum(pdf_reduced+ 0.5*pdf_update, 0.0) # fig,ax = plt.subplots(figsize=figsize) # ax.plot( @@ -407,7 +428,7 @@ def calculate_pair_dist_function( # # ax.plot( # # r, - # # pdf + pdf_update, + # # pdf_reduced+ pdf_update, # # color = 'r', # # ) @@ -435,11 +456,6 @@ def calculate_pair_dist_function( self.radial_mean, color = 'k', ) - # ax.plot( - # k, - # np.ones(k.size)*coefs[0], - # color = 'r', - # ) ax.plot( k, bg, @@ -455,10 +471,6 @@ def calculate_pair_dist_function( ax.set_ylim((np.min(self.radial_mean[self.radial_mean>0])*0.8, np.max(self.radial_mean*mask)*1.25)) ax.set_yscale('log') - # print(np.min(self.radial_mean)*0.8) - # print(np.min(self.radial_mean)*0.8) - - if plot_sf_estimate: fig,ax = plt.subplots(figsize=figsize) @@ -473,7 +485,17 @@ def calculate_pair_dist_function( yr[1]+0.05*(yr[1]-yr[0]), )) ax.set_xlabel('Scattering Vector [A^-1]') - ax.set_ylabel('Structure Factor') + ax.set_ylabel('Reduced Structure Factor') + + if plot_reduced_pdf: + fig,ax = plt.subplots(figsize=figsize) + ax.plot( + r, + pdf_reduced, + color = 'r', + ) + ax.set_xlabel('Radius [A]') + ax.set_ylabel('Reduced Pair Distribution Function') if plot_pdf: fig,ax = plt.subplots(figsize=figsize) @@ -484,6 +506,8 @@ def calculate_pair_dist_function( ) ax.set_xlabel('Radius [A]') ax.set_ylabel('Pair Distribution Function') + + # r = (np.min(Sk),np.max(Sk)) # ax.set_ylim(( # r[0]-0.05*(r[1]-r[0]), From 5167c456cbfcd32cebb83c03fbd66931bafc7c5b Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 16:23:27 +1000 Subject: [PATCH 016/176] Extra factor of 2/pi? --- py4DSTEM/process/polar/polar_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 3f6df4d60..86d97476d 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -300,6 +300,7 @@ def calculate_pair_dist_function( if dens is not None: pdf = pdf_reduced.copy() pdf[1:] /= (4*np.pi*dens*r[1:]*(r[1]-r[0])) + pdf *= (2/np.pi) pdf += 1 if damp_origin_fluctuations: From 2b30bc44e41a15883811696ead37c2319c2c4088 Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 16:26:35 +1000 Subject: [PATCH 017/176] cleaning up --- py4DSTEM/process/polar/polar_analysis.py | 160 ++--------------------- 1 file changed, 13 insertions(+), 147 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 86d97476d..22b38d62f 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -307,146 +307,7 @@ def calculate_pair_dist_function( pdf *= r_mask pdf = np.maximum(pdf, 0.0) - # fig,ax = plt.subplots(figsize=figsize) - # ax.plot( - # k, - # mask, - # color = 'r', - # ) - # # invert - # ind_max = np.argmax(pdf_reduced* np.sqrt(r)) - # r_ind_max = r[ind_max-1] - # r_mask = np.minimum(r / (r_ind_max), 1.0) - # r_mask = np.sin(r_mask*np.pi/2)**2 - - # pdf_corr = np.maximum(pdf*6 + erf((r - 1.5)/0.5)*0.5 + 0.5, 0.0) - # fig,ax = plt.subplots(figsize=figsize) - # ax.plot( - # r, - # pdf_corr, - # color = 'k', - # ) - - - # Sk_back_proj = (0.5*r_step)*np.sum( - # np.sin( - # 2*np.pi*ra*ka - # ) * pdf_corr[None,:],# * r_mask[None,:], - # # ) * pdf_corr[None,:],# * r_mask[None,:], - # axis=1, - # ) - # fig,ax = plt.subplots(figsize=figsize) - # ax.plot( - # k, - # Sk, - # color = 'k', - # ) - # ax.plot( - # k, - # Sk_back_proj, - # color = 'r', - # ) - - - # # polynomial high pass filtering - # if poly_background_order is not None: - # u = np.linspace(0,1,k.size) - # basis = np.zeros((k.size,poly_background_order+1)) - # for ii in range(poly_background_order+1): - # basis[:,ii] = comb(poly_background_order,ii) * \ - # ((1-u)**(poly_background_order-ii)) * (u**ii) - # coefs = np.linalg.lstsq( - # basis[sub_fit,:], - # Sk[sub_fit], - # rcond=None)[0] - # bg_poly = basis @ coefs - # Sk -= bg_poly - - - # # pad or crop S(k) to 0 and k_pad_max - # k_pad = np.arange(0, k_pad_max, dk) - # Sk_pad = np.zeros_like(k_pad) - # ind_0 = np.argmin(np.abs(k_pad-k[0])) - # ind_max = ind_0 + k.size - # if ind_max > k_pad.size: - # Sk_pad[ind_0:] = Sk[ind_0:k_pad.size] - # else: - # Sk_pad[ind_0:ind_max] = Sk - - - # # iterative refinement of the PDF - # if iterative_pdf_refine: - # # pdf_reduced= np.maximum(pdf_reduced+ (r/r[-1]), 0.0) - - # ind_max = np.argmax(pdf) - # r_ind_max = r[ind_max] - # r_mask = np.minimum(r / r_ind_max, 1.0) - # r_mask = np.sin(r_mask*np.pi/2)**2 - - # pdf_reduced= np.maximum(pdf_reduced* r_mask + (r/r[-1]), 0.0) - # r_weight = r_mask * (1 - r / r[-1])**2 - - - - # # basis = np.vstack((np.ones_like(r),r)).T - # # coefs_lin = np.linalg.lstsq(basis, pdf, rcond=None)[0] - # # pdf_lin = basis * coefs_lin - # # print(coefs_lin) - - - # for a0 in range(10): - # Sk_back_proj = (1*np.pi/r.size)*np.sum( - # np.sin( - # 2*np.pi*ra*ka - # ) * pdf[None,:], - # axis=1, - # ) - - # Sk_diff = Sk - Sk_back_proj - # Sk_diff = (Sk_diff - np.mean(Sk_diff*mask)/mask_sum) * mask - - # pdf_update = 4*np.pi*dk*np.sum( - # np.sin( - # 8*np.pi*ra*ka - # ) * Sk_diff[:,None], - # axis=0, - # ) * r_weight - - # pdf_reduced= np.maximum(pdf_reduced+ 0.5*pdf_update, 0.0) - - # fig,ax = plt.subplots(figsize=figsize) - # ax.plot( - # r, - # pdf, - # color = 'k', - # ) - # # ax.plot( - # # r, - # # pdf_lin, - # # color = 'r', - # # ) - - # # ax.plot( - # # r, - # # pdf_reduced+ pdf_update, - # # color = 'r', - # # ) - - # # ax.plot( - # # k, - # # Sk, - # # color = 'k', - # # ) - # # ax.plot( - # # k, - # # Sk_back_proj, - # # color = 'r', - # # ) - # # ax.plot( - # # Sk_diff, - # # color = 'r', - # # ) # Plots @@ -509,17 +370,22 @@ def calculate_pair_dist_function( ax.set_ylabel('Pair Distribution Function') - # r = (np.min(Sk),np.max(Sk)) - # ax.set_ylim(( - # r[0]-0.05*(r[1]-r[0]), - # r[1]+0.05*(r[1]-r[0]), - # )) - - - # ax.set_yscale('log') + # functions for inverting from reduced PDF back to S(k) + # # invert + # ind_max = np.argmax(pdf_reduced* np.sqrt(r)) + # r_ind_max = r[ind_max-1] + # r_mask = np.minimum(r / (r_ind_max), 1.0) + # r_mask = np.sin(r_mask*np.pi/2)**2 + # Sk_back_proj = (0.5*r_step)*np.sum( + # np.sin( + # 2*np.pi*ra*ka + # ) * pdf_corr[None,:],# * r_mask[None,:], + # # ) * pdf_corr[None,:],# * r_mask[None,:], + # axis=1, + # ) From ea5e723c11d3167f5dbeb80d9c9721ce2bfb26b2 Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 16:32:52 +1000 Subject: [PATCH 018/176] Revert unneeded changes --- py4DSTEM/datacube/virtualimage.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 5e2681eb6..ad6344c7d 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -325,7 +325,7 @@ def position_detector( shift_center = False, scan_position = None, invert = False, - color = 'c', + color = 'r', alpha = 0.7, **kwargs ): @@ -383,7 +383,6 @@ def position_detector( if data is None: image = None keys = ['dp_mean','dp_max','dp_median'] - image = None for k in keys: try: image = self.tree(k) From 1ec56a9ab56e6c57cd804184482bd75899132998 Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 16:34:07 +1000 Subject: [PATCH 019/176] minor --- py4DSTEM/preprocess/preprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index 4001f80cb..9b67392d1 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -257,8 +257,8 @@ def bin_data_diffraction( R_Nx, R_Ny, int(Q_Nx / bin_factor), - bin_factor, - int(Q_Ny / bin_factor), + bin_factor, + int(Q_Ny / bin_factor), bin_factor, ).sum(axis=(3, 5)).astype(dtype) From 703dd7d14a93acbb1eba9df62e9378f9a6ee7d54 Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 16:46:26 +1000 Subject: [PATCH 020/176] Removing deprecated function Need to pass checks! --- py4DSTEM/process/polar/polar_datacube.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index cc3f534c6..591943444 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -95,7 +95,6 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( - # calculate_FEM_global, calculate_radial_statistics, plot_radial_mean, plot_radial_var_norm, From 86b8f9469bc50b04d3ad9a282cfbe8deab4032ee Mon Sep 17 00:00:00 2001 From: cophus Date: Mon, 21 Aug 2023 16:50:13 +1000 Subject: [PATCH 021/176] Fixing build --- py4DSTEM/process/polar/polar_analysis.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 22b38d62f..0f355bbd1 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -13,7 +13,8 @@ def calculate_radial_statistics( self, median_local = False, median_global = False, - plot_results = False, + plot_results_mean = False, + plot_results_var = False, figsize = (8,4), returnval = False, returnfig = False, @@ -82,15 +83,27 @@ def calculate_radial_statistics( self.radial_var_norm[sub] /= self.radial_mean[sub]**2 # plot results - if plot_results: + if plot_results_mean: if returnfig: - fig,ax = plot_FEM_global( + fig,ax = plot_radial_mean( self, figsize = figsize, returnfig = True, ) else: - plot_FEM_global( + plot_radial_mean( + self, + figsize = figsize, + ) + elif plot_results_var: + if returnfig: + fig,ax = plot_radial_var_norm( + self, + figsize = figsize, + returnfig = True, + ) + else: + plot_radial_var_norm( self, figsize = figsize, ) From aac2d12a3c7d17ca471be504c828119ab3a127a7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 21 Aug 2023 08:43:11 -0700 Subject: [PATCH 022/176] minor tv bugfix --- .../process/phase/iterative_multislice_ptychography.py | 8 ++++---- .../process/phase/iterative_ptychographic_constraints.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 5966ca07a..307095960 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1474,7 +1474,7 @@ def _object_denoise_tv_pylops(self, current_object, weights): if xp.iscomplexobj(current_object): current_object_tv = current_object warnings.warn( - ("tv_denoise currently for potential objects only"), + ("TV denoising is currently only supported for potential objects."), UserWarning, ) @@ -1484,6 +1484,7 @@ def _object_denoise_tv_pylops(self, current_object, weights): current_object = xp.pad( current_object, pad_width=pad_width, mode="constant" ) + # run tv denoising nz, nx, ny = current_object.shape niter_out = 40 @@ -1509,11 +1510,10 @@ def _object_denoise_tv_pylops(self, current_object, weights): show=False, )[0] - current_object_tv = current_object_tv.reshape(current_object.shape) - # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] - return current_object_tv[1:-1] + return current_object_tv def _constraints( self, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index e300e1154..95c2a9531 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -207,7 +207,7 @@ def _object_denoise_tv_pylops(self, current_object, weight): if xp.iscomplexobj(current_object): current_object_tv = current_object warnings.warn( - ("tv_denoise currently for potential objects only"), + ("TV denoising is currently only supported for potential objects."), UserWarning, ) @@ -287,11 +287,10 @@ def _object_denoise_tv_chambolle( if xp.iscomplexobj(current_object): updated_object = current_object warnings.warn( - ("tv_denoise currently for potential objects only"), + ("TV denoising is currently only supported for potential objects."), UserWarning, ) else: - current_object_sum = xp.sum(current_object) if axis is None: ndim = xp.arange(current_object.ndim).tolist() From da85db0da8524d935d4caa2a798c7f9e0c2cb079 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 21 Aug 2023 09:53:13 -0700 Subject: [PATCH 023/176] improvements for depth plotting --- .../iterative_multislice_ptychography.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 307095960..ddedac229 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2933,6 +2933,7 @@ def show_depth( x2: float, y1: float, y2: float, + specify_calibrated: bool = False, gaussian_filter_sigma: float = None, ms_object=None, cbar: bool = False, @@ -2947,6 +2948,9 @@ def show_depth( -------- x1, x2, y1, y2: floats (pixels) Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels gaussian_filter_sigma: float (optional) Standard deviation of gaussian kernel in A ms_object: np.array @@ -2962,11 +2966,31 @@ def show_depth( ms_obj = ms_object else: ms_obj = self.object_cropped - angle = np.arctan((x2 - x1) / (y2 - y1)) + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) x0 = ms_obj.shape[1] / 2 y0 = ms_obj.shape[2] / 2 + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + from py4DSTEM.process.phase.utils import rotate_point x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) From a5822129726f28992844f98e94cbcb45b79d9786 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 21 Aug 2023 19:26:26 -0700 Subject: [PATCH 024/176] tv fixes --- .../iterative_multislice_ptychography.py | 91 ++++++++++++++----- .../iterative_ptychographic_constraints.py | 8 +- .../iterative_singleslice_ptychography.py | 14 ++- 3 files changed, 86 insertions(+), 27 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index ddedac229..d9bf6bacf 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -1451,7 +1451,7 @@ def _object_identical_slices_constraint(self, current_object): return current_object - def _object_denoise_tv_pylops(self, current_object, weights): + def _object_denoise_tv_pylops(self, current_object, weights, iterations): """ Performs second order TV denoising along x and y @@ -1462,6 +1462,9 @@ def _object_denoise_tv_pylops(self, current_object, weights): weights : [float, float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops Returns ------- @@ -1487,28 +1490,66 @@ def _object_denoise_tv_pylops(self, current_object, weights): # run tv denoising nz, nx, ny = current_object.shape - niter_out = 40 + niter_out = iterations niter_in = 1 Iop = pylops.Identity(nx * ny * nz) - z_gradient = pylops.FirstDerivative( - (nz, nx, ny), axis=0, edge=False, kind="backward" - ) - xy_laplacian = pylops.Laplacian( - (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" - ) - l1_regs = [z_gradient, xy_laplacian] - - current_object_tv = pylops.optimization.sparsity.splitbregman( - Op=Iop, - y=current_object.ravel(), - RegsL1=l1_regs, - niter_outer=niter_out, - niter_inner=niter_in, - epsRL1s=weights, - tol=1e-4, - tau=1.0, - show=False, - )[0] + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] # remove padding current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] @@ -1552,6 +1593,7 @@ def _constraints( tv_denoise_pad_chambolle, tv_denoise, tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1627,6 +1669,8 @@ def _constraints( tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1661,6 +1705,7 @@ def _constraints( current_object = self._object_denoise_tv_pylops( current_object, tv_denoise_weights, + tv_denoise_inner_iter, ) elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( @@ -1771,6 +1816,7 @@ def reconstruct( tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, tv_denoise_weights=None, + tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1874,6 +1920,8 @@ def reconstruct( tv_denoise_weights: [float,float] Denoising weights[z weight, r weight]. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -2222,6 +2270,7 @@ def reconstruct( tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 95c2a9531..217253945 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -185,7 +185,7 @@ def _object_butterworth_constraint( return current_object - def _object_denoise_tv_pylops(self, current_object, weight): + def _object_denoise_tv_pylops(self, current_object, weight, iterations): """ Performs second order TV denoising along x and y @@ -196,6 +196,10 @@ def _object_denoise_tv_pylops(self, current_object, weight): weight : float Denoising weight. The greater `weight`, the more denoising (at the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + Returns ------- constrained_object: np.ndarray @@ -213,7 +217,7 @@ def _object_denoise_tv_pylops(self, current_object, weight): else: nx, ny = current_object.shape - niter_out = 40 + niter_out = iterations niter_in = 1 Iop = pylops.Identity(nx * ny) xy_laplacian = pylops.Laplacian( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 2480974f3..97c7a3e5d 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1025,6 +1025,7 @@ def _constraints( butterworth_order, tv_denoise, tv_denoise_weight, + tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, @@ -1082,8 +1083,10 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter tv_denoise: bool If True, applies TV denoising on object - tv_denoise_weight: float + tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1116,8 +1119,7 @@ def _constraints( if tv_denoise: current_object = self._object_denoise_tv_pylops( - current_object, - tv_denoise_weight, + current_object, tv_denoise_weight, tv_denoise_inner_iter ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1212,6 +1214,7 @@ def reconstruct( butterworth_order: float = 2, tv_denoise_iter: int = np.inf, tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1300,8 +1303,10 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter tv_denoise_iter: int, optional Number of iterations to run using tv denoise filter on object - tv_denoise_weight: float + tv_denoise_weight: float Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1638,6 +1643,7 @@ def reconstruct( butterworth_order=butterworth_order, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse From eaa3699a9a8c00a48bb18a6e8e8efe6f595f2397 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 24 Aug 2023 17:22:14 -0700 Subject: [PATCH 025/176] everyone gets TV denoising --- .../iterative_mixedstate_ptychography.py | 26 +++ .../iterative_overlap_magnetic_tomography.py | 168 +++++++++++++++++- .../phase/iterative_overlap_tomography.py | 147 ++++++++++++++- .../iterative_simultaneous_ptychography.py | 30 ++++ 4 files changed, 361 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 56fec1004..d066c7f3f 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1125,6 +1125,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, orthogonalize_probe, object_positivity, shrinkage_rad, @@ -1183,6 +1186,12 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter orthogonalize_probe: bool If True, probe will be orthogonalized + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1213,6 +1222,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1290,6 +1304,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1373,6 +1390,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1707,6 +1730,9 @@ def reconstruct( q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 8691a121d..2642b7193 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize import show @@ -1679,6 +1680,111 @@ def _divergence_free_constraint(self, vector_field): return vector_field + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1710,6 +1816,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1771,6 +1880,15 @@ def _constraints( If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1822,6 +1940,31 @@ def _constraints( butterworth_order, ) + elif tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[1] = self._object_denoise_tv_pylops( + current_object[1], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[2] = self._object_denoise_tv_pylops( + current_object[2], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[3] = self._object_denoise_tv_pylops( + current_object[3], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object[0] = self._object_shrinkage_constraint( current_object[0], @@ -1913,6 +2056,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1998,6 +2144,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2477,6 +2632,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2487,11 +2646,7 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - ( - self._object, - self._probe, - _, - ) = self._constraints( + (self._object, self._probe, _,) = self._constraints( self._object, self._probe, None, @@ -2530,6 +2685,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + v_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..9cfec2b39 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize import show @@ -1527,6 +1528,111 @@ def _object_butterworth_constraint( current_object += current_object_mean return xp.real(current_object) + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1555,6 +1661,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1611,6 +1720,13 @@ def _constraints( Phase shift in radians to be subtracted from the potential at each iteration object_mask: np.ndarray (boolean) If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1634,6 +1750,12 @@ def _constraints( q_highpass, butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( @@ -1723,6 +1845,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1806,6 +1931,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2203,6 +2337,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2211,11 +2349,7 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - ( - self._object, - self._probe, - _, - ) = self._constraints( + (self._object, self._probe, _,) = self._constraints( self._object, self._probe, None, @@ -2251,6 +2385,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + v_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8881d021c..a19fc82d3 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -2232,6 +2232,9 @@ def _constraints( q_highpass_e, q_highpass_m, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, warmup_iteration, object_positivity, shrinkage_rad, @@ -2300,6 +2303,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising warmup_iteration: bool If True, constraints electrostatic object only object_positivity: bool @@ -2349,6 +2358,15 @@ def _constraints( if self._object_type == "complex": magnetic_obj = magnetic_obj.real + if tv_denoise: + electrostatic_obj = self._object_denoise_tv_pylops( + electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) + + if not warmup_iteration: + magnetic_obj = self._object_denoise_tv_pylops( + magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) if shrinkage_rad > 0.0 or object_mask is not None: electrostatic_obj = self._object_shrinkage_constraint( @@ -2446,6 +2464,9 @@ def reconstruct( q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -2538,6 +2559,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -2899,6 +2926,9 @@ def reconstruct( q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse From bc133b1be2c2396e584e1f30eed94bd7c848adb6 Mon Sep 17 00:00:00 2001 From: cophus Date: Tue, 29 Aug 2023 15:42:00 +1000 Subject: [PATCH 026/176] Adding moire lattice generation and plotting --- py4DSTEM/process/diffraction/crystal.py | 451 +++++++++++++++++- py4DSTEM/process/diffraction/crystal_phase.py | 1 + 2 files changed, 449 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index fa735438b..6e167dfbf 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional from scipy.optimize import curve_fit @@ -744,9 +745,12 @@ def generate_diffraction_pattern( ) bragg_peaks = PointList(np.array([], dtype=pl_dtype)) if np.any(keep_int): - bragg_peaks.add_data_by_field( - [gx_proj, gy_proj, gz_proj, g_int[keep_int], h, k, l] - ) + bragg_peaks.add_data_by_field([ + gx_proj, + gy_proj, + gz_proj, + g_int[keep_int], + h,k,l]) else: pl_dtype = np.dtype( [ @@ -1074,3 +1078,444 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp + + + +def generate_moire( + bragg_peaks_0, + bragg_peaks_1, + thresh_0 = 0.0002, + thresh_1 = 0.0002, + int_range = (0,5e-3), + exx_1 = 0.0, + eyy_1 = 0.0, + exy_1 = 0.0, + phi_1 = 0.0, + power = 2.0, + k_max = 1.0, + plot_result = True, + plot_subpixel = True, + labels = None, + marker_size_parent = 16, + marker_size_moire = 4, + text_size_parent = 10, + text_size_moire = 6, + add_labels_parent = False, + add_labels_moire = False, + dist_labels = 0.03, + dist_check = 0.06, + sep_labels = 0.03, + figsize = (8,6), + return_moire = False, + returnfig = False, + ): + """ + Calculate a Moire lattice from 2 parent diffraction patterns. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + thresh_0: float + thresh_1: float + int_range: (float, float) + exx_1: float + eyy_1: float + exy_1: float + phi_1: float + power: float + k_max: float + plot_result: bool + plot_subpixel: bool + labels: list + List of text labels for parent lattices + marker_size_parent: float + marker_size_moire: float + text_size_parent: float + text_size_moire: float + add_labels_parent: bool + add_labels_moire: bool + dist_labels: float + dist_check: float + sep_labels: float + figsize: (float,float) + return_moire: bool + returnfig: bool + + Returns + -------- + bragg_peaksMoire: BraggVector (optjonal) + Bragg vectors for moire lattice. + fig, ax: matplotlib handles (optional) + Figure and axes handles for the moire plot. + + """ + + # peak labels + if labels is None: + labels = ('crystal 0', 'crystal 1') + + # get intenties of all peaks + int0 = bragg_peaks_0['intensity']**(power/2.0) + int1 = bragg_peaks_1['intensity']**(power/2.0) + + # peaks above threshold + sub0 = int0 >= thresh_0 + sub1 = int1 >= thresh_1 + + # Remove origin (assuming brightest peak) + ind0_or = np.argmax(bragg_peaks_0['intensity']) + ind1_or = np.argmax(bragg_peaks_1['intensity']) + sub0[ind0_or] = False + sub1[ind1_or] = False + int0_sub = int0[sub0] + int1_sub = int1[sub1] + + # Get peaks + qx0 = bragg_peaks_0['qx'][sub0] + qy0 = bragg_peaks_0['qy'][sub0] + qx1_init = bragg_peaks_1['qx'][sub1] + qy1_init = bragg_peaks_1['qy'][sub1] + + # peak labels + if add_labels_parent or add_labels_moire or return_moire: + def overline(x): + return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") + + h0 = bragg_peaks_0['h'][sub0] + k0 = bragg_peaks_0['k'][sub0] + l0 = bragg_peaks_0['l'][sub0] + h1 = bragg_peaks_1['h'][sub1] + k1 = bragg_peaks_1['k'][sub1] + l1 = bragg_peaks_1['l'][sub1] + + # apply strain tensor to lattice 1 + # infinitesimal + # m = np.array([ + # [1 + exx_1, (exy_1 - phi_1)*0.5], + # [(exy_1 _ phi_1)*0.5, 1 + eyy_1], + # ]) + # finite rotation + m = np.array([ + [np.cos(phi_1), -np.sin(phi_1)], + [np.sin(phi_1), np.cos(phi_1)], + ]) @ np.array([ + [1 + exx_1, exy_1*0.5], + [exy_1*0.5, 1 + eyy_1], + ]) + qx1 = m[0,0] * qx1_init + m[0,1] * qy1_init + qy1 = m[1,0] * qx1_init + m[1,1] * qy1_init + + # Generate moire lattice + ind0, ind1 = np.meshgrid( + np.arange(np.sum(sub0)), + np.arange(np.sum(sub1)), + indexing = 'ij', + ) + # ind0 = ind0.ravel() + # ind1 = ind1.ravel() + qx = qx0[ind0] + qx1[ind1] + qy = qy0[ind0] + qy1[ind1] + # int_moire = int0_sub[ind0] + int1_sub[ind1] + int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 + + # moire labels + if add_labels_moire or return_moire: + m_h0 = h0[ind0] + m_k0 = k0[ind0] + m_l0 = l0[ind0] + m_h1 = h1[ind1] + m_k1 = k1[ind1] + m_l1 = l1[ind1] + + # If needed, convert moire peaks to BraggVector class + if return_moire: + pl_dtype = np.dtype([ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ]) + bragg_moire = PointList( + np.array([],dtype=pl_dtype) + ) + bragg_moire.add_data_by_field([ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(),m_k0.ravel(),m_l0.ravel(), + m_h1.ravel(),m_k1.ravel(),m_l1.ravel(), + ]) + + + # plot outputs + if plot_result: + fig = plt.figure(figsize = figsize) + ax = fig.add_axes([0.09,0.09,0.65,0.9]) + ax_labels = fig.add_axes([0.75,0,0.25,1]) + + + text_params_parent = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_parent, + } + text_params_moire = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_moire, + } + + + if plot_subpixel is False: + + # moire + ax.scatter( + qy, + qx, + # color = (0,0,0,1), + c = int_moire, + s = marker_size_moire, + cmap = 'gray_r', + vmin = int_range[0], + vmax = int_range[1], + antialiased=True, + ) + + # parent lattices + ax.scatter( + qy0, + qx0, + color = (1,0,0,1), + s = marker_size_parent, + antialiased=True, + ) + ax.scatter( + qy1, + qx1, + color = (0,0.7,1,1), + s = marker_size_parent, + antialiased=True, + ) + + # origin + ax.scatter( + 0, + 0, + color = (0,0,0,1), + s = marker_size_parent, + antialiased=True, + ) + + else: + # moire peaks + int_all = np.clip( + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), + 0,1) + keep = np.logical_and.reduce(( + qx >= -k_max, + qx <= k_max, + qy >= -k_max, + qy <= k_max + )) + for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): + ax.add_artist(Circle( + xy=(y, x), + radius = np.sqrt(marker_size_moire)/800.0, + color = (1-int_marker,1-int_marker,1-int_marker), + )) + if add_labels_moire: + for a0 in range(qx.size): + if keep.ravel()[a0]: + x0 = qx.ravel()[a0] + y0 = qy.ravel()[a0] + d2 = (qx.ravel()-x0)**2 + (qy.ravel()-y0)**2 + sub = d2 < dist_check**2 + xc = np.mean(qx.ravel()[sub]) + yc = np.mean(qy.ravel()[sub]) + xp = x0 - xc + yp = y0 - yc + if xp == 0 and yp == 0.0: + xp = x0 - dist_labels + yp = y0 + else: + leng = np.linalg.norm((xp,yp)) + xp = x0 + xp * dist_labels / leng + yp = y0 + yp * dist_labels / leng + + ax.text( + yp, + xp - sep_labels, + "$" + overline(m_h0.ravel()[a0]) \ + + overline(m_k0.ravel()[a0]) \ + + overline(m_l0.ravel()[a0]) + "$", + c = 'r', + **text_params_moire, + ) + ax.text( + yp, + xp, + "$" + overline(m_h1.ravel()[a0]) \ + + overline(m_k1.ravel()[a0]) \ + + overline(m_l1.ravel()[a0]) + "$", + c = (0,0.7,1.0), + **text_params_moire, + ) + + + keep = np.logical_and.reduce(( + qx0 >= -k_max, + qx0 <= k_max, + qy0 >= -k_max, + qy0 <= k_max + )) + for x, y in zip(qx0[keep], qy0[keep]): + ax.add_artist(Circle( + xy=(y, x), + radius = np.sqrt(marker_size_parent)/800.0, + color = (1,0,0), + )) + if add_labels_parent: + for a0 in range(qx0.size): + if keep.ravel()[a0]: + xp = qx0.ravel()[a0] - dist_labels + yp = qy0.ravel()[a0] + ax.text( + yp, + xp, + "$" + overline(h0.ravel()[a0]) \ + + overline(k0.ravel()[a0]) \ + + overline(l0.ravel()[a0]) + "$", + c = 'k', + **text_params_parent, + ) + + keep = np.logical_and.reduce(( + qx1 >= -k_max, + qx1 <= k_max, + qy1 >= -k_max, + qy1 <= k_max + )) + for x, y in zip(qx1[keep], qy1[keep]): + ax.add_artist(Circle( + xy=(y, x), + radius = np.sqrt(marker_size_parent)/800.0, + color = (0,0.7,1), + )) + if add_labels_parent: + for a0 in range(qx1.size): + if keep.ravel()[a0]: + xp = qx1.ravel()[a0] - dist_labels + yp = qy1.ravel()[a0] + ax.text( + yp, + xp, + "$" + overline(h1.ravel()[a0]) \ + + overline(k1.ravel()[a0]) \ + + overline(l1.ravel()[a0]) + "$", + c = 'k', + **text_params_parent, + ) + + # origin + ax.add_artist(Circle( + xy=(0, 0), + radius = np.sqrt(marker_size_parent)/800.0, + color = (0,0,0), + )) + + ax.set_xlim((-k_max,k_max)) + ax.set_ylim((-k_max,k_max)) + ax.set_ylabel('$q_x$ (1/A)') + ax.set_xlabel('$q_y$ (1/A)') + ax.invert_yaxis() + + # labels + ax_labels.scatter( + 0, + 0, + color = (1,0,0,1), + s = marker_size_parent, + ) + ax_labels.scatter( + 0, + -1, + color = (0,0.7,1,1), + s = marker_size_parent, + ) + ax_labels.scatter( + 0, + -2, + color = (0,0,0,1), + s = marker_size_moire, + ) + ax_labels.text( + 0.4, + -0.2, + labels[0], + fontsize = 14, + ) + ax_labels.text( + 0.4, + -1.2, + labels[1], + fontsize = 14, + ) + ax_labels.text( + 0.4, + -2.2, + 'Moiré lattice', + fontsize = 14, + ) + + ax_labels.text( + 0, + -4.2, + labels[1] + ' $\epsilon_{xx}$ = ' + str(np.round(exx_1*100,2)) + '%', + fontsize = 14, + ) + ax_labels.text( + 0, + -5.2, + labels[1] + ' $\epsilon_{yy}$ = ' + str(np.round(eyy_1*100,2)) + '%', + fontsize = 14, + ) + ax_labels.text( + 0, + -6.2, + labels[1] + ' $\epsilon_{xy}$ = ' + str(np.round(exy_1*100,2)) + '%', + fontsize = 14, + ) + ax_labels.text( + 0, + -7.2, + labels[1] + ' $\phi$ = ' + str(np.round(phi_1*180/np.pi,2)) + '$^\circ$', + fontsize = 14, + + ) + + ax_labels.set_xlim((-1,4)) + ax_labels.set_ylim((-21,1)) + + ax_labels.axis('off') + + if return_moire: + if returnfig: + return bragg_moire, fig, ax + else: + return bragg_moire + if returnfig: + return fig, ax + + diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 84824fe63..b0cb1fe16 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -4,6 +4,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt +from dataclasses import dataclass, field from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern From cc81ccd8de978f862908a2ace775ed5762fdc260 Mon Sep 17 00:00:00 2001 From: Steven Zeltmann Date: Tue, 29 Aug 2023 16:07:02 -0400 Subject: [PATCH 027/176] run black --- py4DSTEM/process/diffraction/crystal.py | 411 ++++++++++++------------ 1 file changed, 210 insertions(+), 201 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 6e167dfbf..12309ac4d 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -745,12 +745,9 @@ def generate_diffraction_pattern( ) bragg_peaks = PointList(np.array([], dtype=pl_dtype)) if np.any(keep_int): - bragg_peaks.add_data_by_field([ - gx_proj, - gy_proj, - gz_proj, - g_int[keep_int], - h,k,l]) + bragg_peaks.add_data_by_field( + [gx_proj, gy_proj, gz_proj, g_int[keep_int], h, k, l] + ) else: pl_dtype = np.dtype( [ @@ -1078,37 +1075,36 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp - def generate_moire( bragg_peaks_0, bragg_peaks_1, - thresh_0 = 0.0002, - thresh_1 = 0.0002, - int_range = (0,5e-3), - exx_1 = 0.0, - eyy_1 = 0.0, - exy_1 = 0.0, - phi_1 = 0.0, - power = 2.0, - k_max = 1.0, - plot_result = True, - plot_subpixel = True, - labels = None, - marker_size_parent = 16, - marker_size_moire = 4, - text_size_parent = 10, - text_size_moire = 6, - add_labels_parent = False, - add_labels_moire = False, - dist_labels = 0.03, - dist_check = 0.06, - sep_labels = 0.03, - figsize = (8,6), - return_moire = False, - returnfig = False, - ): + thresh_0=0.0002, + thresh_1=0.0002, + int_range=(0, 5e-3), + exx_1=0.0, + eyy_1=0.0, + exy_1=0.0, + phi_1=0.0, + power=2.0, + k_max=1.0, + plot_result=True, + plot_subpixel=True, + labels=None, + marker_size_parent=16, + marker_size_moire=4, + text_size_parent=10, + text_size_moire=6, + add_labels_parent=False, + add_labels_moire=False, + dist_labels=0.03, + dist_check=0.06, + sep_labels=0.03, + figsize=(8, 6), + return_moire=False, + returnfig=False, +): """ Calculate a Moire lattice from 2 parent diffraction patterns. @@ -1143,7 +1139,7 @@ def generate_moire( figsize: (float,float) return_moire: bool returnfig: bool - + Returns -------- bragg_peaksMoire: BraggVector (optjonal) @@ -1155,42 +1151,43 @@ def generate_moire( # peak labels if labels is None: - labels = ('crystal 0', 'crystal 1') + labels = ("crystal 0", "crystal 1") # get intenties of all peaks - int0 = bragg_peaks_0['intensity']**(power/2.0) - int1 = bragg_peaks_1['intensity']**(power/2.0) - + int0 = bragg_peaks_0["intensity"] ** (power / 2.0) + int1 = bragg_peaks_1["intensity"] ** (power / 2.0) + # peaks above threshold sub0 = int0 >= thresh_0 sub1 = int1 >= thresh_1 - + # Remove origin (assuming brightest peak) - ind0_or = np.argmax(bragg_peaks_0['intensity']) - ind1_or = np.argmax(bragg_peaks_1['intensity']) + ind0_or = np.argmax(bragg_peaks_0["intensity"]) + ind1_or = np.argmax(bragg_peaks_1["intensity"]) sub0[ind0_or] = False sub1[ind1_or] = False int0_sub = int0[sub0] int1_sub = int1[sub1] - + # Get peaks - qx0 = bragg_peaks_0['qx'][sub0] - qy0 = bragg_peaks_0['qy'][sub0] - qx1_init = bragg_peaks_1['qx'][sub1] - qy1_init = bragg_peaks_1['qy'][sub1] - + qx0 = bragg_peaks_0["qx"][sub0] + qy0 = bragg_peaks_0["qy"][sub0] + qx1_init = bragg_peaks_1["qx"][sub1] + qy1_init = bragg_peaks_1["qy"][sub1] + # peak labels if add_labels_parent or add_labels_moire or return_moire: + def overline(x): return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") - h0 = bragg_peaks_0['h'][sub0] - k0 = bragg_peaks_0['k'][sub0] - l0 = bragg_peaks_0['l'][sub0] - h1 = bragg_peaks_1['h'][sub1] - k1 = bragg_peaks_1['k'][sub1] - l1 = bragg_peaks_1['l'][sub1] - + h0 = bragg_peaks_0["h"][sub0] + k0 = bragg_peaks_0["k"][sub0] + l0 = bragg_peaks_0["l"][sub0] + h1 = bragg_peaks_1["h"][sub1] + k1 = bragg_peaks_1["k"][sub1] + l1 = bragg_peaks_1["l"][sub1] + # apply strain tensor to lattice 1 # infinitesimal # m = np.array([ @@ -1198,21 +1195,25 @@ def overline(x): # [(exy_1 _ phi_1)*0.5, 1 + eyy_1], # ]) # finite rotation - m = np.array([ - [np.cos(phi_1), -np.sin(phi_1)], - [np.sin(phi_1), np.cos(phi_1)], - ]) @ np.array([ - [1 + exx_1, exy_1*0.5], - [exy_1*0.5, 1 + eyy_1], - ]) - qx1 = m[0,0] * qx1_init + m[0,1] * qy1_init - qy1 = m[1,0] * qx1_init + m[1,1] * qy1_init - + m = np.array( + [ + [np.cos(phi_1), -np.sin(phi_1)], + [np.sin(phi_1), np.cos(phi_1)], + ] + ) @ np.array( + [ + [1 + exx_1, exy_1 * 0.5], + [exy_1 * 0.5, 1 + eyy_1], + ] + ) + qx1 = m[0, 0] * qx1_init + m[0, 1] * qy1_init + qy1 = m[1, 0] * qx1_init + m[1, 1] * qy1_init + # Generate moire lattice ind0, ind1 = np.meshgrid( np.arange(np.sum(sub0)), np.arange(np.sum(sub1)), - indexing = 'ij', + indexing="ij", ) # ind0 = ind0.ravel() # ind1 = ind1.ravel() @@ -1232,36 +1233,40 @@ def overline(x): # If needed, convert moire peaks to BraggVector class if return_moire: - pl_dtype = np.dtype([ - ("qx", "float"), - ("qy", "float"), - ("intensity", "float"), - ("h0", "int"), - ("k0", "int"), - ("l0", "int"), - ("h1", "int"), - ("k1", "int"), - ("l1", "int"), - ]) - bragg_moire = PointList( - np.array([],dtype=pl_dtype) + pl_dtype = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ] + ) + bragg_moire = PointList(np.array([], dtype=pl_dtype)) + bragg_moire.add_data_by_field( + [ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(), + m_k0.ravel(), + m_l0.ravel(), + m_h1.ravel(), + m_k1.ravel(), + m_l1.ravel(), + ] ) - bragg_moire.add_data_by_field([ - qx.ravel(), - qy.ravel(), - int_moire.ravel(), - m_h0.ravel(),m_k0.ravel(),m_l0.ravel(), - m_h1.ravel(),m_k1.ravel(),m_l1.ravel(), - ]) - - + # plot outputs if plot_result: - fig = plt.figure(figsize = figsize) - ax = fig.add_axes([0.09,0.09,0.65,0.9]) - ax_labels = fig.add_axes([0.75,0,0.25,1]) - - + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) + ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) + text_params_parent = { "ha": "center", "va": "center", @@ -1277,19 +1282,17 @@ def overline(x): "size": text_size_moire, } - if plot_subpixel is False: - # moire ax.scatter( qy, qx, # color = (0,0,0,1), - c = int_moire, - s = marker_size_moire, - cmap = 'gray_r', - vmin = int_range[0], - vmax = int_range[1], + c=int_moire, + s=marker_size_moire, + cmap="gray_r", + vmin=int_range[0], + vmax=int_range[1], antialiased=True, ) @@ -1297,15 +1300,15 @@ def overline(x): ax.scatter( qy0, qx0, - color = (1,0,0,1), - s = marker_size_parent, + color=(1, 0, 0, 1), + s=marker_size_parent, antialiased=True, ) ax.scatter( qy1, qx1, - color = (0,0.7,1,1), - s = marker_size_parent, + color=(0, 0.7, 1, 1), + s=marker_size_parent, antialiased=True, ) @@ -1313,34 +1316,33 @@ def overline(x): ax.scatter( 0, 0, - color = (0,0,0,1), - s = marker_size_parent, + color=(0, 0, 0, 1), + s=marker_size_parent, antialiased=True, ) - + else: # moire peaks int_all = np.clip( - (int_moire - int_range[0]) / (int_range[1] - int_range[0]), - 0,1) - keep = np.logical_and.reduce(( - qx >= -k_max, - qx <= k_max, - qy >= -k_max, - qy <= k_max - )) + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1 + ) + keep = np.logical_and.reduce( + (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max) + ) for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): - ax.add_artist(Circle( - xy=(y, x), - radius = np.sqrt(marker_size_moire)/800.0, - color = (1-int_marker,1-int_marker,1-int_marker), - )) + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_moire) / 800.0, + color=(1 - int_marker, 1 - int_marker, 1 - int_marker), + ) + ) if add_labels_moire: for a0 in range(qx.size): if keep.ravel()[a0]: x0 = qx.ravel()[a0] y0 = qy.ravel()[a0] - d2 = (qx.ravel()-x0)**2 + (qy.ravel()-y0)**2 + d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2 sub = d2 < dist_check**2 xc = np.mean(qx.ravel()[sub]) yc = np.mean(qy.ravel()[sub]) @@ -1350,42 +1352,44 @@ def overline(x): xp = x0 - dist_labels yp = y0 else: - leng = np.linalg.norm((xp,yp)) + leng = np.linalg.norm((xp, yp)) xp = x0 + xp * dist_labels / leng yp = y0 + yp * dist_labels / leng ax.text( yp, xp - sep_labels, - "$" + overline(m_h0.ravel()[a0]) \ - + overline(m_k0.ravel()[a0]) \ - + overline(m_l0.ravel()[a0]) + "$", - c = 'r', + "$" + + overline(m_h0.ravel()[a0]) + + overline(m_k0.ravel()[a0]) + + overline(m_l0.ravel()[a0]) + + "$", + c="r", **text_params_moire, ) ax.text( yp, xp, - "$" + overline(m_h1.ravel()[a0]) \ - + overline(m_k1.ravel()[a0]) \ - + overline(m_l1.ravel()[a0]) + "$", - c = (0,0.7,1.0), + "$" + + overline(m_h1.ravel()[a0]) + + overline(m_k1.ravel()[a0]) + + overline(m_l1.ravel()[a0]) + + "$", + c=(0, 0.7, 1.0), **text_params_moire, ) - - - keep = np.logical_and.reduce(( - qx0 >= -k_max, - qx0 <= k_max, - qy0 >= -k_max, - qy0 <= k_max - )) + + keep = np.logical_and.reduce( + (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max) + ) for x, y in zip(qx0[keep], qy0[keep]): - ax.add_artist(Circle( - xy=(y, x), - radius = np.sqrt(marker_size_parent)/800.0, - color = (1,0,0), - )) + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(1, 0, 0), + ) + ) if add_labels_parent: for a0 in range(qx0.size): if keep.ravel()[a0]: @@ -1394,25 +1398,26 @@ def overline(x): ax.text( yp, xp, - "$" + overline(h0.ravel()[a0]) \ - + overline(k0.ravel()[a0]) \ - + overline(l0.ravel()[a0]) + "$", - c = 'k', + "$" + + overline(h0.ravel()[a0]) + + overline(k0.ravel()[a0]) + + overline(l0.ravel()[a0]) + + "$", + c="k", **text_params_parent, ) - - keep = np.logical_and.reduce(( - qx1 >= -k_max, - qx1 <= k_max, - qy1 >= -k_max, - qy1 <= k_max - )) + + keep = np.logical_and.reduce( + (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max) + ) for x, y in zip(qx1[keep], qy1[keep]): - ax.add_artist(Circle( - xy=(y, x), - radius = np.sqrt(marker_size_parent)/800.0, - color = (0,0.7,1), - )) + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0.7, 1), + ) + ) if add_labels_parent: for a0 in range(qx1.size): if keep.ravel()[a0]: @@ -1421,101 +1426,105 @@ def overline(x): ax.text( yp, xp, - "$" + overline(h1.ravel()[a0]) \ - + overline(k1.ravel()[a0]) \ - + overline(l1.ravel()[a0]) + "$", - c = 'k', + "$" + + overline(h1.ravel()[a0]) + + overline(k1.ravel()[a0]) + + overline(l1.ravel()[a0]) + + "$", + c="k", **text_params_parent, ) - + # origin - ax.add_artist(Circle( - xy=(0, 0), - radius = np.sqrt(marker_size_parent)/800.0, - color = (0,0,0), - )) - - ax.set_xlim((-k_max,k_max)) - ax.set_ylim((-k_max,k_max)) - ax.set_ylabel('$q_x$ (1/A)') - ax.set_xlabel('$q_y$ (1/A)') + ax.add_artist( + Circle( + xy=(0, 0), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0, 0), + ) + ) + + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + ax.set_ylabel("$q_x$ (1/A)") + ax.set_xlabel("$q_y$ (1/A)") ax.invert_yaxis() # labels ax_labels.scatter( 0, 0, - color = (1,0,0,1), - s = marker_size_parent, + color=(1, 0, 0, 1), + s=marker_size_parent, ) ax_labels.scatter( 0, -1, - color = (0,0.7,1,1), - s = marker_size_parent, + color=(0, 0.7, 1, 1), + s=marker_size_parent, ) ax_labels.scatter( 0, -2, - color = (0,0,0,1), - s = marker_size_moire, + color=(0, 0, 0, 1), + s=marker_size_moire, ) ax_labels.text( 0.4, -0.2, labels[0], - fontsize = 14, + fontsize=14, ) ax_labels.text( 0.4, -1.2, labels[1], - fontsize = 14, + fontsize=14, ) ax_labels.text( 0.4, -2.2, - 'Moiré lattice', - fontsize = 14, + "Moiré lattice", + fontsize=14, ) - + ax_labels.text( 0, -4.2, - labels[1] + ' $\epsilon_{xx}$ = ' + str(np.round(exx_1*100,2)) + '%', - fontsize = 14, + labels[1] + " $\epsilon_{xx}$ = " + str(np.round(exx_1 * 100, 2)) + "%", + fontsize=14, ) ax_labels.text( 0, -5.2, - labels[1] + ' $\epsilon_{yy}$ = ' + str(np.round(eyy_1*100,2)) + '%', - fontsize = 14, + labels[1] + " $\epsilon_{yy}$ = " + str(np.round(eyy_1 * 100, 2)) + "%", + fontsize=14, ) ax_labels.text( 0, -6.2, - labels[1] + ' $\epsilon_{xy}$ = ' + str(np.round(exy_1*100,2)) + '%', - fontsize = 14, + labels[1] + " $\epsilon_{xy}$ = " + str(np.round(exy_1 * 100, 2)) + "%", + fontsize=14, ) ax_labels.text( 0, -7.2, - labels[1] + ' $\phi$ = ' + str(np.round(phi_1*180/np.pi,2)) + '$^\circ$', - fontsize = 14, - + labels[1] + + " $\phi$ = " + + str(np.round(phi_1 * 180 / np.pi, 2)) + + "$^\circ$", + fontsize=14, ) - - ax_labels.set_xlim((-1,4)) - ax_labels.set_ylim((-21,1)) - - ax_labels.axis('off') - + + ax_labels.set_xlim((-1, 4)) + ax_labels.set_ylim((-21, 1)) + + ax_labels.axis("off") + if return_moire: if returnfig: return bragg_moire, fig, ax else: return bragg_moire if returnfig: - return fig, ax - - + return fig, ax From cc364907faf66d6351e10861c5991ad492c8b380 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 29 Aug 2023 13:30:43 -0700 Subject: [PATCH 028/176] subpixel alignment phase correct, part 1 --- py4DSTEM/process/phase/iterative_parallax.py | 151 ++++++++++++++++++- 1 file changed, 143 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index c2bfc8739..9fe1fbc90 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -895,7 +895,8 @@ def subpixel_alignment( xy_shifts = self._xy_shifts BF_size = np.array(self._stack_BF_no_window.shape[-2:]) - pixel_output = BF_size * kde_upsample_factor + self._kde_upsample_factor = kde_upsample_factor + pixel_output = BF_size * self._kde_upsample_factor pixel_size = pixel_output.prod() # shifted coordinates @@ -903,8 +904,8 @@ def subpixel_alignment( y = xp.arange(BF_size[1]) xa, ya = xp.meshgrid(x, y, indexing="ij") - xa = ((xa + xy_shifts[:, 0, None, None]) * kde_upsample_factor).ravel() - ya = ((ya + xy_shifts[:, 1, None, None]) * kde_upsample_factor).ravel() + xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() # bilinear sampling xF = xp.floor(xa).astype("int") @@ -948,7 +949,7 @@ def subpixel_alignment( ) # kernel density estimate - sigma = kde_sigma * kde_upsample_factor + sigma = kde_sigma * self._kde_upsample_factor pix_count = gaussian_filter(pix_count, sigma) pix_count[pix_output == 0.0] = np.inf pix_output = gaussian_filter(pix_output, sigma) @@ -970,8 +971,12 @@ def subpixel_alignment( cmap = kwargs.pop("cmap", "magma") cropped_object = self._crop_padded_object(self._recon_BF) - upsampled_pad_x = self._object_padding_px[0] * kde_upsample_factor // 2 - upsampled_pad_y = self._object_padding_px[1] * kde_upsample_factor // 2 + upsampled_pad_x = ( + self._object_padding_px[0] * self._kde_upsample_factor // 2 + ) + upsampled_pad_y = ( + self._object_padding_px[1] * self._kde_upsample_factor // 2 + ) cropped_object_aligned = self.recon_BF_subpixel_aligned[ upsampled_pad_x:-upsampled_pad_x, upsampled_pad_y:-upsampled_pad_y, @@ -1007,8 +1012,8 @@ def subpixel_alignment( if plot_upsampled_FFT_comparison: recon_fft = xp.fft.fft2(self._recon_BF) recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) - pad_x = BF_size[0] * (kde_upsample_factor - 1) // 2 - pad_y = BF_size[1] * (kde_upsample_factor - 1) // 2 + pad_x = BF_size[0] * (self._kde_upsample_factor - 1) // 2 + pad_y = BF_size[1] * (self._kde_upsample_factor - 1) // 2 pad_recon_fft = asnumpy( xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) ) @@ -1318,6 +1323,136 @@ def aberration_correct( ax.set_xlabel("y [A]") ax.set_title("Parallax-Corrected Phase Image") + def subpixel_aberration_correct( + self, + plot_corrected_phase: bool = True, + k_info_limit: float = None, + k_info_power: float = 1.0, + Wiener_filter=False, + Wiener_signal_noise_ratio=1.0, + Wiener_filter_low_only=False, + **kwargs, + ): + """ + CTF correction of the phase image using the measured defocus aberration. + + Parameters + ---------- + plot_corrected_phase: bool, optional + If True, the CTF-corrected phase is plotted + k_info_limit: float, optional + maximum allowed frequency in butterworth filter + k_info_power: float, optional + power of butterworth filter + Wiener_filter: bool, optional + Use Wiener filtering instead of CTF sign correction. + Wiener_signal_noise_ratio: float, optional + Signal to noise radio at k = 0 for Wiener filter + Wiener_filter_low_only: bool, optional + Apply Wiener filtering only to the CTF portions before the 1st CTF maxima. + """ + + xp = self._xp + asnumpy = self._asnumpy + + if not hasattr(self, "aberration_C1"): + raise ValueError( + ( + "CTF correction is meant to be ran after alignment and aberration fitting. " + "Please run the `reconstruct()` and `aberration_fit()` functions first." + ) + ) + + # Fourier coordinates + kx = xp.fft.fftfreq( + self._recon_BF_subpixel_aligned.shape[0], + self._scan_sampling[0] / self._kde_upsample_factor, + ) + ky = xp.fft.fftfreq( + self._recon_BF_subpixel_aligned.shape[1], + self._scan_sampling[1] / self._kde_upsample_factor, + ) + kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 + + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr + print(self._recon_BF_subpixel_aligned.shape) + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) + + # Output phase image + self._recon_phase_corrected_subpixel_aligned = xp.real( + xp.fft.ifft2(im_fft_corr) + ) + self.recon_phase_corrected_subpixel_aligned = asnumpy( + self._recon_phase_corrected_subpixel_aligned + ) + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + # plotting + if plot_corrected_phase: + figsize = kwargs.pop("figsize", (6, 6)) + cmap = kwargs.pop("cmap", "magma") + + fig, ax = plt.subplots(figsize=figsize) + + cropped_object = self._crop_padded_object(self._recon_phase_corrected) + + extent = [ + 0, + self._scan_sampling[1] + / self._kde_upsample_factor + * cropped_object.shape[1], + self._scan_sampling[0] + / self._kde_upsample_factor + * cropped_object.shape[0], + 0, + ] + + ax.imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_title("Parallax-Corrected Phase Image Subpixel Aligned") + def depth_section( self, depth_angstroms=np.arange(-250, 260, 100), From 41f319d09720e2cbc2fcfb8d04dd70f351905789 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 29 Aug 2023 13:31:53 -0700 Subject: [PATCH 029/176] removing print statement --- py4DSTEM/process/phase/iterative_parallax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 9fe1fbc90..36d90a6f0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1403,7 +1403,6 @@ def subpixel_aberration_correct( # apply correction to mean reconstructed BF image im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr - print(self._recon_BF_subpixel_aligned.shape) # if needed, add low pass filter output image if k_info_limit is not None: im_fft_corr /= 1 + (kra2**k_info_power) / ( From 95beba51cfe84ba8a85614b5a7fc7f20dfd9222f Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 7 Sep 2023 05:38:41 -0700 Subject: [PATCH 030/176] parallax plotting fix --- py4DSTEM/process/phase/iterative_parallax.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 36d90a6f0..d8824b770 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1428,7 +1428,9 @@ def subpixel_aberration_correct( fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_phase_corrected) + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) extent = [ 0, @@ -1594,6 +1596,7 @@ def _crop_padded_object( self, padded_object: np.ndarray, remaining_padding: int = 0, + upsampled: bool = False, ): """ Utility function to crop padded object @@ -1617,6 +1620,10 @@ def _crop_padded_object( pad_x = self._object_padding_px[0] // 2 - remaining_padding pad_y = self._object_padding_px[1] // 2 - remaining_padding + if upsampled == True: + pad_x *= self._kde_upsample_factor + pad_y *= self._kde_upsample_factor + return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) def _visualize_figax( From a74b583f69c4c02459efaa8db74981c86a897947 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 7 Sep 2023 20:50:11 -0700 Subject: [PATCH 031/176] reorganization and bug fix --- py4DSTEM/__init__.py | 2 +- py4DSTEM/process/__init__.py | 2 +- py4DSTEM/process/latticevectors/__init__.py | 1 - py4DSTEM/process/latticevectors/fit.py | 129 ------ py4DSTEM/process/latticevectors/index.py | 131 ------ py4DSTEM/process/latticevectors/strain.py | 231 ---------- py4DSTEM/process/strain/__init__.py | 2 + py4DSTEM/process/strain/latticevectors.py | 448 ++++++++++++++++++++ py4DSTEM/process/{ => strain}/strain.py | 152 ++++--- 9 files changed, 537 insertions(+), 561 deletions(-) delete mode 100644 py4DSTEM/process/latticevectors/strain.py create mode 100644 py4DSTEM/process/strain/__init__.py create mode 100644 py4DSTEM/process/strain/latticevectors.py rename py4DSTEM/process/{ => strain}/strain.py (86%) diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index dcb6a861d..adf757d1b 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -53,7 +53,7 @@ ) # strain -from py4DSTEM.process import StrainMap +from py4DSTEM.process.strain.strain import StrainMap # TODO - crystal # TODO - ptycho diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index e711e907d..0df11ef01 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,5 +1,5 @@ from py4DSTEM.process.polar import PolarDatacube -from py4DSTEM.process.strain import StrainMap +from py4DSTEM.process.strain.strain import StrainMap from py4DSTEM.process import latticevectors from py4DSTEM.process import phase diff --git a/py4DSTEM/process/latticevectors/__init__.py b/py4DSTEM/process/latticevectors/__init__.py index 560a3b7e6..cda4f91e5 100644 --- a/py4DSTEM/process/latticevectors/__init__.py +++ b/py4DSTEM/process/latticevectors/__init__.py @@ -1,4 +1,3 @@ from py4DSTEM.process.latticevectors.initialguess import * from py4DSTEM.process.latticevectors.index import * from py4DSTEM.process.latticevectors.fit import * -from py4DSTEM.process.latticevectors.strain import * diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py index 659bc8940..d36b10bca 100644 --- a/py4DSTEM/process/latticevectors/fit.py +++ b/py4DSTEM/process/latticevectors/fit.py @@ -7,135 +7,6 @@ from py4DSTEM.data import RealSlice -def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (7-tuple) A 7-tuple containing: - - * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. - * **y0**: *(float)* the y-coord of the origin - * **g1x**: *(float)* x-coord of the first lattice vector - * **g1y**: *(float)* y-coord of the first lattice vector - * **g2x**: *(float)* x-coord of the second lattice vector - * **g2y**: *(float)* y-coord of the second lattice vector - * **error**: *(float)* the fit error - """ - assert isinstance(braggpeaks, PointList) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - braggpeaks = braggpeaks.copy() - - # Remove unindexed peaks - if "index_mask" in braggpeaks.dtype.names: - deletemask = braggpeaks.data["index_mask"] == False - braggpeaks.remove(deletemask) - - # Check to ensure enough peaks are present - if braggpeaks.length < minNumPeaks: - return None, None, None, None, None, None, None - - # Get M, the matrix of (h,k) indices - h, k = braggpeaks.data["h"], braggpeaks.data["k"] - M = np.vstack((np.ones_like(h, dtype=int), h, k)).T - - # Get alpha, the matrix of measured Bragg peak positions - alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T - - # Get weighted matrices - weights = braggpeaks.data["intensity"] - weighted_M = M * weights[:, np.newaxis] - weighted_alpha = alpha * weights[:, np.newaxis] - - # Solve for lattice vectors - beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] - x0, y0 = beta[0, 0], beta[0, 1] - g1x, g1y = beta[1, 0], beta[1, 1] - g2x, g2y = beta[2, 0], beta[2, 1] - - # Calculate the error - alpha_calculated = np.matmul(M, beta) - error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) - error = np.sum(error * weights) / np.sum(weights) - - return x0, y0, g1x, g1y, g2x, g2y, error - - -def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - - return g1g2_map - - def fit_lattice_vectors_masked(braggpeaks, mask, x0=0, y0=0, minNumPeaks=5): """ Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks corresponding diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py index 4ac7939e7..2d243cd0c 100644 --- a/py4DSTEM/process/latticevectors/index.py +++ b/py4DSTEM/process/latticevectors/index.py @@ -34,61 +34,6 @@ def get_selected_lattice_vectors(gx, gy, i0, i1, i2): return (g1x, g1y), (g2x, g2y) -def index_bragg_directions(x0, y0, gx, gy, g1, g2): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - The approach is to solve the matrix equation - ``alpha = beta * M`` - where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, - beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the - h,k indices. - - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - gx (1d array): x-coord of the reciprocal lattice vectors - gy (1d array): y-coord of the reciprocal lattice vectors - g1 (2-tuple of floats): g1x,g1y - g2 (2-tuple of floats): g2x,g2y - - Returns: - (3-tuple) A 3-tuple containing: - - * **h**: *(ndarray of ints)* first index of the bragg directions - * **k**: *(ndarray of ints)* second index of the bragg directions - * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the - indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y - coords 'h' and 'k' contain h and k. - """ - # Get beta, the matrix of lattice vectors - beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) - - # Get alpha, the matrix of measured bragg angles - alpha = np.vstack([gx - x0, gy - y0]) - - # Calculate M, the matrix of peak positions - M = lstsq(beta, alpha, rcond=None)[0].T - M = np.round(M).astype(int) - - # Get h,k - h = M[:, 0] - k = M[:, 1] - - # Store in a PointList - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - temp_array = np.zeros([], dtype=coords) - bragg_directions = PointList(data=temp_array) - bragg_directions.add_data_by_field((gx, gy, h, k)) - mask = np.zeros(bragg_directions["qx"].shape[0]) - mask[0] = 1 - bragg_directions.remove(mask) - - return h, k, bragg_directions - - def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None): """ Returns a full reciprocal lattice stretching to the limits of the diffraction pattern @@ -163,82 +108,6 @@ def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None) return ideal_lattice -def add_indices_to_braggvectors( - braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None -): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - braggpeaks (PointListArray): the braggpeaks to index. Must contain - the coordinates 'qx', 'qy', and 'intensity' - lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. - Must contain the coordinates 'qx', 'qy', 'h', and 'k' - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - - Returns: - (PointListArray): The original braggpeaks pointlistarray, with new coordinates - 'h', 'k', containing the indices of each indexable peak. - """ - - # assert isinstance(braggpeaks,BraggVectors) - # assert isinstance(lattice, PointList) - # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) - - if mask is None: - mask = np.ones(braggpeaks.Rshape, dtype=bool) - - assert ( - mask.shape == braggpeaks.Rshape - ), "mask must have same shape as pointlistarray" - assert mask.dtype == bool, "mask must be boolean" - - coords = [ - ("qx", float), - ("qy", float), - ("intensity", float), - ("h", int), - ("k", int), - ] - - indexed_braggpeaks = PointListArray( - dtype=coords, - shape=braggpeaks.Rshape, - ) - - # loop over all the scan positions - for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): - if mask[Rx, Ry]: - pl = braggpeaks.cal[Rx, Ry] - for i in range(pl.data.shape[0]): - r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( - pl.data["qy"][i] - lattice.data["qy"] + qy_shift - ) ** 2 - ind = np.argmin(r2) - if r2[ind] <= maxPeakSpacing**2: - indexed_braggpeaks[Rx, Ry].add_data_by_field( - ( - pl.data["qx"][i], - pl.data["qy"][i], - pl.data["intensity"][i], - lattice.data["h"][ind], - lattice.data["k"][ind], - ) - ) - - return indexed_braggpeaks def bragg_vector_intensity_map_by_index(braggpeaks, h, k, symmetric=False): diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py deleted file mode 100644 index 6f4000449..000000000 --- a/py4DSTEM/process/latticevectors/strain.py +++ /dev/null @@ -1,231 +0,0 @@ -# Functions for calculating strain from lattice vector maps - -import numpy as np -from numpy.linalg import lstsq - -from py4DSTEM.data import RealSlice - - -def get_reference_g1g2(g1g2_map, mask): - """ - Gets a pair of reference lattice vectors from a region of real space specified by - mask. Takes the median of the lattice vectors in g1g2_map within the specified - region. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever - mask==True - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing: - - * **g1**: *(2-tuple)* first reference lattice vector (x,y) - * **g2**: *(2-tuple)* second reference lattice vector (x,y) - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] - ) - assert mask.dtype == bool - g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) - g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) - g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) - g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) - return (g1x, g1y), (g2x, g2y) - - -def get_strain_from_reference_g1g2(g1g2_map, g1, g2): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real and - diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - g1 (2-tuple): first reference lattice vector (x,y) - g2 (2-tuple): second reference lattice vector (x,y) - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - - # Get RealSlice for output storage - R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape - strain_map = RealSlice( - data=np.zeros((5, R_Nx, R_Ny)), - slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), - name="strain_map", - ) - - # Get reference lattice matrix - g1x, g1y = g1 - g2x, g2y = g2 - M = np.array([[g1x, g1y], [g2x, g2y]]) - - for Rx in range(R_Nx): - for Ry in range(R_Ny): - # Get lattice vectors for DP at Rx,Ry - alpha = np.array( - [ - [ - g1g2_map.get_slice("g1x").data[Rx, Ry], - g1g2_map.get_slice("g1y").data[Rx, Ry], - ], - [ - g1g2_map.get_slice("g2x").data[Rx, Ry], - g1g2_map.get_slice("g2y").data[Rx, Ry], - ], - ] - ) - # Get transformation matrix - beta = lstsq(M, alpha, rcond=None)[0].T - - # Get the infinitesimal strain matrix - strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] - strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] - strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 - strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 - strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ - Rx, Ry - ] - return strain_map - - -def get_strain_from_reference_region(g1g2_map, mask): - """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real - and diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - assert mask.dtype == bool - - g1, g2 = get_reference_g1g2(g1g2_map, mask) - strain_map = get_strain_from_reference_g1g2(g1g2_map, g1, g2) - return strain_map - - -def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector - along the new x-axis - unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the - infinitessimal strain matrix elements, stored at - * unrotated_strain_map.get_slice('e_xx') - * unrotated_strain_map.get_slice('e_xy') - * unrotated_strain_map.get_slice('e_yy') - * unrotated_strain_map.get_slice('theta') - - Returns: - (RealSlice) the rotated counterpart to unrotated_strain_map, with the - rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate - system - """ - assert isinstance(unrotated_strain_map, RealSlice) - assert np.all( - [ - key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] - for key in unrotated_strain_map.slicelabels - ] - ) - theta = -np.arctan2(xaxis_y, xaxis_x) - cost = np.cos(theta) - sint = np.sin(theta) - cost2 = cost**2 - sint2 = sint**2 - - Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape - rotated_strain_map = RealSlice( - data=np.zeros((5, Rx, Ry)), - slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], - name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), - ) - - rotated_strain_map.data[0, :, :] = ( - cost2 * unrotated_strain_map.get_slice("e_xx").data - - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + sint2 * unrotated_strain_map.get_slice("e_yy").data - ) - rotated_strain_map.data[1, :, :] = ( - cost - * sint - * ( - unrotated_strain_map.get_slice("e_xx").data - - unrotated_strain_map.get_slice("e_yy").data - ) - + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data - ) - rotated_strain_map.data[2, :, :] = ( - sint2 * unrotated_strain_map.get_slice("e_xx").data - + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + cost2 * unrotated_strain_map.get_slice("e_yy").data - ) - if flip_theta == True: - rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data - else: - rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data - rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data - return rotated_strain_map diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py new file mode 100644 index 000000000..b47682aa4 --- /dev/null +++ b/py4DSTEM/process/strain/__init__.py @@ -0,0 +1,2 @@ +from py4DSTEM.process.strain.strain import StrainMap +from py4DSTEM.process.strain.latticevectors import * diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py new file mode 100644 index 000000000..90f7f938d --- /dev/null +++ b/py4DSTEM/process/strain/latticevectors.py @@ -0,0 +1,448 @@ +# Functions for indexing the Bragg directions + +import numpy as np +from emdfile import PointList, PointListArray, tqdmnd +from numpy.linalg import lstsq +from py4DSTEM.data import RealSlice + + +def index_bragg_directions(x0, y0, gx, gy, g1, g2): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + The approach is to solve the matrix equation + ``alpha = beta * M`` + where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, + beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the + h,k indices. + + Args: + x0 (float): x-coord of origin + y0 (float): y-coord of origin + gx (1d array): x-coord of the reciprocal lattice vectors + gy (1d array): y-coord of the reciprocal lattice vectors + g1 (2-tuple of floats): g1x,g1y + g2 (2-tuple of floats): g2x,g2y + + Returns: + (3-tuple) A 3-tuple containing: + + * **h**: *(ndarray of ints)* first index of the bragg directions + * **k**: *(ndarray of ints)* second index of the bragg directions + * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the + indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y + coords 'h' and 'k' contain h and k. + """ + # Get beta, the matrix of lattice vectors + beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) + + # Get alpha, the matrix of measured bragg angles + alpha = np.vstack([gx - x0, gy - y0]) + + # Calculate M, the matrix of peak positions + M = lstsq(beta, alpha, rcond=None)[0].T + M = np.round(M).astype(int) + + # Get h,k + h = M[:, 0] + k = M[:, 1] + + # Store in a PointList + coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] + temp_array = np.zeros([], dtype=coords) + bragg_directions = PointList(data=temp_array) + bragg_directions.add_data_by_field((gx, gy, h, k)) + mask = np.zeros(bragg_directions["qx"].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) + + return h, k, bragg_directions + + +def add_indices_to_braggvectors( + braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None +): + """ + Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, + identify the indices for each peak in the PointListArray braggpeaks. + Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus + three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak + indices with the ints (h,k) and indicating whether the peak was successfully indexed + or not with the bool index_mask. If `mask` is specified, only the locations where + mask is True are indexed. + + Args: + braggpeaks (PointListArray): the braggpeaks to index. Must contain + the coordinates 'qx', 'qy', and 'intensity' + lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. + Must contain the coordinates 'qx', 'qy', 'h', and 'k' + maxPeakSpacing (float): Maximum distance from the ideal lattice points + to include a peak for indexing + qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList + relative to the `braggpeaks` PointListArray + mask (bool): Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + + Returns: + (PointListArray): The original braggpeaks pointlistarray, with new coordinates + 'h', 'k', containing the indices of each indexable peak. + """ + + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + + if mask is None: + mask = np.ones(braggpeaks.Rshape, dtype=bool) + + assert ( + mask.shape == braggpeaks.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + + coords = [ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ] + + indexed_braggpeaks = PointListArray( + dtype=coords, + shape=braggpeaks.Rshape, + ) + + # loop over all the scan positions + for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + if mask[Rx, Ry]: + pl = braggpeaks.cal[Rx, Ry] + for i in range(pl.data.shape[0]): + r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( + pl.data["qy"][i] - lattice.data["qy"] + qy_shift + ) ** 2 + ind = np.argmin(r2) + if r2[ind] <= maxPeakSpacing**2: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + lattice.data["h"][ind], + lattice.data["k"][ind], + ) + ) + + return indexed_braggpeaks + + +def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (7-tuple) A 7-tuple containing: + + * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. + * **y0**: *(float)* the y-coord of the origin + * **g1x**: *(float)* x-coord of the first lattice vector + * **g1y**: *(float)* y-coord of the first lattice vector + * **g2x**: *(float)* x-coord of the second lattice vector + * **g2y**: *(float)* y-coord of the second lattice vector + * **error**: *(float)* the fit error + """ + assert isinstance(braggpeaks, PointList) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + braggpeaks = braggpeaks.copy() + + # Remove unindexed peaks + if "index_mask" in braggpeaks.dtype.names: + deletemask = braggpeaks.data["index_mask"] == False + braggpeaks.remove(deletemask) + + # Check to ensure enough peaks are present + if braggpeaks.length < minNumPeaks: + return None, None, None, None, None, None, None + + # Get M, the matrix of (h,k) indices + h, k = braggpeaks.data["h"], braggpeaks.data["k"] + M = np.vstack((np.ones_like(h, dtype=int), h, k)).T + + # Get alpha, the matrix of measured Bragg peak positions + alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T + + # Get weighted matrices + weights = braggpeaks.data["intensity"] + weighted_M = M * weights[:, np.newaxis] + weighted_alpha = alpha * weights[:, np.newaxis] + + # Solve for lattice vectors + beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] + x0, y0 = beta[0, 0], beta[0, 1] + g1x, g1y = beta[1, 0], beta[1, 1] + g2x, g2y = beta[2, 0], beta[2, 1] + + # Calculate the error + alpha_calculated = np.matmul(M, beta) + error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) + error = np.sum(error * weights) / np.sum(weights) + + return x0, y0, g1x, g1y, g2x, g2y, error + + +def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some + known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: + + * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice + * ``g1g2_map.get_slice('y0')`` y-coord of the origin + * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector + * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector + * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector + * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector + * ``g1g2_map.get_slice('error')`` the fit error + * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful + fits + """ + assert isinstance(braggpeaks, PointListArray) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + + # Make RealSlice to contain outputs + slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") + g1g2_map = RealSlice( + data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), + slicelabels=slicelabels, + name="g1g2_map", + ) + + # Fit lattice vectors + for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): + braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) + qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( + braggpeaks_curr, x0, y0, minNumPeaks + ) + # Store data + if g1x is not None: + g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x + g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y + g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x + g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y + g1g2_map.get_slice("error").data[Rx, Ry] = error + g1g2_map.get_slice("mask").data[Rx, Ry] = 1 + + return g1g2_map + + +def get_reference_g1g2(g1g2_map, mask): + """ + Gets a pair of reference lattice vectors from a region of real space specified by + mask. Takes the median of the lattice vectors in g1g2_map within the specified + region. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever + mask==True + + Returns: + (2-tuple of 2-tuples) A 2-tuple containing: + + * **g1**: *(2-tuple)* first reference lattice vector (x,y) + * **g2**: *(2-tuple)* second reference lattice vector (x,y) + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] + ) + assert mask.dtype == bool + g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) + g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) + g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) + g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) + return (g1x, g1y), (g2x, g2y) + + +def get_strain_from_reference_g1g2(g1g2_map, g1, g2): + """ + Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map + g1g2_map. + + Note that this function will return the strain map oriented with respect to the x/y + axes of diffraction space - to rotate the coordinate system, use + get_rotated_strain_map(). Calibration of the rotational misalignment between real and + diffraction space may also be necessary. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + g1 (2-tuple): first reference lattice vector (x,y) + g2 (2-tuple): second reference lattice vector (x,y) + + Returns: + (RealSlice) the strain map; contains the elements of the infinitessimal strain + matrix, in the following 5 arrays: + + * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect + to x + * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect + to y + * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect + to y + * ``strain_map.get_slice('theta')``: rotation of lattice with respect to + reference + * ``strain_map.get_slice('mask')``: 0/False indicates unknown values + + Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] + ) + + # Get RealSlice for output storage + R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), + name="strain_map", + ) + + # Get reference lattice matrix + g1x, g1y = g1 + g2x, g2y = g2 + M = np.array([[g1x, g1y], [g2x, g2y]]) + + for Rx in range(R_Nx): + for Ry in range(R_Ny): + # Get lattice vectors for DP at Rx,Ry + alpha = np.array( + [ + [ + g1g2_map.get_slice("g1x").data[Rx, Ry], + g1g2_map.get_slice("g1y").data[Rx, Ry], + ], + [ + g1g2_map.get_slice("g2x").data[Rx, Ry], + g1g2_map.get_slice("g2y").data[Rx, Ry], + ], + ] + ) + # Get transformation matrix + beta = lstsq(M, alpha, rcond=None)[0].T + + # Get the infinitesimal strain matrix + strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] + strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] + strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 + strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 + strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ + Rx, Ry + ] + return strain_map + +def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): + """ + Starting from a strain map defined with respect to the xy coordinate system of + diffraction space, i.e. where exx and eyy are the compression/tension along the Qx + and Qy directions, respectively, get a strain map defined with respect to some other + right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, + xaxis_y). + + Args: + xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector + along the new x-axis + unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the + infinitessimal strain matrix elements, stored at + * unrotated_strain_map.get_slice('e_xx') + * unrotated_strain_map.get_slice('e_xy') + * unrotated_strain_map.get_slice('e_yy') + * unrotated_strain_map.get_slice('theta') + + Returns: + (RealSlice) the rotated counterpart to unrotated_strain_map, with the + rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate + system + """ + assert isinstance(unrotated_strain_map, RealSlice) + assert np.all( + [ + key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] + for key in unrotated_strain_map.slicelabels + ] + ) + theta = -np.arctan2(xaxis_y, xaxis_x) + cost = np.cos(theta) + sint = np.sin(theta) + cost2 = cost**2 + sint2 = sint**2 + + Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx, Ry)), + slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], + name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), + ) + + rotated_strain_map.data[0, :, :] = ( + cost2 * unrotated_strain_map.get_slice("e_xx").data + - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + sint2 * unrotated_strain_map.get_slice("e_yy").data + ) + rotated_strain_map.data[1, :, :] = ( + cost + * sint + * ( + unrotated_strain_map.get_slice("e_xx").data + - unrotated_strain_map.get_slice("e_yy").data + ) + + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data + ) + rotated_strain_map.data[2, :, :] = ( + sint2 * unrotated_strain_map.get_slice("e_xx").data + + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + cost2 * unrotated_strain_map.get_slice("e_yy").data + ) + if flip_theta == True: + rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data + else: + rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data + rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data + return rotated_strain_map diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain/strain.py similarity index 86% rename from py4DSTEM/process/strain.py rename to py4DSTEM/process/strain/strain.py index db252f75b..751016a89 100644 --- a/py4DSTEM/process/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -1,5 +1,6 @@ # Defines the Strain class +import warnings from typing import Optional import matplotlib.pyplot as plt @@ -8,20 +9,33 @@ from py4DSTEM.braggvectors import BraggVectors from py4DSTEM.data import Data, RealSlice from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.process.strain.latticevectors import ( + add_indices_to_braggvectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_rotated_strain_map, + get_strain_from_reference_g1g2, + index_bragg_directions, +) from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show +warnings.simplefilter(action="always", category=UserWarning) + class StrainMap(RealSlice, Data): """ - Stores strain map. - - TODO add docs - + Storage and processing methods for 4D-STEM datasets. + """ def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): + """ - TODO + Accepts: + braggvectors (BraggVectors): BraggVectors for Strain Map + name (str): the name of the strainmap + Returns: + A new StrainMap instance. """ assert isinstance( braggvectors, BraggVectors @@ -58,6 +72,12 @@ def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap" # re-calibration are issued self.calstate = self.braggvectors.calstate assert self.calstate["center"], "braggvectors must be centered" + if self.calstate["rotate"] == False: + warnings.warn( + ("Real to reciprocal space rotaiton not calibrated"), + UserWarning, + ) + # get the BVM # a new BVM using the current calstate is computed self.bvm = self.braggvectors.histogram(mode="cal") @@ -110,16 +130,18 @@ def choose_lattice_vectors( minSpacing=0, edgeBoundary=1, maxNumPeaks=10, - figsize=(12, 6), + x0=None, + y0=None, + figsize=(14, 9), c_indices="lightblue", c0="g", c1="r", c2="r", c_vectors="r", c_vectorlabels="w", - size_indices=20, + size_indices=15, width_vectors=1, - size_vectorlabels=20, + size_vectorlabels=15, vis_params={}, returncalc=False, returnfig=False, @@ -198,6 +220,7 @@ def choose_lattice_vectors( ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." # find the maxima + g = get_maxima_2D( self.bvm.data, subpixel=subpixel, @@ -220,10 +243,34 @@ def choose_lattice_vectors( g2y = gy[index_g2] - g0[1] g1, g2 = (g1x, g1y), (g2x, g2y) + # if x0 is None: + # x0 = self.braggvectors.Qshape[0] / 2 + # if y0 is None: + # y0 = self.braggvectors.Qshape[0] / 2 + + # index braggvectors + # _, _, braggdirections = index_bragg_directions( + # x0, y0, g["x"], g["y"], g1, g2 + # ) + + _, _, braggdirections = index_bragg_directions( + g0[0], g0[1], g["x"], g["y"], g1, g2 + ) + + self.braggdirections = braggdirections + # make the figure - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - show(self.bvm.data, figax=(fig, ax1), **vis_params) - show(self.bvm.data, figax=(fig, ax2), **vis_params) + fig, ax = plt.subplots(1, 3, figsize=figsize) + show(self.bvm.data, figax=(fig, ax[0]), **vis_params) + show(self.bvm.data, figax=(fig, ax[1]), **vis_params) + self.show_bragg_indexing( + self.bvm.data, + bragg_directions=braggdirections, + points=True, + figax=(fig, ax[2]), + size=size_indices, + **vis_params, + ) # Add indices to left panel d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} @@ -251,10 +298,10 @@ def choose_lattice_vectors( "fontweight": "bold", "labels": [str(index_g2)], } - add_pointlabels(ax1, d) - add_pointlabels(ax1, d0) - add_pointlabels(ax1, d1) - add_pointlabels(ax1, d2) + add_pointlabels(ax[0], d) + add_pointlabels(ax[0], d0) + add_pointlabels(ax[0], d1) + add_pointlabels(ax[0], d2) # Add vectors to right panel dg1 = { @@ -279,8 +326,8 @@ def choose_lattice_vectors( "labelsize": size_vectorlabels, "labelcolor": c_vectorlabels, } - add_vector(ax2, dg1) - add_vector(ax2, dg2) + add_vector(ax[1], dg1) + add_vector(ax[1], dg2) # store vectors self.g = g @@ -290,18 +337,16 @@ def choose_lattice_vectors( # return if returncalc and returnfig: - return (g0, g1, g2), (fig, (ax1, ax2)) + return (g0, g1, g2), (fig, ax) elif returncalc: return (g0, g1, g2) elif returnfig: - return (fig, (ax1, ax2)) + return (fig, ax) else: return def fit_lattice_vectors( self, - x0=None, - y0=None, max_peak_spacing=2, mask=None, plot=True, @@ -337,31 +382,6 @@ def fit_lattice_vectors( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - if x0 is None: - x0 = self.braggvectors.Qshape[0] / 2 - if y0 is None: - y0 = self.braggvectors.Qshape[0] / 2 - - # index braggvectors - from py4DSTEM.process.latticevectors import index_bragg_directions - - _, _, braggdirections = index_bragg_directions( - x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 - ) - - self.braggdirections = braggdirections - - if plot: - self.show_bragg_indexing( - self.bvm, - bragg_directions=braggdirections, - points=True, - **vis_params, - ) - - # add indicies to braggvectors - from py4DSTEM.process.latticevectors import add_indices_to_braggvectors - bragg_vectors_indexed = add_indices_to_braggvectors( self.braggvectors, self.braggdirections, @@ -374,13 +394,11 @@ def fit_lattice_vectors( self.bragg_vectors_indexed = bragg_vectors_indexed # fit bragg vectors - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) self.g1g2_map = g1g2_map if returncalc: - braggdirections, bragg_vectors_indexed, g1g2_map + self.braggdirections, self.bragg_vectors_indexed, self.g1g2_map def get_strain( self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs @@ -407,30 +425,24 @@ def get_strain( if mask is None: mask = np.ones(self.g1g2_map.shape, dtype="bool") - from py4DSTEM.process.latticevectors import get_strain_from_reference_region + # strainmap_g1g2 = get_strain_from_reference_region( + # self.g1g2_map, + # mask=mask, + # ) - strainmap_g1g2 = get_strain_from_reference_region( - self.g1g2_map, - mask=mask, - ) - else: - from py4DSTEM.process.latticevectors import get_reference_g1g2 + # g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + # strain_map = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) + # else: - g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - - strainmap_g1g2 = get_strain_from_reference_g1g2( - self.g1g2_map, g1_ref, g2_ref - ) + strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) self.strainmap_g1g2 = strainmap_g1g2 if g_reference is None: g_reference = np.subtract(self.g1, self.g2) - from py4DSTEM.process.latticevectors import get_rotated_strain_map - strainmap_rotated = get_rotated_strain_map( self.strainmap_g1g2, xaxis_x=g_reference[0], @@ -529,6 +541,7 @@ def show_bragg_indexing( points=True, pointcolor="r", pointsize=50, + figax=None, returnfig=False, **kwargs, ): @@ -544,7 +557,13 @@ def show_bragg_indexing( for k in ("qx", "qy", "h", "k"): assert k in bragg_directions.data.dtype.fields - fig, ax = show(ar, returnfig=True, **kwargs) + if figax is None: + fig, ax = show(ar, returnfig=True, **kwargs) + else: + fig = figax[0] + ax = figax[1] + show(ar, figax=figax, **kwargs) + d = { "bragg_directions": bragg_directions, "voffset": voffset, @@ -560,7 +579,6 @@ def show_bragg_indexing( if returnfig: return fig, ax else: - plt.show() return def copy(self, name=None): @@ -585,7 +603,7 @@ def copy(self, name=None): strainmap_copy.metadata = self.metadata[k].copy() return strainmap_copy - # IO methods + # TODO IO methods # read @classmethod From c25e24264249da686a3ad8d14f01cb098aee25ea Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 08:58:06 -0700 Subject: [PATCH 032/176] generalizing overlap tomo to orientation matrices --- .../phase/iterative_overlap_tomography.py | 107 +++++++++++------- 1 file changed, 66 insertions(+), 41 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..98bfb7b5f 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -55,8 +55,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): The electron energy of the wave functions in eV num_slices: int Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of tilt angles in degrees, + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -94,13 +94,18 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _swap_zxy_to_xyz = np.array([ + [0,1,0], + [0,0,1], + [1,0,0] + ]) def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[float], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -122,22 +127,24 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from scipy.special import erf self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from cupyx.scipy.special import erf self._erf = erf @@ -156,7 +163,7 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts @@ -185,7 +192,7 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts def _precompute_propagator_arrays( @@ -323,6 +330,30 @@ def _expand_sliced_object(self, array: np.ndarray, output_z): normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + ): + """ + """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T@rot_matrix.T@swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume,tf,offset=offset,order=3) + + return volume + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -663,15 +694,15 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - current_angle_deg = self._tilt_angles_deg[tilt_index] - probe_overlap_3D = self._rotate( + + rot_matrix = self._tilt_orientation_matrices[tilt_index] + + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, + rot_matrix@old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -691,14 +722,12 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] - - probe_overlap_3D = self._rotate( - probe_overlap_3D, - -current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, - ) + old_rot_matrix = rot_matrix + + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( @@ -2018,17 +2047,17 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index tilt_error = 0.0 - self._object = self._rotate( + rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + self._object = self._rotate_zxy_volume( self._object, - self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix @ old_rot_matrix.T, ) object_sliced = self._project_sliced_object( @@ -2132,24 +2161,15 @@ def reconstruct( ) if collective_tilt_updates: - collective_object += self._rotate( + collective_object += self._rotate_zxy_volume( object_update, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix.T ) else: self._object += object_update - - self._object = self._rotate( - self._object, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, - ) - + + old_rot_matrix = rot_matrix + # Normalize Error tilt_error /= ( self._mean_diffraction_intensity[self._active_tilt_index] @@ -2205,6 +2225,11 @@ def reconstruct( else None, ) + self._object = self._rotate_zxy_volume( + self._object, + old_rot_matrix.T + ) + # Normalize Error Over Tilts error /= self._num_tilts From 31b5525b25bb7c07b15a18e20c918b4edb7ec8d5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 09:03:45 -0700 Subject: [PATCH 033/176] black formatting --- .../iterative_overlap_magnetic_tomography.py | 6 +- .../phase/iterative_overlap_tomography.py | 61 +++++++++---------- .../iterative_ptychographic_constraints.py | 2 +- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 2642b7193..712a35647 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -2646,7 +2646,11 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - (self._object, self._probe, _,) = self._constraints( + ( + self._object, + self._probe, + _, + ) = self._constraints( self._object, self._probe, None, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index e8cde608e..110a547b6 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -96,11 +96,7 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): # Class-specific Metadata _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") - _swap_zxy_to_xyz = np.array([ - [0,1,0], - [0,0,1], - [1,0,0] - ]) + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, @@ -140,7 +136,12 @@ def __init__( elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform + from cupyx.scipy.ndimage import ( + gaussian_filter, + rotate, + zoom, + affine_transform, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom @@ -335,24 +336,23 @@ def _rotate_zxy_volume( self, volume_array, rot_matrix, - ): - """ - """ - + ): + """ """ + xp = self._xp affine_transform = self._affine_transform swap_zxy_to_xyz = self._swap_zxy_to_xyz - + volume = volume_array.copy() volume_shape = xp.asarray(volume.shape) - tf = xp.asarray(swap_zxy_to_xyz.T@rot_matrix.T@swap_zxy_to_xyz) - + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + in_center = (volume_shape - 1) / 2 out_center = tf @ in_center offset = in_center - out_center - - volume = affine_transform(volume,tf,offset=offset,order=3) - + + volume = affine_transform(volume, tf, offset=offset, order=3) + return volume def preprocess( @@ -695,15 +695,14 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) - old_rot_matrix = np.eye(3) # identity + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - rot_matrix = self._tilt_orientation_matrices[tilt_index] probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - rot_matrix@old_rot_matrix.T, + rot_matrix @ old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -724,7 +723,7 @@ def preprocess( probe_overlap_3D += probe_overlap[None] old_rot_matrix = rot_matrix - + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, old_rot_matrix.T, @@ -2181,8 +2180,8 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) - old_rot_matrix = np.eye(3) # identity - + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index @@ -2296,14 +2295,13 @@ def reconstruct( if collective_tilt_updates: collective_object += self._rotate_zxy_volume( - object_update, - rot_matrix.T + object_update, rot_matrix.T ) else: self._object += object_update - + old_rot_matrix = rot_matrix - + # Normalize Error tilt_error /= ( self._mean_diffraction_intensity[self._active_tilt_index] @@ -2363,10 +2361,7 @@ def reconstruct( tv_denoise_inner_iter=tv_denoise_inner_iter, ) - self._object = self._rotate_zxy_volume( - self._object, - old_rot_matrix.T - ) + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) # Normalize Error Over Tilts error /= self._num_tilts @@ -2374,7 +2369,11 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - (self._object, self._probe, _,) = self._constraints( + ( + self._object, + self._probe, + _, + ) = self._constraints( self._object, self._probe, None, diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 217253945..ba9f28332 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -197,7 +197,7 @@ def _object_denoise_tv_pylops(self, current_object, weight, iterations): Denoising weight. The greater `weight`, the more denoising (at the expense of fidelity to `input`). iterations: float - Number of iterations to run in denoising algorithm. + Number of iterations to run in denoising algorithm. `niter_out` in pylops Returns From d4363f4eb081b3deb54ce559d3987caf7967782f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 09:30:35 -0700 Subject: [PATCH 034/176] flake8 6.1.0 found some more issues --- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 4 ++-- py4DSTEM/process/phase/iterative_overlap_tomography.py | 4 ++-- py4DSTEM/process/phase/iterative_parallax.py | 2 +- py4DSTEM/process/phase/iterative_ptychographic_constraints.py | 3 ++- py4DSTEM/process/phase/utils.py | 4 ++-- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index d9bf6bacf..3e36978e4 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2254,7 +2254,7 @@ def reconstruct( and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None - and type(kz_regularization_gamma) == np.ndarray + and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, @@ -3068,7 +3068,7 @@ def show_depth( 0, ] - if plot_line_profile == False: + if not plot_line_profile: fig, ax = plt.subplots() im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 110a547b6..fd9af0bb2 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -124,7 +124,7 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom, affine_transform + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom @@ -137,10 +137,10 @@ def __init__( self._xp = cp self._asnumpy = cp.asnumpy from cupyx.scipy.ndimage import ( + affine_transform, gaussian_filter, rotate, zoom, - affine_transform, ) self._gaussian_filter = gaussian_filter diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index d8824b770..b23fe2cae 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1620,7 +1620,7 @@ def _crop_padded_object( pad_x = self._object_padding_px[0] // 2 - remaining_padding pad_y = self._object_padding_px[1] // 2 - remaining_padding - if upsampled == True: + if upsampled: pad_x *= self._kde_upsample_factor pad_y *= self._kde_upsample_factor diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index ba9f28332..4721ed12b 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pylops from py4DSTEM.process.phase.utils import ( @@ -8,7 +10,6 @@ regularize_probe_amplitude, ) from py4DSTEM.process.utils import get_CoM -import warnings class PtychographicConstraints: diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index a8a702f89..d06db111c 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1614,9 +1614,9 @@ def fit_aberration_surface( def rotate_point(origin, point, angle): """ - Rotate a point (x1, y1) counterclockwise by a given angle around + Rotate a point (x1, y1) counterclockwise by a given angle around a given origin (x0, y0). - + Parameters -------- origin: 2-tuple of floats From e60071a3859283667cefeaad9b9f95495928418a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 20 Sep 2023 14:28:18 -0700 Subject: [PATCH 035/176] adding mixed-state multi-slice ptycho class --- py4DSTEM/process/phase/__init__.py | 30 +- ...tive_mixedstate_multislice_ptychography.py | 3513 +++++++++++++++++ 2 files changed, 3521 insertions(+), 22 deletions(-) create mode 100644 py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 178079349..1005a619d 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,28 +3,14 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import ( - MixedstatePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_multislice_ptychography import ( - MultislicePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import ( - OverlapMagneticTomographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_tomography import ( - OverlapTomographicReconstruction, -) +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import ( - SimultaneousPtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( - SingleslicePtychographicReconstruction, -) -from py4DSTEM.process.phase.parameter_optimize import ( - OptimizationParameter, - PtychographyOptimizer, -) +from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..acb9f12a2 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -0,0 +1,3513 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pylops +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex + +try: + import cupy as cp +except ImportError: + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate + +warnings.simplefilter(action="always", category=UserWarning) + + +class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + object_type: str = "complex", + verbose: bool = True, + device: str = "cpu", + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._object_padding_px = object_padding_px + self._verbose = verbose + self._device = device + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + probe_roi_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + probe_roi_shape, (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._probe_roi_shape = probe_roi_shape + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + self._intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + self.com_x, + self.com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + ( + self._amplitudes, + self._mean_diffraction_intensity, + ) = self._normalize_diffraction_intensities( + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + ) + + # explicitly delete namespace + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + del self._intensities + + self._positions_px = self._calculate_scan_positions_in_pixels( + self._scan_positions + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + if self._object is None: + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) + p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( + "int" + ) + if self._object_type == "potential": + self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) + else: + if self._object_type == "potential": + self._object = xp.asarray(self._object, dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.asarray(self._object, dtype=xp.complex64) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # Vectorized Patches + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Probe Initialization + if self._probe is None or isinstance(self._probe, ComplexProbe): + if self._probe is None: + if self._vacuum_probe_intensity is not None: + self._semiangle_cutoff = np.inf + self._vacuum_probe_intensity = xp.asarray( + self._vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + self._vacuum_probe_intensity, + device=self._device, + ) + self._vacuum_probe_intensity = get_shifted_ar( + self._vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=self._device, + ) + + _probe = ( + ComplexProbe( + gpts=self._region_of_interest_shape, + sampling=self.sampling, + energy=self._energy, + semiangle_cutoff=self._semiangle_cutoff, + rolloff=self._rolloff, + vacuum_probe_intensity=self._vacuum_probe_intensity, + parameters=self._polar_parameters, + device=self._device, + ) + .build() + ._array + ) + + else: + if self._probe._gpts != self._region_of_interest_shape: + raise ValueError() + if hasattr(self._probe, "_array"): + _probe = self._probe._array + else: + self._probe._xp = xp + _probe = self._probe.build()._array + + self._probe = xp.zeros( + (self._num_probes,) + tuple(self._region_of_interest_shape), + dtype=xp.complex64, + ) + sx, sy = self._region_of_interest_shape + self._probe[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, self._num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + self._probe[i_probe] = ( + self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) + self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) + + else: + self._probe = xp.asarray(self._probe, dtype=xp.complex64) + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + ) + + # overlaps + shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) + probe_intensities = xp.abs(shifted_probes) ** 2 + probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + + if object_fov_mask is None: + self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + cmap = kwargs.pop("cmap", "Greys_r") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + hue_start = kwargs.pop("hue_start", 0) + invert = kwargs.pop("invert", False) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + vmin=vmin, + vmax=vmax, + hue_start=hue_start, + invert=invert, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + vmin=vmin, + vmax=vmax, + hue_start=hue_start, + invert=invert, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + **kwargs, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial Probe[0]") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + **kwargs, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + ) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated Probe[0]") + + ax3.imshow( + asnumpy(probe_overlap), + extent=extent, + cmap=cmap, + **kwargs, + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object Field of View") + + fig.tight_layout() + + self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + propagated_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) + propagated_probes[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes = ( + xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes[s + 1] = self._propagate_array( + transmitted_probes, self._propagator_arrays[s] + ) + + return propagated_probes, object_patches, transmitted_probes + + def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + + Returns + -------- + exit_waves:np.ndarray + Exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves + modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_exit_wave - transmitted_probes + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = transmitted_probes.copy() + + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = ( + projection_c * transmitted_probes + projection_y * exit_waves + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * transmitted_probes + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + ( + propagated_probes, + object_patches, + transmitted_probes, + ) = self._overlap_projection(current_object, current_probe) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, transmitted_probes + ) + + return propagated_probes, object_patches, transmitted_probes, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + current_probe, + transmitted_probes, + amplitudes, + current_positions, + positions_step_size, + constrain_position_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe:np.ndarray + fractionally-shifted probes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + constrain_position_distance: float + Distance to constrain position correction within original + field of view in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # Intensity gradient + exit_waves_fft = xp.fft.fft2(transmitted_probes) + exit_waves_fft_conj = xp.conj(exit_waves_fft) + estimated_intensity = xp.abs(exit_waves_fft) ** 2 + measured_intensity = amplitudes**2 + + flat_shape = (transmitted_probes.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # Computing perturbed exit waves one at a time to save on memory + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + # dx + obj_rolled_patches = complex_object[ + :, + (self._vectorized_patch_indices_row + 1) % self._object_shape[0], + self._vectorized_patch_indices_col, + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + # dy + obj_rolled_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + (self._vectorized_patch_indices_col + 1) % self._object_shape[1], + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + partial_intensity_dx = 2 * xp.real( + exit_waves_dx_fft * exit_waves_fft_conj + ).reshape(flat_shape) + partial_intensity_dy = 2 * xp.real( + exit_waves_dy_fft * exit_waves_fft_conj + ).reshape(flat_shape) + + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + + # positions_update = xp.einsum( + # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity + # ) + + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + if constrain_position_distance is not None: + constrain_position_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + x1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 0 + ] + y1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 1 + ] + x0 = self._positions_px_initial[:, 0] + y0 = self._positions_px_initial[:, 1] + if self._rotation_best_transpose: + x0, y0 = xp.array([y0, x0]) + x1, y1 = xp.array([y1, x1]) + + if self._rotation_best_rad is not None: + rotation_angle = self._rotation_best_rad + x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( + -rotation_angle + ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) + x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( + -rotation_angle + ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) + + outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( + x1 < (xp.min(x0) - constrain_position_distance) + ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( + y1 < (xp.min(y0) - constrain_position_distance) + ) > 0 + + positions_update[..., 0][outlier_ind] = 0 + + current_positions -= positions_step_size * positions_update[..., 0] + + return current_positions + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + 2D Butterworth filter + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + current_object = xp.pad( + current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" + ) + + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[1:] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + + def _constraints( + self, + current_object, + current_probe, + current_positions, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, + fix_positions, + global_affine_transformation, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + q_lowpass, + q_highpass, + butterworth_order, + kz_regularization_filter, + kz_regularization_gamma, + identical_slices, + object_positivity, + shrinkage_rad, + object_mask, + pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + orthogonalize_probe, + ): + """ + Ptychographic constraints operator. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + current_positions: np.ndarray + Current positions estimate + fix_com: bool + If True, probe CoM is fixed to the center + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool + If True, probe amplitude is constrained by top hat function + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude + fix_positions: bool + If True, positions are not updated + gaussian_filter: bool + If True, applies real-space gaussian filter in A + gaussian_filter_sigma: float + Standard deviation of gaussian kernel + butterworth_filter: bool + If True, applies fourier-space butterworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool + If True, applies fourier-space arctan regularization filter + kz_regularization_gamma: float + Slice regularization strength + identical_slices: bool + If True, forces all object slices to be identical + object_positivity: bool + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + pure_phase_object: bool + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True, performs TV denoising along z + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + orthogonalize_probe: bool + If True, probe will be orthogonalized + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + constrained_probe: np.ndarray + Constrained probe estimate + constrained_positions: np.ndarray + Constrained positions estimate + """ + + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, kz_regularization_gamma + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + pad_object=tv_denoise_pad_chambolle, + ) + + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # These constraints don't _really_ make sense for mixed-state + if fix_probe_aperture: + raise NotImplementedError() + elif constrain_probe_fourier_amplitude: + raise NotImplementedError() + if fit_probe_aberrations: + raise NotImplementedError() + if constrain_probe_amplitude: + raise NotImplementedError() + + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + max_iter: int = 64, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe_iter: int = 0, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions_iter: int = np.inf, + constrain_position_distance: float = None, + global_affine_transformation: bool = True, + gaussian_filter_sigma: float = None, + gaussian_filter_iter: int = np.inf, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + butterworth_filter_iter: int = np.inf, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter_iter: int = np.inf, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices_iter: int = 0, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + switch_object_iter: int = np.inf, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + max_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe_iter: int, optional + Number of iterations to run with a fixed probe before updating probe estimate + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions_iter: int, optional + Number of iterations to run with fixed positions before updating positions estimate + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter_iter: int, optional + Number of iterations to run using object smoothness constraint + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + butterworth_filter_iter: int, optional + Number of iterations to run using high-pass butteworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter_iter: int, optional + Number of iterations to run using kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices_iter: int, optional + Number of iterations to run using identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object_iter: int, optional + Number of iterations where object amplitude is set to unity + tv_denoise_iter_chambolle: bool + Number of iterations with TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + switch_object_iter: int, optional + Iteration to switch object type between 'complex' and 'potential' or between + 'potential' and 'complex' + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + asnumpy = self._asnumpy + xp = self._xp + + # Reconstruction method + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + if self._verbose: + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + if reconstruction_parameter is not None: + if np.array(reconstruction_parameter).shape == (3,): + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + else: + if step_size is not None: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + if max_batch_size is not None: + xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + if reset: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + self._exit_waves = None + self._object_type = self._object_type_initial + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] + self._exit_waves = None + + # main loop + for a0 in tqdmnd( + max_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if a0 == switch_object_iter: + if self._object_type == "potential": + self._object_type = "complex" + self._object = xp.exp(1j * self._object) + elif self._object_type == "complex": + self._object_type = "potential" + self._object = xp.angle(self._object) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( + self._num_diffraction_patterns + ) + positions_px = self._positions_px.copy()[shuffled_indices] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[shuffled_indices[start:end]] + + # forward operator + ( + propagated_probes, + object_patches, + self._transmitted_probes, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + self._probe, + amplitudes, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + propagated_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + self._probe[0], + self._transmitted_probes[:, 0], + amplitudes, + self._positions_px, + positions_step_size, + constrain_position_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._positions_px = positions_px.copy()[unshuffled_indices] + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=a0 < kz_regularization_filter_iter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=a0 < identical_slices_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _visualize_last_iteration_figax( + self, + fig, + object_ax, + convergence_ax, + cbar: bool, + padding: int = 0, + **kwargs, + ): + """ + Displays last reconstructed object on a given fig/ax. + + Parameters + -------- + fig: Figure + Matplotlib figure object_ax lives in + object_ax: Axes + Matplotlib axes to plot reconstructed object in + convergence_ax: Axes, optional + Matplotlib axes to plot convergence plot in + cbar: bool, optional + If true, displays a colorbar + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + cmap = kwargs.pop("cmap", "magma") + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + im = object_ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(object_ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if convergence_ax is not None and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = self.error_iterations + + convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + padding: int, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + """ + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + invert = kwargs.pop("invert", False) + hue_start = kwargs.pop("hue_start", 0) + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe_array = Complex2RGB( + self.probe_fourier[0], hue_start=hue_start, invert=invert + ) + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + self.probe[0], hue_start=hue_start, invert=invert + ) + ax.set_title("Reconstructed probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + + else: + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = np.array(self.error_iterations) + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration Number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + iterations_grid: Tuple[int, int], + padding: int, + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + if iterations_grid == "auto": + num_iter = len(self.error_iterations) + + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + invert = kwargs.pop("invert", False) + hue_start = kwargs.pop("hue_start", 0) + + errors = np.array(self.error_iterations) + + objects = [] + object_type = [] + + for obj in self.object_iterations: + if np.iscomplexobj(obj): + obj = np.angle(obj) + object_type.append("phase") + else: + object_type.append("potential") + objects.append( + self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) + ) + + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + probes = self.probe_iterations + else: + total_grids = np.prod(iterations_grid) + max_iter = len(objects) - 1 + grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[grid_range[n]], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = Complex2RGB( + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0] + ) + ), + hue_start=hue_start, + invert=invert, + ) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + probes[grid_range[n]][0], hue_start=hue_start, invert=invert + ) + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + **kwargs, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], hue_start=hue_start, invert=invert + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration Number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + cbar: bool = True, + padding: int = 0, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + return self + + def show_fourier_probe( + self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list(self.probe_fourier) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + show_complex( + probe if len(probe) > 1 else probe[0], + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + + def show_transmitted_probe( + self, + plot_fourier_probe: bool = False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + transmitted_probe_intensities = xp.sum( + xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) + ) + min_intensity_transmitted = self._transmitted_probes[ + xp.argmin(transmitted_probe_intensities), 0 + ] + max_intensity_transmitted = self._transmitted_probes[ + xp.argmax(transmitted_probe_intensities), 0 + ] + mean_transmitted = self._transmitted_probes[:, 0].mean(0) + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy(self._return_fourier_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + show_complex( + probes, + title=title, + **kwargs, + ) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + vmin = np.min(rotated_object) if common_color_scale else None + vmax = np.max(rotated_object) if common_color_scale else None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + if not plot_line_profile: + fig, ax = plt.subplots() + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[0] * ms_obj.shape[1], + self.sampling[1] * ms_obj.shape[2], + 0, + ] + fig, ax = plt.subplots(2, 1) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def tune_num_slices_and_thicknesses( + self, + num_slices_guess=None, + thicknesses_guess=None, + num_slices_step_size=1, + thicknesses_step_size=20, + num_slices_values=3, + num_thicknesses_values=3, + update_defocus=False, + max_iter=5, + plot_reconstructions=True, + plot_convergence=True, + return_values=False, + **kwargs, + ): + """ + Run reconstructions over a parameters space of number of slices + and slice thicknesses. Should be run after the preprocess step. + + Parameters + ---------- + num_slices_guess: float, optional + initial starting guess for number of slices, rounds to nearest integer + if None, uses current initialized values + thicknesses_guess: float (A), optional + initial starting guess for thicknesses of slices assuming same + thickness for each slice + if None, uses current initialized values + num_slices_step_size: float, optional + size of change of number of slices for each step in parameter space + thicknesses_step_size: float (A), optional + size of change of slice thicknesses for each step in parameter space + num_slices_values: int, optional + number of number of slice values to test, must be >= 1 + num_thicknesses_values: int,optional + number of thicknesses values to test, must be >= 1 + update_defocus: bool, optional + if True, updates defocus based on estimated total thickness + max_iter: int, optional + number of iterations to run in ptychographic reconstruction + plot_reconstructions: bool, optional + if True, plot phase of reconstructed objects + plot_convergence: bool, optional + if True, plots error for each iteration for each reconstruction + return_values: bool, optional + if True, returns objects, convergence + + Returns + ------- + objects: list + reconstructed objects + convergence: np.ndarray + array of convergence values from reconstructions + """ + + # calculate number of slices and thicknesses values to test + if num_slices_guess is None: + num_slices_guess = self._num_slices + if thicknesses_guess is None: + thicknesses_guess = np.mean(self._slice_thicknesses) + + if num_slices_values == 1: + num_slices_step_size = 0 + + if num_thicknesses_values == 1: + thicknesses_step_size = 0 + + num_slices = np.linspace( + num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_values, + ) + + thicknesses = np.linspace( + thicknesses_guess + - thicknesses_step_size * (num_thicknesses_values - 1) / 2, + thicknesses_guess + + thicknesses_step_size * (num_thicknesses_values - 1) / 2, + num_thicknesses_values, + ) + + if return_values: + convergence = [] + objects = [] + + # current initialized values + current_verbose = self._verbose + current_num_slices = self._num_slices + current_thicknesses = self._slice_thicknesses + current_rotation_deg = self._rotation_best_rad * 180 / np.pi + current_transpose = self._rotation_best_transpose + current_defocus = -self._polar_parameters["C10"] + + # Gridspec to plot on + if plot_reconstructions: + if plot_convergence: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values * 2, + height_ratios=[1, 1 / 4] * num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) + ) + else: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) + ) + + fig = plt.figure(figsize=figsize) + + progress_bar = kwargs.pop("progress_bar", False) + # run loop and plot along the way + self._verbose = False + for flat_index, (slices, thickness) in enumerate( + tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") + ): + slices = int(slices) + self._num_slices = slices + self._slice_thicknesses = np.tile(thickness, slices - 1) + self._probe = None + self._object = None + if update_defocus: + defocus = current_defocus + slices / 2 * thickness + self._polar_parameters["C10"] = -defocus + + self.preprocess( + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + ) + self.reconstruct( + reset=True, + store_iterations=True if plot_convergence else False, + max_iter=max_iter, + progress_bar=progress_bar, + **kwargs, + ) + + if plot_reconstructions: + row_index, col_index = np.unravel_index( + flat_index, (num_slices_values, num_thicknesses_values) + ) + + if plot_convergence: + object_ax = fig.add_subplot(spec[row_index * 2, col_index]) + convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=convergence_ax, + cbar=True, + ) + convergence_ax.yaxis.tick_right() + else: + object_ax = fig.add_subplot(spec[row_index, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=None, + cbar=True, + ) + + object_ax.set_title( + f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" + ) + object_ax.set_xticks([]) + object_ax.set_yticks([]) + + if return_values: + objects.append(self.object) + convergence.append(self.error_iterations.copy()) + + # initialize back to pre-tuning values + self._probe = None + self._object = None + self._num_slices = current_num_slices + self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) + self._polar_parameters["C10"] = -current_defocus + self.preprocess( + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + ) + self._verbose = current_verbose + + if plot_reconstructions: + spec.tight_layout(fig) + + if return_values: + return objects, convergence + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + obj = asnumpy(obj) + if np.iscomplexobj(obj): + obj = np.angle(obj) + + obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) From 807ac1503da5bedf5f8e761827fd239a08192db3 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 22 Sep 2023 14:25:45 -0700 Subject: [PATCH 036/176] small reset bug --- .../phase/iterative_mixedstate_multislice_ptychography.py | 2 ++ py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 2 ++ py4DSTEM/process/phase/iterative_multislice_ptychography.py | 2 ++ .../process/phase/iterative_overlap_magnetic_tomography.py | 3 ++- py4DSTEM/process/phase/iterative_overlap_tomography.py | 3 ++- py4DSTEM/process/phase/iterative_simultaneous_ptychography.py | 2 ++ py4DSTEM/process/phase/iterative_singleslice_ptychography.py | 2 ++ 7 files changed, 14 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index acb9f12a2..6620cf71f 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -2249,6 +2249,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index d066c7f3f..6acbf7fc3 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1598,6 +1598,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 3e36978e4..f663b9905 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -2118,6 +2118,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 712a35647..3aac7edc7 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -2326,12 +2326,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index fd9af0bb2..79a477da2 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -2143,12 +2143,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index a19fc82d3..584edfa6c 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -2775,6 +2775,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = (None,) * self._num_sim_measurements self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 97c7a3e5d..97e607cb6 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -1509,6 +1509,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( From 8b35dbb3ff1955bfec50abc4b932a12ed62f211a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 14:46:25 -0700 Subject: [PATCH 037/176] fixed NaN bug --- .../process/phase/iterative_base_class.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index ae4c92d4b..b7aa61af0 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -484,9 +484,14 @@ def _calculate_intensities_center_of_mass( ) if com_shifts is None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + com_shifts = fit_origin( - (asnumpy(com_measured_x), asnumpy(com_measured_y)), + (com_measured_x_np, com_measured_y_np), fitfunction=fit_function, + mask=finite_mask, ) # Fit function to center of mass @@ -494,12 +499,12 @@ def _calculate_intensities_center_of_mass( com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units - com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[ - 0 - ] - com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[ - 1 - ] + com_normalized_x = ( + xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + ) + com_normalized_y = ( + xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + ) return ( com_measured_x, From 6cb62f1baac4e6e9bf6b773ffb43419d3ff5ffd4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 15:53:47 -0700 Subject: [PATCH 038/176] changed complex plotting --- py4DSTEM/visualize/vis_special.py | 122 +++++++++++++----------------- setup.py | 3 +- 2 files changed, 55 insertions(+), 70 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 43cf7fff8..d46788472 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -1,6 +1,5 @@ from matplotlib import cm, colors as mcolors, pyplot as plt import numpy as np -from matplotlib.colors import hsv_to_rgb from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi @@ -18,9 +17,7 @@ from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes,ax_addaxes_QtoR - - - +from colorspacious import cspace_convert def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, color_ann='y',color_ell='r',alpha_ann=0.2,alpha_ell=0.7, @@ -717,15 +714,21 @@ def show_selected_dps(datacube,positions,im,bragg_pos=None, get_pointcolors=lambda i:colors[i], **kwargs) -def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=False): +def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float) : power to raise amplitude to """ - amp = np.abs(complex_data) + if power is None: + norm = mcolors.Normalize() + else: + norm = mcolors.PowerNorm(power) + + amp = norm(np.abs(complex_data)).data + phase = np.angle(complex_data) + if np.isclose(np.max(amp),np.min(amp)): if vmin is None: vmin = 0 @@ -746,35 +749,37 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=Fals amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) + + J = amp*100 + C = np.where(J<61.5,98*J/123,1400/11-14*J/11) + h = np.rad2deg(phase)+180 - phase = np.angle(complex_data) + np.deg2rad(hue_start) - amp /= np.max(amp) - rgb = np.zeros(phase.shape +(3,)) - rgb[...,0] = 0.5*(np.sin(phase)+1)*amp - rgb[...,1] = 0.5*(np.sin(phase+np.pi/2)+1)*amp - rgb[...,2] = 0.5*(-np.sin(phase)+1)*amp + JCh = np.stack((J,C,h), axis=-1) + rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - return 1-rgb if invert else rgb + return rgb -def add_colorbar_arg(cax, vmin = None, vmax = None, hue_start = 0, invert = False): +def add_colorbar_arg(cax, c = 49, j = 61.5): """ - cax : axis to add cbar too - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + cax : axis to add cbar to + c : constant chroma value + j : constant luminance value """ - z = np.exp(1j * np.linspace(-np.pi, np.pi, 200)) - rgb_vals = Complex2RGB(z, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert) + + h = np.linspace(0, 360, 256,endpoint=False) + J = np.full_like(h,j) + C = np.full_like(h,c) + JCh = np.stack((J,C,h), axis=-1) + rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) - cb1 = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) - cb1.set_label("arg", rotation=0, ha="center", va="bottom") - cb1.ax.yaxis.set_label_coords(0.5, 1.01) - cb1.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) - cb1.set_ticklabels( + cb.set_label("arg", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) + cb.set_ticklabels( [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) @@ -787,8 +792,7 @@ def show_complex( pixelunits="pixels", pixelsize=1, returnfig=False, - hue_start = 0, - invert=False, + power=None, **kwargs ): """ @@ -801,13 +805,12 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels - cbar (bool, optional) : if True, include color wheel + cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - hue_start (float, optional) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float,optional) : power to raise amplitude to Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -817,12 +820,12 @@ def show_complex( ar_complex = ar_complex[0] if (isinstance(ar_complex,list) and len(ar_complex) == 1) else ar_complex if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): - rgb = [Complex2RGB(ar, vmin, vmax, hue_start = hue_start, invert=invert) for sublist in ar_complex for ar in sublist] + rgb = [Complex2RGB(ar, vmin, vmax, power=power) for sublist in ar_complex for ar in sublist] H = len(ar_complex) W = len(ar_complex[0]) else: - rgb = [Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) for ar in ar_complex] + rgb = [Complex2RGB(ar, vmin, vmax, power=power) for ar in ar_complex] if len(rgb[0].shape) == 4: H = len(ar_complex) W = rgb[0].shape[0] @@ -831,7 +834,7 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, hue_start=hue_start, invert=invert) + rgb = Complex2RGB(ar_complex, vmin, vmax, power=power) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -882,37 +885,18 @@ def show_complex( add_scalebar(ax, scalebar) # add color bar - if cbar == True: - ax0 = fig.add_axes([1, 0.35, 0.3, 0.3]) - - # create wheel - AA = 1000 - kx = np.fft.fftshift(np.fft.fftfreq(AA)) - ky = np.fft.fftshift(np.fft.fftfreq(AA)) - kya, kxa = np.meshgrid(ky, kx) - kra = (kya**2 + kxa**2) ** 0.5 - ktheta = np.arctan2(-kxa, kya) - ktheta = kra * np.exp(1j * ktheta) - - # convert to hsv - rgb = Complex2RGB(ktheta, 0, 0.4, hue_start = hue_start, invert=invert) - ind = kra > 0.4 - rgb[ind] = [1, 1, 1] - - # plot - ax0.imshow(rgb) - - # add axes - ax0.axhline(AA / 2, 0, AA, color="k") - ax0.axvline(AA / 2, 0, AA, color="k") - ax0.axis("off") - - label_size = 16 - - ax0.text(AA, AA / 2, 1, fontsize=label_size) - ax0.text(AA / 2, 0, "i", fontsize=label_size) - ax0.text(AA / 2, AA, "-i", fontsize=label_size) - ax0.text(0, AA / 2, -1, fontsize=label_size) - - if returnfig == True: + if cbar: + if is_grid: + for ax_flat in ax.flatten(): + divider = make_axes_locatable(ax_flat) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb) + else: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb) + + fig.tight_layout() + + if returnfig: return fig, ax diff --git a/setup.py b/setup.py index d8baff354..40255d5bf 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,8 @@ 'dask >= 2.3.0', 'distributed >= 2.3.0', 'emdfile >= 0.0.10', - 'pylops >= 2.1.0' + 'pylops >= 2.1.0', + 'colorspacious >= 1.1.2', ], extras_require={ 'ipyparallel': ['ipyparallel >= 6.2.4', 'dill >= 0.3.3'], From ad79416f4545659ddf39bcad05e6d94cbc5e78e1 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 17:15:24 -0700 Subject: [PATCH 039/176] updated complex plotting phase calls --- ...tive_mixedstate_multislice_ptychography.py | 85 ++++++------------- .../iterative_mixedstate_ptychography.py | 48 ++++------- .../iterative_multislice_ptychography.py | 64 ++++---------- .../iterative_overlap_magnetic_tomography.py | 32 ++----- .../phase/iterative_overlap_tomography.py | 61 +++++-------- .../iterative_simultaneous_ptychography.py | 37 +++----- .../iterative_singleslice_ptychography.py | 54 ++++-------- py4DSTEM/visualize/vis_special.py | 8 +- 8 files changed, 124 insertions(+), 265 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6620cf71f..ea10050dd 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -609,19 +609,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered[0], - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -633,10 +625,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -658,38 +647,33 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe[0]") + ax1.set_title("Initial probe[0] intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax2) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe[0]") + ax2.set_title("Propagated probe[0] intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -701,7 +685,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1126,23 +1110,17 @@ def _projection_sets_adjoint( ) if self._object_type == "potential": - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.real( - -1j - * xp.conj(obj) - * xp.conj(probe[:, i_probe]) - * exit_waves_copy[:, i_probe] - ) + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] ) ) else: - object_update += ( - step_size - * self._sum_overlapping_patches_bincounts( - xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] - ) + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] ) probe_normalization = 1 / xp.sqrt( @@ -2519,8 +2497,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -2615,30 +2591,25 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert - ) + probe_array = Complex2RGB(self.probe_fourier[0]) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert - ) - ax.set_title("Reconstructed probe[0]") + probe_array = Complex2RGB(self.probe[0], power=2) + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -2671,10 +2642,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2746,8 +2717,6 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) errors = np.array(self.error_iterations) @@ -2852,15 +2821,14 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, ) ax.set_title(f"Iter: {grid_range[n]} probe[0]") ax.set_ylabel("x [A]") @@ -2869,12 +2837,11 @@ def _visualize_all_iterations( im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2886,7 +2853,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 6acbf7fc3..6dbccff06 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -505,19 +505,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -540,23 +532,19 @@ def preprocess( axs[i].imshow( complex_probe_rgb[i], extent=probe_extent, - **kwargs, ) axs[i].set_ylabel("x [A]") axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial Probe[{i}]") + axs[i].set_title(f"Initial probe[{i}] intensity") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax) axs[-1].imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) axs[-1].scatter( self.positions[:, 1], @@ -568,7 +556,7 @@ def preprocess( axs[-1].set_xlabel("y [A]") axs[-1].set_xlim((extent[0], extent[1])) axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object Field of View") + axs[-1].set_title("Object field of view") fig.tight_layout() @@ -1849,8 +1837,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -1943,29 +1929,29 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert + self.probe_fourier[0], ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert + self.probe[0], + power=2, ) - ax.set_title("Reconstructed probe[0]") + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -1998,10 +1984,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2177,24 +2163,22 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: @@ -2211,7 +2195,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index f663b9905..b3614c0ad 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -577,19 +577,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -601,10 +593,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -626,38 +615,33 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -669,7 +653,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -2387,8 +2371,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -2484,29 +2466,26 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert - ) - ax.set_title("Reconstructed probe") + probe_array = Complex2RGB(self.probe, power=2) + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -2539,10 +2518,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2614,8 +2593,6 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) errors = np.array(self.error_iterations) @@ -2720,29 +2697,24 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert - ) - ax.set_title(f"Iter: {grid_range[n]} probe") + probe_array = Complex2RGB(probes[grid_range[n]], power=2) + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2754,7 +2726,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 3aac7edc7..d2934497c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -808,19 +808,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -832,10 +824,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -857,38 +846,35 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -900,7 +886,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -3092,7 +3078,7 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[-1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 79a477da2..3d5982e9e 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -748,19 +748,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) # propagated @@ -772,10 +764,7 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -797,38 +786,35 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -840,7 +826,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -2593,8 +2579,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) asnumpy = self._asnumpy @@ -2696,16 +2680,17 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2718,7 +2703,9 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg( + ax_cb, + ) else: ax = fig.add_subplot(spec[0]) im = ax.imshow( @@ -2747,10 +2734,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2834,8 +2821,6 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) errors = np.array(self.error_iterations) @@ -2950,29 +2935,27 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2984,7 +2967,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 584edfa6c..85e9a0b18 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -746,19 +746,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -780,23 +772,21 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax2.scatter( self.positions[:, 1], @@ -808,7 +798,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -3061,8 +3051,6 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj_e = np.angle(self.object[0]) @@ -3188,29 +3176,26 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert - ) - ax.set_title("Reconstructed probe") + probe_array = Complex2RGB(self.probe, power=2) + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: # Electrostatic Object @@ -3261,10 +3246,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 97e607cb6..3843da983 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -474,19 +474,11 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, ) extent = [ @@ -508,23 +500,19 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="gray", ) ax2.scatter( self.positions[:, 1], @@ -536,7 +524,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -1762,8 +1750,6 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj = np.angle(self.object) @@ -1856,29 +1842,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb) else: ax = fig.add_subplot(spec[0]) @@ -1911,10 +1897,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1985,9 +1971,7 @@ def _visualize_all_iterations( else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "inferno") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + cmap = kwargs.pop("cmap", "magma") errors = np.array(self.error_iterations) @@ -2091,8 +2075,6 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2100,21 +2082,21 @@ def _visualize_all_iterations( else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], ) if plot_convergence: @@ -2126,7 +2108,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index d46788472..722b55800 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -728,7 +728,7 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = norm(np.abs(complex_data)).data phase = np.angle(complex_data) - + if np.isclose(np.max(amp),np.min(amp)): if vmin is None: vmin = 0 @@ -736,9 +736,9 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): vmax = np.max(amp) else: if vmin is None: - vmin = 0.02 + vmin = 0.025 if vmax is None: - vmax = 0.98 + vmax = 0.975 vals = np.sort(amp[~np.isnan(amp)]) ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") @@ -750,7 +750,7 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) - J = amp*100 + J = amp*61.5 # Note we restrict luminance to 61.5 C = np.where(J<61.5,98*J/123,1400/11-14*J/11) h = np.rad2deg(phase)+180 From 356a3a180e4699f2c1db703b250b0d856b76de6a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 23 Sep 2023 17:58:08 -0700 Subject: [PATCH 040/176] adding complex CoM plotting and various dpc plotting bugs --- .../process/phase/iterative_base_class.py | 81 +++++++++++++++---- py4DSTEM/process/phase/iterative_dpc.py | 48 ++++++----- py4DSTEM/visualize/vis_special.py | 6 +- 3 files changed, 96 insertions(+), 39 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index b7aa61af0..96b3d5088 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -295,13 +295,14 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) self._scan_sampling = (1.0, 1.0) self._scan_units = ("pixels",) * 2 @@ -359,13 +360,14 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) self._angular_sampling = (1.0, 1.0) self._angular_units = ("pixels",) * 2 @@ -1134,6 +1136,57 @@ def _normalize_diffraction_intensities( return amplitudes, mean_intensity + def show_complex_CoM( + self, + com=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot complex-valued CoM image + + Parameters + ---------- + + com = (CoM_x, CoM_y) tuple + If None is specified, uses (self.com_x, self.com_y) instead + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A + pixelsize: float, optional + default is scan sampling + """ + + if com is None: + com = (self.com_x, self.com_y) + + if pixelsize is None: + pixelsize = self._scan_sampling[0] + if pixelunits is None: + pixelunits = r"$\AA$" + + figsize = kwargs.pop("figsize", (6, 6)) + fig, ax = plt.subplots(figsize=figsize) + + complex_com = com[0] + 1j * com[1] + + show_complex( + complex_com, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): """ diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4c80ed177..20796160a 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -718,24 +718,26 @@ def reconstruct( xp = self._xp asnumpy = self._asnumpy - if reset is None and hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - self.error_iterations = [] if reset: self.error = np.inf + self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] self.error = getattr(self, "error", np.inf) @@ -770,7 +772,8 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - print(f"Iteration {a0}, step reduced to {self._step_size}") + if self._verbose: + print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -807,10 +810,11 @@ def reconstruct( self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: - warnings.warn( - f"Step-size has decreased below stopping criterion {stopping_criterion}.", - UserWarning, - ) + if self._verbose: + warnings.warn( + f"Step-size has decreased below stopping criterion {stopping_criterion}.", + UserWarning, + ) # crop result self._object_phase = self._padded_object_phase[ @@ -840,7 +844,7 @@ def _visualize_last_iteration( If true, the NMSE error plot is displayed """ - figsize = kwargs.pop("figsize", (8, 8)) + figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") if plot_convergence: @@ -862,7 +866,7 @@ def _visualize_last_iteration( im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC Phase Reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") if cbar: divider = make_axes_locatable(ax1) @@ -870,11 +874,11 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "_error_iterations"): - errors = self._error_iterations + if plot_convergence: + errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -979,7 +983,7 @@ def _visualize_all_iterations( if plot_convergence: ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -990,7 +994,7 @@ def visualize( fig=None, iterations_grid: Tuple[int, int] = None, plot_convergence: bool = True, - cbar: bool = False, + cbar: bool = True, **kwargs, ): """ diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 722b55800..0792ee3c8 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -736,9 +736,9 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): vmax = np.max(amp) else: if vmin is None: - vmin = 0.025 + vmin = 0.0 if vmax is None: - vmax = 0.975 + vmax = 1.0 vals = np.sort(amp[~np.isnan(amp)]) ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") @@ -751,7 +751,7 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = np.where(amp > vmax, vmax, amp) J = amp*61.5 # Note we restrict luminance to 61.5 - C = np.where(J<61.5,98*J/123,1400/11-14*J/11) + C = np.where(J<61.5,98*J/123,1400/11-14*J/11) # Min uniform chroma h = np.rad2deg(phase)+180 JCh = np.stack((J,C,h), axis=-1) From 2bc9da98a6ca3bb0388cd88a7a4a3dd9f35d2dd5 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 29 Sep 2023 17:47:49 -0700 Subject: [PATCH 041/176] parallax descan correct --- py4DSTEM/process/phase/iterative_parallax.py | 36 ++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index b23fe2cae..0d2f4cfc9 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -11,6 +11,7 @@ from emdfile import Custom, tqdmnd from matplotlib.gridspec import GridSpec from py4DSTEM import DataCube +from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom @@ -112,6 +113,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, + descan_correct: bool = False, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, @@ -134,6 +136,8 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + descan_correct: float, optional + If True, aligns bright field stack based on measured descan rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 @@ -180,6 +184,38 @@ def preprocess( raise ValueError( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct + if descan_correct: + from py4DSTEM.process.phase import DPCReconstruction + + dpc = DPCReconstruction( + energy=self._energy, + datacube=self._datacube, + verbose=False, + ).preprocess( + force_com_rotation=0, + force_com_transpose=False, + plot_center_of_mass=False, + ) + + intensities_shifted = self._intensities.copy() + + center_x = np.mean(dpc._com_measured_x) + center_y = np.mean(dpc._com_measured_y) + for rx in range(intensities_shifted.shape[0]): + for ry in range(intensities_shifted.shape[1]): + intensity_shifted = get_shifted_ar( + self._intensities[rx, ry], + -dpc._com_measured_x[rx, ry] + center_x, + -dpc._com_measured_y[rx, ry] + center_y, + bilinear=True, + device="cpu", + ) + + intensities_shifted[rx, ry] = intensity_shifted + + self._intensities = intensities_shifted + self._dp_mean = intensities_shifted.mean((0, 1)) # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) From 3a99d5ae0cadf48d5beb8e829797bab1195557a4 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 29 Sep 2023 21:01:55 -0700 Subject: [PATCH 042/176] complex plotting improvements, formatting --- py4DSTEM/visualize/vis_special.py | 921 +++++++++++++++++++----------- 1 file changed, 577 insertions(+), 344 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 0792ee3c8..6dd980bce 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -15,13 +15,25 @@ add_scalebar, ) from py4DSTEM.visualize.vis_grid import show_image_grid -from py4DSTEM.visualize.vis_RQ import ax_addaxes,ax_addaxes_QtoR +from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR from colorspacious import cspace_convert -def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, - color_ann='y',color_ell='r',alpha_ann=0.2,alpha_ell=0.7, - linewidth_ann=2,linewidth_ell=2,returnfig=False,**kwargs): + +def show_elliptical_fit( + ar, + fitradii, + p_ellipse, + fill=True, + color_ann="y", + color_ell="r", + alpha_ann=0.2, + alpha_ell=0.7, + linewidth_ann=2, + linewidth_ell=2, + returnfig=False, + **kwargs +): """ Plots an elliptical curve over its annular fit region. @@ -39,35 +51,55 @@ def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, linewidth_ann: linewidth_ell: """ - Ri,Ro = fitradii - qx0,qy0,a,b,theta = p_ellipse - fig,ax = show(ar, - annulus={'center':(qx0,qy0), - 'radii':(Ri,Ro), - 'fill':fill, - 'color':color_ann, - 'alpha':alpha_ann, - 'linewidth':linewidth_ann}, - ellipse={'center':(qx0,qy0), - 'a':a, - 'b':b, - 'theta':theta, - 'color':color_ell, - 'alpha':alpha_ell, - 'linewidth':linewidth_ell}, - returnfig=True,**kwargs) + Ri, Ro = fitradii + qx0, qy0, a, b, theta = p_ellipse + fig, ax = show( + ar, + annulus={ + "center": (qx0, qy0), + "radii": (Ri, Ro), + "fill": fill, + "color": color_ann, + "alpha": alpha_ann, + "linewidth": linewidth_ann, + }, + ellipse={ + "center": (qx0, qy0), + "a": a, + "b": b, + "theta": theta, + "color": color_ell, + "alpha": alpha_ell, + "linewidth": linewidth_ell, + }, + returnfig=True, + **kwargs, + ) if not returnfig: plt.show() return else: - return fig,ax + return fig, ax -def show_amorphous_ring_fit(dp,fitradii,p_dsg,N=12,cmap=('gray','gray'), - fitborder=True,fitbordercolor='k',fitborderlw=0.5, - scaling='log',ellipse=False,ellipse_color='r', - ellipse_alpha=0.7,ellipse_lw=2,returnfig=False,**kwargs): +def show_amorphous_ring_fit( + dp, + fitradii, + p_dsg, + N=12, + cmap=("gray", "gray"), + fitborder=True, + fitbordercolor="k", + fitborderlw=0.5, + scaling="log", + ellipse=False, + ellipse_color="r", + ellipse_alpha=0.7, + ellipse_lw=2, + returnfig=False, + **kwargs +): """ Display a diffraction pattern with a fit to its amorphous ring, interleaving the data and the fit in a pinwheel pattern. @@ -90,75 +122,112 @@ def show_amorphous_ring_fit(dp,fitradii,p_dsg,N=12,cmap=('gray','gray'), """ from py4DSTEM.process.calibration import double_sided_gaussian from py4DSTEM.process.utils import convert_ellipse_params - assert(len(p_dsg)==11) - assert(isinstance(N,(int,np.integer))) - if isinstance(cmap,tuple): - cmap_data,cmap_fit = cmap[0],cmap[1] + + assert len(p_dsg) == 11 + assert isinstance(N, (int, np.integer)) + if isinstance(cmap, tuple): + cmap_data, cmap_fit = cmap[0], cmap[1] else: - cmap_data,cmap_fit = cmap,cmap - Q_Nx,Q_Ny = dp.shape - qmin,qmax = fitradii + cmap_data, cmap_fit = cmap, cmap + Q_Nx, Q_Ny = dp.shape + qmin, qmax = fitradii # Make coords - qx0,qy0 = p_dsg[6],p_dsg[7] - qyy,qxx = np.meshgrid(np.arange(Q_Ny),np.arange(Q_Nx)) - qx,qy = qxx-qx0,qyy-qy0 - q = np.hypot(qx,qy) - theta = np.arctan2(qy,qx) + qx0, qy0 = p_dsg[6], p_dsg[7] + qyy, qxx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx)) + qx, qy = qxx - qx0, qyy - qy0 + q = np.hypot(qx, qy) + theta = np.arctan2(qy, qx) # Make mask - thetas = np.linspace(-np.pi,np.pi,2*N+1) - pinwheel = np.zeros((Q_Nx,Q_Ny),dtype=bool) + thetas = np.linspace(-np.pi, np.pi, 2 * N + 1) + pinwheel = np.zeros((Q_Nx, Q_Ny), dtype=bool) for i in range(N): - pinwheel += (theta>thetas[2*i]) * (theta<=thetas[2*i+1]) - mask = pinwheel * (q>qmin) * (q<=qmax) + pinwheel += (theta > thetas[2 * i]) * (theta <= thetas[2 * i + 1]) + mask = pinwheel * (q > qmin) * (q <= qmax) # Get fit data fit = double_sided_gaussian(p_dsg, qxx, qyy) # Show - (fig,ax),(vmin,vmax) = show(dp,scaling=scaling,cmap=cmap_data, - mask=np.logical_not(mask),mask_color='empty', - returnfig=True,returnclipvals=True,**kwargs) - show(fit,scaling=scaling,figax=(fig,ax),clipvals='manual',min=vmin,max=vmax, - cmap=cmap_fit,mask=mask,mask_color='empty',**kwargs) + (fig, ax), (vmin, vmax) = show( + dp, + scaling=scaling, + cmap=cmap_data, + mask=np.logical_not(mask), + mask_color="empty", + returnfig=True, + returnclipvals=True, + **kwargs, + ) + show( + fit, + scaling=scaling, + figax=(fig, ax), + clipvals="manual", + min=vmin, + max=vmax, + cmap=cmap_fit, + mask=mask, + mask_color="empty", + **kwargs, + ) if fitborder: - if N%2==1: thetas += (thetas[1]-thetas[0])/2 - if (N//2%2)==0: thetas = np.roll(thetas,-1) + if N % 2 == 1: + thetas += (thetas[1] - thetas[0]) / 2 + if (N // 2 % 2) == 0: + thetas = np.roll(thetas, -1) for i in range(N): - ax.add_patch(Wedge((qy0,qx0),qmax,np.degrees(thetas[2*i]), - np.degrees(thetas[2*i+1]),width=qmax-qmin,fill=None, - color=fitbordercolor,lw=fitborderlw)) + ax.add_patch( + Wedge( + (qy0, qx0), + qmax, + np.degrees(thetas[2 * i]), + np.degrees(thetas[2 * i + 1]), + width=qmax - qmin, + fill=None, + color=fitbordercolor, + lw=fitborderlw, + ) + ) # Add ellipse overlay if ellipse: - A,B,C = p_dsg[8],p_dsg[9],p_dsg[10] - a,b,theta = convert_ellipse_params(A,B,C) - ellipse={'center':(qx0,qy0),'a':a,'b':b,'theta':theta, - 'color':ellipse_color,'alpha':ellipse_alpha,'linewidth':ellipse_lw} - add_ellipses(ax,ellipse) + A, B, C = p_dsg[8], p_dsg[9], p_dsg[10] + a, b, theta = convert_ellipse_params(A, B, C) + ellipse = { + "center": (qx0, qy0), + "a": a, + "b": b, + "theta": theta, + "color": ellipse_color, + "alpha": ellipse_alpha, + "linewidth": ellipse_lw, + } + add_ellipses(ax, ellipse) if not returnfig: plt.show() return else: - return fig,ax + return fig, ax def show_qprofile( q, intensity, ymax=None, - figsize=(12,4), + figsize=(12, 4), returnfig=False, - color='k', - xlabel='q (pixels)', - ylabel='Intensity (A.U.)', + color="k", + xlabel="q (pixels)", + ylabel="Intensity (A.U.)", labelsize=16, ticklabelsize=14, grid=True, label=None, - **kwargs): + **kwargs +): """ Plots a diffraction space radial profile. Params: @@ -174,148 +243,167 @@ def show_qprofile( label a legend label for the plotted curve """ if ymax is None: - ymax = np.max(intensity)*1.05 + ymax = np.max(intensity) * 1.05 - fig,ax = plt.subplots(figsize=figsize) - ax.plot(q,intensity,color=color,label=label) + fig, ax = plt.subplots(figsize=figsize) + ax.plot(q, intensity, color=color, label=label) ax.grid(grid) - ax.set_ylim(0,ymax) - ax.tick_params(axis='x',labelsize=ticklabelsize) + ax.set_ylim(0, ymax) + ax.tick_params(axis="x", labelsize=ticklabelsize) ax.set_yticklabels([]) - ax.set_xlabel(xlabel,size=labelsize) - ax.set_ylabel(ylabel,size=labelsize) + ax.set_xlabel(xlabel, size=labelsize) + ax.set_ylabel(ylabel, size=labelsize) if not returnfig: plt.show() return else: - return fig,ax + return fig, ax -def show_kernel( - kernel, - R, - L, - W, - figsize=(12,6), - returnfig=False, - **kwargs): + +def show_kernel(kernel, R, L, W, figsize=(12, 6), returnfig=False, **kwargs): """ Plots, side by side, the probe kernel and its line profile. R is the kernel plot's window size. L and W are the length and width of the lineprofile. """ - lineprofile_1 = np.concatenate([ - np.sum(kernel[-L:,:W],axis=1), - np.sum(kernel[:L,:W],axis=1) - ]) - lineprofile_2 = np.concatenate([ - np.sum(kernel[:W,-L:],axis=0), - np.sum(kernel[:W,:L],axis=0) - ]) - - im_kernel = np.vstack([ - np.hstack([ - kernel[-int(R):,-int(R):], - kernel[-int(R):,:int(R)] - ]), - np.hstack([ - kernel[:int(R),-int(R):], - kernel[:int(R),:int(R)] - ]), - ]) - - fig,axs = plt.subplots(1,2,figsize=figsize) - axs[0].matshow(im_kernel,cmap='gray') - axs[0].plot( - np.ones(2*R)*R, - np.arange(2*R), - c='r') - axs[0].plot( - np.arange(2*R), - np.ones(2*R)*R, - c='c') - - - axs[1].plot( - np.arange(len(lineprofile_1)), - lineprofile_1, - c='r') - axs[1].plot( - np.arange(len(lineprofile_2)), - lineprofile_2, - c='c') + lineprofile_1 = np.concatenate( + [np.sum(kernel[-L:, :W], axis=1), np.sum(kernel[:L, :W], axis=1)] + ) + lineprofile_2 = np.concatenate( + [np.sum(kernel[:W, -L:], axis=0), np.sum(kernel[:W, :L], axis=0)] + ) + + im_kernel = np.vstack( + [ + np.hstack([kernel[-int(R) :, -int(R) :], kernel[-int(R) :, : int(R)]]), + np.hstack([kernel[: int(R), -int(R) :], kernel[: int(R), : int(R)]]), + ] + ) + + fig, axs = plt.subplots(1, 2, figsize=figsize) + axs[0].matshow(im_kernel, cmap="gray") + axs[0].plot(np.ones(2 * R) * R, np.arange(2 * R), c="r") + axs[0].plot(np.arange(2 * R), np.ones(2 * R) * R, c="c") + + axs[1].plot(np.arange(len(lineprofile_1)), lineprofile_1, c="r") + axs[1].plot(np.arange(len(lineprofile_2)), lineprofile_2, c="c") if not returnfig: plt.show() return else: - return fig,axs + return fig, axs -def show_voronoi(ar,x,y,color_points='r',color_lines='w',max_dist=None, - returnfig=False,**kwargs): + +def show_voronoi( + ar, + x, + y, + color_points="r", + color_lines="w", + max_dist=None, + returnfig=False, + **kwargs +): """ words """ from py4DSTEM.process.utils import get_voronoi_vertices - Nx,Ny = ar.shape - points = np.vstack((x,y)).T + + Nx, Ny = ar.shape + points = np.vstack((x, y)).T voronoi = Voronoi(points) - vertices = get_voronoi_vertices(voronoi,Nx,Ny) + vertices = get_voronoi_vertices(voronoi, Nx, Ny) if max_dist is None: - fig,ax = show(ar,returnfig=True,**kwargs) + fig, ax = show(ar, returnfig=True, **kwargs) else: - centers = [(x[i],y[i]) for i in range(len(x))] - fig,ax = show(ar,returnfig=True,**kwargs, - circle={'center':centers,'R':max_dist,'fill':False,'color':color_points}) + centers = [(x[i], y[i]) for i in range(len(x))] + fig, ax = show( + ar, + returnfig=True, + **kwargs, + circle={ + "center": centers, + "R": max_dist, + "fill": False, + "color": color_points, + }, + ) - ax.scatter(voronoi.points[:,1],voronoi.points[:,0],color=color_points) + ax.scatter(voronoi.points[:, 1], voronoi.points[:, 0], color=color_points) for region in range(len(vertices)): vertices_curr = vertices[region] for i in range(len(vertices_curr)): - x0,y0 = vertices_curr[i,:] - xf,yf = vertices_curr[(i+1)%len(vertices_curr),:] - ax.plot((y0,yf),(x0,xf),color=color_lines) - ax.set_xlim([0,Ny]) - ax.set_ylim([0,Nx]) + x0, y0 = vertices_curr[i, :] + xf, yf = vertices_curr[(i + 1) % len(vertices_curr), :] + ax.plot((y0, yf), (x0, xf), color=color_lines) + ax.set_xlim([0, Ny]) + ax.set_ylim([0, Nx]) plt.gca().invert_yaxis() if not returnfig: plt.show() return else: - return fig,ax + return fig, ax + -def show_class_BPs(ar,x,y,s,s2,color='r',color2='y',**kwargs): +def show_class_BPs(ar, x, y, s, s2, color="r", color2="y", **kwargs): """ words """ N = len(x) - assert(N==len(y)==len(s)) + assert N == len(y) == len(s) - fig,ax = show(ar,returnfig=True,**kwargs) - ax.scatter(y,x,s=s2,color=color2) - ax.scatter(y,x,s=s,color=color) + fig, ax = show(ar, returnfig=True, **kwargs) + ax.scatter(y, x, s=s2, color=color2) + ax.scatter(y, x, s=s, color=color) plt.show() return -def show_class_BPs_grid(ar,H,W,x,y,get_s,s2,color='r',color2='y',returnfig=False, - axsize=(6,6),titlesize=0,get_bordercolor=None,**kwargs): + +def show_class_BPs_grid( + ar, + H, + W, + x, + y, + get_s, + s2, + color="r", + color2="y", + returnfig=False, + axsize=(6, 6), + titlesize=0, + get_bordercolor=None, + **kwargs +): """ words """ - fig,axs = show_image_grid(lambda i:ar,H,W,axsize=axsize,titlesize=titlesize, - get_bordercolor=get_bordercolor,returnfig=True,**kwargs) + fig, axs = show_image_grid( + lambda i: ar, + H, + W, + axsize=axsize, + titlesize=titlesize, + get_bordercolor=get_bordercolor, + returnfig=True, + **kwargs, + ) for i in range(H): for j in range(W): - ax = axs[i,j] - N = i*W+j + ax = axs[i, j] + N = i * W + j s = get_s(N) - ax.scatter(y,x,s=s2,color=color2) - ax.scatter(y,x,s=s,color=color) + ax.scatter(y, x, s=s2, color=color2) + ax.scatter(y, x, s=s, color=color) if not returnfig: plt.show() return else: - return fig,axs + return fig, axs + def show_strain( strainmap, @@ -323,10 +411,10 @@ def show_strain( vrange_theta, vrange_exy=None, vrange_eyy=None, - flip_theta = False, + flip_theta=False, bkgrd=True, - show_cbars=('exx','eyy','exy','theta'), - bordercolor='k', + show_cbars=("exx", "eyy", "exy", "theta"), + bordercolor="k", borderwidth=1, titlesize=24, ticklabelsize=16, @@ -339,20 +427,21 @@ def show_strain( xaxis_y=0, axes_length=10, axes_width=1, - axes_color='r', - xaxis_space='Q', + axes_color="r", + xaxis_space="Q", labelaxes=True, QR_rotation=0, axes_labelsize=12, - axes_labelcolor='r', - axes_plots=('exx'), - cmap='RdBu_r', + axes_labelcolor="r", + axes_plots=("exx"), + cmap="RdBu_r", layout=0, - figsize=(12,12), - returnfig=False): + figsize=(12, 12), + returnfig=False, +): """ Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') + masking each image with strainmap.get_slice('mask') Args: strainmap (RealSlice): @@ -360,7 +449,7 @@ def show_strain( vrange_theta (length 2 list or tuple): vrange_exy (length 2 list or tuple): vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle + flip_theta (bool): if True, take negative of angle bkgrd (bool): show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a tuple containing any, all, or none of ('exx','eyy','exy','theta'). @@ -394,11 +483,11 @@ def show_strain( returnfig (bool): """ # Lookup table for different layouts - assert(layout in (0,1,2)) + assert layout in (0, 1, 2) layout_lookup = { - 0:['left','right','left','right'], - 1:['bottom','bottom','bottom','bottom'], - 2:['right','right','right','right'], + 0: ["left", "right", "left", "right"], + 1: ["bottom", "bottom", "bottom", "bottom"], + 2: ["right", "right", "right", "right"], } layout_p = layout_lookup[layout] @@ -407,141 +496,204 @@ def show_strain( vrange_exy = vrange_exx if vrange_eyy is None: vrange_eyy = vrange_exx - for vrange in (vrange_exx,vrange_eyy,vrange_exy,vrange_theta): - assert(len(vrange)==2), 'vranges must have length 2' - vmin_exx,vmax_exx = vrange_exx[0]/100.,vrange_exx[1]/100. - vmin_eyy,vmax_eyy = vrange_eyy[0]/100.,vrange_eyy[1]/100. - vmin_exy,vmax_exy = vrange_exy[0]/100.,vrange_exy[1]/100. + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 # theta is plotted in units of degrees - vmin_theta,vmax_theta = vrange_theta[0]/(180.0/np.pi),vrange_theta[1]/(180.0/np.pi) + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) # Get images - e_xx = np.ma.array(strainmap.get_slice('e_xx').data,mask=strainmap.get_slice('mask').data==False) - e_yy = np.ma.array(strainmap.get_slice('e_yy').data,mask=strainmap.get_slice('mask').data==False) - e_xy = np.ma.array(strainmap.get_slice('e_xy').data,mask=strainmap.get_slice('mask').data==False) - theta = np.ma.array(strainmap.get_slice('theta').data,mask=strainmap.get_slice('mask').data==False) - if flip_theta == True: - theta = - theta + e_xx = np.ma.array( + strainmap.get_slice("e_xx").data, mask=strainmap.get_slice("mask").data == False + ) + e_yy = np.ma.array( + strainmap.get_slice("e_yy").data, mask=strainmap.get_slice("mask").data == False + ) + e_xy = np.ma.array( + strainmap.get_slice("e_xy").data, mask=strainmap.get_slice("mask").data == False + ) + theta = np.ma.array( + strainmap.get_slice("theta").data, + mask=strainmap.get_slice("mask").data == False, + ) + if flip_theta == True: + theta = -theta # Plot - if layout==0: - fig,((ax11,ax12),(ax21,ax22)) = plt.subplots(2,2,figsize=figsize) - elif layout==1: - fig,(ax11,ax12,ax21,ax22) = plt.subplots(1,4,figsize=figsize) + if layout == 0: + fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + elif layout == 1: + fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) else: - fig,(ax11,ax12,ax21,ax22) = plt.subplots(4,1,figsize=figsize) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) cax11 = show( e_xx, - figax=(fig,ax11), + figax=(fig, ax11), vmin=vmin_exx, vmax=vmax_exx, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) + returncax=True, + ) cax12 = show( e_yy, - figax=(fig,ax12), + figax=(fig, ax12), vmin=vmin_eyy, vmax=vmax_eyy, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) + returncax=True, + ) cax21 = show( e_xy, - figax=(fig,ax21), + figax=(fig, ax21), vmin=vmin_exy, vmax=vmax_exy, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) + returncax=True, + ) cax22 = show( theta, - figax=(fig,ax22), + figax=(fig, ax22), vmin=vmin_theta, vmax=vmax_theta, - intensity_range='absolute', + intensity_range="absolute", cmap=cmap, - returncax=True) - ax11.set_title(r'$\epsilon_{xx}$',size=titlesize) - ax12.set_title(r'$\epsilon_{yy}$',size=titlesize) - ax21.set_title(r'$\epsilon_{xy}$',size=titlesize) - ax22.set_title(r'$\theta$',size=titlesize) + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) # Add black background if bkgrd: mask = np.ma.masked_where( - strainmap.get_slice('mask').data.astype(bool), - np.zeros_like(strainmap.get_slice('mask').data)) - ax11.matshow(mask,cmap='gray') - ax12.matshow(mask,cmap='gray') - ax21.matshow(mask,cmap='gray') - ax22.matshow(mask,cmap='gray') + strainmap.get_slice("mask").data.astype(bool), + np.zeros_like(strainmap.get_slice("mask").data), + ) + ax11.matshow(mask, cmap="gray") + ax12.matshow(mask, cmap="gray") + ax21.matshow(mask, cmap="gray") + ax22.matshow(mask, cmap="gray") # Colorbars - show_cbars = np.array(['exx' in show_cbars,'eyy' in show_cbars, - 'exy' in show_cbars,'theta' in show_cbars]) + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) if np.any(show_cbars): divider11 = make_axes_locatable(ax11) divider12 = make_axes_locatable(ax12) divider21 = make_axes_locatable(ax21) divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0],size="4%",pad=0.15) - cbax12 = divider12.append_axes(layout_p[1],size="4%",pad=0.15) - cbax21 = divider21.append_axes(layout_p[2],size="4%",pad=0.15) - cbax22 = divider22.append_axes(layout_p[3],size="4%",pad=0.15) - for (ind,show_cbar,cax,cbax,vmin,vmax,tickside,tickunits) in zip( + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( range(4), show_cbars, - (cax11,cax12,cax21,cax22), - (cbax11,cbax12,cbax21,cbax22), - (vmin_exx,vmin_eyy,vmin_exy,vmin_theta), - (vmax_exx,vmax_eyy,vmax_exy,vmax_theta), - (layout_p[0],layout_p[1],layout_p[2],layout_p[3]), - ('% ',' %','% ',r' $^\circ$')): + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): if show_cbar: - ticks = np.linspace(vmin,vmax,ticknumber,endpoint=True) + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) if ind < 3: - ticklabels = np.round(np.linspace( - 100*vmin,100*vmax,ticknumber,endpoint=True),decimals=2).astype(str) + ticklabels = np.round( + np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), + decimals=2, + ).astype(str) else: - ticklabels = np.round(np.linspace( - (180/np.pi)*vmin,(180/np.pi)*vmax,ticknumber,endpoint=True),decimals=2).astype(str) - - if tickside in ('left','right'): - cb = plt.colorbar(cax,cax=cbax,ticks=ticks,orientation='vertical') - cb.ax.set_yticklabels(ticklabels,size=ticklabelsize) + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits,size=unitlabelsize,rotation=0) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) cbax.yaxis.set_label_position(tickside) else: - cb = plt.colorbar(cax,cax=cbax,ticks=ticks,orientation='horizontal') - cb.ax.set_xticklabels(ticklabels,size=ticklabelsize) + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits,size=unitlabelsize,rotation=0) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) cbax.xaxis.set_label_position(tickside) else: - cbax.axis('off') + cbax.axis("off") # Add coordinate axes if show_axes: - assert(xaxis_space in ('R','Q')), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array(['exx' in axes_plots,'eyy' in axes_plots, - 'exy' in axes_plots,'theta' in axes_plots]) - for _show,_ax in zip(show_which_axes,(ax11,ax12,ax21,ax22)): + assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" + show_which_axes = np.array( + [ + "exx" in axes_plots, + "eyy" in axes_plots, + "exy" in axes_plots, + "theta" in axes_plots, + ] + ) + for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): if _show: - if xaxis_space=='R': - ax_addaxes(_ax,xaxis_x,xaxis_y,axes_length,axes_x0,axes_y0, - width=axes_width,color=axes_color,labelaxes=labelaxes, - labelsize=axes_labelsize,labelcolor=axes_labelcolor) + if xaxis_space == "R": + ax_addaxes( + _ax, + xaxis_x, + xaxis_y, + axes_length, + axes_x0, + axes_y0, + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) else: - ax_addaxes_QtoR(_ax,xaxis_x,xaxis_y,axes_length,axes_x0,axes_y0,QR_rotation, - width=axes_width,color=axes_color,labelaxes=labelaxes, - labelsize=axes_labelsize,labelcolor=axes_labelcolor) + ax_addaxes_QtoR( + _ax, + xaxis_x, + xaxis_y, + axes_length, + axes_x0, + axes_y0, + QR_rotation, + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) # Add borders if bordercolor is not None: - for ax in (ax11,ax12,ax21,ax22): - for s in ['bottom','top','left','right']: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: ax.spines[s].set_color(bordercolor) ax.spines[s].set_linewidth(borderwidth) ax.set_xticks([]) @@ -551,54 +703,87 @@ def show_strain( plt.show() return else: - axs = ((ax11,ax12),(ax21,ax22)) - return fig,axs + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs -def show_pointlabels(ar,x,y,color='lightblue',size=20,alpha=1,returnfig=False,**kwargs): +def show_pointlabels( + ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs +): """ Show enumerated index labels for a set of points """ - fig,ax = show(ar,returnfig=True,**kwargs) - d = {'x':x,'y':y,'size':size,'color':color,'alpha':alpha} - add_pointlabels(ax,d) + fig, ax = show(ar, returnfig=True, **kwargs) + d = {"x": x, "y": y, "size": size, "color": color, "alpha": alpha} + add_pointlabels(ax, d) if returnfig: - return fig,ax + return fig, ax else: plt.show() return -def select_point(ar,x,y,i,color='lightblue',color_selected='r',size=20,returnfig=False,**kwargs): +def select_point( + ar, + x, + y, + i, + color="lightblue", + color_selected="r", + size=20, + returnfig=False, + **kwargs +): """ Show enumerated index labels for a set of points, with one selected point highlighted """ - fig,ax = show(ar,returnfig=True,**kwargs) - d1 = {'x':x,'y':y,'size':size,'color':color} - d2 = {'x':x[i],'y':y[i],'size':size,'color':color_selected,'fontweight':'bold'} - add_pointlabels(ax,d1) - add_pointlabels(ax,d2) + fig, ax = show(ar, returnfig=True, **kwargs) + d1 = {"x": x, "y": y, "size": size, "color": color} + d2 = { + "x": x[i], + "y": y[i], + "size": size, + "color": color_selected, + "fontweight": "bold", + } + add_pointlabels(ax, d1) + add_pointlabels(ax, d2) if returnfig: - return fig,ax + return fig, ax else: plt.show() return -def show_max_peak_spacing(ar,spacing,braggdirections,color='g',lw=2,returnfig=False,**kwargs): - """ Show a circle of radius `spacing` about each Bragg direction - """ - centers = [(braggdirections.data['qx'][i],braggdirections.data['qy'][i]) for i in range(braggdirections.length)] - fig,ax = show(ar,circle={'center':centers,'R':spacing,'color':color,'fill':False,'lw':lw}, - returnfig=True,**kwargs) +def show_max_peak_spacing( + ar, spacing, braggdirections, color="g", lw=2, returnfig=False, **kwargs +): + """Show a circle of radius `spacing` about each Bragg direction""" + centers = [ + (braggdirections.data["qx"][i], braggdirections.data["qy"][i]) + for i in range(braggdirections.length) + ] + fig, ax = show( + ar, + circle={ + "center": centers, + "R": spacing, + "color": color, + "fill": False, + "lw": lw, + }, + returnfig=True, + **kwargs, + ) if returnfig: - return fig,ax + return fig, ax else: plt.show() return + def show_origin_meas(data): """ Show the measured positions of the origin. @@ -608,17 +793,19 @@ def show_origin_meas(data): """ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube - if isinstance(data,tuple): - assert len(data)==2 - qx,qy = data - elif isinstance(data,DataCube): - qx,qy = data.calibration.get_origin_meas() - elif isinstance(data,Calibration): - qx,qy = data.get_origin_meas() + + if isinstance(data, tuple): + assert len(data) == 2 + qx, qy = data + elif isinstance(data, DataCube): + qx, qy = data.calibration.get_origin_meas() + elif isinstance(data, Calibration): + qx, qy = data.get_origin_meas() else: raise Exception("data must be of type Datacube or Calibration or tuple") - show_image_grid(get_ar = lambda i:[qx,qy][i],H=1,W=2,cmap='RdBu') + show_image_grid(get_ar=lambda i: [qx, qy][i], H=1, W=2, cmap="RdBu") + def show_origin_fit(data): """ @@ -630,29 +817,49 @@ def show_origin_fit(data): """ from py4DSTEM.data import Calibration from py4DSTEM.datacube import DataCube - if isinstance(data,tuple): - assert len(data)==3 - qx0_meas,qy_meas = data[0] - qx0_fit,qy0_fit = data[1] - qx0_residuals,qy0_residuals = data[2] - elif isinstance(data,DataCube): - qx0_meas,qy0_meas = data.calibration.get_origin_meas() - qx0_fit,qy0_fit = data.calibration.get_origin() - qx0_residuals,qy0_residuals = data.calibration.get_origin_residuals() - elif isinstance(data,Calibration): - qx0_meas,qy0_meas = data.get_origin_meas() - qx0_fit,qy0_fit = data.get_origin() - qx0_residuals,qy0_residuals = data.get_origin_residuals() + + if isinstance(data, tuple): + assert len(data) == 3 + qx0_meas, qy_meas = data[0] + qx0_fit, qy0_fit = data[1] + qx0_residuals, qy0_residuals = data[2] + elif isinstance(data, DataCube): + qx0_meas, qy0_meas = data.calibration.get_origin_meas() + qx0_fit, qy0_fit = data.calibration.get_origin() + qx0_residuals, qy0_residuals = data.calibration.get_origin_residuals() + elif isinstance(data, Calibration): + qx0_meas, qy0_meas = data.get_origin_meas() + qx0_fit, qy0_fit = data.get_origin() + qx0_residuals, qy0_residuals = data.get_origin_residuals() else: raise Exception("data must be of type Datacube or Calibration or tuple") - show_image_grid(get_ar = lambda i:[qx0_meas,qx0_fit,qx0_residuals, - qy0_meas,qy0_fit,qy0_residuals][i], - H=2,W=3,cmap='RdBu') + show_image_grid( + get_ar=lambda i: [ + qx0_meas, + qx0_fit, + qx0_residuals, + qy0_meas, + qy0_fit, + qy0_residuals, + ][i], + H=2, + W=3, + cmap="RdBu", + ) -def show_selected_dps(datacube,positions,im,bragg_pos=None, - colors=None,HW=None,figsize_im=(6,6),figsize_dp=(4,4), - **kwargs): + +def show_selected_dps( + datacube, + positions, + im, + bragg_pos=None, + colors=None, + HW=None, + figsize_im=(6, 6), + figsize_dp=(4, 4), + **kwargs +): """ Shows two plots: first, a real space image overlaid with colored dots at the specified positions; second, a grid of diffraction patterns @@ -673,72 +880,87 @@ def show_selected_dps(datacube,positions,im,bragg_pos=None, *diffraction patterns*. Default is `scaling='log'` """ from py4DSTEM.datacube import DataCube - assert isinstance(datacube,DataCube) + + assert isinstance(datacube, DataCube) N = len(positions) - assert(all([len(x)==2 for x in positions])), "Improperly formated argument `positions`" + assert all( + [len(x) == 2 for x in positions] + ), "Improperly formated argument `positions`" if bragg_pos is not None: show_disk_pos = True - assert(len(bragg_pos)==N) + assert len(bragg_pos) == N else: show_disk_pos = False if colors is None: from matplotlib.cm import gist_ncar - linsp = np.linspace(0,1,N,endpoint=False) + + linsp = np.linspace(0, 1, N, endpoint=False) colors = [gist_ncar(i) for i in linsp] - assert(len(colors)==N), "Number of positions and colors don't match" + assert len(colors) == N, "Number of positions and colors don't match" from matplotlib.colors import is_color_like - assert([is_color_like(i) for i in colors]) + + assert [is_color_like(i) for i in colors] if HW is None: W = int(np.ceil(np.sqrt(N))) - if W<3: W=3 - H = int(np.ceil(N/W)) + if W < 3: + W = 3 + H = int(np.ceil(N / W)) else: - H,W = HW - assert(all([isinstance(x,(int,np.integer)) for x in (H,W)])) + H, W = HW + assert all([isinstance(x, (int, np.integer)) for x in (H, W)]) x = [i[0] for i in positions] y = [i[1] for i in positions] - if 'scaling' not in kwargs.keys(): - kwargs['scaling'] = 'log' + if "scaling" not in kwargs.keys(): + kwargs["scaling"] = "log" if not show_disk_pos: - fig,ax = show(im,figsize=figsize_im,returnfig=True) - add_points(ax,d = {'x':x,'y':y,'pointcolor':colors}) - show_image_grid(get_ar=lambda i:datacube.data[x[i],y[i],:,:],H=H,W=W, - get_bordercolor=lambda i:colors[i],axsize=figsize_dp, - **kwargs) + fig, ax = show(im, figsize=figsize_im, returnfig=True) + add_points(ax, d={"x": x, "y": y, "pointcolor": colors}) + show_image_grid( + get_ar=lambda i: datacube.data[x[i], y[i], :, :], + H=H, + W=W, + get_bordercolor=lambda i: colors[i], + axsize=figsize_dp, + **kwargs, + ) else: - show_image_grid(get_ar=lambda i:datacube.data[x[i],y[i],:,:],H=H,W=W, - get_bordercolor=lambda i:colors[i],axsize=figsize_dp, - get_x=lambda i:bragg_pos[i].data['qx'], - get_y=lambda i:bragg_pos[i].data['qy'], - get_pointcolors=lambda i:colors[i], - **kwargs) + show_image_grid( + get_ar=lambda i: datacube.data[x[i], y[i], :, :], + H=H, + W=W, + get_bordercolor=lambda i: colors[i], + axsize=figsize_dp, + get_x=lambda i: bragg_pos[i].data["qx"], + get_y=lambda i: bragg_pos[i].data["qy"], + get_pointcolors=lambda i: colors[i], + **kwargs, + ) + -def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): """ complex_data (array): complex array to plot - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value + vmin (float) : minimum absolute value + vmax (float) : maximum absolute value power (float) : power to raise amplitude to """ - if power is None: - norm = mcolors.Normalize() - else: - norm = mcolors.PowerNorm(power) - - amp = norm(np.abs(complex_data)).data + amp = np.abs(complex_data) phase = np.angle(complex_data) - if np.isclose(np.max(amp),np.min(amp)): + if power is not None: + amp = amp**power + + if np.isclose(np.max(amp), np.min(amp)): if vmin is None: vmin = 0 if vmax is None: vmax = np.max(amp) else: if vmin is None: - vmin = 0.0 + vmin = 0.02 if vmax is None: - vmax = 1.0 + vmax = 0.98 vals = np.sort(amp[~np.isnan(amp)]) ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") @@ -749,27 +971,29 @@ def Complex2RGB(complex_data, vmin=None, vmax = None, power = None): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) - - J = amp*61.5 # Note we restrict luminance to 61.5 - C = np.where(J<61.5,98*J/123,1400/11-14*J/11) # Min uniform chroma - h = np.rad2deg(phase)+180 + amp = ((amp - vmin) / vmax).clip(1e-16, 1) - JCh = np.stack((J,C,h), axis=-1) + J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff + C = np.where(J < 61.5, 98 * J / 123, 1400 / 11 - 14 * J / 11) # Min uniform chroma + h = np.rad2deg(phase) + 180 + + JCh = np.stack((J, C, h), axis=-1) rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - + return rgb -def add_colorbar_arg(cax, c = 49, j = 61.5): + +def add_colorbar_arg(cax, c=49, j=61.5): """ cax : axis to add cbar to c : constant chroma value j : constant luminance value """ - h = np.linspace(0, 360, 256,endpoint=False) - J = np.full_like(h,j) - C = np.full_like(h,c) - JCh = np.stack((J,C,h), axis=-1) + h = np.linspace(0, 360, 256, endpoint=False) + J = np.full_like(h, j) + C = np.full_like(h, c) + JCh = np.stack((J, C, h), axis=-1) rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) @@ -783,6 +1007,7 @@ def add_colorbar_arg(cax, c = 49, j = 61.5): [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) + def show_complex( ar_complex, vmin=None, @@ -803,7 +1028,7 @@ def show_complex( such as [array1, array2], then arrays are horizonally plotted in one figure vmin (float, optional) : minimum absolute value vmax (float, optional) : maximum absolute value - if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, + if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar @@ -811,16 +1036,24 @@ def show_complex( pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) power (float,optional) : power to raise amplitude to - + Returns: if returnfig==False (default), the figure is plotted and nothing is returned. if returnfig==True, return the figure and the axis. """ # convert to complex colors - ar_complex = ar_complex[0] if (isinstance(ar_complex,list) and len(ar_complex) == 1) else ar_complex + ar_complex = ( + ar_complex[0] + if (isinstance(ar_complex, list) and len(ar_complex) == 1) + else ar_complex + ) if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): - rgb = [Complex2RGB(ar, vmin, vmax, power=power) for sublist in ar_complex for ar in sublist] + rgb = [ + Complex2RGB(ar, vmin, vmax, power=power) + for sublist in ar_complex + for ar in sublist + ] H = len(ar_complex) W = len(ar_complex[0]) @@ -843,7 +1076,7 @@ def show_complex( is_grid = True H = rgb.shape[0] W = rgb.shape[1] - rgb = rgb.reshape((-1,)+rgb.shape[-3:]) + rgb = rgb.reshape((-1,) + rgb.shape[-3:]) else: is_grid = False # plot From d4099ea28f5c8579f9ca549644b8c6029bcacb1b Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 3 Oct 2023 17:10:47 -0700 Subject: [PATCH 043/176] change to fitted intensities --- py4DSTEM/process/phase/iterative_parallax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 0d2f4cfc9..3882fb1e4 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -200,14 +200,14 @@ def preprocess( intensities_shifted = self._intensities.copy() - center_x = np.mean(dpc._com_measured_x) - center_y = np.mean(dpc._com_measured_y) + center_x = np.mean(dpc._com_fitted_x) + center_y = np.mean(dpc._com_fitted_y) for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): intensity_shifted = get_shifted_ar( self._intensities[rx, ry], - -dpc._com_measured_x[rx, ry] + center_x, - -dpc._com_measured_y[rx, ry] + center_y, + -dpc._com_fitted_x[rx, ry] + center_x, + -dpc._com_fitted_y[rx, ry] + center_y, bilinear=True, device="cpu", ) From 466343b54838eee691ff9206be5e16118ce6c508 Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 4 Oct 2023 14:20:52 -0700 Subject: [PATCH 044/176] Adding docstrings and more description --- py4DSTEM/process/diffraction/crystal.py | 41 +++++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 12309ac4d..9b6976460 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1106,7 +1106,10 @@ def generate_moire( returnfig=False, ): """ - Calculate a Moire lattice from 2 parent diffraction patterns. + Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated + and strained with respect to the original lattice. Note that this strain is applied in real space, + and so the inverse of the calculated infinitestimal strain tensor is applied. + Parameters -------- @@ -1115,30 +1118,54 @@ def generate_moire( bragg_peaks_1: BraggVector Bragg vectors for parent lattice 1. thresh_0: float + Intensity threshold for structure factors from lattice 0. thresh_1: float + Intensity threshold for structure factors from lattice 1. int_range: (float, float) + Plotting intensity range for the Moire peaks. exx_1: float + Strain of lattice 1 in x direction (vertical) in real space. eyy_1: float + Strain of lattice 1 in y direction (horizontal) in real space. exy_1: float + Shear strain of lattice 1 in (x,y) direction (diagonal) in real space. phi_1: float + Rotation of lattice 1 in real space. power: float + Plotting power law (default is amplitude**2.0, i.e. intensity). k_max: float + Max k value of the calculated (and plotted) Moire lattice. plot_result: bool + Plot the resulting Moire lattice. plot_subpixel: bool + Apply subpixel corrections to the Bragg spot positions. + Matplotlib default scatter plot rounds to the nearest pixel. labels: list List of text labels for parent lattices marker_size_parent: float + Size of plot markers for the two parent lattices. marker_size_moire: float + Size of plot markers for the Moire lattice. text_size_parent: float + Label text size for parent lattice. text_size_moire: float + Label text size for Moire lattice. add_labels_parent: bool + Plot the parent lattice index labels. add_labels_moire: bool + Plot the parent lattice index labels for the Moire spots. dist_labels: float + Distance to move the labels off the spots. dist_check: float + Set to some distance to "push" the labels away from each other if they are within this distance. sep_labels: float + Separation distance for labels which are "pushed" apart. figsize: (float,float) + Size of output figure. return_moire: bool + Return the moire lattice as a pointlist. returnfig: bool + Return the (fix,ax) handles of the plot. Returns -------- @@ -1200,11 +1227,13 @@ def overline(x): [np.cos(phi_1), -np.sin(phi_1)], [np.sin(phi_1), np.cos(phi_1)], ] - ) @ np.array( - [ - [1 + exx_1, exy_1 * 0.5], - [exy_1 * 0.5, 1 + eyy_1], - ] + ) @ np.linalg.inv( + np.array( + [ + [1 + exx_1, exy_1 * 0.5], + [exy_1 * 0.5, 1 + eyy_1], + ] + ) ) qx1 = m[0, 0] * qx1_init + m[0, 1] * qy1_init qy1 = m[1, 0] * qx1_init + m[1, 1] * qy1_init From 4f97b81b477bab5f72bcbb7f3c51895042e6512b Mon Sep 17 00:00:00 2001 From: cophus Date: Wed, 4 Oct 2023 14:23:30 -0700 Subject: [PATCH 045/176] Black formatting --- py4DSTEM/process/diffraction/crystal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 9b6976460..1c43f89bc 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1138,7 +1138,7 @@ def generate_moire( plot_result: bool Plot the resulting Moire lattice. plot_subpixel: bool - Apply subpixel corrections to the Bragg spot positions. + Apply subpixel corrections to the Bragg spot positions. Matplotlib default scatter plot rounds to the nearest pixel. labels: list List of text labels for parent lattices From f2c21d5581febf3fddb44a020305d89c55b38cca Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 8 Oct 2023 08:56:07 -0700 Subject: [PATCH 046/176] preprocessing dtype bug --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 96b3d5088..4ccb21226 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1108,7 +1108,7 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros_like(diffraction_intensities) + amplitudes = xp.zeros(diffraction_intensities.shape, dtype=xp.float32) region_of_interest_shape = diffraction_intensities.shape[-2:] com_fitted_x = self._asnumpy(com_fitted_x) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 6dbccff06..2a482faf0 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -1283,6 +1283,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, + constrain_position_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -1358,6 +1359,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1680,6 +1683,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, + constrain_position_distance, ) error += batch_error From d6f4799224728c86a551b9a2b739b01147b3698f Mon Sep 17 00:00:00 2001 From: Steve Zeltmann Date: Mon, 9 Oct 2023 09:41:29 -0400 Subject: [PATCH 047/176] update CUDA source file location in setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 069bf1600..c3cbbd151 100644 --- a/setup.py +++ b/setup.py @@ -57,8 +57,8 @@ package_data={ "py4DSTEM": [ "process/utils/scattering_factors.txt", - "process/diskdetection/multicorr_row_kernel.cu", - "process/diskdetection/multicorr_col_kernel.cu", + "braggvectors/multicorr_row_kernel.cu", + "braggvectors/multicorr_col_kernel.cu", ] }, ) From 74c60d40b34563a78a6c696ebb22c916362c20b2 Mon Sep 17 00:00:00 2001 From: Steve Zeltmann Date: Mon, 9 Oct 2023 10:05:56 -0400 Subject: [PATCH 048/176] update cupy import try statements --- py4DSTEM/braggvectors/diskdetection_aiml_cuda.py | 4 ++-- py4DSTEM/preprocess/utils.py | 4 ++-- py4DSTEM/process/phase/iterative_base_class.py | 4 ++-- py4DSTEM/process/phase/iterative_dpc.py | 4 ++-- py4DSTEM/process/phase/iterative_mixedstate_ptychography.py | 4 ++-- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 4 ++-- .../process/phase/iterative_overlap_magnetic_tomography.py | 4 ++-- py4DSTEM/process/phase/iterative_overlap_tomography.py | 4 ++-- py4DSTEM/process/phase/iterative_parallax.py | 4 ++-- py4DSTEM/process/phase/iterative_simultaneous_ptychography.py | 4 ++-- py4DSTEM/process/phase/iterative_singleslice_ptychography.py | 4 ++-- py4DSTEM/process/utils/cross_correlate.py | 4 ++-- py4DSTEM/process/utils/multicorr.py | 4 ++-- py4DSTEM/process/utils/utils.py | 4 ++-- 14 files changed, 28 insertions(+), 28 deletions(-) diff --git a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py index d0f550dcc..c5f89b9fd 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py @@ -17,8 +17,8 @@ try: import cupy as cp -except: - raise ImportError("Import Error: Please install cupy before proceeding") +except ModuleNotFoundError: + raise ImportError("AIML CUDA Requires cupy") try: import tensorflow as tf diff --git a/py4DSTEM/preprocess/utils.py b/py4DSTEM/preprocess/utils.py index 0c76f35a7..752e2f81c 100644 --- a/py4DSTEM/preprocess/utils.py +++ b/py4DSTEM/preprocess/utils.py @@ -5,8 +5,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def bin2D(array, factor, dtype=np.float64): diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index ae4c92d4b..6d7967550 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -13,8 +13,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4c80ed177..02138d738 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -13,8 +13,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 56fec1004..ceae66cd8 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a352502d0..aee383675 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 8691a121d..b09d18ca7 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -16,8 +16,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..1f6be1c38 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -16,8 +16,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 80cdd8cd8..7c5896b6a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -19,8 +19,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np warnings.simplefilter(action="always", category=UserWarning) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8881d021c..e3713cde1 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0480bae8a..df0ef5e1c 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube diff --git a/py4DSTEM/process/utils/cross_correlate.py b/py4DSTEM/process/utils/cross_correlate.py index f9aac1312..50de91e33 100644 --- a/py4DSTEM/process/utils/cross_correlate.py +++ b/py4DSTEM/process/utils/cross_correlate.py @@ -6,8 +6,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def get_cross_correlation(ar, template, corrPower=1, _returnval="real"): diff --git a/py4DSTEM/process/utils/multicorr.py b/py4DSTEM/process/utils/multicorr.py index 8523c8e62..bc07390bb 100644 --- a/py4DSTEM/process/utils/multicorr.py +++ b/py4DSTEM/process/utils/multicorr.py @@ -15,8 +15,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def upsampled_correlation(imageCorr, upsampleFactor, xyShift, device="cpu"): diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 03d3d07a0..4ef2e1d8a 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -24,8 +24,8 @@ def clear_output(wait=True): try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def radial_reduction(ar, x0, y0, binsize=1, fn=np.mean, coords=None): From 7c3a0d8ebd8bb3b12cc421b4c1ce2f3d9c88fca6 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Tue, 10 Oct 2023 08:25:12 -0700 Subject: [PATCH 049/176] adding tilt to propagators --- .../iterative_multislice_ptychography.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index b3614c0ad..365bd8b8f 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -80,6 +80,10 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -111,6 +115,8 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, object_type: str = "complex", verbose: bool = True, device: str = "cpu", @@ -191,6 +197,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -198,6 +206,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -213,6 +223,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in angles) + theta_y: float + y tilt of propagator (in angles) Returns ------- @@ -232,6 +246,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -239,6 +257,10 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp(1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))) return propagators @@ -561,6 +583,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -1859,6 +1883,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional From 23204aa5cee2243a879f9d5a5602f9e26dbb17af Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 12 Oct 2023 14:22:41 -0400 Subject: [PATCH 050/176] single slice crop patterns --- .../process/phase/iterative_base_class.py | 49 +++++++++++++++++-- .../iterative_multislice_ptychography.py | 4 +- .../iterative_singleslice_ptychography.py | 12 +++-- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 4ccb21226..2e6e0a917 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1084,6 +1084,7 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, + crop_patterns, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1096,6 +1097,8 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns Returns ------- @@ -1108,13 +1111,46 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros(diffraction_intensities.shape, dtype=xp.float32) - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities = self._asnumpy(diffraction_intensities) + if crop_patterns: + crop_x = int( + np.minimum( + diffraction_intensities.shape[2] - com_fitted_x.max(), + com_fitted_x.min(), + ) + ) + crop_y = int( + np.minimum( + diffraction_intensities.shape[3] - com_fitted_y.max(), + com_fitted_y.min(), + ) + ) + + crop_w = np.minimum(crop_y, crop_x) + region_of_interest_shape = (crop_w * 2, crop_w * 2) + amplitudes = np.zeros( + ( + diffraction_intensities.shape[0], + diffraction_intensities.shape[1], + crop_w * 2, + crop_w * 2, + ), + dtype=np.float32, + ) + + crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask[:crop_w, :crop_w] = True + crop_mask[-crop_w:, :crop_w] = True + crop_mask[:crop_w:, -crop_w:] = True + crop_mask[-crop_w:, -crop_w:] = True + self._crop_mask = crop_mask + + else: + region_of_interest_shape = diffraction_intensities.shape[-2:] + amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) - diffraction_intensities = self._asnumpy(diffraction_intensities) - amplitudes = self._asnumpy(amplitudes) for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): @@ -1126,6 +1162,11 @@ def _normalize_diffraction_intensities( device="cpu", ) + if crop_patterns: + intensities = intensities[crop_mask].reshape( + region_of_interest_shape + ) + mean_intensity += np.sum(intensities) amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 365bd8b8f..be6cbd6ed 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -260,7 +260,9 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) ) - propagators[i] *= xp.exp(1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 3843da983..0cc6b65d5 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -188,6 +188,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -245,6 +246,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns Returns -------- @@ -330,9 +333,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -412,6 +413,11 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, From 871e0a508cd53815715c975f470e9ffaacc88254 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 12 Oct 2023 14:57:20 -0700 Subject: [PATCH 051/176] Formatting --- py4DSTEM/process/polar/polar_analysis.py | 321 +++++++++++------------ 1 file changed, 160 insertions(+), 161 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 0c2454289..829bbde59 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -11,15 +11,15 @@ def calculate_radial_statistics( self, - median_local = False, - median_global = False, - plot_results_mean = False, - plot_results_var = False, - figsize = (8,4), - returnval = False, - returnfig = False, - progress_bar = True, - ): + median_local=False, + median_global=False, + plot_results_mean=False, + plot_results_var=False, + figsize=(8, 4), + returnval=False, + returnfig=False, + progress_bar=True, +): """ Calculate fluctuation electron microscopy (FEM) statistics, including radial mean, variance, and normalized variance. This function uses the original FEM definitions, @@ -49,16 +49,20 @@ def calculate_radial_statistics( self.scattering_vector_units = self.calibration.get_Q_pixel_units() # init radial data arrays - self.radial_all = np.zeros(( - self._datacube.shape[0], - self._datacube.shape[1], - self.polar_shape[1], - )) - self.radial_all_std = np.zeros(( - self._datacube.shape[0], - self._datacube.shape[1], - self.polar_shape[1], - )) + self.radial_all = np.zeros( + ( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + ) + ) + self.radial_all_std = np.zeros( + ( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + ) + ) # Compute the radial mean for each probe position for rx, ry in tqdmnd( @@ -66,28 +70,26 @@ def calculate_radial_statistics( self._datacube.shape[1], desc="Radial statistics", unit=" probe positions", - disable=not progress_bar): - - self.radial_all[rx,ry] = np.mean( - self.data[rx,ry], - axis=0) - self.radial_all_std[rx,ry] = np.sqrt(np.mean( - (self.data[rx,ry] - self.radial_all[rx,ry][None])**2, - axis=0)) - - self.radial_mean = np.mean(self.radial_all, axis=(0,1)) + disable=not progress_bar, + ): + self.radial_all[rx, ry] = np.mean(self.data[rx, ry], axis=0) + self.radial_all_std[rx, ry] = np.sqrt( + np.mean((self.data[rx, ry] - self.radial_all[rx, ry][None]) ** 2, axis=0) + ) + + self.radial_mean = np.mean(self.radial_all, axis=(0, 1)) self.radial_var = np.mean( - (self.radial_all - self.radial_mean[None,None])**2, - axis=(0,1)) + (self.radial_all - self.radial_mean[None, None]) ** 2, axis=(0, 1) + ) - self.radial_var_norm = self.radial_var + self.radial_var_norm = self.radial_var sub = self.radial_mean > 0.0 - self.radial_var_norm[sub] /= self.radial_mean[sub]**2 + self.radial_var_norm[sub] /= self.radial_mean[sub] ** 2 # plot results if plot_results_mean: if returnfig: - fig,ax = plot_radial_mean( + fig, ax = plot_radial_mean( self, figsize=figsize, returnfig=True, @@ -95,20 +97,20 @@ def calculate_radial_statistics( else: plot_radial_mean( self, - figsize = figsize, - ) + figsize=figsize, + ) elif plot_results_var: if returnfig: - fig,ax = plot_radial_var_norm( + fig, ax = plot_radial_var_norm( self, - figsize = figsize, - returnfig = True, - ) + figsize=figsize, + returnfig=True, + ) else: plot_radial_var_norm( self, - figsize = figsize, - ) + figsize=figsize, + ) # Return values if returnval: @@ -125,31 +127,31 @@ def calculate_radial_statistics( def plot_radial_mean( self, - log_x = False, - log_y = False, - figsize = (8,4), - returnfig = False, - ): + log_x=False, + log_y=False, + figsize=(8, 4), + returnfig=False, +): """ Plot radial mean """ - fig,ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) ax.plot( self.scattering_vector, self.radial_mean, - ) + ) if log_x: - ax.set_xscale('log') + ax.set_xscale("log") if log_y: - ax.set_yscale('log') + ax.set_yscale("log") - ax.set_xlabel('Scattering Vector (' + self.scattering_vector_units + ')') - ax.set_ylabel('Radial Mean') + ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_ylabel("Radial Mean") if log_x and self.scattering_vector[0] == 0.0: - ax.set_xlim((self.scattering_vector[1],self.scattering_vector[-1])) + ax.set_xlim((self.scattering_vector[1], self.scattering_vector[-1])) else: - ax.set_xlim((self.scattering_vector[0],self.scattering_vector[-1])) + ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) if returnfig: return fig, ax @@ -157,9 +159,9 @@ def plot_radial_mean( def plot_radial_var_norm( self, - figsize = (8,4), - returnfig = False, - ): + figsize=(8, 4), + returnfig=False, +): """ Plotting function for the global FEM. """ @@ -179,27 +181,27 @@ def plot_radial_var_norm( def calculate_pair_dist_function( self, - k_min = 0.05, - k_max = None, - k_width = 0.25, - k_lowpass = None, - k_highpass = None, + k_min=0.05, + k_max=None, + k_width=0.25, + k_lowpass=None, + k_highpass=None, # k_pad_max = 10.0, - r_min = 0.0, - r_max = 20.0, - r_step = 0.02, - damp_origin_fluctuations = False, + r_min=0.0, + r_max=20.0, + r_step=0.02, + damp_origin_fluctuations=False, # poly_background_order = 2, # iterative_pdf_refine = True, # num_iter = 10, - dens = None, - plot_fits = False, - plot_sf_estimate = False, - plot_reduced_pdf = True, - plot_pdf = False, - figsize = (8,4), - maxfev = None, - ): + dens=None, + plot_fits=False, + plot_sf_estimate=False, + plot_reduced_pdf=True, + plot_pdf=False, + figsize=(8, 4), + maxfev=None, +): """ Calculate the pair distribution function (PDF). @@ -218,7 +220,7 @@ def calculate_pair_dist_function( int0 = np.median(self.radial_mean) / int_mean - const_bg sigma0 = np.mean(k) coefs = [const_bg, int0, sigma0, int0, sigma0] - lb = [0,0,0,0,0] + lb = [0, 0, 0, 0, 0] ub = [np.inf, np.inf, np.inf, np.inf, np.inf] # Weight the fit towards high k values noise_est = k[-1] - k + dk @@ -226,31 +228,30 @@ def calculate_pair_dist_function( # Estimate the mean atomic form factor + background if maxfev is None: coefs = curve_fit( - scattering_model, - k2[sub_fit], - Ik[sub_fit] / int_mean, - sigma = noise_est[sub_fit], + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma=noise_est[sub_fit], p0=coefs, - xtol = 1e-8, - bounds = (lb,ub), + xtol=1e-8, + bounds=(lb, ub), )[0] else: coefs = curve_fit( - scattering_model, - k2[sub_fit], - Ik[sub_fit] / int_mean, - sigma = noise_est[sub_fit], + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma=noise_est[sub_fit], p0=coefs, - xtol = 1e-8, - bounds = (lb,ub), - maxfev = maxfev, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, )[0] coefs[0] *= int_mean coefs[1] *= int_mean coefs[3] *= int_mean - # Calculate the mean atomic form factor wthout any background coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) fk = scattering_model(k2, coefs_fk) @@ -263,40 +264,40 @@ def calculate_pair_dist_function( # (k - k_min) / k_width, # (k_max - k) / k_width, # ),0,1) - mask = np.clip(np.minimum( - (k - 0.0) / k_width, - (k_max - k) / k_width, - ),0,1) - mask = np.sin(mask*(np.pi/2)) + mask = np.clip( + np.minimum( + (k - 0.0) / k_width, + (k_max - k) / k_width, + ), + 0, + 1, + ) + mask = np.sin(mask * (np.pi / 2)) # Estimate the reduced structure factor S(k) Sk = (Ik - bg) * k / fk # Masking edges of S(k) mask_sum = np.sum(mask) - Sk = (Sk - np.sum(Sk*mask)/mask_sum) * mask + Sk = (Sk - np.sum(Sk * mask) / mask_sum) * mask # Filtering of S(k) if k_lowpass is not None and k_lowpass > 0.0: - Sk = gaussian_filter( - Sk, - sigma=k_lowpass / dk, - mode = 'nearest') + Sk = gaussian_filter(Sk, sigma=k_lowpass / dk, mode="nearest") if k_highpass is not None: - Sk_lowpass = gaussian_filter( - Sk, - sigma=k_highpass / dk, - mode = 'nearest') + Sk_lowpass = gaussian_filter(Sk, sigma=k_highpass / dk, mode="nearest") Sk -= Sk_lowpass # Calculate the real space PDF r = np.arange(r_min, r_max, r_step) - ra,ka = np.meshgrid(r,k) - pdf_reduced = (2/np.pi)*dk*np.sum( - np.sin( - 2*np.pi*ra*ka - ) * Sk[:,None], - axis=0, + ra, ka = np.meshgrid(r, k) + pdf_reduced = ( + (2 / np.pi) + * dk + * np.sum( + np.sin(2 * np.pi * ra * ka) * Sk[:, None], + axis=0, + ) ) # Damp the unphysical fluctuations at the PDF origin @@ -304,7 +305,7 @@ def calculate_pair_dist_function( ind_max = np.argmax(pdf_reduced) r_ind_max = r[ind_max] r_mask = np.minimum(r / r_ind_max, 1.0) - r_mask = np.sin(r_mask*np.pi/2)**2 + r_mask = np.sin(r_mask * np.pi / 2) ** 2 pdf_reduced *= r_mask # Store results @@ -314,8 +315,8 @@ def calculate_pair_dist_function( # if density is provided, we can estimate the full PDF if dens is not None: pdf = pdf_reduced.copy() - pdf[1:] /= (4*np.pi*dens*r[1:]*(r[1]-r[0])) - pdf *= (2/np.pi) + pdf[1:] /= 4 * np.pi * dens * r[1:] * (r[1] - r[0]) + pdf *= 2 / np.pi pdf += 1 if damp_origin_fluctuations: @@ -323,68 +324,70 @@ def calculate_pair_dist_function( pdf = np.maximum(pdf, 0.0) - - # Plots if plot_fits: - fig,ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) ax.plot( self.scattering_vector, self.radial_mean, - color = 'k', - ) + color="k", + ) ax.plot( k, bg, - color = 'r', - ) - ax.set_xlabel('Scattering Vector (' + self.scattering_vector_units + ')') - ax.set_ylabel('Radial Mean') - ax.set_xlim((self.scattering_vector[0],self.scattering_vector[-1])) + color="r", + ) + ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_ylabel("Radial Mean") + ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) # ax.set_ylim((0,2e-5)) - ax.set_xlabel('Scattering Vector [A^-1]') - ax.set_ylabel('I(k) and Fit Estimates') + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("I(k) and Fit Estimates") - ax.set_ylim((np.min(self.radial_mean[self.radial_mean>0])*0.8, - np.max(self.radial_mean*mask)*1.25)) - ax.set_yscale('log') + ax.set_ylim( + ( + np.min(self.radial_mean[self.radial_mean > 0]) * 0.8, + np.max(self.radial_mean * mask) * 1.25, + ) + ) + ax.set_yscale("log") if plot_sf_estimate: - fig,ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) ax.plot( k, Sk, - color = 'r', + color="r", + ) + yr = (np.min(Sk), np.max(Sk)) + ax.set_ylim( + ( + yr[0] - 0.05 * (yr[1] - yr[0]), + yr[1] + 0.05 * (yr[1] - yr[0]), ) - yr = (np.min(Sk),np.max(Sk)) - ax.set_ylim(( - yr[0]-0.05*(yr[1]-yr[0]), - yr[1]+0.05*(yr[1]-yr[0]), - )) - ax.set_xlabel('Scattering Vector [A^-1]') - ax.set_ylabel('Reduced Structure Factor') + ) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("Reduced Structure Factor") if plot_reduced_pdf: - fig,ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) ax.plot( r, pdf_reduced, - color = 'r', - ) - ax.set_xlabel('Radius [A]') - ax.set_ylabel('Reduced Pair Distribution Function') + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Reduced Pair Distribution Function") if plot_pdf: - fig,ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(figsize=figsize) ax.plot( r, pdf, - color = 'r', - ) - ax.set_xlabel('Radius [A]') - ax.set_ylabel('Pair Distribution Function') - - + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Pair Distribution Function") # functions for inverting from reduced PDF back to S(k) @@ -403,7 +406,6 @@ def calculate_pair_dist_function( # ) - def calculate_FEM_local( self, figsize=(8, 6), @@ -428,13 +430,10 @@ def calculate_FEM_local( """ - pass - - def scattering_model(k2, *coefs): coefs = np.squeeze(np.array(coefs)) @@ -444,14 +443,14 @@ def scattering_model(k2, *coefs): int1 = coefs[3] sigma1 = coefs[4] - int_model = const_bg + \ - int0*np.exp(k2/(-2*sigma0**2)) + \ - int1*np.exp(k2**2/(-2*sigma1**4)) - - # (int1*sigma1)/(k2 + sigma1**2) - # int1*np.exp(k2/(-2*sigma1**2)) - # int1*np.exp(k2/(-2*sigma1**2)) + int_model = ( + const_bg + + int0 * np.exp(k2 / (-2 * sigma0**2)) + + int1 * np.exp(k2**2 / (-2 * sigma1**4)) + ) + # (int1*sigma1)/(k2 + sigma1**2) + # int1*np.exp(k2/(-2*sigma1**2)) + # int1*np.exp(k2/(-2*sigma1**2)) return int_model - From 84a2067238f0d512856b139866050a46c774ac21 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 13 Oct 2023 03:17:38 -0700 Subject: [PATCH 052/176] tv_denoise typo --- py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py | 2 +- py4DSTEM/process/phase/iterative_overlap_tomography.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index d2934497c..5a1c5dde3 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -2678,7 +2678,7 @@ def reconstruct( else None, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, - v_denoise_inner_iter=tv_denoise_inner_iter, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 3d5982e9e..0157fa422 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -2398,7 +2398,7 @@ def reconstruct( else None, tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, tv_denoise_weights=tv_denoise_weights, - v_denoise_inner_iter=tv_denoise_inner_iter, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) From 0978c4b4ff928acc42487e7f91a8cd41afda50d8 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 13 Oct 2023 04:28:09 -0700 Subject: [PATCH 053/176] revisting casting inconsistencies --- .../process/phase/iterative_base_class.py | 10 ++-- py4DSTEM/process/phase/iterative_parallax.py | 46 +++++++++++++------ 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 4ccb21226..4b9d905d1 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -278,7 +278,9 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - intensities = datacube.data + xp = self._xp + + intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -450,8 +452,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - intensities = xp.asarray(intensities, dtype=xp.float32) - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -1108,7 +1108,7 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros(diffraction_intensities.shape, dtype=xp.float32) + amplitudes = xp.zeros_like(diffraction_intensities) region_of_interest_shape = diffraction_intensities.shape[-2:] com_fitted_x = self._asnumpy(com_fitted_x) @@ -1129,8 +1129,6 @@ def _normalize_diffraction_intensities( mean_intensity += np.sum(intensities) amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) - amplitudes = xp.asarray(amplitudes, dtype=xp.float32) - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) mean_intensity /= amplitudes.shape[0] diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 3882fb1e4..aa2a4a6b0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -176,7 +176,7 @@ def preprocess( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities, dtype=xp.float32) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -224,14 +224,16 @@ def preprocess( # diffraction space coordinates self._xy_inds = np.argwhere(self._dp_mask) - self._kxy = (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) * xp.array( - self._reciprocal_sampling - )[None] + self._kxy = xp.asarray( + (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) + * xp.array(self._reciprocal_sampling)[None], + dtype=xp.float32, + ) self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) # Window function - x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1)[1:] + x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( xp.sin( @@ -242,7 +244,7 @@ def preprocess( ) ** 2 ) - y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1)[1:] + y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:] y -= (y[1] - y[0]) / 2 wy = ( xp.sin( @@ -259,7 +261,8 @@ def preprocess( ( self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], - ) + ), + dtype=xp.float32, ) self._window_pad[ self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -282,8 +285,8 @@ def preprocess( self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape) - self._stack_BF_no_window = xp.ones(stack_shape) + self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -306,12 +309,12 @@ def preprocess( ] = all_bfs elif normalize_order == 1: - x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) - y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones(xa.size), + xp.ones_like(xa), xa.ravel(), ya.ravel(), ) @@ -364,7 +367,11 @@ def preprocess( # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.asarray(qx, dtype=xp.float32) + qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.asarray(qy, dtype=xp.float32) + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") self._qx_shift = -2j * xp.pi * qxa self._qy_shift = -2j * xp.pi * qya @@ -399,7 +406,7 @@ def preprocess( del Gs else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2)) + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) self._stack_mean = xp.mean(self._stack_BF) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images @@ -686,7 +693,8 @@ def reconstruct( ( self._num_bf_images, (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1), - ) + ), + dtype=xp.float32, ) for ii in np.arange(regularizer_matrix_size[0] + 1): Bi = ( @@ -771,7 +779,7 @@ def reconstruct( # Sort by radial order, from center to outer edge inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1)) - shifts_update = xp.zeros((self._num_bf_images, 2)) + shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) for a1 in tqdmnd( xy_vals.shape[0], @@ -840,11 +848,19 @@ def reconstruct( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) + self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) + self._stack_BF = xp.asarray( + self._stack_BF, dtype=xp.float32 + ) # numpy fft upcasts? + self._stack_mask = xp.asarray( + self._stack_mask, dtype=xp.float32 + ) # numpy fft upcasts? + del Gs # Center the shifts From c9ac5db8de4221b24448a3615b83826e13f04862 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 14 Oct 2023 08:02:20 -0400 Subject: [PATCH 054/176] crop pattern option for all classes --- py4DSTEM/process/phase/iterative_base_class.py | 1 + .../iterative_mixedstate_multislice_ptychography.py | 12 ++++++++---- .../phase/iterative_mixedstate_ptychography.py | 11 ++++++++--- .../phase/iterative_multislice_ptychography.py | 11 ++++++++--- .../phase/iterative_overlap_magnetic_tomography.py | 11 ++++++++--- .../process/phase/iterative_overlap_tomography.py | 11 ++++++++--- .../phase/iterative_simultaneous_ptychography.py | 13 ++++++++++--- .../phase/iterative_singleslice_ptychography.py | 2 +- 8 files changed, 52 insertions(+), 20 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 209d33436..6ddfea643 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1099,6 +1099,7 @@ def _normalize_diffraction_intensities( Best fit vertical center of mass gradient crop_patterns: bool if True, crop patterns to avoid wrap around of patterns + when centering Returns ------- diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index ea10050dd..306f47f77 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -300,6 +300,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -357,6 +358,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -442,9 +445,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -525,7 +526,10 @@ def preprocess( bilinear=True, device=self._device, ) - + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( gpts=self._region_of_interest_shape, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 2a482faf0..658079c3e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -204,6 +204,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -261,6 +262,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -346,9 +349,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -429,6 +430,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index be6cbd6ed..382efedcd 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -305,6 +305,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -362,6 +363,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -447,9 +450,7 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, - self._com_fitted_x, - self._com_fitted_y, + self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns ) # explicitly delete namespace @@ -529,6 +530,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 5a1c5dde3..459b0ae8c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -431,6 +431,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -475,6 +476,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -592,9 +595,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -685,6 +686,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 0157fa422..bb3ee09c2 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -372,6 +372,7 @@ def preprocess( force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -416,6 +417,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -532,9 +535,7 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, - com_fitted_x, - com_fitted_y, + intensities, com_fitted_x, com_fitted_y, crop_patterns ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -625,6 +626,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 85e9a0b18..084a6fcb8 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -192,6 +192,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -246,6 +247,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -401,9 +404,7 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, - com_fitted_x_0, - com_fitted_y_0, + intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns ) # explicitly delete namescapes @@ -487,6 +488,7 @@ def preprocess( intensities_1, com_fitted_x_1, com_fitted_y_1, + crop_patterns ) # explicitly delete namescapes @@ -571,6 +573,7 @@ def preprocess( intensities_2, com_fitted_x_2, com_fitted_y_2, + crop_patterns ) # explicitly delete namescapes @@ -683,6 +686,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0cc6b65d5..0dc2cd053 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -247,7 +247,7 @@ def preprocess( Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded crop_patterns: bool - if True, crop patterns to avoid wrap around of patterns + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- From 00991c99638f86c85cb1bd1571385b51b59cdc6a Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 14 Oct 2023 08:28:20 -0400 Subject: [PATCH 055/176] fix for gpu --- py4DSTEM/process/phase/iterative_base_class.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 6ddfea643..62cf3a3a1 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1172,6 +1172,7 @@ def _normalize_diffraction_intensities( amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) + amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity From 31af42990caa970e0fd36adf2f2df5076d9c491a Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 00:42:19 -0700 Subject: [PATCH 056/176] cleaned up parallax descan --- py4DSTEM/process/phase/iterative_parallax.py | 44 ++++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index aa2a4a6b0..67815cd14 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -184,38 +184,46 @@ def preprocess( raise ValueError( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct if descan_correct: - from py4DSTEM.process.phase import DPCReconstruction - - dpc = DPCReconstruction( - energy=self._energy, - datacube=self._datacube, - verbose=False, - ).preprocess( - force_com_rotation=0, - force_com_transpose=False, - plot_center_of_mass=False, + ( + _, + _, + com_fitted_x, + com_fitted_y, + _, + _, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=None, + fit_function="plane", + com_shifts=None, + com_measured=None, ) - intensities_shifted = self._intensities.copy() + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + intensities = asnumpy(self._intensities) + intensities_shifted = np.zeros_like(intensities) + + center_x = np.mean(com_fitted_x) + center_y = np.mean(com_fitted_y) - center_x = np.mean(dpc._com_fitted_x) - center_y = np.mean(dpc._com_fitted_y) for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): intensity_shifted = get_shifted_ar( - self._intensities[rx, ry], - -dpc._com_fitted_x[rx, ry] + center_x, - -dpc._com_fitted_y[rx, ry] + center_y, + intensities[rx, ry], + -com_fitted_x[rx, ry] + center_x, + -com_fitted_y[rx, ry] + center_y, bilinear=True, device="cpu", ) intensities_shifted[rx, ry] = intensity_shifted - self._intensities = intensities_shifted - self._dp_mean = intensities_shifted.mean((0, 1)) + self._intensities = xp.asarray(intensities_shifted, xp.float32) + self._dp_mean = self._intensities.mean((0, 1)) # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) From 43289e3fc4c82419f477822e459f7bbf26fb1101 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 01:47:13 -0700 Subject: [PATCH 057/176] added support for float upsampling --- py4DSTEM/process/phase/iterative_parallax.py | 67 ++++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 67815cd14..5cec2e95a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -177,6 +177,9 @@ def preprocess( require_calibrations=True, ) + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) + self._scan_shape = np.array(self._intensities.shape[:2]) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -207,8 +210,10 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - center_x = np.mean(com_fitted_x) - center_y = np.mean(com_fitted_y) + # center_x = np.mean(com_fitted_x) + # center_y = np.mean(com_fitted_y) + + center_x, center_y = self._region_of_interest_shape / 2 for rx in range(intensities_shifted.shape[0]): for ry in range(intensities_shifted.shape[1]): @@ -926,7 +931,7 @@ def reconstruct( def subpixel_alignment( self, - kde_upsample_factor=4, + kde_upsample_factor=None, kde_sigma=0.125, plot_upsampled_BF_comparison: bool = True, plot_upsampled_FFT_comparison: bool = False, @@ -955,8 +960,42 @@ def subpixel_alignment( xy_shifts = self._xy_shifts BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + self._DF_upsample_limit = np.max( + self._region_of_interest_shape / self._scan_shape + ) + self._BF_upsample_limit = ( + 2 * self._kr.max() / self._reciprocal_sampling[0] + ) / self._scan_shape.max() + if self._device == "gpu": + self._BF_upsample_limit = self._BF_upsample_limit.item() + + if kde_upsample_factor is None: + kde_upsample_factor = np.minimum( + self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit + ) + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + + if kde_upsample_factor < 1: + raise ValueError("kde_upsample_factor must be larger than 1") + + if kde_upsample_factor > self._DF_upsample_limit: + warnings.warn( + ( + "Requested upsampling factor exceeds " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}." + ), + UserWarning, + ) + self._kde_upsample_factor = kde_upsample_factor - pixel_output = BF_size * self._kde_upsample_factor + pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") pixel_size = pixel_output.prod() # shifted coordinates @@ -1031,12 +1070,12 @@ def subpixel_alignment( cmap = kwargs.pop("cmap", "magma") cropped_object = self._crop_padded_object(self._recon_BF) - upsampled_pad_x = ( - self._object_padding_px[0] * self._kde_upsample_factor // 2 - ) - upsampled_pad_y = ( - self._object_padding_px[1] * self._kde_upsample_factor // 2 - ) + upsampled_pad_x = np.round( + self._object_padding_px[0] * self._kde_upsample_factor / 2 + ).astype("int") + upsampled_pad_y = np.round( + self._object_padding_px[1] * self._kde_upsample_factor / 2 + ).astype("int") cropped_object_aligned = self.recon_BF_subpixel_aligned[ upsampled_pad_x:-upsampled_pad_x, upsampled_pad_y:-upsampled_pad_y, @@ -1072,8 +1111,12 @@ def subpixel_alignment( if plot_upsampled_FFT_comparison: recon_fft = xp.fft.fft2(self._recon_BF) recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) - pad_x = BF_size[0] * (self._kde_upsample_factor - 1) // 2 - pad_y = BF_size[1] * (self._kde_upsample_factor - 1) // 2 + pad_x = np.round( + BF_size[0] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_y = np.round( + BF_size[1] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") pad_recon_fft = asnumpy( xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) ) From 35d076fb05b0131a3aa012c885847ad4a601724e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 01:50:47 -0700 Subject: [PATCH 058/176] making descan correction the default --- py4DSTEM/process/phase/iterative_parallax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 5cec2e95a..ba751cf9b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -113,7 +113,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, - descan_correct: bool = False, + descan_correct: bool = True, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, From c06ca467a2628339d7a2b95d74c74aea2580b9f3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 02:10:17 -0700 Subject: [PATCH 059/176] removing redundant if statement --- py4DSTEM/process/phase/iterative_parallax.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index ba751cf9b..825366a5e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1225,12 +1225,8 @@ def aberration_fit( f"{self.aberration_A1y:.0f}) Ang" ) ) - if self.aberration_C1 > 0: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - else: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") # Plot the CTF comparison between experiment and fit if plot_CTF_compare: From 9be2e9a03368436f20e5eb14e06f218c828d263d Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 03:37:43 -0700 Subject: [PATCH 060/176] removed separate ctf corrections and other subpixel improvements --- py4DSTEM/process/phase/iterative_parallax.py | 229 +++++-------------- 1 file changed, 63 insertions(+), 166 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 825366a5e..a8cbd0998 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1070,16 +1070,9 @@ def subpixel_alignment( cmap = kwargs.pop("cmap", "magma") cropped_object = self._crop_padded_object(self._recon_BF) - upsampled_pad_x = np.round( - self._object_padding_px[0] * self._kde_upsample_factor / 2 - ).astype("int") - upsampled_pad_y = np.round( - self._object_padding_px[1] * self._kde_upsample_factor / 2 - ).astype("int") - cropped_object_aligned = self.recon_BF_subpixel_aligned[ - upsampled_pad_x:-upsampled_pad_x, - upsampled_pad_y:-upsampled_pad_y, - ] + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) extent = [ 0, @@ -1109,7 +1102,6 @@ def subpixel_alignment( ax.set_xlabel("y [A]") if plot_upsampled_FFT_comparison: - recon_fft = xp.fft.fft2(self._recon_BF) recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) pad_x = np.round( BF_size[0] * (self._kde_upsample_factor - 1) / 2 @@ -1128,10 +1120,10 @@ def subpixel_alignment( ) reciprocal_extent = [ - 0, - self._reciprocal_sampling[1] * cropped_object_aligned.shape[1], - self._reciprocal_sampling[0] * cropped_object_aligned.shape[0], - 0, + -self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, + self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, + self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, + -self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, ] show( @@ -1312,8 +1304,9 @@ def aberration_correct( k_info_limit: float = None, k_info_power: float = 1.0, Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, + Wiener_signal_noise_ratio: float = 1.0, + Wiener_filter_low_only: bool = False, + upsampled: bool = True, **kwargs, ): """ @@ -1346,9 +1339,19 @@ def aberration_correct( ) ) + if upsampled and hasattr(self, "_kde_upsample_factor"): + im = self._recon_BF_subpixel_aligned + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + else: + upsampled = False + im = self._recon_BF + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + kx = xp.fft.fftfreq(im.shape[0], sx) + ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 # CTF @@ -1371,7 +1374,7 @@ def aberration_correct( CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr else: # CTF without tilt correction (beyond the parallax operator) @@ -1379,7 +1382,7 @@ def aberration_correct( CTF_corr[0, 0] = 0 # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr # if needed, add low pass filter output image if k_info_limit is not None: @@ -1391,131 +1394,6 @@ def aberration_correct( self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) self.recon_phase_corrected = asnumpy(self._recon_phase_corrected) - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() - - # plotting - if plot_corrected_phase: - figsize = kwargs.pop("figsize", (6, 6)) - cmap = kwargs.pop("cmap", "magma") - - fig, ax = plt.subplots(figsize=figsize) - - cropped_object = self._crop_padded_object(self._recon_phase_corrected) - - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] - - ax.imshow( - cropped_object, - extent=extent, - cmap=cmap, - **kwargs, - ) - - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Parallax-Corrected Phase Image") - - def subpixel_aberration_correct( - self, - plot_corrected_phase: bool = True, - k_info_limit: float = None, - k_info_power: float = 1.0, - Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, - **kwargs, - ): - """ - CTF correction of the phase image using the measured defocus aberration. - - Parameters - ---------- - plot_corrected_phase: bool, optional - If True, the CTF-corrected phase is plotted - k_info_limit: float, optional - maximum allowed frequency in butterworth filter - k_info_power: float, optional - power of butterworth filter - Wiener_filter: bool, optional - Use Wiener filtering instead of CTF sign correction. - Wiener_signal_noise_ratio: float, optional - Signal to noise radio at k = 0 for Wiener filter - Wiener_filter_low_only: bool, optional - Apply Wiener filtering only to the CTF portions before the 1st CTF maxima. - """ - - xp = self._xp - asnumpy = self._asnumpy - - if not hasattr(self, "aberration_C1"): - raise ValueError( - ( - "CTF correction is meant to be ran after alignment and aberration fitting. " - "Please run the `reconstruct()` and `aberration_fit()` functions first." - ) - ) - - # Fourier coordinates - kx = xp.fft.fftfreq( - self._recon_BF_subpixel_aligned.shape[0], - self._scan_sampling[0] / self._kde_upsample_factor, - ) - ky = xp.fft.fftfreq( - self._recon_BF_subpixel_aligned.shape[1], - self._scan_sampling[1] / self._kde_upsample_factor, - ) - kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) - - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio - ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr - - else: - # CTF without tilt correction (beyond the parallax operator) - CTF_corr = xp.sign(sin_chi) - CTF_corr[0, 0] = 0 - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF_subpixel_aligned) * CTF_corr - # if needed, add low pass filter output image - if k_info_limit is not None: - im_fft_corr /= 1 + (kra2**k_info_power) / ( - (k_info_limit) ** (2 * k_info_power) - ) - - # Output phase image - self._recon_phase_corrected_subpixel_aligned = xp.real( - xp.fft.ifft2(im_fft_corr) - ) - self.recon_phase_corrected_subpixel_aligned = asnumpy( - self._recon_phase_corrected_subpixel_aligned - ) - if self._device == "gpu": xp._default_memory_pool.free_all_blocks() xp.clear_memo() @@ -1528,17 +1406,13 @@ def subpixel_aberration_correct( fig, ax = plt.subplots(figsize=figsize) cropped_object = self._crop_padded_object( - self._recon_BF_subpixel_aligned, upsampled=True + self._recon_phase_corrected, upsampled=upsampled ) extent = [ 0, - self._scan_sampling[1] - / self._kde_upsample_factor - * cropped_object.shape[1], - self._scan_sampling[0] - / self._kde_upsample_factor - * cropped_object.shape[0], + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], 0, ] @@ -1551,7 +1425,7 @@ def subpixel_aberration_correct( ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") - ax.set_title("Parallax-Corrected Phase Image Subpixel Aligned") + ax.set_title("Parallax-Corrected Phase Image") def depth_section( self, @@ -1716,12 +1590,19 @@ def _crop_padded_object( asnumpy = self._asnumpy - pad_x = self._object_padding_px[0] // 2 - remaining_padding - pad_y = self._object_padding_px[1] // 2 - remaining_padding - if upsampled: - pad_x *= self._kde_upsample_factor - pad_y *= self._kde_upsample_factor + pad_x = np.round( + self._object_padding_px[0] / 2 * self._kde_upsample_factor + ).astype("int") + pad_y = np.round( + self._object_padding_px[1] / 2 * self._kde_upsample_factor + ).astype("int") + else: + pad_x = self._object_padding_px[0] // 2 + pad_y = self._object_padding_px[1] // 2 + + pad_x -= remaining_padding + pad_y -= remaining_padding return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) @@ -1730,6 +1611,7 @@ def _visualize_figax( fig, ax, remaining_padding: int = 0, + upsampled: bool = False, **kwargs, ): """ @@ -1748,14 +1630,29 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + cropped_object = self._crop_padded_object( + self._recon_BF, remaining_padding, upsampled + ) + + if upsampled: + extent = [ + 0, + self._scan_sampling[1] + * cropped_object.shape[1] + / self._kde_upsample_factor, + self._scan_sampling[0] + * cropped_object.shape[0] + / self._kde_upsample_factor, + 0, + ] - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + else: + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] ax.imshow( cropped_object, From 73b4fe798431c0b1aae1ec07f6233d5041e40690 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 16 Oct 2023 07:05:13 -0400 Subject: [PATCH 061/176] bug fix for calibrated strain --- py4DSTEM/process/strain/latticevectors.py | 12 +++++++++++- py4DSTEM/process/strain/strain.py | 5 ++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py index 90f7f938d..26c8d66a5 100644 --- a/py4DSTEM/process/strain/latticevectors.py +++ b/py4DSTEM/process/strain/latticevectors.py @@ -116,10 +116,19 @@ def add_indices_to_braggvectors( shape=braggpeaks.Rshape, ) + calstate = braggpeaks.calstate + # loop over all the scan positions for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): if mask[Rx, Ry]: - pl = braggpeaks.cal[Rx, Ry] + pl = braggpeaks.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) for i in range(pl.data.shape[0]): r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( pl.data["qy"][i] - lattice.data["qy"] + qy_shift @@ -378,6 +387,7 @@ def get_strain_from_reference_g1g2(g1g2_map, g1, g2): ] return strain_map + def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): """ Starting from a strain map defined with respect to the xy coordinate system of diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 751016a89..47545c04b 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -25,11 +25,10 @@ class StrainMap(RealSlice, Data): """ Storage and processing methods for 4D-STEM datasets. - + """ def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): - """ Accepts: braggvectors (BraggVectors): BraggVectors for Strain Map @@ -95,7 +94,7 @@ def braggvectors(self, x): ), f".braggvectors must be BraggVectors, not type {type(x)}" assert ( x.calibration.origin is not None - ), f"braggvectors must have a calibrated origin" + ), "braggvectors must have a calibrated origin" self._braggvectors = x self._braggvectors.tree(self, force=True) From de2ea8ba8bf51da9c493a422d2bdbb13567275c4 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 12:10:13 +0100 Subject: [PATCH 062/176] removes unused imports --- py4DSTEM/process/polar/polar_analysis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 829bbde59..e2fca17a8 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -3,7 +3,6 @@ import numpy as np import matplotlib.pyplot as plt from scipy.optimize import curve_fit -from scipy.special import comb, erf from scipy.ndimage import gaussian_filter from emdfile import tqdmnd From 9d7616590d5421f2623cbb2dcb743300d4ccc7cc Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 12:14:06 +0100 Subject: [PATCH 063/176] removes unused method args --- py4DSTEM/process/polar/polar_analysis.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index e2fca17a8..9619f6b30 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -10,8 +10,6 @@ def calculate_radial_statistics( self, - median_local=False, - median_global=False, plot_results_mean=False, plot_results_var=False, figsize=(8, 4), From 6abf29e038634adc51d3f888ae5b02515bdd2f15 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 12:34:31 +0100 Subject: [PATCH 064/176] fixes scattering_vector bug --- py4DSTEM/process/polar/polar_analysis.py | 62 +++++++++++++----------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 9619f6b30..ce478887f 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -18,33 +18,37 @@ def calculate_radial_statistics( progress_bar=True, ): """ - Calculate fluctuation electron microscopy (FEM) statistics, including radial mean, - variance, and normalized variance. This function uses the original FEM definitions, - where the signal is computed pattern-by-pattern. - - TODO - finish docstrings, add median statistics. + Calculate the radial statistics used in fluctuation electron microscopy (FEM) + and as an initial step in radial distribution function (RDF) calculation. + The computed quantities are the radial mean, variance, and normalized variance. + Each signal is calculated using the original FEM definitions + [[TODO: add reference]], i.e. pattern-by-pattern. Parameters -------- - self: PolarDatacube - Polar datacube used for measuring FEM properties. + plot_results_mean: bool + Toggles plotting the computed radial means + plot_results_var: bool + Toggles plotting the computed radial variances + figsize: 2-tuple + Size of output figures + returnval: bool + Toggles returning the answer. Answers are always stored internally. + returnfig: bool + Toggles returning figures Returns -------- radial_avg: np.array - Average radial intensity + Optional - returned iff returnval is True. The average radial intensity. radial_var: np.array - Variance in the radial dimension - - + Optional - returned iff returnval is True. The radial variance. + fig_means: 2-tuple (fig,ax) + Optional - returned iff returnfig is True. Plot of the radial means. + fig_var: 2-tuple (fig,ax) + Optional - returned iff returnfig is True. Plot of the radial variances. """ - # Get the dimensioned radial bins - self.scattering_vector = ( - self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() - ) - self.scattering_vector_units = self.calibration.get_Q_pixel_units() - # init radial data arrays self.radial_all = np.zeros( ( @@ -134,7 +138,7 @@ def plot_radial_mean( """ fig, ax = plt.subplots(figsize=figsize) ax.plot( - self.scattering_vector, + self.qq, self.radial_mean, ) @@ -143,12 +147,12 @@ def plot_radial_mean( if log_y: ax.set_yscale("log") - ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") ax.set_ylabel("Radial Mean") - if log_x and self.scattering_vector[0] == 0.0: - ax.set_xlim((self.scattering_vector[1], self.scattering_vector[-1])) + if log_x and self.qq[0] == 0.0: + ax.set_xlim((self.qq[1], self.qq[-1])) else: - ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) + ax.set_xlim((self.qq[0], self.qq[-1])) if returnfig: return fig, ax @@ -164,13 +168,13 @@ def plot_radial_var_norm( """ fig, ax = plt.subplots(figsize=figsize) ax.plot( - self.scattering_vector, + self.qq, self.radial_var_norm, ) - ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") ax.set_ylabel("Normalized Variance") - ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) + ax.set_xlim((self.qq[0], self.qq[-1])) if returnfig: return fig, ax @@ -205,7 +209,7 @@ def calculate_pair_dist_function( """ # init - k = self.scattering_vector + k = self.qq dk = k[1] - k[0] k2 = k**2 Ik = self.radial_mean @@ -325,7 +329,7 @@ def calculate_pair_dist_function( if plot_fits: fig, ax = plt.subplots(figsize=figsize) ax.plot( - self.scattering_vector, + self.qq, self.radial_mean, color="k", ) @@ -334,9 +338,9 @@ def calculate_pair_dist_function( bg, color="r", ) - ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") ax.set_ylabel("Radial Mean") - ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) + ax.set_xlim((self.qq[0], self.qq[-1])) # ax.set_ylim((0,2e-5)) ax.set_xlabel("Scattering Vector [A^-1]") ax.set_ylabel("I(k) and Fit Estimates") From ab76946469b565a1bb1c868a21f869a84999971f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 16 Oct 2023 04:38:00 -0700 Subject: [PATCH 065/176] added read-write functionality to parralax --- py4DSTEM/process/phase/iterative_parallax.py | 134 +++++++++++++++++-- 1 file changed, 125 insertions(+), 9 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index a8cbd0998..ff6fb52af 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from emdfile import Custom, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec from py4DSTEM import DataCube from py4DSTEM.preprocess.utils import get_shifted_ar @@ -75,6 +75,8 @@ def __init__( else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_save_defaults() + # Data self._datacube = datacube @@ -88,9 +90,68 @@ def __init__( def to_h5(self, group): """ Wraps datasets and metadata to write in emdfile classes, - notably ... + notably the (subpixel-)aligned BF. """ - raise NotImplementedError() + # instantiation metadata + self.metadata = Metadata( + name="instantiation_metadata", + data={ + "energy": self._energy, + "verbose": self._verbose, + "device": self._device, + "object_padding_px": self._object_padding_px, + "name": self.name, + }, + ) + + # preprocessing metadata + self.metadata = Metadata( + name="preprocess_metadata", + data={ + "scan_sampling": self._scan_sampling, + "wavelength": self._wavelength, + }, + ) + + # reconstruction metadata + recon_metadata = {"reconstruction_error": float(self._recon_error)} + + if hasattr(self, "aberration_C1"): + recon_metadata |= { + "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_C1": self.aberration_C1, + "aberration_A1x": self.aberration_A1x, + "aberration_A1y": self.aberration_A1y, + } + + if hasattr(self, "_kde_upsample_factor"): + recon_metadata |= { + "kde_upsample_factor": self._kde_upsample_factor, + } + self._subpixel_aligned_BF_emd = Array( + name="subpixel_aligned_BF", + data=self._asnumpy(self._recon_BF_subpixel_aligned), + ) + + self.metadata = Metadata( + name="reconstruction_metadata", + data=recon_metadata, + ) + + self._aligned_BF_emd = Array( + name="aligned_BF", + data=self._asnumpy(self._recon_BF), + ) + + # datacube + if self._save_datacube: + self.metadata = self._datacube.calibration + Custom.to_h5(self, group) + else: + dc = self._datacube + self._datacube = None + Custom.to_h5(self, group) + self._datacube = dc @classmethod def _get_constructor_args(cls, group): @@ -98,14 +159,67 @@ def _get_constructor_args(cls, group): Returns a dictionary of arguments/values to pass to the class' __init__ function """ - raise NotImplementedError() + # Get data + dict_data = cls._get_emd_attr_data(cls, group) + + # Get metadata dictionaries + instance_md = _read_metadata(group, "instantiation_metadata") + + # Fix calibrations bug + if "_datacube" in dict_data: + calibrations_dict = _read_metadata(group, "calibration")._params + cal = Calibration() + cal._params.update(calibrations_dict) + dc = dict_data["_datacube"] + dc.calibration = cal + else: + dc = None + + # Populate args and return + kwargs = { + "datacube": dc, + "energy": instance_md["energy"], + "verbose": instance_md["verbose"], + "device": instance_md["device"], + "object_padding_px": instance_md["object_padding_px"], + "name": instance_md["name"], + } + + return kwargs def _populate_instance(self, group): """ Sets post-initialization properties, notably some preprocessing meta optional; during read, this method is run after object instantiation. """ - raise NotImplementedError() + + xp = self._xp + + # Preprocess metadata + preprocess_md = _read_metadata(group, "preprocess_metadata") + self._scan_sampling = preprocess_md["scan_sampling"] + self._wavelength = preprocess_md["wavelength"] + + # Reconstruction metadata + reconstruction_md = _read_metadata(group, "reconstruction_metadata") + self._recon_error = reconstruction_md["reconstruction_error"] + + # Data + dict_data = Custom._get_emd_attr_data(Custom, group) + + if "aberration_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.aberration_C1 = reconstruction_md["aberration_C1"] + self.aberration_A1x = reconstruction_md["aberration_A1x"] + self.aberration_A1y = reconstruction_md["aberration_A1y"] + + if "kde_upsample_factor" in reconstruction_md.keys: + self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] + self._recon_BF_subpixel_aligned = xp.asarray( + dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32 + ) + + self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) def preprocess( self, @@ -1630,11 +1744,11 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object( - self._recon_BF, remaining_padding, upsampled - ) - if upsampled: + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, remaining_padding, upsampled + ) + extent = [ 0, self._scan_sampling[1] @@ -1647,6 +1761,8 @@ def _visualize_figax( ] else: + cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + extent = [ 0, self._scan_sampling[1] * cropped_object.shape[1], From c08f84ca753d5aba414923ff7bf272bab2917d8c Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 12:53:22 +0100 Subject: [PATCH 066/176] adds documentation --- py4DSTEM/process/polar/polar_analysis.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index ce478887f..01ba57a2f 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -21,8 +21,23 @@ def calculate_radial_statistics( Calculate the radial statistics used in fluctuation electron microscopy (FEM) and as an initial step in radial distribution function (RDF) calculation. The computed quantities are the radial mean, variance, and normalized variance. - Each signal is calculated using the original FEM definitions - [[TODO: add reference]], i.e. pattern-by-pattern. + + There are several ways the means and variances can be computed. Here we first + compute the mean and standard deviation pattern by pattern, i.e. for + diffraction signal d(x,y; q,theta) we take + + d_mean_all(x,y; q) = \int_{0}^{2\pi} d(x,y; q,\theta) d\theta + d_var_all(x,y; q) = \int_{0}^{2\pi} + \( d(x,y; q,\theta) - d_mean_all(x,y; q,\theta) \)^2 d\theta + + Then we find the mean and variance profiles by taking the means of these + quantities over all scan positions: + + d_mean(q) = \sum_{x,y} d_mean_all(x,y; q) + d_var(q) = \sum_{x,y} d_var_all(x,y; q) + + and the normalized variance is d_var/d_mean. + Parameters -------- @@ -65,7 +80,7 @@ def calculate_radial_statistics( ) ) - # Compute the radial mean for each probe position + # Compute the radial mean and standard deviation for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], From b4dbba4455acbf9a3d65337d353b0ba2e23f95e3 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 12:56:29 +0100 Subject: [PATCH 067/176] fixes variance normalization bug --- py4DSTEM/process/polar/polar_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 01ba57a2f..f88c57aa2 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -98,7 +98,7 @@ def calculate_radial_statistics( (self.radial_all - self.radial_mean[None, None]) ** 2, axis=(0, 1) ) - self.radial_var_norm = self.radial_var + self.radial_var_norm = np.copy(self.radial_var) sub = self.radial_mean > 0.0 self.radial_var_norm[sub] /= self.radial_mean[sub] ** 2 From 986c1218d3c95e3a3173909150a9863d2401668a Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 13:14:09 +0100 Subject: [PATCH 068/176] corrects calculate_statistics return behavior --- py4DSTEM/process/polar/polar_analysis.py | 57 ++++++++++-------------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index f88c57aa2..322c67820 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -50,7 +50,8 @@ def calculate_radial_statistics( returnval: bool Toggles returning the answer. Answers are always stored internally. returnfig: bool - Toggles returning figures + Toggles returning figures that have been plotted. Only figures for + which `plot_results_*` is True are returned. Returns -------- @@ -102,43 +103,33 @@ def calculate_radial_statistics( sub = self.radial_mean > 0.0 self.radial_var_norm[sub] /= self.radial_mean[sub] ** 2 + # prepare answer + statistics = self.radial_mean, self.radial_var, self.radial_var_norm + if returnval: + ans = statistics if not returnfig else [statistics] + else: + ans = None if not returnfig else [] + # plot results if plot_results_mean: + fig, ax = plot_radial_mean( + self, + figsize=figsize, + returnfig=True, + ) if returnfig: - fig, ax = plot_radial_mean( - self, - figsize=figsize, - returnfig=True, - ) - else: - plot_radial_mean( - self, - figsize=figsize, - ) - elif plot_results_var: + ans.append((fig,ax)) + if plot_results_var: + fig, ax = plot_radial_var_norm( + self, + figsize=figsize, + returnfig=True, + ) if returnfig: - fig, ax = plot_radial_var_norm( - self, - figsize=figsize, - returnfig=True, - ) - else: - plot_radial_var_norm( - self, - figsize=figsize, - ) + ans.append((fig,ax)) - # Return values - if returnval: - if returnfig: - return self.radial_mean, self.radial_var, fig, ax - else: - return self.radial_mean, self.radial_var - else: - if returnfig: - return fig, ax - else: - pass + # return + return ans def plot_radial_mean( From 3cac12c27b2da981ebf9870c7b4c421b0c1787c8 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 13:19:27 +0100 Subject: [PATCH 069/176] adds documentation --- py4DSTEM/process/polar/polar_analysis.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 322c67820..4efe25f5e 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -140,7 +140,18 @@ def plot_radial_mean( returnfig=False, ): """ - Plot radial mean + Plot the radial means. + + Parameters + ---------- + log_x : bool + Toggle log scaling of the x-axis + log_y : bool + Toggle log scaling of the y-axis + figsize : 2-tuple + Size of the output figure + returnfig : bool + Toggle returning the figure """ fig, ax = plt.subplots(figsize=figsize) ax.plot( @@ -170,7 +181,15 @@ def plot_radial_var_norm( returnfig=False, ): """ - Plotting function for the global FEM. + Plot the radial variances. + + Parameters + ---------- + figsize : 2-tuple + Size of the output figure + returnfig : bool + Toggle returning the figure + """ fig, ax = plt.subplots(figsize=figsize) ax.plot( From 9bcb4702033fee27f05d67a598f6f0a606c652fb Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 13:43:41 +0100 Subject: [PATCH 070/176] adds documentation --- py4DSTEM/process/polar/polar_analysis.py | 42 +++++++++++++++++------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 4efe25f5e..59f54a4f4 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -212,14 +212,10 @@ def calculate_pair_dist_function( k_width=0.25, k_lowpass=None, k_highpass=None, - # k_pad_max = 10.0, r_min=0.0, r_max=20.0, r_step=0.02, damp_origin_fluctuations=False, - # poly_background_order = 2, - # iterative_pdf_refine = True, - # num_iter = 10, dens=None, plot_fits=False, plot_sf_estimate=False, @@ -231,9 +227,21 @@ def calculate_pair_dist_function( """ Calculate the pair distribution function (PDF). + Parameters + ---------- + k_min : number + minimum scattering vector to include in the calculation + k_max : number or None + maximum scattering vector to include in the calculation. Note that + this cutoff is *not* used when estimating the background and single + atom scattering factor, which is best estimated from high scattering + lengths. + k_width : number + xxx + """ - # init + # set up coordinates and scaling k = self.qq dk = k[1] - k[0] k2 = k**2 @@ -241,7 +249,7 @@ def calculate_pair_dist_function( int_mean = np.mean(Ik) sub_fit = k >= k_min - # initial coefs + # initial guesses for background coefs const_bg = np.min(self.radial_mean) / int_mean int0 = np.median(self.radial_mean) / int_mean - const_bg sigma0 = np.mean(k) @@ -278,7 +286,7 @@ def calculate_pair_dist_function( coefs[1] *= int_mean coefs[3] *= int_mean - # Calculate the mean atomic form factor wthout any background + # Calculate the mean atomic form factor without a constant offset coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) fk = scattering_model(k2, coefs_fk) bg = scattering_model(k2, coefs) @@ -286,10 +294,6 @@ def calculate_pair_dist_function( # mask for structure factor estimate if k_max is None: k_max = np.max(k) - # mask = np.clip(np.minimum( - # (k - k_min) / k_width, - # (k_max - k) / k_width, - # ),0,1) mask = np.clip( np.minimum( (k - 0.0) / k_width, @@ -461,6 +465,21 @@ def calculate_FEM_local( def scattering_model(k2, *coefs): + """ + The scattering model used to fit the PDF background. The fit + function is a constant plus two exponentials - one in k^2 and one + in k^4: + + f(k; c,i0,s0,i1,s1) = + c + i0*exp(k^2/-2*s0^2) + i1*exp(k^4/-2*s1^4) + + Parameters + ---------- + k2 : 1d array + the scattering vector squared + coefs : 5-tuple + Initial guesses at the parameters (c,i0,s0,i1,s1) + """ coefs = np.squeeze(np.array(coefs)) const_bg = coefs[0] @@ -480,3 +499,4 @@ def scattering_model(k2, *coefs): # int1*np.exp(k2/(-2*sigma1**2)) return int_model + From 8a009c416cfb44653a2c7138e4bf776a23dbaff5 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 14:06:45 +0100 Subject: [PATCH 071/176] adds documentation --- py4DSTEM/process/polar/polar_analysis.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 59f54a4f4..2d73d0395 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -230,15 +230,20 @@ def calculate_pair_dist_function( Parameters ---------- k_min : number - minimum scattering vector to include in the calculation + Minimum scattering vector to include in the calculation k_max : number or None - maximum scattering vector to include in the calculation. Note that - this cutoff is *not* used when estimating the background and single - atom scattering factor, which is best estimated from high scattering - lengths. + Maximum scattering vector to include in the calculation. Note that + this cutoff is used when calculating the structure factor - however it + is *not* used when estimating the background / single atom scattering + factor, which is best estimated from high scattering lengths. k_width : number - xxx - + The fitting window for the structure factor calculation [k_min,k_max] + includes a damped region at its edges, i.e. the signal is smoothly dampled + to zero in the regions [k_min, k_min+k_width] and [k_max-k_width,k_max] + k_lowpass : number or None + Lowpass filter, in units the scattering vector stepsize (i.e. self.qstep) + k_highpass : number or None + Highpass filter, in units the scattering vector stepsize (i.e. self.qstep) """ # set up coordinates and scaling @@ -290,6 +295,9 @@ def calculate_pair_dist_function( coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) fk = scattering_model(k2, coefs_fk) bg = scattering_model(k2, coefs) + # @cophus: + # can we eliminate recalculating the model with a modified offset by + # just subtracting off the constant offset from bg ? # mask for structure factor estimate if k_max is None: From 9c5bfc8eacce708262acd0adf63ea94a7de810fa Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 14:19:10 +0100 Subject: [PATCH 072/176] simplifies + removes unneeded computation --- py4DSTEM/process/polar/polar_analysis.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 2d73d0395..5523d8c19 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -292,12 +292,10 @@ def calculate_pair_dist_function( coefs[3] *= int_mean # Calculate the mean atomic form factor without a constant offset - coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) - fk = scattering_model(k2, coefs_fk) + #coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) + #fk = scattering_model(k2, coefs_fk) bg = scattering_model(k2, coefs) - # @cophus: - # can we eliminate recalculating the model with a modified offset by - # just subtracting off the constant offset from bg ? + fk = bg - coefs[0] # mask for structure factor estimate if k_max is None: @@ -322,7 +320,7 @@ def calculate_pair_dist_function( # Filtering of S(k) if k_lowpass is not None and k_lowpass > 0.0: Sk = gaussian_filter(Sk, sigma=k_lowpass / dk, mode="nearest") - if k_highpass is not None: + if k_highpass is not None and k_highpass > 0.0: Sk_lowpass = gaussian_filter(Sk, sigma=k_highpass / dk, mode="nearest") Sk -= Sk_lowpass From 1f41a1e3d3683a661a84d2e3a4743de397c4ca73 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 17:12:37 +0100 Subject: [PATCH 073/176] adds documentation --- py4DSTEM/process/polar/polar_analysis.py | 47 +++++++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 5523d8c19..38e221fdd 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -227,6 +227,30 @@ def calculate_pair_dist_function( """ Calculate the pair distribution function (PDF). + First a background is calculated using primarily the signal at the highest + scattering vectors available, given by a sum of two exponentials ~exp(-k^2) + and ~exp(-k^4) modelling the single atom scattering factor plus a constant + offset. Next, the structure factor is computed as + + S(k) = (I(k) - bg(k)) * k / f(k) + + where k is the magnitude of the scattering vector, I(k) is the mean radial + signal, f(k) is the single atom scattering factor, and bg(k) is the total + background signal (i.e. f(k) plus a constant offset). S(k) is masked outside + of the selected fitting region of k-values [k_min,k_max] and low/high pass + filters are optionally applied. The structure factor is then inverted into + the reduced pair distribution function g(r) using + + g(r) = \frac{2}{\pi) \int sin( 2\pi r k ) S(k) dk + + The value of the integral is (optionally) damped to zero at the origin to + match the physical requirement that this condition holds. Finally, the + full PDF G(r) is computed if a known dens is provided, using + + G(r) = 1 + [ \frac{2}{\pi} * g(r) / ( 4\pi * D * r dr ) ] + + + Parameters ---------- k_min : number @@ -244,6 +268,25 @@ def calculate_pair_dist_function( Lowpass filter, in units the scattering vector stepsize (i.e. self.qstep) k_highpass : number or None Highpass filter, in units the scattering vector stepsize (i.e. self.qstep) + r_min,r_max,r_step : numbers + Define the real space coordinates r that the PDF g(r) will be computed in. + The coordinates will be np.arange(r_min,r_max,r_step), given in units + inverse to the scattering vector units. + damp_origin_fluctuations : bool + The value of the PDF approaching the origin should be zero, however numerical + instability may result in non-physical finite values there. This flag toggles + damping the value of the PDF to zero near the origin. + dens : number or None + The dens of the sample, if known. If this is not provided, only the + reduced PDF is calculated. If this value is provided, the PDF is also + calculated. + plot_fits : bool + plot_sf_estimate : bool + plot_reduced_pdf=True : bool + plot_pdf : bool + figsize : 2-tuple + maxfev : integer or None + Max number of iterations to use when fitting the background """ # set up coordinates and scaling @@ -324,7 +367,7 @@ def calculate_pair_dist_function( Sk_lowpass = gaussian_filter(Sk, sigma=k_highpass / dk, mode="nearest") Sk -= Sk_lowpass - # Calculate the real space PDF + # Calculate the PDF r = np.arange(r_min, r_max, r_step) ra, ka = np.meshgrid(r, k) pdf_reduced = ( @@ -348,7 +391,7 @@ def calculate_pair_dist_function( self.pdf_r = r self.pdf_reduced = pdf_reduced - # if density is provided, we can estimate the full PDF + # if dens is provided, we can estimate the full PDF if dens is not None: pdf = pdf_reduced.copy() pdf[1:] /= 4 * np.pi * dens * r[1:] * (r[1] - r[0]) From 7675945b7ba2c236a2ee31c6cbd6427e208d8a02 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 17:13:09 +0100 Subject: [PATCH 074/176] 'dens'->'density' --- py4DSTEM/process/polar/polar_analysis.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 38e221fdd..71033475b 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -216,7 +216,7 @@ def calculate_pair_dist_function( r_max=20.0, r_step=0.02, damp_origin_fluctuations=False, - dens=None, + density=None, plot_fits=False, plot_sf_estimate=False, plot_reduced_pdf=True, @@ -245,7 +245,7 @@ def calculate_pair_dist_function( The value of the integral is (optionally) damped to zero at the origin to match the physical requirement that this condition holds. Finally, the - full PDF G(r) is computed if a known dens is provided, using + full PDF G(r) is computed if a known density is provided, using G(r) = 1 + [ \frac{2}{\pi} * g(r) / ( 4\pi * D * r dr ) ] @@ -276,8 +276,8 @@ def calculate_pair_dist_function( The value of the PDF approaching the origin should be zero, however numerical instability may result in non-physical finite values there. This flag toggles damping the value of the PDF to zero near the origin. - dens : number or None - The dens of the sample, if known. If this is not provided, only the + density : number or None + The density of the sample, if known. If this is not provided, only the reduced PDF is calculated. If this value is provided, the PDF is also calculated. plot_fits : bool @@ -391,10 +391,10 @@ def calculate_pair_dist_function( self.pdf_r = r self.pdf_reduced = pdf_reduced - # if dens is provided, we can estimate the full PDF - if dens is not None: + # if density is provided, we can estimate the full PDF + if density is not None: pdf = pdf_reduced.copy() - pdf[1:] /= 4 * np.pi * dens * r[1:] * (r[1] - r[0]) + pdf[1:] /= 4 * np.pi * density * r[1:] * (r[1] - r[0]) pdf *= 2 / np.pi pdf += 1 From 51574e010dc45e05b80716043d896fc1eeebef77 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 17:13:42 +0100 Subject: [PATCH 075/176] makes damping origin fluctuations default behaavior --- py4DSTEM/process/polar/polar_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 71033475b..e301b56cf 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -215,7 +215,7 @@ def calculate_pair_dist_function( r_min=0.0, r_max=20.0, r_step=0.02, - damp_origin_fluctuations=False, + damp_origin_fluctuations=True, density=None, plot_fits=False, plot_sf_estimate=False, From 76a8644e46b627b3a04fe9eaa4c8f26384cdb09b Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 17:19:38 +0100 Subject: [PATCH 076/176] stores S(k),f(k),other intermediate vals --- py4DSTEM/process/polar/polar_analysis.py | 28 +++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index e301b56cf..fc6e71cf2 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -287,6 +287,11 @@ def calculate_pair_dist_function( figsize : 2-tuple maxfev : integer or None Max number of iterations to use when fitting the background + returnval: bool + Toggles returning the answer. Answers are always stored internally. + returnfig: bool + Toggles returning figures that have been plotted. Only figures for + which `plot_*` is True are returned. """ # set up coordinates and scaling @@ -391,6 +396,11 @@ def calculate_pair_dist_function( self.pdf_r = r self.pdf_reduced = pdf_reduced + self.Sk = Sk + self.fk = fk + self.bg = bg + self.offset = coefs[0] + # if density is provided, we can estimate the full PDF if density is not None: pdf = pdf_reduced.copy() @@ -398,11 +408,27 @@ def calculate_pair_dist_function( pdf *= 2 / np.pi pdf += 1 + # damp and clip values below zero if damp_origin_fluctuations: pdf *= r_mask - pdf = np.maximum(pdf, 0.0) + # store results + self.pdf = pdf + + + # prepare answer + if density is None: + return_values = self.pdf_r, self.pdf_reduced + else: + return_values = self.pdf_r, self.pdf_reduced + if returnval: + ans = statistics if not returnfig else [statistics] + else: + ans = None if not returnfig else [] + + + # Plots if plot_fits: fig, ax = plt.subplots(figsize=figsize) From 7be77be3a577858d78792650463c634c74827551 Mon Sep 17 00:00:00 2001 From: Steve Zeltmann <37132012+sezelt@users.noreply.github.com> Date: Mon, 16 Oct 2023 12:32:51 -0400 Subject: [PATCH 077/176] Apply suggestions from code review Co-authored-by: Georgios Varnavides --- py4DSTEM/process/diffraction/crystal.py | 757 +++++++++--------- py4DSTEM/process/diffraction/crystal_phase.py | 1 - 2 files changed, 400 insertions(+), 358 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 1c43f89bc..041b111f9 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1075,42 +1075,23 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp - - -def generate_moire( +def generate_moire_diffraction_pattern( bragg_peaks_0, bragg_peaks_1, thresh_0=0.0002, thresh_1=0.0002, - int_range=(0, 5e-3), exx_1=0.0, eyy_1=0.0, exy_1=0.0, phi_1=0.0, power=2.0, k_max=1.0, - plot_result=True, - plot_subpixel=True, - labels=None, - marker_size_parent=16, - marker_size_moire=4, - text_size_parent=10, - text_size_moire=6, - add_labels_parent=False, - add_labels_moire=False, - dist_labels=0.03, - dist_check=0.06, - sep_labels=0.03, - figsize=(8, 6), - return_moire=False, - returnfig=False, ): """ Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated and strained with respect to the original lattice. Note that this strain is applied in real space, and so the inverse of the calculated infinitestimal strain tensor is applied. - - + Parameters -------- bragg_peaks_0: BraggVector @@ -1121,8 +1102,6 @@ def generate_moire( Intensity threshold for structure factors from lattice 0. thresh_1: float Intensity threshold for structure factors from lattice 1. - int_range: (float, float) - Plotting intensity range for the Moire peaks. exx_1: float Strain of lattice 1 in x direction (vertical) in real space. eyy_1: float @@ -1135,51 +1114,14 @@ def generate_moire( Plotting power law (default is amplitude**2.0, i.e. intensity). k_max: float Max k value of the calculated (and plotted) Moire lattice. - plot_result: bool - Plot the resulting Moire lattice. - plot_subpixel: bool - Apply subpixel corrections to the Bragg spot positions. - Matplotlib default scatter plot rounds to the nearest pixel. - labels: list - List of text labels for parent lattices - marker_size_parent: float - Size of plot markers for the two parent lattices. - marker_size_moire: float - Size of plot markers for the Moire lattice. - text_size_parent: float - Label text size for parent lattice. - text_size_moire: float - Label text size for Moire lattice. - add_labels_parent: bool - Plot the parent lattice index labels. - add_labels_moire: bool - Plot the parent lattice index labels for the Moire spots. - dist_labels: float - Distance to move the labels off the spots. - dist_check: float - Set to some distance to "push" the labels away from each other if they are within this distance. - sep_labels: float - Separation distance for labels which are "pushed" apart. - figsize: (float,float) - Size of output figure. - return_moire: bool - Return the moire lattice as a pointlist. - returnfig: bool - Return the (fix,ax) handles of the plot. - + Returns -------- - bragg_peaksMoire: BraggVector (optjonal) + bragg_moire: BraggVector Bragg vectors for moire lattice. - fig, ax: matplotlib handles (optional) - Figure and axes handles for the moire plot. - + """ - - # peak labels - if labels is None: - labels = ("crystal 0", "crystal 1") - + # get intenties of all peaks int0 = bragg_peaks_0["intensity"] ** (power / 2.0) int1 = bragg_peaks_1["intensity"] ** (power / 2.0) @@ -1203,24 +1145,21 @@ def generate_moire( qy1_init = bragg_peaks_1["qy"][sub1] # peak labels - if add_labels_parent or add_labels_moire or return_moire: - - def overline(x): - return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") - - h0 = bragg_peaks_0["h"][sub0] - k0 = bragg_peaks_0["k"][sub0] - l0 = bragg_peaks_0["l"][sub0] - h1 = bragg_peaks_1["h"][sub1] - k1 = bragg_peaks_1["k"][sub1] - l1 = bragg_peaks_1["l"][sub1] + h0 = bragg_peaks_0["h"][sub0] + k0 = bragg_peaks_0["k"][sub0] + l0 = bragg_peaks_0["l"][sub0] + h1 = bragg_peaks_1["h"][sub1] + k1 = bragg_peaks_1["k"][sub1] + l1 = bragg_peaks_1["l"][sub1] # apply strain tensor to lattice 1 + # infinitesimal # m = np.array([ # [1 + exx_1, (exy_1 - phi_1)*0.5], - # [(exy_1 _ phi_1)*0.5, 1 + eyy_1], + # [(exy_1 - phi_1)*0.5, 1 + eyy_1], # ]) + # finite rotation m = np.array( [ @@ -1252,308 +1191,412 @@ def overline(x): int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 # moire labels - if add_labels_moire or return_moire: - m_h0 = h0[ind0] - m_k0 = k0[ind0] - m_l0 = l0[ind0] - m_h1 = h1[ind1] - m_k1 = k1[ind1] - m_l1 = l1[ind1] - - # If needed, convert moire peaks to BraggVector class - if return_moire: - pl_dtype = np.dtype( - [ - ("qx", "float"), - ("qy", "float"), - ("intensity", "float"), - ("h0", "int"), - ("k0", "int"), - ("l0", "int"), - ("h1", "int"), - ("k1", "int"), - ("l1", "int"), - ] - ) - bragg_moire = PointList(np.array([], dtype=pl_dtype)) - bragg_moire.add_data_by_field( - [ - qx.ravel(), - qy.ravel(), - int_moire.ravel(), - m_h0.ravel(), - m_k0.ravel(), - m_l0.ravel(), - m_h1.ravel(), - m_k1.ravel(), - m_l1.ravel(), - ] - ) - - # plot outputs - if plot_result: - fig = plt.figure(figsize=figsize) - ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) - ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) - - text_params_parent = { - "ha": "center", - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "size": text_size_parent, - } - text_params_moire = { - "ha": "center", - "va": "center", - "family": "sans-serif", - "fontweight": "normal", - "size": text_size_moire, - } - - if plot_subpixel is False: - # moire - ax.scatter( - qy, - qx, - # color = (0,0,0,1), - c=int_moire, - s=marker_size_moire, - cmap="gray_r", - vmin=int_range[0], - vmax=int_range[1], - antialiased=True, - ) - - # parent lattices - ax.scatter( - qy0, - qx0, - color=(1, 0, 0, 1), - s=marker_size_parent, - antialiased=True, - ) - ax.scatter( - qy1, - qx1, - color=(0, 0.7, 1, 1), - s=marker_size_parent, - antialiased=True, - ) + m_h0 = h0[ind0] + m_k0 = k0[ind0] + m_l0 = l0[ind0] + m_h1 = h1[ind1] + m_k1 = k1[ind1] + m_l1 = l1[ind1] + + # Convert thresholded and moire peaks to BraggVector class + + pl_dtype_parent = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h", "int"), + ("k", "int"), + ("l", "int"), + ] + ) + + bragg_parent_0 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_0.add_data_by_field( + [ + qx0.ravel(), + qy0.ravel(), + int0_sub.ravel(), + h0.ravel(), + k0.ravel(), + l0.ravel(), + ] + ) - # origin - ax.scatter( - 0, - 0, - color=(0, 0, 0, 1), - s=marker_size_parent, - antialiased=True, - ) + bragg_parent_1 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_1.add_data_by_field( + [ + qx1.ravel(), + qy1.ravel(), + int1_sub.ravel(), + h1.ravel(), + k1.ravel(), + l1.ravel(), + ] + ) + + pl_dtype = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ] + ) + bragg_moire = PointList(np.array([], dtype=pl_dtype)) + bragg_moire.add_data_by_field( + [ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(), + m_k0.ravel(), + m_l0.ravel(), + m_h1.ravel(), + m_k1.ravel(), + m_l1.ravel(), + ] + ) + + return bragg_parent_0, bragg_parent_1, bragg_moire - else: - # moire peaks - int_all = np.clip( - (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1 - ) - keep = np.logical_and.reduce( - (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max) - ) - for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): - ax.add_artist( - Circle( - xy=(y, x), - radius=np.sqrt(marker_size_moire) / 800.0, - color=(1 - int_marker, 1 - int_marker, 1 - int_marker), - ) - ) - if add_labels_moire: - for a0 in range(qx.size): - if keep.ravel()[a0]: - x0 = qx.ravel()[a0] - y0 = qy.ravel()[a0] - d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2 - sub = d2 < dist_check**2 - xc = np.mean(qx.ravel()[sub]) - yc = np.mean(qy.ravel()[sub]) - xp = x0 - xc - yp = y0 - yc - if xp == 0 and yp == 0.0: - xp = x0 - dist_labels - yp = y0 - else: - leng = np.linalg.norm((xp, yp)) - xp = x0 + xp * dist_labels / leng - yp = y0 + yp * dist_labels / leng - - ax.text( - yp, - xp - sep_labels, - "$" - + overline(m_h0.ravel()[a0]) - + overline(m_k0.ravel()[a0]) - + overline(m_l0.ravel()[a0]) - + "$", - c="r", - **text_params_moire, - ) - ax.text( - yp, - xp, - "$" - + overline(m_h1.ravel()[a0]) - + overline(m_k1.ravel()[a0]) - + overline(m_l1.ravel()[a0]) - + "$", - c=(0, 0.7, 1.0), - **text_params_moire, - ) - - keep = np.logical_and.reduce( - (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max) - ) - for x, y in zip(qx0[keep], qy0[keep]): - ax.add_artist( - Circle( - xy=(y, x), - radius=np.sqrt(marker_size_parent) / 800.0, - color=(1, 0, 0), - ) - ) - if add_labels_parent: - for a0 in range(qx0.size): - if keep.ravel()[a0]: - xp = qx0.ravel()[a0] - dist_labels - yp = qy0.ravel()[a0] - ax.text( - yp, - xp, - "$" - + overline(h0.ravel()[a0]) - + overline(k0.ravel()[a0]) - + overline(l0.ravel()[a0]) - + "$", - c="k", - **text_params_parent, - ) - - keep = np.logical_and.reduce( - (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max) - ) - for x, y in zip(qx1[keep], qy1[keep]): - ax.add_artist( - Circle( - xy=(y, x), - radius=np.sqrt(marker_size_parent) / 800.0, - color=(0, 0.7, 1), - ) - ) - if add_labels_parent: - for a0 in range(qx1.size): - if keep.ravel()[a0]: - xp = qx1.ravel()[a0] - dist_labels - yp = qy1.ravel()[a0] - ax.text( - yp, - xp, - "$" - + overline(h1.ravel()[a0]) - + overline(k1.ravel()[a0]) - + overline(l1.ravel()[a0]) - + "$", - c="k", - **text_params_parent, - ) - - # origin - ax.add_artist( - Circle( - xy=(0, 0), - radius=np.sqrt(marker_size_parent) / 800.0, - color=(0, 0, 0), - ) - ) +def plot_moire_diffraction_pattern( + bragg_parent_0, + bragg_parent_1, + bragg_moire, + int_range=(0, 5e-3), + k_max=1.0, + plot_subpixel=True, + labels=None, + marker_size_parent=16, + marker_size_moire=4, + text_size_parent=10, + text_size_moire=6, + add_labels_parent=False, + add_labels_moire=False, + dist_labels=0.03, + dist_check=0.06, + sep_labels=0.03, + figsize=(8, 6), + returnfig=False, +): + """ + Plot Moire lattice and parent lattices. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + bragg_moire: BraggVector + Bragg vectors for moire lattice. + int_range: (float, float) + Plotting intensity range for the Moire peaks. + k_max: float + Max k value of the plotted Moire lattice. + plot_subpixel: bool + Apply subpixel corrections to the Bragg spot positions. + Matplotlib default scatter plot rounds to the nearest pixel. + labels: list + List of text labels for parent lattices + marker_size_parent: float + Size of plot markers for the two parent lattices. + marker_size_moire: float + Size of plot markers for the Moire lattice. + text_size_parent: float + Label text size for parent lattice. + text_size_moire: float + Label text size for Moire lattice. + add_labels_parent: bool + Plot the parent lattice index labels. + add_labels_moire: bool + Plot the parent lattice index labels for the Moire spots. + dist_labels: float + Distance to move the labels off the spots. + dist_check: float + Set to some distance to "push" the labels away from each other if they are within this distance. + sep_labels: float + Separation distance for labels which are "pushed" apart. + figsize: (float,float) + Size of output figure. + returnfig: bool + Return the (fix,ax) handles of the plot. + + Returns + -------- + fig, ax: matplotlib handles (optional) + Figure and axes handles for the moire plot. + """ - ax.set_xlim((-k_max, k_max)) - ax.set_ylim((-k_max, k_max)) - ax.set_ylabel("$q_x$ (1/A)") - ax.set_xlabel("$q_y$ (1/A)") - ax.invert_yaxis() + # peak labels + + if labels is None: + labels = ("crystal 0", "crystal 1") + + def overline(x): + return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") + + # parent 1 + qx0 = bragg_parent_0["qx"] + qy0 = bragg_parent_0["qy"] + h0 = bragg_parent_0["h"] + k0 = bragg_parent_0["k"] + l0 = bragg_parent_0["l"] + + # parent 2 + qx1 = bragg_parent_1["qx"] + qy1 = bragg_parent_1["qy"] + h1 = bragg_parent_1["h"] + k1 = bragg_parent_1["k"] + l1 = bragg_parent_1["l"] + + # moire + qx = bragg_moire["qx"] + qy = bragg_moire["qy"] + m_h0 = bragg_moire["h0"] + m_k0 = bragg_moire["k0"] + m_l0 = bragg_moire["l0"] + m_h1 = bragg_moire["h1"] + m_k1 = bragg_moire["k1"] + m_l1 = bragg_moire["l1"] + int_moire = bragg_moire["intensity"] + + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) + ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) + + text_params_parent = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_parent, + } + text_params_moire = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_moire, + } + + if plot_subpixel is False: + + # moire + ax.scatter( + qy, + qx, + # color = (0,0,0,1), + c=int_moire, + s=marker_size_moire, + cmap="gray_r", + vmin=int_range[0], + vmax=int_range[1], + antialiased=True, + ) - # labels - ax_labels.scatter( - 0, - 0, + # parent lattices + ax.scatter( + qy0, + qx0, color=(1, 0, 0, 1), s=marker_size_parent, + antialiased=True, ) - ax_labels.scatter( - 0, - -1, + ax.scatter( + qy1, + qx1, color=(0, 0.7, 1, 1), s=marker_size_parent, + antialiased=True, ) - ax_labels.scatter( + + # origin + ax.scatter( + 0, 0, - -2, color=(0, 0, 0, 1), - s=marker_size_moire, - ) - ax_labels.text( - 0.4, - -0.2, - labels[0], - fontsize=14, + s=marker_size_parent, + antialiased=True, ) - ax_labels.text( - 0.4, - -1.2, - labels[1], - fontsize=14, + + else: + # moire peaks + int_all = np.clip( + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1 ) - ax_labels.text( - 0.4, - -2.2, - "Moiré lattice", - fontsize=14, + keep = np.logical_and.reduce( + (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max) ) + for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_moire) / 800.0, + color=(1 - int_marker, 1 - int_marker, 1 - int_marker), + ) + ) + if add_labels_moire: + for a0 in range(qx.size): + if keep.ravel()[a0]: + x0 = qx.ravel()[a0] + y0 = qy.ravel()[a0] + d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2 + sub = d2 < dist_check**2 + xc = np.mean(qx.ravel()[sub]) + yc = np.mean(qy.ravel()[sub]) + xp = x0 - xc + yp = y0 - yc + if xp == 0 and yp == 0.0: + xp = x0 - dist_labels + yp = y0 + else: + leng = np.linalg.norm((xp, yp)) + xp = x0 + xp * dist_labels / leng + yp = y0 + yp * dist_labels / leng + + ax.text( + yp, + xp - sep_labels, + "$" + + overline(m_h0.ravel()[a0]) + + overline(m_k0.ravel()[a0]) + + overline(m_l0.ravel()[a0]) + + "$", + c="r", + **text_params_moire, + ) + ax.text( + yp, + xp, + "$" + + overline(m_h1.ravel()[a0]) + + overline(m_k1.ravel()[a0]) + + overline(m_l1.ravel()[a0]) + + "$", + c=(0, 0.7, 1.0), + **text_params_moire, + ) - ax_labels.text( - 0, - -4.2, - labels[1] + " $\epsilon_{xx}$ = " + str(np.round(exx_1 * 100, 2)) + "%", - fontsize=14, - ) - ax_labels.text( - 0, - -5.2, - labels[1] + " $\epsilon_{yy}$ = " + str(np.round(eyy_1 * 100, 2)) + "%", - fontsize=14, + keep = np.logical_and.reduce( + (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max) ) - ax_labels.text( - 0, - -6.2, - labels[1] + " $\epsilon_{xy}$ = " + str(np.round(exy_1 * 100, 2)) + "%", - fontsize=14, + for x, y in zip(qx0[keep], qy0[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(1, 0, 0), + ) + ) + if add_labels_parent: + for a0 in range(qx0.size): + if keep.ravel()[a0]: + xp = qx0.ravel()[a0] - dist_labels + yp = qy0.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h0.ravel()[a0]) + + overline(k0.ravel()[a0]) + + overline(l0.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + keep = np.logical_and.reduce( + (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max) ) - ax_labels.text( - 0, - -7.2, - labels[1] - + " $\phi$ = " - + str(np.round(phi_1 * 180 / np.pi, 2)) - + "$^\circ$", - fontsize=14, + for x, y in zip(qx1[keep], qy1[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0.7, 1), + ) + ) + if add_labels_parent: + for a0 in range(qx1.size): + if keep.ravel()[a0]: + xp = qx1.ravel()[a0] - dist_labels + yp = qy1.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h1.ravel()[a0]) + + overline(k1.ravel()[a0]) + + overline(l1.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + # origin + ax.add_artist( + Circle( + xy=(0, 0), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0, 0), + ) ) - ax_labels.set_xlim((-1, 4)) - ax_labels.set_ylim((-21, 1)) + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + ax.set_ylabel("$q_x$ (1/A)") + ax.set_xlabel("$q_y$ (1/A)") + ax.invert_yaxis() + + # labels + ax_labels.scatter( + 0, + 0, + color=(1, 0, 0, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -2, + color=(0, 0, 0, 1), + s=marker_size_moire, + ) + ax_labels.text( + 0.4, + -0.2, + labels[0], + fontsize=14, + ) + ax_labels.text( + 0.4, + -1.2, + labels[1], + fontsize=14, + ) + ax_labels.text( + 0.4, + -2.2, + "Moiré lattice", + fontsize=14, + ) + + ax_labels.set_xlim((-1, 4)) + ax_labels.set_ylim((-21, 1)) - ax_labels.axis("off") + ax_labels.axis("off") - if return_moire: - if returnfig: - return bragg_moire, fig, ax - else: - return bragg_moire if returnfig: return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index b0cb1fe16..84824fe63 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -4,7 +4,6 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from dataclasses import dataclass, field from emdfile import tqdmnd, PointListArray from py4DSTEM.visualize import show, show_image_grid from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern From c7ad07ed7cc6edf02a800ac0e179116d2020f22f Mon Sep 17 00:00:00 2001 From: Steve Zeltmann Date: Mon, 16 Oct 2023 12:34:19 -0400 Subject: [PATCH 078/176] format with black --- py4DSTEM/process/diffraction/crystal.py | 38 +++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 041b111f9..a797bd166 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -1075,6 +1075,8 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp + + def generate_moire_diffraction_pattern( bragg_peaks_0, bragg_peaks_1, @@ -1091,7 +1093,7 @@ def generate_moire_diffraction_pattern( Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated and strained with respect to the original lattice. Note that this strain is applied in real space, and so the inverse of the calculated infinitestimal strain tensor is applied. - + Parameters -------- bragg_peaks_0: BraggVector @@ -1114,14 +1116,14 @@ def generate_moire_diffraction_pattern( Plotting power law (default is amplitude**2.0, i.e. intensity). k_max: float Max k value of the calculated (and plotted) Moire lattice. - + Returns -------- bragg_moire: BraggVector Bragg vectors for moire lattice. - + """ - + # get intenties of all peaks int0 = bragg_peaks_0["intensity"] ** (power / 2.0) int1 = bragg_peaks_1["intensity"] ** (power / 2.0) @@ -1153,13 +1155,13 @@ def generate_moire_diffraction_pattern( l1 = bragg_peaks_1["l"][sub1] # apply strain tensor to lattice 1 - + # infinitesimal # m = np.array([ # [1 + exx_1, (exy_1 - phi_1)*0.5], # [(exy_1 - phi_1)*0.5, 1 + eyy_1], # ]) - + # finite rotation m = np.array( [ @@ -1199,7 +1201,7 @@ def generate_moire_diffraction_pattern( m_l1 = l1[ind1] # Convert thresholded and moire peaks to BraggVector class - + pl_dtype_parent = np.dtype( [ ("qx", "float"), @@ -1210,7 +1212,7 @@ def generate_moire_diffraction_pattern( ("l", "int"), ] ) - + bragg_parent_0 = PointList(np.array([], dtype=pl_dtype_parent)) bragg_parent_0.add_data_by_field( [ @@ -1221,7 +1223,7 @@ def generate_moire_diffraction_pattern( k0.ravel(), l0.ravel(), ] - ) + ) bragg_parent_1 = PointList(np.array([], dtype=pl_dtype_parent)) bragg_parent_1.add_data_by_field( @@ -1233,8 +1235,8 @@ def generate_moire_diffraction_pattern( k1.ravel(), l1.ravel(), ] - ) - + ) + pl_dtype = np.dtype( [ ("qx", "float"), @@ -1262,9 +1264,10 @@ def generate_moire_diffraction_pattern( m_l1.ravel(), ] ) - + return bragg_parent_0, bragg_parent_1, bragg_moire + def plot_moire_diffraction_pattern( bragg_parent_0, bragg_parent_1, @@ -1287,7 +1290,7 @@ def plot_moire_diffraction_pattern( ): """ Plot Moire lattice and parent lattices. - + Parameters -------- bragg_peaks_0: BraggVector @@ -1327,7 +1330,7 @@ def plot_moire_diffraction_pattern( Size of output figure. returnfig: bool Return the (fix,ax) handles of the plot. - + Returns -------- fig, ax: matplotlib handles (optional) @@ -1335,10 +1338,10 @@ def plot_moire_diffraction_pattern( """ # peak labels - + if labels is None: labels = ("crystal 0", "crystal 1") - + def overline(x): return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") @@ -1366,7 +1369,7 @@ def overline(x): m_k1 = bragg_moire["k1"] m_l1 = bragg_moire["l1"] int_moire = bragg_moire["intensity"] - + fig = plt.figure(figsize=figsize) ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) @@ -1387,7 +1390,6 @@ def overline(x): } if plot_subpixel is False: - # moire ax.scatter( qy, From 5de83bf9401564b53140838757602fd5c7f3fdac Mon Sep 17 00:00:00 2001 From: Steve Zeltmann Date: Mon, 16 Oct 2023 12:45:53 -0400 Subject: [PATCH 079/176] clean up moire --- py4DSTEM/process/diffraction/crystal.py | 32 +++----------------- py4DSTEM/process/diffraction/crystal_ACOM.py | 2 -- 2 files changed, 5 insertions(+), 29 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index d3e3cebd6..b508d589e 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -5,21 +5,12 @@ from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional -from scipy.optimize import curve_fit import sys -from emdfile import tqdmnd, PointList, PointListArray +from emdfile import PointList from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom -from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -from py4DSTEM.process.diffraction.crystal_viz import plot_ring_pattern -from py4DSTEM.process.diffraction.utils import Orientation, calc_1D_profile - -try: - from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - from pymatgen.core.structure import Structure -except ImportError: - pass +from py4DSTEM.process.diffraction.utils import Orientation class Crystal: @@ -1091,7 +1082,6 @@ def generate_moire_diffraction_pattern( exy_1=0.0, phi_1=0.0, power=2.0, - k_max=1.0, ): """ Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated @@ -1118,13 +1108,12 @@ def generate_moire_diffraction_pattern( Rotation of lattice 1 in real space. power: float Plotting power law (default is amplitude**2.0, i.e. intensity). - k_max: float - Max k value of the calculated (and plotted) Moire lattice. Returns -------- - bragg_moire: BraggVector - Bragg vectors for moire lattice. + parent_peaks_0, parent_peaks_1, moire_peaks: BraggVectors + Bragg vectors for the rotated & strained parent lattices + and the moire lattice """ @@ -1159,14 +1148,6 @@ def generate_moire_diffraction_pattern( l1 = bragg_peaks_1["l"][sub1] # apply strain tensor to lattice 1 - - # infinitesimal - # m = np.array([ - # [1 + exx_1, (exy_1 - phi_1)*0.5], - # [(exy_1 - phi_1)*0.5, 1 + eyy_1], - # ]) - - # finite rotation m = np.array( [ [np.cos(phi_1), -np.sin(phi_1)], @@ -1189,11 +1170,8 @@ def generate_moire_diffraction_pattern( np.arange(np.sum(sub1)), indexing="ij", ) - # ind0 = ind0.ravel() - # ind1 = ind1.ravel() qx = qx0[ind0] + qx1[ind1] qy = qy0[ind0] + qy1[ind1] - # int_moire = int0_sub[ind0] + int1_sub[ind1] int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 # moire labels diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index da553456f..5722f3f38 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1,8 +1,6 @@ import numpy as np import matplotlib.pyplot as plt -import os from typing import Union, Optional -import time, sys from tqdm import tqdm from emdfile import tqdmnd, PointList, PointListArray From 7c653fefbd18f70a92a4cfd0aee8d7510940ec39 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 16 Oct 2023 09:55:19 -0700 Subject: [PATCH 080/176] Fix for plotting bug --- py4DSTEM/process/diffraction/crystal_viz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index e17e87b93..9f9336155 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -149,7 +149,7 @@ def plot_structure( zs=xyz[sub, 2], # + d[2], s=size_marker, linewidth=2, - color=atomic_colors(ID_plot), + facecolors=atomic_colors(ID_plot), edgecolor=[0, 0, 0], ) From 486639f0f476a134e4da266a33d6a89b74aabdbb Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 17:57:53 +0100 Subject: [PATCH 081/176] restructures return and plotting behavior more flexibly --- py4DSTEM/process/polar/polar_analysis.py | 195 ++++++++++++++++------- py4DSTEM/process/polar/polar_datacube.py | 8 +- 2 files changed, 143 insertions(+), 60 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index fc6e71cf2..8d2d92585 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -217,12 +217,14 @@ def calculate_pair_dist_function( r_step=0.02, damp_origin_fluctuations=True, density=None, - plot_fits=False, + plot_background_fits=False, plot_sf_estimate=False, plot_reduced_pdf=True, plot_pdf=False, figsize=(8, 4), maxfev=None, + returnval=False, + returnfig=False, ): """ Calculate the pair distribution function (PDF). @@ -280,7 +282,7 @@ def calculate_pair_dist_function( The density of the sample, if known. If this is not provided, only the reduced PDF is calculated. If this value is provided, the PDF is also calculated. - plot_fits : bool + plot_background_fits : bool plot_sf_estimate : bool plot_reduced_pdf=True : bool plot_pdf : bool @@ -400,6 +402,7 @@ def calculate_pair_dist_function( self.fk = fk self.bg = bg self.offset = coefs[0] + self.Sk_mask = mask # if density is provided, we can estimate the full PDF if density is not None: @@ -421,78 +424,154 @@ def calculate_pair_dist_function( if density is None: return_values = self.pdf_r, self.pdf_reduced else: - return_values = self.pdf_r, self.pdf_reduced + return_values = self.pdf_r, self.pdf_reduced, self.pdf if returnval: - ans = statistics if not returnfig else [statistics] + ans = return_values if not returnfig else [return_values] else: ans = None if not returnfig else [] - # Plots - if plot_fits: - fig, ax = plt.subplots(figsize=figsize) - ax.plot( - self.qq, - self.radial_mean, - color="k", - ) - ax.plot( - k, - bg, - color="r", + if plot_background_fits: + fig,ax = self.plot_background_fits( + figsize = figsize, + returnfig = True ) - ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") - ax.set_ylabel("Radial Mean") - ax.set_xlim((self.qq[0], self.qq[-1])) - # ax.set_ylim((0,2e-5)) - ax.set_xlabel("Scattering Vector [A^-1]") - ax.set_ylabel("I(k) and Fit Estimates") - - ax.set_ylim( - ( - np.min(self.radial_mean[self.radial_mean > 0]) * 0.8, - np.max(self.radial_mean * mask) * 1.25, - ) - ) - ax.set_yscale("log") + if returnfig: + ans.append((fig,ax)) if plot_sf_estimate: - fig, ax = plt.subplots(figsize=figsize) - ax.plot( - k, - Sk, - color="r", + fig,ax = self.plot_sf_estimate( + figsize = figsize, + returnfig = True ) - yr = (np.min(Sk), np.max(Sk)) - ax.set_ylim( - ( - yr[0] - 0.05 * (yr[1] - yr[0]), - yr[1] + 0.05 * (yr[1] - yr[0]), - ) - ) - ax.set_xlabel("Scattering Vector [A^-1]") - ax.set_ylabel("Reduced Structure Factor") + if returnfig: + ans.append((fig,ax)) if plot_reduced_pdf: - fig, ax = plt.subplots(figsize=figsize) - ax.plot( - r, - pdf_reduced, - color="r", + fig,ax = self.plot_reduced_pdf( + figsize = figsize, + returnfig = True ) - ax.set_xlabel("Radius [A]") - ax.set_ylabel("Reduced Pair Distribution Function") + if returnfig: + ans.append((fig,ax)) if plot_pdf: - fig, ax = plt.subplots(figsize=figsize) - ax.plot( - r, - pdf, - color="r", + fig,ax = self.plot_pdf( + figsize = figsize, + returnfig = True + ) + if returnfig: + ans.append((fig,ax)) + + # return + return ans + + +def plot_background_fits( + self, + figsize=(8, 4), + returnfig=False, + ): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.radial_mean, + color="k", + ) + ax.plot( + self.qq, + self.bg, + color="r", + ) + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") + ax.set_ylabel("Radial Mean") + ax.set_xlim((self.qq[0], self.qq[-1])) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("I(k) and Background Fit Estimates") + ax.set_ylim( + ( + np.min(self.radial_mean[self.radial_mean > 0]) * 0.8, + np.max(self.radial_mean * self.Sk_mask) * 1.25, ) - ax.set_xlabel("Radius [A]") - ax.set_ylabel("Pair Distribution Function") + ) + ax.set_yscale("log") + if returnfig: + return fig,ax + plt.show() + +def plot_sf_estimate( + self, + figsize=(8, 4), + returnfig=False, + ): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.Sk, + color="r", + ) + yr = (np.min(self.Sk), np.max(self.Sk)) + ax.set_ylim( + ( + yr[0] - 0.05 * (yr[1] - yr[0]), + yr[1] + 0.05 * (yr[1] - yr[0]), + ) + ) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("Reduced Structure Factor") + if returnfig: + return fig,ax + plt.show() + + +def plot_reduced_pdf( + self, + figsize=(8, 4), + returnfig=False, + ): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.pdf_r, + self.pdf_reduced, + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Reduced Pair Distribution Function") + if returnfig: + return fig,ax + plt.show() + +def plot_pdf( + self, + figsize=(8, 4), + returnfig=False, + ): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.pdf_r, + self.pdf, + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Pair Distribution Function") + if returnfig: + return fig,ax + plt.show() + + # functions for inverting from reduced PDF back to S(k) diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index a5d48c99e..56071c534 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -95,10 +95,14 @@ def __init__( from py4DSTEM.process.polar.polar_analysis import ( calculate_radial_statistics, - plot_radial_mean, - plot_radial_var_norm, calculate_pair_dist_function, calculate_FEM_local, + plot_radial_mean, + plot_radial_var_norm, + plot_background_fits, + plot_sf_estimate, + plot_reduced_pdf, + plot_pdf, ) from py4DSTEM.process.polar.polar_peaks import ( find_peaks_single_pattern, From 06e18376b820776e9f8a46f4f80095a94f474743 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 16 Oct 2023 12:54:58 -0700 Subject: [PATCH 082/176] Starting on CTF fitting --- py4DSTEM/process/phase/iterative_parallax.py | 208 ++++++++++++------- 1 file changed, 128 insertions(+), 80 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index ff6fb52af..c855c1451 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1267,9 +1267,11 @@ def subpixel_alignment( def aberration_fit( self, + fit_thon_rings = True, + fit_upsampled_fft = True, plot_CTF_compare: bool = False, - plot_dk: float = 0.005, - plot_k_sigma: float = 0.02, + # plot_dk: float = 0.005, + # plot_k_sigma: float = 0.02, ): """ Fit aberrations to the measured image shifts. @@ -1277,17 +1279,27 @@ def aberration_fit( Parameters ---------- plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies + If True, the fitted CTF is plotted against the reconstructed frequencies. + fit_thon_rings: bool + Set to True to directly fit aberrations in the FFT of the upsampled BF + image (if available). Note that this method relies on visible zero + crossings in the FFT, and will not work if they are not present. + fit_upsampled_fft: bool + If True, we aberration fit is performed on the upsampled BF image. + This option does nothing if fit_thon_rings is not True. plot_dk: float, optional Reciprocal bin-size for polar-averaged FFT plot_k_sigma: float, optional sigma to gaussian blur polar-averaged FFT by + """ xp = self._xp asnumpy = self._asnumpy gaussian_filter = self._gaussian_filter + # initial aberration fit + # Convert real space shifts to Angstroms self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) @@ -1316,6 +1328,42 @@ def aberration_fit( xp._default_memory_pool.free_all_blocks() xp.clear_memo() + # Refinement using Thon rings + if fit_thon_rings: + if fit_upsampled_fft: + # Get mean FFT of BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + print(self._kde_upsample_factor) + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor + else: + # Get mean FFT of upsampled BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + + # FFT coordinates + qx = fft + + # weights for fits + + # #zero origin pixel + # im_fft[0,0] = 0 + + + print(im_fft.shape) + + fig,ax = plt.subplots(figsize=(6,6)) + ax.imshow( + np.fft.fftshift(im_fft)**0.5, + ) + + + # Print results if self._verbose: print( @@ -1334,83 +1382,83 @@ def aberration_fit( print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: - # Get polar mean from FFT of BF reconstruction - im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - k_max = xp.max(kra) / np.sqrt(2.0) - k_num_bins = int(xp.ceil(k_max / plot_dk)) - k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # histogram - k_ind = kra / plot_dk - kf = np.floor(k_ind).astype("int") - dk = k_ind - kf - sub = kf <= k_num_bins - hist_exp = xp.bincount( - kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins - ) - hist_norm = xp.bincount( - kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins - ) - sub = kf <= k_num_bins - 1 - - hist_exp += xp.bincount( - kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins - ) - hist_norm += xp.bincount( - kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins - ) - - # KDE and normalizing - k_sigma = plot_dk / plot_k_sigma - hist_exp[0] = 0.0 - hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - hist_exp /= hist_norm - - # CTF comparison - CTF_fit = xp.sin( - (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 - ) - - # plotting input - log scale - min_hist_val = xp.max(hist_exp) * 1e-3 - hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - hist_plot -= xp.min(hist_plot) - hist_plot /= xp.max(hist_plot) - - hist_plot = asnumpy(hist_plot) - k_bins = asnumpy(k_bins) - CTF_fit = asnumpy(CTF_fit) - - fig, ax = plt.subplots(figsize=(8, 4)) - - ax.fill_between( - k_bins, - hist_plot, - color=(0.7, 0.7, 0.7, 1), - ) - - ax.plot( - k_bins, - np.clip(CTF_fit, 0.0, np.inf), - color=(1, 0, 0, 1), - linewidth=2, - ) - ax.plot( - k_bins, - np.clip(-CTF_fit, 0.0, np.inf), - color=(0, 0.5, 1, 1), - linewidth=2, - ) - ax.set_xlim([0, k_bins[-1]]) - ax.set_ylim([0, 1.05]) + # # Plot the CTF comparison between experiment and fit + # if plot_CTF_compare: + # # Get polar mean from FFT of BF reconstruction + # im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) + + # # coordinates + # kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) + # ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + # kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) + # k_max = xp.max(kra) / np.sqrt(2.0) + # k_num_bins = int(xp.ceil(k_max / plot_dk)) + # k_bins = xp.arange(k_num_bins + 1) * plot_dk + + # # histogram + # k_ind = kra / plot_dk + # kf = np.floor(k_ind).astype("int") + # dk = k_ind - kf + # sub = kf <= k_num_bins + # hist_exp = xp.bincount( + # kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins + # ) + # hist_norm = xp.bincount( + # kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins + # ) + # sub = kf <= k_num_bins - 1 + + # hist_exp += xp.bincount( + # kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins + # ) + # hist_norm += xp.bincount( + # kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins + # ) + + # # KDE and normalizing + # k_sigma = plot_dk / plot_k_sigma + # hist_exp[0] = 0.0 + # hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") + # hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") + # hist_exp /= hist_norm + + # # CTF comparison + # CTF_fit = xp.sin( + # (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 + # ) + + # # plotting input - log scale + # min_hist_val = xp.max(hist_exp) * 1e-3 + # hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) + # hist_plot -= xp.min(hist_plot) + # hist_plot /= xp.max(hist_plot) + + # hist_plot = asnumpy(hist_plot) + # k_bins = asnumpy(k_bins) + # CTF_fit = asnumpy(CTF_fit) + + # fig, ax = plt.subplots(figsize=(8, 4)) + + # ax.fill_between( + # k_bins, + # hist_plot, + # color=(0.7, 0.7, 0.7, 1), + # ) + + # ax.plot( + # k_bins, + # np.clip(CTF_fit, 0.0, np.inf), + # color=(1, 0, 0, 1), + # linewidth=2, + # ) + # ax.plot( + # k_bins, + # np.clip(-CTF_fit, 0.0, np.inf), + # color=(0, 0.5, 1, 1), + # linewidth=2, + # ) + # ax.set_xlim([0, k_bins[-1]]) + # ax.set_ylim([0, 1.05]) def aberration_correct( self, From 5013d3e0f9923161a828350f786b22a14a522521 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 16 Oct 2023 23:59:24 +0100 Subject: [PATCH 083/176] autoformats --- py4DSTEM/process/polar/polar_analysis.py | 59 +++++++++--------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 8d2d92585..5bca4a331 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -118,7 +118,7 @@ def calculate_radial_statistics( returnfig=True, ) if returnfig: - ans.append((fig,ax)) + ans.append((fig, ax)) if plot_results_var: fig, ax = plot_radial_var_norm( self, @@ -126,7 +126,7 @@ def calculate_radial_statistics( returnfig=True, ) if returnfig: - ans.append((fig,ax)) + ans.append((fig, ax)) # return return ans @@ -342,8 +342,8 @@ def calculate_pair_dist_function( coefs[3] *= int_mean # Calculate the mean atomic form factor without a constant offset - #coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) - #fk = scattering_model(k2, coefs_fk) + # coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) + # fk = scattering_model(k2, coefs_fk) bg = scattering_model(k2, coefs) fk = bg - coefs[0] @@ -419,7 +419,6 @@ def calculate_pair_dist_function( # store results self.pdf = pdf - # prepare answer if density is None: return_values = self.pdf_r, self.pdf_reduced @@ -430,39 +429,26 @@ def calculate_pair_dist_function( else: ans = None if not returnfig else [] - # Plots if plot_background_fits: - fig,ax = self.plot_background_fits( - figsize = figsize, - returnfig = True - ) + fig, ax = self.plot_background_fits(figsize=figsize, returnfig=True) if returnfig: - ans.append((fig,ax)) + ans.append((fig, ax)) if plot_sf_estimate: - fig,ax = self.plot_sf_estimate( - figsize = figsize, - returnfig = True - ) + fig, ax = self.plot_sf_estimate(figsize=figsize, returnfig=True) if returnfig: - ans.append((fig,ax)) + ans.append((fig, ax)) if plot_reduced_pdf: - fig,ax = self.plot_reduced_pdf( - figsize = figsize, - returnfig = True - ) + fig, ax = self.plot_reduced_pdf(figsize=figsize, returnfig=True) if returnfig: - ans.append((fig,ax)) + ans.append((fig, ax)) if plot_pdf: - fig,ax = self.plot_pdf( - figsize = figsize, - returnfig = True - ) + fig, ax = self.plot_pdf(figsize=figsize, returnfig=True) if returnfig: - ans.append((fig,ax)) + ans.append((fig, ax)) # return return ans @@ -472,7 +458,7 @@ def plot_background_fits( self, figsize=(8, 4), returnfig=False, - ): +): """ TODO """ @@ -500,14 +486,15 @@ def plot_background_fits( ) ax.set_yscale("log") if returnfig: - return fig,ax + return fig, ax plt.show() + def plot_sf_estimate( self, figsize=(8, 4), returnfig=False, - ): +): """ TODO """ @@ -527,7 +514,7 @@ def plot_sf_estimate( ax.set_xlabel("Scattering Vector [A^-1]") ax.set_ylabel("Reduced Structure Factor") if returnfig: - return fig,ax + return fig, ax plt.show() @@ -535,7 +522,7 @@ def plot_reduced_pdf( self, figsize=(8, 4), returnfig=False, - ): +): """ TODO """ @@ -548,14 +535,15 @@ def plot_reduced_pdf( ax.set_xlabel("Radius [A]") ax.set_ylabel("Reduced Pair Distribution Function") if returnfig: - return fig,ax + return fig, ax plt.show() + def plot_pdf( self, figsize=(8, 4), returnfig=False, - ): +): """ TODO """ @@ -568,11 +556,9 @@ def plot_pdf( ax.set_xlabel("Radius [A]") ax.set_ylabel("Pair Distribution Function") if returnfig: - return fig,ax + return fig, ax plt.show() - - # functions for inverting from reduced PDF back to S(k) # # invert @@ -653,4 +639,3 @@ def scattering_model(k2, *coefs): # int1*np.exp(k2/(-2*sigma1**2)) return int_model - From 593f07d4c7417167c681f84e628b937c29c36e67 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 16 Oct 2023 18:00:11 -0700 Subject: [PATCH 084/176] Adding more parts of parallax CTF fitting --- py4DSTEM/process/phase/iterative_parallax.py | 178 +++++++++++++++++-- 1 file changed, 160 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index c855c1451..f7618c749 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -18,6 +18,8 @@ from py4DSTEM.visualize import show from scipy.linalg import polar from scipy.special import comb +from scipy.optimize import curve_fit +from scipy.signal import medfilt2d try: import cupy as cp @@ -1269,6 +1271,10 @@ def aberration_fit( self, fit_thon_rings = True, fit_upsampled_fft = True, + aber_order_max = 2, + q_power_fit = 0.0, + medfilt_size = None, + maxfev = None, plot_CTF_compare: bool = False, # plot_dk: float = 0.005, # plot_k_sigma: float = 0.02, @@ -1287,10 +1293,14 @@ def aberration_fit( fit_upsampled_fft: bool If True, we aberration fit is performed on the upsampled BF image. This option does nothing if fit_thon_rings is not True. + aber_order_max: int + Max radial order for fitting of aberrations. + q_power_fit: float + q power fitting weight. plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT + Reciprocal bin-size for polar-averaged FFT. plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by + sigma to gaussian blur polar-averaged FFT by. """ @@ -1335,7 +1345,6 @@ def aberration_fit( im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) # coordinates - print(self._kde_upsample_factor) q_pixel_size = np.array(self._reciprocal_sampling) \ / self._kde_upsample_factor else: @@ -1347,22 +1356,130 @@ def aberration_fit( # FFT coordinates - qx = fft + qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) + qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha2 = qr2 * self._wavelength**2 + self.theta = np.arctan2(qy[None,:],qx[:,None]) # weights for fits - - # #zero origin pixel - # im_fft[0,0] = 0 - - - print(im_fft.shape) - - fig,ax = plt.subplots(figsize=(6,6)) - ax.imshow( - np.fft.fftshift(im_fft)**0.5, + self.q_weight = qr2 ** (q_power_fit/2) + + # Aberration coefs + mn = [] + for m in range(0,aber_order_max//2+1): + n_max = np.floor(aber_order_max-2*m).astype('int') + + for n in range(0,n_max+1): + if m + n > 1 or (m > 0 and n == 0): + if n == 0: + mn.append([m,n,0]) + else: + mn.append([m,n,0]) + mn.append([m,n,1]) + self.aber_mn = np.array(mn) + + # Aberration basis + self.aber_num = self.aber_mn.shape[0] + self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) + # self.aber_basis[:,0] = self.alpha2.ravel() + for a0 in range(self.aber_num): + if self.aber_mn[a0,1] == 0: + # Radially symmetric basis + self.aber_basis[:,a0] = self.alpha2.ravel()**self.aber_mn[a0,0] + elif self.aber_mn[a0,2] == 0: + # cos coef + self.aber_basis[:,a0] = \ + self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ + np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + else: + # sin coef + self.aber_basis[:,a0] = \ + self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ + np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # fitting image + im_fit = im_fft * self.q_weight + if medfilt_size is not None: + im_fit = np.fft.ifftshift(medfilt2d( + np.fft.fftshift(im_fit), + medfilt_size)) + + # initial coefs + int_max = np.max(im_fit) + sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) + coefs = np.zeros(5 + self.aber_num) + lb = np.zeros(5 + self.aber_num) + ub = np.ones(5 + self.aber_num) * np.inf + coefs[0] = 1e-3 + coefs[1] = int_max * 0.1 + coefs[2] = sigma_init + coefs[3] = int_max * 0.9 + coefs[4] = sigma_init + lb[5:] = -np.inf + # initial C1 value (defocus) + ind = np.argmin( + np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + ) + C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength + coefs[ind + 5] = C1_dimensionless + + # Fitting mask + fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) + basis_masked = self.aber_basis[fit_mask.ravel(),:] + + # Define fitting functions + + def calc_CTF_mag(alpha2, *coefs): + int0 = coefs[0] + int1 = coefs[1] + sigma1 = coefs[2] + int_env = coefs[3] + sigma_env = coefs[4] + + im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + chi = np.zeros_like(im_CTF_mag) + for a0 in range(5,len(coefs)): + chi += coefs[a0] * self.aber_basis[:,a0-5] + return im_CTF_mag + np.abs(np.sin(chi)) * env + + def calc_CTF_mag_masked(alpha2, *coefs): + int0 = coefs[0] + int1 = coefs[1] + sigma1 = coefs[2] + int_env = coefs[3] + sigma_env = coefs[4] + + im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + chi = np.zeros_like(im_CTF_mag) + for a0 in range(5,len(coefs)): + chi += coefs[a0] * basis_masked[:,a0-5] + return im_CTF_mag + np.abs(np.sin(chi)) * env + + # Refine aberration coefficients + if maxfev is None: + coefs = np.array( + curve_fit( + calc_CTF_mag_masked, + self.alpha2[fit_mask], + im_fit[fit_mask], + p0 = tuple(coefs), + bounds = (lb,ub), + )[0] + ) + else: + coefs = np.array( + curve_fit( + calc_CTF_mag_masked, + self.alpha2[fit_mask], + im_fit[fit_mask], + p0 = tuple(coefs), + bounds = (lb,ub), + maxfev = maxfev, + )[0] ) - - # Print results if self._verbose: @@ -1382,8 +1499,33 @@ def aberration_fit( print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - # # Plot the CTF comparison between experiment and fit - # if plot_CTF_compare: + # Plot the CTF comparison between experiment and fit + if plot_CTF_compare: + im_test = np.reshape(calc_CTF_mag(self.alpha2.ravel(), *coefs), im_fit.shape) + + fig,ax = plt.subplots(figsize=(12,6)) + ax.imshow( + np.hstack(( + np.fft.fftshift(im_fit), + np.fft.fftshift(im_test), + )), + vmin = np.min(im_test[fit_mask]), + vmax = np.max(im_test[fit_mask]), + cmap = 'gray', + ) + + # ax.imshow( + # im_plot / np.max(im_plot), + # vmin = 0, + # vmax = 1, + # cmap = 'gray', + # ) + # ax.imshow( \ + # np.fft.fftshift( + # np.mod(np.reshape(self.aber_basis[:,2],im_fft.shape)+np.pi,2*np.pi)-np.pi + # )) + + # # Get polar mean from FFT of BF reconstruction # im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) From e5d7425b29fcb613057ca7d8a90e7671d2c04c38 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Tue, 17 Oct 2023 03:13:26 -0700 Subject: [PATCH 085/176] added chroma_boost for show_complex --- py4DSTEM/visualize/vis_special.py | 33 ++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 6dd980bce..f7beec241 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -938,12 +938,13 @@ def show_selected_dps( ) -def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value power (float) : power to raise amplitude to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.5) """ amp = np.abs(complex_data) phase = np.angle(complex_data) @@ -974,7 +975,7 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): amp = ((amp - vmin) / vmax).clip(1e-16, 1) J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff - C = np.where(J < 61.5, 98 * J / 123, 1400 / 11 - 14 * J / 11) # Min uniform chroma + C = np.minimum(chroma_boost * 98 * J / 123, 110) h = np.rad2deg(phase) + 180 JCh = np.stack((J, C, h), axis=-1) @@ -983,16 +984,17 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, power=None): return rgb -def add_colorbar_arg(cax, c=49, j=61.5): +def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5): """ cax : axis to add cbar to - c : constant chroma value - j : constant luminance value + chroma_boost (float): boosts chroma for higher-contrast (~1-2.25) + c (float) : constant chroma value + j (float) : constant luminance value """ h = np.linspace(0, 360, 256, endpoint=False) J = np.full_like(h, j) - C = np.full_like(h, c) + C = np.full_like(h, np.minimum(c * chroma_boost, 110)) JCh = np.stack((J, C, h), axis=-1) rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) @@ -1012,12 +1014,13 @@ def show_complex( ar_complex, vmin=None, vmax=None, + power=None, + chroma_boost=1, cbar=True, scalebar=False, pixelunits="pixels", pixelsize=1, returnfig=False, - power=None, **kwargs ): """ @@ -1030,12 +1033,13 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels + power (float,optional) : power to raise amplitude to + chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25) cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - power (float,optional) : power to raise amplitude to Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -1050,7 +1054,7 @@ def show_complex( if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): rgb = [ - Complex2RGB(ar, vmin, vmax, power=power) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for sublist in ar_complex for ar in sublist ] @@ -1058,7 +1062,10 @@ def show_complex( W = len(ar_complex[0]) else: - rgb = [Complex2RGB(ar, vmin, vmax, power=power) for ar in ar_complex] + rgb = [ + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) + for ar in ar_complex + ] if len(rgb[0].shape) == 4: H = len(ar_complex) W = rgb[0].shape[0] @@ -1067,7 +1074,9 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, power=power) + rgb = Complex2RGB( + ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost + ) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -1127,7 +1136,7 @@ def show_complex( else: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) fig.tight_layout() From 7167706e7d9a244f256de01562f16f074a0fe10a Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 17 Oct 2023 13:12:13 +0100 Subject: [PATCH 086/176] adds placeholder for citations --- py4DSTEM/process/polar/polar_analysis.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 5bca4a331..4f053dcfa 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -38,6 +38,8 @@ def calculate_radial_statistics( and the normalized variance is d_var/d_mean. + This follows the methods described in [@cophus TODO ADD CITATION]. + Parameters -------- @@ -251,6 +253,7 @@ def calculate_pair_dist_function( G(r) = 1 + [ \frac{2}{\pi} * g(r) / ( 4\pi * D * r dr ) ] + This follows the methods described in [@cophus TODO ADD CITATION]. Parameters From 62c4cae86ad5ad63242b4f22f1e742a0b80746bc Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 10:46:11 -0700 Subject: [PATCH 087/176] Working on CTF --- py4DSTEM/process/phase/iterative_parallax.py | 383 ++++++++++--------- 1 file changed, 209 insertions(+), 174 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index f7618c749..4bfd265f9 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1269,15 +1269,12 @@ def subpixel_alignment( def aberration_fit( self, - fit_thon_rings = True, - fit_upsampled_fft = True, - aber_order_max = 2, - q_power_fit = 0.0, - medfilt_size = None, - maxfev = None, + fit_CTF_FFT = True, + fit_CTF_threshold = 0.1, + fit_upsampled_FFT = True, + fit_aber_order_max = 2, + fit_maxfev = None, plot_CTF_compare: bool = False, - # plot_dk: float = 0.005, - # plot_k_sigma: float = 0.02, ): """ Fit aberrations to the measured image shifts. @@ -1286,21 +1283,17 @@ def aberration_fit( ---------- plot_CTF_compare: bool, optional If True, the fitted CTF is plotted against the reconstructed frequencies. - fit_thon_rings: bool + fit_CTF_FFT: bool Set to True to directly fit aberrations in the FFT of the upsampled BF image (if available). Note that this method relies on visible zero crossings in the FFT, and will not work if they are not present. - fit_upsampled_fft: bool + fit_upsampled_FFT: bool If True, we aberration fit is performed on the upsampled BF image. This option does nothing if fit_thon_rings is not True. - aber_order_max: int + fit_aber_order_max: int Max radial order for fitting of aberrations. - q_power_fit: float - q power fitting weight. - plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT. - plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by. + ctf_threshold: float + CTF fitting minimizes value at CTF zero crossings (Thon ring minima). """ @@ -1338,182 +1331,204 @@ def aberration_fit( xp._default_memory_pool.free_all_blocks() xp.clear_memo() - # Refinement using Thon rings - if fit_thon_rings: - if fit_upsampled_fft: - # Get mean FFT of BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor + # Aberration coefs + mn = [] + for m in range(0,fit_aber_order_max//2+1): + n_max = np.floor(fit_aber_order_max-2*m).astype('int') + + for n in range(0,n_max+1): + if m + n > 1 or (m > 0 and n == 0): + if n == 0: + mn.append([m,n,0]) + else: + mn.append([m,n,0]) + mn.append([m,n,1]) + self.aber_mn = np.array(mn) + + # Aberration basis + self.aber_num = self.aber_mn.shape[0] + self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) + for a0 in range(self.aber_num): + if self.aber_mn[a0,1] == 0: + # Radially symmetric basis + self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) + elif self.aber_mn[a0,2] == 0: + # cos coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.cos(self.aber_mn[a0,1] * self.theta.ravel()) else: - # Get mean FFT of upsampled BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF)) + # sin coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # CTF function + def calc_CTF(alpha, *coefs): + chi = np.zeros_like(alpha.ravel()) + for a0 in range(len(coefs)): + chi += coefs[a0] * self.aber_basis[:,a0] + return np.reshape(chi, alpha.shape) + + if fit_upsampled_FFT: + # Get mean FFT of BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor + else: + # Get mean FFT of upsampled BF reconstruction + im_fft = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + # FFT coordinates + qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) + qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha = np.sqrt(qr2) * self._wavelength + self.theta = np.arctan2(qy[None,:],qx[:,None]) + + # initial coefficients and plotting intensity range mask + C1_dimensionless = self.aberration_C1 * 0.5 * self._wavelength + coefs = np.zeros(self.aber_num) + ind = np.argmin( + np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + ) + coefs[ind] = C1_dimensionless + plot_mask = self.alpha > np.sqrt(np.pi/np.abs(C1_dimensionless)) + angular_mask = np.cos(4.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.5 - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + # Refinement using CTF fitting / Thon rings + if fit_CTF_FFT: + pass - # FFT coordinates - qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) - qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) - qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha2 = qr2 * self._wavelength**2 - self.theta = np.arctan2(qy[None,:],qx[:,None]) - - # weights for fits - self.q_weight = qr2 ** (q_power_fit/2) - - # Aberration coefs - mn = [] - for m in range(0,aber_order_max//2+1): - n_max = np.floor(aber_order_max-2*m).astype('int') - - for n in range(0,n_max+1): - if m + n > 1 or (m > 0 and n == 0): - if n == 0: - mn.append([m,n,0]) - else: - mn.append([m,n,0]) - mn.append([m,n,1]) - self.aber_mn = np.array(mn) - - # Aberration basis - self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) - # self.aber_basis[:,0] = self.alpha2.ravel() - for a0 in range(self.aber_num): - if self.aber_mn[a0,1] == 0: - # Radially symmetric basis - self.aber_basis[:,a0] = self.alpha2.ravel()**self.aber_mn[a0,0] - elif self.aber_mn[a0,2] == 0: - # cos coef - self.aber_basis[:,a0] = \ - self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ - np.cos(self.aber_mn[a0,1] * self.theta.ravel()) - else: - # sin coef - self.aber_basis[:,a0] = \ - self.alpha2.ravel()**(self.aber_mn[a0,0] + self.aber_mn[a0,1]/2.0) * \ - np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # fitting image - im_fit = im_fft * self.q_weight - if medfilt_size is not None: - im_fit = np.fft.ifftshift(medfilt2d( - np.fft.fftshift(im_fit), - medfilt_size)) - - # initial coefs - int_max = np.max(im_fit) - sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) - coefs = np.zeros(5 + self.aber_num) - lb = np.zeros(5 + self.aber_num) - ub = np.ones(5 + self.aber_num) * np.inf - coefs[0] = 1e-3 - coefs[1] = int_max * 0.1 - coefs[2] = sigma_init - coefs[3] = int_max * 0.9 - coefs[4] = sigma_init - lb[5:] = -np.inf - # initial C1 value (defocus) - ind = np.argmin( - np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - ) - C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength - coefs[ind + 5] = C1_dimensionless + # im_fit = im_fft * self.q_weight + # if medfilt_size is not None: + # im_fit = np.fft.ifftshift(medfilt2d( + # np.fft.fftshift(im_fit), + # medfilt_size)) + + # # initial coefs + # int_max = np.max(im_fit) + # sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) + # coefs = np.zeros(5 + self.aber_num) + # lb = np.zeros(5 + self.aber_num) + # ub = np.ones(5 + self.aber_num) * np.inf + # coefs[0] = 1e-3 + # coefs[1] = int_max * 0.1 + # coefs[2] = sigma_init + # coefs[3] = int_max * 0.9 + # coefs[4] = sigma_init + # lb[5:] = -np.inf + # # initial C1 value (defocus) + # ind = np.argmin( + # np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + # ) + # C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength + # coefs[ind + 5] = C1_dimensionless # Fitting mask - fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) - basis_masked = self.aber_basis[fit_mask.ravel(),:] + # fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) + # basis_masked = self.aber_basis[fit_mask.ravel(),:] # Define fitting functions - def calc_CTF_mag(alpha2, *coefs): - int0 = coefs[0] - int1 = coefs[1] - sigma1 = coefs[2] - int_env = coefs[3] - sigma_env = coefs[4] - - im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - chi = np.zeros_like(im_CTF_mag) - for a0 in range(5,len(coefs)): - chi += coefs[a0] * self.aber_basis[:,a0-5] - return im_CTF_mag + np.abs(np.sin(chi)) * env - - def calc_CTF_mag_masked(alpha2, *coefs): - int0 = coefs[0] - int1 = coefs[1] - sigma1 = coefs[2] - int_env = coefs[3] - sigma_env = coefs[4] - - im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - chi = np.zeros_like(im_CTF_mag) - for a0 in range(5,len(coefs)): - chi += coefs[a0] * basis_masked[:,a0-5] - return im_CTF_mag + np.abs(np.sin(chi)) * env - - # Refine aberration coefficients - if maxfev is None: - coefs = np.array( - curve_fit( - calc_CTF_mag_masked, - self.alpha2[fit_mask], - im_fit[fit_mask], - p0 = tuple(coefs), - bounds = (lb,ub), - )[0] - ) - else: - coefs = np.array( - curve_fit( - calc_CTF_mag_masked, - self.alpha2[fit_mask], - im_fit[fit_mask], - p0 = tuple(coefs), - bounds = (lb,ub), - maxfev = maxfev, - )[0] - ) - - # Print results - if self._verbose: - print( - ( - "Rotation of Q w.r.t. R = " - f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" - ) - ) - print( - ( - "Astigmatism (A1x,A1y) = (" - f"{self.aberration_A1x:.0f}," - f"{self.aberration_A1y:.0f}) Ang" - ) - ) - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + # def calc_CTF_mag(alpha2, *coefs): + # int0 = coefs[0] + # int1 = coefs[1] + # sigma1 = coefs[2] + # int_env = coefs[3] + # sigma_env = coefs[4] + + # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + # chi = np.zeros_like(im_CTF_mag) + # for a0 in range(5,len(coefs)): + # chi += coefs[a0] * self.aber_basis[:,a0-5] + # return im_CTF_mag + np.abs(np.sin(chi)) * env + + # def calc_CTF_mag_masked(alpha2, *coefs): + # int0 = coefs[0] + # int1 = coefs[1] + # sigma1 = coefs[2] + # int_env = coefs[3] + # sigma_env = coefs[4] + + # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) + # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) + # chi = np.zeros_like(im_CTF_mag) + # for a0 in range(5,len(coefs)): + # chi += coefs[a0] * basis_masked[:,a0-5] + # return im_CTF_mag + np.abs(np.sin(chi)) * env + + # # Refine aberration coefficients + # if maxfev is None: + # coefs = np.array( + # curve_fit( + # calc_CTF_mag_masked, + # self.alpha2[fit_mask], + # im_fit[fit_mask], + # p0 = tuple(coefs), + # bounds = (lb,ub), + # )[0] + # ) + # else: + # coefs = np.array( + # curve_fit( + # calc_CTF_mag_masked, + # self.alpha2[fit_mask], + # im_fit[fit_mask], + # p0 = tuple(coefs), + # bounds = (lb,ub), + # maxfev = maxfev, + # )[0] + # ) # Plot the CTF comparison between experiment and fit if plot_CTF_compare: - im_test = np.reshape(calc_CTF_mag(self.alpha2.ravel(), *coefs), im_fit.shape) + # Generate FFT plotting image + int_range = (np.min(im_fft[plot_mask]),np.max(im_fft[plot_mask])) + int_range = (int_range[0],(int_range[1]-int_range[0])*0.5 + int_range[0]) + im_scale = np.clip( + (np.fft.fftshift(im_fft) - int_range[0]) / (int_range[1] - int_range[0]), + 0,1) + # im_scale = im_scale**0.5 + im_plot = np.tile(im_scale[:,:,None],(1,1,3)) + + # Add CTF zero crossings + im_CTF = calc_CTF(self.alpha,*coefs) + # im_CTF = np.sin(im_CTF)**2 + # im_CTF = np.fft.fftshift(im_CTF) + # print(np.max(im_CTF)) + im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold + im_CTF[np.logical_not(plot_mask)] = 0 + im_CTF = np.fft.fftshift(im_CTF * angular_mask) + im_plot[:,:,0] += im_CTF + im_plot[:,:,1] -= im_CTF + im_plot[:,:,2] -= im_CTF + im_plot = np.clip(im_plot,0,1) fig,ax = plt.subplots(figsize=(12,6)) ax.imshow( - np.hstack(( - np.fft.fftshift(im_fit), - np.fft.fftshift(im_test), - )), - vmin = np.min(im_test[fit_mask]), - vmax = np.max(im_test[fit_mask]), - cmap = 'gray', + im_plot, + # np.fft.fftshift(np.reshape(self.aber_basis[:,1],im_CTF.shape)) + # angular_mask, + # np.hstack(( + # im_scale, + # im_CTF + # )) + # im_ctf ) + # ax.imshow( # im_plot / np.max(im_plot), # vmin = 0, @@ -1602,6 +1617,26 @@ def calc_CTF_mag_masked(alpha2, *coefs): # ax.set_xlim([0, k_bins[-1]]) # ax.set_ylim([0, 1.05]) + + # Print results + if self._verbose: + print( + ( + "Rotation of Q w.r.t. R = " + f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg" + ) + ) + print( + ( + "Astigmatism (A1x,A1y) = (" + f"{self.aberration_A1x:.0f}," + f"{self.aberration_A1y:.0f}) Ang" + ) + ) + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + + def aberration_correct( self, plot_corrected_phase: bool = True, From ea070f85b8196828ff597f8479dad4fb910fb87c Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 15:52:13 -0700 Subject: [PATCH 088/176] It works! --- py4DSTEM/process/phase/iterative_parallax.py | 120 +++++++++++++------ 1 file changed, 85 insertions(+), 35 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 4bfd265f9..243a8ec5b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -18,7 +18,7 @@ from py4DSTEM.visualize import show from scipy.linalg import polar from scipy.special import comb -from scipy.optimize import curve_fit +from scipy.optimize import curve_fit, minimize from scipy.signal import medfilt2d try: @@ -1270,10 +1270,12 @@ def subpixel_alignment( def aberration_fit( self, fit_CTF_FFT = True, - fit_CTF_threshold = 0.1, + fit_CTF_threshold = 0.25, fit_upsampled_FFT = True, fit_aber_order_max = 2, - fit_maxfev = None, + fit_max_num_rings = 6, + fit_power_alpha = 2.0, + # fit_maxfev = None, plot_CTF_compare: bool = False, ): """ @@ -1301,7 +1303,7 @@ def aberration_fit( asnumpy = self._asnumpy gaussian_filter = self._gaussian_filter - # initial aberration fit + ### initial aberration fit ### # Convert real space shifts to Angstroms self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) @@ -1331,6 +1333,30 @@ def aberration_fit( xp._default_memory_pool.free_all_blocks() xp.clear_memo() + + ### FFT fitting / plotting code ### + + if fit_upsampled_FFT: + # Get mean FFT of BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor + else: + # Get mean FFT of upsampled BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + # FFT coordinates + qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) + qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha = np.sqrt(qr2) * self._wavelength + self.theta = np.arctan2(qy[None,:],qx[:,None]) + # Aberration coefs mn = [] for m in range(0,fit_aber_order_max//2+1): @@ -1347,7 +1373,7 @@ def aberration_fit( # Aberration basis self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha2.size,self.aber_num)) + self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) for a0 in range(self.aber_num): if self.aber_mn[a0,1] == 0: # Radially symmetric basis @@ -1370,40 +1396,56 @@ def calc_CTF(alpha, *coefs): chi += coefs[a0] * self.aber_basis[:,a0] return np.reshape(chi, alpha.shape) - if fit_upsampled_FFT: - # Get mean FFT of BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor - else: - # Get mean FFT of upsampled BF reconstruction - im_fft = np.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) - - # FFT coordinates - qx = np.fft.fftfreq(im_fft.shape[0],q_pixel_size[0]) - qy = np.fft.fftfreq(im_fft.shape[1],q_pixel_size[1]) - qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha = np.sqrt(qr2) * self._wavelength - self.theta = np.arctan2(qy[None,:],qx[:,None]) - # initial coefficients and plotting intensity range mask - C1_dimensionless = self.aberration_C1 * 0.5 * self._wavelength + C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength coefs = np.zeros(self.aber_num) ind = np.argmin( np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) ) - coefs[ind] = C1_dimensionless - plot_mask = self.alpha > np.sqrt(np.pi/np.abs(C1_dimensionless)) - angular_mask = np.cos(4.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.5 + coefs[ind] = C10_dimensionless + plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) + # plot_mask[:] = True + angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: - pass + # scoring function to minimize - mean value of zero crossing regions of FFT + def score_CTF(coefs): + im_CTF = np.abs(calc_CTF(self.alpha,*coefs)) + mask = np.logical_and( + im_CTF > 0.5*np.pi, + im_CTF < (max_num_rings+0.5)*np.pi, + ) + if np.any(mask): + weights = np.cos(im_CTF[mask])**4 + return np.sum(weights*im_FFT[mask]*self.alpha[mask]**fit_power_alpha) / np.sum(weights) + else: + return np.inf + + for max_num_rings in range(1,fit_max_num_rings+1): + # minimization + res = minimize( + score_CTF, + coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method = 'BFGS', + tol = 1e-8, + ) + coefs = res.x + + # basis = np.vstack(( + # self.alpha.ravel(), + # im_FFT.ravel() + # )) + # print(basis.shape) + # score = score_CTF(self.alpha,coefs*1) + # print(score) + + + # im_CTF = calc_CTF(self.alpha,*coefs) + # im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; + # im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold @@ -1495,20 +1537,28 @@ def calc_CTF(alpha, *coefs): # Plot the CTF comparison between experiment and fit if plot_CTF_compare: # Generate FFT plotting image - int_range = (np.min(im_fft[plot_mask]),np.max(im_fft[plot_mask])) - int_range = (int_range[0],(int_range[1]-int_range[0])*0.5 + int_range[0]) + im_scale = im_FFT * self.alpha**fit_power_alpha + # int_range = (np.min(im_scale[plot_mask]),np.max(im_scale[plot_mask])) + int_vals = np.sort(im_scale.ravel()) + int_range = ( + int_vals[np.round(0.02*im_scale.size).astype('int')], + int_vals[np.round(0.98*im_scale.size).astype('int')], + ) + + int_range = (int_range[0],(int_range[1]-int_range[0])*1.0 + int_range[0]) im_scale = np.clip( - (np.fft.fftshift(im_fft) - int_range[0]) / (int_range[1] - int_range[0]), + (np.fft.fftshift(im_scale) - int_range[0]) / (int_range[1] - int_range[0]), 0,1) # im_scale = im_scale**0.5 im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings im_CTF = calc_CTF(self.alpha,*coefs) + im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; # im_CTF = np.sin(im_CTF)**2 # im_CTF = np.fft.fftshift(im_CTF) # print(np.max(im_CTF)) - im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold + im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold im_CTF[np.logical_not(plot_mask)] = 0 im_CTF = np.fft.fftshift(im_CTF * angular_mask) im_plot[:,:,0] += im_CTF From b2cbede265a97cf20b3c36d2e8d3edb9936b87b9 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 16:27:00 -0700 Subject: [PATCH 089/176] Updating outputs --- py4DSTEM/process/phase/iterative_parallax.py | 404 ++++++------------- 1 file changed, 121 insertions(+), 283 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 243a8ec5b..d4ed2a80a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1335,77 +1335,85 @@ def aberration_fit( ### FFT fitting / plotting code ### - - if fit_upsampled_FFT: - # Get mean FFT of BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) - - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor - else: - # Get mean FFT of upsampled BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) - - # FFT coordinates - qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) - qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) - qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha = np.sqrt(qr2) * self._wavelength - self.theta = np.arctan2(qy[None,:],qx[:,None]) - - # Aberration coefs - mn = [] - for m in range(0,fit_aber_order_max//2+1): - n_max = np.floor(fit_aber_order_max-2*m).astype('int') - - for n in range(0,n_max+1): - if m + n > 1 or (m > 0 and n == 0): - if n == 0: - mn.append([m,n,0]) - else: - mn.append([m,n,0]) - mn.append([m,n,1]) - self.aber_mn = np.array(mn) - - # Aberration basis - self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) - for a0 in range(self.aber_num): - if self.aber_mn[a0,1] == 0: - # Radially symmetric basis - self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) - elif self.aber_mn[a0,2] == 0: - # cos coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + if fit_CTF_FFT or plot_CTF_compare: + if fit_upsampled_FFT: + # Get mean FFT of BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + + # coordinates + q_pixel_size = np.array(self._reciprocal_sampling) \ + / self._kde_upsample_factor else: - # sin coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.sin(self.aber_mn[a0,1] * self.theta.ravel()) - - # CTF function - def calc_CTF(alpha, *coefs): - chi = np.zeros_like(alpha.ravel()) - for a0 in range(len(coefs)): - chi += coefs[a0] * self.aber_basis[:,a0] - return np.reshape(chi, alpha.shape) - - # initial coefficients and plotting intensity range mask - C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength - coefs = np.zeros(self.aber_num) - ind = np.argmin( - np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - ) - coefs[ind] = C10_dimensionless - plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) - # plot_mask[:] = True - angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 + # Get mean FFT of upsampled BF reconstruction + im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) + + # coordinates + q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + + # FFT coordinates + qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) + qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) + qr2 = qx[:,None]**2 + qy[None,:]**2 + self.alpha = np.sqrt(qr2) * self._wavelength + self.theta = np.arctan2(qy[None,:],qx[:,None]) + + # Aberration coefs + mn = [] + for m in range(0,fit_aber_order_max//2+1): + n_max = np.floor(fit_aber_order_max-2*m).astype('int') + + for n in range(0,n_max+1): + if m + n > 1 or (m > 0 and n == 0): + if n == 0: + mn.append([m,n,0]) + else: + mn.append([m,n,0]) + mn.append([m,n,1]) + self.aber_mn = np.array(mn) + self.aber_mn = self.aber_mn[np.argsort(self.aber_mn[:,1]),:] + # self.aber_mn = self.aber_mn[np.lexsort(( + # self.aber_mn[:,0], + # self.aber_mn[:,2], + # self.aber_mn[:,1], + # ))] + sub = self.aber_mn[:,1] > 0 + self.aber_mn[sub,:] = self.aber_mn[sub,:][np.argsort(self.aber_mn[sub,0]),:] + + # Aberration basis + self.aber_num = self.aber_mn.shape[0] + self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) + for a0 in range(self.aber_num): + if self.aber_mn[a0,1] == 0: + # Radially symmetric basis + self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) + elif self.aber_mn[a0,2] == 0: + # cos coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + else: + # sin coef + self.aber_basis[:,a0] = \ + self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ + np.sin(self.aber_mn[a0,1] * self.theta.ravel()) + + # CTF function + def calc_CTF(alpha, *coefs): + chi = np.zeros_like(alpha.ravel()) + for a0 in range(len(coefs)): + chi += coefs[a0] * self.aber_basis[:,a0] + return np.reshape(chi, alpha.shape) + + # initial coefficients and plotting intensity range mask + C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength + coefs = np.zeros(self.aber_num) + ind = np.argmin( + np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) + ) + coefs[ind] = C10_dimensionless + plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) + # plot_mask[:] = True + angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: @@ -1422,142 +1430,36 @@ def score_CTF(coefs): else: return np.inf - for max_num_rings in range(1,fit_max_num_rings+1): - # minimization - res = minimize( - score_CTF, - coefs, - # method = 'Nelder-Mead', - # method = 'CG', - method = 'BFGS', - tol = 1e-8, - ) - coefs = res.x - - # basis = np.vstack(( - # self.alpha.ravel(), - # im_FFT.ravel() - # )) - # print(basis.shape) - # score = score_CTF(self.alpha,coefs*1) - # print(score) - - - # im_CTF = calc_CTF(self.alpha,*coefs) - # im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; - # im_CTF = np.sin(im_CTF)**2 < fit_CTF_threshold - - - - - - # fitting image - # im_fit = im_fft * self.q_weight - # if medfilt_size is not None: - # im_fit = np.fft.ifftshift(medfilt2d( - # np.fft.fftshift(im_fit), - # medfilt_size)) - - # # initial coefs - # int_max = np.max(im_fit) - # sigma_init = np.sqrt(np.max(self.alpha2) / 8.0) - # coefs = np.zeros(5 + self.aber_num) - # lb = np.zeros(5 + self.aber_num) - # ub = np.ones(5 + self.aber_num) * np.inf - # coefs[0] = 1e-3 - # coefs[1] = int_max * 0.1 - # coefs[2] = sigma_init - # coefs[3] = int_max * 0.9 - # coefs[4] = sigma_init - # lb[5:] = -np.inf - # # initial C1 value (defocus) - # ind = np.argmin( - # np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - # ) - # C1_dimensionless = self.aberration_C1 / np.pi * self._wavelength - # coefs[ind + 5] = C1_dimensionless - - # Fitting mask - # fit_mask = self.alpha2 > np.sqrt(np.pi/2/np.abs(C1_dimensionless)) - # basis_masked = self.aber_basis[fit_mask.ravel(),:] - - # Define fitting functions - - # def calc_CTF_mag(alpha2, *coefs): - # int0 = coefs[0] - # int1 = coefs[1] - # sigma1 = coefs[2] - # int_env = coefs[3] - # sigma_env = coefs[4] - - # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - # chi = np.zeros_like(im_CTF_mag) - # for a0 in range(5,len(coefs)): - # chi += coefs[a0] * self.aber_basis[:,a0-5] - # return im_CTF_mag + np.abs(np.sin(chi)) * env - - # def calc_CTF_mag_masked(alpha2, *coefs): - # int0 = coefs[0] - # int1 = coefs[1] - # sigma1 = coefs[2] - # int_env = coefs[3] - # sigma_env = coefs[4] - - # im_CTF_mag = int0 + int1 * np.exp(alpha2/(-2.0*sigma1**2)) - # env = int_env * np.exp(alpha2/(-2.0*sigma_env**2)) - # chi = np.zeros_like(im_CTF_mag) - # for a0 in range(5,len(coefs)): - # chi += coefs[a0] * basis_masked[:,a0-5] - # return im_CTF_mag + np.abs(np.sin(chi)) * env - - # # Refine aberration coefficients - # if maxfev is None: - # coefs = np.array( - # curve_fit( - # calc_CTF_mag_masked, - # self.alpha2[fit_mask], - # im_fit[fit_mask], - # p0 = tuple(coefs), - # bounds = (lb,ub), - # )[0] - # ) - # else: - # coefs = np.array( - # curve_fit( - # calc_CTF_mag_masked, - # self.alpha2[fit_mask], - # im_fit[fit_mask], - # p0 = tuple(coefs), - # bounds = (lb,ub), - # maxfev = maxfev, - # )[0] + # for max_num_rings in range(1,fit_max_num_rings+1): + # # minimization + # res = minimize( + # score_CTF, + # coefs, + # # method = 'Nelder-Mead', + # # method = 'CG', + # method = 'BFGS', + # tol = 1e-8, # ) + # coefs = res.x # Plot the CTF comparison between experiment and fit if plot_CTF_compare: # Generate FFT plotting image im_scale = im_FFT * self.alpha**fit_power_alpha - # int_range = (np.min(im_scale[plot_mask]),np.max(im_scale[plot_mask])) int_vals = np.sort(im_scale.ravel()) int_range = ( int_vals[np.round(0.02*im_scale.size).astype('int')], int_vals[np.round(0.98*im_scale.size).astype('int')], ) - int_range = (int_range[0],(int_range[1]-int_range[0])*1.0 + int_range[0]) im_scale = np.clip( (np.fft.fftshift(im_scale) - int_range[0]) / (int_range[1] - int_range[0]), 0,1) - # im_scale = im_scale**0.5 im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings im_CTF = calc_CTF(self.alpha,*coefs) im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; - # im_CTF = np.sin(im_CTF)**2 - # im_CTF = np.fft.fftshift(im_CTF) - # print(np.max(im_CTF)) im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold im_CTF[np.logical_not(plot_mask)] = 0 im_CTF = np.fft.fftshift(im_CTF * angular_mask) @@ -1569,107 +1471,13 @@ def score_CTF(coefs): fig,ax = plt.subplots(figsize=(12,6)) ax.imshow( im_plot, - # np.fft.fftshift(np.reshape(self.aber_basis[:,1],im_CTF.shape)) - # angular_mask, - # np.hstack(( - # im_scale, - # im_CTF - # )) - # im_ctf ) - - # ax.imshow( - # im_plot / np.max(im_plot), - # vmin = 0, - # vmax = 1, - # cmap = 'gray', - # ) - # ax.imshow( \ - # np.fft.fftshift( - # np.mod(np.reshape(self.aber_basis[:,2],im_fft.shape)+np.pi,2*np.pi)-np.pi - # )) - - - # # Get polar mean from FFT of BF reconstruction - # im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # # coordinates - # kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - # ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - # kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - # k_max = xp.max(kra) / np.sqrt(2.0) - # k_num_bins = int(xp.ceil(k_max / plot_dk)) - # k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # # histogram - # k_ind = kra / plot_dk - # kf = np.floor(k_ind).astype("int") - # dk = k_ind - kf - # sub = kf <= k_num_bins - # hist_exp = xp.bincount( - # kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins - # ) - # hist_norm = xp.bincount( - # kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins - # ) - # sub = kf <= k_num_bins - 1 - - # hist_exp += xp.bincount( - # kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins - # ) - # hist_norm += xp.bincount( - # kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins - # ) - - # # KDE and normalizing - # k_sigma = plot_dk / plot_k_sigma - # hist_exp[0] = 0.0 - # hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - # hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - # hist_exp /= hist_norm - - # # CTF comparison - # CTF_fit = xp.sin( - # (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 - # ) - - # # plotting input - log scale - # min_hist_val = xp.max(hist_exp) * 1e-3 - # hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - # hist_plot -= xp.min(hist_plot) - # hist_plot /= xp.max(hist_plot) - - # hist_plot = asnumpy(hist_plot) - # k_bins = asnumpy(k_bins) - # CTF_fit = asnumpy(CTF_fit) - - # fig, ax = plt.subplots(figsize=(8, 4)) - - # ax.fill_between( - # k_bins, - # hist_plot, - # color=(0.7, 0.7, 0.7, 1), - # ) - - # ax.plot( - # k_bins, - # np.clip(CTF_fit, 0.0, np.inf), - # color=(1, 0, 0, 1), - # linewidth=2, - # ) - # ax.plot( - # k_bins, - # np.clip(-CTF_fit, 0.0, np.inf), - # color=(0, 0.5, 1, 1), - # linewidth=2, - # ) - # ax.set_xlim([0, k_bins[-1]]) - # ax.set_ylim([0, 1.05]) - - # Print results if self._verbose: + if fit_CTF_FFT: + print('Initial Aberration coefficients') + print('-------------------------------') print( ( "Rotation of Q w.r.t. R = " @@ -1686,6 +1494,36 @@ def score_CTF(coefs): print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + if fit_CTF_FFT: + # radial_order = 2 * self.aber_mn[a0,0] + + + print() + print('Refined Aberration coefficients') + print('-------------------------------') + print('radial annular dir. coefs') + print('order order ') + print('------ ------- ---- -----') + + for a0 in range(self.aber_mn.shape[0]): + if self.aber_mn[a0,1] == 0: + print( + str(self.aber_mn[a0,0]) + \ + ' 0 - ' + \ + str(np.round(coefs[a0]).astype('int')) ) + elif self.aber_mn[a0,2] == 0: + print( + str(self.aber_mn[a0,0]) + \ + ' ' + \ + str(self.aber_mn[a0,1]) + \ + ' x ' + \ + str(np.round(coefs[a0]).astype('int')) ) + else: + print( + str(self.aber_mn[a0,0]) + \ + ' ' + \ + str(self.aber_mn[a0,1]) + \ + ' y ' + \ + str(np.round(coefs[a0]).astype('int')) ) def aberration_correct( self, From 270551459904b4ccb23c07bf003dd0c1b99fec3c Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 17 Oct 2023 16:44:47 -0700 Subject: [PATCH 090/176] Adding outputs, plotting --- py4DSTEM/process/phase/iterative_parallax.py | 107 +++++++++++-------- 1 file changed, 63 insertions(+), 44 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index d4ed2a80a..25bca20e8 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1403,14 +1403,15 @@ def calc_CTF(alpha, *coefs): for a0 in range(len(coefs)): chi += coefs[a0] * self.aber_basis[:,a0] return np.reshape(chi, alpha.shape) + self.calc_CTF = calc_CTF # initial coefficients and plotting intensity range mask C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength - coefs = np.zeros(self.aber_num) + self.aber_coefs = np.zeros(self.aber_num) ind = np.argmin( np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) ) - coefs[ind] = C10_dimensionless + self.aber_coefs[ind] = C10_dimensionless plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) # plot_mask[:] = True angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 @@ -1430,17 +1431,17 @@ def score_CTF(coefs): else: return np.inf - # for max_num_rings in range(1,fit_max_num_rings+1): - # # minimization - # res = minimize( - # score_CTF, - # coefs, - # # method = 'Nelder-Mead', - # # method = 'CG', - # method = 'BFGS', - # tol = 1e-8, - # ) - # coefs = res.x + for max_num_rings in range(1,fit_max_num_rings+1): + # minimization + res = minimize( + score_CTF, + self.aber_coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method = 'BFGS', + tol = 1e-8, + ) + self.aber_coefs = res.x # Plot the CTF comparison between experiment and fit if plot_CTF_compare: @@ -1458,7 +1459,7 @@ def score_CTF(coefs): im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings - im_CTF = calc_CTF(self.alpha,*coefs) + im_CTF = calc_CTF(self.alpha,*self.aber_coefs) im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold im_CTF[np.logical_not(plot_mask)] = 0 @@ -1495,7 +1496,7 @@ def score_CTF(coefs): print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") if fit_CTF_FFT: - # radial_order = 2 * self.aber_mn[a0,0] + + radial_order = 2 * self.aber_mn[:,0] + self.aber_mn[:,1] print() print('Refined Aberration coefficients') @@ -1507,26 +1508,27 @@ def score_CTF(coefs): for a0 in range(self.aber_mn.shape[0]): if self.aber_mn[a0,1] == 0: print( - str(self.aber_mn[a0,0]) + \ + str(radial_order[a0]) + \ ' 0 - ' + \ - str(np.round(coefs[a0]).astype('int')) ) + str(np.round(self.aber_coefs[a0]).astype('int')) ) elif self.aber_mn[a0,2] == 0: print( - str(self.aber_mn[a0,0]) + \ + str(radial_order[a0]) + \ ' ' + \ str(self.aber_mn[a0,1]) + \ ' x ' + \ - str(np.round(coefs[a0]).astype('int')) ) + str(np.round(self.aber_coefs[a0]).astype('int')) ) else: print( - str(self.aber_mn[a0,0]) + \ + str(radial_order[a0]) + \ ' ' + \ str(self.aber_mn[a0,1]) + \ ' y ' + \ - str(np.round(coefs[a0]).astype('int')) ) + str(np.round(self.aber_coefs[a0]).astype('int')) ) def aberration_correct( self, + use_FFT_fit = True, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, @@ -1541,6 +1543,8 @@ def aberration_correct( Parameters ---------- + use_FFT_fit: bool + Use the CTF fitted to the zero crossings of the FFT. plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional @@ -1581,30 +1585,9 @@ def aberration_correct( ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) - - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio - ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(im) * CTF_corr + if use_FFT_fit: + sin_chi = np.sin(self.calc_CTF(self.alpha,*self.aber_coefs)) - else: - # CTF without tilt correction (beyond the parallax operator) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 @@ -1616,6 +1599,42 @@ def aberration_correct( im_fft_corr /= 1 + (kra2**k_info_power) / ( (k_info_limit) ** (2 * k_info_power) ) + else: + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) From 7eae9484b4e46bf6ce0bc2af5519aef4324fff9e Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 19 Oct 2023 10:19:02 -0700 Subject: [PATCH 091/176] finally works --- py4DSTEM/process/phase/iterative_parallax.py | 386 +++++++++++++------ 1 file changed, 264 insertions(+), 122 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 25bca20e8..ef6ae9f5b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -13,6 +13,7 @@ from py4DSTEM import DataCube from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom from py4DSTEM.visualize import show @@ -1236,10 +1237,10 @@ def subpixel_alignment( ) reciprocal_extent = [ - -self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, - self._reciprocal_sampling[1] * cropped_object_aligned.shape[1] / 2, - self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, - -self._reciprocal_sampling[0] * cropped_object_aligned.shape[0] / 2, + -0.5/(self._scan_sampling[1]/self._kde_upsample_factor), + 0.5/(self._scan_sampling[1]/self._kde_upsample_factor), + 0.5/(self._scan_sampling[0]/self._kde_upsample_factor), + -0.5/(self._scan_sampling[0]/self._kde_upsample_factor), ] show( @@ -1269,22 +1270,23 @@ def subpixel_alignment( def aberration_fit( self, - fit_CTF_FFT = True, - fit_CTF_threshold = 0.25, - fit_upsampled_FFT = True, - fit_aber_order_max = 2, - fit_max_num_rings = 6, - fit_power_alpha = 2.0, - # fit_maxfev = None, - plot_CTF_compare: bool = False, + fit_BF_shifts:bool = True, + fit_CTF_FFT:bool = False, + fit_aberrations_max_radial_order:int=3, + fit_aberrations_max_angular_order:int=4, + fit_aberrations_min_radial_order:int=1, + fit_aberrations_min_angular_order:int=0, + fit_max_thon_rings:int = 6, + fit_power_alpha:float = 2.0, + plot_CTF_comparison: bool = None, + plot_BF_shifts_comparison: bool = None, + upsampled:bool=True, ): """ Fit aberrations to the measured image shifts. Parameters ---------- - plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies. fit_CTF_FFT: bool Set to True to directly fit aberrations in the FFT of the upsampled BF image (if available). Note that this method relies on visible zero @@ -1296,14 +1298,14 @@ def aberration_fit( Max radial order for fitting of aberrations. ctf_threshold: float CTF fitting minimizes value at CTF zero crossings (Thon ring minima). - + plot_CTF_compare: bool, optional + If True, the fitted CTF is plotted against the reconstructed frequencies. """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter - ### initial aberration fit ### + ### First pass # Convert real space shifts to Angstroms self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) @@ -1326,127 +1328,210 @@ def aberration_fit( self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 self.aberration_A1x = ( m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 # factor /2 for A1 astigmatism? /4? + ) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + ### Second pass + # Aberration coefs + mn = [] - ### FFT fitting / plotting code ### - if fit_CTF_FFT or plot_CTF_compare: - if fit_upsampled_FFT: - # Get mean FFT of BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + for m in range(fit_aberrations_min_radial_order,fit_aberrations_max_radial_order+1): + n_max = np.minimum(fit_aberrations_max_angular_order,m+1) + for n in range(fit_aberrations_min_angular_order,n_max+1): + if (m+n) % 2: + mn.append([m,n,0]) + if n > 0: + mn.append([m,n,1]) + + self._aberrations_mn = np.array(mn) + self._aberrations_mn = self._aberrations_mn[np.argsort(self._aberrations_mn[:,1]),:] + + sub = self._aberrations_mn[:,1] > 0 + self._aberrations_mn[sub,:] = self._aberrations_mn[sub,:][np.argsort(self._aberrations_mn[sub,0]),:] + self._aberrations_num = self._aberrations_mn.shape[0] + + if plot_CTF_comparison is None: + if fit_CTF_FFT: + plot_CTF_comparison = True + + if plot_BF_shifts_comparison is None: + if fit_BF_shifts: + plot_BF_shifts_comparison = True + + # Thon Rings Fitting + if fit_CTF_FFT or plot_CTF_comparison: + if upsampled and hasattr(self,"_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor - # coordinates - q_pixel_size = np.array(self._reciprocal_sampling) \ - / self._kde_upsample_factor else: - # Get mean FFT of upsampled BF reconstruction - im_FFT = np.abs(xp.fft.fft2(self._recon_BF)) + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + upsampled=False - # coordinates - q_pixel_size = np.array(parallax_recon._reciprocal_sampling) + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0],sx) + qy = xp.fft.fftfreq(im_FFT.shape[1],sy) + qr2 = qx[:,None]**2 + qy[None,:]**2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None,:],qx[:,None]) + + # Aberration basis + self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) + for a0 in range(self._aberrations_num): + m,n,a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() + elif a == 0: + # cos coef + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() + else: + # sin coef + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() + + # global scaling + self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_surface_shape = alpha.shape + plot_mask = qr2 > np.pi**2/4/np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta)**2 < 0.25 + + # Direct Shifts Fitting + elif fit_BF_shifts: + # FFT coordinates - qx = np.fft.fftfreq(im_FFT.shape[0],1/im_FFT.shape[0]/q_pixel_size[0]) - qy = np.fft.fftfreq(im_FFT.shape[1],1/im_FFT.shape[1]/q_pixel_size[1]) + sx = 1/(self._reciprocal_sampling[0]*self._region_of_interest_shape[0]) + sy = 1/(self._reciprocal_sampling[1]*self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0],sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1],sy) qr2 = qx[:,None]**2 + qy[None,:]**2 - self.alpha = np.sqrt(qr2) * self._wavelength - self.theta = np.arctan2(qy[None,:],qx[:,None]) - - # Aberration coefs - mn = [] - for m in range(0,fit_aber_order_max//2+1): - n_max = np.floor(fit_aber_order_max-2*m).astype('int') - - for n in range(0,n_max+1): - if m + n > 1 or (m > 0 and n == 0): - if n == 0: - mn.append([m,n,0]) - else: - mn.append([m,n,0]) - mn.append([m,n,1]) - self.aber_mn = np.array(mn) - self.aber_mn = self.aber_mn[np.argsort(self.aber_mn[:,1]),:] - # self.aber_mn = self.aber_mn[np.lexsort(( - # self.aber_mn[:,0], - # self.aber_mn[:,2], - # self.aber_mn[:,1], - # ))] - sub = self.aber_mn[:,1] > 0 - self.aber_mn[sub,:] = self.aber_mn[sub,:][np.argsort(self.aber_mn[sub,0]),:] + + u = qx[:,None]*self._wavelength + v = qy[None,:]*self._wavelength + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None,:],qx[:,None]) # Aberration basis - self.aber_num = self.aber_mn.shape[0] - self.aber_basis = np.zeros((self.alpha.size,self.aber_num)) - for a0 in range(self.aber_num): - if self.aber_mn[a0,1] == 0: + self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size,self._aberrations_num)) + for a0 in range(self._aberrations_num): + m,n,a = self._aberrations_mn[a0] + + if n == 0: # Radially symmetric basis - self.aber_basis[:,a0] = self.alpha.ravel()**(2*self.aber_mn[a0,0]) - elif self.aber_mn[a0,2] == 0: + self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() + self._aberrations_basis_du[:,a0] = (u*alpha**(m-1)).ravel() + self._aberrations_basis_dv[:,a0] = (v*alpha**(m-1)).ravel() + + elif a == 0: # cos coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.cos(self.aber_mn[a0,1] * self.theta.ravel()) + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() + self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.cos(n*theta) + n*v*xp.sin(n*theta))/(m+1)).ravel() + self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.cos(n*theta) - n*u*xp.sin(n*theta))/(m+1)).ravel() + else: # sin coef - self.aber_basis[:,a0] = \ - self.alpha.ravel()**(2.0*self.aber_mn[a0,0] + self.aber_mn[a0,1]) * \ - np.sin(self.aber_mn[a0,1] * self.theta.ravel()) - - # CTF function - def calc_CTF(alpha, *coefs): - chi = np.zeros_like(alpha.ravel()) - for a0 in range(len(coefs)): - chi += coefs[a0] * self.aber_basis[:,a0] - return np.reshape(chi, alpha.shape) - self.calc_CTF = calc_CTF - - # initial coefficients and plotting intensity range mask - C10_dimensionless = self.aberration_C1 * np.pi / 4 / self._wavelength - self.aber_coefs = np.zeros(self.aber_num) - ind = np.argmin( - np.abs(self.aber_mn[:,0] - 1.0) + np.abs(self.aber_mn[:,1]) - ) - self.aber_coefs[ind] = C10_dimensionless - plot_mask = self.alpha > np.sqrt(np.pi/2/np.abs(C10_dimensionless)) - # plot_mask[:] = True - angular_mask = np.cos(8.0 * np.arctan2(qy[:,None],qx[None,:]))**2 < 0.25 + self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() + self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.sin(n*theta) - n*v*xp.cos(n*theta))/(m+1)).ravel() + self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.sin(n*theta) + n*u*xp.cos(n*theta))/(m+1)).ravel() + + # global scaling + self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_surface_shape = alpha.shape + + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:,0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:,a0] + return xp.reshape(chi, alpha_shape) + self._calculate_CTF = calculate_CTF + + # initial coefficients and plotting intensity range mask + self._aberrations_coefs = np.zeros(self._aberrations_num) + ind = np.argmin( + np.abs(self._aberrations_mn[:,0] - 1.0) + self._aberrations_mn[:,1] + ) + self._aberrations_coefs[ind] = self.aberration_C1 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: + # scoring function to minimize - mean value of zero crossing regions of FFT def score_CTF(coefs): - im_CTF = np.abs(calc_CTF(self.alpha,*coefs)) - mask = np.logical_and( + im_CTF = xp.abs(self._calculate_CTF(self._aberrations_surface_shape,*coefs)) + mask = xp.logical_and( im_CTF > 0.5*np.pi, im_CTF < (max_num_rings+0.5)*np.pi, ) if np.any(mask): - weights = np.cos(im_CTF[mask])**4 - return np.sum(weights*im_FFT[mask]*self.alpha[mask]**fit_power_alpha) / np.sum(weights) + weights = xp.cos(im_CTF[mask])**4 + return asnumpy(xp.sum(weights*im_FFT[mask]*alpha[mask]**fit_power_alpha) / xp.sum(weights)) else: return np.inf - for max_num_rings in range(1,fit_max_num_rings+1): + for max_num_rings in range(1,fit_max_thon_rings+1): # minimization res = minimize( score_CTF, - self.aber_coefs, + self._aberrations_coefs, # method = 'Nelder-Mead', # method = 'CG', method = 'BFGS', tol = 1e-8, ) - self.aber_coefs = res.x + self._aberrations_coefs = res.x + + # Refinement using CTF fitting / Thon rings + elif fit_BF_shifts: + + # Gradient basis + corner_indices = self._xy_inds-xp.asarray(self._region_of_interest_shape//2) + raveled_indices = np.ravel_multi_index(corner_indices.T,self._region_of_interest_shape,mode='wrap') + gradients = xp.vstack(( + self._aberrations_basis_du[raveled_indices,:], + self._aberrations_basis_dv[raveled_indices,:] + )) + + # Untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang,xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq(gradients, rotated_shifts,rcond=None)[:2] + + # Transposed fit + transposed_shifts = xp.flip(self._xy_shifts_Ang,axis=1) + m_T = asnumpy(xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0]) + m_rotation_T, _ = polar(m_T, side="right") + rotation_Q_to_R_rads_T = -1 * np.arctan2(m_rotation_T[1, 0], m_rotation_T[0, 0]) + if np.abs(np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi) > ( + np.pi * 0.5 + ): + rotation_Q_to_R_rads_T = ( + np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + ) + + tf = AffineTransform(angle=rotation_Q_to_R_rads_T) + rotated_shifts_T = tf(transposed_shifts,xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq(gradients, rotated_shifts_T,rcond=None)[:2] + if res_T.sum() < res.sum(): + self._aberrations_coefs = asnumpy(aberrations_coefs_T) + self._rotated_shifts = rotated_shifts_T + else: + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts + + # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: + if plot_CTF_comparison: # Generate FFT plotting image - im_scale = im_FFT * self.alpha**fit_power_alpha + im_scale = asnumpy(im_FFT * alpha**fit_power_alpha) int_vals = np.sort(im_scale.ravel()) int_range = ( int_vals[np.round(0.02*im_scale.size).astype('int')], @@ -1459,24 +1544,77 @@ def score_CTF(coefs): im_plot = np.tile(im_scale[:,:,None],(1,1,3)) # Add CTF zero crossings - im_CTF = calc_CTF(self.alpha,*self.aber_coefs) - im_CTF[np.abs(im_CTF) > (fit_max_num_rings+0.5)*np.pi] = np.pi/2; - im_CTF = np.abs(np.sin(im_CTF)) < fit_CTF_threshold - im_CTF[np.logical_not(plot_mask)] = 0 - im_CTF = np.fft.fftshift(im_CTF * angular_mask) + im_CTF = self._calculate_CTF(self._aberrations_surface_shape,*self._aberrations_coefs) + im_CTF_cos = xp.cos(xp.abs(im_CTF))**4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings+0.5)*np.pi] = np.pi/2 + im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 + im_CTF[xp.logical_not(plot_mask)] = 0 + + im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) im_plot[:,:,0] += im_CTF im_plot[:,:,1] -= im_CTF im_plot[:,:,2] -= im_CTF im_plot = np.clip(im_plot,0,1) - fig,ax = plt.subplots(figsize=(12,6)) - ax.imshow( + fig,(ax1,ax2) = plt.subplots(1,2,figsize=(12,6)) + ax1.imshow( im_plot, + vmin=int_range[0], + vmax=int_range[1] + ) + + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap='gray' + ) + + fig.tight_layout() + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + + if not fit_BF_shifts: + raise ValueError() + + measured_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + measured_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[:self._xy_inds.shape[0]] + + measured_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + measured_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[self._xy_inds.shape[0]:] + + fitted_shifts = xp.tensordot(gradients,xp.array(self._aberrations_coefs),axes=1) + + fitted_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + fitted_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[:self._xy_inds.shape[0]] + + fitted_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) + fitted_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[self._xy_inds.shape[0]:] + + max_shift = xp.max(xp.array([xp.abs(measured_shifts_sx).max(),xp.abs(measured_shifts_sy).max(),xp.abs(fitted_shifts_sx).max(),xp.abs(fitted_shifts_sy).max()])) + + show( + [ + [ + asnumpy(measured_shifts_sx), + asnumpy(measured_shifts_sy) + ], + [ + asnumpy(fitted_shifts_sx), + asnumpy(fitted_shifts_sy) + ], + ], + cmap='PiYG', + vmin=-max_shift, + vmax=max_shift, + intensity_range='absolute', + axsize=(4,4), + ticks=False, + title=["Measured Vertical Shifts","Measured Horizontal Shifts","Fitted Vertical Shifts","Fitted Horizontal Shifts"] ) # Print results if self._verbose: - if fit_CTF_FFT: + if fit_CTF_FFT or fit_BF_shifts: print('Initial Aberration coefficients') print('-------------------------------') print( @@ -1495,36 +1633,40 @@ def score_CTF(coefs): print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - if fit_CTF_FFT: - radial_order = 2 * self.aber_mn[:,0] + self.aber_mn[:,1] + if fit_CTF_FFT or fit_BF_shifts: print() print('Refined Aberration coefficients') print('-------------------------------') - print('radial annular dir. coefs') + print('radial angular dir. coefs') print('order order ') print('------ ------- ---- -----') - for a0 in range(self.aber_mn.shape[0]): - if self.aber_mn[a0,1] == 0: + for a0 in range(self._aberrations_mn.shape[0]): + m, n, a = self._aberrations_mn[a0] + if n == 0: print( - str(radial_order[a0]) + \ + str(m) + \ ' 0 - ' + \ - str(np.round(self.aber_coefs[a0]).astype('int')) ) - elif self.aber_mn[a0,2] == 0: + str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + elif a == 0: print( - str(radial_order[a0]) + \ + str(m) + \ ' ' + \ - str(self.aber_mn[a0,1]) + \ + str(n) + \ ' x ' + \ - str(np.round(self.aber_coefs[a0]).astype('int')) ) + str(np.round(self._aberrations_coefs[a0]).astype('int')) ) else: print( - str(radial_order[a0]) + \ + str(m) + \ ' ' + \ - str(self.aber_mn[a0,1]) + \ + str(n) + \ ' y ' + \ - str(np.round(self.aber_coefs[a0]).astype('int')) ) + str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() def aberration_correct( self, @@ -1909,7 +2051,7 @@ def _visualize_figax( **kwargs, ) - def _visualize_shifts( + def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, From d1f6efb6d4bce315a7596d4f4f5e6a21ad26845f Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 19 Oct 2023 10:31:37 -0700 Subject: [PATCH 092/176] some support for aberration correct --- py4DSTEM/process/phase/iterative_parallax.py | 442 +++++++++++-------- 1 file changed, 264 insertions(+), 178 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index ef6ae9f5b..5f7e1c25e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1237,10 +1237,10 @@ def subpixel_alignment( ) reciprocal_extent = [ - -0.5/(self._scan_sampling[1]/self._kde_upsample_factor), - 0.5/(self._scan_sampling[1]/self._kde_upsample_factor), - 0.5/(self._scan_sampling[0]/self._kde_upsample_factor), - -0.5/(self._scan_sampling[0]/self._kde_upsample_factor), + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), ] show( @@ -1270,17 +1270,17 @@ def subpixel_alignment( def aberration_fit( self, - fit_BF_shifts:bool = True, - fit_CTF_FFT:bool = False, - fit_aberrations_max_radial_order:int=3, - fit_aberrations_max_angular_order:int=4, - fit_aberrations_min_radial_order:int=1, - fit_aberrations_min_angular_order:int=0, - fit_max_thon_rings:int = 6, - fit_power_alpha:float = 2.0, + fit_BF_shifts: bool = True, + fit_CTF_FFT: bool = False, + fit_aberrations_max_radial_order: int = 3, + fit_aberrations_max_angular_order: int = 4, + fit_aberrations_min_radial_order: int = 1, + fit_aberrations_min_angular_order: int = 0, + fit_max_thon_rings: int = 6, + fit_power_alpha: float = 2.0, plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, - upsampled:bool=True, + upsampled: bool = True, ): """ Fit aberrations to the measured image shifts. @@ -1288,8 +1288,8 @@ def aberration_fit( Parameters ---------- fit_CTF_FFT: bool - Set to True to directly fit aberrations in the FFT of the upsampled BF - image (if available). Note that this method relies on visible zero + Set to True to directly fit aberrations in the FFT of the upsampled BF + image (if available). Note that this method relies on visible zero crossings in the FFT, and will not work if they are not present. fit_upsampled_FFT: bool If True, we aberration fit is performed on the upsampled BF image. @@ -1326,9 +1326,7 @@ def aberration_fit( ) m_aberration = -1.0 * m_aberration self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = ( - m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 ### Second pass @@ -1336,32 +1334,38 @@ def aberration_fit( # Aberration coefs mn = [] - for m in range(fit_aberrations_min_radial_order,fit_aberrations_max_radial_order+1): - n_max = np.minimum(fit_aberrations_max_angular_order,m+1) - for n in range(fit_aberrations_min_angular_order,n_max+1): - if (m+n) % 2: - mn.append([m,n,0]) + for m in range( + fit_aberrations_min_radial_order, fit_aberrations_max_radial_order + 1 + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) if n > 0: - mn.append([m,n,1]) + mn.append([m, n, 1]) self._aberrations_mn = np.array(mn) - self._aberrations_mn = self._aberrations_mn[np.argsort(self._aberrations_mn[:,1]),:] - - sub = self._aberrations_mn[:,1] > 0 - self._aberrations_mn[sub,:] = self._aberrations_mn[sub,:][np.argsort(self._aberrations_mn[sub,0]),:] + self._aberrations_mn = self._aberrations_mn[ + np.argsort(self._aberrations_mn[:, 1]), : + ] + + sub = self._aberrations_mn[:, 1] > 0 + self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ + np.argsort(self._aberrations_mn[sub, 0]), : + ] self._aberrations_num = self._aberrations_mn.shape[0] if plot_CTF_comparison is None: if fit_CTF_FFT: plot_CTF_comparison = True - + if plot_BF_shifts_comparison is None: if fit_BF_shifts: plot_BF_shifts_comparison = True # Thon Rings Fitting if fit_CTF_FFT or plot_CTF_comparison: - if upsampled and hasattr(self,"_kde_upsample_factor"): + if upsampled and hasattr(self, "_kde_upsample_factor"): im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) sx = self._scan_sampling[0] / self._kde_upsample_factor sy = self._scan_sampling[1] / self._kde_upsample_factor @@ -1370,145 +1374,188 @@ def aberration_fit( im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) sx = self._scan_sampling[0] sy = self._scan_sampling[1] - upsampled=False + upsampled = False # FFT coordinates - qx = xp.fft.fftfreq(im_FFT.shape[0],sx) - qy = xp.fft.fftfreq(im_FFT.shape[1],sy) - qr2 = qx[:,None]**2 + qy[None,:]**2 + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None,:],qx[:,None]) + theta = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) for a0 in range(self._aberrations_num): - m,n,a = self._aberrations_mn[a0] + m, n, a = self._aberrations_mn[a0] if n == 0: # Radially symmetric basis - self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() elif a == 0: # cos coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() else: # sin coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() - + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + # global scaling - self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_basis *= 2 * np.pi / self._wavelength self._aberrations_surface_shape = alpha.shape - plot_mask = qr2 > np.pi**2/4/np.abs(self.aberration_C1) - angular_mask = np.cos(8.0 * theta)**2 < 0.25 + plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta) ** 2 < 0.25 # Direct Shifts Fitting elif fit_BF_shifts: - # FFT coordinates - sx = 1/(self._reciprocal_sampling[0]*self._region_of_interest_shape[0]) - sy = 1/(self._reciprocal_sampling[1]*self._region_of_interest_shape[1]) - qx = xp.fft.fftfreq(self._region_of_interest_shape[0],sx) - qy = xp.fft.fftfreq(self._region_of_interest_shape[1],sy) - qr2 = qx[:,None]**2 + qy[None,:]**2 - - u = qx[:,None]*self._wavelength - v = qy[None,:]*self._wavelength + sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) + sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + u = qx[:, None] * self._wavelength + v = qy[None, :] * self._wavelength alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None,:],qx[:,None]) + theta = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size,self._aberrations_num)) - self._aberrations_basis_du = xp.zeros((alpha.size,self._aberrations_num)) - self._aberrations_basis_dv = xp.zeros((alpha.size,self._aberrations_num)) + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) for a0 in range(self._aberrations_num): - m,n,a = self._aberrations_mn[a0] + m, n, a = self._aberrations_mn[a0] if n == 0: # Radially symmetric basis - self._aberrations_basis[:,a0] = (alpha**(m+1) / (m+1)).ravel() - self._aberrations_basis_du[:,a0] = (u*alpha**(m-1)).ravel() - self._aberrations_basis_dv[:,a0] = (v*alpha**(m-1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() elif a == 0: # cos coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.cos(n * theta) / (m+1)).ravel() - self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.cos(n*theta) + n*v*xp.sin(n*theta))/(m+1)).ravel() - self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.cos(n*theta) - n*u*xp.sin(n*theta))/(m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() else: # sin coef - self._aberrations_basis[:,a0] = (alpha**(m+1) * xp.sin(n * theta) / (m+1)).ravel() - self._aberrations_basis_du[:,a0] = (alpha**(m-1)*((m+1)*u*xp.sin(n*theta) - n*v*xp.cos(n*theta))/(m+1)).ravel() - self._aberrations_basis_dv[:,a0] = (alpha**(m-1)*((m+1)*v*xp.sin(n*theta) + n*u*xp.cos(n*theta))/(m+1)).ravel() + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() # global scaling - self._aberrations_basis *= 2*np.pi/self._wavelength + self._aberrations_basis *= 2 * np.pi / self._wavelength self._aberrations_surface_shape = alpha.shape # CTF function def calculate_CTF(alpha_shape, *coefs): - chi = xp.zeros_like(self._aberrations_basis[:,0]) + chi = xp.zeros_like(self._aberrations_basis[:, 0]) for a0 in range(len(coefs)): - chi += coefs[a0] * self._aberrations_basis[:,a0] + chi += coefs[a0] * self._aberrations_basis[:, a0] return xp.reshape(chi, alpha_shape) + self._calculate_CTF = calculate_CTF # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) ind = np.argmin( - np.abs(self._aberrations_mn[:,0] - 1.0) + self._aberrations_mn[:,1] + np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] ) self._aberrations_coefs[ind] = self.aberration_C1 # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: - # scoring function to minimize - mean value of zero crossing regions of FFT def score_CTF(coefs): - im_CTF = xp.abs(self._calculate_CTF(self._aberrations_surface_shape,*coefs)) + im_CTF = xp.abs( + self._calculate_CTF(self._aberrations_surface_shape, *coefs) + ) mask = xp.logical_and( - im_CTF > 0.5*np.pi, - im_CTF < (max_num_rings+0.5)*np.pi, + im_CTF > 0.5 * np.pi, + im_CTF < (max_num_rings + 0.5) * np.pi, ) if np.any(mask): - weights = xp.cos(im_CTF[mask])**4 - return asnumpy(xp.sum(weights*im_FFT[mask]*alpha[mask]**fit_power_alpha) / xp.sum(weights)) + weights = xp.cos(im_CTF[mask]) ** 4 + return asnumpy( + xp.sum(weights * im_FFT[mask] * alpha[mask] ** fit_power_alpha) + / xp.sum(weights) + ) else: return np.inf - for max_num_rings in range(1,fit_max_thon_rings+1): + for max_num_rings in range(1, fit_max_thon_rings + 1): # minimization res = minimize( - score_CTF, - self._aberrations_coefs, - # method = 'Nelder-Mead', + score_CTF, + self._aberrations_coefs, + # method = 'Nelder-Mead', # method = 'CG', - method = 'BFGS', - tol = 1e-8, + method="BFGS", + tol=1e-8, ) self._aberrations_coefs = res.x - + # Refinement using CTF fitting / Thon rings elif fit_BF_shifts: - # Gradient basis - corner_indices = self._xy_inds-xp.asarray(self._region_of_interest_shape//2) - raveled_indices = np.ravel_multi_index(corner_indices.T,self._region_of_interest_shape,mode='wrap') - gradients = xp.vstack(( - self._aberrations_basis_du[raveled_indices,:], - self._aberrations_basis_dv[raveled_indices,:] - )) + corner_indices = self._xy_inds - xp.asarray( + self._region_of_interest_shape // 2 + ) + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.vstack( + ( + self._aberrations_basis_du[raveled_indices, :], + self._aberrations_basis_dv[raveled_indices, :], + ) + ) # Untransposed fit tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang,xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq(gradients, rotated_shifts,rcond=None)[:2] + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] # Transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang,axis=1) - m_T = asnumpy(xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0]) + transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) + m_T = asnumpy( + xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0] + ) m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2(m_rotation_T[1, 0], m_rotation_T[0, 0]) + rotation_Q_to_R_rads_T = -1 * np.arctan2( + m_rotation_T[1, 0], m_rotation_T[0, 0] + ) if np.abs(np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi) > ( np.pi * 0.5 ): @@ -1517,8 +1564,10 @@ def score_CTF(coefs): ) tf = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf(transposed_shifts,xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq(gradients, rotated_shifts_T,rcond=None)[:2] + rotated_shifts_T = tf(transposed_shifts, xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq( + gradients, rotated_shifts_T, rcond=None + )[:2] if res_T.sum() < res.sum(): self._aberrations_coefs = asnumpy(aberrations_coefs_T) @@ -1526,97 +1575,122 @@ def score_CTF(coefs): else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts - - + # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: # Generate FFT plotting image im_scale = asnumpy(im_FFT * alpha**fit_power_alpha) int_vals = np.sort(im_scale.ravel()) int_range = ( - int_vals[np.round(0.02*im_scale.size).astype('int')], - int_vals[np.round(0.98*im_scale.size).astype('int')], - ) - int_range = (int_range[0],(int_range[1]-int_range[0])*1.0 + int_range[0]) + int_vals[np.round(0.02 * im_scale.size).astype("int")], + int_vals[np.round(0.98 * im_scale.size).astype("int")], + ) + int_range = ( + int_range[0], + (int_range[1] - int_range[0]) * 1.0 + int_range[0], + ) im_scale = np.clip( - (np.fft.fftshift(im_scale) - int_range[0]) / (int_range[1] - int_range[0]), - 0,1) - im_plot = np.tile(im_scale[:,:,None],(1,1,3)) + (np.fft.fftshift(im_scale) - int_range[0]) + / (int_range[1] - int_range[0]), + 0, + 1, + ) + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) # Add CTF zero crossings - im_CTF = self._calculate_CTF(self._aberrations_surface_shape,*self._aberrations_coefs) - im_CTF_cos = xp.cos(xp.abs(im_CTF))**4 - im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings+0.5)*np.pi] = np.pi/2 + im_CTF = self._calculate_CTF( + self._aberrations_surface_shape, *self._aberrations_coefs + ) + im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 im_CTF[xp.logical_not(plot_mask)] = 0 im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) - im_plot[:,:,0] += im_CTF - im_plot[:,:,1] -= im_CTF - im_plot[:,:,2] -= im_CTF - im_plot = np.clip(im_plot,0,1) - - fig,(ax1,ax2) = plt.subplots(1,2,figsize=(12,6)) - ax1.imshow( - im_plot, - vmin=int_range[0], - vmax=int_range[1] - ) - - ax2.imshow( - np.fft.fftshift(asnumpy(im_CTF_cos)), - cmap='gray' - ) + im_plot[:, :, 0] += im_CTF + im_plot[:, :, 1] -= im_CTF + im_plot[:, :, 2] -= im_CTF + im_plot = np.clip(im_plot, 0, 1) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) + ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) + + ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") fig.tight_layout() - + # Plot the measured/fitted shifts comparison if plot_BF_shifts_comparison: - if not fit_BF_shifts: raise ValueError() - - measured_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - measured_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[:self._xy_inds.shape[0]] - - measured_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - measured_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = self._rotated_shifts[self._xy_inds.shape[0]:] - fitted_shifts = xp.tensordot(gradients,xp.array(self._aberrations_coefs),axes=1) + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[: self._xy_inds.shape[0]] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._rotated_shifts[self._xy_inds.shape[0] :] + + fitted_shifts = xp.tensordot( + gradients, xp.array(self._aberrations_coefs), axes=1 + ) - fitted_shifts_sx = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - fitted_shifts_sx[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[:self._xy_inds.shape[0]] + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + : self._xy_inds.shape[0] + ] - fitted_shifts_sy = xp.zeros(self._region_of_interest_shape,dtype=xp.float32) - fitted_shifts_sy[self._xy_inds[:,0],self._xy_inds[:,1]] = fitted_shifts[self._xy_inds.shape[0]:] + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ + self._xy_inds.shape[0] : + ] - max_shift = xp.max(xp.array([xp.abs(measured_shifts_sx).max(),xp.abs(measured_shifts_sy).max(),xp.abs(fitted_shifts_sx).max(),xp.abs(fitted_shifts_sy).max()])) + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] + ) + ) show( [ - [ - asnumpy(measured_shifts_sx), - asnumpy(measured_shifts_sy) - ], - [ - asnumpy(fitted_shifts_sx), - asnumpy(fitted_shifts_sy) - ], + [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], + [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], ], - cmap='PiYG', + cmap="PiYG", vmin=-max_shift, vmax=max_shift, - intensity_range='absolute', - axsize=(4,4), + intensity_range="absolute", + axsize=(4, 4), ticks=False, - title=["Measured Vertical Shifts","Measured Horizontal Shifts","Fitted Vertical Shifts","Fitted Horizontal Shifts"] + title=[ + "Measured Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Vertical Shifts", + "Fitted Horizontal Shifts", + ], ) # Print results if self._verbose: if fit_CTF_FFT or fit_BF_shifts: - print('Initial Aberration coefficients') - print('-------------------------------') + print("Initial Aberration coefficients") + print("-------------------------------") print( ( "Rotation of Q w.r.t. R = " @@ -1634,35 +1708,37 @@ def score_CTF(coefs): print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") if fit_CTF_FFT or fit_BF_shifts: - print() - print('Refined Aberration coefficients') - print('-------------------------------') - print('radial angular dir. coefs') - print('order order ') - print('------ ------- ---- -----') + print("Refined Aberration coefficients") + print("-------------------------------") + print("radial angular dir. coefs") + print("order order ") + print("------ ------- ---- -----") for a0 in range(self._aberrations_mn.shape[0]): m, n, a = self._aberrations_mn[a0] if n == 0: print( - str(m) + \ - ' 0 - ' + \ - str(np.round(self._aberrations_coefs[a0]).astype('int')) ) - elif a == 0: + str(m) + + " 0 - " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + elif a == 0: print( - str(m) + \ - ' ' + \ - str(n) + \ - ' x ' + \ - str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + str(m) + + " " + + str(n) + + " x " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) else: print( - str(m) + \ - ' ' + \ - str(n) + \ - ' y ' + \ - str(np.round(self._aberrations_coefs[a0]).astype('int')) ) + str(m) + + " " + + str(n) + + " y " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) if self._device == "gpu": xp._default_memory_pool.free_all_blocks() @@ -1670,7 +1746,7 @@ def score_CTF(coefs): def aberration_correct( self, - use_FFT_fit = True, + use_CTF_fit=None, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, @@ -1727,8 +1803,16 @@ def aberration_correct( ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - if use_FFT_fit: - sin_chi = np.sin(self.calc_CTF(self.alpha,*self.aber_coefs)) + if use_CTF_fit is None: + if hasattr(self, "_aberrations_surface_shape"): + use_CTF_fit = True + + if use_CTF_fit: + sin_chi = np.sin( + self._calculate_CTF( + self._aberrations_surface_shape, *self._aberrations_coefs + ) + ) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 @@ -1748,7 +1832,9 @@ def aberration_correct( if Wiener_filter: SNR_inv = ( xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) + 1 + + (kra2**k_info_power) + / ((k_info_limit) ** (2 * k_info_power)) ) / Wiener_signal_noise_ratio ) From de75b517f89be4f1976cc7bdd29dd2908d58bedf Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 19 Oct 2023 15:03:35 -0700 Subject: [PATCH 093/176] small bug fixes --- py4DSTEM/process/phase/iterative_parallax.py | 106 ++++++++++++++----- 1 file changed, 77 insertions(+), 29 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 5f7e1c25e..1a03f4fa3 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1381,38 +1381,47 @@ def aberration_fit( qy = xp.fft.fftfreq(im_FFT.shape[1], sy) qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 - alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None, :], qx[:, None]) + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_FFT = xp.zeros( + (alpha_FFT.size, self._aberrations_num) + ) for a0 in range(self._aberrations_num): m, n, a = self._aberrations_mn[a0] if n == 0: # Radially symmetric basis - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) / (m + 1) + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) / (m + 1) ).ravel() elif a == 0: # cos coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) ).ravel() else: # sin coef - self._aberrations_basis[:, a0] = ( - alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) ).ravel() # global scaling - self._aberrations_basis *= 2 * np.pi / self._wavelength - self._aberrations_surface_shape = alpha.shape + self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape_FFT = alpha_FFT.shape plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) - angular_mask = np.cos(8.0 * theta) ** 2 < 0.25 + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # CTF function + def calculate_CTF_FFT(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] + return xp.reshape(chi, alpha_shape) # Direct Shifts Fitting - elif fit_BF_shifts: + if fit_BF_shifts: # FFT coordinates sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) @@ -1476,14 +1485,12 @@ def aberration_fit( self._aberrations_basis *= 2 * np.pi / self._wavelength self._aberrations_surface_shape = alpha.shape - # CTF function - def calculate_CTF(alpha_shape, *coefs): - chi = xp.zeros_like(self._aberrations_basis[:, 0]) - for a0 in range(len(coefs)): - chi += coefs[a0] * self._aberrations_basis[:, a0] - return xp.reshape(chi, alpha_shape) - - self._calculate_CTF = calculate_CTF + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) @@ -1497,7 +1504,7 @@ def calculate_CTF(alpha_shape, *coefs): # scoring function to minimize - mean value of zero crossing regions of FFT def score_CTF(coefs): im_CTF = xp.abs( - self._calculate_CTF(self._aberrations_surface_shape, *coefs) + calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) ) mask = xp.logical_and( im_CTF > 0.5 * np.pi, @@ -1506,7 +1513,9 @@ def score_CTF(coefs): if np.any(mask): weights = xp.cos(im_CTF[mask]) ** 4 return asnumpy( - xp.sum(weights * im_FFT[mask] * alpha[mask] ** fit_power_alpha) + xp.sum( + weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha + ) / xp.sum(weights) ) else: @@ -1579,7 +1588,7 @@ def score_CTF(coefs): # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: # Generate FFT plotting image - im_scale = asnumpy(im_FFT * alpha**fit_power_alpha) + im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) int_vals = np.sort(im_scale.ravel()) int_range = ( int_vals[np.round(0.02 * im_scale.size).astype("int")], @@ -1598,8 +1607,8 @@ def score_CTF(coefs): im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) # Add CTF zero crossings - im_CTF = self._calculate_CTF( - self._aberrations_surface_shape, *self._aberrations_coefs + im_CTF = calculate_CTF_FFT( + self._aberrations_surface_shape_FFT, *self._aberrations_coefs ) im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 @@ -1744,6 +1753,47 @@ def score_CTF(coefs): xp._default_memory_pool.free_all_blocks() xp.clear_memo() + def _calculate_CTF(self, alpha_shape, sampling, *coefs): + xp = self._xp + + # FFT coordinates + sx, sy = sampling + qx = xp.fft.fftfreq(alpha_shape[0], sx) + qy = xp.fft.fftfreq(alpha_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + + # global scaling + aberrations_basis *= 2 * np.pi / self._wavelength + + chi = xp.zeros_like(aberrations_basis[:, 0]) + + for a0 in range(len(coefs)): + chi += coefs[a0] * aberrations_basis[:, a0] + + return xp.reshape(chi, alpha_shape) + def aberration_correct( self, use_CTF_fit=None, @@ -1809,9 +1859,7 @@ def aberration_correct( if use_CTF_fit: sin_chi = np.sin( - self._calculate_CTF( - self._aberrations_surface_shape, *self._aberrations_coefs - ) + self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) ) CTF_corr = xp.sign(sin_chi) From 52ee427f9403839c02eb4a7de83719fa8c19585c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 04:09:47 -0700 Subject: [PATCH 094/176] cleaned up parallax --- py4DSTEM/process/phase/iterative_parallax.py | 110 +++++++++++++++---- 1 file changed, 86 insertions(+), 24 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 1a03f4fa3..34631256c 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -10,6 +10,7 @@ import numpy as np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import DataCube from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction @@ -29,6 +30,23 @@ warnings.simplefilter(action="always", category=UserWarning) +_aberration_names = { + (1, 0): "-defocus ", + (1, 2): "stig ", + (2, 1): "coma ", + (2, 3): "trefoil ", + (3, 0): "Cs ", + (3, 2): "stig2 ", + (3, 4): "quadfoil ", + (4, 1): "coma2 ", + (4, 3): "trefoil2 ", + (4, 5): "pentafoil ", + (5, 0): "C5 ", + (5, 2): "stig3 ", + (5, 4): "quadfoil2 ", + (5, 6): "hexafoil ", +} + class ParallaxReconstruction(PhaseReconstruction): """ @@ -40,9 +58,6 @@ class ParallaxReconstruction(PhaseReconstruction): Input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - dp_mean: ndarray, optional - Mean diffraction pattern - If None, get_dp_mean() is used verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -122,6 +137,7 @@ def to_h5(self, group): if hasattr(self, "aberration_C1"): recon_metadata |= { "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_transpose": self.transpose_detected, "aberration_C1": self.aberration_C1, "aberration_A1x": self.aberration_A1x, "aberration_A1y": self.aberration_A1y, @@ -136,6 +152,15 @@ def to_h5(self, group): data=self._asnumpy(self._recon_BF_subpixel_aligned), ) + if hasattr(self, "aberration_dict"): + self.metadata = Metadata( + name="aberrations_metadata", + data={ + v["common name"]: v["value [Ang]"] + for k, v in self.aberration_dict.items() + }, + ) + self.metadata = Metadata( name="reconstruction_metadata", data=recon_metadata, @@ -212,6 +237,7 @@ def _populate_instance(self, group): if "aberration_C1" in reconstruction_md.keys: self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.transpose_detected = reconstruction_md["aberration_transpose"] self.aberration_C1 = reconstruction_md["aberration_C1"] self.aberration_A1x = reconstruction_md["aberration_A1x"] self.aberration_A1y = reconstruction_md["aberration_A1y"] @@ -327,9 +353,6 @@ def preprocess( intensities = asnumpy(self._intensities) intensities_shifted = np.zeros_like(intensities) - # center_x = np.mean(com_fitted_x) - # center_y = np.mean(com_fitted_y) - center_x, center_y = self._region_of_interest_shape / 2 for rx in range(intensities_shifted.shape[0]): @@ -706,8 +729,6 @@ def tune_angle_and_defocus( convergence.append(asnumpy(self._recon_error[0])) if plot_convergence: - from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - fig, ax = plt.subplots() ax.set_title("convergence") im = ax.imshow( @@ -1287,19 +1308,29 @@ def aberration_fit( Parameters ---------- + fit_BF_shifts: bool + Set to True to fit aberrations to the measured BF shifts directly. fit_CTF_FFT: bool - Set to True to directly fit aberrations in the FFT of the upsampled BF - image (if available). Note that this method relies on visible zero - crossings in the FFT, and will not work if they are not present. - fit_upsampled_FFT: bool - If True, we aberration fit is performed on the upsampled BF image. - This option does nothing if fit_thon_rings is not True. - fit_aber_order_max: int + Set to True to fit aberrations in the FFT of the (upsampled) BF + image. Note that this method relies on visible zero crossings in the FFT. + fit_aberrations_max_radial_order: int Max radial order for fitting of aberrations. - ctf_threshold: float - CTF fitting minimizes value at CTF zero crossings (Thon ring minima). - plot_CTF_compare: bool, optional + fit_aberrations_max_angular_order: int + Max angular order for fitting of aberrations. + fit_aberrations_min_radial_order: int + Min radial order for fitting of aberrations. + fit_aberrations_min_angular_order: int + Min angular order for fitting of aberrations. + fit_max_thon_rings: int + Max number of Thon rings to search for during CTF FFT fitting. + fit_power_alpha: int + Power to raise FFT alpha weighting during CTF FFT fitting. + plot_CTF_comparison: bool, optional If True, the fitted CTF is plotted against the reconstructed frequencies. + plot_BF_shifts_comparison: bool, optional + If True, the measured vs fitted BF shifts are plotted. + upsampled: bool + If True, and upsampled BF is available, uses that for CTF FFT fitting. """ xp = self._xp @@ -1328,6 +1359,7 @@ def aberration_fit( self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + self.transpose_detected = False ### Second pass @@ -1353,6 +1385,9 @@ def aberration_fit( self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ np.argsort(self._aberrations_mn[sub, 0]), : ] + self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][ + np.argsort(self._aberrations_mn[~sub, 0]), : + ] self._aberrations_num = self._aberrations_mn.shape[0] if plot_CTF_comparison is None: @@ -1579,8 +1614,18 @@ def score_CTF(coefs): )[:2] if res_T.sum() < res.sum(): + self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T + self.transpose_detected = True self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T + + warnings.warn( + ( + "Data transpose detected. " + f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" + ), + UserWarning, + ) else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts @@ -1695,6 +1740,16 @@ def score_CTF(coefs): ], ) + self.aberration_dict = { + tuple(self._aberrations_mn[a0]): { + "common name": _aberration_names.get( + tuple(self._aberrations_mn[a0, :2]), "-" + ).strip(), + "value [Ang]": self._aberrations_coefs[a0], + } + for a0 in range(self._aberrations_num) + } + # Print results if self._verbose: if fit_CTF_FFT or fit_BF_shifts: @@ -1720,21 +1775,26 @@ def score_CTF(coefs): print() print("Refined Aberration coefficients") print("-------------------------------") - print("radial angular dir. coefs") - print("order order ") - print("------ ------- ---- -----") + print("common radial angular dir. coefs") + print("name order order Ang ") + print("---------- ------- ------- ---- -----") for a0 in range(self._aberrations_mn.shape[0]): m, n, a = self._aberrations_mn[a0] + name = _aberration_names.get((m, n), " -- ") if n == 0: print( - str(m) + name + + " " + + str(m) + " 0 - " + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) elif a == 0: print( - str(m) + name + + " " + + str(m) + " " + str(n) + " x " @@ -1742,7 +1802,9 @@ def score_CTF(coefs): ) else: print( - str(m) + name + + " " + + str(m) + " " + str(n) + " y " From 54d5859f9c58c78d2dc97aa75728081ac04d02f5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 07:15:51 -0700 Subject: [PATCH 095/176] ptycho new aberration formalism --- .../iterative_ptychographic_constraints.py | 2 + py4DSTEM/process/phase/utils.py | 80 ++++++++++++++----- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 4721ed12b..3eebdb068 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -555,10 +555,12 @@ def _probe_aberration_fitting_constraint( fourier_probe = xp.fft.fft2(current_probe) fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling + energy = self._energy fitted_angle, _ = fit_aberration_surface( fourier_probe, sampling, + energy, max_angular_order, max_radial_order, xp=xp, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index d06db111c..cc5fa8cb4 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1543,39 +1543,75 @@ def step_model(radius, sig_0, rad_0, width): def aberrations_basis_function( probe_size, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ + # mn = [[0,0,0]] + mn = [] + + for m in range(1, max_radial_order + 1): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(0, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + aberrations_mn = np.array(mn) + aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :] + + sub = aberrations_mn[:, 1] > 0 + aberrations_mn[sub, :] = aberrations_mn[sub, :][ + np.argsort(aberrations_mn[sub, 0]), : + ] + aberrations_mn[~sub, :] = aberrations_mn[~sub, :][ + np.argsort(aberrations_mn[~sub, 0]), : + ] + aberrations_num = aberrations_mn.shape[0] + sx, sy = probe_size dx, dy = probe_sampling + wavelength = electron_wavelength_angstrom(energy) + qx = xp.fft.fftfreq(sx, dx) qy = xp.fft.fftfreq(sy, dy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, aberrations_num)) + + for a0 in range(aberrations_num): + m, n, a = aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - qxa, qya = xp.meshgrid(qx, qy, indexing="ij") - q2 = qxa**2 + qya**2 - theta = xp.arctan2(qya, qxa) - - basis = [] - index = [] - - for n in range(max_angular_order + 1): - for m in range((max_radial_order - n) // 2 + 1): - basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) - index.append((m, n, 0)) - if n > 0: - basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) - index.append((m, n, 1)) - - basis = xp.array(basis) + # global scaling + aberrations_basis *= 2 * np.pi / wavelength - return basis, index + return aberrations_basis, aberrations_mn def fit_aberration_surface( complex_probe, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, @@ -1592,22 +1628,22 @@ def fit_aberration_surface( unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) - basis, _ = aberrations_basis_function( + raveled_basis, _ = aberrations_basis_function( complex_probe.shape, probe_sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - raveled_basis = basis.reshape((basis.shape[0], -1)) raveled_weights = probe_amp.ravel() - Aw = raveled_basis.T * raveled_weights[:, None] + Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights - coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] + coeff = -xp.linalg.lstsq(Aw, bw, rcond=None)[0] - fitted_angle = xp.tensordot(coeff, basis, axes=1) + fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) return fitted_angle, coeff From 9865f39a8475dc155a30fc6c9f1faafff6111ccf Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 07:58:47 -0700 Subject: [PATCH 096/176] adding chroma_boost defaults --- .../process/phase/iterative_base_class.py | 2 ++ ...tive_mixedstate_multislice_ptychography.py | 28 ++++++++++++++++--- .../iterative_mixedstate_ptychography.py | 28 +++++++++++++++---- .../iterative_multislice_ptychography.py | 25 ++++++++++++++--- .../iterative_overlap_magnetic_tomography.py | 5 ++++ .../phase/iterative_overlap_tomography.py | 21 ++++++++++++++ .../iterative_simultaneous_ptychography.py | 13 +++++++-- .../iterative_singleslice_ptychography.py | 21 ++++++++++++-- 8 files changed, 126 insertions(+), 17 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 62cf3a3a1..4dced4291 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2208,6 +2208,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 2) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2218,6 +2219,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost = chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 306f47f77..2747fe601 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -613,11 +613,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered[0], power=2, + chroma_boost = chroma_boost, ) # propagated @@ -630,6 +632,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -657,6 +660,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -669,7 +673,7 @@ def preprocess( divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax2) + add_colorbar_arg(cax2, chroma_boost=chroma_boost) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") ax2.set_title("Propagated probe[0] intensity") @@ -2502,6 +2506,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + if self._object_type == "complex": obj = np.angle(self.object) else: @@ -2595,12 +2604,12 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = Complex2RGB(self.probe_fourier[0]) + probe_array = Complex2RGB(self.probe_fourier[0],chroma_boost=chroma_boost) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe[0], power=2) + probe_array = Complex2RGB(self.probe[0], power=2, chroma_boost=chroma_boost) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2613,7 +2622,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2722,6 +2731,11 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + errors = np.array(self.error_iterations) objects = [] @@ -2825,6 +2839,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2833,6 +2848,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0]") ax.set_ylabel("x [A]") @@ -2846,6 +2862,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: @@ -2953,12 +2970,15 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost = chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 658079c3e..07b1fe9aa 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -510,11 +510,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -544,7 +546,7 @@ def preprocess( divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax) + add_colorbar_arg(cax, chroma_boost = chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), @@ -1847,6 +1849,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + if self._object_type == "complex": obj = np.angle(self.object) else: @@ -1939,6 +1946,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier[0], + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -1947,6 +1955,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe[0], power=2, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") @@ -1960,7 +1969,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost = chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2068,8 +2077,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2172,6 +2184,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2180,6 +2193,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") @@ -2192,7 +2206,8 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: @@ -2301,11 +2316,14 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost = chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 382efedcd..823c71ca0 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -608,11 +608,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) # propagated @@ -625,6 +627,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -650,7 +653,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -664,6 +667,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") @@ -2404,6 +2408,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2500,12 +2509,13 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2) + probe_array = Complex2RGB(self.probe, power=2, chroma_boost = chroma_boost) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2518,7 +2528,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost = chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2627,6 +2637,11 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + errors = np.array(self.error_iterations) objects = [] @@ -2730,12 +2745,13 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(probes[grid_range[n]], power=2) + probe_array = Complex2RGB(probes[grid_range[n]], power=2, chroma_boost = chroma_boost) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2748,6 +2764,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 459b0ae8c..5035ced81 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -813,11 +813,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) # propagated @@ -830,6 +832,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -857,6 +860,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -871,6 +875,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, + chroma_boost = chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index bb3ee09c2..193c4f5eb 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -753,11 +753,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) # propagated @@ -770,6 +772,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -797,6 +800,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -811,6 +815,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, + chroma_boost = chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") @@ -2585,6 +2590,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + asnumpy = self._asnumpy if projection_angle_deg is not None: @@ -2686,6 +2696,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2694,6 +2705,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -2710,6 +2722,7 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( ax_cb, + chroma_boost = chroma_boost, ) else: ax = fig.add_subplot(spec[0]) @@ -2827,6 +2840,11 @@ def _visualize_all_iterations( figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + errors = np.array(self.error_iterations) if projection_angle_deg is not None: @@ -2940,6 +2958,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2948,6 +2967,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2961,6 +2981,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 084a6fcb8..8af804325 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -753,11 +753,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -785,6 +787,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, + chroma_boost = chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -3078,6 +3081,11 @@ def _visualize_last_iteration( vmax_e = kwargs.pop("vmax_e", max_e) vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) extent = [ 0, @@ -3184,12 +3192,13 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2) + probe_array = Complex2RGB(self.probe, power=2,chroma_boost=chroma_boost) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -3202,7 +3211,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) else: # Electrostatic Object diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0dc2cd053..26020e8de 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -480,11 +480,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, + chroma_boost = chroma_boost, ) extent = [ @@ -510,7 +512,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1) + add_colorbar_arg(cax1,chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -1757,6 +1759,11 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + if self._object_type == "complex": obj = np.angle(self.object) else: @@ -1849,6 +1856,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -1857,6 +1865,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, + chroma_boost = chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -1870,7 +1879,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1978,6 +1987,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2081,6 +2095,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2090,6 +2105,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, + chroma_boost = chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2103,6 +2119,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], + chroma_boost = chroma_boost, ) if plot_convergence: From 020e170b57428fd48f6704dfd68a274d83f6f952 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 08:01:18 -0700 Subject: [PATCH 097/176] formatted, linted, isorted --- .../process/phase/iterative_base_class.py | 2 +- ...tive_mixedstate_multislice_ptychography.py | 26 +++++++++++-------- .../iterative_mixedstate_ptychography.py | 22 ++++++++-------- .../iterative_multislice_ptychography.py | 22 +++++++++------- .../iterative_overlap_magnetic_tomography.py | 8 +++--- .../phase/iterative_overlap_tomography.py | 20 +++++++------- py4DSTEM/process/phase/iterative_parallax.py | 5 ++-- .../iterative_simultaneous_ptychography.py | 24 +++++++---------- .../iterative_singleslice_ptychography.py | 18 ++++++------- 9 files changed, 75 insertions(+), 72 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 4dced4291..6e02dd598 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2219,7 +2219,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2747fe601..fca48b38c 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -619,7 +619,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered[0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -632,7 +632,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -660,7 +660,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -2604,12 +2604,16 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: - probe_array = Complex2RGB(self.probe_fourier[0],chroma_boost=chroma_boost) + probe_array = Complex2RGB( + self.probe_fourier[0], chroma_boost=chroma_boost + ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe[0], power=2, chroma_boost=chroma_boost) + probe_array = Complex2RGB( + self.probe[0], power=2, chroma_boost=chroma_boost + ) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2622,7 +2626,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2839,7 +2843,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2848,7 +2852,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0]") ax.set_ylabel("x [A]") @@ -2862,7 +2866,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2971,14 +2975,14 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" chroma_boost = kwargs.pop("chroma_boost", 2) - + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 07b1fe9aa..21a29b0b1 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -516,7 +516,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -546,7 +546,7 @@ def preprocess( divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax, chroma_boost = chroma_boost) + add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), @@ -1946,7 +1946,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier[0], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -1955,7 +1955,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe[0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") @@ -1969,7 +1969,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost = chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2077,7 +2077,7 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -2184,7 +2184,7 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") @@ -2193,7 +2193,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]][0], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") @@ -2207,7 +2207,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2317,13 +2317,13 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" chroma_boost = kwargs.pop("chroma_boost", 2) - + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 823c71ca0..a22fad715 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -614,7 +614,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -627,7 +627,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -2408,7 +2408,7 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -2509,13 +2509,15 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2, chroma_boost = chroma_boost) + probe_array = Complex2RGB( + self.probe, power=2, chroma_boost=chroma_boost + ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2528,7 +2530,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, chroma_boost = chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2745,13 +2747,15 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(probes[grid_range[n]], power=2, chroma_boost = chroma_boost) + probe_array = Complex2RGB( + probes[grid_range[n]], power=2, chroma_boost=chroma_boost + ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2764,7 +2768,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 5035ced81..af665baac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -819,7 +819,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -832,7 +832,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -860,7 +860,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -875,7 +875,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 193c4f5eb..d1c323a5d 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -759,7 +759,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) # propagated @@ -772,7 +772,7 @@ def preprocess( complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -800,7 +800,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -815,7 +815,7 @@ def preprocess( cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") @@ -2696,7 +2696,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2705,7 +2705,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -2722,7 +2722,7 @@ def _visualize_last_iteration( ax_cb = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( ax_cb, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) else: ax = fig.add_subplot(spec[0]) @@ -2958,7 +2958,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2967,7 +2967,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2981,7 +2981,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 34631256c..828fd12a2 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -11,7 +11,7 @@ from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable -from py4DSTEM import DataCube +from py4DSTEM import Calibration, DataCube from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import AffineTransform @@ -19,9 +19,8 @@ from py4DSTEM.process.utils.utils import electron_wavelength_angstrom from py4DSTEM.visualize import show from scipy.linalg import polar +from scipy.optimize import minimize from scipy.special import comb -from scipy.optimize import curve_fit, minimize -from scipy.signal import medfilt2d try: import cupy as cp diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8af804325..eb900d5d0 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -485,10 +485,7 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, - com_fitted_x_1, - com_fitted_y_1, - crop_patterns + intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns ) # explicitly delete namescapes @@ -570,10 +567,7 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, - com_fitted_x_2, - com_fitted_y_2, - crop_patterns + intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns ) # explicitly delete namescapes @@ -759,7 +753,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -787,7 +781,7 @@ def preprocess( cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( cax1, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") @@ -3081,7 +3075,7 @@ def _visualize_last_iteration( vmax_e = kwargs.pop("vmax_e", max_e) vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -3192,13 +3186,15 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: - probe_array = Complex2RGB(self.probe, power=2,chroma_boost=chroma_boost) + probe_array = Complex2RGB( + self.probe, power=2, chroma_boost=chroma_boost + ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -3211,7 +3207,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: # Electrostatic Object diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 26020e8de..547117f8d 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -486,7 +486,7 @@ def preprocess( complex_probe_rgb = Complex2RGB( self.probe_centered, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) extent = [ @@ -512,7 +512,7 @@ def preprocess( divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(cax1,chroma_boost=chroma_boost) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") ax1.set_title("Initial probe intensity") @@ -1856,7 +1856,7 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( self.probe_fourier, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") @@ -1865,7 +1865,7 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe, power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") @@ -1879,7 +1879,7 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb,chroma_boost=chroma_boost) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1987,7 +1987,7 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - + if plot_fourier_probe: chroma_boost = kwargs.pop("chroma_boost", 2) else: @@ -2095,7 +2095,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2105,7 +2105,7 @@ def _visualize_all_iterations( probe_array = Complex2RGB( probes[grid_range[n]], power=2, - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") @@ -2119,7 +2119,7 @@ def _visualize_all_iterations( if cbar: add_colorbar_arg( grid.cbar_axes[n], - chroma_boost = chroma_boost, + chroma_boost=chroma_boost, ) if plot_convergence: From 9806e277b7892f1a3eb8f2b54f390b0c1e326cf5 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 20 Oct 2023 16:15:08 -0700 Subject: [PATCH 098/176] fixing radial order accounting --- py4DSTEM/process/phase/iterative_parallax.py | 18 +++++++++--------- py4DSTEM/process/phase/utils.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 8b2b007f5..716b84342 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -34,7 +34,7 @@ (1, 2): "stig ", (2, 1): "coma ", (2, 3): "trefoil ", - (3, 0): "Cs ", + (3, 0): "C3 ", (3, 2): "stig2 ", (3, 4): "quadfoil ", (4, 1): "coma2 ", @@ -155,7 +155,7 @@ def to_h5(self, group): self.metadata = Metadata( name="aberrations_metadata", data={ - v["common name"]: v["value [Ang]"] + v["aberration name"]: v["value [Ang]"] for k, v in self.aberration_dict.items() }, ) @@ -1294,7 +1294,7 @@ def aberration_fit( fit_CTF_FFT: bool = False, fit_aberrations_max_radial_order: int = 3, fit_aberrations_max_angular_order: int = 4, - fit_aberrations_min_radial_order: int = 1, + fit_aberrations_min_radial_order: int = 2, fit_aberrations_min_angular_order: int = 0, fit_max_thon_rings: int = 6, fit_power_alpha: float = 2.0, @@ -1366,7 +1366,7 @@ def aberration_fit( mn = [] for m in range( - fit_aberrations_min_radial_order, fit_aberrations_max_radial_order + 1 + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order ): n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) for n in range(fit_aberrations_min_angular_order, n_max + 1): @@ -1741,7 +1741,7 @@ def score_CTF(coefs): self.aberration_dict = { tuple(self._aberrations_mn[a0]): { - "common name": _aberration_names.get( + "aberration name": _aberration_names.get( tuple(self._aberrations_mn[a0, :2]), "-" ).strip(), "value [Ang]": self._aberrations_coefs[a0], @@ -1774,7 +1774,7 @@ def score_CTF(coefs): print() print("Refined Aberration coefficients") print("-------------------------------") - print("common radial angular dir. coefs") + print("aberration radial angular dir. coefs") print("name order order Ang ") print("---------- ------- ------- ---- -----") @@ -1785,7 +1785,7 @@ def score_CTF(coefs): print( name + " " - + str(m) + + str(m + 1) + " 0 - " + str(np.round(self._aberrations_coefs[a0]).astype("int")) ) @@ -1793,7 +1793,7 @@ def score_CTF(coefs): print( name + " " - + str(m) + + str(m + 1) + " " + str(n) + " x " @@ -1803,7 +1803,7 @@ def score_CTF(coefs): print( name + " " - + str(m) + + str(m + 1) + " " + str(n) + " y " diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index cc5fa8cb4..7e348826e 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1552,7 +1552,7 @@ def aberrations_basis_function( # mn = [[0,0,0]] mn = [] - for m in range(1, max_radial_order + 1): + for m in range(1, max_radial_order): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: From 2eca4b7eb1d5f74947148cd87f745af0d05e9c38 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 20 Oct 2023 16:48:36 -0700 Subject: [PATCH 099/176] make lint happy I hope! --- py4DSTEM/visualize/vis_special.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 19e0c5c7a..cfa017299 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -18,6 +18,7 @@ from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR from colorspacious import cspace_convert + def show_elliptical_fit( ar, fitradii, From 71b8f6d9039bce7d5177341b74516c62cd475b5a Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 20 Oct 2023 17:28:32 -0700 Subject: [PATCH 100/176] fix extent for ms depth sectioning --- .../process/phase/iterative_multislice_ptychography.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 576b431bf..4515590fe 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3095,9 +3095,9 @@ def show_depth( self._slice_thicknesses[0] * plot_im.shape[0], 0, ] - + figsize = kwargs.pop("figsize", (6, 6)) if not plot_line_profile: - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=figsize) im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax.set_aspect(aspect) @@ -3112,11 +3112,12 @@ def show_depth( else: extent2 = [ 0, - self.sampling[0] * ms_obj.shape[1], self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], 0, ] - fig, ax = plt.subplots(2, 1) + + fig, ax = plt.subplots(2, 1, figsize=figsize) ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) ax[0].plot( [y1 * self.sampling[0], y2 * self.sampling[1]], From dfc312b190b8144655a1f49005aacbfb3510684e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 21 Oct 2023 10:28:32 -0700 Subject: [PATCH 101/176] small fixes --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- .../phase/iterative_mixedstate_multislice_ptychography.py | 7 ++++--- .../process/phase/iterative_overlap_magnetic_tomography.py | 2 +- py4DSTEM/process/phase/iterative_overlap_tomography.py | 2 +- py4DSTEM/process/phase/iterative_parallax.py | 1 + 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 66dc3a8d6..906e9add1 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2243,7 +2243,7 @@ def show_object_fft(self, obj=None, **kwargs): vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index fca48b38c..3eeb07814 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3237,8 +3237,9 @@ def show_depth( 0, ] + figsize = kwargs.pop("figsize", (6, 6)) if not plot_line_profile: - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=figsize) im = ax.imshow(plot_im, cmap="magma", extent=extent) if aspect is not None: ax.set_aspect(aspect) @@ -3253,11 +3254,11 @@ def show_depth( else: extent2 = [ 0, - self.sampling[0] * ms_obj.shape[1], self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], 0, ] - fig, ax = plt.subplots(2, 1) + fig, ax = plt.subplots(2, 1, figsize=figsize) ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) ax[0].plot( [y1 * self.sampling[0], y2 * self.sampling[1]], diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 57e42a366..32b0f6fd4 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3292,7 +3292,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index ab37dfad5..66cf46487 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3172,7 +3172,7 @@ def show_object_fft( vmax = kwargs.pop("vmax", 1) power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 716b84342..a69dece3b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1874,6 +1874,7 @@ def aberration_correct( ---------- use_FFT_fit: bool Use the CTF fitted to the zero crossings of the FFT. + Default is True plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional From b0e2c4244b1241505315b5a7bd2e555fff4b07d2 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 21 Oct 2023 11:46:35 -0700 Subject: [PATCH 102/176] fix for ptycho aberration fit --- py4DSTEM/process/phase/utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 7e348826e..93428f5bb 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1549,10 +1549,10 @@ def aberrations_basis_function( xp=np, ): """ """ - # mn = [[0,0,0]] - mn = [] + mn = [[0,0,0]] + # mn = [] - for m in range(1, max_radial_order): + for m in range(max_radial_order+1): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: @@ -1583,9 +1583,9 @@ def aberrations_basis_function( theta = xp.arctan2(qy[None, :], qx[:, None]) # Aberration basis - aberrations_basis = xp.zeros((alpha.size, aberrations_num)) + aberrations_basis = xp.ones((alpha.size, aberrations_num)) - for a0 in range(aberrations_num): + for a0 in range(1,aberrations_num): m, n, a = aberrations_mn[a0] if n == 0: # Radially symmetric basis @@ -1641,7 +1641,7 @@ def fit_aberration_surface( Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights - coeff = -xp.linalg.lstsq(Aw, bw, rcond=None)[0] + coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) From 17dd9a2ef212fadf8f282b97ba5934fea21cc042 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sat, 21 Oct 2023 11:48:32 -0700 Subject: [PATCH 103/176] black format --- py4DSTEM/process/phase/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 93428f5bb..374e3fc15 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1549,10 +1549,10 @@ def aberrations_basis_function( xp=np, ): """ """ - mn = [[0,0,0]] + mn = [[0, 0, 0]] # mn = [] - for m in range(max_radial_order+1): + for m in range(max_radial_order + 1): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: @@ -1585,7 +1585,7 @@ def aberrations_basis_function( # Aberration basis aberrations_basis = xp.ones((alpha.size, aberrations_num)) - for a0 in range(1,aberrations_num): + for a0 in range(1, aberrations_num): m, n, a = aberrations_mn[a0] if n == 0: # Radially symmetric basis From ada4d4d8a2149b35c4ef03011477a024ee4a7ced Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 21 Oct 2023 17:17:10 -0700 Subject: [PATCH 104/176] fixed ptycho fitting, added transpose flag in parallax --- py4DSTEM/process/phase/iterative_parallax.py | 98 ++++++++++++-------- py4DSTEM/process/phase/utils.py | 8 +- 2 files changed, 65 insertions(+), 41 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index a69dece3b..dcfd8f504 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -30,7 +30,7 @@ warnings.simplefilter(action="always", category=UserWarning) _aberration_names = { - (1, 0): "-defocus ", + (1, 0): "C1 ", (1, 2): "stig ", (2, 1): "coma ", (2, 3): "trefoil ", @@ -1290,7 +1290,7 @@ def subpixel_alignment( def aberration_fit( self, - fit_BF_shifts: bool = True, + fit_BF_shifts: bool = False, fit_CTF_FFT: bool = False, fit_aberrations_max_radial_order: int = 3, fit_aberrations_max_angular_order: int = 4, @@ -1301,6 +1301,7 @@ def aberration_fit( plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, upsampled: bool = True, + force_transpose: bool = None, ): """ Fit aberrations to the measured image shifts. @@ -1330,6 +1331,8 @@ def aberration_fit( If True, the measured vs fitted BF shifts are plotted. upsampled: bool If True, and upsampled BF is available, uses that for CTF FFT fitting. + force_transpose: bool + If True, and fit_BF_shifts is True, flips the measured x and y shifts """ xp = self._xp @@ -1358,7 +1361,11 @@ def aberration_fit( self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - self.transpose_detected = False + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose ### Second pass @@ -1583,48 +1590,63 @@ def score_CTF(coefs): ) ) - # Untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None - )[:2] - - # Transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) - m_T = asnumpy( - xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[0] - ) - m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2( - m_rotation_T[1, 0], m_rotation_T[0, 0] - ) - if np.abs(np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi) > ( - np.pi * 0.5 - ): - rotation_Q_to_R_rads_T = ( - np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + if force_transpose is None or force_transpose is True: + # Transposed fit + transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) + m_T = asnumpy( + xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ + 0 + ] ) + m_rotation_T, _ = polar(m_T, side="right") + rotation_Q_to_R_rads_T = -1 * np.arctan2( + m_rotation_T[1, 0], m_rotation_T[0, 0] + ) + if np.abs( + np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi + ) > (np.pi * 0.5): + rotation_Q_to_R_rads_T = ( + np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi + ) - tf = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf(transposed_shifts, xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq( - gradients, rotated_shifts_T, rcond=None - )[:2] + tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) + rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() + aberrations_coefs_T, res_T = xp.linalg.lstsq( + gradients, rotated_shifts_T, rcond=None + )[:2] + + if force_transpose is None or force_transpose is False: + # Untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # Compare fits + if res_T.sum() < res.sum(): + self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T + self.transpose_detected = True + self._aberrations_coefs = asnumpy(aberrations_coefs_T) + self._rotated_shifts = rotated_shifts_T + + warnings.warn( + ( + "Data transpose detected. " + f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" + ), + UserWarning, + ) + else: + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts - if res_T.sum() < res.sum(): + elif force_transpose is True: self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = True self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T - warnings.warn( - ( - "Data transpose detected. " - f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" - ), - UserWarning, - ) else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 374e3fc15..d29765d04 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1549,10 +1549,11 @@ def aberrations_basis_function( xp=np, ): """ """ - mn = [[0, 0, 0]] - # mn = [] - for m in range(max_radial_order + 1): + # Add constant phase shift in basis + mn = [[-1, 0, 0]] + + for m in range(1, max_radial_order): n_max = np.minimum(max_angular_order, m + 1) for n in range(0, n_max + 1): if (m + n) % 2: @@ -1585,6 +1586,7 @@ def aberrations_basis_function( # Aberration basis aberrations_basis = xp.ones((alpha.size, aberrations_num)) + # Skip constant to avoid dividing by zero in normalization for a0 in range(1, aberrations_num): m, n, a = aberrations_mn[a0] if n == 0: From faac2b20d37947bb7a53ba8f200ebd2745a40bf6 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Sun, 22 Oct 2023 12:43:57 +0100 Subject: [PATCH 105/176] versions emdfile --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c3cbbd151..a6b1ce061 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ "gdown >= 4.7.1", "dask >= 2.3.0", "distributed >= 2.3.0", - "emdfile >= 0.0.13", + "emdfile >= 0.0.14", "mpire >= 2.7.1", "threadpoolctl >= 3.1.0", ], From 9529945ee4f899277b6338a5e9de484235e88b70 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Sun, 22 Oct 2023 11:39:59 -0700 Subject: [PATCH 106/176] added force_transpose option for other two aberration fit methods --- py4DSTEM/process/phase/iterative_parallax.py | 49 +++++++++----------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index dcfd8f504..6ebb9962e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1341,7 +1341,18 @@ def aberration_fit( ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) # Solve affine transformation m = asnumpy( @@ -1362,11 +1373,6 @@ def aberration_fit( self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if force_transpose is None: - self.transpose_detected = False - else: - self.transpose_detected = force_transpose - ### Second pass # Aberration coefs @@ -1590,8 +1596,15 @@ def score_CTF(coefs): ) ) - if force_transpose is None or force_transpose is True: - # Transposed fit + # (Relative) untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # (Relative) transposed fit transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) m_T = asnumpy( xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ @@ -1615,19 +1628,10 @@ def score_CTF(coefs): gradients, rotated_shifts_T, rcond=None )[:2] - if force_transpose is None or force_transpose is False: - # Untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None - )[:2] - - if force_transpose is None: # Compare fits if res_T.sum() < res.sum(): self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = True + self.transpose_detected = not self.transpose_detected self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T @@ -1638,15 +1642,6 @@ def score_CTF(coefs): ), UserWarning, ) - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts - - elif force_transpose is True: - self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self._aberrations_coefs = asnumpy(aberrations_coefs_T) - self._rotated_shifts = rotated_shifts_T - else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts From 43220d0b68705ec8d59ad94b7ff68446d7738255 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sun, 22 Oct 2023 12:47:16 -0700 Subject: [PATCH 107/176] read-write device bugfix --- .../process/phase/iterative_base_class.py | 51 ++++++++++++++++++- py4DSTEM/process/phase/iterative_dpc.py | 4 +- py4DSTEM/process/phase/iterative_parallax.py | 4 +- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 906e9add1..04cfd6a60 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -56,6 +56,53 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self + def reinitialize_parameters(self, device: str = None, verbose: bool = None): + """ + Reinitializes common parameters. This is useful when loading a previously-saved + reconstruction (which set device='cpu' and verbose=True for compatibility) , + using different initialization parameters. + + Parameters + ---------- + device: str, optional + If not None, imports and assigns appropriate device modules + verbose: bool, optional + If not None, sets the verbosity to verbose + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device is not None: + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device + + if verbose is not None: + self._verbose = verbose + + return self + def set_save_defaults( self, save_datacube: bool = False, @@ -1408,10 +1455,10 @@ def _get_constructor_args(cls, group): "object_type": instance_md["object_type"], "semiangle_cutoff": instance_md["semiangle_cutoff"], "rolloff": instance_md["rolloff"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], "polar_parameters": polar_params, + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } class_specific_kwargs = {} diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4ca2c170f..af3cbbb45 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -195,9 +195,9 @@ def _get_constructor_args(cls, group): "datacube": dc, "initial_object_guess": np.asarray(obj), "energy": instance_md["energy"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 6ebb9962e..74688fa0b 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -206,10 +206,10 @@ def _get_constructor_args(cls, group): kwargs = { "datacube": dc, "energy": instance_md["energy"], - "verbose": instance_md["verbose"], - "device": instance_md["device"], "object_padding_px": instance_md["object_padding_px"], "name": instance_md["name"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs From 2493514195980705180ea597929a667fdcdf14e4 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 03:25:12 -0400 Subject: [PATCH 108/176] bugfix to update_version.py --- .github/scripts/update_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/update_version.py b/.github/scripts/update_version.py index 2aaaa07af..635cf8268 100644 --- a/.github/scripts/update_version.py +++ b/.github/scripts/update_version.py @@ -8,7 +8,7 @@ lines = f.readlines() line_split = lines[0].split(".") -patch_number = line_split[2].split("'")[0] +patch_number = line_split[2].split("'")[0].split('"')[0] # Increment patch number patch_number = str(int(patch_number) + 1) + "'" From 18c1b8c900e4edd94ff80abd8a0c4da431edc10c Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 03:28:44 -0400 Subject: [PATCH 109/176] Manually update version.py --- py4DSTEM/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 224f1fb74..141826d55 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.4" +__version__ = "0.14.5" From b134f848f7b502a776ecd4c4257404c23257b27b Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 12:03:06 +0100 Subject: [PATCH 110/176] bugfix --- py4DSTEM/visualize/vis_special.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index cfa017299..ba0ee024a 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -156,16 +156,16 @@ def show_amorphous_ring_fit( mask=np.logical_not(mask), mask_color="empty", returnfig=True, - returnclipvals=True, + return_intensity_range=True, **kwargs, ) show( fit, scaling=scaling, figax=(fig, ax), - clipvals="manual", - min=vmin, - max=vmax, + intensity_range="absolute", + vmin=vmin, + vmax=vmax, cmap=cmap_fit, mask=mask, mask_color="empty", From e76bd2a39e3a29524021db2adf5f2dc269067bbb Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 12:11:49 +0100 Subject: [PATCH 111/176] rms import * --- py4DSTEM/process/strain/__init__.py | 12 +++++++++++- py4DSTEM/process/strain/latticevectors.py | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py index b47682aa4..213d5e812 100644 --- a/py4DSTEM/process/strain/__init__.py +++ b/py4DSTEM/process/strain/__init__.py @@ -1,2 +1,12 @@ from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process.strain.latticevectors import * +from py4DSTEM.process.strain.latticevectors import ( + index_bragg_directions, + add_indices_to_braggvectors, + fit_lattice_vectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_strain_from_reference_g1g2, + get_rotated_strain_map, + +) + diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py index 26c8d66a5..30e5cc989 100644 --- a/py4DSTEM/process/strain/latticevectors.py +++ b/py4DSTEM/process/strain/latticevectors.py @@ -456,3 +456,4 @@ def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data return rotated_strain_map + From 78050bd7ea010306f7f1c257e6a09af2b001e452 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 13:58:29 +0100 Subject: [PATCH 112/176] bugfixes to strainmapping --- py4DSTEM/process/strain/strain.py | 155 ++++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 47545c04b..d0848182f 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -from py4DSTEM import PointList +from py4DSTEM import PointList, PointListArray, tqdmnd from py4DSTEM.braggvectors import BraggVectors from py4DSTEM.data import Data, RealSlice from py4DSTEM.preprocess.utils import get_maxima_2D @@ -73,7 +73,7 @@ def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap" assert self.calstate["center"], "braggvectors must be centered" if self.calstate["rotate"] == False: warnings.warn( - ("Real to reciprocal space rotaiton not calibrated"), + ("Real to reciprocal space rotation not calibrated"), UserWarning, ) @@ -98,6 +98,19 @@ def braggvectors(self, x): self._braggvectors = x self._braggvectors.tree(self, force=True) + @property + def rshape(self): + return self._braggvectors.Rshape + + @property + def qshape(self): + return self._braggvectors.Qshape + + @property + def origin(self): + return self.calibration.get_origin_mean() + + def reset_calstate(self): """ Resets the calibration state. This recomputes the BVM, and removes any computations @@ -117,9 +130,9 @@ def reset_calstate(self): def choose_lattice_vectors( self, - index_g0, - index_g1, - index_g2, + index_g1 = None, + index_g2 = None, + index_origin = None, subpixel="multicorr", upsample_factor=16, sigma=0, @@ -155,12 +168,12 @@ def choose_lattice_vectors( Parameters ---------- - index_g0 : int - selected index for the origin index_g1 : int selected index for g1 index_g2 :int selected index for g2 + index_origin : int + selected index for the origin subpixel : str in ('pixel','poly','multicorr') See the docstring for py4DSTEM.preprocess.get_maxima_2D upsample_factor : int @@ -211,8 +224,8 @@ def choose_lattice_vectors( (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter """ # validate inputs - for i in (index_g0, index_g1, index_g2): - assert isinstance(i, (int, np.integer)), "indices must be integers!" + for i in (index_origin, index_g1, index_g2): + assert(isinstance(i, (int, np.integer)) or (i is None)), "indices must be integers!" # check the calstate assert ( self.calstate == self.braggvectors.calstate @@ -233,31 +246,43 @@ def choose_lattice_vectors( maxNumPeaks=maxNumPeaks, ) + # guess the origin and g1 g2 vectors if indices aren't provided + if np.any([x is None for x in (index_g1,index_g2,index_origin)]): + + # get distances and angles from calibrated origin + g_dists = np.hypot(g['x']-self.origin[0], g['y']-self.origin[1]) + g_angles = np.angle(g['x']-self.origin[0] + 1j*(g['y']-self.origin[1])) + + # guess the origin + if index_origin is None: + index_origin = np.argmin(g_dists) + g_dists[index_origin] = 2*np.max(g_dists) + + # guess g1 + if index_g1 is None: + index_g1 = np.argmin(g_dists) + g_dists[index_g1] = 2*np.max(g_dists) + + # guess g2 + if index_g2 is None: + angle_scaling = np.cos(g_angles - g_angles[index_g1])**2 + index_g2 = np.argmin(g_dists*(angle_scaling+0.1)) + + # get the lattice vectors gx, gy = g["x"], g["y"] - g0 = gx[index_g0], gy[index_g0] + g0 = gx[index_origin], gy[index_origin] g1x = gx[index_g1] - g0[0] g1y = gy[index_g1] - g0[1] g2x = gx[index_g2] - g0[0] g2y = gy[index_g2] - g0[1] g1, g2 = (g1x, g1y), (g2x, g2y) - # if x0 is None: - # x0 = self.braggvectors.Qshape[0] / 2 - # if y0 is None: - # y0 = self.braggvectors.Qshape[0] / 2 - - # index braggvectors - # _, _, braggdirections = index_bragg_directions( - # x0, y0, g["x"], g["y"], g1, g2 - # ) - + # index the lattice vectors _, _, braggdirections = index_bragg_directions( g0[0], g0[1], g["x"], g["y"], g1, g2 ) - self.braggdirections = braggdirections - # make the figure fig, ax = plt.subplots(1, 3, figsize=figsize) show(self.bvm.data, figax=(fig, ax[0]), **vis_params) @@ -274,12 +299,12 @@ def choose_lattice_vectors( # Add indices to left panel d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} d0 = { - "x": gx[index_g0], - "y": gy[index_g0], + "x": gx[index_origin], + "y": gy[index_origin], "size": size_indices, "color": c0, "fontweight": "bold", - "labels": [str(index_g0)], + "labels": [str(index_origin)], } d1 = { "x": gx[index_g1], @@ -304,8 +329,8 @@ def choose_lattice_vectors( # Add vectors to right panel dg1 = { - "x0": gx[index_g0], - "y0": gy[index_g0], + "x0": gx[index_origin], + "y0": gy[index_origin], "vx": g1[0], "vy": g1[1], "width": width_vectors, @@ -315,8 +340,8 @@ def choose_lattice_vectors( "labelcolor": c_vectorlabels, } dg2 = { - "x0": gx[index_g0], - "y0": gy[index_g0], + "x0": gx[index_origin], + "y0": gy[index_origin], "vx": g2[0], "vy": g2[1], "width": width_vectors, @@ -334,6 +359,11 @@ def choose_lattice_vectors( self.g1 = g1 self.g2 = g2 + # center the bragg directions and store + braggdirections.data['qx'] -= self.origin[0] + braggdirections.data['qy'] -= self.origin[1] + self.braggdirections = braggdirections + # return if returncalc and returnfig: return (g0, g1, g2), (fig, ax) @@ -381,23 +411,68 @@ def fit_lattice_vectors( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - bragg_vectors_indexed = add_indices_to_braggvectors( - self.braggvectors, - self.braggdirections, - maxPeakSpacing=max_peak_spacing, - qx_shift=self.braggvectors.Qshape[0] / 2, - qy_shift=self.braggvectors.Qshape[1] / 2, - mask=mask, - ) - self.bragg_vectors_indexed = bragg_vectors_indexed + ### add indices to the bragg vectors - # fit bragg vectors + # validate mask + if mask is None: + mask = np.ones(self.braggvectors.Rshape, dtype=bool) + assert ( + mask.shape == self.braggvectors.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + + # set up new braggpeaks PLA + indexed_braggpeaks = PointListArray( + dtype = [ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ], + shape=self.braggvectors.Rshape, + ) + calstate = self.braggvectors.calstate + + # loop over all the scan positions + for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + if mask[Rx, Ry]: + pl = self.braggvectors.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r = np.hypot( + pl.data["qx"][i]-self.braggdirections.data["qx"], + pl.data["qy"][i]-self.braggdirections.data["qy"] + ) + ind = np.argmin(r) + if r[ind] <= max_peak_spacing: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + self.braggdirections.data["h"][ind], + self.braggdirections.data["k"][ind], + ) + ) + self.bragg_vectors_indexed = indexed_braggpeaks + + + ### fit bragg vectors g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) self.g1g2_map = g1g2_map + + # return if returncalc: - self.braggdirections, self.bragg_vectors_indexed, self.g1g2_map + return self.braggdirections, self.bragg_vectors_indexed, self.g1g2_map def get_strain( self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs From 85e214174c3c993f41092ae9471f1317d0fb13fc Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 16:09:41 +0100 Subject: [PATCH 113/176] strain map updates --- py4DSTEM/process/strain/strain.py | 362 ++++++++++++++++++++++++++++-- py4DSTEM/visualize/vis_special.py | 301 ------------------------- 2 files changed, 341 insertions(+), 322 deletions(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index d0848182f..cd662b8f1 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -4,6 +4,7 @@ from typing import Optional import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np from py4DSTEM import PointList, PointListArray, tqdmnd from py4DSTEM.braggvectors import BraggVectors @@ -18,6 +19,7 @@ index_bragg_directions, ) from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show +from py4DSTEM.visualize import ax_addaxes, ax_addaxes_QtoR warnings.simplefilter(action="always", category=UserWarning) @@ -110,6 +112,12 @@ def qshape(self): def origin(self): return self.calibration.get_origin_mean() + @property + def mask(self): + try: + return self.g1g2_map['mask'].data.astype('bool') + except: + return np.ones(self.rshape, dtype=bool) def reset_calstate(self): """ @@ -366,9 +374,9 @@ def choose_lattice_vectors( # return if returncalc and returnfig: - return (g0, g1, g2), (fig, ax) + return (self.g0, self.g1, self.g2, self.braggdirections), (fig, ax) elif returncalc: - return (g0, g1, g2) + return (self.g0, self.g1, self.g2, self.braggdirections) elif returnfig: return (fig, ax) else: @@ -378,8 +386,6 @@ def fit_lattice_vectors( self, max_peak_spacing=2, mask=None, - plot=True, - vis_params={}, returncalc=False, ): """ @@ -388,10 +394,6 @@ def fit_lattice_vectors( reciprocal lattice directions. Args: - x0 : floagt - x-coord of origin - y0 : float - y-coord of origin max_peak_spacing: float Maximum distance from the ideal lattice points to include a peak for indexing @@ -399,10 +401,6 @@ def fit_lattice_vectors( Boolean mask, same shape as the pointlistarray, indicating which locations should be indexed. This can be used to index different regions of the scan with different lattices - plot:bool - plot results if tru - vis_params : dict - additional visualization parameters passed to `show` returncalc : bool if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map """ @@ -472,7 +470,7 @@ def fit_lattice_vectors( # return if returncalc: - return self.braggdirections, self.bragg_vectors_indexed, self.g1g2_map + return self.bragg_vectors_indexed, self.g1g2_map def get_strain( self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs @@ -497,8 +495,8 @@ def get_strain( ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." if mask is None: - mask = np.ones(self.g1g2_map.shape, dtype="bool") - + mask = self.mask + #mask = np.ones(self.g1g2_map.shape, dtype="bool") # strainmap_g1g2 = get_strain_from_reference_region( # self.g1g2_map, # mask=mask, @@ -524,9 +522,12 @@ def get_strain( flip_theta=flip_theta, ) - self.strainmap_rotated = strainmap_rotated - - from py4DSTEM.visualize import show_strain + self.data[0] = strainmap_rotated['e_xx'].data + self.data[1] = strainmap_rotated['e_yy'].data + self.data[2] = strainmap_rotated['e_xy'].data + self.data[3] = strainmap_rotated['theta'].data + self.data[4] = strainmap_rotated['mask'].data + self.g_reference = g_reference figsize = kwargs.pop("figsize", (14, 4)) vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) @@ -535,8 +536,7 @@ def get_strain( bkgrd = kwargs.pop("bkgrd", False) axes_plots = kwargs.pop("axes_plots", ()) - fig, ax = show_strain( - self.strainmap_rotated, + fig, ax = self.show_strain( vrange_exx=vrange_exx, vrange_theta=vrange_theta, ticknumber=ticknumber, @@ -554,7 +554,327 @@ def get_strain( ax[1][1].imshow(mask, alpha=0.2, cmap="binary") if returncalc: - return self.strainmap_rotated + return self.strainmap + + + def show_strain( + self, + vrange_exx, + vrange_theta, + vrange_exy=None, + vrange_eyy=None, + flip_theta=False, + bkgrd=True, + show_cbars=("exx", "eyy", "exy", "theta"), + bordercolor="k", + borderwidth=1, + titlesize=24, + ticklabelsize=16, + ticknumber=5, + unitlabelsize=24, + show_axes=False, + axes_position = (0,0), + axes_length=10, + axes_width=1, + axes_color="w", + xaxis_space="Q", + labelaxes=True, + QR_rotation=0, + axes_labelsize=12, + axes_labelcolor="r", + axes_plots=("exx"), + cmap="RdBu_r", + mask_color = 'k', + layout=0, + figsize=(12, 12), + returnfig=False, + ): + """ + Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and + masking each image with strainmap.get_slice('mask') + + Args: + vrange_exx (length 2 list or tuple): + vrange_theta (length 2 list or tuple): + vrange_exy (length 2 list or tuple): + vrange_eyy (length 2 list or tuple): + flip_theta (bool): if True, take negative of angle + bkgrd (bool): + show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a + tuple containing any, all, or none of ('exx','eyy','exy','theta'). + bordercolor (color): + borderwidth (number): + titlesize (number): + ticklabelsize (number): + ticknumber (number): number of ticks on colorbars + unitlabelsize (number): + show_axes (bool): + axes_x0 (number): + axes_y0 (number): + xaxis_x (number): + xaxis_y (number): + axes_length (number): + axes_width (number): + axes_color (color): + xaxis_space (string): must be 'Q' or 'R' + labelaxes (bool): + QR_rotation (number): + axes_labelsize (number): + axes_labelcolor (color): + axes_plots (tuple of strings): controls if coordinate axes showing the + orientation of the strain matrices are overlaid over any of the plots. + Must be a tuple of strings containing any, all, or none of + ('exx','eyy','exy','theta'). + cmap (colormap): + layout=0 (int): determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize (length 2 tuple of numbers): + returnfig (bool): + """ + # Lookup table for different layouts + assert layout in (0, 1, 2) + layout_lookup = { + 0: ["left", "right", "left", "right"], + 1: ["bottom", "bottom", "bottom", "bottom"], + 2: ["right", "right", "right", "right"], + } + layout_p = layout_lookup[layout] + + # Contrast limits + if vrange_exy is None: + vrange_exy = vrange_exx + if vrange_eyy is None: + vrange_eyy = vrange_exx + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 + # theta is plotted in units of degrees + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) + + # Get images + e_xx = np.ma.array( + self.get_slice("exx").data, mask=self.get_slice("mask").data == False + ) + e_yy = np.ma.array( + self.get_slice("eyy").data, mask=self.get_slice("mask").data == False + ) + e_xy = np.ma.array( + self.get_slice("exy").data, mask=self.get_slice("mask").data == False + ) + theta = np.ma.array( + self.get_slice("theta").data, + mask=self.get_slice("mask").data == False, + ) + if flip_theta == True: + theta = -theta + + ## Plot + + # modify the figsize according to the image aspect ratio + ratio = np.sqrt(self.rshape[1]/self.rshape[0]) + figsize_mean = np.mean(figsize) + figsize = (figsize_mean*ratio, figsize_mean/ratio) + + # set up layout + if layout == 0: + fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + elif layout == 1: + figsize = (figsize[0]*np.sqrt(2),figsize[1]/np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) + else: + figsize = (figsize[0]/np.sqrt(2),figsize[1]*np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) + + # display images, returning cbar axis references + cax11 = show( + e_xx, + figax=(fig, ax11), + vmin=vmin_exx, + vmax=vmax_exx, + intensity_range="absolute", + cmap=cmap, + mask = self.mask, + mask_color = mask_color, + returncax=True, + ) + cax12 = show( + e_yy, + figax=(fig, ax12), + vmin=vmin_eyy, + vmax=vmax_eyy, + intensity_range="absolute", + cmap=cmap, + mask = self.mask, + mask_color = mask_color, + returncax=True, + ) + cax21 = show( + e_xy, + figax=(fig, ax21), + vmin=vmin_exy, + vmax=vmax_exy, + intensity_range="absolute", + cmap=cmap, + mask = self.mask, + mask_color = mask_color, + returncax=True, + ) + cax22 = show( + theta, + figax=(fig, ax22), + vmin=vmin_theta, + vmax=vmax_theta, + intensity_range="absolute", + cmap=cmap, + mask = self.mask, + mask_color = mask_color, + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) + + # Add black background + if bkgrd: + mask = np.ma.masked_where( + self.get_slice("mask").data.astype(bool), + np.zeros_like(self.get_slice("mask").data), + ) + ax11.matshow(mask, cmap="gray") + ax12.matshow(mask, cmap="gray") + ax21.matshow(mask, cmap="gray") + ax22.matshow(mask, cmap="gray") + + # add colorbars + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) + if np.any(show_cbars): + divider11 = make_axes_locatable(ax11) + divider12 = make_axes_locatable(ax12) + divider21 = make_axes_locatable(ax21) + divider22 = make_axes_locatable(ax22) + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( + range(4), + show_cbars, + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): + if show_cbar: + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) + if ind < 3: + ticklabels = np.round( + np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), + decimals=2, + ).astype(str) + else: + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) + cbax.yaxis.set_ticks_position(tickside) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) + cbax.yaxis.set_label_position(tickside) + else: + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) + cbax.xaxis.set_ticks_position(tickside) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) + cbax.xaxis.set_label_position(tickside) + else: + cbax.axis("off") + + # Add coordinate axes + if show_axes: + assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" + show_which_axes = np.array( + [ + "exx" in axes_plots, + "eyy" in axes_plots, + "exy" in axes_plots, + "theta" in axes_plots, + ] + ) + for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): + if _show: + if xaxis_space == "R": + ax_addaxes( + _ax, + self.g_reference[0], + self.g_reference[1], + axes_length, + axes_position[0], + axes_position[1], + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) + else: + ax_addaxes_QtoR( + _ax, + self.g_reference[0], + self.g_reference[1], + axes_length, + axes_position[0], + axes_position[1], + QR_rotation, + width=axes_width, + color=axes_color, + labelaxes=labelaxes, + labelsize=axes_labelsize, + labelcolor=axes_labelcolor, + ) + + # Add borders + if bordercolor is not None: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: + ax.spines[s].set_color(bordercolor) + ax.spines[s].set_linewidth(borderwidth) + ax.set_xticks([]) + ax.set_yticks([]) + + if not returnfig: + plt.show() + return + else: + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs + + def show_lattice_vectors( ar, diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index ba0ee024a..8612d65b8 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -404,307 +404,6 @@ def show_class_BPs_grid( return fig, axs -def show_strain( - strainmap, - vrange_exx, - vrange_theta, - vrange_exy=None, - vrange_eyy=None, - flip_theta=False, - bkgrd=True, - show_cbars=("exx", "eyy", "exy", "theta"), - bordercolor="k", - borderwidth=1, - titlesize=24, - ticklabelsize=16, - ticknumber=5, - unitlabelsize=24, - show_axes=True, - axes_x0=0, - axes_y0=0, - xaxis_x=1, - xaxis_y=0, - axes_length=10, - axes_width=1, - axes_color="r", - xaxis_space="Q", - labelaxes=True, - QR_rotation=0, - axes_labelsize=12, - axes_labelcolor="r", - axes_plots=("exx"), - cmap="RdBu_r", - layout=0, - figsize=(12, 12), - returnfig=False, -): - """ - Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') - - Args: - strainmap (RealSlice): - vrange_exx (length 2 list or tuple): - vrange_theta (length 2 list or tuple): - vrange_exy (length 2 list or tuple): - vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle - bkgrd (bool): - show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a - tuple containing any, all, or none of ('exx','eyy','exy','theta'). - bordercolor (color): - borderwidth (number): - titlesize (number): - ticklabelsize (number): - ticknumber (number): number of ticks on colorbars - unitlabelsize (number): - show_axes (bool): - axes_x0 (number): - axes_y0 (number): - xaxis_x (number): - xaxis_y (number): - axes_length (number): - axes_width (number): - axes_color (color): - xaxis_space (string): must be 'Q' or 'R' - labelaxes (bool): - QR_rotation (number): - axes_labelsize (number): - axes_labelcolor (color): - axes_plots (tuple of strings): controls if coordinate axes showing the - orientation of the strain matrices are overlaid over any of the plots. - Must be a tuple of strings containing any, all, or none of - ('exx','eyy','exy','theta'). - cmap (colormap): - layout=0 (int): determines the layout of the grid which the strain components - will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). - figsize (length 2 tuple of numbers): - returnfig (bool): - """ - # Lookup table for different layouts - assert layout in (0, 1, 2) - layout_lookup = { - 0: ["left", "right", "left", "right"], - 1: ["bottom", "bottom", "bottom", "bottom"], - 2: ["right", "right", "right", "right"], - } - layout_p = layout_lookup[layout] - - # Contrast limits - if vrange_exy is None: - vrange_exy = vrange_exx - if vrange_eyy is None: - vrange_eyy = vrange_exx - for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): - assert len(vrange) == 2, "vranges must have length 2" - vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 - vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 - vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 - # theta is plotted in units of degrees - vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( - 180.0 / np.pi - ) - - # Get images - e_xx = np.ma.array( - strainmap.get_slice("e_xx").data, mask=strainmap.get_slice("mask").data == False - ) - e_yy = np.ma.array( - strainmap.get_slice("e_yy").data, mask=strainmap.get_slice("mask").data == False - ) - e_xy = np.ma.array( - strainmap.get_slice("e_xy").data, mask=strainmap.get_slice("mask").data == False - ) - theta = np.ma.array( - strainmap.get_slice("theta").data, - mask=strainmap.get_slice("mask").data == False, - ) - if flip_theta == True: - theta = -theta - - # Plot - if layout == 0: - fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) - elif layout == 1: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) - else: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) - cax11 = show( - e_xx, - figax=(fig, ax11), - vmin=vmin_exx, - vmax=vmax_exx, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax12 = show( - e_yy, - figax=(fig, ax12), - vmin=vmin_eyy, - vmax=vmax_eyy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax21 = show( - e_xy, - figax=(fig, ax21), - vmin=vmin_exy, - vmax=vmax_exy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax22 = show( - theta, - figax=(fig, ax22), - vmin=vmin_theta, - vmax=vmax_theta, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) - ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) - ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) - ax22.set_title(r"$\theta$", size=titlesize) - - # Add black background - if bkgrd: - mask = np.ma.masked_where( - strainmap.get_slice("mask").data.astype(bool), - np.zeros_like(strainmap.get_slice("mask").data), - ) - ax11.matshow(mask, cmap="gray") - ax12.matshow(mask, cmap="gray") - ax21.matshow(mask, cmap="gray") - ax22.matshow(mask, cmap="gray") - - # Colorbars - show_cbars = np.array( - [ - "exx" in show_cbars, - "eyy" in show_cbars, - "exy" in show_cbars, - "theta" in show_cbars, - ] - ) - if np.any(show_cbars): - divider11 = make_axes_locatable(ax11) - divider12 = make_axes_locatable(ax12) - divider21 = make_axes_locatable(ax21) - divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) - cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) - cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) - cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) - for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( - range(4), - show_cbars, - (cax11, cax12, cax21, cax22), - (cbax11, cbax12, cbax21, cbax22), - (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), - (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), - (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), - ("% ", " %", "% ", r" $^\circ$"), - ): - if show_cbar: - ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) - if ind < 3: - ticklabels = np.round( - np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), - decimals=2, - ).astype(str) - else: - ticklabels = np.round( - np.linspace( - (180 / np.pi) * vmin, - (180 / np.pi) * vmax, - ticknumber, - endpoint=True, - ), - decimals=2, - ).astype(str) - - if tickside in ("left", "right"): - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="vertical" - ) - cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) - cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) - cbax.yaxis.set_label_position(tickside) - else: - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="horizontal" - ) - cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) - cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) - cbax.xaxis.set_label_position(tickside) - else: - cbax.axis("off") - - # Add coordinate axes - if show_axes: - assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array( - [ - "exx" in axes_plots, - "eyy" in axes_plots, - "exy" in axes_plots, - "theta" in axes_plots, - ] - ) - for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): - if _show: - if xaxis_space == "R": - ax_addaxes( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - else: - ax_addaxes_QtoR( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - QR_rotation, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - - # Add borders - if bordercolor is not None: - for ax in (ax11, ax12, ax21, ax22): - for s in ["bottom", "top", "left", "right"]: - ax.spines[s].set_color(bordercolor) - ax.spines[s].set_linewidth(borderwidth) - ax.set_xticks([]) - ax.set_yticks([]) - - if not returnfig: - plt.show() - return - else: - axs = ((ax11, ax12), (ax21, ax22)) - return fig, axs - def show_pointlabels( ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs From ce0e1da615dbf5e09f07eb72272c6e7e4df65351 Mon Sep 17 00:00:00 2001 From: Steven Zeltmann Date: Mon, 23 Oct 2023 11:35:31 -0400 Subject: [PATCH 114/176] add pyright to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6c008b0ff..24587a3b3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.swp *.ipynb_checkpoints* .vscode/ +pyrightconfig.json # Folders # .idea/ From 143f373a9e99564505f17e1394d177e052fc1038 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 16:39:35 +0100 Subject: [PATCH 115/176] autoformats --- py4DSTEM/process/latticevectors/index.py | 2 - py4DSTEM/process/strain/__init__.py | 2 - py4DSTEM/process/strain/latticevectors.py | 1 - py4DSTEM/process/strain/strain.py | 90 +++++++++++------------ 4 files changed, 44 insertions(+), 51 deletions(-) diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py index 2d243cd0c..03cdf07ce 100644 --- a/py4DSTEM/process/latticevectors/index.py +++ b/py4DSTEM/process/latticevectors/index.py @@ -108,8 +108,6 @@ def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None) return ideal_lattice - - def bragg_vector_intensity_map_by_index(braggpeaks, h, k, symmetric=False): """ Returns a correlation intensity map for an indexed (h,k) Bragg vector diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py index 213d5e812..b487c916b 100644 --- a/py4DSTEM/process/strain/__init__.py +++ b/py4DSTEM/process/strain/__init__.py @@ -7,6 +7,4 @@ get_reference_g1g2, get_strain_from_reference_g1g2, get_rotated_strain_map, - ) - diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py index 30e5cc989..26c8d66a5 100644 --- a/py4DSTEM/process/strain/latticevectors.py +++ b/py4DSTEM/process/strain/latticevectors.py @@ -456,4 +456,3 @@ def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data return rotated_strain_map - diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index cd662b8f1..538c90825 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -115,7 +115,7 @@ def origin(self): @property def mask(self): try: - return self.g1g2_map['mask'].data.astype('bool') + return self.g1g2_map["mask"].data.astype("bool") except: return np.ones(self.rshape, dtype=bool) @@ -138,9 +138,9 @@ def reset_calstate(self): def choose_lattice_vectors( self, - index_g1 = None, - index_g2 = None, - index_origin = None, + index_g1=None, + index_g2=None, + index_origin=None, subpixel="multicorr", upsample_factor=16, sigma=0, @@ -233,7 +233,9 @@ def choose_lattice_vectors( """ # validate inputs for i in (index_origin, index_g1, index_g2): - assert(isinstance(i, (int, np.integer)) or (i is None)), "indices must be integers!" + assert isinstance(i, (int, np.integer)) or ( + i is None + ), "indices must be integers!" # check the calstate assert ( self.calstate == self.braggvectors.calstate @@ -255,27 +257,27 @@ def choose_lattice_vectors( ) # guess the origin and g1 g2 vectors if indices aren't provided - if np.any([x is None for x in (index_g1,index_g2,index_origin)]): - + if np.any([x is None for x in (index_g1, index_g2, index_origin)]): # get distances and angles from calibrated origin - g_dists = np.hypot(g['x']-self.origin[0], g['y']-self.origin[1]) - g_angles = np.angle(g['x']-self.origin[0] + 1j*(g['y']-self.origin[1])) + g_dists = np.hypot(g["x"] - self.origin[0], g["y"] - self.origin[1]) + g_angles = np.angle( + g["x"] - self.origin[0] + 1j * (g["y"] - self.origin[1]) + ) # guess the origin if index_origin is None: index_origin = np.argmin(g_dists) - g_dists[index_origin] = 2*np.max(g_dists) + g_dists[index_origin] = 2 * np.max(g_dists) # guess g1 if index_g1 is None: index_g1 = np.argmin(g_dists) - g_dists[index_g1] = 2*np.max(g_dists) + g_dists[index_g1] = 2 * np.max(g_dists) # guess g2 if index_g2 is None: - angle_scaling = np.cos(g_angles - g_angles[index_g1])**2 - index_g2 = np.argmin(g_dists*(angle_scaling+0.1)) - + angle_scaling = np.cos(g_angles - g_angles[index_g1]) ** 2 + index_g2 = np.argmin(g_dists * (angle_scaling + 0.1)) # get the lattice vectors gx, gy = g["x"], g["y"] @@ -368,8 +370,8 @@ def choose_lattice_vectors( self.g2 = g2 # center the bragg directions and store - braggdirections.data['qx'] -= self.origin[0] - braggdirections.data['qy'] -= self.origin[1] + braggdirections.data["qx"] -= self.origin[0] + braggdirections.data["qy"] -= self.origin[1] self.braggdirections = braggdirections # return @@ -409,7 +411,6 @@ def fit_lattice_vectors( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - ### add indices to the bragg vectors # validate mask @@ -422,7 +423,7 @@ def fit_lattice_vectors( # set up new braggpeaks PLA indexed_braggpeaks = PointListArray( - dtype = [ + dtype=[ ("qx", float), ("qy", float), ("intensity", float), @@ -446,8 +447,8 @@ def fit_lattice_vectors( ) for i in range(pl.data.shape[0]): r = np.hypot( - pl.data["qx"][i]-self.braggdirections.data["qx"], - pl.data["qy"][i]-self.braggdirections.data["qy"] + pl.data["qx"][i] - self.braggdirections.data["qx"], + pl.data["qy"][i] - self.braggdirections.data["qy"], ) ind = np.argmin(r) if r[ind] <= max_peak_spacing: @@ -462,12 +463,10 @@ def fit_lattice_vectors( ) self.bragg_vectors_indexed = indexed_braggpeaks - ### fit bragg vectors g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) self.g1g2_map = g1g2_map - # return if returncalc: return self.bragg_vectors_indexed, self.g1g2_map @@ -496,7 +495,7 @@ def get_strain( if mask is None: mask = self.mask - #mask = np.ones(self.g1g2_map.shape, dtype="bool") + # mask = np.ones(self.g1g2_map.shape, dtype="bool") # strainmap_g1g2 = get_strain_from_reference_region( # self.g1g2_map, # mask=mask, @@ -522,11 +521,11 @@ def get_strain( flip_theta=flip_theta, ) - self.data[0] = strainmap_rotated['e_xx'].data - self.data[1] = strainmap_rotated['e_yy'].data - self.data[2] = strainmap_rotated['e_xy'].data - self.data[3] = strainmap_rotated['theta'].data - self.data[4] = strainmap_rotated['mask'].data + self.data[0] = strainmap_rotated["e_xx"].data + self.data[1] = strainmap_rotated["e_yy"].data + self.data[2] = strainmap_rotated["e_xy"].data + self.data[3] = strainmap_rotated["theta"].data + self.data[4] = strainmap_rotated["mask"].data self.g_reference = g_reference figsize = kwargs.pop("figsize", (14, 4)) @@ -556,7 +555,6 @@ def get_strain( if returncalc: return self.strainmap - def show_strain( self, vrange_exx, @@ -573,7 +571,7 @@ def show_strain( ticknumber=5, unitlabelsize=24, show_axes=False, - axes_position = (0,0), + axes_position=(0, 0), axes_length=10, axes_width=1, axes_color="w", @@ -584,7 +582,7 @@ def show_strain( axes_labelcolor="r", axes_plots=("exx"), cmap="RdBu_r", - mask_color = 'k', + mask_color="k", layout=0, figsize=(12, 12), returnfig=False, @@ -675,18 +673,18 @@ def show_strain( ## Plot # modify the figsize according to the image aspect ratio - ratio = np.sqrt(self.rshape[1]/self.rshape[0]) + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) figsize_mean = np.mean(figsize) - figsize = (figsize_mean*ratio, figsize_mean/ratio) + figsize = (figsize_mean * ratio, figsize_mean / ratio) # set up layout if layout == 0: fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) elif layout == 1: - figsize = (figsize[0]*np.sqrt(2),figsize[1]/np.sqrt(2)) + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) else: - figsize = (figsize[0]/np.sqrt(2),figsize[1]*np.sqrt(2)) + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) # display images, returning cbar axis references @@ -697,8 +695,8 @@ def show_strain( vmax=vmax_exx, intensity_range="absolute", cmap=cmap, - mask = self.mask, - mask_color = mask_color, + mask=self.mask, + mask_color=mask_color, returncax=True, ) cax12 = show( @@ -708,8 +706,8 @@ def show_strain( vmax=vmax_eyy, intensity_range="absolute", cmap=cmap, - mask = self.mask, - mask_color = mask_color, + mask=self.mask, + mask_color=mask_color, returncax=True, ) cax21 = show( @@ -719,8 +717,8 @@ def show_strain( vmax=vmax_exy, intensity_range="absolute", cmap=cmap, - mask = self.mask, - mask_color = mask_color, + mask=self.mask, + mask_color=mask_color, returncax=True, ) cax22 = show( @@ -730,8 +728,8 @@ def show_strain( vmax=vmax_theta, intensity_range="absolute", cmap=cmap, - mask = self.mask, - mask_color = mask_color, + mask=self.mask, + mask_color=mask_color, returncax=True, ) ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) @@ -782,7 +780,9 @@ def show_strain( ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) if ind < 3: ticklabels = np.round( - np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), + np.linspace( + 100 * vmin, 100 * vmax, ticknumber, endpoint=True + ), decimals=2, ).astype(str) else: @@ -874,8 +874,6 @@ def show_strain( axs = ((ax11, ax12), (ax21, ax22)) return fig, axs - - def show_lattice_vectors( ar, x0, From c926d8e21a85f8a4bd3b7a602596aebbb88afb26 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 17:06:17 +0100 Subject: [PATCH 116/176] autoformats --- py4DSTEM/visualize/vis_special.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 8612d65b8..d1efbd023 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -404,7 +404,6 @@ def show_class_BPs_grid( return fig, axs - def show_pointlabels( ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs ): From 3048ebf49eb8d2d0d4d93bda9cf07e3348ed352b Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 18:06:13 +0100 Subject: [PATCH 117/176] rms deprecated latticevectors module --- py4DSTEM/process/__init__.py | 1 - py4DSTEM/process/latticevectors/__init__.py | 3 - py4DSTEM/process/latticevectors/fit.py | 71 ------ py4DSTEM/process/latticevectors/index.py | 147 ----------- .../process/latticevectors/initialguess.py | 229 ------------------ 5 files changed, 451 deletions(-) delete mode 100644 py4DSTEM/process/latticevectors/__init__.py delete mode 100644 py4DSTEM/process/latticevectors/fit.py delete mode 100644 py4DSTEM/process/latticevectors/index.py delete mode 100644 py4DSTEM/process/latticevectors/initialguess.py diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 0df11ef01..6f0019019 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,7 +1,6 @@ from py4DSTEM.process.polar import PolarDatacube from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process import latticevectors from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils diff --git a/py4DSTEM/process/latticevectors/__init__.py b/py4DSTEM/process/latticevectors/__init__.py deleted file mode 100644 index cda4f91e5..000000000 --- a/py4DSTEM/process/latticevectors/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from py4DSTEM.process.latticevectors.initialguess import * -from py4DSTEM.process.latticevectors.index import * -from py4DSTEM.process.latticevectors.fit import * diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py deleted file mode 100644 index d36b10bca..000000000 --- a/py4DSTEM/process/latticevectors/fit.py +++ /dev/null @@ -1,71 +0,0 @@ -# Functions for fitting lattice vectors to measured Bragg peak positions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.data import RealSlice - - -def fit_lattice_vectors_masked(braggpeaks, mask, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks corresponding - to a scan position for which mask==True. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - mask (boolean array): real space shaped (R_Nx,R_Ny); fit lattice vectors where - mask is True - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((braggpeaks.shape[0], braggpeaks.shape[1], 8)), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - if mask[Rx, Ry]: - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - return g1g2_map diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py deleted file mode 100644 index 03cdf07ce..000000000 --- a/py4DSTEM/process/latticevectors/index.py +++ /dev/null @@ -1,147 +0,0 @@ -# Functions for indexing the Bragg directions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray - - -def get_selected_lattice_vectors(gx, gy, i0, i1, i2): - """ - From a set of reciprocal lattice points (gx,gy), and indices in those arrays which - specify the center beam, the first basis lattice vector, and the second basis lattice - vector, computes and returns the lattice vectors g1 and g2. - - Args: - gx (1d array): the reciprocal lattice points x-coords - gy (1d array): the reciprocal lattice points y-coords - i0 (int): index in the (gx,gy) arrays specifying the center beam - i1 (int): index in the (gx,gy) arrays specifying the first basis lattice vector - i2 (int): index in the (gx,gy) arrays specifying the second basis lattice vector - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing - - * **g1**: *(2-tuple)* the first lattice vector, (g1x,g1y) - * **g2**: *(2-tuple)* the second lattice vector, (g2x,g2y) - """ - for i in (i0, i1, i2): - assert isinstance(i, (int, np.integer)) - g1x = gx[i1] - gx[i0] - g1y = gy[i1] - gy[i0] - g2x = gx[i2] - gx[i0] - g2y = gy[i2] - gy[i0] - return (g1x, g1y), (g2x, g2y) - - -def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None): - """ - Returns a full reciprocal lattice stretching to the limits of the diffraction pattern - by making linear combinations of the lattice vectors up to (±h_max,±k_max). - - This can be useful when there are false peaks or missing peaks in the braggvectormap, - which can cause errors in the strain finding routines that rely on those peaks for - indexing. This allows us to create a reference lattice that has all combinations of - the lattice vectors all the way out to the edges of the frame, and excluding any - erroneous intermediate peaks. - - Args: - ux (float): x-coord of the u lattice vector - uy (float): y-coord of the u lattice vector - vx (float): x-coord of the v lattice vector - vy (float): y-coord of the v lattice vector - x0 (float): x-coord of the lattice origin - y0 (float): y-coord of the lattice origin - Q_Nx (int): diffraction pattern size in the x-direction - Q_Ny (int): diffraction pattern size in the y-direction - h_max, k_max (int): maximal indices for generating the lattice (the lattive is - always trimmed to fit inside the pattern so you can overestimate these, or - leave unspecified and they will be automatically found) - - Returns: - (PointList): A 4-coordinate PointList, ('qx','qy','h','k'), containing points - corresponding to linear combinations of the u and v vectors, with associated - indices - """ - - # Matrix of lattice vectors - beta = np.array([[ux, uy], [vx, vy]]) - - # If no max index is specified, (over)estimate based on image size - if (h_max is None) or (k_max is None): - (y, x) = np.mgrid[0:Q_Ny, 0:Q_Nx] - x = x - x0 - y = y - y0 - h_max = np.max(np.ceil(np.abs((x / ux, y / uy)))) - k_max = np.max(np.ceil(np.abs((x / vx, y / vy)))) - - (hlist, klist) = np.meshgrid( - np.arange(-h_max, h_max + 1), np.arange(-k_max, k_max + 1) - ) - - M_ideal = np.vstack((hlist.ravel(), klist.ravel())).T - ideal_peaks = np.matmul(M_ideal, beta) - - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - - ideal_data = np.zeros(len(ideal_peaks[:, 0]), dtype=coords) - ideal_data["qx"] = ideal_peaks[:, 0] - ideal_data["qy"] = ideal_peaks[:, 1] - ideal_data["h"] = M_ideal[:, 0] - ideal_data["k"] = M_ideal[:, 1] - - ideal_lattice = PointList(data=ideal_data) - - # shift to the DP center - ideal_lattice.data["qx"] += x0 - ideal_lattice.data["qy"] += y0 - - # trim peaks outside the image - deletePeaks = ( - (ideal_lattice.data["qx"] > Q_Nx) - | (ideal_lattice.data["qx"] < 0) - | (ideal_lattice.data["qy"] > Q_Ny) - | (ideal_lattice.data["qy"] < 0) - ) - ideal_lattice.remove(deletePeaks) - - return ideal_lattice - - -def bragg_vector_intensity_map_by_index(braggpeaks, h, k, symmetric=False): - """ - Returns a correlation intensity map for an indexed (h,k) Bragg vector - Used to obtain a darkfield image corresponding to the (h,k) reflection - or a bightfield image when h=k=0 - - Args: - braggpeaks (PointListArray): must contain the coordinates 'h','k', and - 'intensity' - h, k (int): indices for the reflection to generate an intensity map from - symmetric (bool): if set to true, returns sum of intensity of (h,k), (-h,k), - (h,-k), (-h,-k) - - Returns: - (numpy array): a map of the intensity of the (h,k) Bragg vector correlation. - Same shape as the pointlistarray. - """ - assert isinstance(braggpeaks, PointListArray), "braggpeaks must be a PointListArray" - assert np.all([name in braggpeaks.dtype.names for name in ("h", "k", "intensity")]) - intensity_map = np.zeros(braggpeaks.shape, dtype=float) - - for Rx in range(braggpeaks.shape[0]): - for Ry in range(braggpeaks.shape[1]): - pl = braggpeaks.get_pointlist(Rx, Ry) - if pl.length > 0: - if symmetric: - matches = np.logical_and( - np.abs(pl.data["h"]) == np.abs(h), - np.abs(pl.data["k"]) == np.abs(k), - ) - else: - matches = np.logical_and(pl.data["h"] == h, pl.data["k"] == k) - - if len(matches) > 0: - intensity_map[Rx, Ry] = np.sum(pl.data["intensity"][matches]) - - return intensity_map diff --git a/py4DSTEM/process/latticevectors/initialguess.py b/py4DSTEM/process/latticevectors/initialguess.py deleted file mode 100644 index d8054143f..000000000 --- a/py4DSTEM/process/latticevectors/initialguess.py +++ /dev/null @@ -1,229 +0,0 @@ -# Obtain an initial guess at the lattice vectors - -import numpy as np -from scipy.ndimage import gaussian_filter -from skimage.transform import radon - -from py4DSTEM.process.utils import get_maxima_1D - - -def get_radon_scores( - braggvectormap, - mask=None, - N_angles=200, - sigma=2, - minSpacing=2, - minRelativeIntensity=0.05, -): - """ - Calculates a score function, score(angle), representing the likelihood that angle is - a principle lattice direction of the lattice in braggvectormap. - - The procedure is as follows: - If mask is not None, ignore any data in braggvectormap where mask is False. Useful - for removing the unscattered beam, which can dominate the results. - Take the Radon transform of the (masked) Bragg vector map. - For each angle, get the corresponding slice of the sinogram, and calculate its score. - If we let R_theta(r) be the sinogram slice at angle theta, and where r is the - sinogram position coordinate, then the score of the slice is given by - score(theta) = sum_i(R_theta(r_i)) / N_i - Here, r_i are the positions r of all local maxima in R_theta(r), and N_i is the - number of such maxima. Thus the score is large when there are few maxima which are - high intensity. - - Args: - braggvectormap (ndarray): the Bragg vector map - mask (ndarray of bools): ignore data in braggvectormap wherever mask==False - N_angles (int): the number of angles at which to calculate the score - sigma (float): smoothing parameter for local maximum identification - minSpacing (float): if two maxima are found in a radon slice closer than - minSpacing, the dimmer of the two is removed - minRelativeIntensity (float): maxima in each radon slice dimmer than - minRelativeIntensity compared to the most intense maximum are removed - - Returns: - (3-tuple) A 3-tuple containing: - - * **scores**: *(ndarray, len N_angles, floats)* the scores for each angle - * **thetas**: *(ndarray, len N_angles, floats)* the angles, in radians - * **sinogram**: *(ndarray)* the radon transform of braggvectormap*mask - """ - # Get sinogram - thetas = np.linspace(0, 180, N_angles) - if mask is not None: - sinogram = radon(braggvectormap * mask, theta=thetas, circle=False) - else: - sinogram = radon(braggvectormap, theta=thetas, circle=False) - - # Get scores - N_maxima = np.empty_like(thetas) - total_intensity = np.empty_like(thetas) - for i in range(len(thetas)): - theta = thetas[i] - - # Get radon transform slice - ind = np.argmin(np.abs(thetas - theta)) - sinogram_theta = sinogram[:, ind] - sinogram_theta = gaussian_filter(sinogram_theta, 2) - - # Get maxima - maxima = get_maxima_1D(sinogram_theta, sigma, minSpacing, minRelativeIntensity) - - # Calculate metrics - N_maxima[i] = len(maxima) - total_intensity[i] = np.sum(sinogram_theta[maxima]) - scores = total_intensity / N_maxima - - return scores, np.radians(thetas), sinogram - - -def get_lattice_directions_from_scores( - thetas, scores, sigma=2, minSpacing=2, minRelativeIntensity=0.05, index1=0, index2=0 -): - """ - Get the lattice directions from the scores of the radon transform slices. - - Args: - thetas (ndarray): the angles, in radians - scores (ndarray): the scores - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - index1 (int): specifies which local maximum to use for the first lattice - direction, in order of maximum intensity - index2 (int): specifies the local maximum for the second lattice direction - - Returns: - (2-tuple) A 2-tuple containing: - - * **theta1**: *(float)* the first lattice direction, in radians - * **theta2**: *(float)* the second lattice direction, in radians - """ - assert len(thetas) == len(scores), "Size of thetas and scores must match" - - # Get first lattice direction - maxima1 = get_maxima_1D( - scores, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max1 = thetas[maxima1] - scores_max1 = scores[maxima1] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max1), dtype=dtype) - ar_structured["thetas"] = thetas_max1 - ar_structured["scores"] = scores_max1 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta1 = ar_structured["thetas"][index1] # Get direction 1 - - # Apply sin**2 damping - scores_damped = scores * np.sin(thetas - theta1) ** 2 - - # Get second lattice direction - maxima2 = get_maxima_1D( - scores_damped, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max2 = thetas[maxima2] - scores_max2 = scores[maxima2] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max2), dtype=dtype) - ar_structured["thetas"] = thetas_max2 - ar_structured["scores"] = scores_max2 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta2 = ar_structured["thetas"][index2] # Get direction 2 - - return theta1, theta2 - - -def get_lattice_vector_lengths( - u_theta, - v_theta, - thetas, - sinogram, - spacing_thresh=1.5, - sigma=1, - minSpacing=2, - minRelativeIntensity=0.1, -): - """ - Gets the lengths of the two lattice vectors from their angles and the sinogram. - - First, finds the spacing between peaks in the sinogram slices projected down the u- - and v- directions, u_proj and v_proj. Then, finds the lengths by taking:: - - |u| = v_proj/sin(u_theta-v_theta) - |v| = u_proj/sin(u_theta-v_theta) - - The most important thresholds for this function are spacing_thresh, which discards - any detected spacing between adjacent radon projection peaks which deviate from the - median spacing by more than this fraction, and minRelativeIntensity, which discards - detected maxima (from which spacings are then calculated) below this threshold - relative to the brightest maximum. - - Args: - u_theta (float): the angle of u, in radians - v_theta (float): the angle of v, in radians - thetas (ndarray): the angles corresponding to the sinogram - sinogram (ndarray): the sinogram - spacing_thresh (float): ignores spacings which are greater than spacing_thresh - times the median spacing - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - - Returns: - (2-tuple) A 2-tuple containing: - - * **u_length**: *(float)* the length of u, in pixels - * **v_length**: *(float)* the length of v, in pixels - """ - assert ( - len(thetas) == sinogram.shape[1] - ), "thetas must corresponding to the number of sinogram projection directions." - - # Get u projected spacing - ind = np.argmin(np.abs(thetas - u_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - u_projected_spacing = np.mean(spacings) - - # Get v projected spacing - ind = np.argmin(np.abs(thetas - v_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - v_projected_spacing = np.mean(spacings) - - # Get u and v lengths - sin_uv = np.sin(np.abs(u_theta - v_theta)) - u_length = v_projected_spacing / sin_uv - v_length = u_projected_spacing / sin_uv - - return u_length, v_length From 9d2e36c14d479c13c597026b585ffe8523de8a2c Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 18:08:29 +0100 Subject: [PATCH 118/176] rms deprecated latticevectors module --- py4DSTEM/process/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 6f0019019..0509d181e 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -5,6 +5,5 @@ from py4DSTEM.process import calibration from py4DSTEM.process import utils from py4DSTEM.process import classification -from py4DSTEM.process import latticevectors from py4DSTEM.process import diffraction from py4DSTEM.process import wholepatternfit From 1e13932e897a5d3ee7811780ee1da0212d8320cb Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 23 Oct 2023 18:26:53 +0100 Subject: [PATCH 119/176] versions to 0.14.6 --- py4DSTEM/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 141826d55..f2c260cc1 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.5" +__version__ = "0.14.6" From 2ce737aface54dc819c70961287b06059937e543 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 24 Oct 2023 05:39:48 +0100 Subject: [PATCH 120/176] change default figsize --- py4DSTEM/visualize/show.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index fb99de5ae..4e99c0de5 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -25,7 +25,7 @@ def show( ar, - figsize=(8, 8), + figsize=(5, 5), cmap="gray", scaling="none", intensity_range="ordered", From d2b7e8ee65ce5fb127ecef287d5007a337ef39f9 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 24 Oct 2023 05:41:01 +0100 Subject: [PATCH 121/176] versions to 0.14.7 --- py4DSTEM/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index f2c260cc1..e1130451f 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.6" +__version__ = "0.14.7" From a802d29232bc3abbfd395b59e1da7a7c38924abd Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 24 Oct 2023 07:22:56 +0100 Subject: [PATCH 122/176] ellipse display bugfix --- py4DSTEM/visualize/overlay.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 996bb89b3..4f46bdb4b 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -408,7 +408,7 @@ def add_ellipses(ax, d): (cent[1], cent[0]), 2 * _b, 2 * _a, - -np.degrees(_theta), + angle = -np.degrees(_theta), color=col, fill=f, alpha=_alpha, From d1e1435f0b119e792e2f092717655f4f06a371d2 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 24 Oct 2023 07:23:19 +0100 Subject: [PATCH 123/176] versions to 0.14.8 --- py4DSTEM/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index e1130451f..6a57210e9 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.7" +__version__ = "0.14.8" From 9a2675b37c56a6e6faa0657d4aca84b77f5b4a35 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 24 Oct 2023 09:46:36 +0100 Subject: [PATCH 124/176] autoformats --- py4DSTEM/visualize/overlay.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 4f46bdb4b..8421dd35b 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -408,7 +408,7 @@ def add_ellipses(ax, d): (cent[1], cent[0]), 2 * _b, 2 * _a, - angle = -np.degrees(_theta), + angle=-np.degrees(_theta), color=col, fill=f, alpha=_alpha, From 1fc9c52ae199148f02663a539de68981fa035e0c Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Tue, 24 Oct 2023 15:25:02 -0700 Subject: [PATCH 125/176] silly parallax bug --- py4DSTEM/process/phase/iterative_parallax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 74688fa0b..094209e7a 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1642,6 +1642,9 @@ def score_CTF(coefs): ), UserWarning, ) + else: + self._aberrations_coefs = asnumpy(aberrations_coefs) + self._rotated_shifts = rotated_shifts else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts From c41c386e64fc28bb8863f10c8289f94ee03c7d18 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 25 Oct 2023 15:17:51 -0700 Subject: [PATCH 126/176] some helpful deets --- py4DSTEM/process/phase/iterative_parallax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 094209e7a..21af22a37 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1789,6 +1789,7 @@ def score_CTF(coefs): ) print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Transpose = {self.transpose_detected}") if fit_CTF_FFT or fit_BF_shifts: print() From adae309beb32903a56c415e704eae4994dfd305e Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Thu, 26 Oct 2023 15:16:51 +0100 Subject: [PATCH 127/176] adds QR_rotation vis method --- py4DSTEM/process/calibration/rotation.py | 177 +++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/py4DSTEM/process/calibration/rotation.py b/py4DSTEM/process/calibration/rotation.py index 2c3e7bb43..21134a352 100644 --- a/py4DSTEM/process/calibration/rotation.py +++ b/py4DSTEM/process/calibration/rotation.py @@ -2,6 +2,183 @@ import numpy as np from typing import Optional +import matplotlib.pyplot as plt +from py4DSTEM import show + + +def compare_QR_rotation( + im_R, + im_Q, + QR_rotation, + R_rotation = 0, + R_position = None, + Q_position = None, + R_pos_anchor = 'center', + Q_pos_anchor = 'center', + R_length = 0.33, + Q_length = 0.33, + R_width = 0.001, + Q_width = 0.001, + R_head_length_adjust = 1, + Q_head_length_adjust = 1, + R_head_width_adjust = 1, + Q_head_width_adjust = 1, + R_color = 'r', + Q_color = 'r', + figsize = (10,5), + returnfig = False + ): + """ + Visualize a rotational offset between an image in real space, e.g. a STEM + virtual image, and an image in diffraction space, e.g. a defocused CBED + shadow image of the same region, by displaying an arrow overlaid over each + of these two images with the specified QR rotation applied. The QR rotation + is defined as the counter-clockwise rotation from real space to diffraction + space, in degrees. + + Parameters + ---------- + im_R : numpy array or other 2D image-like object (e.g. a VirtualImage) + A real space image, e.g. a STEM virtual image + im_Q : numpy array or other 2D image-like object + A diffraction space image, e.g. a defocused CBED image + QR_rotation : number + The counterclockwise rotation from real space to diffraction space, + in degrees + R_rotation : number + The orientation of the arrow drawn in real space, in degrees + R_position : None or 2-tuple + The position of the anchor point for the R-space arrow. If None, defaults + to the center of the image + Q_position : None or 2-tuple + The position of the anchor point for the Q-space arrow. If None, defaults + to the center of the image + R_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the R-space arrow, i.e. the point being specified by + the `R_position` parameter + Q_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the Q-space arrow, i.e. the point being specified by + the `Q_position` parameter + R_length : number or None + The length of the R-space arrow, as a fraction of the mean size of the + image + Q_length : number or None + The length of the Q-space arrow, as a fraction of the mean size of the + image + R_width : number + The width of the R-space arrow + Q_width : number + The width of the R-space arrow + R_head_length_adjust : number + Scaling factor for the R-space arrow head length + Q_head_length_adjust : number + Scaling factor for the Q-space arrow head length + R_head_width_adjust : number + Scaling factor for the R-space arrow head width + Q_head_width_adjust : number + Scaling factor for the Q-space arrow head width + R_color : color + Color of the R-space arrow + Q_color : color + Color of the Q-space arrow + figsize : 2-tuple + The figure size + returnfig : bool + Toggles returning the figure and axes + """ + # parse inputs + if R_position is None: + R_position = ( + im_R.shape[0]/2, + im_R.shape[1]/2, + ) + if Q_position is None: + Q_position = ( + im_Q.shape[0]/2, + im_Q.shape[1]/2, + ) + R_length = np.mean(im_R.shape) * R_length + Q_length = np.mean(im_Q.shape) * Q_length + assert R_pos_anchor in ('center','tail','head') + assert Q_pos_anchor in ('center','tail','head') + + # compute positions + rpos_x,rpos_y = R_position + qpos_x,qpos_y = Q_position + R_rot_rad = np.radians(R_rotation) + Q_rot_rad = np.radians(R_rotation+QR_rotation) + rvecx = np.cos(R_rot_rad) + rvecy = np.sin(R_rot_rad) + qvecx = np.cos(Q_rot_rad) + qvecy = np.sin(Q_rot_rad) + if R_pos_anchor == 'center': + x0_r = rpos_x - rvecx*R_length/2 + y0_r = rpos_y - rvecy*R_length/2 + x1_r = rpos_x + rvecx*R_length/2 + y1_r = rpos_y + rvecy*R_length/2 + elif R_pos_anchor == 'tail': + x0_r = rpos_x + y0_r = rpos_y + x1_r = rpos_x + rvecx*R_length + y1_r = rpos_y + rvecy*R_length + elif R_pos_anchor == 'head': + x0_r = rpos_x - rvecx*R_length + y0_r = rpos_y - rvecy*R_length + x1_r = rpos_x + y1_r = rpos_y + else: + raise Exception(f"Invalid value for R_pos_anchor {R_pos_anchor}") + if Q_pos_anchor == 'center': + x0_q = qpos_x - qvecx*Q_length/2 + y0_q = qpos_y - qvecy*Q_length/2 + x1_q = qpos_x + qvecx*Q_length/2 + y1_q = qpos_y + qvecy*Q_length/2 + elif Q_pos_anchor == 'tail': + x0_q = qpos_x + y0_q = qpos_y + x1_q = qpos_x + qvecx*Q_length + y1_q = qpos_y + qvecy*Q_length + elif Q_pos_anchor == 'head': + x0_q = qpos_x - qvecx*Q_length + y0_q = qpos_y - qvecy*Q_length + x1_q = qpos_x + y1_q = qpos_y + else: + raise Exception(f"Invalid value for Q_pos_anchor {Q_pos_anchor}") + + # make the figure + axsize = (figsize[0]/2,figsize[1]) + fig,axs = show( + [im_R,im_Q], + returnfig = True, + axsize = axsize + ) + axs[0,0].arrow( + x = y0_r, + y = x0_r, + dx = y1_r - y0_r, + dy = x1_r - x0_r, + color = R_color, + length_includes_head = True, + width = R_width, + head_width = R_length*R_head_width_adjust*0.072, + head_length = R_length*R_head_length_adjust*0.1 + ) + axs[0,1].arrow( + x = y0_q, + y = x0_q, + dx = y1_q - y0_q, + dy = x1_q - x0_q, + color = Q_color, + length_includes_head = True, + width = Q_width, + head_width = Q_length*Q_head_width_adjust*0.072, + head_length = Q_length*Q_head_length_adjust*0.1 + ) + if returnfig: + return fig,axs + else: + plt.show() def get_Qvector_from_Rvector(vx, vy, QR_rotation): From 257862ae97d77d8e83b6def4b20e45e63f6f22fc Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Thu, 26 Oct 2023 15:22:09 +0100 Subject: [PATCH 128/176] adds QR_rot in rad and degrees in Calibration --- py4DSTEM/data/calibration.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py index a31f098d4..ffdbfa410 100644 --- a/py4DSTEM/data/calibration.py +++ b/py4DSTEM/data/calibration.py @@ -205,6 +205,7 @@ def __init__( self["R_pixel_size"] = 1 self["Q_pixel_units"] = "pixels" self["R_pixel_units"] = "pixels" + self["QR_flip"] = False # EMD root property @property @@ -666,8 +667,17 @@ def ellipse(self, x): # Q/R-space rotation and flip + @call_calibrate + def set_QR_rotation(self, x): + self._params["QR_rotation"] = x + self._params["QR_rotation_degrees"] = np.degrees(x) + + def get_QR_rotation(self): + return self._get_value("QR_rotation") + @call_calibrate def set_QR_rotation_degrees(self, x): + self._params["QR_rotation"] = np.radians(x) self._params["QR_rotation_degrees"] = x def get_QR_rotation_degrees(self): @@ -689,10 +699,31 @@ def set_QR_rotflip(self, rot_flip): flip (bool): True indicates a Q/R axes flip """ rot, flip = rot_flip + self._params["QR_rotation"] = rot + self._params["QR_rotation_degrees"] = np.degrees(rot) + self._params["QR_flip"] = flip + + @call_calibrate + def set_QR_rotflip_degrees(self, rot_flip): + """ + Args: + rot_flip (tuple), (rot, flip) where: + rot (number): rotation in degrees + flip (bool): True indicates a Q/R axes flip + """ + rot, flip = rot_flip + self._params["QR_rotation"] = np.radians(rot) self._params["QR_rotation_degrees"] = rot self._params["QR_flip"] = flip def get_QR_rotflip(self): + rot = self.get_QR_rotation() + flip = self.get_QR_flip() + if rot is None or flip is None: + return None + return (rot, flip) + + def get_QR_rotflip_degrees(self): rot = self.get_QR_rotation_degrees() flip = self.get_QR_flip() if rot is None or flip is None: From 6695d8e677fea067775c742571faffad060b8f2e Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Thu, 26 Oct 2023 16:06:19 +0100 Subject: [PATCH 129/176] adds QR_rot in rad and degrees in Calibration --- py4DSTEM/braggvectors/braggvectors.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 45c08b9c9..3a9ccb1ea 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -200,7 +200,7 @@ def setcal( if pixel is None: pixel = False if c.get_Q_pixel_size() == 1 else True if rotate is None: - rotate = False if c.get_QR_rotflip() is None else True + rotate = False if c.get_QR_rotation() is None else True # validate requested state if center: @@ -210,7 +210,7 @@ def setcal( if pixel: assert c.get_Q_pixel_size() is not None, "Requested calibration not found" if rotate: - assert c.get_QR_rotflip() is not None, "Requested calibration not found" + assert c.get_QR_rotation() is not None, "Requested calibration not found" # set the calibrations self._calstate = { @@ -478,10 +478,10 @@ def _transform( # Q/R rotation if rotate: - flip = cal.get_QR_flip() - theta = cal.get_QR_rotation_degrees() - assert flip is not None, "Requested calibration was not found!" + theta = cal.get_QR_rotation() assert theta is not None, "Requested calibration was not found!" + flip = cal.get_QR_flip() + flip = False if flip is None else flip # rotation matrix R = np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] From 20bc04157e1f48ea5f359ae3c36fd0854b40fc94 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 26 Oct 2023 11:23:56 -0700 Subject: [PATCH 130/176] middle focus for multislice --- .../iterative_multislice_ptychography.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 4515590fe..77a5c69ea 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -81,9 +81,11 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): Probe positions in Å for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in angles) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in angles) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -117,6 +119,7 @@ def __init__( initial_scan_positions: np.ndarray = None, theta_x: float = 0, theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", verbose: bool = True, device: str = "cpu", @@ -150,6 +153,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: From 3ec7b0bc425c2a20094c3d96237ed7df8fa19818 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 27 Oct 2023 13:47:28 -0700 Subject: [PATCH 131/176] Fixing scale bar being plotted as slightly too long --- py4DSTEM/visualize/overlay.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 8421dd35b..3bec9eaee 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -832,7 +832,14 @@ def add_scalebar(ax, d): labelpos_y = y0 # Add line - ax.plot((yi, yf), (xi, xf), lw=width, color=color, alpha=alpha) + ax.plot( + (yi, yf), + (xi, xf), + color=color, + alpha=alpha, + lw=width, + solid_capstyle = 'butt', + ) # Add label if label: From 77762d4f1386e49327640361fbf49fdd2324479e Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 27 Oct 2023 13:53:03 -0700 Subject: [PATCH 132/176] Add option to skip calculating correlation array in crystal.orientation_plot() --- py4DSTEM/process/diffraction/crystal_ACOM.py | 129 +++++++++---------- 1 file changed, 64 insertions(+), 65 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 5722f3f38..78aa577ea 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -29,6 +29,7 @@ def orientation_plan( corr_kernel_size: float = 0.08, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling + calculate_correlation_array = True, tol_peak_delete=None, tol_distance: float = 0.01, fiber_axis=None, @@ -61,6 +62,8 @@ def orientation_plan( corr_kernel_size (float): Correlation kernel size length in Angstroms radial_power (float): Power for scaling the correlation intensity as a function of the peak radius intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity + calculate_correlation_array (bool): Set to false to skip calculating the correlation array. + This is useful when we only want the angular range / rotation matrices. tol_peak_delete (float): Distance to delete peaks for multiple matches. Default is kernel_size * 0.5 tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] @@ -598,21 +601,6 @@ def orientation_plan( # init storage arrays self.orientation_rotation_angles = np.zeros((self.orientation_num_zones, 3)) self.orientation_rotation_matrices = np.zeros((self.orientation_num_zones, 3, 3)) - self.orientation_ref = np.zeros( - ( - self.orientation_num_zones, - np.size(self.orientation_shell_radii), - self.orientation_in_plane_steps, - ), - dtype="float", - ) - # self.orientation_ref_1D = np.zeros( - # ( - # self.orientation_num_zones, - # np.size(self.orientation_shell_radii), - # ), - # dtype="float", - # ) # If possible, Get symmetry operations for this spacegroup, store in matrix form if self.pymatgen_available: @@ -697,65 +685,76 @@ def orientation_plan( k0 = np.array([0.0, 0.0, -1.0 / self.wavelength]) n = np.array([0.0, 0.0, -1.0]) - for a0 in tqdmnd( - np.arange(self.orientation_num_zones), - desc="Orientation plan", - unit=" zone axes", - disable=not progress_bar, - ): - # reciprocal lattice spots and excitation errors - g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - sg = self.excitation_errors(g) - - # Keep only points that will contribute to this orientation plan slice - keep = np.abs(sg) < self.orientation_kernel_size - - # in-plane rotation angle - phi = np.arctan2(g[1, :], g[0, :]) - - # Loop over all peaks - for a1 in np.arange(self.g_vec_all.shape[1]): - ind_radial = self.orientation_shell_index[a1] + if calculate_correlation_array: + # initialize empty correlation array + self.orientation_ref = np.zeros( + ( + self.orientation_num_zones, + np.size(self.orientation_shell_radii), + self.orientation_in_plane_steps, + ), + dtype="float", + ) - if keep[a1] and ind_radial >= 0: - # 2D orientation plan - self.orientation_ref[a0, ind_radial, :] += ( - np.power(self.orientation_shell_radii[ind_radial], radial_power) - * np.power(self.struct_factors_int[a1], intensity_power) - * np.maximum( - 1 - - np.sqrt( - sg[a1] ** 2 - + ( - ( - np.mod( - self.orientation_gamma - phi[a1] + np.pi, - 2 * np.pi, + for a0 in tqdmnd( + np.arange(self.orientation_num_zones), + desc="Orientation plan", + unit=" zone axes", + disable=not progress_bar, + ): + # reciprocal lattice spots and excitation errors + g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all + sg = self.excitation_errors(g) + + # Keep only points that will contribute to this orientation plan slice + keep = np.abs(sg) < self.orientation_kernel_size + + # in-plane rotation angle + phi = np.arctan2(g[1, :], g[0, :]) + + # Loop over all peaks + for a1 in np.arange(self.g_vec_all.shape[1]): + ind_radial = self.orientation_shell_index[a1] + + if keep[a1] and ind_radial >= 0: + # 2D orientation plan + self.orientation_ref[a0, ind_radial, :] += ( + np.power(self.orientation_shell_radii[ind_radial], radial_power) + * np.power(self.struct_factors_int[a1], intensity_power) + * np.maximum( + 1 + - np.sqrt( + sg[a1] ** 2 + + ( + ( + np.mod( + self.orientation_gamma - phi[a1] + np.pi, + 2 * np.pi, + ) + - np.pi ) - - np.pi + * self.orientation_shell_radii[ind_radial] ) - * self.orientation_shell_radii[ind_radial] + ** 2 ) - ** 2 + / self.orientation_kernel_size, + 0, ) - / self.orientation_kernel_size, - 0, ) - ) - orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) - if orientation_ref_norm > 0: - self.orientation_ref[a0, :, :] /= orientation_ref_norm + orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) + if orientation_ref_norm > 0: + self.orientation_ref[a0, :, :] /= orientation_ref_norm - # Maximum value - self.orientation_ref_max = np.max(np.real(self.orientation_ref)) + # Maximum value + self.orientation_ref_max = np.max(np.real(self.orientation_ref)) - # Fourier domain along angular axis - if self.CUDA: - self.orientation_ref = cp.asarray(self.orientation_ref) - self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref)) - else: - self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref)) + # Fourier domain along angular axis + if self.CUDA: + self.orientation_ref = cp.asarray(self.orientation_ref) + self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref)) + else: + self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref)) def match_orientations( From 0c84848e9f1152e0441a4aa228cbb7bfdd1ea4fa Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 27 Oct 2023 13:54:39 -0700 Subject: [PATCH 133/176] Black formatting --- py4DSTEM/process/diffraction/crystal_ACOM.py | 2 +- py4DSTEM/visualize/overlay.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 78aa577ea..60c284263 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -29,7 +29,7 @@ def orientation_plan( corr_kernel_size: float = 0.08, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling - calculate_correlation_array = True, + calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, fiber_axis=None, diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 3bec9eaee..32baff443 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -833,12 +833,12 @@ def add_scalebar(ax, d): # Add line ax.plot( - (yi, yf), - (xi, xf), - color=color, + (yi, yf), + (xi, xf), + color=color, alpha=alpha, lw=width, - solid_capstyle = 'butt', + solid_capstyle="butt", ) # Add label From 958642e25a56b84a28e4b6390c5550948794a747 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Fri, 27 Oct 2023 15:05:18 -0700 Subject: [PATCH 134/176] adding assert statement --- py4DSTEM/process/diffraction/crystal_ACOM.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 60c284263..0c01dc70d 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -900,9 +900,13 @@ def match_single_pattern( Orientation class containing all outputs fig, ax: handles Figure handles for the plotting output - """ + """ + + # adding assert statement for checking self.orientation_ref is present + assert hasattr( + self, "orientation_ref" + ), "orientation_plan must be run with 'calculate_correlation_array=True'" - # init orientation output orientation = Orientation(num_matches=num_matches_return) if bragg_peaks.data.shape[0] < min_number_peaks: return orientation From 329034299cb3f676233a9117d343d3238afaa611 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Fri, 27 Oct 2023 15:09:17 -0700 Subject: [PATCH 135/176] black --- py4DSTEM/process/diffraction/crystal_ACOM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 0c01dc70d..cba68d8fb 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -900,7 +900,7 @@ def match_single_pattern( Orientation class containing all outputs fig, ax: handles Figure handles for the plotting output - """ + """ # adding assert statement for checking self.orientation_ref is present assert hasattr( From f9e2423084daea1602142f5d869ae983f6d4f4e5 Mon Sep 17 00:00:00 2001 From: gvarnavi Date: Fri, 27 Oct 2023 15:18:06 -0700 Subject: [PATCH 136/176] re-introducing probe intensity normalizations into constraints --- .../iterative_ptychographic_constraints.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 3eebdb068..0760087b4 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -433,8 +433,8 @@ def _probe_amplitude_constraint( xp = self._xp erf = self._erf - # probe_intensity = xp.abs(current_probe) ** 2 - # current_probe_sum = xp.sum(probe_intensity) + probe_intensity = xp.abs(current_probe) ** 2 + current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -444,10 +444,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_fourier_amplitude_constraint( self, @@ -476,7 +476,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -489,10 +489,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aperture_constraint( self, @@ -514,16 +514,16 @@ def _probe_aperture_constraint( """ xp = self._xp - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aberration_fitting_constraint( self, From 210ef749662ee7edc9bee88f1abece294c8f1037 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Sat, 28 Oct 2023 09:30:40 +0100 Subject: [PATCH 137/176] origin plotting and strain legend updates --- py4DSTEM/braggvectors/braggvector_methods.py | 148 ++++--- py4DSTEM/process/strain/strain.py | 385 ++++++++++++------- 2 files changed, 337 insertions(+), 196 deletions(-) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 267f81e5f..3ca898609 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -7,6 +7,7 @@ from emdfile import Array, Metadata, tqdmnd, _read_metadata from py4DSTEM.datacube import VirtualImage +from py4DSTEM import show class BraggVectorMethods: @@ -518,6 +519,7 @@ def fit_origin( mask_check_data=True, plot=True, plot_range=None, + cmap = 'RdBu_r', returncalc=True, **kwargs, ): @@ -537,6 +539,7 @@ def fit_origin( mask_check_data (bool): Get mask from origin measurements equal to zero. (TODO - replace) plot (bool, optional): plot results plot_range (float): min and max color range for plot (pixels) + cmap (colormap): plotting colormap Returns: (variable): Return value depends on returnfitp. If ``returnfitp==False`` @@ -561,75 +564,106 @@ def fit_origin( else: qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(tuple(q_meas)) - # try to add to calibration + # try to add update calibration metadata try: - self.calibration.set_origin([qx0_fit, qy0_fit]) + self.calibration.set_origin((qx0_fit, qy0_fit)) + self.setcal() except AttributeError: warn( "No calibration found on this datacube - fit values are not being stored" ) pass - if plot: - from py4DSTEM.visualize import show_image_grid - if mask is None: - qx0_meas, qy0_meas = q_meas - qx0_res_plot = qx0_residuals - qy0_res_plot = qy0_residuals - else: - qx0_meas = np.ma.masked_array(q_meas[0], mask=np.logical_not(mask)) - qy0_meas = np.ma.masked_array(q_meas[1], mask=np.logical_not(mask)) - qx0_res_plot = np.ma.masked_array( - qx0_residuals, mask=np.logical_not(mask) - ) - qy0_res_plot = np.ma.masked_array( - qy0_residuals, mask=np.logical_not(mask) - ) - qx0_mean = np.mean(qx0_fit) - qy0_mean = np.mean(qy0_fit) - - if plot_range is None: - plot_range = 2 * np.max(qx0_fit - qx0_mean) - - cmap = kwargs.get("cmap", "RdBu_r") - kwargs.pop("cmap", None) - axsize = kwargs.get("axsize", (6, 2)) - kwargs.pop("axsize", None) - - show_image_grid( - lambda i: [ - qx0_meas - qx0_mean, - qx0_fit - qx0_mean, - qx0_res_plot, - qy0_meas - qy0_mean, - qy0_fit - qy0_mean, - qy0_res_plot, - ][i], - H=2, - W=3, - cmap=cmap, - axsize=axsize, - title=[ - "measured origin, x", - "fitorigin, x", - "residuals, x", - "measured origin, y", - "fitorigin, y", - "residuals, y", - ], - vmin=-1 * plot_range, - vmax=1 * plot_range, - intensity_range="absolute", - **kwargs, + # show + if plot: + self.show_origin_fit( + q_meas[0], + q_meas[1], + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask = mask, + plot_range = plot_range, + cmap = cmap, + **kwargs ) - # update calibration metadata - self.calibration.set_origin((qx0_fit, qy0_fit)) - self.setcal() - + # return if returncalc: return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals + + def show_origin_fit( + self, + qx0_meas, + qy0_meas, + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask = None, + plot_range = None, + cmap = 'RdBu_r', + **kwargs + ): + + # apply mask + if mask is not None: + qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask)) + qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask)) + qx0_residuals = np.ma.masked_array( + qx0_residuals, mask=np.logical_not(mask) + ) + qy0_residuals = np.ma.masked_array( + qy0_residuals, mask=np.logical_not(mask) + ) + qx0_mean = np.mean(qx0_fit) + qy0_mean = np.mean(qy0_fit) + + # set range + if plot_range is None: + plot_range = max(( + 1.5 * np.max(np.abs(qx0_fit - qx0_mean)), + 1.5 * np.max(np.abs(qy0_fit - qy0_mean)) + )) + + # set figsize + imsize_ratio = np.sqrt(qx0_meas.shape[1]/qx0_meas.shape[0]) + axsize = (3*imsize_ratio, 3/imsize_ratio) + + # plot + show( + [[qx0_meas - qx0_mean, + qx0_fit - qx0_mean, + qx0_residuals + ],[ + qy0_meas - qy0_mean, + qy0_fit - qy0_mean, + qy0_residuals + ]], + cmap = cmap, + axsize = axsize, + title = [ + "measured origin, x", + "fitorigin, x", + "residuals, x", + "measured origin, y", + "fitorigin, y", + "residuals, y", + ], + vmin = -1 * plot_range, + vmax = 1 * plot_range, + intensity_range="absolute", + **kwargs + ) + + return + + + + + def fit_p_ellipse( self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs ): diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 538c90825..5ec3b85f6 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -472,27 +472,28 @@ def fit_lattice_vectors( return self.bragg_vectors_indexed, self.g1g2_map def get_strain( - self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs + self, mask=None, coordinate_rotation=0, returncalc=False, **kwargs ): """ - mask: nd.array (bool) + Parameters + ---------- + mask : nd.array (bool) Use lattice vectors from g1g2_map scan positions wherever mask==True. If mask is None gets median strain map from entire field of view. If mask is not None, gets reference g1 and g2 from region and then calculates strain. - g_reference: nd.array of form [x,y] - G_reference (tupe): reference coordinate system for - xaxis_x and xaxis_y - flip_theta: bool - If True, flips rotation coordinate system - returncal: bool + coordinate_rotation : number + Rotate the reference coordinate system counterclockwise by this + amount, in degrees + returncal : bool It True, returns rotated map """ - # check the calstate + # confirm that the calstate hasn't changed assert ( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + # get the mask if mask is None: mask = self.mask # mask = np.ones(self.g1g2_map.shape, dtype="bool") @@ -505,129 +506,116 @@ def get_strain( # strain_map = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) # else: + # get the reference g1/g2 vectors g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + # find the strain strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - self.strainmap_g1g2 = strainmap_g1g2 - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) + # get the reference coordinate system + theta = np.radians(coordinate_rotation) + xaxis_x = np.cos(theta) + xaxis_y = np.sin(theta) + # get the strain in the reference coordinates strainmap_rotated = get_rotated_strain_map( self.strainmap_g1g2, - xaxis_x=g_reference[0], - xaxis_y=g_reference[1], - flip_theta=flip_theta, + xaxis_x = xaxis_x, + xaxis_y = xaxis_y, + flip_theta = False, ) + # store the data self.data[0] = strainmap_rotated["e_xx"].data self.data[1] = strainmap_rotated["e_yy"].data self.data[2] = strainmap_rotated["e_xy"].data self.data[3] = strainmap_rotated["theta"].data self.data[4] = strainmap_rotated["mask"].data - self.g_reference = g_reference - - figsize = kwargs.pop("figsize", (14, 4)) - vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) - vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) - ticknumber = kwargs.pop("ticknumber", 3) - bkgrd = kwargs.pop("bkgrd", False) - axes_plots = kwargs.pop("axes_plots", ()) + self.coordinate_rotation = coordinate_rotation + # plot the results fig, ax = self.show_strain( - vrange_exx=vrange_exx, - vrange_theta=vrange_theta, - ticknumber=ticknumber, - axes_plots=axes_plots, - bkgrd=bkgrd, - figsize=figsize, **kwargs, returnfig=True, ) + # modify masking if not np.all(mask == True): ax[0][0].imshow(mask, alpha=0.2, cmap="binary") ax[0][1].imshow(mask, alpha=0.2, cmap="binary") ax[1][0].imshow(mask, alpha=0.2, cmap="binary") ax[1][1].imshow(mask, alpha=0.2, cmap="binary") + # return if returncalc: return self.strainmap + def show_strain( self, - vrange_exx, - vrange_theta, + vrange = [-3,3], + vrange_theta = [-3,3], + vrange_exx=None, vrange_exy=None, vrange_eyy=None, - flip_theta=False, bkgrd=True, - show_cbars=("exx", "eyy", "exy", "theta"), + show_cbars=("eyy", "theta"), bordercolor="k", borderwidth=1, titlesize=24, ticklabelsize=16, ticknumber=5, unitlabelsize=24, - show_axes=False, - axes_position=(0, 0), - axes_length=10, - axes_width=1, - axes_color="w", - xaxis_space="Q", - labelaxes=True, - QR_rotation=0, - axes_labelsize=12, - axes_labelcolor="r", - axes_plots=("exx"), cmap="RdBu_r", + cmap_theta="PRGn", mask_color="k", + color_axes="k", + show_gvects=False, + color_gvects="r", + legend_camera_length = 1.6, layout=0, - figsize=(12, 12), + figsize=None, returnfig=False, ): """ - Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') + Display a strain map, showing the 4 strain components + (e_xx,e_yy,e_xy,theta), and masking each image with + strainmap.get_slice('mask') - Args: - vrange_exx (length 2 list or tuple): - vrange_theta (length 2 list or tuple): - vrange_exy (length 2 list or tuple): - vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle - bkgrd (bool): - show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a - tuple containing any, all, or none of ('exx','eyy','exy','theta'). - bordercolor (color): - borderwidth (number): - titlesize (number): - ticklabelsize (number): - ticknumber (number): number of ticks on colorbars - unitlabelsize (number): - show_axes (bool): - axes_x0 (number): - axes_y0 (number): - xaxis_x (number): - xaxis_y (number): - axes_length (number): - axes_width (number): - axes_color (color): - xaxis_space (string): must be 'Q' or 'R' - labelaxes (bool): - QR_rotation (number): - axes_labelsize (number): - axes_labelcolor (color): - axes_plots (tuple of strings): controls if coordinate axes showing the - orientation of the strain matrices are overlaid over any of the plots. - Must be a tuple of strings containing any, all, or none of - ('exx','eyy','exy','theta'). - cmap (colormap): - layout=0 (int): determines the layout of the grid which the strain components - will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). - figsize (length 2 tuple of numbers): - returnfig (bool): + Parameters + ---------- + vrange : length 2 list or tuple + vrange_theta : length 2 list or tuple + vrange_exx : length 2 list or tuple + vrange_exy : length 2 list or tuple + vrange_eyy :length 2 list or tuple + bkgrd : bool + show_cbars :tuple of strings + Show colorbars for the specified axes. Must be a tuple + containing any, all, or none of ('exx','eyy','exy','theta') + bordercolor : color + borderwidth : number + titlesize : number + ticklabelsize : number + ticknumber : number + number of ticks on colorbars + unitlabelsize : number + cmap : colormap + cmap_theta : colormap + mask_color : color + color_axes : color + show_gvects : bool + Toggles displaying the g-vectors in the legend + color_gvects : color + legend_camera_length : number + The distance the legend is viewed from; a smaller number yields + a larger legend + layout : int + determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize : length 2 tuple of numbers + returnfig : bool """ # Lookup table for different layouts assert layout in (0, 1, 2) @@ -639,10 +627,12 @@ def show_strain( layout_p = layout_lookup[layout] # Contrast limits + if vrange_exx is None: + vrange_exx = vrange if vrange_exy is None: - vrange_exy = vrange_exx + vrange_exy = vrange if vrange_eyy is None: - vrange_eyy = vrange_exx + vrange_eyy = vrange for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): assert len(vrange) == 2, "vranges must have length 2" vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 @@ -667,25 +657,33 @@ def show_strain( self.get_slice("theta").data, mask=self.get_slice("mask").data == False, ) - if flip_theta == True: - theta = -theta ## Plot - # modify the figsize according to the image aspect ratio - ratio = np.sqrt(self.rshape[1] / self.rshape[0]) - figsize_mean = np.mean(figsize) - figsize = (figsize_mean * ratio, figsize_mean / ratio) + # if figsize hasn't been set, set it based on the + # chosen layout and the image shape + if figsize is None: + ratio = np.sqrt(self.rshape[1]/self.rshape[0]) + if layout == 0: + figsize = (13*ratio,8/ratio) + elif layout == 1: + figsize = (10*ratio,4/ratio) + else: + figsize = (4*ratio,10/ratio) + # set up layout if layout == 0: - fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) =\ + plt.subplots(2, 3, figsize=figsize) elif layout == 1: figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) + fig, (ax11, ax12, ax21, ax22, ax_legend) =\ + plt.subplots(1, 5, figsize=figsize) else: figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) + fig, (ax11, ax12, ax21, ax22, ax_legend) =\ + plt.subplots(5, 1, figsize=figsize) # display images, returning cbar axis references cax11 = show( @@ -727,7 +725,7 @@ def show_strain( vmin=vmin_theta, vmax=vmax_theta, intensity_range="absolute", - cmap=cmap, + cmap=cmap_theta, mask=self.mask, mask_color=mask_color, returncax=True, @@ -815,49 +813,6 @@ def show_strain( else: cbax.axis("off") - # Add coordinate axes - if show_axes: - assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array( - [ - "exx" in axes_plots, - "eyy" in axes_plots, - "exy" in axes_plots, - "theta" in axes_plots, - ] - ) - for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): - if _show: - if xaxis_space == "R": - ax_addaxes( - _ax, - self.g_reference[0], - self.g_reference[1], - axes_length, - axes_position[0], - axes_position[1], - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - else: - ax_addaxes_QtoR( - _ax, - self.g_reference[0], - self.g_reference[1], - axes_length, - axes_position[0], - axes_position[1], - QR_rotation, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - # Add borders if bordercolor is not None: for ax in (ax11, ax12, ax21, ax22): @@ -867,6 +822,158 @@ def show_strain( ax.set_xticks([]) ax.set_yticks([]) + # Legend + + # for layout 0, combine vertical plots on the right end + if layout == 0: + # get gridspec object + gs = ax_legend1.get_gridspec() + # remove last two axes + ax_legend1.remove() + ax_legend2.remove() + # make new axis + ax_legend = fig.add_subplot(gs[:,-1]) + + # get the coordinate axes' directions + QRrot = self.calibration.get_QR_rotation() + rotation = np.sum([ + np.radians(self.coordinate_rotation), + QRrot + ]) + xaxis_vectx = np.cos(rotation) + xaxis_vecty = np.sin(rotation) + yaxis_vectx = np.cos(rotation+np.pi/2) + yaxis_vecty = np.sin(rotation+np.pi/2) + + # make the coordinate axes + ax_legend.arrow( + x = 0, + y = 0, + dx = xaxis_vecty, + dy = xaxis_vectx, + color = color_axes, + length_includes_head = True, + width = 0.01, + head_width = 0.1, + ) + ax_legend.arrow( + x = 0, + y = 0, + dx = yaxis_vecty, + dy = yaxis_vectx, + color = color_axes, + length_includes_head = True, + width = 0.01, + head_width = 0.1, + ) + ax_legend.text( + x = xaxis_vecty*1.12, + y = xaxis_vectx*1.12, + s = 'x', + fontsize = 14, + color = color_axes, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + ax_legend.text( + x = yaxis_vecty*1.12, + y = yaxis_vectx*1.12, + s = 'y', + fontsize = 14, + color = color_axes, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + + # make the g-vectors + if show_gvects: + + # get the g-vectors directions + g1q = np.array(self.g1) + g2q = np.array(self.g2) + g1norm = np.linalg.norm(g1q) + g2norm = np.linalg.norm(g2q) + g1q /= np.linalg.norm(g1norm) + g2q /= np.linalg.norm(g2norm) + # set the lengths + g_ratio = g2norm/g1norm + if g_ratio > 1: + g1q /= g_ratio + else: + g2q *= g_ratio + # rotate + R = np.array( + [ + [ np.cos(QRrot), np.sin(QRrot)], + [-np.sin(QRrot), np.cos(QRrot)] + ] + ) + g1_x,g1_y = np.matmul(g1q,R) + g2_x,g2_y = np.matmul(g2q,R) + + # draw the g vectors + ax_legend.arrow( + x = 0, + y = 0, + dx = g1_y*0.8, + dy = g1_x*0.8, + color = color_gvects, + length_includes_head = True, + width = 0.005, + head_width = 0.05, + ) + ax_legend.arrow( + x = 0, + y = 0, + dx = g2_y*0.8, + dy = g2_x*0.8, + color = color_gvects, + length_includes_head = True, + width = 0.005, + head_width = 0.05, + ) + ax_legend.text( + x = g1_y*0.96, + y = g1_x*0.96, + s = r'$g_1$', + fontsize = 12, + color = color_gvects, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + ax_legend.text( + x = g2_y*0.96, + y = g2_x*0.96, + s = r'$g_2$', + fontsize = 12, + color = color_gvects, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + + # find center and extent + xmin = np.min([0,0,xaxis_vectx,yaxis_vectx]) + xmax = np.max([0,0,xaxis_vectx,yaxis_vectx]) + ymin = np.min([0,0,xaxis_vecty,yaxis_vecty]) + ymax = np.max([0,0,xaxis_vecty,yaxis_vecty]) + if show_gvects: + xmin = np.min([xmin,g1_x,g2_x]) + xmax = np.max([xmax,g1_x,g2_x]) + ymin = np.min([ymin,g1_y,g2_y]) + ymax = np.max([ymax,g1_y,g2_y]) + x0 = np.mean([xmin,xmax]) + y0 = np.mean([ymin,ymax]) + xL = (xmax-x0) * legend_camera_length + yL = (ymax-y0) * legend_camera_length + + # set the extent and aspect + ax_legend.set_xlim([y0-yL,y0+yL]) + ax_legend.set_ylim([x0-xL,x0+xL]) + ax_legend.invert_yaxis() + ax_legend.set_aspect("equal") + ax_legend.axis('off') + + # show/return if not returnfig: plt.show() return From d01d7e6224110f95956b8507e9b5c77a5b6620a5 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Sat, 28 Oct 2023 09:59:28 +0100 Subject: [PATCH 138/176] updates --- py4DSTEM/braggvectors/braggvectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 3a9ccb1ea..26f8eb8f4 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -486,7 +486,7 @@ def _transform( R = np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ) - # apply + # rotate and flip if flip: positions = R @ np.vstack((ans["qy"], ans["qx"])) else: From 27a1c962335e757bcd2a6ebfcd0bff175cdfeedc Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Mon, 30 Oct 2023 09:03:51 -0700 Subject: [PATCH 139/176] bug in depth plotting --- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 77a5c69ea..6bcacd934 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3097,7 +3097,7 @@ def show_depth( rotated_object = np.roll( rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), - int(x1_0), + -int(x1_0), axis=1, ) From a23839e5ca0c9c21d3802f9df302b12dd7d5522e Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 30 Oct 2023 10:56:33 -0700 Subject: [PATCH 140/176] changing assert to warning --- py4DSTEM/process/diffraction/crystal_ACOM.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index cba68d8fb..0dcf5dad0 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -903,10 +903,12 @@ def match_single_pattern( """ # adding assert statement for checking self.orientation_ref is present - assert hasattr( + # adding assert statement for checking self.orientation_ref is present + if not hasattr( self, "orientation_ref" - ), "orientation_plan must be run with 'calculate_correlation_array=True'" - + ): + raise Warning("orientation_plan must be run with 'calculate_correlation_array=True'") + orientation = Orientation(num_matches=num_matches_return) if bragg_peaks.data.shape[0] < min_number_peaks: return orientation From 7ca7edb2a4259e1d44efd8b671b4c561bb9767f6 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Mon, 30 Oct 2023 10:57:11 -0700 Subject: [PATCH 141/176] black --- py4DSTEM/process/diffraction/crystal_ACOM.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 0dcf5dad0..650fe0583 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -903,12 +903,12 @@ def match_single_pattern( """ # adding assert statement for checking self.orientation_ref is present - # adding assert statement for checking self.orientation_ref is present - if not hasattr( - self, "orientation_ref" - ): - raise Warning("orientation_plan must be run with 'calculate_correlation_array=True'") - + # adding assert statement for checking self.orientation_ref is present + if not hasattr(self, "orientation_ref"): + raise Warning( + "orientation_plan must be run with 'calculate_correlation_array=True'" + ) + orientation = Orientation(num_matches=num_matches_return) if bragg_peaks.data.shape[0] < min_number_peaks: return orientation From 78f54905cdbbaabf713b1da90dd2d4cfd13257d6 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Mon, 30 Oct 2023 19:23:21 +0000 Subject: [PATCH 142/176] adds show_reference_directions --- py4DSTEM/process/strain/strain.py | 368 +++++++++++++++++++++++++++++- 1 file changed, 366 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 5ec3b85f6..aad317160 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -517,6 +517,8 @@ def get_strain( theta = np.radians(coordinate_rotation) xaxis_x = np.cos(theta) xaxis_y = np.sin(theta) + self.coordinate_rotation_degrees = coordinate_rotation + self.coordinate_rotation_radians = theta # get the strain in the reference coordinates strainmap_rotated = get_rotated_strain_map( @@ -532,7 +534,6 @@ def get_strain( self.data[2] = strainmap_rotated["e_xy"].data self.data[3] = strainmap_rotated["theta"].data self.data[4] = strainmap_rotated["mask"].data - self.coordinate_rotation = coordinate_rotation # plot the results fig, ax = self.show_strain( @@ -837,7 +838,7 @@ def show_strain( # get the coordinate axes' directions QRrot = self.calibration.get_QR_rotation() rotation = np.sum([ - np.radians(self.coordinate_rotation), + self.coordinate_rotation_radians, QRrot ]) xaxis_vectx = np.cos(rotation) @@ -981,6 +982,369 @@ def show_strain( axs = ((ax11, ax12), (ax21, ax22)) return fig, axs + + def show_reference_directions( + self, + im_uncal = None, + im_cal = None, + color_axes="y", + color_gvects="r", + origin_uncal = None, + origin_cal = None, + camera_length = 1.8, + visp_uncal = {'scaling' : 'log'}, + visp_cal = {'scaling' : 'log'}, + layout = "horizontal", + titlesize = 16, + size_labels = 14, + #ticklabelsize=16, + #ticknumber=5, + #unitlabelsize=24, + figsize = None, + returnfig = False, + ): + """ + Show the reference coordinate system used to compute the strain + overlaid over calibrated and uncalibrated diffraction space images. + + The diffraction images used can be specificied with the `im_uncal` + and `im_cal` arguments, and default to the uncalibrated and calibrated + Bragg vector maps. The `rotate_cal` argument causes the `im_cal` array + to be rotated by -QR rotation from the calibration metadata, so that an + uncalibrated image (like a raw diffraction image or mean or max + diffraction pattern) can be passed to the `im_cal` argument. + + Parameters + ---------- + im_uncal : 2d array or None + Uncalibrated diffraction space image to dispay; defaults to + the maximal diffraction image. + im_cal : 2d array or None + Calibrated diffraction space image to display; defaults to + the calibrated Bragg vector map. + color_axes : color + The color of the overlaid coordinate axes + color_gvects : color + The color of the g-vectors + origin_uncal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the uncalibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + origin_cal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the calibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + camera_length : number + Determines the length of the overlaid coordinate axes; a smaller + number yields larger axes. + visp_uncal : dict + Visualization parameters for the uncalibrated diffraction image. + visp_cal : dict + Visualization parameters for the calibrated diffraction image. + layout : str; either "horizontal" or "vertical" + Determines the layout of the visualization. + titlesize : number + The size of the plot titles + size_labels : number + The size of the axis labels + figsize : length 2 tuple of numbers or None + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Set up the figure + assert layout in ("horizontal", "vertical") + + # Set the figsize + if figsize is None: + ratio = np.sqrt(self.rshape[1]/self.rshape[0]) + if layout == "horizontal": + figsize = (10*ratio,8/ratio) + else: + figsize = (8*ratio,12/ratio) + + # Create the figure + if layout == "horizontal": + fig, (ax1, ax2) =\ + plt.subplots(1, 2, figsize=figsize) + else: + fig, (ax1, ax2) =\ + plt.subplots(2, 1, figsize=figsize) + + # prepare images + if im_uncal is None: + im_uncal = self.braggvectors.histogram( mode='raw' ) + if im_cal is None: + im_cal = self.braggvectors.histogram( mode='cal' ) + + # display images + show( + im_cal, + figax=(fig, ax1), + **visp_cal + ) + show( + im_uncal, + figax=(fig, ax2), + **visp_uncal + ) + ax1.set_title("Calibrated", size=titlesize) + ax2.set_title("Uncalibrated", size=titlesize) + + + # Get the coordinate axes + + # get the directions + + # calibrated + QRrot = self.calibration.get_QR_rotation() + rotation = np.sum([ + self.coordinate_rotation_radians, + QRrot + ]) + xaxis_cal = np.array([ + np.cos(rotation), + np.sin(rotation) + ]) + yaxis_cal = np.array([ + np.cos(rotation+np.pi/2), + np.sin(rotation+np.pi/2) + ]) + + # uncalibrated + xaxis_uncal = np.array([ + np.cos(self.coordinate_rotation_radians), + np.sin(self.coordinate_rotation_radians) + ]) + yaxis_uncal = np.array([ + np.cos(self.coordinate_rotation_radians+np.pi/2), + np.sin(self.coordinate_rotation_radians+np.pi/2) + ]) + # inversion + if self.calibration.get_QR_flip(): + xaxis_uncal = np.array([ + xaxis_uncal[1], + xaxis_uncal[0] + ]) + yaxis_uncal = np.array([ + yaxis_uncal[1], + yaxis_uncal[0] + ]) + + # set the lengths + Lmean = np.mean([im_cal.shape[0],im_cal.shape[1]])/2 + xaxis_cal *= Lmean/camera_length + yaxis_cal *= Lmean/camera_length + xaxis_uncal *= Lmean/camera_length + yaxis_uncal *= Lmean/camera_length + + # Get the g-vectors + + # calibrated + g1_cal = np.array(self.g1) + g2_cal = np.array(self.g2) + + # uncalibrated + R = np.array( + [ + [np.cos(QRrot), -np.sin(QRrot)], + [np.sin(QRrot), np.cos(QRrot)] + ] + ) + g1_uncal = np.matmul(g1_cal,R) + g2_uncal = np.matmul(g2_cal,R) + # inversion + if self.calibration.get_QR_flip(): + g1_uncal = np.array([ + g1_uncal[1], + g1_uncal[0] + ]) + g2_uncal = np.array([ + g2_uncal[1], + g2_uncal[0] + ]) + + + # Set origin positions + if origin_uncal is None: + origin_uncal = self.calibration.get_origin_mean() + if origin_cal is None: + origin_cal = self.calibration.get_origin_mean() + + # Draw calibrated coordinate axes + ax1.arrow( + x = origin_cal[1], + y = origin_cal[0], + dx = xaxis_cal[1], + dy = xaxis_cal[0], + color = color_axes, + length_includes_head = True, + width = 0.01, + head_width = 0.1, + ) + ax1.arrow( + x = origin_cal[1], + y = origin_cal[0], + dx = yaxis_cal[1], + dy = yaxis_cal[0], + color = color_axes, + length_includes_head = True, + width = 0.01, + head_width = 0.1, + ) + ax1.text( + x = origin_cal[1] + xaxis_cal[1]*1.12, + y = origin_cal[0] + xaxis_cal[0]*1.12, + s = 'x', + fontsize = size_labels, + color = color_axes, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + ax1.text( + x = origin_cal[1] + yaxis_cal[1]*1.12, + y = origin_cal[0] + yaxis_cal[0]*1.12, + s = 'y', + fontsize = size_labels, + color = color_axes, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + + # Draw uncalibrated coordinate axes + ax2.arrow( + x = origin_uncal[1], + y = origin_uncal[0], + dx = xaxis_uncal[1], + dy = xaxis_uncal[0], + color = color_axes, + length_includes_head = True, + width = 0.01, + head_width = 0.1, + ) + ax2.arrow( + x = origin_uncal[1], + y = origin_uncal[0], + dx = yaxis_uncal[1], + dy = yaxis_uncal[0], + color = color_axes, + length_includes_head = True, + width = 0.01, + head_width = 0.1, + ) + ax2.text( + x = origin_uncal[1] + xaxis_uncal[1]*1.12, + y = origin_uncal[0] + xaxis_uncal[0]*1.12, + s = 'x', + fontsize = size_labels, + color = color_axes, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + ax2.text( + x = origin_uncal[1] + yaxis_uncal[1]*1.12, + y = origin_uncal[0] + yaxis_uncal[0]*1.12, + s = 'y', + fontsize = size_labels, + color = color_axes, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + + + # Draw the calibrated g-vectors + + # draw the g vectors + ax1.arrow( + x = origin_cal[1], + y = origin_cal[0], + dx = g1_cal[1], + dy = g1_cal[0], + color = color_gvects, + length_includes_head = True, + width = 0.005, + head_width = 0.05, + ) + ax1.arrow( + x = origin_cal[1], + y = origin_cal[0], + dx = g2_cal[1], + dy = g2_cal[0], + color = color_gvects, + length_includes_head = True, + width = 0.005, + head_width = 0.05, + ) + ax1.text( + x = origin_cal[1] + g1_cal[1]*1.08, + y = origin_cal[0] + g1_cal[0]*1.08, + s = r'$g_1$', + fontsize = size_labels*0.88, + color = color_gvects, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + ax1.text( + x = origin_cal[1] + g2_cal[1]*1.08, + y = origin_cal[0] + g2_cal[0]*1.08, + s = r'$g_2$', + fontsize = size_labels*0.88, + color = color_gvects, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + + # Draw the uncalibrated g-vectors + + # draw the g vectors + ax2.arrow( + x = origin_uncal[1], + y = origin_uncal[0], + dx = g1_uncal[1], + dy = g1_uncal[0], + color = color_gvects, + length_includes_head = True, + width = 0.005, + head_width = 0.05, + ) + ax2.arrow( + x = origin_uncal[1], + y = origin_uncal[0], + dx = g2_uncal[1], + dy = g2_uncal[0], + color = color_gvects, + length_includes_head = True, + width = 0.005, + head_width = 0.05, + ) + ax2.text( + x = origin_uncal[1] + g1_uncal[1]*1.08, + y = origin_uncal[0] + g1_uncal[0]*1.08, + s = r'$g_1$', + fontsize = size_labels*0.88, + color = color_gvects, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + ax2.text( + x = origin_uncal[1] + g2_uncal[1]*1.08, + y = origin_uncal[0] + g2_uncal[0]*1.08, + s = r'$g_2$', + fontsize = size_labels*0.88, + color = color_gvects, + horizontalalignment = 'center', + verticalalignment = 'center', + ) + + + # show/return + if not returnfig: + plt.show() + return + else: + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs + def show_lattice_vectors( ar, x0, From 7bd2ceb7bdc4296038ed74917653f8487a8f60d4 Mon Sep 17 00:00:00 2001 From: Steven Zeltmann Date: Mon, 30 Oct 2023 16:55:35 -0400 Subject: [PATCH 143/176] change Warning to ValueError --- py4DSTEM/process/diffraction/crystal_ACOM.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 650fe0583..49be73b99 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -905,7 +905,7 @@ def match_single_pattern( # adding assert statement for checking self.orientation_ref is present # adding assert statement for checking self.orientation_ref is present if not hasattr(self, "orientation_ref"): - raise Warning( + raise ValueError( "orientation_plan must be run with 'calculate_correlation_array=True'" ) From 701b2755f76e21262cd74523e92402bf6d4ea176 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 30 Oct 2023 14:05:40 -0700 Subject: [PATCH 144/176] minor dpc bugfixes --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- py4DSTEM/process/phase/iterative_dpc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 04cfd6a60..f04a3c552 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1257,7 +1257,7 @@ def show_complex_CoM( if pixelsize is None: pixelsize = self._scan_sampling[0] if pixelunits is None: - pixelunits = r"$\AA$" + pixelunits = self._scan_units[0] figsize = kwargs.pop("figsize", (6, 6)) fig, ax = plt.subplots(figsize=figsize) diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index af3cbbb45..b390ce46d 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -799,6 +799,7 @@ def reconstruct( anti_gridding=anti_gridding, ) + self.error_iterations.append(self.error.item()) if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -807,7 +808,6 @@ def reconstruct( ].copy() ) ) - self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: if self._verbose: From dcc62a3dcc9edeac9f0e7f4daa97a047e2ce9ed0 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 30 Oct 2023 14:06:15 -0700 Subject: [PATCH 145/176] parallax DF limit bug, cropped property --- py4DSTEM/process/phase/iterative_parallax.py | 62 ++++++++++++++++---- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 21af22a37..daab204a0 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1098,26 +1098,46 @@ def subpixel_alignment( BF_size = np.array(self._stack_BF_no_window.shape[-2:]) self._DF_upsample_limit = np.max( - self._region_of_interest_shape / self._scan_shape + 2 * self._region_of_interest_shape / self._scan_shape ) self._BF_upsample_limit = ( - 2 * self._kr.max() / self._reciprocal_sampling[0] + 4 * self._kr.max() / self._reciprocal_sampling[0] ) / self._scan_shape.max() if self._device == "gpu": self._BF_upsample_limit = self._BF_upsample_limit.item() if kde_upsample_factor is None: - kde_upsample_factor = np.minimum( - self._BF_upsample_limit * 3 / 2, self._DF_upsample_limit - ) + if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit: + kde_upsample_factor = self._DF_upsample_limit - warnings.warn( - ( - f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " - f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." - ), - UserWarning, - ) + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (the " + f"dark-field upsampling limit)." + ), + UserWarning, + ) + + elif self._BF_upsample_limit * 3 / 2 > 1: + kde_upsample_factor = self._BF_upsample_limit * 3 / 2 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + else: + kde_upsample_factor = self._DF_upsample_limit * 2 / 3 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (2/3 times the " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f})." + ), + UserWarning, + ) if kde_upsample_factor < 1: raise ValueError("kde_upsample_factor must be larger than 1") @@ -2349,3 +2369,21 @@ def visualize( ax.set_title("Reconstructed Bright Field Image") return self + + @property + def object_cropped(self): + """cropped object""" + if hasattr(self, "_recon_phase_corrected"): + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_phase_corrected, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_phase_corrected) + else: + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_BF) From b7a7a5f15b589f573bad8491e12888585f813ada Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 30 Oct 2023 14:06:56 -0700 Subject: [PATCH 146/176] complex grid scalebar bug --- py4DSTEM/visualize/vis_special.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index d1efbd023..388b57e0a 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -829,7 +829,7 @@ def show_complex( for ax_flat in ax.flatten(): divider = make_axes_locatable(ax_flat) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") From b760b260715a6dbd738af2d7b0c457355b13c8d8 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 31 Oct 2023 11:07:22 +0000 Subject: [PATCH 147/176] updates --- py4DSTEM/process/strain/strain.py | 231 +++++++++++++++++------------- 1 file changed, 133 insertions(+), 98 deletions(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index aad317160..36edf47f7 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -557,27 +557,28 @@ def show_strain( self, vrange = [-3,3], vrange_theta = [-3,3], - vrange_exx=None, - vrange_exy=None, - vrange_eyy=None, - bkgrd=True, - show_cbars=("eyy", "theta"), - bordercolor="k", - borderwidth=1, - titlesize=24, - ticklabelsize=16, - ticknumber=5, - unitlabelsize=24, - cmap="RdBu_r", - cmap_theta="PRGn", - mask_color="k", - color_axes="k", - show_gvects=False, - color_gvects="r", + vrange_exx = None, + vrange_exy = None, + vrange_eyy = None, + bkgrd = True, + show_cbars = None, + bordercolor = "k", + borderwidth = 1, + titlesize = 18, + ticklabelsize = 10, + ticknumber = 5, + unitlabelsize = 16, + cmap = "RdBu_r", + cmap_theta = "PRGn", + mask_color = "k", + color_axes = "k", + show_gvects = True, + color_gvects = "r", legend_camera_length = 1.6, - layout=0, - figsize=None, - returnfig=False, + scale_gvects = 0.6, + layout = 0, + figsize = None, + returnfig = False, ): """ Display a strain map, showing the 4 strain components @@ -587,36 +588,59 @@ def show_strain( Parameters ---------- vrange : length 2 list or tuple + The colorbar intensity range for exx,eyy, and exy. vrange_theta : length 2 list or tuple + The colorbar intensity range for theta. vrange_exx : length 2 list or tuple + The colorbar intensity range for exx; overrides `vrange` + for exx vrange_exy : length 2 list or tuple - vrange_eyy :length 2 list or tuple + The colorbar intensity range for exy; overrides `vrange` + for exy + vrange_eyy : length 2 list or tuple + The colorbar intensity range for eyy; overrides `vrange` + for eyy bkgrd : bool - show_cbars :tuple of strings - Show colorbars for the specified axes. Must be a tuple - containing any, all, or none of ('exx','eyy','exy','theta') + Overlay a mask over background pixels + show_cbars : None or a tuple of strings + Show colorbars for the specified axes. Valid strings are + 'exx', 'eyy', 'exy', and 'theta'. bordercolor : color + Color for the image borders borderwidth : number + Width of the image borders titlesize : number + Size of the image titles ticklabelsize : number + Size of the colorbar ticks ticknumber : number - number of ticks on colorbars + Number of ticks on colorbars unitlabelsize : number + Size of the units label on the colorbars cmap : colormap + Colormap for exx, exy, and eyy cmap_theta : colormap + Colormap for theta mask_color : color + Color for the background mask color_axes : color + Color for the legend coordinate axes show_gvects : bool Toggles displaying the g-vectors in the legend color_gvects : color + Color for the legend g-vectors legend_camera_length : number The distance the legend is viewed from; a smaller number yields a larger legend + scale_gvects : number + Scaling for the legend g-vectors relative to the coordinate axes layout : int - determines the layout of the grid which the strain components + Determines the layout of the grid which the strain components will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). figsize : length 2 tuple of numbers + Size of the figure returnfig : bool + Toggles returning the figure """ # Lookup table for different layouts assert layout in (0, 1, 2) @@ -627,6 +651,20 @@ def show_strain( } layout_p = layout_lookup[layout] + # Set which colorbars to display + if show_cbars is None: + if np.all([v is None for v in ( + vrange_exx, + vrange_eyy, + vrange_exy, + )]): + show_cbars = ('eyy','theta') + else: + show_cbars = ('exx','eyy','exy','theta') + else: + assert np.all([v in ('exx','eyy','exy','theta') for v in show_cbars]) + + # Contrast limits if vrange_exx is None: vrange_exx = vrange @@ -836,11 +874,7 @@ def show_strain( ax_legend = fig.add_subplot(gs[:,-1]) # get the coordinate axes' directions - QRrot = self.calibration.get_QR_rotation() - rotation = np.sum([ - self.coordinate_rotation_radians, - QRrot - ]) + rotation = self.coordinate_rotation_radians xaxis_vectx = np.cos(rotation) xaxis_vecty = np.sin(rotation) yaxis_vectx = np.cos(rotation+np.pi/2) @@ -868,8 +902,8 @@ def show_strain( head_width = 0.1, ) ax_legend.text( - x = xaxis_vecty*1.12, - y = xaxis_vectx*1.12, + x = xaxis_vecty*1.16, + y = xaxis_vectx*1.16, s = 'x', fontsize = 14, color = color_axes, @@ -877,8 +911,8 @@ def show_strain( verticalalignment = 'center', ) ax_legend.text( - x = yaxis_vecty*1.12, - y = yaxis_vectx*1.12, + x = yaxis_vecty*1.16, + y = yaxis_vectx*1.16, s = 'y', fontsize = 14, color = color_axes, @@ -894,8 +928,8 @@ def show_strain( g2q = np.array(self.g2) g1norm = np.linalg.norm(g1q) g2norm = np.linalg.norm(g2q) - g1q /= np.linalg.norm(g1norm) - g2q /= np.linalg.norm(g2norm) + g1q /= g1norm + g2q /= g2norm # set the lengths g_ratio = g2norm/g1norm if g_ratio > 1: @@ -903,21 +937,23 @@ def show_strain( else: g2q *= g_ratio # rotate - R = np.array( - [ - [ np.cos(QRrot), np.sin(QRrot)], - [-np.sin(QRrot), np.cos(QRrot)] - ] - ) - g1_x,g1_y = np.matmul(g1q,R) - g2_x,g2_y = np.matmul(g2q,R) + #R = np.array( + # [ + # [ np.cos(QRrot), np.sin(QRrot)], + # [-np.sin(QRrot), np.cos(QRrot)] + # ] + #) + #g1_x,g1_y = np.matmul(g1q,R) + #g2_x,g2_y = np.matmul(g2q,R) + g1_x,g1_y = g1q + g2_x,g2_y = g2q # draw the g vectors ax_legend.arrow( x = 0, y = 0, - dx = g1_y*0.8, - dy = g1_x*0.8, + dx = g1_y*scale_gvects, + dy = g1_x*scale_gvects, color = color_gvects, length_includes_head = True, width = 0.005, @@ -926,16 +962,16 @@ def show_strain( ax_legend.arrow( x = 0, y = 0, - dx = g2_y*0.8, - dy = g2_x*0.8, + dx = g2_y*scale_gvects, + dy = g2_x*scale_gvects, color = color_gvects, length_includes_head = True, width = 0.005, head_width = 0.05, ) ax_legend.text( - x = g1_y*0.96, - y = g1_x*0.96, + x = g1_y*scale_gvects*1.2, + y = g1_x*scale_gvects*1.2, s = r'$g_1$', fontsize = 12, color = color_gvects, @@ -943,8 +979,8 @@ def show_strain( verticalalignment = 'center', ) ax_legend.text( - x = g2_y*0.96, - y = g2_x*0.96, + x = g2_y*scale_gvects*1.2, + y = g2_x*scale_gvects*1.2, s = r'$g_2$', fontsize = 12, color = color_gvects, @@ -987,7 +1023,7 @@ def show_reference_directions( self, im_uncal = None, im_cal = None, - color_axes="y", + color_axes="linen", color_gvects="r", origin_uncal = None, origin_cal = None, @@ -997,9 +1033,6 @@ def show_reference_directions( layout = "horizontal", titlesize = 16, size_labels = 14, - #ticklabelsize=16, - #ticknumber=5, - #unitlabelsize=24, figsize = None, returnfig = False, ): @@ -1097,11 +1130,7 @@ def show_reference_directions( # get the directions # calibrated - QRrot = self.calibration.get_QR_rotation() - rotation = np.sum([ - self.coordinate_rotation_radians, - QRrot - ]) + rotation = self.coordinate_rotation_radians xaxis_cal = np.array([ np.cos(rotation), np.sin(rotation) @@ -1112,13 +1141,18 @@ def show_reference_directions( ]) # uncalibrated + QRrot = self.calibration.get_QR_rotation() + rotation = np.sum([ + self.coordinate_rotation_radians, + -QRrot + ]) xaxis_uncal = np.array([ - np.cos(self.coordinate_rotation_radians), - np.sin(self.coordinate_rotation_radians) + np.cos(rotation), + np.sin(rotation) ]) yaxis_uncal = np.array([ - np.cos(self.coordinate_rotation_radians+np.pi/2), - np.sin(self.coordinate_rotation_radians+np.pi/2) + np.cos(rotation+np.pi/2), + np.sin(rotation+np.pi/2) ]) # inversion if self.calibration.get_QR_flip(): @@ -1172,6 +1206,7 @@ def show_reference_directions( origin_cal = self.calibration.get_origin_mean() # Draw calibrated coordinate axes + coordax_width = Lmean*2/100 ax1.arrow( x = origin_cal[1], y = origin_cal[0], @@ -1179,8 +1214,8 @@ def show_reference_directions( dy = xaxis_cal[0], color = color_axes, length_includes_head = True, - width = 0.01, - head_width = 0.1, + width = coordax_width, + head_width = coordax_width * 5, ) ax1.arrow( x = origin_cal[1], @@ -1189,12 +1224,12 @@ def show_reference_directions( dy = yaxis_cal[0], color = color_axes, length_includes_head = True, - width = 0.01, - head_width = 0.1, + width = coordax_width, + head_width = coordax_width * 5, ) ax1.text( - x = origin_cal[1] + xaxis_cal[1]*1.12, - y = origin_cal[0] + xaxis_cal[0]*1.12, + x = origin_cal[1] + xaxis_cal[1]*1.16, + y = origin_cal[0] + xaxis_cal[0]*1.16, s = 'x', fontsize = size_labels, color = color_axes, @@ -1202,8 +1237,8 @@ def show_reference_directions( verticalalignment = 'center', ) ax1.text( - x = origin_cal[1] + yaxis_cal[1]*1.12, - y = origin_cal[0] + yaxis_cal[0]*1.12, + x = origin_cal[1] + yaxis_cal[1]*1.16, + y = origin_cal[0] + yaxis_cal[0]*1.16, s = 'y', fontsize = size_labels, color = color_axes, @@ -1219,8 +1254,8 @@ def show_reference_directions( dy = xaxis_uncal[0], color = color_axes, length_includes_head = True, - width = 0.01, - head_width = 0.1, + width = coordax_width, + head_width = coordax_width * 5, ) ax2.arrow( x = origin_uncal[1], @@ -1229,12 +1264,12 @@ def show_reference_directions( dy = yaxis_uncal[0], color = color_axes, length_includes_head = True, - width = 0.01, - head_width = 0.1, + width = coordax_width, + head_width = coordax_width * 5, ) ax2.text( - x = origin_uncal[1] + xaxis_uncal[1]*1.12, - y = origin_uncal[0] + xaxis_uncal[0]*1.12, + x = origin_uncal[1] + xaxis_uncal[1]*1.16, + y = origin_uncal[0] + xaxis_uncal[0]*1.16, s = 'x', fontsize = size_labels, color = color_axes, @@ -1242,8 +1277,8 @@ def show_reference_directions( verticalalignment = 'center', ) ax2.text( - x = origin_uncal[1] + yaxis_uncal[1]*1.12, - y = origin_uncal[0] + yaxis_uncal[0]*1.12, + x = origin_uncal[1] + yaxis_uncal[1]*1.16, + y = origin_uncal[0] + yaxis_uncal[0]*1.16, s = 'y', fontsize = size_labels, color = color_axes, @@ -1262,8 +1297,8 @@ def show_reference_directions( dy = g1_cal[0], color = color_gvects, length_includes_head = True, - width = 0.005, - head_width = 0.05, + width = coordax_width * 0.5, + head_width = coordax_width * 2.5, ) ax1.arrow( x = origin_cal[1], @@ -1272,12 +1307,12 @@ def show_reference_directions( dy = g2_cal[0], color = color_gvects, length_includes_head = True, - width = 0.005, - head_width = 0.05, + width = coordax_width * 0.5, + head_width = coordax_width * 2.5, ) ax1.text( - x = origin_cal[1] + g1_cal[1]*1.08, - y = origin_cal[0] + g1_cal[0]*1.08, + x = origin_cal[1] + g1_cal[1]*1.16, + y = origin_cal[0] + g1_cal[0]*1.16, s = r'$g_1$', fontsize = size_labels*0.88, color = color_gvects, @@ -1285,8 +1320,8 @@ def show_reference_directions( verticalalignment = 'center', ) ax1.text( - x = origin_cal[1] + g2_cal[1]*1.08, - y = origin_cal[0] + g2_cal[0]*1.08, + x = origin_cal[1] + g2_cal[1]*1.16, + y = origin_cal[0] + g2_cal[0]*1.16, s = r'$g_2$', fontsize = size_labels*0.88, color = color_gvects, @@ -1304,8 +1339,8 @@ def show_reference_directions( dy = g1_uncal[0], color = color_gvects, length_includes_head = True, - width = 0.005, - head_width = 0.05, + width = coordax_width * 0.5, + head_width = coordax_width * 2.5, ) ax2.arrow( x = origin_uncal[1], @@ -1314,12 +1349,12 @@ def show_reference_directions( dy = g2_uncal[0], color = color_gvects, length_includes_head = True, - width = 0.005, - head_width = 0.05, + width = coordax_width * 0.5, + head_width = coordax_width * 2.5, ) ax2.text( - x = origin_uncal[1] + g1_uncal[1]*1.08, - y = origin_uncal[0] + g1_uncal[0]*1.08, + x = origin_uncal[1] + g1_uncal[1]*1.16, + y = origin_uncal[0] + g1_uncal[0]*1.16, s = r'$g_1$', fontsize = size_labels*0.88, color = color_gvects, @@ -1327,8 +1362,8 @@ def show_reference_directions( verticalalignment = 'center', ) ax2.text( - x = origin_uncal[1] + g2_uncal[1]*1.08, - y = origin_uncal[0] + g2_uncal[0]*1.08, + x = origin_uncal[1] + g2_uncal[1]*1.16, + y = origin_uncal[0] + g2_uncal[0]*1.16, s = r'$g_2$', fontsize = size_labels*0.88, color = color_gvects, From 6785b06ed129c3b8aeab7040f2c75a797484ee3b Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Tue, 31 Oct 2023 12:47:52 +0000 Subject: [PATCH 148/176] updates --- py4DSTEM/process/strain/strain.py | 204 ++++++++++++++++++------------ 1 file changed, 126 insertions(+), 78 deletions(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 36edf47f7..c22e2331a 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -18,8 +18,14 @@ get_strain_from_reference_g1g2, index_bragg_directions, ) -from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show -from py4DSTEM.visualize import ax_addaxes, ax_addaxes_QtoR +from py4DSTEM.visualize import ( + show, + add_bragg_index_labels, + add_pointlabels, + add_vector, + ax_addaxes, + ax_addaxes_QtoR +) warnings.simplefilter(action="always", category=UserWarning) @@ -30,13 +36,22 @@ class StrainMap(RealSlice, Data): """ - def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): + def __init__( + self, + braggvectors: BraggVectors, + name: Optional[str] = "strainmap" + ): """ - Accepts: - braggvectors (BraggVectors): BraggVectors for Strain Map - name (str): the name of the strainmap - Returns: - A new StrainMap instance. + Parameters + ---------- + braggvectors : BraggVectors + The Bragg vectors + name : str + The name of the strainmap + + Returns + ------- + A new StrainMap instance. """ assert isinstance( braggvectors, BraggVectors @@ -112,13 +127,6 @@ def qshape(self): def origin(self): return self.calibration.get_origin_mean() - @property - def mask(self): - try: - return self.g1g2_map["mask"].data.astype("bool") - except: - return np.ones(self.rshape, dtype=bool) - def reset_calstate(self): """ Resets the calibration state. This recomputes the BVM, and removes any computations @@ -170,9 +178,11 @@ def choose_lattice_vectors( Choose which lattice vectors to use for strain mapping. Overlays the bvm with the points detected via local 2D - maxima detection, plus an index for each point. User selects - 3 points using the overlaid indices, which are identified as - the origin and the termini of the lattice vectors g1 and g2. + maxima detection, plus an index for each point. Three points + are selected which correspond to the origin, and the basis + reciprocal lattice vectors g1 and g2. By default these are + automatically located; the user can override and select these + manually using the `index_*` arguments. Parameters ---------- @@ -386,25 +396,27 @@ def choose_lattice_vectors( def fit_lattice_vectors( self, - max_peak_spacing=2, - mask=None, - returncalc=False, + max_peak_spacing = 2, + mask = None, + returncalc = False ): """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - max_peak_spacing: float - Maximum distance from the ideal lattice points - to include a peak for indexing - mask: bool - Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - returncalc : bool - if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + Fit the basis lattice vectors g1 and g2 to the detected Bragg peaks + in each pattern. The fit uses all detected peaks which are within + a distance of `max_peak_spacing` of the indexed peaks determined + in `choose_lattice_vectors`. Bragg peaks used in the fit are weighted + by their intensity. + + Parameters + ---------- + max_peak_spacing : float + Maximum distance from the ideal lattice points to include a peak + for indexing + mask : 2d boolean array + A real space shaped Boolean mask indicating scan positions at which + to fit the lattice vectors. + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map """ # check the calstate assert ( @@ -420,6 +432,7 @@ def fit_lattice_vectors( mask.shape == self.braggvectors.Rshape ), "mask must have same shape as pointlistarray" assert mask.dtype == bool, "mask must be boolean" + self.mask = mask # set up new braggpeaks PLA indexed_braggpeaks = PointListArray( @@ -467,21 +480,36 @@ def fit_lattice_vectors( g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) self.g1g2_map = g1g2_map + # update the mask + g1g2_mask = self.g1g2_map["mask"].data.astype("bool") + self.mask = np.logical_and(self.mask, g1g2_mask) + # return if returncalc: return self.bragg_vectors_indexed, self.g1g2_map def get_strain( - self, mask=None, coordinate_rotation=0, returncalc=False, **kwargs + self, + gvects = None, + coordinate_rotation = 0, + returncalc = False, + **kwargs ): """ + Compute the strain as the deviation of the basis reciprocal lattice + vectors which have been fit at each scan position with respect to a + pair of reference lattice vectors, determined by the argument `gvects`. + Parameters ---------- - mask : nd.array (bool) - Use lattice vectors from g1g2_map scan positions - wherever mask==True. If mask is None gets median strain - map from entire field of view. If mask is not None, gets - reference g1 and g2 from region and then calculates strain. + gvects : None or 2d-array or tuple + Specifies how to select the reference lattice vectors. If None, + use the median of the fit lattice vectors over the whole dataset. + If a 2d array is passed, it should be real space shaped and boolean. + In this case, uses the median of the fit lattice vectors in all scan + positions where this array is True. Otherwise, should be a length 2 + tuple of length 2 array/list/tuples, which are used directly as + g1 and g2. coordinate_rotation : number Rotate the reference coordinate system counterclockwise by this amount, in degrees @@ -493,24 +521,29 @@ def get_strain( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - # get the mask - if mask is None: - mask = self.mask - # mask = np.ones(self.g1g2_map.shape, dtype="bool") - # strainmap_g1g2 = get_strain_from_reference_region( - # self.g1g2_map, - # mask=mask, - # ) - - # g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) - # strain_map = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - # else: - - # get the reference g1/g2 vectors - g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + # get the reference g-vectors + if gvects is None: + g1_ref, g2_ref = get_reference_g1g2( + self.g1g2_map, + self.mask + ) + elif isinstance(gvects, np.ndarray): + assert(gvects.shape == self.rshape) + assert(gvects.dtype == bool) + g1_ref, g2_ref = get_reference_g1g2( + self.g1g2_map, + np.logical_and(gvects, self.mask) + ) + else: + g1_ref = np.array(gvects[0]) + g2_ref = np.array(gvects[1]) # find the strain - strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) + strainmap_g1g2 = get_strain_from_reference_g1g2( + self.g1g2_map, + g1_ref, + g2_ref + ) self.strainmap_g1g2 = strainmap_g1g2 # get the reference coordinate system @@ -541,18 +574,36 @@ def get_strain( returnfig=True, ) - # modify masking - if not np.all(mask == True): - ax[0][0].imshow(mask, alpha=0.2, cmap="binary") - ax[0][1].imshow(mask, alpha=0.2, cmap="binary") - ax[1][0].imshow(mask, alpha=0.2, cmap="binary") - ax[1][1].imshow(mask, alpha=0.2, cmap="binary") - # return if returncalc: return self.strainmap + def get_reference_g1g2( + self, + ROI + ): + """ + Get reference g1,g2 vectors by taking the median fit vectors + in the specified ROI. + + Parameters + ---------- + ROI : real space shaped 2d boolean ndarray + Use scan positions where ROI is True + + Returns + ------- + g1_ref,g2_ref : 2 tuple of length 2 ndarrays + """ + g1_ref, g2_ref = get_reference_g1g2( + self.g1g2_map, + ROI + ) + return g1_ref, g2_ref + + + def show_strain( self, vrange = [-3,3], @@ -936,15 +987,6 @@ def show_strain( g1q /= g_ratio else: g2q *= g_ratio - # rotate - #R = np.array( - # [ - # [ np.cos(QRrot), np.sin(QRrot)], - # [-np.sin(QRrot), np.cos(QRrot)] - # ] - #) - #g1_x,g1_y = np.matmul(g1q,R) - #g2_x,g2_y = np.matmul(g2q,R) g1_x,g1_y = g1q g2_x,g2_y = g2q @@ -1393,7 +1435,10 @@ def show_lattice_vectors( returnfig=False, **kwargs, ): - """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" + """ + Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). + g1 and g2 are 2-tuples (gx,gy). + """ fig, ax = show(ar, returnfig=True, **kwargs) # Add vectors @@ -1446,10 +1491,13 @@ def show_bragg_indexing( """ Shows an array with an overlay describing the Bragg directions - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. + Parameters + ---------- + ar : np.ndarray + The display image + bragg_directions : PointList + The Bragg scattering directions. Must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. """ assert isinstance(bragg_directions, PointList) for k in ("qx", "qy", "h", "k"): From 2b3c251a9ecdbda8b67c32cf6cae7d5a81933d5e Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Wed, 1 Nov 2023 11:37:23 +0000 Subject: [PATCH 149/176] updates --- py4DSTEM/process/strain/latticevectors.py | 61 +++++----- py4DSTEM/process/strain/strain.py | 131 ++++++++++++++++++---- 2 files changed, 146 insertions(+), 46 deletions(-) diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py index 26c8d66a5..ba9bb4fcf 100644 --- a/py4DSTEM/process/strain/latticevectors.py +++ b/py4DSTEM/process/strain/latticevectors.py @@ -258,7 +258,13 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): ) # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): + for Rx, Ry in tqdmnd( + braggpeaks.shape[0], + braggpeaks.shape[1], + desc="Fitting lattice vectors", + unit="DP", + unit_scale=True, + ): braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( braggpeaks_curr, x0, y0, minNumPeaks @@ -359,32 +365,37 @@ def get_strain_from_reference_g1g2(g1g2_map, g1, g2): g2x, g2y = g2 M = np.array([[g1x, g1y], [g2x, g2y]]) - for Rx in range(R_Nx): - for Ry in range(R_Ny): - # Get lattice vectors for DP at Rx,Ry - alpha = np.array( + for Rx, Ry in tqdmnd( + R_Nx, + R_Ny, + desc="Calculating strain", + unit="DP", + unit_scale=True, + ): + # Get lattice vectors for DP at Rx,Ry + alpha = np.array( + [ [ - [ - g1g2_map.get_slice("g1x").data[Rx, Ry], - g1g2_map.get_slice("g1y").data[Rx, Ry], - ], - [ - g1g2_map.get_slice("g2x").data[Rx, Ry], - g1g2_map.get_slice("g2y").data[Rx, Ry], - ], - ] - ) - # Get transformation matrix - beta = lstsq(M, alpha, rcond=None)[0].T - - # Get the infinitesimal strain matrix - strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] - strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] - strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 - strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 - strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ - Rx, Ry + g1g2_map.get_slice("g1x").data[Rx, Ry], + g1g2_map.get_slice("g1y").data[Rx, Ry], + ], + [ + g1g2_map.get_slice("g2x").data[Rx, Ry], + g1g2_map.get_slice("g2y").data[Rx, Ry], + ], ] + ) + # Get transformation matrix + beta = lstsq(M, alpha, rcond=None)[0].T + + # Get the infinitesimal strain matrix + strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] + strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] + strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 + strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 + strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ + Rx, Ry + ] return strain_map diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index c22e2331a..dbbba5881 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -4,6 +4,8 @@ from typing import Optional import matplotlib.pyplot as plt +from matplotlib.patches import Circle +from matplotlib.collections import PatchCollection from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np from py4DSTEM import PointList, PointListArray, tqdmnd @@ -27,8 +29,6 @@ ax_addaxes_QtoR ) -warnings.simplefilter(action="always", category=UserWarning) - class StrainMap(RealSlice, Data): """ @@ -144,7 +144,7 @@ def reset_calstate(self): # Class methods - def choose_lattice_vectors( + def choose_basis_vectors( self, index_g1=None, index_g2=None, @@ -175,7 +175,7 @@ def choose_lattice_vectors( returnfig=False, ): """ - Choose which lattice vectors to use for strain mapping. + Choose basis lattice vectors g1 and g2 for strain mapping. Overlays the bvm with the points detected via local 2D maxima detection, plus an index for each point. Three points @@ -239,7 +239,7 @@ def choose_lattice_vectors( Returns ------- - (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or the latter two """ # validate inputs for i in (index_origin, index_g1, index_g2): @@ -394,27 +394,99 @@ def choose_lattice_vectors( else: return - def fit_lattice_vectors( + def set_max_peak_spacing( + self, + max_peak_spacing, + returnfig = False, + **vis_params, + ): + """ + Set the size of the regions of diffraction space in which detected Bragg + peaks will be indexed and included in subsequent fitting of basis + vectors, and visualize those regions. + + Parameters + ---------- + max_peak_spacing : number + The maximum allowable distance between a detected Bragg peak and + the indexed maxima found in `choose_basis_vectors` for the detected + peak to be indexed + returnfig : bool + Toggles returning the figure + vis_params : dict + Any additional arguments are passed to the `show` function when + visualization the BVM + """ + # set the max peak spacing + self.max_peak_spacing = max_peak_spacing + + # make the figure + fig,ax = show( + self.bvm.data, + returnfig=True, + **vis_params, + ) + + # make the circle patch collection + patches = [] + qx = self.braggdirections['qx'] + qy = self.braggdirections['qy'] + origin = self.origin + for idx in range(len(qx)): + c = Circle( + xy = ( + qy[idx] + origin[1], + qx[idx] + origin[0] + ), + radius = self.max_peak_spacing, + edgecolor = 'r', + fill = False + ) + patches.append(c) + pc = PatchCollection(patches, match_original=True) + + # draw the circles + ax.add_collection(pc) + + # return + if returnfig: + return fig,ax + else: + plt.show() + + + def fit_basis_vectors( self, - max_peak_spacing = 2, mask = None, + max_peak_spacing = None, + vis_params = {}, returncalc = False ): """ - Fit the basis lattice vectors g1 and g2 to the detected Bragg peaks - in each pattern. The fit uses all detected peaks which are within - a distance of `max_peak_spacing` of the indexed peaks determined - in `choose_lattice_vectors`. Bragg peaks used in the fit are weighted - by their intensity. + Fit the basis lattice vectors to the detected Bragg peaks at each + scan position. + + First, the lattice vectors at each scan position are indexed using the + basis vectors g1 and g2 specified previously with `choose_basis_vectors` + Detected Bragg peaks which are farther from the set of lattice vectors + found in `choose_basis vectors` than the maximum peak spacing are + ignored; the maximum peak spacing can be set previously by calling + `set_max_peak_spacing` or by specifying the `max_peak_spacing` argument + here. A fit is then performed to refine the values of g1 and g2 at each + scan position, fitting the basis vectors to all detected and indexed + peaks, weighting the peaks according to their intensity. Parameters ---------- - max_peak_spacing : float - Maximum distance from the ideal lattice points to include a peak - for indexing mask : 2d boolean array A real space shaped Boolean mask indicating scan positions at which to fit the lattice vectors. + max_peak_spacing : float + Maximum distance from the ideal lattice points to include a peak + for indexing + vis_params : dict + Visualization parameters for showing the max peak spacing; ignored + if `max_peak_spacing` is not set returncalc : bool if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map """ @@ -423,9 +495,17 @@ def fit_lattice_vectors( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - ### add indices to the bragg vectors + # handle the max peak spacing + if max_peak_spacing is not None: + self.set_max_peak_spacing( + max_peak_spacing, + **vis_params + ) + assert(hasattr(self,'max_peak_spacing')), "Set the maximum peak spacing!" + + # index the bragg vectors - # validate mask + # handle the mask if mask is None: mask = np.ones(self.braggvectors.Rshape, dtype=bool) assert ( @@ -445,10 +525,17 @@ def fit_lattice_vectors( ], shape=self.braggvectors.Rshape, ) - calstate = self.braggvectors.calstate # loop over all the scan positions - for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + # and perform indexing, excluding peaks outside of max_peak_spacing + calstate = self.braggvectors.calstate + for Rx, Ry in tqdmnd( + mask.shape[0], + mask.shape[1], + desc="Indexing Bragg scattering", + unit="DP", + unit_scale=True, + ): if mask[Rx, Ry]: pl = self.braggvectors.get_vectors( Rx, @@ -464,7 +551,7 @@ def fit_lattice_vectors( pl.data["qy"][i] - self.braggdirections.data["qy"], ) ind = np.argmin(r) - if r[ind] <= max_peak_spacing: + if r[ind] <= self.max_peak_spacing: indexed_braggpeaks[Rx, Ry].add_data_by_field( ( pl.data["qx"][i], @@ -476,7 +563,7 @@ def fit_lattice_vectors( ) self.bragg_vectors_indexed = indexed_braggpeaks - ### fit bragg vectors + # fit bragg vectors g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) self.g1g2_map = g1g2_map @@ -488,6 +575,7 @@ def fit_lattice_vectors( if returncalc: return self.bragg_vectors_indexed, self.g1g2_map + def get_strain( self, gvects = None, @@ -1563,3 +1651,4 @@ def _get_constructor_args(cls, group): "name": ar_constr_args["name"], } return args + From fa64065ba6300c6c024043c5e20c12575c230336 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Wed, 1 Nov 2023 16:44:06 +0000 Subject: [PATCH 150/176] autoformats --- py4DSTEM/braggvectors/braggvector_methods.py | 72 +- py4DSTEM/process/calibration/rotation.py | 148 ++-- py4DSTEM/process/strain/strain.py | 720 +++++++++---------- 3 files changed, 424 insertions(+), 516 deletions(-) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 3ca898609..669817788 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -519,7 +519,7 @@ def fit_origin( mask_check_data=True, plot=True, plot_range=None, - cmap = 'RdBu_r', + cmap="RdBu_r", returncalc=True, **kwargs, ): @@ -583,17 +583,16 @@ def fit_origin( qy0_fit, qx0_residuals, qy0_residuals, - mask = mask, - plot_range = plot_range, - cmap = cmap, - **kwargs + mask=mask, + plot_range=plot_range, + cmap=cmap, + **kwargs, ) # return if returncalc: return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals - def show_origin_fit( self, qx0_meas, @@ -602,49 +601,42 @@ def show_origin_fit( qy0_fit, qx0_residuals, qy0_residuals, - mask = None, - plot_range = None, - cmap = 'RdBu_r', - **kwargs - ): - + mask=None, + plot_range=None, + cmap="RdBu_r", + **kwargs, + ): # apply mask if mask is not None: qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask)) qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask)) - qx0_residuals = np.ma.masked_array( - qx0_residuals, mask=np.logical_not(mask) - ) - qy0_residuals = np.ma.masked_array( - qy0_residuals, mask=np.logical_not(mask) - ) + qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask)) + qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask)) qx0_mean = np.mean(qx0_fit) qy0_mean = np.mean(qy0_fit) # set range if plot_range is None: - plot_range = max(( - 1.5 * np.max(np.abs(qx0_fit - qx0_mean)), - 1.5 * np.max(np.abs(qy0_fit - qy0_mean)) - )) + plot_range = max( + ( + 1.5 * np.max(np.abs(qx0_fit - qx0_mean)), + 1.5 * np.max(np.abs(qy0_fit - qy0_mean)), + ) + ) # set figsize - imsize_ratio = np.sqrt(qx0_meas.shape[1]/qx0_meas.shape[0]) - axsize = (3*imsize_ratio, 3/imsize_ratio) + imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0]) + axsize = (3 * imsize_ratio, 3 / imsize_ratio) # plot show( - [[qx0_meas - qx0_mean, - qx0_fit - qx0_mean, - qx0_residuals - ],[ - qy0_meas - qy0_mean, - qy0_fit - qy0_mean, - qy0_residuals - ]], - cmap = cmap, - axsize = axsize, - title = [ + [ + [qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals], + [qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals], + ], + cmap=cmap, + axsize=axsize, + title=[ "measured origin, x", "fitorigin, x", "residuals, x", @@ -652,18 +644,14 @@ def show_origin_fit( "fitorigin, y", "residuals, y", ], - vmin = -1 * plot_range, - vmax = 1 * plot_range, + vmin=-1 * plot_range, + vmax=1 * plot_range, intensity_range="absolute", - **kwargs + **kwargs, ) return - - - - def fit_p_ellipse( self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs ): diff --git a/py4DSTEM/process/calibration/rotation.py b/py4DSTEM/process/calibration/rotation.py index 21134a352..aaf8a49ce 100644 --- a/py4DSTEM/process/calibration/rotation.py +++ b/py4DSTEM/process/calibration/rotation.py @@ -10,24 +10,24 @@ def compare_QR_rotation( im_R, im_Q, QR_rotation, - R_rotation = 0, - R_position = None, - Q_position = None, - R_pos_anchor = 'center', - Q_pos_anchor = 'center', - R_length = 0.33, - Q_length = 0.33, - R_width = 0.001, - Q_width = 0.001, - R_head_length_adjust = 1, - Q_head_length_adjust = 1, - R_head_width_adjust = 1, - Q_head_width_adjust = 1, - R_color = 'r', - Q_color = 'r', - figsize = (10,5), - returnfig = False - ): + R_rotation=0, + R_position=None, + Q_position=None, + R_pos_anchor="center", + Q_pos_anchor="center", + R_length=0.33, + Q_length=0.33, + R_width=0.001, + Q_width=0.001, + R_head_length_adjust=1, + Q_head_length_adjust=1, + R_head_width_adjust=1, + Q_head_width_adjust=1, + R_color="r", + Q_color="r", + figsize=(10, 5), + returnfig=False, +): """ Visualize a rotational offset between an image in real space, e.g. a STEM virtual image, and an image in diffraction space, e.g. a defocused CBED @@ -89,94 +89,90 @@ def compare_QR_rotation( # parse inputs if R_position is None: R_position = ( - im_R.shape[0]/2, - im_R.shape[1]/2, + im_R.shape[0] / 2, + im_R.shape[1] / 2, ) if Q_position is None: Q_position = ( - im_Q.shape[0]/2, - im_Q.shape[1]/2, + im_Q.shape[0] / 2, + im_Q.shape[1] / 2, ) R_length = np.mean(im_R.shape) * R_length Q_length = np.mean(im_Q.shape) * Q_length - assert R_pos_anchor in ('center','tail','head') - assert Q_pos_anchor in ('center','tail','head') + assert R_pos_anchor in ("center", "tail", "head") + assert Q_pos_anchor in ("center", "tail", "head") # compute positions - rpos_x,rpos_y = R_position - qpos_x,qpos_y = Q_position + rpos_x, rpos_y = R_position + qpos_x, qpos_y = Q_position R_rot_rad = np.radians(R_rotation) - Q_rot_rad = np.radians(R_rotation+QR_rotation) + Q_rot_rad = np.radians(R_rotation + QR_rotation) rvecx = np.cos(R_rot_rad) rvecy = np.sin(R_rot_rad) qvecx = np.cos(Q_rot_rad) qvecy = np.sin(Q_rot_rad) - if R_pos_anchor == 'center': - x0_r = rpos_x - rvecx*R_length/2 - y0_r = rpos_y - rvecy*R_length/2 - x1_r = rpos_x + rvecx*R_length/2 - y1_r = rpos_y + rvecy*R_length/2 - elif R_pos_anchor == 'tail': + if R_pos_anchor == "center": + x0_r = rpos_x - rvecx * R_length / 2 + y0_r = rpos_y - rvecy * R_length / 2 + x1_r = rpos_x + rvecx * R_length / 2 + y1_r = rpos_y + rvecy * R_length / 2 + elif R_pos_anchor == "tail": x0_r = rpos_x y0_r = rpos_y - x1_r = rpos_x + rvecx*R_length - y1_r = rpos_y + rvecy*R_length - elif R_pos_anchor == 'head': - x0_r = rpos_x - rvecx*R_length - y0_r = rpos_y - rvecy*R_length + x1_r = rpos_x + rvecx * R_length + y1_r = rpos_y + rvecy * R_length + elif R_pos_anchor == "head": + x0_r = rpos_x - rvecx * R_length + y0_r = rpos_y - rvecy * R_length x1_r = rpos_x y1_r = rpos_y else: raise Exception(f"Invalid value for R_pos_anchor {R_pos_anchor}") - if Q_pos_anchor == 'center': - x0_q = qpos_x - qvecx*Q_length/2 - y0_q = qpos_y - qvecy*Q_length/2 - x1_q = qpos_x + qvecx*Q_length/2 - y1_q = qpos_y + qvecy*Q_length/2 - elif Q_pos_anchor == 'tail': + if Q_pos_anchor == "center": + x0_q = qpos_x - qvecx * Q_length / 2 + y0_q = qpos_y - qvecy * Q_length / 2 + x1_q = qpos_x + qvecx * Q_length / 2 + y1_q = qpos_y + qvecy * Q_length / 2 + elif Q_pos_anchor == "tail": x0_q = qpos_x y0_q = qpos_y - x1_q = qpos_x + qvecx*Q_length - y1_q = qpos_y + qvecy*Q_length - elif Q_pos_anchor == 'head': - x0_q = qpos_x - qvecx*Q_length - y0_q = qpos_y - qvecy*Q_length + x1_q = qpos_x + qvecx * Q_length + y1_q = qpos_y + qvecy * Q_length + elif Q_pos_anchor == "head": + x0_q = qpos_x - qvecx * Q_length + y0_q = qpos_y - qvecy * Q_length x1_q = qpos_x y1_q = qpos_y else: raise Exception(f"Invalid value for Q_pos_anchor {Q_pos_anchor}") # make the figure - axsize = (figsize[0]/2,figsize[1]) - fig,axs = show( - [im_R,im_Q], - returnfig = True, - axsize = axsize + axsize = (figsize[0] / 2, figsize[1]) + fig, axs = show([im_R, im_Q], returnfig=True, axsize=axsize) + axs[0, 0].arrow( + x=y0_r, + y=x0_r, + dx=y1_r - y0_r, + dy=x1_r - x0_r, + color=R_color, + length_includes_head=True, + width=R_width, + head_width=R_length * R_head_width_adjust * 0.072, + head_length=R_length * R_head_length_adjust * 0.1, ) - axs[0,0].arrow( - x = y0_r, - y = x0_r, - dx = y1_r - y0_r, - dy = x1_r - x0_r, - color = R_color, - length_includes_head = True, - width = R_width, - head_width = R_length*R_head_width_adjust*0.072, - head_length = R_length*R_head_length_adjust*0.1 - ) - axs[0,1].arrow( - x = y0_q, - y = x0_q, - dx = y1_q - y0_q, - dy = x1_q - x0_q, - color = Q_color, - length_includes_head = True, - width = Q_width, - head_width = Q_length*Q_head_width_adjust*0.072, - head_length = Q_length*Q_head_length_adjust*0.1 + axs[0, 1].arrow( + x=y0_q, + y=x0_q, + dx=y1_q - y0_q, + dy=x1_q - x0_q, + color=Q_color, + length_includes_head=True, + width=Q_width, + head_width=Q_length * Q_head_width_adjust * 0.072, + head_length=Q_length * Q_head_length_adjust * 0.1, ) if returnfig: - return fig,axs + return fig, axs else: plt.show() diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index dbbba5881..d4636b84c 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -26,7 +26,7 @@ add_pointlabels, add_vector, ax_addaxes, - ax_addaxes_QtoR + ax_addaxes_QtoR, ) @@ -36,11 +36,7 @@ class StrainMap(RealSlice, Data): """ - def __init__( - self, - braggvectors: BraggVectors, - name: Optional[str] = "strainmap" - ): + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): """ Parameters ---------- @@ -397,7 +393,7 @@ def choose_basis_vectors( def set_max_peak_spacing( self, max_peak_spacing, - returnfig = False, + returnfig=False, **vis_params, ): """ @@ -421,7 +417,7 @@ def set_max_peak_spacing( self.max_peak_spacing = max_peak_spacing # make the figure - fig,ax = show( + fig, ax = show( self.bvm.data, returnfig=True, **vis_params, @@ -429,18 +425,15 @@ def set_max_peak_spacing( # make the circle patch collection patches = [] - qx = self.braggdirections['qx'] - qy = self.braggdirections['qy'] + qx = self.braggdirections["qx"] + qy = self.braggdirections["qy"] origin = self.origin for idx in range(len(qx)): c = Circle( - xy = ( - qy[idx] + origin[1], - qx[idx] + origin[0] - ), - radius = self.max_peak_spacing, - edgecolor = 'r', - fill = False + xy=(qy[idx] + origin[1], qx[idx] + origin[0]), + radius=self.max_peak_spacing, + edgecolor="r", + fill=False, ) patches.append(c) pc = PatchCollection(patches, match_original=True) @@ -450,17 +443,12 @@ def set_max_peak_spacing( # return if returnfig: - return fig,ax + return fig, ax else: plt.show() - def fit_basis_vectors( - self, - mask = None, - max_peak_spacing = None, - vis_params = {}, - returncalc = False + self, mask=None, max_peak_spacing=None, vis_params={}, returncalc=False ): """ Fit the basis lattice vectors to the detected Bragg peaks at each @@ -497,11 +485,8 @@ def fit_basis_vectors( # handle the max peak spacing if max_peak_spacing is not None: - self.set_max_peak_spacing( - max_peak_spacing, - **vis_params - ) - assert(hasattr(self,'max_peak_spacing')), "Set the maximum peak spacing!" + self.set_max_peak_spacing(max_peak_spacing, **vis_params) + assert hasattr(self, "max_peak_spacing"), "Set the maximum peak spacing!" # index the bragg vectors @@ -575,13 +560,8 @@ def fit_basis_vectors( if returncalc: return self.bragg_vectors_indexed, self.g1g2_map - def get_strain( - self, - gvects = None, - coordinate_rotation = 0, - returncalc = False, - **kwargs + self, gvects=None, coordinate_rotation=0, returncalc=False, **kwargs ): """ Compute the strain as the deviation of the basis reciprocal lattice @@ -611,27 +591,19 @@ def get_strain( # get the reference g-vectors if gvects is None: - g1_ref, g2_ref = get_reference_g1g2( - self.g1g2_map, - self.mask - ) + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, self.mask) elif isinstance(gvects, np.ndarray): - assert(gvects.shape == self.rshape) - assert(gvects.dtype == bool) + assert gvects.shape == self.rshape + assert gvects.dtype == bool g1_ref, g2_ref = get_reference_g1g2( - self.g1g2_map, - np.logical_and(gvects, self.mask) + self.g1g2_map, np.logical_and(gvects, self.mask) ) else: g1_ref = np.array(gvects[0]) g2_ref = np.array(gvects[1]) # find the strain - strainmap_g1g2 = get_strain_from_reference_g1g2( - self.g1g2_map, - g1_ref, - g2_ref - ) + strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) self.strainmap_g1g2 = strainmap_g1g2 # get the reference coordinate system @@ -644,9 +616,9 @@ def get_strain( # get the strain in the reference coordinates strainmap_rotated = get_rotated_strain_map( self.strainmap_g1g2, - xaxis_x = xaxis_x, - xaxis_y = xaxis_y, - flip_theta = False, + xaxis_x=xaxis_x, + xaxis_y=xaxis_y, + flip_theta=False, ) # store the data @@ -666,11 +638,7 @@ def get_strain( if returncalc: return self.strainmap - - def get_reference_g1g2( - self, - ROI - ): + def get_reference_g1g2(self, ROI): """ Get reference g1,g2 vectors by taking the median fit vectors in the specified ROI. @@ -684,40 +652,35 @@ def get_reference_g1g2( ------- g1_ref,g2_ref : 2 tuple of length 2 ndarrays """ - g1_ref, g2_ref = get_reference_g1g2( - self.g1g2_map, - ROI - ) + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, ROI) return g1_ref, g2_ref - - def show_strain( self, - vrange = [-3,3], - vrange_theta = [-3,3], - vrange_exx = None, - vrange_exy = None, - vrange_eyy = None, - bkgrd = True, - show_cbars = None, - bordercolor = "k", - borderwidth = 1, - titlesize = 18, - ticklabelsize = 10, - ticknumber = 5, - unitlabelsize = 16, - cmap = "RdBu_r", - cmap_theta = "PRGn", - mask_color = "k", - color_axes = "k", - show_gvects = True, - color_gvects = "r", - legend_camera_length = 1.6, - scale_gvects = 0.6, - layout = 0, - figsize = None, - returnfig = False, + vrange=[-3, 3], + vrange_theta=[-3, 3], + vrange_exx=None, + vrange_exy=None, + vrange_eyy=None, + bkgrd=True, + show_cbars=None, + bordercolor="k", + borderwidth=1, + titlesize=18, + ticklabelsize=10, + ticknumber=5, + unitlabelsize=16, + cmap="RdBu_r", + cmap_theta="PRGn", + mask_color="k", + color_axes="k", + show_gvects=True, + color_gvects="r", + legend_camera_length=1.6, + scale_gvects=0.6, + layout=0, + figsize=None, + returnfig=False, ): """ Display a strain map, showing the 4 strain components @@ -792,17 +755,21 @@ def show_strain( # Set which colorbars to display if show_cbars is None: - if np.all([v is None for v in ( - vrange_exx, - vrange_eyy, - vrange_exy, - )]): - show_cbars = ('eyy','theta') + if np.all( + [ + v is None + for v in ( + vrange_exx, + vrange_eyy, + vrange_exy, + ) + ] + ): + show_cbars = ("eyy", "theta") else: - show_cbars = ('exx','eyy','exy','theta') + show_cbars = ("exx", "eyy", "exy", "theta") else: - assert np.all([v in ('exx','eyy','exy','theta') for v in show_cbars]) - + assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars]) # Contrast limits if vrange_exx is None: @@ -841,27 +808,29 @@ def show_strain( # if figsize hasn't been set, set it based on the # chosen layout and the image shape if figsize is None: - ratio = np.sqrt(self.rshape[1]/self.rshape[0]) + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) if layout == 0: - figsize = (13*ratio,8/ratio) + figsize = (13 * ratio, 8 / ratio) elif layout == 1: - figsize = (10*ratio,4/ratio) + figsize = (10 * ratio, 4 / ratio) else: - figsize = (4*ratio,10/ratio) - + figsize = (4 * ratio, 10 / ratio) # set up layout if layout == 0: - fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) =\ - plt.subplots(2, 3, figsize=figsize) + fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( + 2, 3, figsize=figsize + ) elif layout == 1: figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22, ax_legend) =\ - plt.subplots(1, 5, figsize=figsize) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 1, 5, figsize=figsize + ) else: figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22, ax_legend) =\ - plt.subplots(5, 1, figsize=figsize) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 5, 1, figsize=figsize + ) # display images, returning cbar axis references cax11 = show( @@ -1010,58 +979,57 @@ def show_strain( ax_legend1.remove() ax_legend2.remove() # make new axis - ax_legend = fig.add_subplot(gs[:,-1]) + ax_legend = fig.add_subplot(gs[:, -1]) # get the coordinate axes' directions rotation = self.coordinate_rotation_radians xaxis_vectx = np.cos(rotation) xaxis_vecty = np.sin(rotation) - yaxis_vectx = np.cos(rotation+np.pi/2) - yaxis_vecty = np.sin(rotation+np.pi/2) + yaxis_vectx = np.cos(rotation + np.pi / 2) + yaxis_vecty = np.sin(rotation + np.pi / 2) # make the coordinate axes ax_legend.arrow( - x = 0, - y = 0, - dx = xaxis_vecty, - dy = xaxis_vectx, - color = color_axes, - length_includes_head = True, - width = 0.01, - head_width = 0.1, + x=0, + y=0, + dx=xaxis_vecty, + dy=xaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, ) ax_legend.arrow( - x = 0, - y = 0, - dx = yaxis_vecty, - dy = yaxis_vectx, - color = color_axes, - length_includes_head = True, - width = 0.01, - head_width = 0.1, + x=0, + y=0, + dx=yaxis_vecty, + dy=yaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, ) ax_legend.text( - x = xaxis_vecty*1.16, - y = xaxis_vectx*1.16, - s = 'x', - fontsize = 14, - color = color_axes, - horizontalalignment = 'center', - verticalalignment = 'center', + x=xaxis_vecty * 1.16, + y=xaxis_vectx * 1.16, + s="x", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", ) ax_legend.text( - x = yaxis_vecty*1.16, - y = yaxis_vectx*1.16, - s = 'y', - fontsize = 14, - color = color_axes, - horizontalalignment = 'center', - verticalalignment = 'center', + x=yaxis_vecty * 1.16, + y=yaxis_vectx * 1.16, + s="y", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", ) # make the g-vectors if show_gvects: - # get the g-vectors directions g1q = np.array(self.g1) g2q = np.array(self.g2) @@ -1070,75 +1038,75 @@ def show_strain( g1q /= g1norm g2q /= g2norm # set the lengths - g_ratio = g2norm/g1norm + g_ratio = g2norm / g1norm if g_ratio > 1: g1q /= g_ratio else: g2q *= g_ratio - g1_x,g1_y = g1q - g2_x,g2_y = g2q + g1_x, g1_y = g1q + g2_x, g2_y = g2q # draw the g vectors ax_legend.arrow( - x = 0, - y = 0, - dx = g1_y*scale_gvects, - dy = g1_x*scale_gvects, - color = color_gvects, - length_includes_head = True, - width = 0.005, - head_width = 0.05, + x=0, + y=0, + dx=g1_y * scale_gvects, + dy=g1_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, ) ax_legend.arrow( - x = 0, - y = 0, - dx = g2_y*scale_gvects, - dy = g2_x*scale_gvects, - color = color_gvects, - length_includes_head = True, - width = 0.005, - head_width = 0.05, + x=0, + y=0, + dx=g2_y * scale_gvects, + dy=g2_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, ) ax_legend.text( - x = g1_y*scale_gvects*1.2, - y = g1_x*scale_gvects*1.2, - s = r'$g_1$', - fontsize = 12, - color = color_gvects, - horizontalalignment = 'center', - verticalalignment = 'center', + x=g1_y * scale_gvects * 1.2, + y=g1_x * scale_gvects * 1.2, + s=r"$g_1$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", ) ax_legend.text( - x = g2_y*scale_gvects*1.2, - y = g2_x*scale_gvects*1.2, - s = r'$g_2$', - fontsize = 12, - color = color_gvects, - horizontalalignment = 'center', - verticalalignment = 'center', + x=g2_y * scale_gvects * 1.2, + y=g2_x * scale_gvects * 1.2, + s=r"$g_2$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", ) # find center and extent - xmin = np.min([0,0,xaxis_vectx,yaxis_vectx]) - xmax = np.max([0,0,xaxis_vectx,yaxis_vectx]) - ymin = np.min([0,0,xaxis_vecty,yaxis_vecty]) - ymax = np.max([0,0,xaxis_vecty,yaxis_vecty]) + xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx]) + xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx]) + ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty]) + ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty]) if show_gvects: - xmin = np.min([xmin,g1_x,g2_x]) - xmax = np.max([xmax,g1_x,g2_x]) - ymin = np.min([ymin,g1_y,g2_y]) - ymax = np.max([ymax,g1_y,g2_y]) - x0 = np.mean([xmin,xmax]) - y0 = np.mean([ymin,ymax]) - xL = (xmax-x0) * legend_camera_length - yL = (ymax-y0) * legend_camera_length + xmin = np.min([xmin, g1_x, g2_x]) + xmax = np.max([xmax, g1_x, g2_x]) + ymin = np.min([ymin, g1_y, g2_y]) + ymax = np.max([ymax, g1_y, g2_y]) + x0 = np.mean([xmin, xmax]) + y0 = np.mean([ymin, ymax]) + xL = (xmax - x0) * legend_camera_length + yL = (ymax - y0) * legend_camera_length # set the extent and aspect - ax_legend.set_xlim([y0-yL,y0+yL]) - ax_legend.set_ylim([x0-xL,x0+xL]) + ax_legend.set_xlim([y0 - yL, y0 + yL]) + ax_legend.set_ylim([x0 - xL, x0 + xL]) ax_legend.invert_yaxis() ax_legend.set_aspect("equal") - ax_legend.axis('off') + ax_legend.axis("off") # show/return if not returnfig: @@ -1148,23 +1116,22 @@ def show_strain( axs = ((ax11, ax12), (ax21, ax22)) return fig, axs - def show_reference_directions( self, - im_uncal = None, - im_cal = None, + im_uncal=None, + im_cal=None, color_axes="linen", color_gvects="r", - origin_uncal = None, - origin_cal = None, - camera_length = 1.8, - visp_uncal = {'scaling' : 'log'}, - visp_cal = {'scaling' : 'log'}, - layout = "horizontal", - titlesize = 16, - size_labels = 14, - figsize = None, - returnfig = False, + origin_uncal=None, + origin_cal=None, + camera_length=1.8, + visp_uncal={"scaling": "log"}, + visp_cal={"scaling": "log"}, + layout="horizontal", + titlesize=16, + size_labels=14, + figsize=None, + returnfig=False, ): """ Show the reference coordinate system used to compute the strain @@ -1220,87 +1187,59 @@ def show_reference_directions( # Set the figsize if figsize is None: - ratio = np.sqrt(self.rshape[1]/self.rshape[0]) + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) if layout == "horizontal": - figsize = (10*ratio,8/ratio) + figsize = (10 * ratio, 8 / ratio) else: - figsize = (8*ratio,12/ratio) + figsize = (8 * ratio, 12 / ratio) # Create the figure if layout == "horizontal": - fig, (ax1, ax2) =\ - plt.subplots(1, 2, figsize=figsize) + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) else: - fig, (ax1, ax2) =\ - plt.subplots(2, 1, figsize=figsize) + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) # prepare images if im_uncal is None: - im_uncal = self.braggvectors.histogram( mode='raw' ) + im_uncal = self.braggvectors.histogram(mode="raw") if im_cal is None: - im_cal = self.braggvectors.histogram( mode='cal' ) + im_cal = self.braggvectors.histogram(mode="cal") # display images - show( - im_cal, - figax=(fig, ax1), - **visp_cal - ) - show( - im_uncal, - figax=(fig, ax2), - **visp_uncal - ) + show(im_cal, figax=(fig, ax1), **visp_cal) + show(im_uncal, figax=(fig, ax2), **visp_uncal) ax1.set_title("Calibrated", size=titlesize) ax2.set_title("Uncalibrated", size=titlesize) - # Get the coordinate axes # get the directions - # calibrated + # calibrated rotation = self.coordinate_rotation_radians - xaxis_cal = np.array([ - np.cos(rotation), - np.sin(rotation) - ]) - yaxis_cal = np.array([ - np.cos(rotation+np.pi/2), - np.sin(rotation+np.pi/2) - ]) - - # uncalibrated + xaxis_cal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_cal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + + # uncalibrated QRrot = self.calibration.get_QR_rotation() - rotation = np.sum([ - self.coordinate_rotation_radians, - -QRrot - ]) - xaxis_uncal = np.array([ - np.cos(rotation), - np.sin(rotation) - ]) - yaxis_uncal = np.array([ - np.cos(rotation+np.pi/2), - np.sin(rotation+np.pi/2) - ]) + rotation = np.sum([self.coordinate_rotation_radians, -QRrot]) + xaxis_uncal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_uncal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) # inversion if self.calibration.get_QR_flip(): - xaxis_uncal = np.array([ - xaxis_uncal[1], - xaxis_uncal[0] - ]) - yaxis_uncal = np.array([ - yaxis_uncal[1], - yaxis_uncal[0] - ]) + xaxis_uncal = np.array([xaxis_uncal[1], xaxis_uncal[0]]) + yaxis_uncal = np.array([yaxis_uncal[1], yaxis_uncal[0]]) # set the lengths - Lmean = np.mean([im_cal.shape[0],im_cal.shape[1]])/2 - xaxis_cal *= Lmean/camera_length - yaxis_cal *= Lmean/camera_length - xaxis_uncal *= Lmean/camera_length - yaxis_uncal *= Lmean/camera_length + Lmean = np.mean([im_cal.shape[0], im_cal.shape[1]]) / 2 + xaxis_cal *= Lmean / camera_length + yaxis_cal *= Lmean / camera_length + xaxis_uncal *= Lmean / camera_length + yaxis_uncal *= Lmean / camera_length # Get the g-vectors @@ -1309,25 +1248,13 @@ def show_reference_directions( g2_cal = np.array(self.g2) # uncalibrated - R = np.array( - [ - [np.cos(QRrot), -np.sin(QRrot)], - [np.sin(QRrot), np.cos(QRrot)] - ] - ) - g1_uncal = np.matmul(g1_cal,R) - g2_uncal = np.matmul(g2_cal,R) + R = np.array([[np.cos(QRrot), -np.sin(QRrot)], [np.sin(QRrot), np.cos(QRrot)]]) + g1_uncal = np.matmul(g1_cal, R) + g2_uncal = np.matmul(g2_cal, R) # inversion if self.calibration.get_QR_flip(): - g1_uncal = np.array([ - g1_uncal[1], - g1_uncal[0] - ]) - g2_uncal = np.array([ - g2_uncal[1], - g2_uncal[0] - ]) - + g1_uncal = np.array([g1_uncal[1], g1_uncal[0]]) + g2_uncal = np.array([g2_uncal[1], g2_uncal[0]]) # Set origin positions if origin_uncal is None: @@ -1336,172 +1263,170 @@ def show_reference_directions( origin_cal = self.calibration.get_origin_mean() # Draw calibrated coordinate axes - coordax_width = Lmean*2/100 + coordax_width = Lmean * 2 / 100 ax1.arrow( - x = origin_cal[1], - y = origin_cal[0], - dx = xaxis_cal[1], - dy = xaxis_cal[0], - color = color_axes, - length_includes_head = True, - width = coordax_width, - head_width = coordax_width * 5, + x=origin_cal[1], + y=origin_cal[0], + dx=xaxis_cal[1], + dy=xaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, ) ax1.arrow( - x = origin_cal[1], - y = origin_cal[0], - dx = yaxis_cal[1], - dy = yaxis_cal[0], - color = color_axes, - length_includes_head = True, - width = coordax_width, - head_width = coordax_width * 5, + x=origin_cal[1], + y=origin_cal[0], + dx=yaxis_cal[1], + dy=yaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, ) ax1.text( - x = origin_cal[1] + xaxis_cal[1]*1.16, - y = origin_cal[0] + xaxis_cal[0]*1.16, - s = 'x', - fontsize = size_labels, - color = color_axes, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_cal[1] + xaxis_cal[1] * 1.16, + y=origin_cal[0] + xaxis_cal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", ) ax1.text( - x = origin_cal[1] + yaxis_cal[1]*1.16, - y = origin_cal[0] + yaxis_cal[0]*1.16, - s = 'y', - fontsize = size_labels, - color = color_axes, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_cal[1] + yaxis_cal[1] * 1.16, + y=origin_cal[0] + yaxis_cal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", ) # Draw uncalibrated coordinate axes ax2.arrow( - x = origin_uncal[1], - y = origin_uncal[0], - dx = xaxis_uncal[1], - dy = xaxis_uncal[0], - color = color_axes, - length_includes_head = True, - width = coordax_width, - head_width = coordax_width * 5, + x=origin_uncal[1], + y=origin_uncal[0], + dx=xaxis_uncal[1], + dy=xaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, ) ax2.arrow( - x = origin_uncal[1], - y = origin_uncal[0], - dx = yaxis_uncal[1], - dy = yaxis_uncal[0], - color = color_axes, - length_includes_head = True, - width = coordax_width, - head_width = coordax_width * 5, + x=origin_uncal[1], + y=origin_uncal[0], + dx=yaxis_uncal[1], + dy=yaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, ) ax2.text( - x = origin_uncal[1] + xaxis_uncal[1]*1.16, - y = origin_uncal[0] + xaxis_uncal[0]*1.16, - s = 'x', - fontsize = size_labels, - color = color_axes, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_uncal[1] + xaxis_uncal[1] * 1.16, + y=origin_uncal[0] + xaxis_uncal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", ) ax2.text( - x = origin_uncal[1] + yaxis_uncal[1]*1.16, - y = origin_uncal[0] + yaxis_uncal[0]*1.16, - s = 'y', - fontsize = size_labels, - color = color_axes, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_uncal[1] + yaxis_uncal[1] * 1.16, + y=origin_uncal[0] + yaxis_uncal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", ) - # Draw the calibrated g-vectors # draw the g vectors ax1.arrow( - x = origin_cal[1], - y = origin_cal[0], - dx = g1_cal[1], - dy = g1_cal[0], - color = color_gvects, - length_includes_head = True, - width = coordax_width * 0.5, - head_width = coordax_width * 2.5, + x=origin_cal[1], + y=origin_cal[0], + dx=g1_cal[1], + dy=g1_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, ) ax1.arrow( - x = origin_cal[1], - y = origin_cal[0], - dx = g2_cal[1], - dy = g2_cal[0], - color = color_gvects, - length_includes_head = True, - width = coordax_width * 0.5, - head_width = coordax_width * 2.5, + x=origin_cal[1], + y=origin_cal[0], + dx=g2_cal[1], + dy=g2_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, ) ax1.text( - x = origin_cal[1] + g1_cal[1]*1.16, - y = origin_cal[0] + g1_cal[0]*1.16, - s = r'$g_1$', - fontsize = size_labels*0.88, - color = color_gvects, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_cal[1] + g1_cal[1] * 1.16, + y=origin_cal[0] + g1_cal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", ) ax1.text( - x = origin_cal[1] + g2_cal[1]*1.16, - y = origin_cal[0] + g2_cal[0]*1.16, - s = r'$g_2$', - fontsize = size_labels*0.88, - color = color_gvects, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_cal[1] + g2_cal[1] * 1.16, + y=origin_cal[0] + g2_cal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", ) # Draw the uncalibrated g-vectors # draw the g vectors ax2.arrow( - x = origin_uncal[1], - y = origin_uncal[0], - dx = g1_uncal[1], - dy = g1_uncal[0], - color = color_gvects, - length_includes_head = True, - width = coordax_width * 0.5, - head_width = coordax_width * 2.5, + x=origin_uncal[1], + y=origin_uncal[0], + dx=g1_uncal[1], + dy=g1_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, ) ax2.arrow( - x = origin_uncal[1], - y = origin_uncal[0], - dx = g2_uncal[1], - dy = g2_uncal[0], - color = color_gvects, - length_includes_head = True, - width = coordax_width * 0.5, - head_width = coordax_width * 2.5, + x=origin_uncal[1], + y=origin_uncal[0], + dx=g2_uncal[1], + dy=g2_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, ) ax2.text( - x = origin_uncal[1] + g1_uncal[1]*1.16, - y = origin_uncal[0] + g1_uncal[0]*1.16, - s = r'$g_1$', - fontsize = size_labels*0.88, - color = color_gvects, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_uncal[1] + g1_uncal[1] * 1.16, + y=origin_uncal[0] + g1_uncal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", ) ax2.text( - x = origin_uncal[1] + g2_uncal[1]*1.16, - y = origin_uncal[0] + g2_uncal[0]*1.16, - s = r'$g_2$', - fontsize = size_labels*0.88, - color = color_gvects, - horizontalalignment = 'center', - verticalalignment = 'center', + x=origin_uncal[1] + g2_uncal[1] * 1.16, + y=origin_uncal[0] + g2_uncal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", ) - # show/return if not returnfig: plt.show() @@ -1651,4 +1576,3 @@ def _get_constructor_args(cls, group): "name": ar_constr_args["name"], } return args - From 0347217edc2c491e463738b97ef890b885fdb3d2 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Wed, 1 Nov 2023 16:55:18 +0000 Subject: [PATCH 151/176] bugfix --- py4DSTEM/process/strain/strain.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index d4636b84c..0b27d6562 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -1432,8 +1432,7 @@ def show_reference_directions( plt.show() return else: - axs = ((ax11, ax12), (ax21, ax22)) - return fig, axs + return fig, (ax1,ax2) def show_lattice_vectors( ar, From 3d220b7da7553c5d0dbdfd357d19f05785191de8 Mon Sep 17 00:00:00 2001 From: bsavitzky Date: Wed, 1 Nov 2023 16:57:52 +0000 Subject: [PATCH 152/176] autoformats --- py4DSTEM/process/strain/strain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 0b27d6562..c6b245192 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -1432,7 +1432,7 @@ def show_reference_directions( plt.show() return else: - return fig, (ax1,ax2) + return fig, (ax1, ax2) def show_lattice_vectors( ar, From b67a06422130c884cb4d6c951fa9b65f49d3f069 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Wed, 1 Nov 2023 12:00:36 -0700 Subject: [PATCH 153/176] adding self_consistency_errors property. not implemented for 3D yet --- .../process/phase/iterative_base_class.py | 31 ++++++++++++++++++- ...tive_mixedstate_multislice_ptychography.py | 28 +++++++++++++++++ .../iterative_mixedstate_ptychography.py | 28 +++++++++++++++++ .../iterative_overlap_magnetic_tomography.py | 5 +++ .../phase/iterative_overlap_tomography.py | 5 +++ .../iterative_simultaneous_ptychography.py | 29 +++++++++++++++++ 6 files changed, 125 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index f04a3c552..13c64d79d 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2366,6 +2366,35 @@ def positions(self): @property def object_cropped(self): - """cropped and rotated object""" + """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 3eeb07814..82155219a 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3509,3 +3509,31 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) + error /= self._mean_diffraction_intensity + + return asnumpy(error) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 2e9fbd076..25bee346c 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -2327,3 +2327,31 @@ def show_fourier_probe( chroma_boost=chroma_boost, **kwargs, ) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) + error /= self._mean_diffraction_intensity + + return asnumpy(error) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 32b0f6fd4..cde84907c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3327,3 +3327,8 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 66cf46487..e92211301 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3207,3 +3207,8 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 37438852f..757b2ffae 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -3357,3 +3357,32 @@ def visualize( ) return self + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) From c1bbb9103e0b621a0e2928a5f3ba5543f8099bce Mon Sep 17 00:00:00 2001 From: Steven Zeltmann Date: Wed, 1 Nov 2023 15:15:12 -0400 Subject: [PATCH 154/176] add robustness back to fit_origin --- py4DSTEM/braggvectors/braggvector_methods.py | 21 ++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 267f81e5f..932056608 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -552,14 +552,19 @@ def fit_origin( from py4DSTEM.process.calibration import fit_origin if mask_check_data is True: - # TODO - replace this bad hack for the mask for the origin fit - mask = np.logical_not(q_meas[0] == 0) - qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin( - tuple(q_meas), - mask=mask, - ) - else: - qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(tuple(q_meas)) + data_mask = np.logical_not(q_meas[0] == 0) + if mask is None: + mask = data_mask + else: + mask = np.logical_and(mask, data_mask) + + qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin( + tuple(q_meas), + mask=mask, + robust=robust, + robust_steps=robust_steps, + robust_thresh=robust_thresh, + ) # try to add to calibration try: From fac36a7f99378713e9af691bb9f2991aa1c0db09 Mon Sep 17 00:00:00 2001 From: Steven Zeltmann Date: Wed, 1 Nov 2023 15:19:24 -0400 Subject: [PATCH 155/176] add fit function option back --- py4DSTEM/braggvectors/braggvector_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 932056608..a47a242c5 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -561,6 +561,7 @@ def fit_origin( qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin( tuple(q_meas), mask=mask, + fitfunction=fitfunction, robust=robust, robust_steps=robust_steps, robust_thresh=robust_thresh, From 9f82c20fb4a44c158270c286c092d48fb220053e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 14:00:54 -0700 Subject: [PATCH 156/176] real space mask for positions to ignore --- py4DSTEM/process/phase/iterative_base_class.py | 15 ++++++++++++++- ...terative_mixedstate_multislice_ptychography.py | 6 +++++- .../phase/iterative_mixedstate_ptychography.py | 6 +++++- .../phase/iterative_multislice_ptychography.py | 6 +++++- .../iterative_overlap_magnetic_tomography.py | 6 +++++- .../process/phase/iterative_overlap_tomography.py | 6 +++++- .../phase/iterative_simultaneous_ptychography.py | 6 +++++- .../phase/iterative_singleslice_ptychography.py | 6 +++++- 8 files changed, 49 insertions(+), 8 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 13c64d79d..476216f79 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1535,7 +1535,9 @@ def _set_polar_parameters(self, parameters: dict): else: raise ValueError("{} not a recognized parameter".format(symbol)) - def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): + def _calculate_scan_positions_in_pixels( + self, positions: np.ndarray, positions_mask + ): """ Method to compute the initial guess of scan positions in pixels. @@ -1544,6 +1546,8 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions: (J,2) np.ndarray or None Input probe positions in Ã…. If None, a raster scan using experimental parameters is constructed. + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1592,6 +1596,15 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions = np.array([x.ravel(), y.ravel()]).T positions -= np.min(positions, axis=0) + if positions_mask is not None: + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converged to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + positions = positions[positions_mask.ravel()] + if self._object_padding_px is None: float_padding = self._region_of_interest_shape / 2 self._object_padding_px = (float_padding, float_padding) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 82155219a..98967ba89 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -85,6 +85,8 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -115,6 +117,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -201,6 +204,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -454,7 +458,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 25bee346c..195dace86 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -74,6 +74,8 @@ class MixedstatePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -102,6 +104,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "mixed-state_ptychographic_reconstruction", @@ -178,6 +181,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -358,7 +362,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 6bcacd934..a137bbeb9 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -89,6 +89,8 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -121,6 +123,7 @@ def __init__( theta_y: float = 0, middle_focus: bool = False, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -211,6 +214,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -481,7 +485,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index cde84907c..b4501d012 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -93,6 +93,8 @@ class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -115,6 +117,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -179,6 +182,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -615,7 +619,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index e92211301..759b12602 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -88,6 +88,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction name: str, optional Class name kwargs: @@ -111,6 +113,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -188,6 +191,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -555,7 +559,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 757b2ffae..35b2bb9ef 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -66,6 +66,8 @@ class SimultaneousPtychographicReconstruction(PtychographicReconstruction): object_padding_px: Tuple[int,int], optional Pixel dimensions to pad objects with If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction initial_object_guess: np.ndarray, optional Initial guess for complex-valued object of dimensions (Px,Py) If None, initialized to 1.0j @@ -102,6 +104,7 @@ def __init__( vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, @@ -167,6 +170,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -607,7 +611,7 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 5dd19d7bd..8e66639b2 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -79,6 +79,8 @@ class SingleslicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -102,6 +104,7 @@ def __init__( initial_scan_positions: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "ptychographic_reconstruction", @@ -163,6 +166,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -342,7 +346,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels From 67e15e7002234c0540cc9a69d8a6c60ff0d4c471 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 14:40:00 -0700 Subject: [PATCH 157/176] amplitudes update for real space mask --- .../process/phase/iterative_base_class.py | 12 +++++----- ...tive_mixedstate_multislice_ptychography.py | 13 +++++++++- .../iterative_mixedstate_ptychography.py | 12 +++++++++- .../iterative_multislice_ptychography.py | 12 +++++++++- .../iterative_overlap_magnetic_tomography.py | 13 +++++++++- .../phase/iterative_overlap_tomography.py | 13 +++++++++- .../iterative_simultaneous_ptychography.py | 24 ++++++++++++++++--- .../iterative_singleslice_ptychography.py | 13 +++++++++- 8 files changed, 97 insertions(+), 15 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 476216f79..73021d8a9 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1132,6 +1132,7 @@ def _normalize_diffraction_intensities( com_fitted_x, com_fitted_y, crop_patterns, + positions_mask, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1147,6 +1148,8 @@ def _normalize_diffraction_intensities( crop_patterns: bool if True, crop patterns to avoid wrap around of patterns when centering + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1220,6 +1223,9 @@ def _normalize_diffraction_intensities( amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) amplitudes = xp.asarray(amplitudes) + if positions_mask is not None: + amplitudes = amplitudes[positions_mask.ravel()] + mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity @@ -1597,12 +1603,6 @@ def _calculate_scan_positions_in_pixels( positions -= np.min(positions, axis=0) if positions_mask is not None: - if positions_mask.dtype != "bool": - warnings.warn( - ("`positions_mask` converged to `bool` array"), - UserWarning, - ) - positions_mask = np.asarray(positions_mask, dtype="bool") positions = positions[positions_mask.ravel()] if self._object_padding_px is None: diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 98967ba89..2915acccb 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -189,6 +189,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -449,7 +456,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 195dace86..01d70bf71 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -164,6 +164,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -353,7 +359,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a137bbeb9..be24f067d 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -198,6 +198,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -476,7 +482,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index b4501d012..810352ce8 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -166,6 +166,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -599,7 +606,11 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, com_fitted_x, com_fitted_y, crop_patterns + intensities, + com_fitted_x, + com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 759b12602..701267e81 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -175,6 +175,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -539,7 +546,11 @@ def preprocess( ], mean_diffraction_intensity_temp, ) = self._normalize_diffraction_intensities( - intensities, com_fitted_x, com_fitted_y, crop_patterns + intensities, + com_fitted_x, + com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 35b2bb9ef..ae1a3ecac 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -153,6 +153,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -408,7 +414,11 @@ def preprocess( amplitudes_0, mean_diffraction_intensity_0, ) = self._normalize_diffraction_intensities( - intensities_0, com_fitted_x_0, com_fitted_y_0, crop_patterns + intensities_0, + com_fitted_x_0, + com_fitted_y_0, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -489,7 +499,11 @@ def preprocess( amplitudes_1, mean_diffraction_intensity_1, ) = self._normalize_diffraction_intensities( - intensities_1, com_fitted_x_1, com_fitted_y_1, crop_patterns + intensities_1, + com_fitted_x_1, + com_fitted_y_1, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -571,7 +585,11 @@ def preprocess( amplitudes_2, mean_diffraction_intensity_2, ) = self._normalize_diffraction_intensities( - intensities_2, com_fitted_x_2, com_fitted_y_2, crop_patterns + intensities_2, + com_fitted_x_2, + com_fitted_y_2, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 8e66639b2..ab16330da 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -150,6 +150,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -337,7 +344,11 @@ def preprocess( self._amplitudes, self._mean_diffraction_intensity, ) = self._normalize_diffraction_intensities( - self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace From 2d48616c7e5a0e83ad2f038c97c35fb6d4ddad24 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 15:51:05 -0700 Subject: [PATCH 158/176] Thnks fr th Mmr(s) --- .../process/phase/iterative_base_class.py | 24 ++++++++++++------- ...tive_mixedstate_multislice_ptychography.py | 2 +- .../iterative_mixedstate_ptychography.py | 2 +- .../iterative_multislice_ptychography.py | 2 +- .../iterative_overlap_magnetic_tomography.py | 2 +- .../phase/iterative_overlap_tomography.py | 2 +- .../iterative_simultaneous_ptychography.py | 2 +- .../iterative_singleslice_ptychography.py | 2 +- 8 files changed, 23 insertions(+), 15 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 73021d8a9..497c7ae1c 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1163,6 +1163,12 @@ def _normalize_diffraction_intensities( mean_intensity = 0 diffraction_intensities = self._asnumpy(diffraction_intensities) + if positions_mask is not None: + number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) + sx, sy = np.where(~self._positions_mask) + else: + number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + if crop_patterns: crop_x = int( np.minimum( @@ -1181,8 +1187,7 @@ def _normalize_diffraction_intensities( region_of_interest_shape = (crop_w * 2, crop_w * 2) amplitudes = np.zeros( ( - diffraction_intensities.shape[0], - diffraction_intensities.shape[1], + number_of_patterns, crop_w * 2, crop_w * 2, ), @@ -1198,13 +1203,19 @@ def _normalize_diffraction_intensities( else: region_of_interest_shape = diffraction_intensities.shape[-2:] - amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32) + amplitudes = np.zeros( + (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + ) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) + counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): + if positions_mask is not None: + if rx in sx and ry in sy: + continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], -com_fitted_x[rx, ry], @@ -1219,13 +1230,10 @@ def _normalize_diffraction_intensities( ) mean_intensity += np.sum(intensities) - amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) amplitudes = xp.asarray(amplitudes) - if positions_mask is not None: - amplitudes = amplitudes[positions_mask.ravel()] - mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 2915acccb..26b0d8cff 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -189,7 +189,7 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 01d70bf71..ebc40928d 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -164,7 +164,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index be24f067d..73f83558e 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -198,7 +198,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 810352ce8..582eea772 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -166,7 +166,7 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 701267e81..f4dfe5022 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -175,7 +175,7 @@ def __init__( if object_type != "potential": raise NotImplementedError() - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index ae1a3ecac..866ff0a89 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -153,7 +153,7 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index ab16330da..350d0a3cb 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -150,7 +150,7 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) - if positions_mask.dtype != "bool": + if positions_mask is not None and positions_mask.dtype != "bool": warnings.warn( ("`positions_mask` converted to `bool` array"), UserWarning, From 3da6fdc3cda11baba1289abbd167ffa2d42627e5 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 16:32:24 -0700 Subject: [PATCH 159/176] one more bug --- py4DSTEM/process/phase/iterative_base_class.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 497c7ae1c..1aa03559a 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1165,7 +1165,6 @@ def _normalize_diffraction_intensities( diffraction_intensities = self._asnumpy(diffraction_intensities) if positions_mask is not None: number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) - sx, sy = np.where(~self._positions_mask) else: number_of_patterns = np.prod(diffraction_intensities.shape[:2]) @@ -1214,7 +1213,7 @@ def _normalize_diffraction_intensities( for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): if positions_mask is not None: - if rx in sx and ry in sy: + if not self._positions_mask[rx,ry]: continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], From 7a4e7a43e926c48aa0643b60bb1d0202e8aa65ea Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 16:34:08 -0700 Subject: [PATCH 160/176] black format --- py4DSTEM/process/phase/iterative_base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 1aa03559a..7437679a2 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -1213,7 +1213,7 @@ def _normalize_diffraction_intensities( for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): if positions_mask is not None: - if not self._positions_mask[rx,ry]: + if not self._positions_mask[rx, ry]: continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], From 54932e9e9591b8891ece51396e9ae0b23e200d95 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Wed, 1 Nov 2023 17:31:28 -0700 Subject: [PATCH 161/176] colorbars for fit origin --- py4DSTEM/braggvectors/braggvector_methods.py | 1 + py4DSTEM/visualize/show.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 669817788..33de8f8c4 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -647,6 +647,7 @@ def show_origin_fit( vmin=-1 * plot_range, vmax=1 * plot_range, intensity_range="absolute", + show_cbar=True, **kwargs, ) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 4e99c0de5..1c316f091 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -75,6 +75,7 @@ def show( theta=None, title=None, show_fft=False, + show_cbar=False, **kwargs ): """ @@ -302,7 +303,8 @@ def show( does not add a scalebar. If a dict is passed, it is propagated to the add_scalebar function which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. - show_fft (Bool): if True, plots 2D-fft of array + show_fft (bool): if True, plots 2D-fft of array + show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() Returns: @@ -605,6 +607,8 @@ def show( ax.matshow( mask_display, cmap=cmap, alpha=mask_alpha, vmin=vmin, vmax=vmax ) + if show_cbar: + fig.colorbar(cax, ax=ax) # ...or, plot its histogram else: hist, bin_edges = np.histogram( From eabd74257d553caf633e7dae2ecd1bd535e9c84f Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 11:39:43 -0700 Subject: [PATCH 162/176] I've been plotting to update this function --- .../process/phase/iterative_base_class.py | 6 -- ...tive_mixedstate_multislice_ptychography.py | 70 ++++++++++++++++++- .../iterative_multislice_ptychography.py | 34 +++++++-- .../iterative_overlap_magnetic_tomography.py | 6 -- .../phase/iterative_overlap_tomography.py | 6 -- 5 files changed, 96 insertions(+), 26 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 7437679a2..56be2784a 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2306,22 +2306,16 @@ def show_object_fft(self, obj=None, **kwargs): figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 26b0d8cff..6cbdca19e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -82,6 +82,12 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio initial_scan_positions: np.ndarray, optional Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') @@ -116,6 +122,9 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", positions_mask: np.ndarray = None, verbose: bool = True, @@ -165,6 +174,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: @@ -221,6 +249,8 @@ def __init__( self._num_probes = num_probes self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -243,6 +273,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) Returns ------- @@ -262,6 +296,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -269,6 +307,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -3075,6 +3119,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -3090,12 +3135,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -3113,8 +3166,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 73f83558e..4b0d6881c 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -81,9 +81,9 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): Probe positions in Ã… for each diffraction intensity If None, initialized to a grid scan theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in degrees) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in degrees) middle_focus: bool if True, adds half the sample thickness to the defocus object_type: str, optional @@ -256,9 +256,9 @@ def _precompute_propagator_arrays( slice_thicknesses: Sequence[float] Array of slice thicknesses in A theta_x: float - x tilt of propagator (in angles) + x tilt of propagator (in degrees) theta_y: float - y tilt of propagator (in angles) + y tilt of propagator (in degrees) Returns ------- @@ -2955,6 +2955,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -2970,12 +2971,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -2993,8 +3002,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.mean(0).ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 582eea772..7c96cb34c 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3303,22 +3303,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index f4dfe5022..54b94010a 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3183,22 +3183,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) From 3a6ee5a80c70cd9cf4c9d94b662c01bb82df8ce7 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 11:46:23 -0700 Subject: [PATCH 163/176] correct propagation of arguments --- .../phase/iterative_mixedstate_multislice_ptychography.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6cbdca19e..6cd74828e 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -258,6 +258,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -656,6 +658,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps From a51594c80b9c22344760ab287b3d8f2a36492cb0 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 11:56:39 -0700 Subject: [PATCH 164/176] one more bug fix --- py4DSTEM/process/phase/iterative_multislice_ptychography.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 4b0d6881c..764f0b4a0 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3004,7 +3004,7 @@ def show_slices( cmap = kwargs.pop("cmap", "magma") if common_color_scale: - vals = np.sort(rotated_object.mean(0).ravel()) + vals = np.sort(rotated_object.ravel()) ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") ind_vmin = np.max([0, ind_vmin]) From 8323bc8f30d35a98de32e7521bfc3616a0d95706 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Thu, 2 Nov 2023 17:51:30 -0700 Subject: [PATCH 165/176] fft hanning window --- py4DSTEM/visualize/show.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 4e99c0de5..b6077c412 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -366,7 +366,9 @@ def show( from py4DSTEM.visualize import show if show_fft: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) for a0 in range(num_images): im = show( ar[a0], From 9185aa3b40bd1039d92e9731caa7aabc2ee0f495 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 3 Nov 2023 14:44:49 -0700 Subject: [PATCH 166/176] mostly formatting changes --- py4DSTEM/braggvectors/braggvector_methods.py | 12 ++++++---- py4DSTEM/process/strain/strain.py | 25 ++++++++++---------- py4DSTEM/visualize/show.py | 5 +++- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 33de8f8c4..99648ab16 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -1,13 +1,14 @@ # BraggVectors methods -import numpy as np -from scipy.ndimage import gaussian_filter -from warnings import warn import inspect +from warnings import warn -from emdfile import Array, Metadata, tqdmnd, _read_metadata -from py4DSTEM.datacube import VirtualImage +import matplotlib.pyplot as plt +import numpy as np +from emdfile import Array, Metadata, _read_metadata, tqdmnd from py4DSTEM import show +from py4DSTEM.datacube import VirtualImage +from scipy.ndimage import gaussian_filter class BraggVectorMethods: @@ -627,6 +628,7 @@ def show_origin_fit( # set figsize imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0]) axsize = (3 * imsize_ratio, 3 / imsize_ratio) + axsize = kwargs.pop("axsize", axsize) # plot show( diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index c6b245192..df9e03d19 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -404,7 +404,7 @@ def set_max_peak_spacing( Parameters ---------- max_peak_spacing : number - The maximum allowable distance between a detected Bragg peak and + The maximum allowable distance in pixels between a detected Bragg peak and the indexed maxima found in `choose_basis_vectors` for the detected peak to be indexed returnfig : bool @@ -678,7 +678,7 @@ def show_strain( color_gvects="r", legend_camera_length=1.6, scale_gvects=0.6, - layout=0, + layout="square", figsize=None, returnfig=False, ): @@ -745,12 +745,13 @@ def show_strain( Toggles returning the figure """ # Lookup table for different layouts - assert layout in (0, 1, 2) + assert layout in ("square", "horizontal", "vertical") layout_lookup = { - 0: ["left", "right", "left", "right"], - 1: ["bottom", "bottom", "bottom", "bottom"], - 2: ["right", "right", "right", "right"], + "square": ["left", "right", "left", "right"], + "horizontal": ["bottom", "bottom", "bottom", "bottom"], + "vertical": ["right", "right", "right", "right"], } + layout_p = layout_lookup[layout] # Set which colorbars to display @@ -809,19 +810,19 @@ def show_strain( # chosen layout and the image shape if figsize is None: ratio = np.sqrt(self.rshape[1] / self.rshape[0]) - if layout == 0: + if layout == "square": figsize = (13 * ratio, 8 / ratio) - elif layout == 1: + elif layout == "horizontal": figsize = (10 * ratio, 4 / ratio) else: figsize = (4 * ratio, 10 / ratio) # set up layout - if layout == 0: + if layout == "square": fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( 2, 3, figsize=figsize ) - elif layout == 1: + elif layout == "horizontal": figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( 1, 5, figsize=figsize @@ -971,8 +972,8 @@ def show_strain( # Legend - # for layout 0, combine vertical plots on the right end - if layout == 0: + # for layout "square", combine vertical plots on the right end + if layout == "square": # get gridspec object gs = ax_legend1.get_gridspec() # remove last two axes diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 1c316f091..6531fd741 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -8,6 +8,7 @@ from matplotlib.axes import Axes from matplotlib.colors import is_color_like from matplotlib.figure import Figure +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM.data import Calibration, DiffractionSlice, RealSlice from py4DSTEM.visualize.overlay import ( add_annuli, @@ -608,7 +609,9 @@ def show( mask_display, cmap=cmap, alpha=mask_alpha, vmin=vmin, vmax=vmax ) if show_cbar: - fig.colorbar(cax, ax=ax) + ax_divider = make_axes_locatable(ax) + c_axis = ax_divider.append_axes("right", size="7%") + fig.colorbar(cax, cax=c_axis) # ...or, plot its histogram else: hist, bin_edges = np.histogram( From 3e5fcf9b94041f0d8886d63637ec837c519f0b50 Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 3 Nov 2023 15:41:12 -0700 Subject: [PATCH 167/176] small bug fixes --- py4DSTEM/braggvectors/braggvector_methods.py | 4 +++- py4DSTEM/process/strain/strain.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 99648ab16..2bd6ee8c8 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -631,7 +631,7 @@ def show_origin_fit( axsize = kwargs.pop("axsize", axsize) # plot - show( + fig, ax = show( [ [qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals], [qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals], @@ -650,8 +650,10 @@ def show_origin_fit( vmax=1 * plot_range, intensity_range="absolute", show_cbar=True, + returnfig=True, **kwargs, ) + plt.tight_layout() return diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index df9e03d19..ab8a46a9a 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -1554,6 +1554,7 @@ def copy(self, name=None): "g1g2_map", "strainmap_g1g2", "strainmap_rotated", + "mask", ): if hasattr(self, attr): setattr(strainmap_copy, attr, getattr(self, attr)) From 9d5e83d14a6414d153014d29936e28ddd140e7c3 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 Nov 2023 16:19:54 -0700 Subject: [PATCH 168/176] ctf transpose bugfix - tested mostly for stig --- py4DSTEM/process/phase/iterative_parallax.py | 364 +++++++++++-------- 1 file changed, 209 insertions(+), 155 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index daab204a0..3758a64e8 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -136,7 +136,7 @@ def to_h5(self, group): if hasattr(self, "aberration_C1"): recon_metadata |= { "aberration_rotation_QR": self.rotation_Q_to_R_rads, - "aberration_transpose": self.transpose_detected, + "aberration_transpose": self.transpose, "aberration_C1": self.aberration_C1, "aberration_A1x": self.aberration_A1x, "aberration_A1y": self.aberration_A1y, @@ -236,7 +236,7 @@ def _populate_instance(self, group): if "aberration_C1" in reconstruction_md.keys: self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] - self.transpose_detected = reconstruction_md["aberration_transpose"] + self.transpose = reconstruction_md["aberration_transpose"] self.aberration_C1 = reconstruction_md["aberration_C1"] self.aberration_A1x = reconstruction_md["aberration_A1x"] self.aberration_A1y = reconstruction_md["aberration_A1y"] @@ -1321,7 +1321,7 @@ def aberration_fit( plot_CTF_comparison: bool = None, plot_BF_shifts_comparison: bool = None, upsampled: bool = True, - force_transpose: bool = None, + force_transpose: bool = False, ): """ Fit aberrations to the measured image shifts. @@ -1362,17 +1362,13 @@ def aberration_fit( # Convert real space shifts to Angstroms - if force_transpose is None: - self.transpose_detected = False - else: - self.transpose_detected = force_transpose - if force_transpose is True: self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( self._scan_sampling ) else: self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + self.transpose = force_transpose # Solve affine transformation m = asnumpy( @@ -1389,9 +1385,15 @@ def aberration_fit( np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi ) m_aberration = -1.0 * m_aberration + self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + + if self.transpose: + self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + else: + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 ### Second pass @@ -1437,12 +1439,26 @@ def aberration_fit( sx = self._scan_sampling[0] / self._kde_upsample_factor sy = self._scan_sampling[1] / self._kde_upsample_factor + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + else: im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) sx = self._scan_sampling[0] sy = self._scan_sampling[1] upsampled = False + reciprocal_extent = [ + -0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[0], + -0.5 / self._scan_sampling[0], + ] + # FFT coordinates qx = xp.fft.fftfreq(im_FFT.shape[0], sx) qy = xp.fft.fftfreq(im_FFT.shape[1], sy) @@ -1494,12 +1510,19 @@ def calculate_CTF_FFT(alpha_shape, *coefs): sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) - qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + qx, qy = np.meshgrid(qx, qy, indexing="ij") + + # passive rotation basis by -theta + rotation_angle = -self.rotation_Q_to_R_rads + qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( + rotation_angle + ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) - u = qx[:, None] * self._wavelength - v = qy[None, :] * self._wavelength + qr2 = qx**2 + qy**2 + u = qx * self._wavelength + v = qy * self._wavelength alpha = xp.sqrt(qr2) * self._wavelength - theta = xp.arctan2(qy[None, :], qx[:, None]) + theta = xp.arctan2(qy, qx) # Aberration basis self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) @@ -1561,10 +1584,17 @@ def calculate_CTF(alpha_shape, *coefs): # initial coefficients and plotting intensity range mask self._aberrations_coefs = np.zeros(self._aberrations_num) - ind = np.argmin( - np.abs(self._aberrations_mn[:, 0] - 1.0) + self._aberrations_mn[:, 1] - ) - self._aberrations_coefs[ind] = self.aberration_C1 + + aberrations_mn_list = self._aberrations_mn.tolist() + if [1, 0, 0] in aberrations_mn_list: + ind_C1 = aberrations_mn_list.index([1, 0, 0]) + self._aberrations_coefs[ind_C1] = self.aberration_C1 + + if [1, 2, 0] in aberrations_mn_list: + ind_A1x = aberrations_mn_list.index([1, 2, 0]) + ind_A1y = aberrations_mn_list.index([1, 2, 1]) + self._aberrations_coefs[ind_A1x] = self.aberration_A1x + self._aberrations_coefs[ind_A1y] = self.aberration_A1y # Refinement using CTF fitting / Thon rings if fit_CTF_FFT: @@ -1617,57 +1647,84 @@ def score_CTF(coefs): ) # (Relative) untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + raveled_shifts = self._xy_shifts_Ang.T.ravel() aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None + gradients, raveled_shifts, rcond=None )[:2] - if force_transpose is None: - # (Relative) transposed fit - transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) - m_T = asnumpy( - xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ - 0 - ] + self._aberrations_coefs = asnumpy(aberrations_coefs) + + if self.transpose: + aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( + self._aberrations_mn[:, 2] == 0 ) - m_rotation_T, _ = polar(m_T, side="right") - rotation_Q_to_R_rads_T = -1 * np.arctan2( - m_rotation_T[1, 0], m_rotation_T[0, 0] + self._aberrations_coefs[aberrations_to_flip] *= -1 + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 ) - if np.abs( - np.mod(rotation_Q_to_R_rads_T + np.pi, 2.0 * np.pi) - np.pi - ) > (np.pi * 0.5): - rotation_Q_to_R_rads_T = ( - np.mod(rotation_Q_to_R_rads_T, 2.0 * np.pi) - np.pi - ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] + + fitted_shifts = ( + xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) + .reshape((2, -1)) + .T + ) + + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] - tf_T = AffineTransform(angle=rotation_Q_to_R_rads_T) - rotated_shifts_T = tf_T(transposed_shifts, xp=xp).T.ravel() - aberrations_coefs_T, res_T = xp.linalg.lstsq( - gradients, rotated_shifts_T, rcond=None - )[:2] - - # Compare fits - if res_T.sum() < res.sum(): - self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = not self.transpose_detected - self._aberrations_coefs = asnumpy(aberrations_coefs_T) - self._rotated_shifts = rotated_shifts_T - - warnings.warn( - ( - "Data transpose detected. " - f"Overwriting rotation value to {np.rad2deg(rotation_Q_to_R_rads_T):.3f}" - ), - UserWarning, + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] + + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] ) - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts + ) + + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], + [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Fitted Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Horizontal Shifts", + ], + ) # Plot the CTF comparison between experiment and fit if plot_CTF_comparison: @@ -1705,79 +1762,24 @@ def score_CTF(coefs): im_plot[:, :, 2] -= im_CTF im_plot = np.clip(im_plot, 0, 1) - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) - ax1.imshow(im_plot, vmin=int_range[0], vmax=int_range[1]) - - ax2.imshow(np.fft.fftshift(asnumpy(im_CTF_cos)), cmap="gray") - - fig.tight_layout() - - # Plot the measured/fitted shifts comparison - if plot_BF_shifts_comparison: - if not fit_BF_shifts: - raise ValueError() - - measured_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sx[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._rotated_shifts[: self._xy_inds.shape[0]] - - measured_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - measured_shifts_sy[ - self._xy_inds[:, 0], self._xy_inds[:, 1] - ] = self._rotated_shifts[self._xy_inds.shape[0] :] - - fitted_shifts = xp.tensordot( - gradients, xp.array(self._aberrations_coefs), axes=1 + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) + ax1.imshow( + im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent ) - - fitted_shifts_sx = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap="gray", + extent=reciprocal_extent, ) - fitted_shifts_sx[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ - : self._xy_inds.shape[0] - ] - fitted_shifts_sy = xp.zeros( - self._region_of_interest_shape, dtype=xp.float32 - ) - fitted_shifts_sy[self._xy_inds[:, 0], self._xy_inds[:, 1]] = fitted_shifts[ - self._xy_inds.shape[0] : - ] + for ax in (ax1, ax2): + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") - max_shift = xp.max( - xp.array( - [ - xp.abs(measured_shifts_sx).max(), - xp.abs(measured_shifts_sy).max(), - xp.abs(fitted_shifts_sx).max(), - xp.abs(fitted_shifts_sy).max(), - ] - ) - ) + ax1.set_title("Aligned Bright Field FFT") + ax2.set_title("Fitted CTF Zero-Crossings") - show( - [ - [asnumpy(measured_shifts_sx), asnumpy(measured_shifts_sy)], - [asnumpy(fitted_shifts_sx), asnumpy(fitted_shifts_sy)], - ], - cmap="PiYG", - vmin=-max_shift, - vmax=max_shift, - intensity_range="absolute", - axsize=(4, 4), - ticks=False, - title=[ - "Measured Vertical Shifts", - "Measured Horizontal Shifts", - "Fitted Vertical Shifts", - "Fitted Horizontal Shifts", - ], - ) + fig.tight_layout() self.aberration_dict = { tuple(self._aberrations_mn[a0]): { @@ -1809,7 +1811,7 @@ def score_CTF(coefs): ) print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - print(f"Transpose = {self.transpose_detected}") + print(f"Transpose = {self.transpose}") if fit_CTF_FFT or fit_BF_shifts: print() @@ -2292,6 +2294,7 @@ def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, + plot_rotated_shifts=True, **kwargs, ): """ @@ -2308,10 +2311,22 @@ def show_shifts( xp = self._xp asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (6, 6)) color = kwargs.pop("color", (1, 0, 0, 1)) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + scaling_factor = ( + xp.array(self._reciprocal_sampling) + / xp.array(self._scan_sampling) + * scale_arrows + ) + rotated_shifts = self._xy_shifts_Ang * scaling_factor - fig, ax = plt.subplots(figsize=figsize) + else: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + + shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -2321,29 +2336,68 @@ def show_shifts( masked_ind = xp.logical_and(freq_mask, self._dp_mask) plot_ind = masked_ind[dp_mask_ind] - ax.quiver( - asnumpy(self._kxy[plot_ind, 1]), - asnumpy(self._kxy[plot_ind, 0]), - asnumpy( - self._xy_shifts[plot_ind, 1] - * scale_arrows - * self._reciprocal_sampling[0] - ), - asnumpy( - self._xy_shifts[plot_ind, 0] - * scale_arrows - * self._reciprocal_sampling[1] - ), - color=color, - angles="xy", - scale_units="xy", - scale=1, - **kwargs, - ) - kr_max = xp.max(self._kr) - ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) - ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + ax[0].quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[0].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_title("Measured Bright Field Shifts") + ax[0].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[0].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[0].set_aspect("equal") + + # passive coordinate rotation + tf_T = AffineTransform(angle=-self.rotation_Q_to_R_rads) + rotated_kxy = tf_T(self._kxy[plot_ind], xp=xp) + ax[1].quiver( + asnumpy(rotated_kxy[:, 1]), + asnumpy(rotated_kxy[:, 0]), + asnumpy(rotated_shifts[plot_ind, 1]), + asnumpy(rotated_shifts[plot_ind, 0]), + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[1].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_title("Rotated Bright Field Shifts") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[1].set_aspect("equal") + else: + ax.quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_title("Measured BF Shifts") + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.set_aspect("equal") + + fig.tight_layout() def visualize( self, From 2e59d1c7b52c509e5eb5164f6af5a58da5457975 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Fri, 3 Nov 2023 16:20:11 -0700 Subject: [PATCH 169/176] making ptycho aberration fitting convention consistent --- py4DSTEM/process/phase/iterative_ptychographic_constraints.py | 2 +- py4DSTEM/process/phase/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 0760087b4..d29aa1747 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -566,7 +566,7 @@ def _probe_aberration_fitting_constraint( xp=xp, ) - fourier_probe = fourier_probe_abs * xp.exp(1.0j * fitted_angle) + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) current_probe = xp.fft.ifft2(fourier_probe) return current_probe diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index d29765d04..a1eb54c80 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1620,7 +1620,7 @@ def fit_aberration_surface( ): """ """ probe_amp = xp.abs(complex_probe) - probe_angle = xp.angle(complex_probe) + probe_angle = -xp.angle(complex_probe) if xp is np: probe_angle = probe_angle.astype(np.float64) From e23878c8fe25cb8fa6bbd7df72b3f26ab5781f63 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Fri, 3 Nov 2023 16:31:40 -0700 Subject: [PATCH 170/176] adding to_strainmap method --- py4DSTEM/braggvectors/braggvector_methods.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 2bd6ee8c8..065eb7d75 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -792,6 +792,21 @@ def mask_in_R(self, mask, update_inplace=False, returncalc=True): else: return + def to_strainmap(self, name: str = None): + """ + Generate a StrainMap object from the BraggVectors + equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors) + + Args: + name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'. + + Returns: + py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors + """ + from py4DSTEM.process.strain import StrainMap + + return StrainMap(self, name) if name else StrainMap(self) + ######### END BraggVectorMethods CLASS ######## From d32b18dc3151c8cf5c457c206c094d804f7b84b7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 12:37:30 -0700 Subject: [PATCH 171/176] update uncertainty viz --- .../process/phase/iterative_base_class.py | 257 +++++++++++++++--- py4DSTEM/visualize/vis_special.py | 31 +++ 2 files changed, 257 insertions(+), 31 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 56be2784a..0f342d5c8 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -8,7 +8,7 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import show, show_complex +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import rotate try: @@ -23,7 +23,11 @@ from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( PtychographicConstraints, ) -from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases +from py4DSTEM.process.phase.utils import ( + AffineTransform, + generate_batches, + polar_aliases, +) from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -2237,6 +2241,226 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(asnumpy(obj)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(self.object_cropped.shape) + asnumpy(2 * padding) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(xp.asarray(errors), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") + pix_output /= pix_count + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5.25) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized Squared Error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + cropped_object_angle, vmin, vmax = return_scaled_histogram_ordering( + np.angle(self.object_cropped), + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * cropped_object_angle.shape[1], + self.sampling[0] * cropped_object_angle.shape[0], + 0, + ] + + ax.imshow( + cropped_object_angle, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + def show_fourier_probe( self, probe=None, @@ -2383,32 +2607,3 @@ def object_cropped(self): """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) - - @property - def self_consistency_errors(self): - """Compute the self-consistency errors for each probe position""" - - xp = self._xp - asnumpy = self._asnumpy - - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) - - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() - - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - - # Normalized mean-squared errors - error = xp.sum( - xp.abs(self._amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) - ) - error /= self._mean_diffraction_intensity - - return asnumpy(error) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 388b57e0a..da501c746 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -839,3 +839,34 @@ def show_complex( if returnfig: return fig, ax + + +def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + if vmin is None: + vmin = 0.02 + if vmax is None: + vmax = 0.98 + + vals = np.sort(array.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + + scaled_array = array.copy() + scaled_array = np.where(scaled_array < vmin, vmin, scaled_array) + scaled_array = np.where(scaled_array > vmax, vmax, scaled_array) + + if normalize: + scaled_array -= scaled_array.min() + scaled_array /= scaled_array.max() + vmin = 0 + vmax = 1 + + return scaled_array, vmin, vmax From f93576a5dad24f7816c9f4bd72010bd393b4cbff Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 12:50:13 -0700 Subject: [PATCH 172/176] generalizing to accommodate other classes easier --- .../process/phase/iterative_base_class.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 0f342d5c8..772f6b133 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -2287,10 +2287,22 @@ def _return_self_consistency_errors( return asnumpy(errors) + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped) + else: + projected_cropped_potential = self.object_cropped + + return projected_cropped_potential + def show_uncertainty_visualization( self, errors=None, max_batch_size=None, + projected_cropped_potential=None, kde_sigma=None, plot_histogram=True, plot_contours=False, @@ -2301,6 +2313,9 @@ def show_uncertainty_visualization( if errors is None: errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + if kde_sigma is None: kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] @@ -2323,7 +2338,9 @@ def show_uncertainty_visualization( padding = xp.min(rotated_points, axis=0).astype("int") # bilinear sampling - pixel_output = np.array(self.object_cropped.shape) + asnumpy(2 * padding) + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) pixel_size = pixel_output.prod() xa = rotated_points[:, 0] @@ -2415,21 +2432,21 @@ def show_uncertainty_visualization( vmin = kwargs.pop("vmin", None) vmax = kwargs.pop("vmax", None) - cropped_object_angle, vmin, vmax = return_scaled_histogram_ordering( - np.angle(self.object_cropped), + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, vmin=vmin, vmax=vmax, ) extent = [ 0, - self.sampling[1] * cropped_object_angle.shape[1], - self.sampling[0] * cropped_object_angle.shape[0], + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], 0, ] ax.imshow( - cropped_object_angle, + projected_cropped_potential, vmin=vmin, vmax=vmax, extent=extent, From 71cde33e67b5d4828e5b78456c7dbfd6af7c932b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 13:47:19 -0700 Subject: [PATCH 173/176] add uncertainty viz to all classes except OT --- ...tive_mixedstate_multislice_ptychography.py | 66 +++++++++++++----- .../iterative_mixedstate_ptychography.py | 55 ++++++++++----- .../iterative_multislice_ptychography.py | 11 +++ .../iterative_overlap_magnetic_tomography.py | 25 ++++++- .../phase/iterative_overlap_tomography.py | 25 ++++++- .../iterative_simultaneous_ptychography.py | 69 +++++++++++++++++++ 6 files changed, 211 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6cd74828e..f4c10cb13 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3595,30 +3595,60 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" xp = self._xp asnumpy = self._asnumpy - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) - # Normalized mean-squared errors - error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) - error /= self._mean_diffraction_intensity + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) - return asnumpy(error) + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index ebc40928d..d68291143 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -2342,30 +2342,49 @@ def show_fourier_probe( **kwargs, ) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" xp = self._xp asnumpy = self._asnumpy - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) - # Normalized mean-squared errors - error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) - error /= self._mean_diffraction_intensity + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity - return asnumpy(error) + return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 764f0b4a0..93e32b079 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3426,3 +3426,14 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 7c96cb34c..c49a1faac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3337,7 +3337,28 @@ def positions(self): return np.asarray(positions_all) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 54b94010a..ddd13ac58 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3217,7 +3217,28 @@ def positions(self): return np.asarray(positions_all) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 866ff0a89..233d34e45 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -351,6 +351,9 @@ def preprocess( ) ) + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + # 1st measurement sets rotation angle and transposition ( measurement_0, @@ -3408,3 +3411,69 @@ def self_consistency_errors(self): error /= self._mean_diffraction_intensity return asnumpy(error) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[0][start:end] + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped[0]) + else: + projected_cropped_potential = self.object_cropped[0] + + return projected_cropped_potential + + @property + def object_cropped(self): + """Cropped and rotated object""" + + obj_e, obj_m = self._object + obj_e = self._crop_rotate_object_fov(obj_e) + obj_m = self._crop_rotate_object_fov(obj_m) + return (obj_e, obj_m) From 4b227dc9bd8ebaa9d53b3792e489b741ee97cad9 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 13:47:58 -0700 Subject: [PATCH 174/176] small kde parallax bug --- py4DSTEM/process/phase/iterative_parallax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 3758a64e8..9f690c434 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1207,7 +1207,7 @@ def subpixel_alignment( # kernel density estimate sigma = kde_sigma * self._kde_upsample_factor pix_count = gaussian_filter(pix_count, sigma) - pix_count[pix_output == 0.0] = np.inf + pix_count[pix_count == 0.0] = np.inf pix_output = gaussian_filter(pix_output, sigma) pix_output /= pix_count From bab740671565f471474602c98fff1fc4f63251dc Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 5 Nov 2023 09:17:23 -0800 Subject: [PATCH 175/176] more parallax plotting fun(ctionality) --- py4DSTEM/process/phase/iterative_parallax.py | 23 +++++++++++++----- py4DSTEM/visualize/vis_special.py | 25 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 9f690c434..716e1d782 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -587,16 +587,27 @@ def preprocess( self.recon_BF = asnumpy(self._recon_BF) if plot_average_bf: - figsize = kwargs.pop("figsize", (6, 6)) + figsize = kwargs.pop("figsize", (6, 12)) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(1, 2, figsize=figsize) - self._visualize_figax(fig, ax, **kwargs) + self._visualize_figax(fig, ax[0], **kwargs) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Average Bright Field Image") + ax[0].set_ylabel("x [A]") + ax[0].set_xlabel("y [A]") + ax[0].set_title("Average Bright Field Image") + reciprocal_extent = [ + -0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + ] + ax[1].imshow(self._dp_mask, extent=reciprocal_extent, cmap="gray") + ax[1].set_title("DP mask") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + plt.tight_layout() self._preprocessed = True if self._device == "gpu": diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index da501c746..1d46ebf44 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -842,6 +842,31 @@ def show_complex( def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + """ + Utility function for calculating min and max values for plotting array + based on distribution of pixel values + + Parameters + ---------- + array: np.array + array to be plotted + vmin: float + lower fraction cut off of pixel values + vmax: float + upper fraction cut off of pixel values + normalize: bool + if True, rescales from 0 to 1 + + Returns + ---------- + scaled_array: np.array + array clipped outside vmin and vmax + vmin: float + lower value to be plotted + vmax: float + upper value to be plotted + """ + if vmin is None: vmin = 0.02 if vmax is None: From dd09924a14433ab7e86bd726b649440421d97e5e Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Sun, 5 Nov 2023 09:20:27 -0800 Subject: [PATCH 176/176] black formatting --- py4DSTEM/visualize/vis_special.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 1d46ebf44..acacb6184 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -850,10 +850,10 @@ def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=Fals ---------- array: np.array array to be plotted - vmin: float - lower fraction cut off of pixel values + vmin: float + lower fraction cut off of pixel values vmax: float - upper fraction cut off of pixel values + upper fraction cut off of pixel values normalize: bool if True, rescales from 0 to 1