From 72593092e731ee623bbe50f161a1b0344073dd9c Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Thu, 14 Nov 2024 21:36:12 -0800 Subject: [PATCH] silly fftftreq casting to float64 --- py4DSTEM/process/phase/dpc.py | 8 ++- py4DSTEM/process/phase/parallax.py | 24 ++++---- .../phase/ptychographic_constraints.py | 36 +++++++---- py4DSTEM/process/phase/utils.py | 60 +++++++++++-------- 4 files changed, 82 insertions(+), 46 deletions(-) diff --git a/py4DSTEM/process/phase/dpc.py b/py4DSTEM/process/phase/dpc.py index d35dd9af9..6747905c0 100644 --- a/py4DSTEM/process/phase/dpc.py +++ b/py4DSTEM/process/phase/dpc.py @@ -569,8 +569,12 @@ def _object_butterworth_constraint( Constrained object estimate """ xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0]).astype( + xp.float32 + ) + qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1]).astype( + xp.float32 + ) qya, qxa = xp.meshgrid(qy, qx) qra = xp.sqrt(qxa**2 + qya**2) diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 9255000e4..28a868d9a 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -1841,8 +1841,12 @@ def subpixel_alignment( position_correction_butterworth_q_lowpass is not None or position_correction_butterworth_q_highpass is not None ): - qx = xp.fft.fftfreq(BF_size[0], self._scan_sampling[0]) - qy = xp.fft.fftfreq(BF_size[1], self._scan_sampling[1]) + qx = xp.fft.fftfreq(BF_size[0], self._scan_sampling[0]).astype( + xp.float32 + ) + qy = xp.fft.fftfreq(BF_size[1], self._scan_sampling[1]).astype( + xp.float32 + ) qya, qxa = xp.meshgrid(qy, qx) qra = xp.sqrt(qxa**2 + qya**2) @@ -2042,8 +2046,8 @@ def subpixel_alignment( ] nx, ny = self._recon_BF_subpixel_aligned.shape - kx = xp.fft.fftfreq(nx, d=1) - ky = xp.fft.fftfreq(ny, d=1) + kx = xp.fft.fftfreq(nx, d=1).astype(xp.float32) + ky = xp.fft.fftfreq(ny, d=1).astype(xp.float32) k = xp.fft.fftshift(xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2)) recon_fft = xp.fft.fftshift( @@ -2587,8 +2591,8 @@ def calculate_CTF(alpha_shape, *coefs): ] # FFT coordinates - qx = xp.fft.fftfreq(im_FFT.shape[0], sx) - qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qx = xp.fft.fftfreq(im_FFT.shape[0], sx).astype(xp.float32) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy).astype(xp.float32) qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 alpha_FFT = xp.sqrt(qr2) * self._wavelength @@ -2824,8 +2828,8 @@ def aberration_correct( sy = self._scan_sampling[1] # Fourier coordinates - kx = xp.fft.fftfreq(im.shape[0], sx) - ky = xp.fft.fftfreq(im.shape[1], sy) + kx = xp.fft.fftfreq(im.shape[0], sx).astype(xp.float32) + ky = xp.fft.fftfreq(im.shape[1], sy).astype(xp.float32) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 if use_CTF_fit is None: @@ -2946,8 +2950,8 @@ def depth_section( # Fourier coordinates sx, sy = self._scan_sampling nx, ny = self._recon_BF.shape - kx = xp.fft.fftfreq(nx, sx) - ky = xp.fft.fftfreq(ny, sy) + kx = xp.fft.fftfreq(nx, sx).astype(xp.float32) + ky = xp.fft.fftfreq(ny, sy).astype(xp.float32) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 if use_CTF_fit: diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index b0a2c2dff..ef8566031 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -184,8 +184,12 @@ def _object_butterworth_constraint( Constrained object estimate """ xp = self._xp - qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1]) + qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0]).astype( + xp.float32 + ) + qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1]).astype( + xp.float32 + ) qya, qxa = xp.meshgrid(qy, qx) qra = xp.sqrt(qxa**2 + qya**2) @@ -620,9 +624,15 @@ def _object_kz_regularization_constraint( pad_width = ((z_padding, z_padding), (0, 0), (0, 0)) current_object = xp.pad(current_object, pad_width=pad_width, mode="constant") - qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]).astype( + xp.float32 + ) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]).astype( + xp.float32 + ) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]).astype( + xp.float32 + ) kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] @@ -831,9 +841,15 @@ def _object_butterworth_constraint( Constrained object estimate """ xp = self._xp - qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]) - qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) - qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]).astype( + xp.float32 + ) + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]).astype( + xp.float32 + ) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]).astype( + xp.float32 + ) qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") qra = xp.sqrt(qza**2 + qxa**2 + qya**2) @@ -975,8 +991,8 @@ def _probe_amplitude_constraint( probe_intensity = xp.abs(current_probe) ** 2 current_probe_sum = xp.sum(probe_intensity) - X = xp.fft.fftfreq(current_probe.shape[0])[:, None] - Y = xp.fft.fftfreq(current_probe.shape[1])[None] + X = xp.fft.fftfreq(current_probe.shape[0]).astype(xp.float32)[:, None] + Y = xp.fft.fftfreq(current_probe.shape[1]).astype(xp.float32)[None] r = xp.hypot(X, Y) - relative_radius sigma = np.sqrt(np.pi) / relative_width diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 628722261..5208a4ae3 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1014,8 +1014,8 @@ def return_1D_profile( if pixel_size is None: pixel_size = (1, 1) - x = xp.fft.fftfreq(intensity.shape[0], pixel_size[0]) - y = xp.fft.fftfreq(intensity.shape[1], pixel_size[1]) + x = xp.fft.fftfreq(intensity.shape[0], pixel_size[0]).astype(xp.float32) + y = xp.fft.fftfreq(intensity.shape[1], pixel_size[1]).astype(xp.float32) q = xp.sqrt(x[:, None] ** 2 + y[None, :] ** 2) q = q.ravel() @@ -1094,8 +1094,8 @@ def fourier_rotate_real_volume(array, angle, axes=(0, 1), xp=np): rotation_ax = np.setdiff1d([0, 1, 2], axes)[0] plane_dims = array_shape[axes] - qx = xp.fft.fftfreq(plane_dims[0], 1) - qy = xp.fft.fftfreq(plane_dims[1], 1) + qx = xp.fft.fftfreq(plane_dims[0], 1).astype(xp.float32) + qy = xp.fft.fftfreq(plane_dims[1], 1).astype(xp.float32) qxa, qya = xp.meshgrid(qx, qy, indexing="ij") x = xp.arange(plane_dims[0]) - plane_dims[0] / 2 @@ -1374,11 +1374,11 @@ def polar_to_cartesian_transform_2Ddata( cx, cy = xy_center if corner_centered: - x = xp.fft.fftfreq(sx, d=1 / sx) - y = xp.fft.fftfreq(sy, d=1 / sy) + x = xp.fft.fftfreq(sx, d=1 / sx).astype(xp.float32) + y = xp.fft.fftfreq(sy, d=1 / sy).astype(xp.float32) else: - x = xp.arange(sx) - y = xp.arange(sy) + x = xp.arange(sx, dtype=xp.float32) + y = xp.arange(sy, dtype=xp.float32) xa, ya = xp.meshgrid(x, y, indexing="ij") ra = xp.hypot(xa - cx, ya - cy) @@ -1548,8 +1548,8 @@ def calculate_aberration_gradient_basis( """ """ sx, sy = sampling nx, ny = gpts - qx = xp.fft.fftfreq(nx, sx) - qy = xp.fft.fftfreq(ny, sy) + qx = xp.fft.fftfreq(nx, sx).astype(xp.float32) + qy = xp.fft.fftfreq(ny, sy).astype(xp.float32) qx, qy = xp.meshgrid(qx, qy, indexing="ij") # passive rotation @@ -1653,8 +1653,8 @@ def aberrations_basis_function( dx, dy = probe_sampling wavelength = electron_wavelength_angstrom(energy) - qx = xp.fft.fftfreq(sx, dx) - qy = xp.fft.fftfreq(sy, dy) + qx = xp.fft.fftfreq(sx, dx).astype(xp.float32) + qy = xp.fft.fftfreq(sy, dy).astype(xp.float32) qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 alpha = xp.sqrt(qr2) * wavelength theta = xp.arctan2(qy[None, :], qx[:, None]) @@ -2155,8 +2155,12 @@ def pixel_rolling_kernel_density_estimate( if lowpass_filter: pix_fft = xp.fft.fft2(pix_output) - pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None] - pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None] + pix_fft /= xp.sinc( + xp.fft.fftfreq(pix_output.shape[0], d=1.0).astype(xp.float32) + )[:, None] + pix_fft /= xp.sinc( + xp.fft.fftfreq(pix_output.shape[1], d=1.0).astype(xp.float32) + )[None] pix_output = xp.real(xp.fft.ifft2(pix_fft)) return pix_output @@ -2262,8 +2266,12 @@ def bilinear_kernel_density_estimate( if lowpass_filter: pix_fft = xp.fft.fft2(pix_output) - pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None] - pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None] + pix_fft /= xp.sinc( + xp.fft.fftfreq(pix_output.shape[0], d=1.0).astype(xp.float32) + )[:, None] + pix_fft /= xp.sinc( + xp.fft.fftfreq(pix_output.shape[1], d=1.0).astype(xp.float32) + )[None] pix_output = xp.real(xp.fft.ifft2(pix_fft)) return pix_output @@ -2364,8 +2372,12 @@ def lanczos_kernel_density_estimate( if lowpass_filter: pix_fft = xp.fft.fft2(pix_output) - pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None] - pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None] + pix_fft /= xp.sinc( + xp.fft.fftfreq(pix_output.shape[0], d=1.0).astype(xp.float32) + )[:, None] + pix_fft /= xp.sinc( + xp.fft.fftfreq(pix_output.shape[1], d=1.0).astype(xp.float32) + )[None] pix_output = xp.real(xp.fft.ifft2(pix_fft)) return pix_output @@ -2898,8 +2910,8 @@ def fourier_rotation_best_shears_combination(tf): def fourier_shear_Sx(array, a, b, xp=np): """ """ Nx, Ny, Nz = array.shape - nx = xp.arange(Nx) - (Nx - 1) / 2 - ny, nz = tuple(xp.fft.fftfreq(N, 1 / N) for N in [Ny, Nz]) + nx = xp.arange(Nx, dtype=xp.float32) - (Nx - 1) / 2 + ny, nz = tuple(xp.fft.fftfreq(N, 1 / N).astype(xp.float32) for N in [Ny, Nz]) nxa, nya, nza = xp.meshgrid(nx, ny, nz, indexing="ij") phase_shift = xp.exp(-2j * np.pi * (a * nya + b * nza) * nxa / Nx) @@ -2910,8 +2922,8 @@ def fourier_shear_Sx(array, a, b, xp=np): def fourier_shear_Sy(array, a, b, xp=np): """ """ Nx, Ny, Nz = array.shape - ny = xp.arange(Ny) - (Ny - 1) / 2 - nx, nz = tuple(xp.fft.fftfreq(N, 1 / N) for N in [Nx, Nz]) + ny = xp.arange(Ny, dtype=xp.float32) - (Ny - 1) / 2 + nx, nz = tuple(xp.fft.fftfreq(N, 1 / N).astype(xp.float32) for N in [Nx, Nz]) nxa, nya, nza = xp.meshgrid(nx, ny, nz, indexing="ij") phase_shift = xp.exp(-2j * np.pi * (a * nza + b * nxa) * nya / Ny) @@ -2922,8 +2934,8 @@ def fourier_shear_Sy(array, a, b, xp=np): def fourier_shear_Sz(array, a, b, xp=np): """ """ Nx, Ny, Nz = array.shape - nz = xp.arange(Nz) - (Nz - 1) / 2 - nx, ny = tuple(xp.fft.fftfreq(N, 1 / N) for N in [Nx, Ny]) + nz = xp.arange(Nz, dtype=xp.float32) - (Nz - 1) / 2 + nx, ny = tuple(xp.fft.fftfreq(N, 1 / N).astype(xp.float32) for N in [Nx, Ny]) nxa, nya, nza = xp.meshgrid(nx, ny, nz, indexing="ij") phase_shift = xp.exp(-2j * np.pi * (a * nxa + b * nya) * nza / Nz)