Skip to content

Commit

Permalink
add uncertainty viz to all classes except OT
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 4, 2023
1 parent f93576a commit 71cde33
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3595,30 +3595,60 @@ def _return_object_fft(
obj = self._crop_rotate_object_fov(np.sum(obj, axis=0))
return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))

@property
def self_consistency_errors(self):
def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""

xp = self._xp
asnumpy = self._asnumpy

# Re-initialize fractional positions and vector patches, max_batch_size = None
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
# Batch-size
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
# Re-initialize fractional positions and vector patches
errors = np.array([])
positions_px = self._positions_px.copy()

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)
intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))
for start, end in generate_batches(
self._num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
self._positions_px = positions_px[start:end]
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
amplitudes = self._amplitudes[start:end]

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)
intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

# Normalized mean-squared errors
batch_errors = xp.sum(
xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1)
)
errors = np.hstack((errors, batch_errors))

# Normalized mean-squared errors
error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1))
error /= self._mean_diffraction_intensity
self._positions_px = positions_px.copy()
errors /= self._mean_diffraction_intensity

return asnumpy(errors)

def _return_projected_cropped_potential(
self,
):
"""Utility function to accommodate multiple classes"""
if self._object_type == "complex":
projected_cropped_potential = np.angle(self.object_cropped).sum(0)
else:
projected_cropped_potential = self.object_cropped.sum(0)

return asnumpy(error)
return projected_cropped_potential
55 changes: 37 additions & 18 deletions py4DSTEM/process/phase/iterative_mixedstate_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,30 +2342,49 @@ def show_fourier_probe(
**kwargs,
)

@property
def self_consistency_errors(self):
def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""

xp = self._xp
asnumpy = self._asnumpy

# Re-initialize fractional positions and vector patches, max_batch_size = None
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
# Batch-size
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
# Re-initialize fractional positions and vector patches
errors = np.array([])
positions_px = self._positions_px.copy()

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)
intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))
for start, end in generate_batches(
self._num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
self._positions_px = positions_px[start:end]
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
amplitudes = self._amplitudes[start:end]

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)
intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

# Normalized mean-squared errors
batch_errors = xp.sum(
xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1)
)
errors = np.hstack((errors, batch_errors))

# Normalized mean-squared errors
error = xp.sum(xp.abs(self._amplitudes - intensity_norm) ** 2, axis=(-2, -1))
error /= self._mean_diffraction_intensity
self._positions_px = positions_px.copy()
errors /= self._mean_diffraction_intensity

return asnumpy(error)
return asnumpy(errors)
11 changes: 11 additions & 0 deletions py4DSTEM/process/phase/iterative_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -3426,3 +3426,14 @@ def _return_object_fft(

obj = self._crop_rotate_object_fov(np.sum(obj, axis=0))
return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))

def _return_projected_cropped_potential(
self,
):
"""Utility function to accommodate multiple classes"""
if self._object_type == "complex":
projected_cropped_potential = np.angle(self.object_cropped).sum(0)
else:
projected_cropped_potential = self.object_cropped.sum(0)

return projected_cropped_potential
25 changes: 23 additions & 2 deletions py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -3337,7 +3337,28 @@ def positions(self):

return np.asarray(positions_all)

@property
def self_consistency_errors(self):
def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""
raise NotImplementedError()

def _return_projected_cropped_potential(
self,
):
"""Utility function to accommodate multiple classes"""
raise NotImplementedError()

def show_uncertainty_visualization(
self,
errors=None,
max_batch_size=None,
projected_cropped_potential=None,
kde_sigma=None,
plot_histogram=True,
plot_contours=False,
**kwargs,
):
"""Plot uncertainty visualization using self-consistency errors"""
raise NotImplementedError()
25 changes: 23 additions & 2 deletions py4DSTEM/process/phase/iterative_overlap_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -3217,7 +3217,28 @@ def positions(self):

return np.asarray(positions_all)

@property
def self_consistency_errors(self):
def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""
raise NotImplementedError()

def _return_projected_cropped_potential(
self,
):
"""Utility function to accommodate multiple classes"""
raise NotImplementedError()

def show_uncertainty_visualization(
self,
errors=None,
max_batch_size=None,
projected_cropped_potential=None,
kde_sigma=None,
plot_histogram=True,
plot_contours=False,
**kwargs,
):
"""Plot uncertainty visualization using self-consistency errors"""
raise NotImplementedError()
69 changes: 69 additions & 0 deletions py4DSTEM/process/phase/iterative_simultaneous_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,9 @@ def preprocess(
)
)

# Ensure plot_center_of_mass is not in kwargs
kwargs.pop("plot_center_of_mass", None)

# 1st measurement sets rotation angle and transposition
(
measurement_0,
Expand Down Expand Up @@ -3408,3 +3411,69 @@ def self_consistency_errors(self):
error /= self._mean_diffraction_intensity

return asnumpy(error)

def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""

xp = self._xp
asnumpy = self._asnumpy

# Batch-size
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

# Re-initialize fractional positions and vector patches
errors = np.array([])
positions_px = self._positions_px.copy()

for start, end in generate_batches(
self._num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
self._positions_px = positions_px[start:end]
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
amplitudes = self._amplitudes[0][start:end]

# Overlaps
_, _, overlap = self._warmup_overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap[0])

# Normalized mean-squared errors
batch_errors = xp.sum(
xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1)
)
errors = np.hstack((errors, batch_errors))

self._positions_px = positions_px.copy()
errors /= self._mean_diffraction_intensity

return asnumpy(errors)

def _return_projected_cropped_potential(
self,
):
"""Utility function to accommodate multiple classes"""
if self._object_type == "complex":
projected_cropped_potential = np.angle(self.object_cropped[0])
else:
projected_cropped_potential = self.object_cropped[0]

return projected_cropped_potential

@property
def object_cropped(self):
"""Cropped and rotated object"""

obj_e, obj_m = self._object
obj_e = self._crop_rotate_object_fov(obj_e)
obj_m = self._crop_rotate_object_fov(obj_m)
return (obj_e, obj_m)

0 comments on commit 71cde33

Please sign in to comment.