From 71cde33e67b5d4828e5b78456c7dbfd6af7c932b Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Sat, 4 Nov 2023 13:47:19 -0700 Subject: [PATCH] add uncertainty viz to all classes except OT --- ...tive_mixedstate_multislice_ptychography.py | 66 +++++++++++++----- .../iterative_mixedstate_ptychography.py | 55 ++++++++++----- .../iterative_multislice_ptychography.py | 11 +++ .../iterative_overlap_magnetic_tomography.py | 25 ++++++- .../phase/iterative_overlap_tomography.py | 25 ++++++- .../iterative_simultaneous_ptychography.py | 69 +++++++++++++++++++ 6 files changed, 211 insertions(+), 40 deletions(-) diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py index 6cd74828e..f4c10cb13 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -3595,30 +3595,60 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" xp = self._xp asnumpy = self._asnumpy - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) - # Normalized mean-squared errors - error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) - error /= self._mean_diffraction_intensity + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) - return asnumpy(error) + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index ebc40928d..d68291143 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -2342,30 +2342,49 @@ def show_fourier_probe( **kwargs, ) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" xp = self._xp asnumpy = self._asnumpy - # Re-initialize fractional positions and vector patches, max_batch_size = None - self._positions_px_fractional = self._positions_px - xp.round( - self._positions_px - ) + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns - ( - self._vectorized_patch_indices_row, - self._vectorized_patch_indices_col, - ) = self._extract_vectorized_patch_indices() + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() - # Overlaps - _, _, overlap = self._overlap_projection(self._object, self._probe) - fourier_overlap = xp.fft.fft2(overlap) - intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) - # Normalized mean-squared errors - error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1)) - error /= self._mean_diffraction_intensity + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity - return asnumpy(error) + return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index 764f0b4a0..93e32b079 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -3426,3 +3426,14 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 7c96cb34c..c49a1faac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -3337,7 +3337,28 @@ def positions(self): return np.asarray(positions_all) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 54b94010a..ddd13ac58 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -3217,7 +3217,28 @@ def positions(self): return np.asarray(positions_all) - @property - def self_consistency_errors(self): + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): """Compute the self-consistency errors for each probe position""" raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 866ff0a89..233d34e45 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -351,6 +351,9 @@ def preprocess( ) ) + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + # 1st measurement sets rotation angle and transposition ( measurement_0, @@ -3408,3 +3411,69 @@ def self_consistency_errors(self): error /= self._mean_diffraction_intensity return asnumpy(error) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[0][start:end] + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped[0]) + else: + projected_cropped_potential = self.object_cropped[0] + + return projected_cropped_potential + + @property + def object_cropped(self): + """Cropped and rotated object""" + + obj_e, obj_m = self._object + obj_e = self._crop_rotate_object_fov(obj_e) + obj_m = self._crop_rotate_object_fov(obj_m) + return (obj_e, obj_m)