Skip to content

Commit

Permalink
dealing with shifting of tilt axis
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Sep 16, 2024
1 parent 930d9fb commit c7b4f2d
Showing 1 changed file with 121 additions and 99 deletions.
220 changes: 121 additions & 99 deletions py4DSTEM/tomography/tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
tilt_rotation_axis_angle: float
Rotation angle of scan direction to tilt axis
tilt_rotation_axis_shift_px: float
Shift of rotation axis relative to center of field of view
Shift of centering of datacubes so that the axis is centered in the object
(in number of datacube pixels)
transpose_xy: bool
If True, swaps x and y
Expand Down Expand Up @@ -555,31 +555,33 @@ def _calculate_scan_positions(
rotation_angle
) + y * np.cos(rotation_angle)

if self._transpose_xy:
x_temp = x.copy()
y_temp = y.copy()
x = y_temp.copy()
y = x_temp.copy()

if self._tilt_rotation_axis_shift_px:
y += self._tilt_rotation_axis_shift_px


# remove data outside FOV
if mask_real_space is None:
mask_real_space = np.ones(x.shape, dtype="bool")

if self._transpose_xy:
mask_real_space[x >= self._field_of_view_A[1]] = False
mask_real_space[x < 0] = False
mask_real_space[y >= self._field_of_view_A[0]] = False
mask_real_space[y < 0] = False
else:
mask_real_space[x >= self._field_of_view_A[0]] = False
mask_real_space[x < 0] = False
mask_real_space[y >= self._field_of_view_A[1]] = False
mask_real_space[y < 0] = False
if self._transpose_xy:
mask_real_space = mask_real_space.swapaxes((-1,-2))

mask_real_space[x >= self._field_of_view_A[0]] = False
mask_real_space[x < 0] = False
mask_real_space[y >= self._field_of_view_A[1]] = False
mask_real_space[y < 0] = False

# calculate positions in voxels
x = x[mask_real_space].ravel()
y = y[mask_real_space].ravel()

if self._transpose_xy:
x_temp = x.copy()
y_temp = y.copy()
x = y_temp.copy()
y = x_temp.copy()

x_vox = x / self._voxel_size_A
y_vox = y / self._voxel_size_A

Expand Down Expand Up @@ -1465,6 +1467,108 @@ def _constraints(
)[0]
self._object[:, ind_zero] = 0


def set_storage(self, storage):
"""
Sets storage device.
Parameters
----------
storage: str
Device arrays will be stored on. Must be 'cpu' or 'gpu'
Returns
--------
self: PhaseReconstruction
Self to enable chaining
"""

if storage == "cpu":
self._xp_storage = np

elif storage == "gpu":
if self._xp is np:
raise ValueError("storage='gpu' and device='cpu' is not supported")
self._xp_storage = cp

else:
raise ValueError(f"storage must be either 'cpu' or 'gpu', not {storage}")

self._asnumpy = copy_to_device
self._storage = storage

return self

def visualize(self, plot_convergence=True, figsize=(10, 10)):
"""
vis
"""

if plot_convergence:
spec = GridSpec(
ncols=2,
nrows=2,
height_ratios=[4, 1],
hspace=0.15,
# width_ratios=[
# (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
# 1,
# ],
wspace=0.15,
)

else:
spec = GridSpec(ncols=2, nrows=1)

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(spec[0, 0])
show(
self.object_6D.mean((2, 3, 4, 5)),
figax=(fig, ax),
cmap="magma",
title="real space object",
)

ax = fig.add_subplot(spec[0, 1])
ind_diff = self._object_shape_6D[-1] // 2
show(
self.object_6D.mean((0, 1, 2))[:, :, ind_diff],
figax=(fig, ax),
cmap="magma",
title="diffraction space object",
)

if plot_convergence:
ax = fig.add_subplot(spec[1, :])
ax.plot(self.error_iterations, color="b")
ax.set_xlabel("iterations")
ax.set_ylabel("error")

return self

@property
def object_6D(self):
"""6D object"""

return self._object.reshape(self._object_shape_6D)

















#### Code for sims, To be removed later
def _make_test_object(
self,
sx: int,
Expand Down Expand Up @@ -1698,86 +1802,4 @@ def set_device(self, device, clear_fft_cache):

return self

def set_storage(self, storage):
"""
Sets storage device.
Parameters
----------
storage: str
Device arrays will be stored on. Must be 'cpu' or 'gpu'
Returns
--------
self: PhaseReconstruction
Self to enable chaining
"""

if storage == "cpu":
self._xp_storage = np

elif storage == "gpu":
if self._xp is np:
raise ValueError("storage='gpu' and device='cpu' is not supported")
self._xp_storage = cp

else:
raise ValueError(f"storage must be either 'cpu' or 'gpu', not {storage}")

self._asnumpy = copy_to_device
self._storage = storage

return self

def visualize(self, plot_convergence=True, figsize=(10, 10)):
"""
vis
"""

if plot_convergence:
spec = GridSpec(
ncols=2,
nrows=2,
height_ratios=[4, 1],
hspace=0.15,
# width_ratios=[
# (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
# 1,
# ],
wspace=0.15,
)

else:
spec = GridSpec(ncols=2, nrows=1)

fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(spec[0, 0])
show(
self.object_6D.mean((2, 3, 4, 5)),
figax=(fig, ax),
cmap="magma",
title="real space object",
)

ax = fig.add_subplot(spec[0, 1])
ind_diff = self._object_shape_6D[-1] // 2
show(
self.object_6D.mean((0, 1, 2))[:, :, ind_diff],
figax=(fig, ax),
cmap="magma",
title="diffraction space object",
)

if plot_convergence:
ax = fig.add_subplot(spec[1, :])
ax.plot(self.error_iterations, color="b")
ax.set_xlabel("iterations")
ax.set_ylabel("error")

return self

@property
def object_6D(self):
"""6D object"""

return self._object.reshape(self._object_shape_6D)

0 comments on commit c7b4f2d

Please sign in to comment.