From 6d2efd3cc31dfb78db38d897f477616416765fb4 Mon Sep 17 00:00:00 2001 From: smribet Date: Thu, 26 Sep 2024 10:41:02 -0700 Subject: [PATCH] a bit of reorg --- py4DSTEM/tomography/tomography.py | 466 +++++++++++++++--------------- 1 file changed, 233 insertions(+), 233 deletions(-) diff --git a/py4DSTEM/tomography/tomography.py b/py4DSTEM/tomography/tomography.py index fe2a81ff2..2992d65b3 100644 --- a/py4DSTEM/tomography/tomography.py +++ b/py4DSTEM/tomography/tomography.py @@ -768,7 +768,7 @@ def _make_diffraction_masks(self, q_max_inv_A): ind_diffraction_rotate_transpose = np.clip( rotate( ind_diffraction_rotate_transpose, - -self._force_q_to_r_rotation_deg, #negative makes this rotation consistant with phase contrast module rotation + -self._force_q_to_r_rotation_deg, # negative makes this rotation consistant with phase contrast module rotation reshape=False, order=0, ), @@ -1552,6 +1552,46 @@ def set_storage(self, storage): return self + def set_device(self, device, clear_fft_cache): + """ + Sets calculation device. + + Parameters + ---------- + device: str + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if clear_fft_cache is not None: + self._clear_fft_cache = clear_fft_cache + + if device is None: + return self + + if device == "cpu": + import scipy + + self._xp = np + self._scipy = scipy + + elif device == "gpu": + from cupyx import scipy + + self._xp = cp + self._scipy = scipy + + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + self._device = device + + return self + def visualize(self, plot_convergence=True, figsize=(10, 10)): """ vis @@ -1606,235 +1646,195 @@ def object_6D(self): return self._object.reshape(self._object_shape_6D) #### Code for sims, To be removed later - def _make_test_object( - self, - sx: int, - sy: int, - sz: int, - sq: int, - q_max: float, - r: int, - num: int, - ): - """ - Make test object with 3D gold cubes at random orientations - - Parameters - ---------- - sx: int - x size (pixels) - sy: int - y size (pixels) - sz: int - z size (pixels) - sq: int - q size (pixels) - q_max: float - maximum scattering angle (A^-1) - r: int - length of 3D gold cubes - num: int - number of cubes - - Returns - -------- - test_object: np.ndarray - 6D test object - """ - xp_storage = self._xp_storage - storage = self._storage - - test_object = xp_storage.zeros((sx, sy, sz, sq, sq, sq)) - - diffraction_cloud = self._make_diffraction_cloud(sq, q_max, [0, 0, 0]) - - test_object[:, :, :, 0, 0, 0] = copy_to_device(diffraction_cloud.sum(), storage) - - for a0 in range(num): - s1 = xp_storage.random.randint(r, sx - r) - s2 = xp_storage.random.randint(r, sy - r) - h = xp_storage.random.randint(r, sz - r, size=1) - t = xp_storage.random.randint(0, 360, size=3) - - cloud = copy_to_device(self._make_diffraction_cloud(sq, q_max, t), storage) - - test_object[s1 - r : s1 + r, s2 - r : s2 + r, h[0] - r : h[0] + r] = cloud - - return test_object - - def _forward_simulation( - self, - current_object: np.ndarray, - tilt_deg: int, - x_index: int, - num_points: np.ndarray = 60, - ): - """ - Forward projection of object for simulation of diffraction data - - Parameters - ---------- - current_object: np.ndarray - current object estimate - tilt_deg: float - tilt of object in degrees - x_index: int - x slice of object to be sliced - num_points: float - number of points for bilinear interpolation - - Returns - -------- - current_object_sliced: np.ndarray - projection of current object sliced in diffraciton space - """ - current_object_projected = self._real_space_radon( - current_object, - tilt_deg, - x_index, - num_points, - ) - - current_object_sliced = self._diffraction_space_slice( - current_object_projected, - tilt_deg, - ) - - return current_object_sliced - - def _make_diffraction_cloud( - self, - sq, - q_max, - rot, - ): - """ - Make 3D diffraction cloud - - Parameters - ---------- - sq: int - q size (pixels) - q_max: float - maximum scattering angle (A^-1) - rot: 3-tuple - rotation of cloud - - Returns - -------- - diffraction_cloud: np.ndarray - 3D structure factor - - """ - xp = self._xp - - gold = self._make_gold(q_max) - - diffraction_cloud = xp.zeros((sq, sq, sq)) - - q_step = q_max * 2 / (sq - 1) - - qz = xp.fft.ifftshift(xp.arange(sq) * q_step - q_step * (sq - 1) / 2) - qx = xp.fft.ifftshift(xp.arange(sq) * q_step - q_step * (sq - 1) / 2) - qy = xp.fft.ifftshift(xp.arange(sq) * q_step - q_step * (sq - 1) / 2) - - qxa, qya, qza = xp.meshgrid(qx, qy, qz, indexing="ij") - - g_vecs = gold.g_vec_all.copy() - r = R.from_euler("zxz", [rot[0], rot[1], rot[2]]) - g_vecs = r.as_matrix() @ g_vecs - - cut_off = 0.1 - - for a0 in range(gold.g_vec_all.shape[1]): - bragg_spot = g_vecs[:, a0] - distance = xp.sqrt( - (qxa - bragg_spot[0]) ** 2 - + (qya - bragg_spot[1]) ** 2 - + (qza - bragg_spot[2]) ** 2 - ) - - update_index = distance < cut_off - update = xp.zeros((distance.shape)) - update[update_index] = cut_off - distance[update_index] - update -= xp.min(update) - update /= xp.sum(update) - update *= gold.struct_factors_int[a0] - diffraction_cloud += update - - return diffraction_cloud - - def _make_gold( - self, - q_max, - ): - """ - Calculate structure factor for gold up to q_max - - Parameters - ---------- - q_max: float - maximum scattering angle (A^-1) - - Returns - -------- - crystal: Crystal - gold crystal with structure factor calculated to q_max - - """ - - pos = [ - [0.0, 0.0, 0.0], - [0.0, 0.5, 0.5], - [0.5, 0.0, 0.5], - [0.5, 0.5, 0.0], - ] - atom_num = 79 - a = 4.08 - cell = a - - crystal = Crystal(pos, atom_num, cell) - - crystal.calculate_structure_factors(q_max) - - return crystal - - def set_device(self, device, clear_fft_cache): - """ - Sets calculation device. - - Parameters - ---------- - device: str - Calculation device will be perfomed on. Must be 'cpu' or 'gpu' - - Returns - -------- - self: PhaseReconstruction - Self to enable chaining - """ - - if clear_fft_cache is not None: - self._clear_fft_cache = clear_fft_cache - - if device is None: - return self - - if device == "cpu": - import scipy - - self._xp = np - self._scipy = scipy - - elif device == "gpu": - from cupyx import scipy - - self._xp = cp - self._scipy = scipy - - else: - raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") - - self._device = device - - return self + # def _make_test_object( + # self, + # sx: int, + # sy: int, + # sz: int, + # sq: int, + # q_max: float, + # r: int, + # num: int, + # ): + # """ + # Make test object with 3D gold cubes at random orientations + + # Parameters + # ---------- + # sx: int + # x size (pixels) + # sy: int + # y size (pixels) + # sz: int + # z size (pixels) + # sq: int + # q size (pixels) + # q_max: float + # maximum scattering angle (A^-1) + # r: int + # length of 3D gold cubes + # num: int + # number of cubes + + # Returns + # -------- + # test_object: np.ndarray + # 6D test object + # """ + # xp_storage = self._xp_storage + # storage = self._storage + + # test_object = xp_storage.zeros((sx, sy, sz, sq, sq, sq)) + + # diffraction_cloud = self._make_diffraction_cloud(sq, q_max, [0, 0, 0]) + + # test_object[:, :, :, 0, 0, 0] = copy_to_device(diffraction_cloud.sum(), storage) + + # for a0 in range(num): + # s1 = xp_storage.random.randint(r, sx - r) + # s2 = xp_storage.random.randint(r, sy - r) + # h = xp_storage.random.randint(r, sz - r, size=1) + # t = xp_storage.random.randint(0, 360, size=3) + + # cloud = copy_to_device(self._make_diffraction_cloud(sq, q_max, t), storage) + + # test_object[s1 - r : s1 + r, s2 - r : s2 + r, h[0] - r : h[0] + r] = cloud + + # return test_object + + # def _forward_simulation( + # self, + # current_object: np.ndarray, + # tilt_deg: int, + # x_index: int, + # num_points: np.ndarray = 60, + # ): + # """ + # Forward projection of object for simulation of diffraction data + + # Parameters + # ---------- + # current_object: np.ndarray + # current object estimate + # tilt_deg: float + # tilt of object in degrees + # x_index: int + # x slice of object to be sliced + # num_points: float + # number of points for bilinear interpolation + + # Returns + # -------- + # current_object_sliced: np.ndarray + # projection of current object sliced in diffraciton space + # """ + # current_object_projected = self._real_space_radon( + # current_object, + # tilt_deg, + # x_index, + # num_points, + # ) + + # current_object_sliced = self._diffraction_space_slice( + # current_object_projected, + # tilt_deg, + # ) + + # return current_object_sliced + + # def _make_diffraction_cloud( + # self, + # sq, + # q_max, + # rot, + # ): + # """ + # Make 3D diffraction cloud + + # Parameters + # ---------- + # sq: int + # q size (pixels) + # q_max: float + # maximum scattering angle (A^-1) + # rot: 3-tuple + # rotation of cloud + + # Returns + # -------- + # diffraction_cloud: np.ndarray + # 3D structure factor + + # """ + # xp = self._xp + + # gold = self._make_gold(q_max) + + # diffraction_cloud = xp.zeros((sq, sq, sq)) + + # q_step = q_max * 2 / (sq - 1) + + # qz = xp.fft.ifftshift(xp.arange(sq) * q_step - q_step * (sq - 1) / 2) + # qx = xp.fft.ifftshift(xp.arange(sq) * q_step - q_step * (sq - 1) / 2) + # qy = xp.fft.ifftshift(xp.arange(sq) * q_step - q_step * (sq - 1) / 2) + + # qxa, qya, qza = xp.meshgrid(qx, qy, qz, indexing="ij") + + # g_vecs = gold.g_vec_all.copy() + # r = R.from_euler("zxz", [rot[0], rot[1], rot[2]]) + # g_vecs = r.as_matrix() @ g_vecs + + # cut_off = 0.1 + + # for a0 in range(gold.g_vec_all.shape[1]): + # bragg_spot = g_vecs[:, a0] + # distance = xp.sqrt( + # (qxa - bragg_spot[0]) ** 2 + # + (qya - bragg_spot[1]) ** 2 + # + (qza - bragg_spot[2]) ** 2 + # ) + + # update_index = distance < cut_off + # update = xp.zeros((distance.shape)) + # update[update_index] = cut_off - distance[update_index] + # update -= xp.min(update) + # update /= xp.sum(update) + # update *= gold.struct_factors_int[a0] + # diffraction_cloud += update + + # return diffraction_cloud + + # def _make_gold( + # self, + # q_max, + # ): + # """ + # Calculate structure factor for gold up to q_max + + # Parameters + # ---------- + # q_max: float + # maximum scattering angle (A^-1) + + # Returns + # -------- + # crystal: Crystal + # gold crystal with structure factor calculated to q_max + + # """ + + # pos = [ + # [0.0, 0.0, 0.0], + # [0.0, 0.5, 0.5], + # [0.5, 0.0, 0.5], + # [0.5, 0.5, 0.0], + # ] + # atom_num = 79 + # a = 4.08 + # cell = a + + # crystal = Crystal(pos, atom_num, cell) + + # crystal.calculate_structure_factors(q_max) + + # return crystal