Skip to content

Commit

Permalink
Merge pull request #1 from py4dstem/pr_683
Browse files Browse the repository at this point in the history
Extending functionality to other ptycho classes
  • Loading branch information
juliedactyl authored Sep 15, 2024
2 parents 8a08b5f + 76a5da2 commit 16e0956
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 1 deletion.
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/magnetic_ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
tv_denoise: bool = True,
tv_denoise_weights=None,
tv_denoise_inner_iter=40,
Expand Down Expand Up @@ -984,6 +985,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
store_iterations: bool, optional
If True, reconstructed objects and probes are stored at each iteration
progress_bar: bool, optional
Expand Down Expand Up @@ -1073,6 +1077,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1180,6 +1187,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
store_iterations: bool = False,
collective_measurement_updates: bool = True,
progress_bar: bool = True,
Expand Down Expand Up @@ -1309,6 +1310,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
store_iterations: bool, optional
If True, reconstructed objects and probes are stored at each iteration
collective_measurement_updates: bool
Expand Down Expand Up @@ -1404,6 +1408,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1486,6 +1493,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme=use_projection_scheme,
projection_a=projection_a,
projection_b=projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/mixedstate_multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
pure_phase_object: bool = False,
tv_denoise_chambolle: bool = True,
tv_denoise_weight_chambolle=None,
Expand Down Expand Up @@ -881,6 +882,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
pure_phase_object: bool, optional
If True, object amplitude is set to unity
tv_denoise_chambolle: bool
Expand Down Expand Up @@ -978,6 +982,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1025,6 +1032,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/mixedstate_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
store_iterations: bool = False,
progress_bar: bool = True,
reset: bool = None,
Expand Down Expand Up @@ -787,6 +788,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
store_iterations: bool, optional
If True, reconstructed objects and probes are stored at each iteration
progress_bar: bool, optional
Expand Down Expand Up @@ -872,6 +876,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -919,6 +926,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/multislice_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
pure_phase_object: bool = False,
tv_denoise_chambolle: bool = True,
tv_denoise_weight_chambolle=None,
Expand Down Expand Up @@ -853,6 +854,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
pure_phase_object: bool, optional
If True, object amplitude is set to unity
tv_denoise_chambolle: bool
Expand Down Expand Up @@ -954,6 +958,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1001,6 +1008,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
26 changes: 25 additions & 1 deletion py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,7 +2766,9 @@ def _return_farfield_amplitudes(self, fourier_overlap):
xp = self._xp
return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask):
def _gradient_descent_fourier_projection(
self, amplitudes, overlap, fourier_mask, virtual_detector_masks
):
"""
Ptychographic fourier projection method for GD method.
Expand All @@ -2779,6 +2781,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
Returns
--------
Expand All @@ -2802,8 +2807,19 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, :, None, :, :]
* virtual_detector_masks[None, None, :, :, :],
axis=(-1, -2),
).transpose(2, 0, 1)
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)

Expand All @@ -2813,6 +2829,7 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap

fourier_modified_overlap = fourier_modified_overlap - fourier_overlap

if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

Expand All @@ -2836,6 +2853,7 @@ def _projection_sets_fourier_projection(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand Down Expand Up @@ -2867,6 +2885,9 @@ def _projection_sets_fourier_projection(
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
Currently not implemented for projection sets
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
projection_a: float
projection_b: float
projection_c: float
Expand All @@ -2882,6 +2903,9 @@ def _projection_sets_fourier_projection(
if fourier_mask is not None:
raise NotImplementedError()

if virtual_detector_masks is not None:
raise NotImplementedError()

xp = self._xp
projection_x = 1 - projection_a - projection_b
projection_y = 1 - projection_c
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/ptychographic_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def reconstruct(
shrinkage_rad: float = 0.0,
fix_potential_baseline: bool = True,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
tv_denoise: bool = True,
tv_denoise_weights: float = None,
tv_denoise_inner_iter=40,
Expand Down Expand Up @@ -901,6 +902,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to apply at the detector-plane for zeroing-out unreliable gradients.
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
store_iterations: bool, optional
If True, reconstructed objects and probes are stored at each iteration
progress_bar: bool, optional
Expand Down Expand Up @@ -979,6 +983,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

# main loop
for a0 in tqdmnd(
num_iter,
Expand Down Expand Up @@ -1071,6 +1078,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme,
projection_a,
projection_b,
Expand Down
8 changes: 8 additions & 0 deletions py4DSTEM/process/phase/xray_magnetic_ptychography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,7 @@ def reconstruct(
tv_denoise_weight: float = None,
tv_denoise_inner_iter: float = 40,
detector_fourier_mask: np.ndarray = None,
virtual_detector_masks: Sequence[np.ndarray] = None,
store_iterations: bool = False,
collective_measurement_updates: bool = True,
progress_bar: bool = True,
Expand Down Expand Up @@ -1281,6 +1282,9 @@ def reconstruct(
detector_fourier_mask: np.ndarray
Corner-centered mask to multiply the detector-plane gradients with (a value of zero supresses those pixels).
Useful when detector has artifacts such as dead-pixels. Usually binary.
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
store_iterations: bool, optional
If True, reconstructed objects and probes are stored at each iteration
collective_measurement_updates: bool
Expand Down Expand Up @@ -1376,6 +1380,9 @@ def reconstruct(
if detector_fourier_mask is not None:
detector_fourier_mask = xp.asarray(detector_fourier_mask)

if virtual_detector_masks is not None:
virtual_detector_masks = xp.asarray(virtual_detector_masks)

if gaussian_filter_sigma_m is None:
gaussian_filter_sigma_m = gaussian_filter_sigma_e

Expand Down Expand Up @@ -1458,6 +1465,7 @@ def reconstruct(
amplitudes_device,
self._exit_waves,
detector_fourier_mask,
virtual_detector_masks,
use_projection_scheme=use_projection_scheme,
projection_a=projection_a,
projection_b=projection_b,
Expand Down

0 comments on commit 16e0956

Please sign in to comment.