Skip to content

Commit

Permalink
silly fftftreq casting to float64
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 15, 2024
1 parent 2a41047 commit 7259309
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 46 deletions.
8 changes: 6 additions & 2 deletions py4DSTEM/process/phase/dpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,12 @@ def _object_butterworth_constraint(
Constrained object estimate
"""
xp = self._xp
qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0])
qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1])
qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0]).astype(
xp.float32
)
qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1]).astype(
xp.float32
)

qya, qxa = xp.meshgrid(qy, qx)
qra = xp.sqrt(qxa**2 + qya**2)
Expand Down
24 changes: 14 additions & 10 deletions py4DSTEM/process/phase/parallax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,8 +1841,12 @@ def subpixel_alignment(
position_correction_butterworth_q_lowpass is not None
or position_correction_butterworth_q_highpass is not None
):
qx = xp.fft.fftfreq(BF_size[0], self._scan_sampling[0])
qy = xp.fft.fftfreq(BF_size[1], self._scan_sampling[1])
qx = xp.fft.fftfreq(BF_size[0], self._scan_sampling[0]).astype(
xp.float32
)
qy = xp.fft.fftfreq(BF_size[1], self._scan_sampling[1]).astype(
xp.float32
)

qya, qxa = xp.meshgrid(qy, qx)
qra = xp.sqrt(qxa**2 + qya**2)
Expand Down Expand Up @@ -2042,8 +2046,8 @@ def subpixel_alignment(
]

nx, ny = self._recon_BF_subpixel_aligned.shape
kx = xp.fft.fftfreq(nx, d=1)
ky = xp.fft.fftfreq(ny, d=1)
kx = xp.fft.fftfreq(nx, d=1).astype(xp.float32)
ky = xp.fft.fftfreq(ny, d=1).astype(xp.float32)
k = xp.fft.fftshift(xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2))

recon_fft = xp.fft.fftshift(
Expand Down Expand Up @@ -2587,8 +2591,8 @@ def calculate_CTF(alpha_shape, *coefs):
]

# FFT coordinates
qx = xp.fft.fftfreq(im_FFT.shape[0], sx)
qy = xp.fft.fftfreq(im_FFT.shape[1], sy)
qx = xp.fft.fftfreq(im_FFT.shape[0], sx).astype(xp.float32)
qy = xp.fft.fftfreq(im_FFT.shape[1], sy).astype(xp.float32)
qr2 = qx[:, None] ** 2 + qy[None, :] ** 2

alpha_FFT = xp.sqrt(qr2) * self._wavelength
Expand Down Expand Up @@ -2824,8 +2828,8 @@ def aberration_correct(
sy = self._scan_sampling[1]

# Fourier coordinates
kx = xp.fft.fftfreq(im.shape[0], sx)
ky = xp.fft.fftfreq(im.shape[1], sy)
kx = xp.fft.fftfreq(im.shape[0], sx).astype(xp.float32)
ky = xp.fft.fftfreq(im.shape[1], sy).astype(xp.float32)
kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2

if use_CTF_fit is None:
Expand Down Expand Up @@ -2946,8 +2950,8 @@ def depth_section(
# Fourier coordinates
sx, sy = self._scan_sampling
nx, ny = self._recon_BF.shape
kx = xp.fft.fftfreq(nx, sx)
ky = xp.fft.fftfreq(ny, sy)
kx = xp.fft.fftfreq(nx, sx).astype(xp.float32)
ky = xp.fft.fftfreq(ny, sy).astype(xp.float32)
kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2

if use_CTF_fit:
Expand Down
36 changes: 26 additions & 10 deletions py4DSTEM/process/phase/ptychographic_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,12 @@ def _object_butterworth_constraint(
Constrained object estimate
"""
xp = self._xp
qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0])
qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1])
qx = xp.fft.fftfreq(current_object.shape[-2], self.sampling[0]).astype(
xp.float32
)
qy = xp.fft.fftfreq(current_object.shape[-1], self.sampling[1]).astype(
xp.float32
)

qya, qxa = xp.meshgrid(qy, qx)
qra = xp.sqrt(qxa**2 + qya**2)
Expand Down Expand Up @@ -620,9 +624,15 @@ def _object_kz_regularization_constraint(
pad_width = ((z_padding, z_padding), (0, 0), (0, 0))
current_object = xp.pad(current_object, pad_width=pad_width, mode="constant")

qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0])
qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]).astype(
xp.float32
)
qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]).astype(
xp.float32
)
qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]).astype(
xp.float32
)

kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0]

Expand Down Expand Up @@ -831,9 +841,15 @@ def _object_butterworth_constraint(
Constrained object estimate
"""
xp = self._xp
qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1])
qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1]).astype(
xp.float32
)
qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]).astype(
xp.float32
)
qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]).astype(
xp.float32
)
qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij")
qra = xp.sqrt(qza**2 + qxa**2 + qya**2)

Expand Down Expand Up @@ -975,8 +991,8 @@ def _probe_amplitude_constraint(
probe_intensity = xp.abs(current_probe) ** 2
current_probe_sum = xp.sum(probe_intensity)

X = xp.fft.fftfreq(current_probe.shape[0])[:, None]
Y = xp.fft.fftfreq(current_probe.shape[1])[None]
X = xp.fft.fftfreq(current_probe.shape[0]).astype(xp.float32)[:, None]
Y = xp.fft.fftfreq(current_probe.shape[1]).astype(xp.float32)[None]
r = xp.hypot(X, Y) - relative_radius

sigma = np.sqrt(np.pi) / relative_width
Expand Down
60 changes: 36 additions & 24 deletions py4DSTEM/process/phase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,8 @@ def return_1D_profile(
if pixel_size is None:
pixel_size = (1, 1)

x = xp.fft.fftfreq(intensity.shape[0], pixel_size[0])
y = xp.fft.fftfreq(intensity.shape[1], pixel_size[1])
x = xp.fft.fftfreq(intensity.shape[0], pixel_size[0]).astype(xp.float32)
y = xp.fft.fftfreq(intensity.shape[1], pixel_size[1]).astype(xp.float32)
q = xp.sqrt(x[:, None] ** 2 + y[None, :] ** 2)
q = q.ravel()

Expand Down Expand Up @@ -1094,8 +1094,8 @@ def fourier_rotate_real_volume(array, angle, axes=(0, 1), xp=np):
rotation_ax = np.setdiff1d([0, 1, 2], axes)[0]
plane_dims = array_shape[axes]

qx = xp.fft.fftfreq(plane_dims[0], 1)
qy = xp.fft.fftfreq(plane_dims[1], 1)
qx = xp.fft.fftfreq(plane_dims[0], 1).astype(xp.float32)
qy = xp.fft.fftfreq(plane_dims[1], 1).astype(xp.float32)
qxa, qya = xp.meshgrid(qx, qy, indexing="ij")

x = xp.arange(plane_dims[0]) - plane_dims[0] / 2
Expand Down Expand Up @@ -1374,11 +1374,11 @@ def polar_to_cartesian_transform_2Ddata(
cx, cy = xy_center

if corner_centered:
x = xp.fft.fftfreq(sx, d=1 / sx)
y = xp.fft.fftfreq(sy, d=1 / sy)
x = xp.fft.fftfreq(sx, d=1 / sx).astype(xp.float32)
y = xp.fft.fftfreq(sy, d=1 / sy).astype(xp.float32)
else:
x = xp.arange(sx)
y = xp.arange(sy)
x = xp.arange(sx, dtype=xp.float32)
y = xp.arange(sy, dtype=xp.float32)

xa, ya = xp.meshgrid(x, y, indexing="ij")
ra = xp.hypot(xa - cx, ya - cy)
Expand Down Expand Up @@ -1548,8 +1548,8 @@ def calculate_aberration_gradient_basis(
""" """
sx, sy = sampling
nx, ny = gpts
qx = xp.fft.fftfreq(nx, sx)
qy = xp.fft.fftfreq(ny, sy)
qx = xp.fft.fftfreq(nx, sx).astype(xp.float32)
qy = xp.fft.fftfreq(ny, sy).astype(xp.float32)
qx, qy = xp.meshgrid(qx, qy, indexing="ij")

# passive rotation
Expand Down Expand Up @@ -1653,8 +1653,8 @@ def aberrations_basis_function(
dx, dy = probe_sampling
wavelength = electron_wavelength_angstrom(energy)

qx = xp.fft.fftfreq(sx, dx)
qy = xp.fft.fftfreq(sy, dy)
qx = xp.fft.fftfreq(sx, dx).astype(xp.float32)
qy = xp.fft.fftfreq(sy, dy).astype(xp.float32)
qr2 = qx[:, None] ** 2 + qy[None, :] ** 2
alpha = xp.sqrt(qr2) * wavelength
theta = xp.arctan2(qy[None, :], qx[:, None])
Expand Down Expand Up @@ -2155,8 +2155,12 @@ def pixel_rolling_kernel_density_estimate(

if lowpass_filter:
pix_fft = xp.fft.fft2(pix_output)
pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None]
pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None]
pix_fft /= xp.sinc(
xp.fft.fftfreq(pix_output.shape[0], d=1.0).astype(xp.float32)
)[:, None]
pix_fft /= xp.sinc(
xp.fft.fftfreq(pix_output.shape[1], d=1.0).astype(xp.float32)
)[None]
pix_output = xp.real(xp.fft.ifft2(pix_fft))

return pix_output
Expand Down Expand Up @@ -2262,8 +2266,12 @@ def bilinear_kernel_density_estimate(

if lowpass_filter:
pix_fft = xp.fft.fft2(pix_output)
pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None]
pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None]
pix_fft /= xp.sinc(
xp.fft.fftfreq(pix_output.shape[0], d=1.0).astype(xp.float32)
)[:, None]
pix_fft /= xp.sinc(
xp.fft.fftfreq(pix_output.shape[1], d=1.0).astype(xp.float32)
)[None]
pix_output = xp.real(xp.fft.ifft2(pix_fft))

return pix_output
Expand Down Expand Up @@ -2364,8 +2372,12 @@ def lanczos_kernel_density_estimate(

if lowpass_filter:
pix_fft = xp.fft.fft2(pix_output)
pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[0], d=1.0))[:, None]
pix_fft /= xp.sinc(xp.fft.fftfreq(pix_output.shape[1], d=1.0))[None]
pix_fft /= xp.sinc(
xp.fft.fftfreq(pix_output.shape[0], d=1.0).astype(xp.float32)
)[:, None]
pix_fft /= xp.sinc(
xp.fft.fftfreq(pix_output.shape[1], d=1.0).astype(xp.float32)
)[None]
pix_output = xp.real(xp.fft.ifft2(pix_fft))

return pix_output
Expand Down Expand Up @@ -2898,8 +2910,8 @@ def fourier_rotation_best_shears_combination(tf):
def fourier_shear_Sx(array, a, b, xp=np):
""" """
Nx, Ny, Nz = array.shape
nx = xp.arange(Nx) - (Nx - 1) / 2
ny, nz = tuple(xp.fft.fftfreq(N, 1 / N) for N in [Ny, Nz])
nx = xp.arange(Nx, dtype=xp.float32) - (Nx - 1) / 2
ny, nz = tuple(xp.fft.fftfreq(N, 1 / N).astype(xp.float32) for N in [Ny, Nz])
nxa, nya, nza = xp.meshgrid(nx, ny, nz, indexing="ij")
phase_shift = xp.exp(-2j * np.pi * (a * nya + b * nza) * nxa / Nx)

Expand All @@ -2910,8 +2922,8 @@ def fourier_shear_Sx(array, a, b, xp=np):
def fourier_shear_Sy(array, a, b, xp=np):
""" """
Nx, Ny, Nz = array.shape
ny = xp.arange(Ny) - (Ny - 1) / 2
nx, nz = tuple(xp.fft.fftfreq(N, 1 / N) for N in [Nx, Nz])
ny = xp.arange(Ny, dtype=xp.float32) - (Ny - 1) / 2
nx, nz = tuple(xp.fft.fftfreq(N, 1 / N).astype(xp.float32) for N in [Nx, Nz])
nxa, nya, nza = xp.meshgrid(nx, ny, nz, indexing="ij")
phase_shift = xp.exp(-2j * np.pi * (a * nza + b * nxa) * nya / Ny)

Expand All @@ -2922,8 +2934,8 @@ def fourier_shear_Sy(array, a, b, xp=np):
def fourier_shear_Sz(array, a, b, xp=np):
""" """
Nx, Ny, Nz = array.shape
nz = xp.arange(Nz) - (Nz - 1) / 2
nx, ny = tuple(xp.fft.fftfreq(N, 1 / N) for N in [Nx, Ny])
nz = xp.arange(Nz, dtype=xp.float32) - (Nz - 1) / 2
nx, ny = tuple(xp.fft.fftfreq(N, 1 / N).astype(xp.float32) for N in [Nx, Ny])
nxa, nya, nza = xp.meshgrid(nx, ny, nz, indexing="ij")
phase_shift = xp.exp(-2j * np.pi * (a * nxa + b * nya) * nza / Nz)

Expand Down

0 comments on commit 7259309

Please sign in to comment.