Skip to content

Commit

Permalink
black and maybe fixed normalization error
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Sep 6, 2024
1 parent 1e63bfd commit adf112b
Showing 1 changed file with 58 additions and 69 deletions.
127 changes: 58 additions & 69 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
clear_fft_cache: bool = True,
name: str = "tomography",
):
"""
Nanobeam tomography!
"""
Nanobeam tomography!
"""

Expand Down Expand Up @@ -235,25 +235,25 @@ def reconstruct(
progress_bar: bool = True,
zero_edges: bool = True,
):
"""
"""
Main loop for reconstruct
Parameters
----------
num_iter: int
Number of iterations
Number of iterations
store_iterations: bool
if True, stores number of iterations
reset: bool
if True, stores number of iterations
reset: bool
if True, resets object
step_size: float
from 0 to 1, step size for update
step_size: float
from 0 to 1, step size for update
num_points: int
number of points for bilinear interpolation in real space
progres_bar: bool
progres_bar: bool
if True, shows progress bar
zero_edges: bool
If True, zero edges along y and z
If True, zero edges along y and z
"""
device = self._device

Expand Down Expand Up @@ -1108,28 +1108,31 @@ def _forward(
"clip",
)


bincount_x = (
xp.tile(
(xp.tile(self._ind_diffraction_ravel, (1, 4))),
(1, s[1]),
(xp.tile(self._ind_diffraction_ravel, 4)),
(s[1]),
)
+ xp.repeat(xp.arange(s[1]), ind_diff.shape[0]) * self._q_length
)
bincount_x = xp.asarray(bincount_x[0], dtype="int")

obj_projected = xp.bincount(
bincount_x,
ind_real = xp.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip")
self.ind_real = ind_real

obj_projected = (
(
obj[xp.ravel_multi_index((ind0, ind1), (s[1], s[2]), mode="clip"),]
* weights_real[:, :, None]
xp.bincount(
bincount_x,
(
(obj[ind_real] * weights_real[:, :, None]).mean(1)[:, ind_diff]
).ravel()
* xp.tile(weights_diff, s[1]).ravel(),
minlength=self._q_length * s[1],
).reshape(s[1], self._q_length)[:, self._circular_mask_bincount]
)
.mean(1)[:, ind_diff]
.ravel()
* xp.tile(weights_diff, s[1]).ravel(),
minlength=self._q_length * s[1],
).reshape(s[1], self._q_length)[:, self._circular_mask_bincount] * s[2]

* s[2]
* 4
)

self._ind0 = ind0
self._ind1 = ind1
Expand Down Expand Up @@ -1242,11 +1245,8 @@ def _calculate_update(

s = self._object_shape_6D



ind0 = self._positions_vox_F[datacube_number][0] == x_index
ind1 = self._positions_vox_F[datacube_number][0] == x_index + 1

ind0 = self._positions_vox_F[datacube_number][0] == x_index
ind1 = self._positions_vox_F[datacube_number][0] == x_index + 1

dp_length = diffraction_patterns_projected.shape[1]

Expand All @@ -1255,45 +1255,35 @@ def _calculate_update(
diffraction_patterns_projected[ind0].ravel(),
diffraction_patterns_projected[ind0].ravel(),
diffraction_patterns_projected[ind1].ravel(),
diffraction_patterns_projected[ind1].ravel()
diffraction_patterns_projected[ind1].ravel(),
]
)

weights = np.hstack(
[
np.repeat(
(
1-self._positions_vox_dF[datacube_number][0][ind0]
)*(
1-self._positions_vox_dF[datacube_number][1][ind0]
), dp_length
(1 - self._positions_vox_dF[datacube_number][0][ind0])
* (1 - self._positions_vox_dF[datacube_number][1][ind0]),
dp_length,
),
np.repeat(
(
1-self._positions_vox_dF[datacube_number][0][ind0]
)*(
self._positions_vox_dF[datacube_number][1][ind0]
), dp_length
(1 - self._positions_vox_dF[datacube_number][0][ind0])
* (self._positions_vox_dF[datacube_number][1][ind0]),
dp_length,
),
np.repeat(
(
self._positions_vox_dF[datacube_number][0][ind1]
)*(
1-self._positions_vox_dF[datacube_number][1][ind1]
), dp_length
(self._positions_vox_dF[datacube_number][0][ind1])
* (1 - self._positions_vox_dF[datacube_number][1][ind1]),
dp_length,
),
np.repeat(
(
self._positions_vox_dF[datacube_number][0][ind1]
)*(
self._positions_vox_dF[datacube_number][1][ind1]
), dp_length
)
(self._positions_vox_dF[datacube_number][0][ind1])
* (self._positions_vox_dF[datacube_number][1][ind1]),
dp_length,
),
]

)


positions_y = xp.clip(
xp.hstack(
[
Expand All @@ -1302,26 +1292,25 @@ def _calculate_update(
self._positions_vox[datacube_number][1][ind1],
self._positions_vox[datacube_number][1][ind1] + 1,
],
), 0, s[1]-1,
),
0,
s[1] - 1,
)

bincount_x = xp.tile(
xp.arange(dp_length), dp_patterns.shape[0]//dp_length
) + xp.repeat(positions_y, dp_length) * dp_length
bincount_x = (
xp.tile(xp.arange(dp_length), dp_patterns.shape[0] // dp_length)
+ xp.repeat(positions_y, dp_length) * dp_length
)

bincount_x = xp.asarray(bincount_x, dtype = "int")
bincount_x = xp.asarray(bincount_x, dtype="int")

dp_patterns_counted = xp.bincount(
bincount_x,
weights = dp_patterns * weights,
minlength = s[1] * dp_length
bincount_x, weights=dp_patterns * weights, minlength=s[1] * dp_length
).reshape((s[1], dp_length))

update = dp_patterns_counted-object_sliced
update = dp_patterns_counted - object_sliced

error = xp.mean(update.ravel() ** 2) / xp.mean(
dp_patterns_counted.ravel() ** 2
)
error = xp.mean(update.ravel() ** 2) / xp.mean(dp_patterns_counted.ravel() ** 2)

error = copy_to_device(error, "cpu")

Expand Down Expand Up @@ -1423,14 +1412,14 @@ def _constraints(
self,
zero_edges: bool,
):
"""
"""
Constrains for object
TODO: add constrains and break into multiple functions possibly
TODO: add constrains and break into multiple functions possibly
Parameters
----------
zero_edges: bool
If True, zero edges along y and z
If True, zero edges along y and z
"""

if zero_edges:
Expand Down

0 comments on commit adf112b

Please sign in to comment.