Skip to content

Commit

Permalink
rotating and transposing
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Sep 17, 2024
1 parent 36ee9d2 commit 397be35
Showing 1 changed file with 42 additions and 36 deletions.
78 changes: 42 additions & 36 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from py4DSTEM.process.diffraction import Crystal
from py4DSTEM.process.phase.utils import copy_to_device
from py4DSTEM.utils import fourier_resample
from scipy.ndimage import zoom
from scipy.ndimage import rotate, zoom
from scipy.spatial.transform import Rotation as R

try:
Expand Down Expand Up @@ -127,8 +127,8 @@ def preprocess(
robust: bool = False,
robust_steps: int = 3,
robust_thresh: int = 2,
force_q_to_r_rotation_deg = 0,
force_q_to_r_transpose = False,
force_q_to_r_rotation_deg=None,
force_q_to_r_transpose=False,
):
"""
Preprocessing for nanobeam tomography
Expand Down Expand Up @@ -573,17 +573,16 @@ def _calculate_scan_positions(

## TODO: check sign
if self._tilt_rotation_axis_shift_px:
y += self._tilt_rotation_axis_shift_px

y += self._tilt_rotation_axis_shift_px

# remove data outside FOV
if mask_real_space is None:
mask_real_space = np.ones(x.shape, dtype="bool")

else:
if self._transpose_xy:
mask_real_space = mask_real_space.swapaxes((-1,-2))
mask_real_space = mask_real_space.swapaxes((-1, -2))

mask_real_space[x >= self._field_of_view_A[0]] = False
mask_real_space[x < 0] = False
mask_real_space[y >= self._field_of_view_A[1]] = False
Expand Down Expand Up @@ -711,12 +710,11 @@ def _reshape_diffraction_patterns(
if datacube_number == 0:
self._make_diffraction_masks(q_max_inv_A=q_max_inv_A)

# TODO deal with with diffraction patterns cut off by detector?

diffraction_patterns_reshaped = self._reshape_4D_array_to_2D(
data=datacube.data,
qx0_fit=qx0_fit,
qy0_fit=qy0_fit,
ind_diffraction_ravel=self._ind_diffraction_rotate_transpose_ravel,
)

del datacube
Expand Down Expand Up @@ -760,8 +758,29 @@ def _make_diffraction_masks(self, q_max_inv_A):

ind_diffraction[mask] = ind_diffraction_rot[mask]

ind_diffraction_rotate_transpose = ind_diffraction.copy()

if self._force_q_to_r_transpose:
ind_diffraction_rotate_transpose = (
ind_diffraction_rotate_transpose.swapaxes(-1, -2)
)
if self._force_q_to_r_rotation_deg is not None:
ind_diffraction_rotate_transpose = np.clip(
rotate(
ind_diffraction_rotate_transpose,
self._force_q_to_r_rotation_deg,
reshape=False,
),
0,
np.max(ind_diffraction),
)

self._ind_diffraction = ind_diffraction
self._ind_diffraction_ravel = ind_diffraction.ravel()
self._ind_diffraction_rotate_transpose = ind_diffraction_rotate_transpose
self._ind_diffraction_rotate_transpose_ravel = (
ind_diffraction_rotate_transpose.ravel()
)
self._q_length = np.unique(self._ind_diffraction).shape[0]

# pixels to remove
Expand All @@ -783,7 +802,9 @@ def _make_diffraction_masks(self, q_max_inv_A):
dtype="bool",
)

def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
def _reshape_4D_array_to_2D(
self, data, qx0_fit=None, qy0_fit=None, ind_diffraction_ravel=None
):
"""
reshape diffraction 4D-data to 2D ravelled patterns
Expand All @@ -795,6 +816,8 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
qx shifts
qy0_fit: int
qy shifts
ind_diffraction: np.ndarray
1D array (length of number of pixels in diffraciton space to project 4D array into)
Returns
Expand All @@ -808,11 +831,13 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
center = ((s[-1] - 1) / 2, (s[-1] - 1) / 2)
diffraction_patterns_reshaped = np.zeros((s[0] * s[1], self._q_length))

if ind_diffraction_ravel is None:
ind_diffraction_ravel = self._ind_diffraction_ravel

for a0 in range(s[0]):
for a1 in range(s[0]):
dp = data[a0, a1]
index = np.ravel_multi_index((a0, a1), (s[0], s[1]))

if qx0_fit is not None:
qx0 = center[0] - qx0_fit[a0, a1]
qy0 = center[1] - qy0_fit[a0, a1]
Expand All @@ -829,7 +854,7 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
(1 - wx)
* (1 - wy)
* np.bincount(
self._ind_diffraction_ravel,
ind_diffraction_ravel,
np.roll(dp, (xF, yF), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
Expand All @@ -839,7 +864,7 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
(wx)
* (1 - wy)
* np.bincount(
self._ind_diffraction_ravel,
ind_diffraction_ravel,
np.roll(dp, (xF + 1, yF), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
Expand All @@ -848,7 +873,7 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
(1 - wx)
* (wy)
* np.bincount(
self._ind_diffraction_ravel,
ind_diffraction_ravel,
np.roll(dp, (xF, yF + 1), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
Expand All @@ -857,7 +882,7 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
(wx)
* (wy)
* np.bincount(
self._ind_diffraction_ravel,
ind_diffraction_ravel,
np.roll(dp, (xF + 1, yF + 1), axis=(0, 1)).ravel(),
minlength=self._q_length,
)
Expand All @@ -866,7 +891,7 @@ def _reshape_4D_array_to_2D(self, data, qx0_fit=None, qy0_fit=None):
else:

diffraction_patterns_reshaped[index] = np.bincount(
self._ind_diffraction_ravel,
ind_diffraction_ravel,
dp.ravel(),
minlength=self._q_length,
)
Expand Down Expand Up @@ -1478,7 +1503,6 @@ def _constraints(
)[0]
self._object[:, ind_zero] = 0


def set_storage(self, storage):
"""
Sets storage device.
Expand Down Expand Up @@ -1563,23 +1587,7 @@ def object_6D(self):

return self._object.reshape(self._object_shape_6D)

















#### Code for sims, To be removed later
#### Code for sims, To be removed later
def _make_test_object(
self,
sx: int,
Expand Down Expand Up @@ -1812,5 +1820,3 @@ def set_device(self, device, clear_fft_cache):
self._device = device

return self


0 comments on commit 397be35

Please sign in to comment.