Skip to content

Commit

Permalink
refactor for update
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Sep 6, 2024
1 parent bf2184c commit 035c57d
Showing 1 changed file with 147 additions and 67 deletions.
214 changes: 147 additions & 67 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,84 +1165,164 @@ def _calculate_update(
"""
xp = self._xp

ind = self._positions_vox_F[0][0] == x_index
# ind = self._positions_vox_F[0][0] == x_index

# diffraction_patterns_resampled = xp.zeros(
# (self._positions_vox_dF[0][0].shape[0], object_sliced.shape[-1])
# )
# diffraction_patterns_resampled[
# xp.ravel_multi_index(
# (
# self._positions_vox_F[datacube_number][0][ind],
# self._positions_vox_F[datacube_number][1][ind],
# ),
# self._initial_datacube_shape[0:2],
# mode="clip",
# )
# ] += (
# diffraction_patterns_projected[ind]
# * (
# (1 - self._positions_vox_dF[datacube_number][0][ind])
# * (1 - self._positions_vox_dF[datacube_number][1][ind])
# )[:, None]
# )

# diffraction_patterns_resampled[
# xp.ravel_multi_index(
# (
# self._positions_vox_F[datacube_number][0][ind] + 1,
# self._positions_vox_F[datacube_number][1][ind],
# ),
# self._initial_datacube_shape[0:2],
# mode="clip",
# )
# ] += (
# diffraction_patterns_projected[ind]
# * (
# (self._positions_vox_dF[datacube_number][0][ind])
# * (1 - self._positions_vox_dF[datacube_number][1][ind])
# )[:, None]
# )

# diffraction_patterns_resampled[
# xp.ravel_multi_index(
# (
# self._positions_vox_F[datacube_number][0][ind],
# self._positions_vox_F[datacube_number][1][ind] + 1,
# ),
# self._initial_datacube_shape[0:2],
# mode="clip",
# )
# ] += (
# diffraction_patterns_projected[ind]
# * (
# (1 - self._positions_vox_dF[datacube_number][0][ind])
# * (self._positions_vox_dF[datacube_number][1][ind])
# )[:, None]
# )

# diffraction_patterns_resampled[
# xp.ravel_multi_index(
# (
# self._positions_vox_F[datacube_number][0][ind] + 1,
# self._positions_vox_F[datacube_number][1][ind] + 1,
# ),
# self._initial_datacube_shape[0:2],
# mode="clip",
# )
# ] += (
# diffraction_patterns_projected[ind]
# * (
# (self._positions_vox_dF[datacube_number][0][ind])
# * (self._positions_vox_dF[datacube_number][1][ind])
# )[:, None]
# )
# diffraction_patterns_resampled = diffraction_patterns_resampled[ind]
# update = diffraction_patterns_resampled - object_sliced

diffraction_patterns_resampled = xp.zeros(
(self._positions_vox_dF[0][0].shape[0], object_sliced.shape[-1])
)
diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind],
self._positions_vox_F[datacube_number][1][ind],
),
self._initial_datacube_shape[0:2],
mode="clip",
)
] += (
diffraction_patterns_projected[ind]
* (
(1 - self._positions_vox_dF[datacube_number][0][ind])
* (1 - self._positions_vox_dF[datacube_number][1][ind])
)[:, None]
)
s = self._object_shape_6D

diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind] + 1,
self._positions_vox_F[datacube_number][1][ind],
),
self._initial_datacube_shape[0:2],
mode="clip",
)
] += (
diffraction_patterns_projected[ind]
* (
(self._positions_vox_dF[datacube_number][0][ind])
* (1 - self._positions_vox_dF[datacube_number][1][ind])
)[:, None]


ind0 = self._positions_vox_F[datacube_number][0] == x_index
ind1 = self._positions_vox_F[datacube_number][0] == x_index + 1


dp_length = diffraction_patterns_projected.shape[1]

dp_patterns = np.hstack(
[
diffraction_patterns_projected[ind0].ravel(),
diffraction_patterns_projected[ind0].ravel(),
diffraction_patterns_projected[ind1].ravel(),
diffraction_patterns_projected[ind1].ravel()
]
)

diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind],
self._positions_vox_F[datacube_number][1][ind] + 1,
weights = np.hstack(
[
np.repeat(
(
1-self._positions_vox_dF[datacube_number][0][ind0]
)*(
1-self._positions_vox_dF[datacube_number][1][ind0]
), dp_length
),
self._initial_datacube_shape[0:2],
mode="clip",
)
] += (
diffraction_patterns_projected[ind]
* (
(1 - self._positions_vox_dF[datacube_number][0][ind])
* (self._positions_vox_dF[datacube_number][1][ind])
)[:, None]
np.repeat(
(
1-self._positions_vox_dF[datacube_number][0][ind0]
)*(
self._positions_vox_dF[datacube_number][1][ind0]
), dp_length
),
np.repeat(
(
self._positions_vox_dF[datacube_number][0][ind1]
)*(
1-self._positions_vox_dF[datacube_number][1][ind1]
), dp_length
),
np.repeat(
(
self._positions_vox_dF[datacube_number][0][ind1]
)*(
self._positions_vox_dF[datacube_number][1][ind1]
), dp_length
)
]

)

diffraction_patterns_resampled[
xp.ravel_multi_index(
(
self._positions_vox_F[datacube_number][0][ind] + 1,
self._positions_vox_F[datacube_number][1][ind] + 1,
),
self._initial_datacube_shape[0:2],
mode="clip",
)
] += (
diffraction_patterns_projected[ind]
* (
(self._positions_vox_dF[datacube_number][0][ind])
* (self._positions_vox_dF[datacube_number][1][ind])
)[:, None]

positions_y = xp.clip(
xp.hstack(
[
self._positions_vox[datacube_number][1][ind0],
self._positions_vox[datacube_number][1][ind0] + 1,
self._positions_vox[datacube_number][1][ind1],
self._positions_vox[datacube_number][1][ind1] + 1,
],
), 0, s[1]-1,
)
diffraction_patterns_resampled = diffraction_patterns_resampled[ind]
update = diffraction_patterns_resampled - object_sliced

bincount_x = xp.tile(
xp.arange(dp_length), dp_patterns.shape[0]//dp_length
) + xp.repeat(positions_y, dp_length) * dp_length

bincount_x = xp.asarray(bincount_x, dtype = "int")

dp_patterns_counted = xp.bincount(
bincount_x,
weights = dp_patterns * weights,
minlength = s[1] * dp_length
).reshape((s[1], dp_length))

update = dp_patterns_counted-object_sliced

error = xp.mean(update.ravel() ** 2) / xp.mean(
diffraction_patterns_projected.ravel() ** 2
dp_patterns_counted.ravel() ** 2
)

error = copy_to_device(error, "cpu")

return update, error
Expand Down

0 comments on commit 035c57d

Please sign in to comment.