diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index 2389298cd..7a38ec0bb 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -2766,7 +2766,9 @@ def _return_farfield_amplitudes(self, fourier_overlap): xp = self._xp return xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) - def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask): + def _gradient_descent_fourier_projection( + self, amplitudes, overlap, fourier_mask, virtual_detector_masks + ): """ Ptychographic fourier projection method for GD method. @@ -2779,6 +2781,9 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask fourier_mask: np.ndarray Mask to apply at the detector-plane for zeroing-out unreliable gradients Useful when detector has artifacts such as dead-pixels + virtual_detector_masks: np.ndarray + List of corner-centered boolean masks for binning forward model exit waves, + to allow comparison with arbitrary geometry detector datasets. Returns -------- @@ -2802,8 +2807,19 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask xp=xp, ) + if virtual_detector_masks is not None: + masked_values = xp.sum( + fourier_overlap[:, :, None, :, :] + * virtual_detector_masks[None, None, :, :, :], + axis=(-1, -2), + ).transpose(2, 0, 1) + fourier_overlap = xp.zeros_like(fourier_overlap) + for mask, value in zip(virtual_detector_masks, masked_values): + fourier_overlap[..., mask] = value[:, :, None] / xp.sum(mask) + if fourier_mask is not None: fourier_overlap *= fourier_mask + farfield_amplitudes = self._return_farfield_amplitudes(fourier_overlap) error = xp.sum(xp.abs(amplitudes - farfield_amplitudes) ** 2) @@ -2813,6 +2829,7 @@ def _gradient_descent_fourier_projection(self, amplitudes, overlap, fourier_mask fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap fourier_modified_overlap = fourier_modified_overlap - fourier_overlap + if fourier_mask is not None: fourier_modified_overlap *= fourier_mask @@ -2836,6 +2853,7 @@ def _projection_sets_fourier_projection( overlap, exit_waves, fourier_mask, + virtual_detector_masks, projection_a, projection_b, projection_c, @@ -2867,6 +2885,9 @@ def _projection_sets_fourier_projection( Mask to apply at the detector-plane for zeroing-out unreliable gradients Useful when detector has artifacts such as dead-pixels Currently not implemented for projection sets + virtual_detector_masks: np.ndarray + List of corner-centered boolean masks for binning forward model exit waves, + to allow comparison with arbitrary geometry detector datasets. projection_a: float projection_b: float projection_c: float @@ -2882,6 +2903,9 @@ def _projection_sets_fourier_projection( if fourier_mask is not None: raise NotImplementedError() + if virtual_detector_masks is not None: + raise NotImplementedError() + xp = self._xp projection_x = 1 - projection_a - projection_b projection_y = 1 - projection_c