From e2b4191d3603d5227f4273d019c6cd07aa1761f7 Mon Sep 17 00:00:00 2001 From: Georgios Varnavides Date: Mon, 4 Nov 2024 15:54:32 -0800 Subject: [PATCH] multi-processing tweaks --- py4DSTEM/process/phase/direct_ptychography.py | 398 ++++++++++-------- 1 file changed, 225 insertions(+), 173 deletions(-) diff --git a/py4DSTEM/process/phase/direct_ptychography.py b/py4DSTEM/process/phase/direct_ptychography.py index a8ae08b0c..183ce4916 100644 --- a/py4DSTEM/process/phase/direct_ptychography.py +++ b/py4DSTEM/process/phase/direct_ptychography.py @@ -3,6 +3,7 @@ namely single-sideband and Wigner-distribution deconvolution. """ +import itertools import warnings from typing import Mapping, Sequence, Union @@ -1452,24 +1453,43 @@ def __init__(self, *args, **kwargs): def _reconstruct_single_frequency( self, - intensities_FFT, - Qx, - Qy, - Kx, - Ky, - probe, - probe_conj, - aperture, - probe_kwargs, - phase_compensation: bool = True, - virtual_detector_masks: Sequence[np.ndarray] = None, + shared_objects, + ind_x, + ind_y, + phase_compensation=True, + virtual_detector_masks=None, xp=np, ): """ """ + ( + input_array, + output_array, + Qx_array, + Qy_array, + Kx, + Ky, + probe, + probe_conj, + aperture, + probe_kwargs, + ) = shared_objects + + sx, sy = Qx_array.shape + + # 2 stride is for complex values + ind_real = ind_x * sy * 2 + ind_y * 2 + 0 + ind_imag = ind_x * sy * 2 + ind_y * 2 + 1 + + intensities_FFT = input_array[ind_x, ind_y] + Qx = Qx_array[ind_x, ind_y] + Qy = Qy_array[ind_x, ind_y] + threshold = 1e-3 G = xp.asarray(intensities_FFT) - if Qx == 0.0 and Qy == 0.0: - return xp.abs(G).sum() + if ind_x == 0 and ind_y == 0: + val = xp.abs(G).sum() + output_array[ind_real] = val.real + output_array[ind_imag] = val.imag else: Kx_plus_Qx = Kx + Qx Ky_plus_Qy = Ky + Qy @@ -1512,7 +1532,9 @@ def _reconstruct_single_frequency( gamma_ind = gamma_abs > threshold normalization = gamma_abs[gamma_ind] - return (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum() + val = (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum() + output_array[ind_real] = val.real + output_array[ind_imag] = val.imag else: aperture_plus = aperture_plus > threshold @@ -1521,13 +1543,15 @@ def _reconstruct_single_frequency( aperture_solo = xp.logical_and( xp.logical_and(aperture, aperture_minus), ~aperture_plus ) - return G[aperture_solo].sum() * 2 # factor of 2 since using single band + + val = G[aperture_solo].sum() * 2 # factor of 2 since using single band + output_array[ind_real] = val.real + output_array[ind_imag] = val.imag def reconstruct( self, phase_compensation=True, num_jobs=None, - threads_per_job=None, virtual_detector_masks: Sequence[np.ndarray] = None, progress_bar: bool = True, polar_parameters: Mapping[str, float] = None, @@ -1570,7 +1594,7 @@ def reconstruct( ) sx, sy = self._grid_scan_shape - psi = xp.empty((sx, sy), dtype=xp.complex64) + probe = self._fourier_probe probe_conj = xp.conj(self._fourier_probe) aperture = xp.abs(self._fourier_probe) > 1e-3 @@ -1586,11 +1610,31 @@ def reconstruct( } Kx, Ky = self._spatial_frequencies - Qx, Qy = self._scan_frequencies + Qx_array, Qy_array = self._scan_frequencies if virtual_detector_masks is not None: virtual_detector_masks = xp.asarray(virtual_detector_masks).astype(xp.bool_) + if num_jobs == 1: + output_array = xp.empty(sx * sy * 2, dtype=xp.float32) + else: + from multiprocessing import Array as mp_Array + + output_array = mp_Array("f", sx * sy * 2, lock=False) + + shared_objects = ( + self._intensities_FFT, + output_array, + Qx_array, + Qy_array, + Kx, + Ky, + probe, + probe_conj, + aperture, + probe_kwargs, + ) + # main loop if num_jobs == 1: @@ -1601,63 +1645,41 @@ def reconstruct( unit="freq.", disable=not progress_bar, ): - psi[ind_x, ind_y] = self._reconstruct_single_frequency( - self._intensities_FFT[ind_x, ind_y], - Qx[ind_x, ind_y], - Qy[ind_x, ind_y], - Kx, - Ky, - self._fourier_probe, - probe_conj, - aperture, - probe_kwargs, + self._reconstruct_single_frequency( + shared_objects, + ind_x, + ind_y, phase_compensation=phase_compensation, virtual_detector_masks=virtual_detector_masks, xp=xp, ) + + psi = output_array.view(xp.complex64).reshape((sx, sy)) else: if self._device == "gpu": raise NotImplementedError() from mpire import WorkerPool, cpu_count - from threadpoolctl import threadpool_limits num_jobs = num_jobs or cpu_count() - if threads_per_job is not None: - num_jobs = num_jobs // threads_per_job - - map_inputs = [ - { - "intensities_FFT": self._intensities_FFT[ind_x, ind_y], - "Qx": Qx[ind_x, ind_y], - "Qy": Qy[ind_x, ind_y], - } - for ind_x in range(sx) - for ind_y in range(sy) - ] - def wrapper_function(**kwargs): - with threadpool_limits(limits=threads_per_job): - return self._reconstruct_single_frequency( - **kwargs, - Kx=Kx, - Ky=Ky, - probe=self._fourier_probe, - probe_conj=probe_conj, - aperture=aperture, - probe_kwargs=probe_kwargs, - phase_compensation=phase_compensation, - virtual_detector_masks=virtual_detector_masks, - xp=xp, - ) - - with WorkerPool(n_jobs=num_jobs) as pool: - flat_results = pool.map( - wrapper_function, map_inputs, progress_bar=progress_bar + def wrapper_function(*args): + return self._reconstruct_single_frequency( + *args, + phase_compensation=phase_compensation, + virtual_detector_masks=virtual_detector_masks, + xp=xp, ) - for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results): - psi[ind_x, ind_y] = res + with WorkerPool(n_jobs=num_jobs, shared_objects=shared_objects) as pool: + pool.map( + wrapper_function, + itertools.product(range(sx), range(sy)), + iterable_len=sx * sy, + n_splits=num_jobs, + progress_bar=progress_bar, + ) + psi = xp.frombuffer(output_array, dtype=xp.complex64).reshape((sx, sy)) self._object = xp.fft.ifft2(psi) / self._mean_diffraction_intensity @@ -1683,24 +1705,42 @@ def __init__(self, *args, **kwargs): def _reconstruct_single_frequency( self, - intensities_FFT, - Qx, - Qy, - Kx, - Ky, - probe, - probe_conj, - probe_normalization, - probe_kwargs, + shared_objects, + ind_x, + ind_y, virtual_detector_masks: Sequence[np.ndarray] = None, xp=np, ): """ """ - threshold = 1e-3 + ( + input_array, + output_array, + Qx_array, + Qy_array, + Kx, + Ky, + probe, + probe_conj, + probe_normalization, + probe_kwargs, + ) = shared_objects + + sx, sy = Qx_array.shape + + # 2 stride is for complex values + ind_real = ind_x * sy * 2 + ind_y * 2 + 0 + ind_imag = ind_x * sy * 2 + ind_y * 2 + 1 + + intensities_FFT = input_array[ind_x, ind_y] + Qx = Qx_array[ind_x, ind_y] + Qy = Qy_array[ind_x, ind_y] + threshold = 1e-3 G = xp.asarray(intensities_FFT) - if Qx == 0.0 and Qy == 0.0: - return xp.abs(G).sum() + if ind_x == 0 and ind_y == 0: + val = xp.abs(G).sum() + output_array[ind_real] = val.real + output_array[ind_imag] = val.imag else: Kx_plus_Qx = Kx + Qx Ky_plus_Qy = Ky + Qy @@ -1732,12 +1772,13 @@ def _reconstruct_single_frequency( d = probe_normalization[gamma_ind] normalization = d * xp.sqrt(xp.sum(normalization**2 / d)) - return (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum() + val = (G[gamma_ind] * gamma[gamma_ind].conj() / normalization).sum() + output_array[ind_real] = val.real + output_array[ind_imag] = val.imag def reconstruct( self, num_jobs=None, - threads_per_job=None, virtual_detector_masks: Sequence[np.ndarray] = None, progress_bar: bool = True, polar_parameters: Mapping[str, float] = None, @@ -1778,8 +1819,9 @@ def reconstruct( ) sx, sy = self._grid_scan_shape - psi = xp.empty((sx, sy), dtype=xp.complex64) + probe = self._fourier_probe probe_conj = xp.conj(self._fourier_probe) + probe_kwargs = { "energy": self._energy, "gpts": self._intensities_shape, @@ -1792,7 +1834,7 @@ def reconstruct( } Kx, Ky = self._spatial_frequencies - Qx, Qy = self._scan_frequencies + Qx_array, Qy_array = self._scan_frequencies probe_normalization = xp.abs(self._fourier_probe) ** 2 probe_normalization /= probe_normalization.sum() @@ -1803,6 +1845,26 @@ def reconstruct( probe_normalization, virtual_detector_masks, in_place=True ) + if num_jobs == 1: + output_array = xp.empty(sx * sy * 2, dtype=xp.float32) + else: + from multiprocessing import Array as mp_Array + + output_array = mp_Array("f", sx * sy * 2, lock=False) + + shared_objects = ( + self._intensities_FFT, + output_array, + Qx_array, + Qy_array, + Kx, + Ky, + probe, + probe_conj, + probe_normalization, + probe_kwargs, + ) + # main loop if num_jobs == 1: @@ -1813,61 +1875,39 @@ def reconstruct( unit="freq.", disable=not progress_bar, ): - psi[ind_x, ind_y] = self._reconstruct_single_frequency( - self._intensities_FFT[ind_x, ind_y], - Qx[ind_x, ind_y], - Qy[ind_x, ind_y], - Kx, - Ky, - self._fourier_probe, - probe_conj, - probe_normalization, - probe_kwargs, + self._reconstruct_single_frequency( + shared_objects, + ind_x, + ind_y, virtual_detector_masks=virtual_detector_masks, xp=xp, ) + + psi = output_array.view(xp.complex64).reshape((sx, sy)) else: if self._device == "gpu": raise NotImplementedError() from mpire import WorkerPool, cpu_count - from threadpoolctl import threadpool_limits num_jobs = num_jobs or cpu_count() - if threads_per_job is not None: - num_jobs = num_jobs // threads_per_job - - map_inputs = [ - { - "intensities_FFT": self._intensities_FFT[ind_x, ind_y], - "Qx": Qx[ind_x, ind_y], - "Qy": Qy[ind_x, ind_y], - } - for ind_x in range(sx) - for ind_y in range(sy) - ] - - def wrapper_function(**kwargs): - with threadpool_limits(limits=threads_per_job): - return self._reconstruct_single_frequency( - **kwargs, - Kx=Kx, - Ky=Ky, - probe=self._fourier_probe, - probe_conj=probe_conj, - probe_normalization=probe_normalization, - probe_kwargs=probe_kwargs, - virtual_detector_masks=virtual_detector_masks, - xp=xp, - ) - with WorkerPool(n_jobs=num_jobs) as pool: - flat_results = pool.map( - wrapper_function, map_inputs, progress_bar=progress_bar + def wrapper_function(*args): + return self._reconstruct_single_frequency( + *args, + virtual_detector_masks=virtual_detector_masks, + xp=xp, ) - for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results): - psi[ind_x, ind_y] = res + with WorkerPool(n_jobs=num_jobs, shared_objects=shared_objects) as pool: + pool.map( + wrapper_function, + itertools.product(range(sx), range(sy)), + iterable_len=sx * sy, + n_splits=num_jobs, + progress_bar=progress_bar, + ) + psi = xp.frombuffer(output_array, dtype=xp.complex64).reshape((sx, sy)) self._object = xp.fft.ifft2(psi) / self._mean_diffraction_intensity # no idea why this is necessary.. @@ -1889,17 +1929,33 @@ def __init__(self, *args, **kwargs): def _reconstruct_single_frequency( self, - intensities_FFT, - Qx, - Qy, - Kx, - Ky, - probe, - epsilon, - probe_kwargs, + shared_objects, + ind_x, + ind_y, xp=np, ): """ """ + ( + input_array, + output_array, + Qx_array, + Qy_array, + Kx, + Ky, + probe, + epsilon, + probe_kwargs, + ) = shared_objects + + sx, sy = Qx_array.shape + + # 2 stride is for complex values + ind_real = ind_x * sy * 2 + ind_y * 2 + 0 + ind_imag = ind_x * sy * 2 + ind_y * 2 + 1 + + intensities_FFT = input_array[ind_x, ind_y] + Qx = Qx_array[ind_x, ind_y] + Qy = Qy_array[ind_x, ind_y] array_G = xp.asarray(intensities_FFT) array_H = xp.fft.ifft2(array_G) @@ -1918,14 +1974,14 @@ def _reconstruct_single_frequency( array_D = wdd_probe_conj * array_H / (wdd_probe * wdd_probe_conj + epsilon) array_D_FFT = xp.fft.fft2(array_D) - return array_D_FFT[0, 0] + val = array_D_FFT[0, 0] + output_array[ind_real] = val.real + output_array[ind_imag] = val.imag def reconstruct( self, relative_wiener_epsilon, num_jobs=None, - threads_per_job=None, - virtual_detector_masks: Sequence[np.ndarray] = None, progress_bar: bool = True, polar_parameters: Mapping[str, float] = None, device: str = None, @@ -1962,7 +2018,7 @@ def reconstruct( ) sx, sy = self._grid_scan_shape - psi = xp.empty((sx, sy), dtype=xp.complex64) + probe = self._fourier_probe wdd_probe_0 = xp.fft.ifft2(self._fourier_probe * self._fourier_probe.conj()) wdd_probe_0 = xp.abs(wdd_probe_0[0, 0]) epsilon = relative_wiener_epsilon * wdd_probe_0 @@ -1979,10 +2035,26 @@ def reconstruct( } Kx, Ky = self._spatial_frequencies - Qx, Qy = self._scan_frequencies + Qx_array, Qy_array = self._scan_frequencies - if virtual_detector_masks is not None: - virtual_detector_masks = xp.asarray(virtual_detector_masks).astype(xp.bool_) + if num_jobs == 1: + output_array = xp.empty(sx * sy * 2, dtype=xp.float32) + else: + from multiprocessing import Array as mp_Array + + output_array = mp_Array("f", sx * sy * 2, lock=False) + + shared_objects = ( + self._intensities_FFT, + output_array, + Qx_array, + Qy_array, + Kx, + Ky, + probe, + epsilon, + probe_kwargs, + ) # main loop @@ -1994,57 +2066,37 @@ def reconstruct( unit="freq.", disable=not progress_bar, ): - psi[ind_x, ind_y] = self._reconstruct_single_frequency( - self._intensities_FFT[ind_x, ind_y], - Qx[ind_x, ind_y], - Qy[ind_x, ind_y], - Kx, - Ky, - self._fourier_probe, - epsilon, - probe_kwargs, + self._reconstruct_single_frequency( + shared_objects, + ind_x, + ind_y, xp=xp, ) + + psi = output_array.view(xp.complex64).reshape((sx, sy)) else: if self._device == "gpu": raise NotImplementedError() from mpire import WorkerPool, cpu_count - from threadpoolctl import threadpool_limits num_jobs = num_jobs or cpu_count() - if threads_per_job is not None: - num_jobs = num_jobs // threads_per_job - - map_inputs = [ - { - "intensities_FFT": self._intensities_FFT[ind_x, ind_y], - "Qx": Qx[ind_x, ind_y], - "Qy": Qy[ind_x, ind_y], - } - for ind_x in range(sx) - for ind_y in range(sy) - ] - - def wrapper_function(**kwargs): - with threadpool_limits(limits=threads_per_job): - return self._reconstruct_single_frequency( - **kwargs, - Kx=Kx, - Ky=Ky, - probe=self._fourier_probe, - epsilon=epsilon, - probe_kwargs=probe_kwargs, - xp=xp, - ) - with WorkerPool(n_jobs=num_jobs) as pool: - flat_results = pool.map( - wrapper_function, map_inputs, progress_bar=progress_bar + def wrapper_function(*args): + return self._reconstruct_single_frequency( + *args, + xp=xp, ) - for (ind_x, ind_y), res in zip(np.ndindex((sx, sy)), flat_results): - psi[ind_x, ind_y] = res + with WorkerPool(n_jobs=num_jobs, shared_objects=shared_objects) as pool: + pool.map( + wrapper_function, + itertools.product(range(sx), range(sy)), + iterable_len=sx * sy, + n_splits=num_jobs, + progress_bar=progress_bar, + ) + psi = xp.frombuffer(output_array, dtype=xp.complex64).reshape((sx, sy)) normalization = xp.abs(psi[0, 0]) / sx / sy psi /= normalization