Skip to content

Commit

Permalink
reorg and positions errors
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Sep 30, 2024
1 parent 5ad901f commit 1ccf74a
Showing 1 changed file with 154 additions and 146 deletions.
300 changes: 154 additions & 146 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,27 +572,33 @@ def _calculate_scan_positions(

if self._tilt_rotation_axis_angle_deg is not None:
rotation_angle = np.deg2rad(self._tilt_rotation_axis_angle_deg)
x_mean = x.mean()
y_mean = y.mean()
x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin(
rotation_angle
) + y * np.cos(rotation_angle)

x -= x.mean()
y -= y.mean()

x += x_mean
y += y_mean

if self._transpose_xy:
x_temp = x.copy()
y_temp = y.copy()
x = y_temp.copy()
y = x_temp.copy()

## TODO: check sign
if self._tilt_rotation_axis_shift_px:
y += self._tilt_rotation_axis_shift_px

# remove data outside FOV
if mask_real_space is None:
mask_real_space = np.ones(x.shape, dtype="bool")

else:
if self._transpose_xy:
mask_real_space = mask_real_space.swapaxes((-1, -2))
mask_real_space = mask_real_space.swapaxes(-1, -2)

mask_real_space[x >= self._field_of_view_A[0]] = False
mask_real_space[x < 0] = False
Expand Down Expand Up @@ -1012,149 +1018,6 @@ def _reshape_2D_array_to_4D(self, data, xy_shape=None, positions=None):

return data_reshaped

def _real_space_radon(
self,
current_object: np.ndarray,
tilt_deg: int,
x_index: int,
num_points: int,
):
"""
Real space projection of current object
Parameters
----------
current_object: np.ndarray
current object estimate
tilt_deg: float
tilt of object in degrees
x_index: int
x slice of object to be sliced
num_points: float
number of points for bilinear interpolation
Returns
--------
current_object_projected: np.ndarray
projection of current object
"""
xp = self._xp
device = self._device

current_object = copy_to_device(current_object, device)

s = current_object.shape

tilt = xp.deg2rad(tilt_deg)

padding = int(xp.ceil(xp.abs(xp.tan(tilt) * s[2])))

line_z = xp.arange(0, 1, 1 / num_points) * (s[2] - 1)
line_y = line_z * xp.tan(tilt) + padding

offset = xp.arange(s[1], dtype="int")

current_object_reshape = xp.pad(
current_object[x_index],
((padding, padding), (0, 0), (0, 0), (0, 0), (0, 0)),
).reshape(((s[1] + padding * 2) * s[2], s[3], s[4], s[5]))

current_object_projected = xp.zeros((s[1], s[3], s[4], s[5]))

yF = xp.floor(line_y).astype("int")
zF = xp.floor(line_z).astype("int")
dy = line_y - yF
dz = line_z - zF

ind0 = np.hstack(
(
xp.tile(yF, (s[1], 1)) + offset[:, None],
xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
xp.tile(yF, (s[1], 1)) + offset[:, None],
xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
)
)

ind1 = np.hstack(
(
xp.tile(zF, (s[1], 1)),
xp.tile(zF, (s[1], 1)),
xp.tile(zF + 1, (s[1], 1)),
xp.tile(zF + 1, (s[1], 1)),
)
)

weights = np.hstack(
(
xp.tile(((1 - dy) * (1 - dz)), (s[1], 1)),
xp.tile(((dy) * (1 - dz)), (s[1], 1)),
xp.tile(((1 - dy) * (dz)), (s[1], 1)),
xp.tile(((dy) * (dz)), (s[1], 1)),
)
)

current_object_projected += (
current_object_reshape[
xp.ravel_multi_index(
(ind0, ind1), (s[1] + 2 * padding, s[2]), mode="clip"
)
]
* weights[:, :, None, None, None]
).sum(1)

return current_object_projected

def _diffraction_space_slice(
self,
current_object_projected: np.ndarray,
tilt_deg: int,
):
"""
Slicing of diffraction space for rotated object
Parameters
----------
current_object_rotated: np.ndarray
current object estimate projected
tilt_deg: float
tilt of object in degrees
Returns
--------
current_object_sliced: np.ndarray
projection of current object sliced in diffraciton space
"""
xp = self._xp

s = current_object_projected.shape

tilt = xp.deg2rad(tilt_deg)

line_y_diff = xp.fft.fftfreq(s[-1], 1 / s[-1]) * xp.cos(tilt)
line_z_diff = xp.fft.fftfreq(s[-1], 1 / s[-1]) * xp.sin(tilt)

yF_diff = xp.floor(line_y_diff).astype("int")
zF_diff = xp.floor(line_z_diff).astype("int")
dy_diff = line_y_diff - yF_diff
dz_diff = line_z_diff - zF_diff

current_object_sliced = xp.zeros((s[0], s[-1], s[-1]))

current_object_sliced = (
current_object_projected[:, :, yF_diff, zF_diff]
* ((1 - dy_diff) * (1 - dz_diff))[None, None, :]
+ current_object_projected[:, :, yF_diff + 1, zF_diff]
* ((dy_diff) * (1 - dz_diff))[None, None, :]
+ current_object_projected[:, :, yF_diff, zF_diff + 1]
* ((1 - dy_diff) * (dz_diff))[None, None, :]
+ current_object_projected[:, :, yF_diff + 1, zF_diff + 1]
* ((dy_diff) * (dz_diff))[None, None, :]
)

return self._asnumpy(current_object_sliced)

def _forward(
self,
x_index: int,
Expand Down Expand Up @@ -1311,6 +1174,8 @@ def _forward(
self._weights_real = weights_real
self._bincount_x = bincount_x
self._ind_diff = ind_diff
self._ind0_diff = ind0_diff
self._ind1_diff = ind1_diff
self._weights_diff = weights_diff

return obj_projected
Expand Down Expand Up @@ -1791,6 +1656,149 @@ def recovered_4D_scan(self, index):

# return current_object_sliced

# def _diffraction_space_slice(
# self,
# current_object_projected: np.ndarray,
# tilt_deg: int,
# ):
# """
# Slicing of diffraction space for rotated object

# Parameters
# ----------
# current_object_rotated: np.ndarray
# current object estimate projected
# tilt_deg: float
# tilt of object in degrees

# Returns
# --------
# current_object_sliced: np.ndarray
# projection of current object sliced in diffraciton space

# """
# xp = self._xp

# s = current_object_projected.shape

# tilt = xp.deg2rad(tilt_deg)

# line_y_diff = xp.fft.fftfreq(s[-1], 1 / s[-1]) * xp.cos(tilt)
# line_z_diff = xp.fft.fftfreq(s[-1], 1 / s[-1]) * xp.sin(tilt)

# yF_diff = xp.floor(line_y_diff).astype("int")
# zF_diff = xp.floor(line_z_diff).astype("int")
# dy_diff = line_y_diff - yF_diff
# dz_diff = line_z_diff - zF_diff

# current_object_sliced = xp.zeros((s[0], s[-1], s[-1]))

# current_object_sliced = (
# current_object_projected[:, :, yF_diff, zF_diff]
# * ((1 - dy_diff) * (1 - dz_diff))[None, None, :]
# + current_object_projected[:, :, yF_diff + 1, zF_diff]
# * ((dy_diff) * (1 - dz_diff))[None, None, :]
# + current_object_projected[:, :, yF_diff, zF_diff + 1]
# * ((1 - dy_diff) * (dz_diff))[None, None, :]
# + current_object_projected[:, :, yF_diff + 1, zF_diff + 1]
# * ((dy_diff) * (dz_diff))[None, None, :]
# )

# return self._asnumpy(current_object_sliced)

# def _real_space_radon(
# self,
# current_object: np.ndarray,
# tilt_deg: int,
# x_index: int,
# num_points: int,
# ):
# """
# Real space projection of current object

# Parameters
# ----------
# current_object: np.ndarray
# current object estimate
# tilt_deg: float
# tilt of object in degrees
# x_index: int
# x slice of object to be sliced
# num_points: float
# number of points for bilinear interpolation

# Returns
# --------
# current_object_projected: np.ndarray
# projection of current object

# """
# xp = self._xp
# device = self._device

# current_object = copy_to_device(current_object, device)

# s = current_object.shape

# tilt = xp.deg2rad(tilt_deg)

# padding = int(xp.ceil(xp.abs(xp.tan(tilt) * s[2])))

# line_z = xp.arange(0, 1, 1 / num_points) * (s[2] - 1)
# line_y = line_z * xp.tan(tilt) + padding

# offset = xp.arange(s[1], dtype="int")

# current_object_reshape = xp.pad(
# current_object[x_index],
# ((padding, padding), (0, 0), (0, 0), (0, 0), (0, 0)),
# ).reshape(((s[1] + padding * 2) * s[2], s[3], s[4], s[5]))

# current_object_projected = xp.zeros((s[1], s[3], s[4], s[5]))

# yF = xp.floor(line_y).astype("int")
# zF = xp.floor(line_z).astype("int")
# dy = line_y - yF
# dz = line_z - zF

# ind0 = np.hstack(
# (
# xp.tile(yF, (s[1], 1)) + offset[:, None],
# xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
# xp.tile(yF, (s[1], 1)) + offset[:, None],
# xp.tile(yF + 1, (s[1], 1)) + offset[:, None],
# )
# )

# ind1 = np.hstack(
# (
# xp.tile(zF, (s[1], 1)),
# xp.tile(zF, (s[1], 1)),
# xp.tile(zF + 1, (s[1], 1)),
# xp.tile(zF + 1, (s[1], 1)),
# )
# )

# weights = np.hstack(
# (
# xp.tile(((1 - dy) * (1 - dz)), (s[1], 1)),
# xp.tile(((dy) * (1 - dz)), (s[1], 1)),
# xp.tile(((1 - dy) * (dz)), (s[1], 1)),
# xp.tile(((dy) * (dz)), (s[1], 1)),
# )
# )

# current_object_projected += (
# current_object_reshape[
# xp.ravel_multi_index(
# (ind0, ind1), (s[1] + 2 * padding, s[2]), mode="clip"
# )
# ]
# * weights[:, :, None, None, None]
# ).sum(1)

# return current_object_projected

# def _make_diffraction_cloud(
# self,
# sq,
Expand Down

0 comments on commit 1ccf74a

Please sign in to comment.