From 035c57d5f112786e98dccd3a34bffb254374da5c Mon Sep 17 00:00:00 2001 From: Stephanie Ribet Date: Fri, 6 Sep 2024 05:16:10 -0700 Subject: [PATCH] refactor for update --- py4DSTEM/tomography/tomography.py | 214 ++++++++++++++++++++---------- 1 file changed, 147 insertions(+), 67 deletions(-) diff --git a/py4DSTEM/tomography/tomography.py b/py4DSTEM/tomography/tomography.py index 6ded10784..8913ecbd1 100644 --- a/py4DSTEM/tomography/tomography.py +++ b/py4DSTEM/tomography/tomography.py @@ -1165,84 +1165,164 @@ def _calculate_update( """ xp = self._xp - ind = self._positions_vox_F[0][0] == x_index + # ind = self._positions_vox_F[0][0] == x_index + + # diffraction_patterns_resampled = xp.zeros( + # (self._positions_vox_dF[0][0].shape[0], object_sliced.shape[-1]) + # ) + # diffraction_patterns_resampled[ + # xp.ravel_multi_index( + # ( + # self._positions_vox_F[datacube_number][0][ind], + # self._positions_vox_F[datacube_number][1][ind], + # ), + # self._initial_datacube_shape[0:2], + # mode="clip", + # ) + # ] += ( + # diffraction_patterns_projected[ind] + # * ( + # (1 - self._positions_vox_dF[datacube_number][0][ind]) + # * (1 - self._positions_vox_dF[datacube_number][1][ind]) + # )[:, None] + # ) + + # diffraction_patterns_resampled[ + # xp.ravel_multi_index( + # ( + # self._positions_vox_F[datacube_number][0][ind] + 1, + # self._positions_vox_F[datacube_number][1][ind], + # ), + # self._initial_datacube_shape[0:2], + # mode="clip", + # ) + # ] += ( + # diffraction_patterns_projected[ind] + # * ( + # (self._positions_vox_dF[datacube_number][0][ind]) + # * (1 - self._positions_vox_dF[datacube_number][1][ind]) + # )[:, None] + # ) + + # diffraction_patterns_resampled[ + # xp.ravel_multi_index( + # ( + # self._positions_vox_F[datacube_number][0][ind], + # self._positions_vox_F[datacube_number][1][ind] + 1, + # ), + # self._initial_datacube_shape[0:2], + # mode="clip", + # ) + # ] += ( + # diffraction_patterns_projected[ind] + # * ( + # (1 - self._positions_vox_dF[datacube_number][0][ind]) + # * (self._positions_vox_dF[datacube_number][1][ind]) + # )[:, None] + # ) + + # diffraction_patterns_resampled[ + # xp.ravel_multi_index( + # ( + # self._positions_vox_F[datacube_number][0][ind] + 1, + # self._positions_vox_F[datacube_number][1][ind] + 1, + # ), + # self._initial_datacube_shape[0:2], + # mode="clip", + # ) + # ] += ( + # diffraction_patterns_projected[ind] + # * ( + # (self._positions_vox_dF[datacube_number][0][ind]) + # * (self._positions_vox_dF[datacube_number][1][ind]) + # )[:, None] + # ) + # diffraction_patterns_resampled = diffraction_patterns_resampled[ind] + # update = diffraction_patterns_resampled - object_sliced - diffraction_patterns_resampled = xp.zeros( - (self._positions_vox_dF[0][0].shape[0], object_sliced.shape[-1]) - ) - diffraction_patterns_resampled[ - xp.ravel_multi_index( - ( - self._positions_vox_F[datacube_number][0][ind], - self._positions_vox_F[datacube_number][1][ind], - ), - self._initial_datacube_shape[0:2], - mode="clip", - ) - ] += ( - diffraction_patterns_projected[ind] - * ( - (1 - self._positions_vox_dF[datacube_number][0][ind]) - * (1 - self._positions_vox_dF[datacube_number][1][ind]) - )[:, None] - ) + s = self._object_shape_6D - diffraction_patterns_resampled[ - xp.ravel_multi_index( - ( - self._positions_vox_F[datacube_number][0][ind] + 1, - self._positions_vox_F[datacube_number][1][ind], - ), - self._initial_datacube_shape[0:2], - mode="clip", - ) - ] += ( - diffraction_patterns_projected[ind] - * ( - (self._positions_vox_dF[datacube_number][0][ind]) - * (1 - self._positions_vox_dF[datacube_number][1][ind]) - )[:, None] + + + ind0 = self._positions_vox_F[datacube_number][0] == x_index + ind1 = self._positions_vox_F[datacube_number][0] == x_index + 1 + + + dp_length = diffraction_patterns_projected.shape[1] + + dp_patterns = np.hstack( + [ + diffraction_patterns_projected[ind0].ravel(), + diffraction_patterns_projected[ind0].ravel(), + diffraction_patterns_projected[ind1].ravel(), + diffraction_patterns_projected[ind1].ravel() + ] ) - diffraction_patterns_resampled[ - xp.ravel_multi_index( - ( - self._positions_vox_F[datacube_number][0][ind], - self._positions_vox_F[datacube_number][1][ind] + 1, + weights = np.hstack( + [ + np.repeat( + ( + 1-self._positions_vox_dF[datacube_number][0][ind0] + )*( + 1-self._positions_vox_dF[datacube_number][1][ind0] + ), dp_length ), - self._initial_datacube_shape[0:2], - mode="clip", - ) - ] += ( - diffraction_patterns_projected[ind] - * ( - (1 - self._positions_vox_dF[datacube_number][0][ind]) - * (self._positions_vox_dF[datacube_number][1][ind]) - )[:, None] + np.repeat( + ( + 1-self._positions_vox_dF[datacube_number][0][ind0] + )*( + self._positions_vox_dF[datacube_number][1][ind0] + ), dp_length + ), + np.repeat( + ( + self._positions_vox_dF[datacube_number][0][ind1] + )*( + 1-self._positions_vox_dF[datacube_number][1][ind1] + ), dp_length + ), + np.repeat( + ( + self._positions_vox_dF[datacube_number][0][ind1] + )*( + self._positions_vox_dF[datacube_number][1][ind1] + ), dp_length + ) + ] + ) - diffraction_patterns_resampled[ - xp.ravel_multi_index( - ( - self._positions_vox_F[datacube_number][0][ind] + 1, - self._positions_vox_F[datacube_number][1][ind] + 1, - ), - self._initial_datacube_shape[0:2], - mode="clip", - ) - ] += ( - diffraction_patterns_projected[ind] - * ( - (self._positions_vox_dF[datacube_number][0][ind]) - * (self._positions_vox_dF[datacube_number][1][ind]) - )[:, None] + + positions_y = xp.clip( + xp.hstack( + [ + self._positions_vox[datacube_number][1][ind0], + self._positions_vox[datacube_number][1][ind0] + 1, + self._positions_vox[datacube_number][1][ind1], + self._positions_vox[datacube_number][1][ind1] + 1, + ], + ), 0, s[1]-1, ) - diffraction_patterns_resampled = diffraction_patterns_resampled[ind] - update = diffraction_patterns_resampled - object_sliced + + bincount_x = xp.tile( + xp.arange(dp_length), dp_patterns.shape[0]//dp_length + ) + xp.repeat(positions_y, dp_length) * dp_length + + bincount_x = xp.asarray(bincount_x, dtype = "int") + + dp_patterns_counted = xp.bincount( + bincount_x, + weights = dp_patterns * weights, + minlength = s[1] * dp_length + ).reshape((s[1], dp_length)) + + update = dp_patterns_counted-object_sliced error = xp.mean(update.ravel() ** 2) / xp.mean( - diffraction_patterns_projected.ravel() ** 2 + dp_patterns_counted.ravel() ** 2 ) + error = copy_to_device(error, "cpu") return update, error