Skip to content

Commit

Permalink
adding virtual detector support for mixed-probe fourier projection
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Sep 14, 2024
1 parent 8a08b5f commit 0a4ecf5
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion py4DSTEM/process/phase/ptychographic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,7 +2766,9 @@ def _return_farfield_amplitudes(self, fourier_overlap):
xp = self._xp
return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))

def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask):
def _gradient_descent_fourier_projection(
self, amplitudes, overlap, fourier_mask, virtual_detector_masks
):
"""
Ptychographic fourier projection method for GD method.
Expand All @@ -2779,6 +2781,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_mask: np.ndarray
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
Returns
--------
Expand All @@ -2802,8 +2807,19 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
xp=xp,
)

if virtual_detector_masks is not None:
masked_values = xp.sum(
fourier_overlap[:, :, None, :, :]
* virtual_detector_masks[None, None, :, :, :],
axis=(-1, -2),
).transpose(2, 0, 1)
fourier_overlap = xp.zeros_like(fourier_overlap)
for mask, value in zip(virtual_detector_masks, masked_values):
fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask)

if fourier_mask is not None:
fourier_overlap *= fourier_mask

farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap)
error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2)

Expand All @@ -2813,6 +2829,7 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask
fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap

fourier_modified_overlap = fourier_modified_overlap - fourier_overlap

if fourier_mask is not None:
fourier_modified_overlap *= fourier_mask

Expand All @@ -2836,6 +2853,7 @@ def _projection_sets_fourier_projection(
overlap,
exit_waves,
fourier_mask,
virtual_detector_masks,
projection_a,
projection_b,
projection_c,
Expand Down Expand Up @@ -2867,6 +2885,9 @@ def _projection_sets_fourier_projection(
Mask to apply at the detector-plane for zeroing-out unreliable gradients
Useful when detector has artifacts such as dead-pixels
Currently not implemented for projection sets
virtual_detector_masks: np.ndarray
List of corner-centered boolean masks for binning forward model exit waves,
to allow comparison with arbitrary geometry detector datasets.
projection_a: float
projection_b: float
projection_c: float
Expand All @@ -2882,6 +2903,9 @@ def _projection_sets_fourier_projection(
if fourier_mask is not None:
raise NotImplementedError()

if virtual_detector_masks is not None:
raise NotImplementedError()

xp = self._xp
projection_x = 1 - projection_a - projection_b
projection_y = 1 - projection_c
Expand Down

0 comments on commit 0a4ecf5

Please sign in to comment.