Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmented and Other Geometry Detector Ptychography #683

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
tv_denoise: bool = True,
tv_denoise_weights=None,
Expand Down Expand Up @@ -990,6 +991,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -1081,6 +1085,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1188,6 +1195,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
store_iterations: bool = False,
collective_measurement_updates: bool = True,
Expand Down Expand Up @@ -1310,6 +1311,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -1407,6 +1411,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1489,6 +1496,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme=use_projection_scheme,
projection_a=projection_a,
projection_b=projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ def reconstruct(
fix_potential_baseline: bool = True,
vacuum_mask: np.ndarray = None,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
pure_phase_object: bool = False,
tv_denoise_chambolle: bool = True,
Expand Down Expand Up @@ -885,6 +886,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
pure_phase_object: bool, optional
Expand Down Expand Up @@ -984,6 +988,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1031,6 +1038,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/mixedstate_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def reconstruct(
fix_potential_baseline: bool = True,
vacuum_mask: np.ndarray = None,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
store_iterations: bool = False,
progress_bar: bool = True,
Expand Down Expand Up @@ -791,6 +792,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -878,6 +882,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -925,6 +932,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def reconstruct(
fix_potential_baseline: bool = True,
vacuum_mask: np.ndarray = None,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
pure_phase_object: bool = False,
tv_denoise_chambolle: bool = True,
Expand Down Expand Up @@ -857,6 +858,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
pure_phase_object: bool, optional
Expand Down Expand Up @@ -960,6 +964,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1007,6 +1014,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
55 changes: 53 additions & 2 deletions py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,7 +1679,9 @@ def cross_correlate_amplitudes_to_probe_aperture(

return self

def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask):
def _gradient_descent_fourier_projection(
self, amplitudes, overlap, fourier_mask, virtual_detector_masks
):
"""
Ptychographic fourier projection method for GD method.

Expand All @@ -1692,6 +1694,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.

Returns
--------
Expand All @@ -1715,6 +1720,15 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, None, :, :] * virtual_detector_masks[None, :, :, :],
axis=(-1, -2),
).transpose()
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[:, mask] = value[:, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

Expand Down Expand Up @@ -1746,6 +1760,7 @@ def _projection_sets_fourier_projection(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand Down Expand Up @@ -1777,6 +1792,9 @@ def _projection_sets_fourier_projection(
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
Currently not implemented for projection-sets
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
projection_a: float
projection_b: float
projection_c: float
Expand All @@ -1792,6 +1810,9 @@ def _projection_sets_fourier_projection(
if fourier_mask is not None:
raise NotImplementedError()

if virtual_detector_masks is not None:
raise NotImplementedError()

xp = self._xp
projection_x = 1 - projection_a - projection_b
projection_y = 1 - projection_c
Expand Down Expand Up @@ -1849,6 +1870,7 @@ def _forward(
amplitudes,
exit_waves,
fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand All @@ -1871,6 +1893,9 @@ def _forward(
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
use_projection_scheme: bool,
If True, use generalized projection update
projection_a: float
Expand Down Expand Up @@ -1907,6 +1932,7 @@ def _forward(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand All @@ -1917,6 +1943,7 @@ def _forward(
amplitudes,
overlap,
fourier_mask,
virtual_detector_masks,
)

return shifted_probes, object_patches, overlap, exit_waves, error
Expand Down Expand Up @@ -2904,7 +2931,9 @@ def _return_farfield_amplitudes(self, fourier_overlap):
xp = self._xp
return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask):
def _gradient_descent_fourier_projection(
self, amplitudes, overlap, fourier_mask, virtual_detector_masks
):
"""
Ptychographic fourier projection method for GD method.

Expand All @@ -2917,6 +2946,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.

Returns
--------
Expand All @@ -2940,8 +2972,19 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, :, None, :, :]
* virtual_detector_masks[None, None, :, :, :],
axis=(-1, -2),
).transpose(2, 0, 1)
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)

Expand All @@ -2951,6 +2994,7 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap

fourier_modified_overlap = fourier_modified_overlap - fourier_overlap

if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

Expand All @@ -2974,6 +3018,7 @@ def _projection_sets_fourier_projection(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand Down Expand Up @@ -3005,6 +3050,9 @@ def _projection_sets_fourier_projection(
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
Currently not implemented for projection sets
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
projection_a: float
projection_b: float
projection_c: float
Expand All @@ -3020,6 +3068,9 @@ def _projection_sets_fourier_projection(
if fourier_mask is not None:
raise NotImplementedError()

if virtual_detector_masks is not None:
raise NotImplementedError()

xp = self._xp
projection_x = 1 - projection_a - projection_b
projection_y = 1 - projection_c
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
probe_real_space_support_mask: np.ndarray = None,
tv_denoise: bool = True,
tv_denoise_weights: float = None,
Expand Down Expand Up @@ -910,6 +911,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to apply at the detector-plane for zeroing-out unreliable gradients.
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
probe_real_space_support_mask: np.ndarray
Corner-centered boolean mask, outside of which the probe amplitude will be set to zero.
store_iterations: bool, optional
Expand Down Expand Up @@ -990,6 +994,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1082,6 +1089,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
Loading