diff --git a/py4DSTEM/tomography/tomography.py b/py4DSTEM/tomography/tomography.py index 54d16f9bd..b34ee36e1 100644 --- a/py4DSTEM/tomography/tomography.py +++ b/py4DSTEM/tomography/tomography.py @@ -48,8 +48,8 @@ def __init__( clear_fft_cache: bool = True, name: str = "tomography", ): - """ - Nanobeam tomography! + """ + Nanobeam tomography! """ @@ -235,25 +235,25 @@ def reconstruct( progress_bar: bool = True, zero_edges: bool = True, ): - """ + """ Main loop for reconstruct Parameters ---------- num_iter: int - Number of iterations + Number of iterations store_iterations: bool - if True, stores number of iterations - reset: bool + if True, stores number of iterations + reset: bool if True, resets object - step_size: float - from 0 to 1, step size for update + step_size: float + from 0 to 1, step size for update num_points: int number of points for bilinear interpolation in real space - progres_bar: bool + progres_bar: bool if True, shows progress bar zero_edges: bool - If True, zero edges along y and z + If True, zero edges along y and z """ device = self._device @@ -1108,28 +1108,31 @@ def _forward( "clip", ) - bincount_x = ( xp.tile( - (xp.tile(self._ind_diffraction_ravel, (1, 4))), - (1, s[1]), + (xp.tile(self._ind_diffraction_ravel, 4)), + (s[1]), ) + xp.repeat(xp.arange(s[1]), ind_diff.shape[0]) * self._q_length ) - bincount_x = xp.asarray(bincount_x[0], dtype="int") - obj_projected = xp.bincount( - bincount_x, + ind_real = xp.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip") + self.ind_real = ind_real + + obj_projected = ( ( - obj[xp.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip"),] - * weights_real[:, :, None] + xp.bincount( + bincount_x, + ( + (obj[ind_real] * weights_real[:, :, None]).mean(1)[:, ind_diff] + ).ravel() + * xp.tile(weights_diff, s[1]).ravel(), + minlength=self._q_length * s[1], + ).reshape(s[1], self._q_length)[:, self._circular_mask_bincount] ) - .mean(1)[:, ind_diff] - .ravel() - * xp.tile(weights_diff, s[1]).ravel(), - minlength=self._q_length * s[1], - ).reshape(s[1], self._q_length)[:, self._circular_mask_bincount] * s[2] - + * s[2] + * 4 + ) self._ind0 = ind0 self._ind1 = ind1 @@ -1242,11 +1245,8 @@ def _calculate_update( s = self._object_shape_6D - - - ind0 = self._positions_vox_F[datacube_number][0] == x_index - ind1 = self._positions_vox_F[datacube_number][0] == x_index + 1 - + ind0 = self._positions_vox_F[datacube_number][0] == x_index + ind1 = self._positions_vox_F[datacube_number][0] == x_index + 1 dp_length = diffraction_patterns_projected.shape[1] @@ -1255,45 +1255,35 @@ def _calculate_update( diffraction_patterns_projected[ind0].ravel(), diffraction_patterns_projected[ind0].ravel(), diffraction_patterns_projected[ind1].ravel(), - diffraction_patterns_projected[ind1].ravel() + diffraction_patterns_projected[ind1].ravel(), ] ) weights = np.hstack( [ np.repeat( - ( - 1-self._positions_vox_dF[datacube_number][0][ind0] - )*( - 1-self._positions_vox_dF[datacube_number][1][ind0] - ), dp_length + (1 - self._positions_vox_dF[datacube_number][0][ind0]) + * (1 - self._positions_vox_dF[datacube_number][1][ind0]), + dp_length, ), np.repeat( - ( - 1-self._positions_vox_dF[datacube_number][0][ind0] - )*( - self._positions_vox_dF[datacube_number][1][ind0] - ), dp_length + (1 - self._positions_vox_dF[datacube_number][0][ind0]) + * (self._positions_vox_dF[datacube_number][1][ind0]), + dp_length, ), np.repeat( - ( - self._positions_vox_dF[datacube_number][0][ind1] - )*( - 1-self._positions_vox_dF[datacube_number][1][ind1] - ), dp_length + (self._positions_vox_dF[datacube_number][0][ind1]) + * (1 - self._positions_vox_dF[datacube_number][1][ind1]), + dp_length, ), np.repeat( - ( - self._positions_vox_dF[datacube_number][0][ind1] - )*( - self._positions_vox_dF[datacube_number][1][ind1] - ), dp_length - ) + (self._positions_vox_dF[datacube_number][0][ind1]) + * (self._positions_vox_dF[datacube_number][1][ind1]), + dp_length, + ), ] - ) - positions_y = xp.clip( xp.hstack( [ @@ -1302,26 +1292,25 @@ def _calculate_update( self._positions_vox[datacube_number][1][ind1], self._positions_vox[datacube_number][1][ind1] + 1, ], - ), 0, s[1]-1, + ), + 0, + s[1] - 1, ) - bincount_x = xp.tile( - xp.arange(dp_length), dp_patterns.shape[0]//dp_length - ) + xp.repeat(positions_y, dp_length) * dp_length + bincount_x = ( + xp.tile(xp.arange(dp_length), dp_patterns.shape[0] // dp_length) + + xp.repeat(positions_y, dp_length) * dp_length + ) - bincount_x = xp.asarray(bincount_x, dtype = "int") + bincount_x = xp.asarray(bincount_x, dtype="int") dp_patterns_counted = xp.bincount( - bincount_x, - weights = dp_patterns * weights, - minlength = s[1] * dp_length + bincount_x, weights=dp_patterns * weights, minlength=s[1] * dp_length ).reshape((s[1], dp_length)) - update = dp_patterns_counted-object_sliced + update = dp_patterns_counted - object_sliced - error = xp.mean(update.ravel() ** 2) / xp.mean( - dp_patterns_counted.ravel() ** 2 - ) + error = xp.mean(update.ravel() ** 2) / xp.mean(dp_patterns_counted.ravel() ** 2) error = copy_to_device(error, "cpu") @@ -1423,14 +1412,14 @@ def _constraints( self, zero_edges: bool, ): - """ + """ Constrains for object - TODO: add constrains and break into multiple functions possibly - + TODO: add constrains and break into multiple functions possibly + Parameters ---------- zero_edges: bool - If True, zero edges along y and z + If True, zero edges along y and z """ if zero_edges: