diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index d5df63f5e..a6e0a6557 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -2,9 +2,10 @@ from emdfile import tqdmnd -### io +### Utility functions +from py4DSTEM.utils import * -# substructure +### IO substructure from emdfile import ( Node, Root, @@ -15,17 +16,13 @@ Custom, print_h5_tree, ) - _emd_hook = True -# structure +# IO structure from py4DSTEM import io from py4DSTEM.io import import_file, read, save - -### basic data classes - -# data +### Basic data classes from py4DSTEM.data import ( Data, Calibration, @@ -34,60 +31,29 @@ QPoints, ) -# datacube -from py4DSTEM.datacube import DataCube, VirtualImage, VirtualDiffraction - - -### visualization - +### Visualization from py4DSTEM import visualize from py4DSTEM.visualize import show, show_complex -### analysis classes - -# braggvectors -from py4DSTEM.braggvectors import ( - Probe, - BraggVectors, - BraggVectorMap, -) - +# Analysis classes +from py4DSTEM.datacube import DataCube +from py4DSTEM.datacube import VirtualImage, VirtualDiffraction +from py4DSTEM.datacube.diskdetection import Probe +from py4DSTEM.braggvectors import BraggVectors, BraggVectorMap from py4DSTEM.process import classification - - -# diffraction from py4DSTEM.process.diffraction import Crystal, Orientation - - -# ptycho from py4DSTEM.process import phase - - -# polar from py4DSTEM.process.polar import PolarDatacube - - -# strain from py4DSTEM.process.strain.strain import StrainMap - from py4DSTEM.process import wholepatternfit -### more submodules -# TODO - -from py4DSTEM import preprocess -from py4DSTEM import process - - -### utilities - -# config +### Config from py4DSTEM.utils.configuration_checker import check_config - # TODO - config .toml # testing from os.path import dirname, join _TESTPATH = join(dirname(__file__), "../test/unit_test_data") + diff --git a/py4DSTEM/braggvectors/__init__.py b/py4DSTEM/braggvectors/__init__.py index 482b1f31e..0556f3583 100644 --- a/py4DSTEM/braggvectors/__init__.py +++ b/py4DSTEM/braggvectors/__init__.py @@ -1,8 +1,5 @@ -from py4DSTEM.braggvectors.probe import Probe from py4DSTEM.braggvectors.braggvectors import BraggVectors from py4DSTEM.braggvectors.braggvector_methods import BraggVectorMap -from py4DSTEM.braggvectors.diskdetection import * -from py4DSTEM.braggvectors.probe import * # from .diskdetection_aiml import * # from .diskdetection_parallel_new import * diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 30ead79a6..30fe77f17 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -2,13 +2,15 @@ from __future__ import annotations import inspect from warnings import warn - +from scipy.ndimage import gaussian_filter import matplotlib.pyplot as plt import numpy as np + from emdfile import Array, Metadata, _read_metadata, tqdmnd from py4DSTEM import show +from py4DSTEM.utils import get_CoM from py4DSTEM.datacube import VirtualImage -from scipy.ndimage import gaussian_filter + class BraggVectorMethods: @@ -384,8 +386,6 @@ def measure_origin( np.argmax(gaussian_filter(bvm, 10)), (Q_Nx, Q_Ny) ) else: - from py4DSTEM.process.utils import get_CoM - x0, y0 = get_CoM(bvm) else: x0, y0 = center_guess diff --git a/py4DSTEM/datacube/__init__.py b/py4DSTEM/datacube/__init__.py index 883961fcb..3e9cc8331 100644 --- a/py4DSTEM/datacube/__init__.py +++ b/py4DSTEM/datacube/__init__.py @@ -3,3 +3,5 @@ from py4DSTEM.datacube.datacube import DataCube from py4DSTEM.datacube.virtualimage import VirtualImage from py4DSTEM.datacube.virtualdiffraction import VirtualDiffraction +from py4DSTEM.datacube.diskdetection.probe import Probe + diff --git a/py4DSTEM/datacube/datacube.py b/py4DSTEM/datacube/datacube.py index 930dd4c13..49feea517 100644 --- a/py4DSTEM/datacube/datacube.py +++ b/py4DSTEM/datacube/datacube.py @@ -8,21 +8,27 @@ distance_transform_edt, binary_fill_holes, gaussian_filter1d, - gaussian_filter, -) + gaussian_filter,) from typing import Optional, Union from emdfile import Array, Metadata, Node, Root, tqdmnd from py4DSTEM.data import Data, Calibration -from py4DSTEM.datacube.virtualimage import DataCubeVirtualImager -from py4DSTEM.datacube.virtualdiffraction import DataCubeVirtualDiffraction +from py4DSTEM.datacube.preprocess import Preprocessor +from py4DSTEM.datacube.virtualimage import VirtualImager +from py4DSTEM.datacube.virtualdiffraction import VirtualDiffractioner +from py4DSTEM.datacube.diskdetection import BraggFinder +from py4DSTEM.datacube.diskdetection import ProbeMaker + class DataCube( Array, Data, - DataCubeVirtualImager, - DataCubeVirtualDiffraction, + Preprocessor, + VirtualImager, + VirtualDiffractioner, + BraggFinder, + ProbeMaker, ): """ Storage and processing methods for 4D-STEM datasets. @@ -304,16 +310,6 @@ def add(self, data, name=""): if isinstance(data, np.ndarray): data = Array(data=data, name=name) self.attach(data) - - def set_scan_shape(self, Rshape): - """ - Reshape the data given the real space scan shape. - - Accepts: - Rshape (2-tuple) - """ - from py4DSTEM.preprocess import set_scan_shape - assert len(Rshape) == 2, "Rshape must have a length of 2" d = set_scan_shape(self, Rshape[0], Rshape[1]) return d @@ -1342,3 +1338,4 @@ def get_braggmask(self, braggvectors, rx, ry, radius): qr = np.hypot(self.qxx_raw - vects.qx[idx], self.qyy_raw - vects.qy[idx]) mask = np.logical_and(mask, qr > radius) return mask + diff --git a/py4DSTEM/datacube/diskdetection/__init__.py b/py4DSTEM/datacube/diskdetection/__init__.py new file mode 100644 index 000000000..e31ca3122 --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/__init__.py @@ -0,0 +1,5 @@ + +from py4DSTEM.datacube.diskdetection.probe import Probe,ProbeMaker +from py4DSTEM.datacube.diskdetection.braggfinder import BraggFinder + + diff --git a/py4DSTEM/datacube/diskdetection/braggfinder.py b/py4DSTEM/datacube/diskdetection/braggfinder.py new file mode 100644 index 000000000..1ad9f9c78 --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/braggfinder.py @@ -0,0 +1,791 @@ +import numpy as np +from typing import Optional +from scipy.ndimage import gaussian_filter + +from emdfile import tqdmnd +from emdfile import Metadata +from py4DSTEM.utils import get_maxima_2D, get_cross_correlation_FT +from py4DSTEM.data import QPoints + + + +class BraggFinder(object): + """ + Handles disk detection. + """ + + def __init__( + self, + ): + pass + + + def find_bragg_vectors( + self, + template, + data=None, + preprocess=None, + corr=None, + thresh=None, + device=None, + ML=None, + return_cc=False, + name = 'braggvectors', + returncalc = True, + _return_cc = False, + ): + + """ + Finds Bragg scattering vectors. + + In normal operation, localizes Bragg scattering using template matching, + by (1) optional preprocessing, (2) cross-correlating with the template, + and (3) finding local maxima, thresholding and returning. See + `preprocess`, `corr`, and `thresh` below. Accelration is handle with + `device`. + + Invoking `ML` makes use of a custom neural network called FCU-net + instead of template matching. If you use FCU-net in your work, + please reference "Munshi, Joydeep, et al. npj Computational Materials + 8.1 (2022): 254". + + + Examples (CPU + cross-correlation) + ---------------------------------- + + >>> datacube.get_bragg_vectors( template ) + + will find bragg scattering for the entire datacube using cross- + correlative template matching on the CPU with a correlation power of 1, + gaussian blurring on each correlagram of 2 pixels, polynomial subpixel + refinement, and the default thresholding parameters. + + >>> datacube.get_bragg_vectors( + >>> template, + >>> corr = { + >>> 'corrPower' : 1, + >>> 'sigma' : 2, + >>> 'subpixel' : 'multicorr', + >>> 'upsample_factor' : 16 + >>> }, + >>> ) + + will perform the same computation but use Fourier upsampling for + subpixel refinement, and + + >>> datacube.get_bragg_vectors( + >>> template, + >>> thresh = { + >>> 'minAboluteIntensity' : 100, + >>> 'minPeakSpacing' : 18, + >>> 'edgeBoundary' : 10, + >>> 'maxNumPeaks' : 100, + >>> }, + >>> ) + + will perform the same computation but threshold the detected + maxima using an absolute rather than relative intensity threshhold, + and modify the other threshold params as above. + + Using + + >>> datacube.get_bragg_vectors( + >>> template, + >>> data = (5,6) + >>> ) + + will perform template matching against the diffraction image at + scan position (5,6), and using + + >>> datacube.get_bragg_vectors( + >>> template, + >>> data = (np.array([4,5,6]),np.array([10,11,12])) + >>> ) + + will perform template matching against the 3 diffraction images at + scan positions (4,10), (5,11), (6,12). + + Using + + >>> datacube.fing_bragg_vectors( + >>> template = None, + >>> corr = { + >>> 'sigma' : 5, + >>> 'subpixel' : 'poly' + >>> }, + >>> ) + + will not cross-correlate at all, and will instead perform maximum + detection on the raw data of the entire datacube, after applying + a gaussian blur to each diffraction image of 5 pixels, and using + polynomial subpixel refinement. + + Using + + >>> datacube.find_bragg_vectors( + >>> template, + >>> preprocess = { + >>> 'sigma' : 2, + >>> }, + >>> corr = { + >>> 'sigma' : 4, + >>> } + >>> ) + + will apply a 2-pixel gaussian blur to the diffraction image, then + cross correlate, then apply a 4-pixel gaussian blur to the cross- + correlation before finding maxima. Using + + >>> datacube.find_bragg_vectors( + >>> template, + >>> preprocess = { + >>> 'radial_bkgrd' : True, + >>> 'localave' : True, + >>> }, + >>> ) + + will subtract the radial median from each diffraction image, then + obtain the weighted average diffraction image with a 3x3 gaussian + footprint in real space (i.e. + + [[ 1, 2, 1 ], + [ 2, 4, 2 ], * (1/16) + [ 1, 2, 1 ]] + + ) and perform template matching against the resulting images. + Using + + >>> def preprocess_fn(data): + >>> return py4DSTEM.utils.bin2D(data,2) + >>> datacube.find_bragg_vectors( + >>> template, + >>> preprocess = { + >>> 'filter_function' : preprocess_fn + >>> }, + >>> ) + + will bin each diffraction image by 2 before cross-correlating. + Note that the template shape must match the final, preprocessed + data shape. + + + Examples (GPU acceleration, cluster acceleration, and ML) + ------------------------------------------------------- + # TODO! + + + Parameters + ---------- + template : qshape'd 2d np.ndarray or Probe or None + The matching template. If an ndarray is passed, must be centered + about the origin. If a Probe is passed, probe.kernel must be + populated. If None is passed, cross correlation is skipped and + the maxima are taken directly from the (possibly preprocessed) + diffraction data + data : None or 2-tuple or 2D numpy ndarray + Specifies the data in which to find the Bragg scattering. Valid + entries and their behavoirs are: + * None: use the entire DataCube, and return a BraggVectors + instance + * 2-tuple of ints: use the diffraction pattern at scan + position (rx,ry), and return a QPoints instance + * 2-tuple of arrays of ints: use the diffraction patterns + at scan positions (rxs,rys), and return a list of QPoints + instances + * 2D numpy array, real-space shaped, boolean: run on the + diffraction images specified by the True pixels in the + input array, and return a list of QPoints instances + * 2D numpy array, diffraction-space shaped: run on the + input array, and return a QPoints instance + preprocess : None or dict + If None, no preprocessing is performed. Otherwise, should be a + dictionary with the following valid keys: + * radial_bkgrd (bool): if True, finds and subtracts the local + median of each diffraction image. Origin must be calibrated. + * localave (bool): if True, takes the local 3x3 gaussian + average of each diffraction image + * sigma (number): if >0, applies a gaussian blur to the data + before cross correlating + * filter_function (callable): function applied to each + diffraction image before peak finding. Must be a function of + only one argument (the diffr image) which returns the pre- + processed image. The shape of the returned DP must match + the shape of the probe. If using distributed disk detection, + the function must be able to be pickled with dill. + If more than one key is passed then all requested preprocessing + steps are performed, in the order they're listed here. + corr : None or dict + If None, no cross correlation is performed, and maximum detection + is performed on the (possibly preprocessed) diffraction data with + no subpixel refinement applied. Otherwise, should be a dictionary + with valid keys: + * corrPower (number, 0-1): type of correlation to perform, + where 1 is a cross correlation, 0 is a phase correlation, + and values in between are hybrid correlations. Pure cross + correlation is recommend to minimize noise + * sigma (number): if >0, apply a gaussian blur to the cross + correlation before detecting maxima + * subpixel ('none' or 'poly' or 'multicorr'): controls + subpixel refinement of maxima. 'none' returns the values + to pixel precision. 'poly' performs polynomial (2D + parabolic) numerical refinement. 'multicorr' performs + Fourier upsampling subpixel refinement, and requires + the `upsample_factor` keyword also be specified. 'poly' + is fast; 'multicorr' is much slower but allows greater + precision. + * upsample_factor (int): the upsampling factor used for + 'multicorr' subpixel refinement and defining the precision + of the refinement. Ignored if `subpixel` is not 'multicorr' + Note that passing `template=None` skips cross-correlation but not + maxmimum detection - in this case, `corrPower` is ignored, but all + parameters in this dictionary are used. + Note also that if this dictionary is specified (i.e. is not None) + but corrPower or sigma or subpixel or upsample_factor are not + specified, their default values (corrPower=1, sigma=2, + subpixel='poly', upsample_factor=16) are used. + thresh : None or dict + If None, no thresholding is performed (not recommended!). Otherwise, + should be a dictionary with valid keys: + * minAbsoluteIntensity (number): maxima with intensities below + `this value are removed. Ignored if set to 0. + * minRelativeIntensity (number): maxima with intensities below a + reference maximum * (this value) are removed. The refernce + maximum is selected for each diffraction image according to + the `relativeToPeak` argument: 0 specifies the brightest + maximum, 1 specifies the second brightest, etc. + * relativeToPeak (int): specifies the reference maximum used + int the `minRelativeIntensity` threshold + * minPeakSpacing (number): if two maxima are closer together than + this number of pixels, the dimmer maximum is removed + * edgeBoundary (number): maxima closer to the edge of the + diffraction image than this value are removed + * maxNumPeaks (int): only the brightest `maxNumPeaks` maxima + are kept + device : None or dict + If None, uses the CPU. Otherwise, should be a dictionary with + valid keys: + * CUDA (bool): enable GPU acceleration + * CUDA_batched (bool): enable batched GPU computation + * ipyparallel (bool): enable multinode distribution using + ipyparallel. Must also specify the "client" and "data_file", + and optionally "cluster_path" keywords + * dask (bool): enable multinode distribution using + dask. Must also specify the "client" and "data_file", + and optionally "cluster_path" keywords + * client (str or obj): when used with ipyparralel, must be a + path to a client json for connecting to your existing + IPyParallel cluster. When used with dask, must be a dask + client that connects to your existing Dask cluster + * data_file (str): the absolute path to your original data + file containing the datacube + * cluster_path (str): working directory for cluster processing, + defaults to current directory + Note that only one of 'CUDA' or 'ipyparallel' or 'dask' may be set + to True, and that 'client' and 'data_file' and optionally + 'cluster_path' must be specified if either 'ipyparallel' or 'dask' + computation is selected. Note also that preprocessing is currently + not performed for any device accelerated computations. + ML : None or dict + If None, does cross correlative template matching. Otherwise, should + be a dictionary with valid keys: + * num_attempts (int): Number of attempts to predict the Bragg + disks. More attempts (ideally) results in a more confident + prediction, as FCU-net uses Monte Carlo dropout to estimate + the model uncertainty. Note that increasing num_attempts will + increase the compute time and it is adviced to use GPU (CUDA) + acceleration for num_attempts > 1. + # TODO: @alex-rakowski - you had "Recommended: 5" but also + # had the default set to 1, and this comment that using >1 + # is not recommended without CUDA. Can we clarify? + * batch_size (int): number of diffraction images to send to + the model at once for prediction. For CPU a batch size of 1 + is recommended. For GPU the batch size may be selected based + on the available GPU RAM and the size of the diffraction + images, with larger batch sizes accelerating the computation + and increasing the required memory. + * model_path (None or str): if None, py4DSTEM will check if a + model is available locally, then download and update the model + if one is not available or the local model is not up-to-date. + Otherwise, must be a filepath to a Tensorflow model of weights + Note that to use GPU / batched GPU computation, the "CUDA" and + "CUDA_batched" flags should be set to True in the `device` arguement. + return_cc : bool + If True, returns the cross correlation in addition to the detected + peaks. + name : str + Name for the output BraggVectors instance + returncalc : bool + If True, return the answer + """ + + # use ML? + if ML: + raise Exception("ML isn't implemented here yet, please use find_Bragg_disks") + # use device? + if device: + raise Exception("Hardware acceleration isn't implemented here yet, please use find_Bragg_disks") + + # parse inputs + corr_defaults = { + 'sigma' : 0, + 'corr_power' : 1, + } + thresh_defaults = { + 'min_intensity' : 0, + 'min_spacing' : 5, + 'subpixel' : 'poly', + 'upsample_factor' : 16, + 'edge' : 0, + 'sigma' : 0, + 'n_peaks_max' : 10000, + 'min_prominence' : 0, + 'prominence_kernel_size' : 3, + 'min_rel_intensity' : 0, + 'ref_peak' : 0, + } + corr = corr_defaults if corr is None else corr_defaults | corr + thresh = thresh_defaults if thresh is None else thresh_defaults | thresh + + ## Set up metamethods (preprocess, crosscorr, thresholding) + + # preprocess + preprocess_options = [ + "bs", + "radial_background_subtraction", + "la", + "local_averaging" + ] + # validate inputs + if isinstance(preprocess,list): + for el in preprocess: + assert(isinstance(el,str) or callable(el)) + if isinstance(el,str): + assert(el in preprocess_options) + def _preprocess(dp,x,y): + """ dp = _preprocess_pattern(datacube.data[x,y]) + """ + if preprocess is None: + return dp + elif callable(preprocess): + return preprocess(dp) + # for dicts, element 'f' is a callable and + # all others are arguments to pass to it + elif isinstance(preprocess,dict): + f = preprocess.pop('f') + # if x+y are keys in preprocess, the callable f should + # recieve scan positions 'x' and 'y' as inputs + if 'x' in preprocess.keys() and 'y' in preprocess.keys(): + preprocess['x'] = x + preprocess['y'] = y + return f(dp,**preprocess) + else: + for el in preprocess: + if callable(el): + dp = el(dp) + else: + if el in ('bs','radial_background_subtraction'): + dp = self.get_radial_bksb_dp(x,y,sigma=0) + pass + elif el in ('la','local_averaging'): + dp = self.get_local_ave_dp(x,y) + else: + raise Exception("How did you get here?? A preprocess option may have been added incorrectly") + return dp + + # cross correlate + def _cross_correlate(dp): + """ cc = _cross_correlate(dp) + """ + if corr['sigma'] > 0: + dp = gaussian_filter(dp, corr['sigma']) + if template is not None: + cc = get_cross_correlation_FT( + dp, + template_FT, + corr['corr_power'], + "fourier", + ) + else: + cc = dp + return cc + + # threshold + def _threshold(cc): + """ vec = _threshold(cc) + """ + cc_real = np.maximum(np.real(np.fft.ifft2(cc)),0) + if thresh['subpixel'] == 'multicorr': + return get_maxima_2D( + cc_real, + subpixel='multicorr', + upsample_factor=thresh['upsample_factor'], + sigma=thresh['sigma'], + minAbsoluteIntensity=thresh['min_intensity'], + minProminence=thresh['min_prominence'], + prominenceKernelSize=thresh['prominence_kernel_size'], + minRelativeIntensity=thresh['min_rel_intensity'], + relativeToPeak=thresh['ref_peak'], + minSpacing=thresh['min_spacing'], + edgeBoundary=thresh['edge'], + maxNumPeaks=thresh['n_peaks_max'], + _ar_FT=cc, + ) + else: + return get_maxima_2D( + cc_real, + subpixel=thresh['subpixel'], + sigma=thresh['sigma'], + minAbsoluteIntensity=thresh['min_intensity'], + minProminence=thresh['min_prominence'], + prominenceKernelSize=thresh['prominence_kernel_size'], + minRelativeIntensity=thresh['min_rel_intensity'], + relativeToPeak=thresh['ref_peak'], + minSpacing=thresh['min_spacing'], + edgeBoundary=thresh['edge'], + maxNumPeaks=thresh['n_peaks_max'], + ) + + + # prepare the template + if template is not None: + template_FT = np.conj(np.fft.fft2(template)) + + + # prepare the data and output container for... + if data is None: + # ...all indices + rxs = np.tile(np.arange(self.Q_Nx),self.Q_Ny) + rys = np.tile(np.arange(self.Q_Ny),(self.Q_Nx,1)).T.reshape(datacube.Q_N) + N = len(rx) + vectors = BraggVectors(datacube.Rshape, datacube.Qshape) + elif isinstance(data,(tuple,list)): + # ...specified indices + rxs,rys = data + N = len(rxs) + vectors = [] + if _return_cc: + ccs = [] + else: + raise Exception(f"Invalid specification of data, {data}") + + + # Compute + for idx in tqdmnd( + N, + desc="Finding Bragg Disks", + unit="DP", + unit_scale=True, + ): + # get a diffraction pattern + rx,ry = rxs[idx],rys[idx] + dp = self.data[rx,ry] + + # preprocess + dp = _preprocess(dp,rx,ry) + + # cross correlate + cc = _cross_correlate(dp) + + # threshold + peaks = _threshold(cc) + + # store results + peaks = QPoints(peaks) + if data is None: + vectors._v_uncal[rx,ry] = peaks + else: + vectors.append(peaks) + if _return_cc: + ccs.append(cc) + + + # Return + if _return_cc is True: + return vectors, ccs + else: + return vectors + + + + + + + def get_beamstop_mask( + self, + threshold=0.25, + distance_edge=2.0, + include_edges=True, + sigma=0, + use_max_dp=False, + scale_radial=None, + name="mask_beamstop", + returncalc=True, + ): + """ + This function uses the mean diffraction pattern plus a threshold to + create a beamstop mask. + + Args: + threshold (float): Value from 0 to 1 defining initial threshold for + beamstop mask, taken from the sorted intensity values - 0 is the + dimmest pixel, while 1 uses the brighted pixels. + distance_edge (float): How many pixels to expand the mask. + include_edges (bool): If set to True, edge pixels will be included + in the mask. + sigma (float): + Gaussain blur std to apply to image before thresholding. + use_max_dp (bool): + Use the max DP instead of the mean DP. + scale_radial (float): + Scale from center of image by this factor (can help with edge) + name (string): Name of the output array. + returncalc (bool): Set to true to return the result. + + Returns: + (Optional): if returncalc is True, returns the beamstop mask + + """ + + if scale_radial is not None: + x = np.arange(self.data.shape[2]) * 2.0 / self.data.shape[2] + y = np.arange(self.data.shape[3]) * 2.0 / self.data.shape[3] + ya, xa = np.meshgrid(y - np.mean(y), x - np.mean(x)) + im_scale = 1.0 + np.sqrt(xa**2 + ya**2) * scale_radial + + # Get image for beamstop mask + if use_max_dp: + # if not "dp_mean" in self.tree.keys(): + # self.get_dp_max(); + # im = self.tree["dp_max"].data.astype('float') + if not "dp_max" in self._branch.keys(): + self.get_dp_max() + im = self.tree("dp_max").data.copy().astype("float") + else: + if not "dp_mean" in self._branch.keys(): + self.get_dp_mean() + im = self.tree("dp_mean").data.copy() + + # if not "dp_mean" in self.tree.keys(): + # self.get_dp_mean(); + # im = self.tree["dp_mean"].data.astype('float') + + # smooth and scale if needed + if sigma > 0.0: + im = gaussian_filter(im, sigma, mode="nearest") + if scale_radial is not None: + im *= im_scale + + # Calculate beamstop mask + int_sort = np.sort(im.ravel()) + ind = np.round( + np.clip(int_sort.shape[0] * threshold, 0, int_sort.shape[0]) + ).astype("int") + intensity_threshold = int_sort[ind] + mask_beamstop = im >= intensity_threshold + + # clean up mask + mask_beamstop = np.logical_not(binary_fill_holes(np.logical_not(mask_beamstop))) + mask_beamstop = binary_fill_holes(mask_beamstop) + + # Edges + if include_edges: + mask_beamstop[0, :] = False + mask_beamstop[:, 0] = False + mask_beamstop[-1, :] = False + mask_beamstop[:, -1] = False + + # Expand mask + mask_beamstop = distance_transform_edt(mask_beamstop) < distance_edge + + # Wrap beamstop mask in a class + x = Array(data=mask_beamstop, name=name) + + # Add metadata + x.metadata = Metadata( + name="gen_params", + data={ + #'gen_func' : + "threshold": threshold, + "distance_edge": distance_edge, + "include_edges": include_edges, + "name": "mask_beamstop", + "returncalc": returncalc, + }, + ) + + # Add to tree + self.tree(x) + + # return + if returncalc: + return mask_beamstop + + + + + + + + + + +# ### OLD CODE +# +# elif mode == "datacube": +# if distributed is None and CUDA == False: +# mode = "dc_CPU" +# elif distributed is None and CUDA == True: +# if CUDA_batched == False: +# mode = "dc_GPU" +# else: +# mode = "dc_GPU_batched" +# else: +# x = _parse_distributed(distributed) +# connect, data_file, cluster_path, distributed_mode = x +# if distributed_mode == "dask": +# mode = "dc_dask" +# elif distributed_mode == "ipyparallel": +# mode = "dc_ipyparallel" +# else: +# er = f"unrecognized distributed mode {distributed_mode}" +# raise Exception(er) +# # overwrite if ML selected +# +# # select a function +# fn_dict = { +# "dp": _find_Bragg_disks_single, +# "dp_stack": _find_Bragg_disks_stack, +# "dc_CPU": _find_Bragg_disks_CPU, +# "dc_GPU": _find_Bragg_disks_CUDA_unbatched, +# "dc_GPU_batched": _find_Bragg_disks_CUDA_batched, +# "dc_dask": _find_Bragg_disks_dask, +# "dc_ipyparallel": _find_Bragg_disks_ipp, +# "dc_ml": find_Bragg_disks_aiml, +# } +# fn = fn_dict[mode] +# +# # prepare kwargs +# kws = {} +# # distributed kwargs +# if distributed is not None: +# kws["connect"] = connect +# kws["data_file"] = data_file +# kws["cluster_path"] = cluster_path +# # ML arguments +# if ML == True: +# kws["CUDA"] = CUDA +# kws["model_path"] = ml_model_path +# kws["num_attempts"] = ml_num_attempts +# kws["batch_size"] = ml_batch_size +# +# # if radial background subtraction is requested, add to args +# if radial_bksb and mode == "dc_CPU": +# kws["radial_bksb"] = radial_bksb +# +# # run and return +# ans = fn( +# data, +# template, +# filter_function=filter_function, +# corrPower=corrPower, +# sigma_dp=sigma_dp, +# sigma_cc=sigma_cc, +# subpixel=subpixel, +# upsample_factor=upsample_factor, +# minAbsoluteIntensity=minAbsoluteIntensity, +# minRelativeIntensity=minRelativeIntensity, +# relativeToPeak=relativeToPeak, +# minPeakSpacing=minPeakSpacing, +# edgeBoundary=edgeBoundary, +# maxNumPeaks=maxNumPeaks, +# **kws, +# ) +# return ans +# +# +# +# +# +# # parse args +# if data is None: +# x = self +# elif isinstance(data, tuple): +# x = self, data[0], data[1] +# elif isinstance(data, np.ndarray): +# assert data.dtype == bool, "array must be boolean" +# assert data.shape == self.Rshape, "array must be Rspace shaped" +# x = self.data[data, :, :] +# else: +# raise Exception(f"unexpected type for `data` {type(data)}") +# +# +# +# +# # compute +# peaks = find_Bragg_disks( +# data=x, +# template=template, +# radial_bksb=radial_bksb, +# filter_function=filter_function, +# corrPower=corrPower, +# sigma_dp=sigma_dp, +# sigma_cc=sigma_cc, +# subpixel=subpixel, +# upsample_factor=upsample_factor, +# minAbsoluteIntensity=minAbsoluteIntensity, +# minRelativeIntensity=minRelativeIntensity, +# relativeToPeak=relativeToPeak, +# minPeakSpacing=minPeakSpacing, +# edgeBoundary=edgeBoundary, +# maxNumPeaks=maxNumPeaks, +# CUDA=CUDA, +# CUDA_batched=CUDA_batched, +# distributed=distributed, +# ML=ML, +# ml_model_path=ml_model_path, +# ml_num_attempts=ml_num_attempts, +# ml_batch_size=ml_batch_size, +# ) +# +# if isinstance(peaks, Node): +# # add metadata +# peaks.name = name +# peaks.metadata = Metadata( +# name="gen_params", +# data={ +# #'gen_func' : +# "template": template, +# "filter_function": filter_function, +# "corrPower": corrPower, +# "sigma_dp": sigma_dp, +# "sigma_cc": sigma_cc, +# "subpixel": subpixel, +# "upsample_factor": upsample_factor, +# "minAbsoluteIntensity": minAbsoluteIntensity, +# "minRelativeIntensity": minRelativeIntensity, +# "relativeToPeak": relativeToPeak, +# "minPeakSpacing": minPeakSpacing, +# "edgeBoundary": edgeBoundary, +# "maxNumPeaks": maxNumPeaks, +# "CUDA": CUDA, +# "CUDA_batched": CUDA_batched, +# "distributed": distributed, +# "ML": ML, +# "ml_model_path": ml_model_path, +# "ml_num_attempts": ml_num_attempts, +# "ml_batch_size": ml_batch_size, +# }, +# ) +# +# # add to tree +# if data is None: +# self.attach(peaks) +# +# # return +# if returncalc: +# return peaks +# +# # aliases +# find_disks = find_bragg = find_bragg_disks = find_bragg_scattering = find_bragg_vectors +# +# diff --git a/py4DSTEM/datacube/diskdetection/diskdetection.py b/py4DSTEM/datacube/diskdetection/diskdetection.py new file mode 100644 index 000000000..47c7ffed7 --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/diskdetection.py @@ -0,0 +1,1330 @@ +# Functions for finding Bragg scattering by cross correlative template matching +# with a vacuum probe. + +import numpy as np +from scipy.ndimage import gaussian_filter + +from emdfile import tqdmnd +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from py4DSTEM.data import QPoints +from py4DSTEM.datacube import DataCube +from py4DSTEM.utils import get_maxima_2D, get_cross_correlation_FT +from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml + + + +def find_bragg_vectors( + data, + template, + corr=None, + thresh=None, + preprocess=None, + preprocess_args=None, + device=None, + ML=None, + return_cc = False, + name = 'braggvectors', + returncalc = True +): + """ + Finds Bragg scattering vectors. + + The method is template matching unless the ML argument is specified. + In normal operation, the sequence is (1) optional preprocessing, + (2) cross-correlating with the template, and (3) finding local maxima, + thresholding and returning. See `preprocess`, `corr`, `thresh`, and + `ML` below. Accelration is handle with `device`. + + Invoking `ML` makes use of a custom neural network called FCU-net to + localize the Bragg scattering. If you use FCU-net in your work, + please reference "Munshi, Joydeep, et al. npj Computational Materials + 8.1 (2022): 254". + + + Examples + -------- + + >>> datacube.find_bragg_vectors( template ) + + will find bragg scattering for the entire datacube using cross- + correlative template matching on the CPU. Calling + + >>> datacube.find_bragg_vectors( + >>> template, + >>> data = (x,y) + >>> ) + + finds and returns bragg scattering at scan position(s) (x,y). + + The cross-correlation by default blurs each correlagra + + By default gaussian blurring on each correlagram of 2 pixels, polynomial subpixel + refinement, and the default thresholding parameters. + + >>> datacube.get_bragg_vectors( + >>> template, + >>> corr = { + >>> 'sigma' : 2, + >>> 'subpixel' : 'poly', + >>> }, + >>> ) + + will perform the same computation but use Fourier upsampling for + subpixel refinement, and + + >>> datacube.get_bragg_vectors( + >>> template, + >>> thresh = { + >>> 'minAboluteIntensity' : 100, + >>> 'minPeakSpacing' : 18, + >>> 'edgeBoundary' : 10, + >>> 'maxNumPeaks' : 100, + >>> }, + >>> ) + + will perform the same computation but threshold the detected + maxima using an absolute rather than relative intensity threshhold, + and modify the other threshold params as above. + + Using + + >>> datacube.get_bragg_vectors( + >>> template, + >>> data = (5,6) + >>> ) + + will perform template matching against the diffraction image at + scan position (5,6), and using + + >>> datacube.get_bragg_vectors( + >>> template, + >>> data = (np.array([4,5,6]),np.array([10,11,12])) + >>> ) + + will perform template matching against the 3 diffraction images at + scan positions (4,10), (5,11), (6,12). + + Using + + >>> datacube.fing_bragg_vectors( + >>> template = None, + >>> corr = { + >>> 'sigma' : 5, + >>> 'subpixel' : 'poly' + >>> }, + >>> ) + + will not cross-correlate at all, and will instead perform maximum + detection on the raw data of the entire datacube, after applying + a gaussian blur to each diffraction image of 5 pixels, and using + polynomial subpixel refinement. + + Using + + >>> datacube.find_bragg_vectors( + >>> template, + >>> preprocess = { + >>> 'sigma' : 2, + >>> }, + >>> corr = { + >>> 'sigma' : 4, + >>> } + >>> ) + + will apply a 2-pixel gaussian blur to the diffraction image, then + cross correlate, then apply a 4-pixel gaussian blur to the cross- + correlation before finding maxima. Using + + >>> datacube.find_bragg_vectors( + >>> template, + >>> preprocess = { + >>> 'radial_bkgrd' : True, + >>> 'localave' : True, + >>> }, + >>> ) + + will subtract the radial median from each diffraction image, then + obtain the weighted average diffraction image with a 3x3 gaussian + footprint in real space (i.e. + + [[ 1, 2, 1 ], + [ 2, 4, 2 ], * (1/16) + [ 1, 2, 1 ]] + + ) and perform template matching against the resulting images. + Using + + >>> def preprocess_fn(data): + >>> return py4DSTEM.utils.bin2D(data,2) + >>> datacube.find_bragg_vectors( + >>> template, + >>> preprocess = { + >>> 'filter_function' : preprocess_fn + >>> }, + >>> ) + + will bin each diffraction image by 2 before cross-correlating. + Note that the template shape must match the final, preprocessed + data shape. + + + Examples (GPU acceleration, cluster acceleration, and ML) + ------------------------------------------------------- + # TODO! + + + Parameters + ---------- + template : qshape'd 2d np.ndarray or Probe or None + The matching template. If an ndarray is passed, must be centered + about the origin. If a Probe is passed, probe.kernel must be + populated. If None is passed, cross correlation is skipped and + the maxima are taken directly from the (possibly preprocessed) + diffraction data + data : None or 2-tuple or 2D numpy ndarray + Specifies the data in which to find the Bragg scattering. Valid + entries and their behavoirs are: + * None: use the entire DataCube, and return a BraggVectors + instance + * 2-tuple of ints: use the diffraction pattern at scan + position (rx,ry), and return a QPoints instance + * 2-tuple of arrays of ints: use the diffraction patterns + at scan positions (rxs,rys), and return a list of QPoints + instances + * 2D numpy array, real-space shaped, boolean: run on the + diffraction images specified by the True pixels in the + input array, and return a list of QPoints instances + * 2D numpy array, diffraction-space shaped: run on the + input array, and return a QPoints instance + preprocess : None or dict + If None, no preprocessing is performed. Otherwise, should be a + dictionary with the following valid keys: + * radial_bkgrd (bool): if True, finds and subtracts the local + median of each diffraction image. Origin must be calibrated. + * localave (bool): if True, takes the local 3x3 gaussian + average of each diffraction image + * sigma (number): if >0, applies a gaussian blur to the data + before cross correlating + * filter_function (callable): function applied to each + diffraction image before peak finding. Must be a function of + only one argument (the diffr image) which returns the pre- + processed image. The shape of the returned DP must match + the shape of the probe. If using distributed disk detection, + the function must be able to be pickled with dill. + If more than one key is passed then all requested preprocessing + steps are performed, in the order they're listed here. + corr : None or dict + If None, no cross correlation is performed, and maximum detection + is performed on the (possibly preprocessed) diffraction data with + no subpixel refinement applied. Otherwise, should be a dictionary + with valid keys: + * corrPower (number, 0-1): type of correlation to perform, + where 1 is a cross correlation, 0 is a phase correlation, + and values in between are hybrid correlations. Pure cross + correlation is recommend to minimize noise + * sigma (number): if >0, apply a gaussian blur to the cross + correlation before detecting maxima + * subpixel ('none' or 'poly' or 'multicorr'): controls + subpixel refinement of maxima. 'none' returns the values + to pixel precision. 'poly' performs polynomial (2D + parabolic) numerical refinement. 'multicorr' performs + Fourier upsampling subpixel refinement, and requires + the `upsample_factor` keyword also be specified. 'poly' + is fast; 'multicorr' is much slower but allows greater + precision. + * upsample_factor (int): the upsampling factor used for + 'multicorr' subpixel refinement and defining the precision + of the refinement. Ignored if `subpixel` is not 'multicorr' + Note that passing `template=None` skips cross-correlation but not + maxmimum detection - in this case, `corrPower` is ignored, but all + parameters in this dictionary are used. + Note also that if this dictionary is specified (i.e. is not None) + but corrPower or sigma or subpixel or upsample_factor are not + specified, their default values (corrPower=1, sigma=2, + subpixel='poly', upsample_factor=16) are used. + thresh : None or dict + If None, no thresholding is performed (not recommended!). Otherwise, + should be a dictionary with valid keys: + * minAbsoluteIntensity (number): maxima with intensities below + `this value are removed. Ignored if set to 0. + * minRelativeIntensity (number): maxima with intensities below a + reference maximum * (this value) are removed. The refernce + maximum is selected for each diffraction image according to + the `relativeToPeak` argument: 0 specifies the brightest + maximum, 1 specifies the second brightest, etc. + * relativeToPeak (int): specifies the reference maximum used + int the `minRelativeIntensity` threshold + * minPeakSpacing (number): if two maxima are closer together than + this number of pixels, the dimmer maximum is removed + * edgeBoundary (number): maxima closer to the edge of the + diffraction image than this value are removed + * maxNumPeaks (int): only the brightest `maxNumPeaks` maxima + are kept + device : None or dict + If None, uses the CPU. Otherwise, should be a dictionary with + valid keys: + * CUDA (bool): enable GPU acceleration + * CUDA_batched (bool): enable batched GPU computation + * ipyparallel (bool): enable multinode distribution using + ipyparallel. Must also specify the "client" and "data_file", + and optionally "cluster_path" keywords + * dask (bool): enable multinode distribution using + dask. Must also specify the "client" and "data_file", + and optionally "cluster_path" keywords + * client (str or obj): when used with ipyparralel, must be a + path to a client json for connecting to your existing + IPyParallel cluster. When used with dask, must be a dask + client that connects to your existing Dask cluster + * data_file (str): the absolute path to your original data + file containing the datacube + * cluster_path (str): working directory for cluster processing, + defaults to current directory + Note that only one of 'CUDA' or 'ipyparallel' or 'dask' may be set + to True, and that 'client' and 'data_file' and optionally + 'cluster_path' must be specified if either 'ipyparallel' or 'dask' + computation is selected. Note also that preprocessing is currently + not performed for any device accelerated computations. + ML : None or dict + If None, does cross correlative template matching. Otherwise, should + be a dictionary with valid keys: + * num_attempts (int): Number of attempts to predict the Bragg + disks. More attempts (ideally) results in a more confident + prediction, as FCU-net uses Monte Carlo dropout to estimate + the model uncertainty. Note that increasing num_attempts will + increase the compute time and it is adviced to use GPU (CUDA) + acceleration for num_attempts > 1. + # TODO: @alex-rakowski - you had "Recommended: 5" but also + # had the default set to 1, and this comment that using >1 + # is not recommended without CUDA. Can we clarify? + * batch_size (int): number of diffraction images to send to + the model at once for prediction. For CPU a batch size of 1 + is recommended. For GPU the batch size may be selected based + on the available GPU RAM and the size of the diffraction + images, with larger batch sizes accelerating the computation + and increasing the required memory. + * model_path (None or str): if None, py4DSTEM will check if a + model is available locally, then download and update the model + if one is not available or the local model is not up-to-date. + Otherwise, must be a filepath to a Tensorflow model of weights + Note that to use GPU / batched GPU computation, the "CUDA" and + "CUDA_batched" flags should be set to True in the `device` arguement. + return_cc : bool + If True, returns the cross correlation in addition to the detected + peaks. + name : str + Name for the output BraggVectors instance + returncalc : bool + If True, return the answer + """ + + # TODO TODO TODO + + # Set defaults + corr_default = { + 'corrPower' : 1, + 'subpixel' : 'poly', + 'upsample_factor' : 16, + } + thresh_default = { + 'sigma' : 2, + 'minAbsoluteIntensity' : 0, + 'minRelativeIntensity' : 0.005, + 'relativeToPeak' : 0, + 'minPeakSpacing' : 60, + 'edgeBoundary' : 20, + 'maxNumPeaks' : 70, + }, + ML_defaults = { + 'CUDA' : None, + 'ml_model_path' : None, + 'ml_num_attempts' : 1, + 'ml_batch_size' : 8, + } + preprocess_keys = ( + 'rbs', + 'radial_background_subtraction', + 'rds' + 'remove_disks', + 'bin', + ) + + + + + # Parse arguments + + # parse `data` + if isinstance(data, DataCube): + datamode = "datacube" + elif isinstance(data, np.ndarray): + if data.ndim == 2: + datamode = "dp" + elif data.ndim == 3: + datamode = "dp_stack" + else: + er = f"if `data` is an array, must be 2- or 3-D, not {data.ndim}-D" + raise Exception(er) + else: + try: + # for positions (rx,ry) + dc, rx, ry = data[0], data[1], data[2] + + # extra logic for HDF5 data + if "h5py" not in str(type(dc.data)): + data = dc.data[np.array(rx), np.array(ry), :, :] + else: + # h5py datasets have different rules for slicing than + # numpy arrays, so we have to do this manually + data = np.zeros((len(rx), dc.Q_Nx, dc.Q_Ny)) + # no background subtraction + for i, (x, y) in enumerate(zip(rx, ry)): + data[i] = dc.data[x, y] + except: + er = f"entry {data} for `data` could not be parsed" + raise Exception(er) + + + # Use the ML method + if ML: + + raise Exception("ML methods are currently accessible in the find_Bragg_disks method. Thanks for you patience!") + + kws["CUDA"] = CUDA + kws["model_path"] = ml_model_path + kws["num_attempts"] = ml_num_attempts + kws["batch_size"] = ml_batch_size + + find_Bragg_disks_aiml( + data, + template, + filter_function=filter_function, + corrPower=corrPower, + sigma_dp=sigma_dp, + sigma_cc=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + **kws, + ) + + + # Use template matching + + # Preprocess + if preprocess is not None: + preprocess = [preprocess] if not isinstance(preprocess,list) else preprocess + if preprocess_args is not None: + preprocess = [preprocess] if not isinstance(preprocess,list) else preprocess + for p in preprocess: + er = f"preprocess must be callable or one of the keys {preprocess_keys}" + assert p in preprocess_keys or callable(p), er + if preprocess_args is not None: + er = "number of preprocessing steps and number of arg dictionaries don't match" + assert len(preprocess_args)==len(preprocess), er + + + # if radial background subtraction is requested, add to args + if radial_bksb and mode == "dc_CPU": + kws["radial_bksb"] = radial_bksb + + # run and return + ans = fn( + data, + template, + filter_function=filter_function, + corrPower=corrPower, + sigma_dp=sigma_dp, + sigma_cc=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + **kws, + ) + return ans + + + + + + # parse args + sigma_cc = sigma if sigma is not None else sigma_cc + + # Radial background subtraction + # no background subtraction + if not radial_bksb: + data = dc.data[np.array(rx), np.array(ry), :, :] + # with bksubtr + else: + data = np.zeros((len(rx), dc.Q_Nx, dc.Q_Ny)) + for i, (x, y) in enumerate(zip(rx, ry)): + data[i] = dc.get_radial_bksb_dp(x, y) + + + + elif mode == "datacube": + if distributed is None and CUDA is False: + mode = "dc_CPU" + elif distributed is None and CUDA is True: + if CUDA_batched is False: + mode = "dc_GPU" + else: + mode = "dc_GPU_batched" + else: + x = _parse_distributed(distributed) + connect, data_file, cluster_path, distributed_mode = x + if distributed_mode == "dask": + mode = "dc_dask" + elif distributed_mode == "ipyparallel": + mode = "dc_ipyparallel" + else: + er = f"unrecognized distributed mode {distributed_mode}" + raise Exception(er) + # overwrite if ML selected + + # select a function + fn_dict = { + "dp": _find_Bragg_disks_single, + "dp_stack": _find_Bragg_disks_stack, + "dc_CPU": _find_Bragg_disks_CPU, + "dc_GPU": _find_Bragg_disks_CUDA_unbatched, + "dc_GPU_batched": _find_Bragg_disks_CUDA_batched, + "dc_dask": _find_Bragg_disks_dask, + "dc_ipyparallel": _find_Bragg_disks_ipp, + "dc_ml": find_Bragg_disks_aiml, + } + fn = fn_dict[mode] + + # prepare kwargs + kws = {} + # distributed kwargs + if distributed is not None: + kws["connect"] = connect + kws["data_file"] = data_file + kws["cluster_path"] = cluster_path + # ML arguments + if ML is True: + kws["CUDA"] = CUDA + kws["model_path"] = ml_model_path + kws["num_attempts"] = ml_num_attempts + kws["batch_size"] = ml_batch_size + + # if radial background subtraction is requested, add to args + if radial_bksb and mode == "dc_CPU": + kws["radial_bksb"] = radial_bksb + + # run and return + ans = fn( + data, + template, + filter_function=filter_function, + corrPower=corrPower, + sigma_dp=sigma_dp, + sigma_cc=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + **kws, + ) + return ans + + + + + + + + +def find_Bragg_disks( + data, + template, + radial_bksb=False, + filter_function=None, + corrPower=1, + sigma=None, + sigma_dp=0, + sigma_cc=2, + subpixel="multicorr", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0.005, + relativeToPeak=0, + minPeakSpacing=60, + edgeBoundary=20, + maxNumPeaks=70, + CUDA=False, + CUDA_batched=True, + distributed=None, + ML=False, + ml_model_path=None, + ml_num_attempts=1, + ml_batch_size=8, +): + """ + Finds the Bragg disks in the diffraction patterns represented by `data` by + cross/phase correlatin with `template`. + + Behavior depends on `data`. If it is + + - a DataCube: runs on all its diffraction patterns, and returns a + BraggVectors instance + - a 2D array: runs on this array, and returns a QPoints instance + - a 3D array: runs slice the ar[i,:,:] slices of this array, and returns + a len(ar.shape[0]) list of QPoints instances. + - a 3-tuple (DataCube, rx, ry), for numbers or length-N arrays (rx,ry): + runs on the diffraction patterns in DataCube at positions (rx,ry), + and returns a instance or length N list of instances of QPoints + + For disk detection on a full DataCube, the calculation can be performed + on the CPU, GPU or a cluster. By default the CPU is used. If `CUDA` is set + to True, tries to use the GPU. If `CUDA_batched` is also set to True, + batches the FFT/IFFT computations on the GPU. For distribution to a cluster, + distributed must be set to a dictionary, with contents describing how + distributed processing should be performed - see below for details. + + + For each diffraction pattern, the algorithm works in 4 steps: + + (1) any pre-processing is performed to the diffraction image. This is + accomplished by passing a callable function to the argument + `filter_function`, a bool to the argument `radial_bksb`, or a value >0 + to `sigma_dp`. If none of these are passed, this step is skipped. + (2) the diffraction image is cross correlated with the template. + Phase/hybrid correlations can be used instead by setting the + `corrPower` argument. Cross correlation can be skipped entirely, + and the subsequent steps performed directly on the diffraction + image instead of the cross correlation, by passing None to + `template`. + (3) the maxima of the cross correlation are located and their + positions and intensities stored. The cross correlation may be + passed through a gaussian filter first by passing the `sigma_cc` + argument. The method for maximum detection can be set with + the `subpixel` parameter. Options, from something like fastest/least + precise to slowest/most precise are 'pixel', 'poly', and 'multicorr'. + (4) filtering is applied to remove untrusted or undesired positive counts, + based on their intensity (`minRelativeIntensity`,`relativeToPeak`, + `minAbsoluteIntensity`) their proximity to one another or the + image edge (`minPeakSpacing`, `edgeBoundary`), and the total + number of peaks per pattern (`maxNumPeaks`). + + + Parameters + ---------- + data : variable + see above + template : 2D array + the vacuum probe template, in real space. For Probe instances, + this is `probe.kernel`. If None, does not perform a cross + correlation. + radial_bksb : bool + if True, computes a radial background given by the median of the + (circular) polar transform of each each diffraction pattern, and + subtracts this background from the pattern before applying any + filter function and computing the cross correlation. The origin + position must be set in the datacube's calibrations. Currently + only supported for full datacubes on the CPU. + filter_function : callable + filtering function to apply to each diffraction pattern before + peak finding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. + The shape of the returned DP must match the shape of the probe + kernel (but does not need to match the shape of the input + diffraction pattern, e.g. the filter can be used to bin the + diffraction pattern). If using distributed disk detection, the + function must be able to be pickled with by dill. + corrPower : float between 0 and 1, inclusive + the cross correlation power. A value of 1 corresponds to a cross + correlation, 0 corresponds to a phase correlation, and intermediate + values correspond to hybrid correlations. + sigma : float + alias for `sigma_cc` + sigma_dp : float + if >0, a gaussian smoothing filter with this standard deviation + is applied to the diffraction pattern before maxima are detected + sigma_cc : float + if >0, a gaussian smoothing filter with this standard deviation + is applied to the cross correlation before maxima are detected + subpixel : str + Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor : int + upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + minAbsoluteIntensity : float + the minimum acceptable correlation peak intensity, on an absolute scale + minRelativeIntensity : float + the minimum acceptable correlation peak intensity, relative to the + intensity of the brightest peak + relativeToPeak : int + specifies the peak against which the minimum relative intensity is + measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing : float + the minimum acceptable spacing between detected peaks + edgeBoundary (int): minimum acceptable distance for detected peaks from + the diffraction image edge, in pixels. + maxNumPeaks : int + the maximum number of peaks to return + CUDA : bool + If True, import cupy and use an NVIDIA GPU to perform disk detection + CUDA_batched : bool + If True, and CUDA is selected, the FFT and IFFT steps of disk detection + are performed in batches to better utilize GPU resources. + distributed : dict + contains information for parallel processing using an IPyParallel or + Dask distributed cluster. Valid keys are: + * ipyparallel (dict): + * client_file (str): path to client json for connecting to your + existing IPyParallel cluster + * dask (dict): client (object): a dask client that connects to + your existing Dask cluster + * data_file (str): the absolute path to your original data + file containing the datacube + * cluster_path (str): defaults to the working directory during + processing + if distributed is None, which is the default, processing will be in + serial + + Returns + ------- + variable + the Bragg peak positions and correlation intensities. If `data` is: + * a DataCube, returns a BraggVectors instance + * a 2D array, returns a QPoints instance + * a 3D array, returns a list of QPoints instances + * a (DataCube,rx,ry) 3-tuple, returns a list of QPoints + instances + """ + + # parse args + sigma_cc = sigma if sigma is not None else sigma_cc + + # `data` type + if isinstance(data, DataCube): + mode = "datacube" + elif isinstance(data, np.ndarray): + if data.ndim == 2: + mode = "dp" + elif data.ndim == 3: + mode = "dp_stack" + else: + er = f"if `data` is an array, must be 2- or 3-D, not {data.ndim}-D" + raise Exception(er) + else: + try: + # when a position (rx,ry) is passed, get those patterns + # and put them in a stack + dc, rx, ry = data[0], data[1], data[2] + + # h5py datasets have different rules for slicing than + # numpy arrays, so we have to do this manually + if "h5py" in str(type(dc.data)): + data = np.zeros((len(rx), dc.Q_Nx, dc.Q_Ny)) + # no background subtraction + if not radial_bksb: + for i, (x, y) in enumerate(zip(rx, ry)): + data[i] = dc.data[x, y] + # with bksubtr + else: + for i, (x, y) in enumerate(zip(rx, ry)): + data[i] = dc.get_radial_bksb_dp(rx, ry) + else: + # no background subtraction + if not radial_bksb: + data = dc.data[np.array(rx), np.array(ry), :, :] + # with bksubtr + else: + data = np.zeros((len(rx), dc.Q_Nx, dc.Q_Ny)) + for i, (x, y) in enumerate(zip(rx, ry)): + data[i] = dc.get_radial_bksb_dp(x, y) + if data.ndim == 2: + mode = "dp" + elif data.ndim == 3: + mode = "dp_stack" + except: + er = f"entry {data} for `data` could not be parsed" + raise Exception(er) + + # CPU/GPU/cluster/ML-AI + + if ML: + mode = "dc_ml" + + elif mode == "datacube": + if distributed is None and CUDA is False: + mode = "dc_CPU" + elif distributed is None and CUDA is True: + if CUDA_batched is False: + mode = "dc_GPU" + else: + mode = "dc_GPU_batched" + else: + x = _parse_distributed(distributed) + connect, data_file, cluster_path, distributed_mode = x + if distributed_mode == "dask": + mode = "dc_dask" + elif distributed_mode == "ipyparallel": + mode = "dc_ipyparallel" + else: + er = f"unrecognized distributed mode {distributed_mode}" + raise Exception(er) + # overwrite if ML selected + + # select a function + fn_dict = { + "dp": _find_Bragg_disks_single, + "dp_stack": _find_Bragg_disks_stack, + "dc_CPU": _find_Bragg_disks_CPU, + "dc_GPU": _find_Bragg_disks_CUDA_unbatched, + "dc_GPU_batched": _find_Bragg_disks_CUDA_batched, + "dc_dask": _find_Bragg_disks_dask, + "dc_ipyparallel": _find_Bragg_disks_ipp, + "dc_ml": find_Bragg_disks_aiml, + } + fn = fn_dict[mode] + + # prepare kwargs + kws = {} + # distributed kwargs + if distributed is not None: + kws["connect"] = connect + kws["data_file"] = data_file + kws["cluster_path"] = cluster_path + # ML arguments + if ML is True: + kws["CUDA"] = CUDA + kws["model_path"] = ml_model_path + kws["num_attempts"] = ml_num_attempts + kws["batch_size"] = ml_batch_size + + # if radial background subtraction is requested, add to args + if radial_bksb and mode == "dc_CPU": + kws["radial_bksb"] = radial_bksb + + # run and return + ans = fn( + data, + template, + filter_function=filter_function, + corrPower=corrPower, + sigma_dp=sigma_dp, + sigma_cc=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + **kws, + ) + return ans + + +# Single diffraction pattern + + +def _find_Bragg_disks_single( + DP, + template, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="poly", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0, + relativeToPeak=0, + minPeakSpacing=0, + edgeBoundary=1, + maxNumPeaks=100, + _return_cc=False, + _template_space="real", +): + # apply filter function + er = "filter_function must be callable" + if filter_function: + assert callable(filter_function), er + DP = DP if filter_function is None else filter_function(DP) + + # check for a template + if template is None: + cc = DP + else: + # fourier transform the template + assert _template_space in ("real", "fourier") + if _template_space == "real": + template_FT = np.conj(np.fft.fft2(template)) + else: + template_FT = template + + # apply any smoothing to the data + if sigma_dp > 0: + DP = gaussian_filter(DP, sigma_dp) + + # Compute cross correlation + # _returnval = 'fourier' if subpixel == 'multicorr' else 'real' + cc = get_cross_correlation_FT( + DP, + template_FT, + corrPower, + "fourier", + ) + + # Get maxima + maxima = get_maxima_2D( + np.maximum(np.real(np.fft.ifft2(cc)), 0), + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma_cc, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + _ar_FT=cc, + ) + + # Wrap as QPoints instance + maxima = QPoints(maxima) + + # Return + if _return_cc is True: + return maxima, cc + return maxima + + +# def _get_cross_correlation_FT( +# DP, +# template_FT, +# corrPower = 1, +# _returnval = 'real' +# ): +# """ +# if _returnval is 'real', returns the real-valued cross-correlation. +# otherwise, returns the complex valued result. +# """ +# +# m = np.fft.fft2(DP) * template_FT +# cc = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m)) +# if _returnval == 'real': +# cc = np.maximum(np.real(np.fft.ifft2(cc)),0) +# return cc + + +# 3D stack of DPs + + +def _find_Bragg_disks_stack( + dp_stack, + template, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="poly", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0, + relativeToPeak=0, + minPeakSpacing=0, + edgeBoundary=1, + maxNumPeaks=100, + _template_space="real", +): + ans = [] + + for idx in range(dp_stack.shape[0]): + dp = dp_stack[idx, :, :] + peaks = _find_Bragg_disks_single( + dp, + template, + filter_function=filter_function, + corrPower=corrPower, + sigma_dp=sigma_dp, + sigma_cc=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + _template_space=_template_space, + _return_cc=False, + ) + ans.append(peaks) + + return ans + + +# Whole datacube, CPU + + +def _find_Bragg_disks_CPU( + datacube, + probe, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="multicorr", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0.005, + relativeToPeak=0, + minPeakSpacing=60, + edgeBoundary=20, + maxNumPeaks=70, + radial_bksb=False, +): + # Make the BraggVectors instance + braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape) + + # Get the template's Fourier Transform + probe_kernel_FT = np.conj(np.fft.fft2(probe)) if probe is not None else None + + # Loop over all diffraction patterns + # Compute and populate BraggVectors data + for rx, ry in tqdmnd( + datacube.R_Nx, + datacube.R_Ny, + desc="Finding Bragg Disks", + unit="DP", + unit_scale=True, + ): + # Get a diffraction pattern + + # without background subtraction + if not radial_bksb: + dp = datacube.data[rx, ry, :, :] + # and with + else: + dp = datacube.get_radial_bksb_dp(rx, ry) + + # Compute + peaks = _find_Bragg_disks_single( + dp, + template=probe_kernel_FT, + filter_function=filter_function, + corrPower=corrPower, + sigma_dp=sigma_dp, + sigma_cc=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + _return_cc=False, + _template_space="fourier", + ) + + # Populate data + braggvectors._v_uncal[rx, ry] = peaks + + # Return + return braggvectors + + +# CUDA - unbatched + + +def _find_Bragg_disks_CUDA_unbatched( + datacube, + probe, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="multicorr", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0.005, + relativeToPeak=0, + minPeakSpacing=60, + edgeBoundary=20, + maxNumPeaks=70, +): + # compute + from py4DSTEM.braggvectors.diskdetection_cuda import find_Bragg_disks_CUDA + + peaks = find_Bragg_disks_CUDA( + datacube, + probe, + filter_function=filter_function, + corrPower=corrPower, + sigma=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + batching=False, + ) + + # Populate a BraggVectors instance and return + braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape) + braggvectors._v_uncal = peaks + braggvectors._set_raw_vector_getter() + braggvectors._set_cal_vector_getter() + return braggvectors + + +# CUDA - batched + + +def _find_Bragg_disks_CUDA_batched( + datacube, + probe, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="multicorr", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0.005, + relativeToPeak=0, + minPeakSpacing=60, + edgeBoundary=20, + maxNumPeaks=70, +): + # compute + from py4DSTEM.braggvectors.diskdetection_cuda import find_Bragg_disks_CUDA + + peaks = find_Bragg_disks_CUDA( + datacube, + probe, + filter_function=filter_function, + corrPower=corrPower, + sigma=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + batching=True, + ) + + # Populate a BraggVectors instance and return + braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape) + braggvectors._v_uncal = peaks + braggvectors._set_raw_vector_getter() + braggvectors._set_cal_vector_getter() + return braggvectors + + +# Distributed - ipyparallel + + +def _find_Bragg_disks_ipp( + datacube, + probe, + connect, + data_file, + cluster_path, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="multicorr", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0.005, + relativeToPeak=0, + minPeakSpacing=60, + edgeBoundary=20, + maxNumPeaks=70, +): + # compute + from py4DSTEM.braggvectors.diskdetection_parallel import find_Bragg_disks_ipp + + peaks = find_Bragg_disks_ipp( + datacube, + probe, + filter_function=filter_function, + corrPower=corrPower, + sigma=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + ipyparallel_client_file=connect, + data_file=data_file, + cluster_path=cluster_path, + ) + + # Populate a BraggVectors instance and return + braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape) + braggvectors._v_uncal = peaks + braggvectors._set_raw_vector_getter() + braggvectors._set_cal_vector_getter() + return braggvectors + + +# Distributed - dask + + +def _find_Bragg_disks_dask( + datacube, + probe, + connect, + data_file, + cluster_path, + filter_function=None, + corrPower=1, + sigma_dp=0, + sigma_cc=2, + subpixel="multicorr", + upsample_factor=16, + minAbsoluteIntensity=0, + minRelativeIntensity=0.005, + relativeToPeak=0, + minPeakSpacing=60, + edgeBoundary=20, + maxNumPeaks=70, +): + # compute + from py4DSTEM.braggvectors.diskdetection_parallel import find_Bragg_disks_dask + + peaks = find_Bragg_disks_dask( + datacube, + probe, + filter_function=filter_function, + corrPower=corrPower, + sigma=sigma_cc, + subpixel=subpixel, + upsample_factor=upsample_factor, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + dask_client_file=connect, + data_file=data_file, + cluster_path=cluster_path, + ) + + # Populate a BraggVectors instance and return + braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape) + braggvectors._v_uncal = peaks + braggvectors._set_raw_vector_getter() + braggvectors._set_cal_vector_getter() + return braggvectors + + +def _parse_distributed(distributed): + """ + Parse the `distributed` dict argument to determine distribution behavior + """ + import os + + # parse mode (ipyparallel or dask) + if "ipyparallel" in distributed: + mode = "ipyparallel" + if "client_file" in distributed["ipyparallel"]: + connect = distributed["ipyparallel"]["client_file"] + else: + er = 'Within distributed["ipyparallel"], ' + er += 'missing key for "client_file"' + raise KeyError(er) + + try: + import ipyparallel as ipp + + c = ipp.Client(url_file=connect, timeout=30) + + if len(c.ids) == 0: + er = "No IPyParallel engines attached to cluster!" + raise RuntimeError(er) + except ImportError: + raise ImportError("Unable to import module ipyparallel!") + + elif "dask" in distributed: + mode = "dask" + if "client" in distributed["dask"]: + connect = distributed["dask"]["client"] + else: + er = 'Within distributed["dask"], missing key for "client"' + raise KeyError(er) + + else: + er = "Within distributed, you must specify 'ipyparallel' or 'dask'!" + raise KeyError(er) + + # parse data file + if "data_file" not in distributed: + er = "Missing input data file path to distributed! " + er += "Required key 'data_file'" + raise KeyError(er) + + data_file = distributed["data_file"] + + if not isinstance(data_file, str): + er = "Expected string for distributed key 'data_file', " + er += f"received {type(data_file)}" + raise TypeError(er) + if len(data_file.strip()) == 0: + er = "Empty data file path from distributed key 'data_file'" + raise ValueError(er) + elif not os.path.exists(data_file): + raise FileNotFoundError("File not found") + + # parse cluster path + if "cluster_path" in distributed: + cluster_path = distributed["cluster_path"] + + if not isinstance(cluster_path, str): + er = "distributed key 'cluster_path' must be of type str, " + er += f"received {type(cluster_path)}" + raise TypeError(er) + + if len(cluster_path.strip()) == 0: + er = "distributed key 'cluster_path' cannot be an empty string!" + raise ValueError(er) + elif not os.path.exists(cluster_path): + er = f"distributed key 'cluster_path' does not exist: {cluster_path}" + raise FileNotFoundError(er) + elif not os.path.isdir(cluster_path): + er = "distributed key 'cluster_path' is not a directory: " + er += f"{cluster_path}" + raise NotADirectoryError(er) + else: + cluster_path = None + + # return + return connect, data_file, cluster_path, mode diff --git a/py4DSTEM/datacube/diskdetection/diskdetection_aiml.py b/py4DSTEM/datacube/diskdetection/diskdetection_aiml.py new file mode 100644 index 000000000..82be36d2b --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/diskdetection_aiml.py @@ -0,0 +1,949 @@ +# Functions for finding Bragg disks using AI/ML pipeline +""" +Functions for finding Braggdisks using AI/ML method using tensorflow +""" + +import os +import glob +import json +import shutil +import numpy as np +from pathlib import Path + + +from scipy.ndimage import gaussian_filter +from time import time +from numbers import Number + +from emdfile import tqdmnd, PointList, PointListArray +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from py4DSTEM.data import QPoints +from py4DSTEM.utils import get_maxima_2D + +# from py4DSTEM.braggvectors import universal_threshold + + +def find_Bragg_disks_aiml_single_DP( + DP, + probe, + num_attempts=5, + int_window_radius=1, + predict=True, + sigma=0, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + filter_function=None, + peaks=None, + model_path=None, +): + """ + Finds the Bragg disks in single DP by AI/ML method. This method utilizes FCU-Net + to predict Bragg disks from diffraction images. + + The input DP and Probes need to be aligned before the prediction. Detected peaks within + edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks + with intensities less than minRelativeIntensity of the brightest peak in the + correlation are discarded. Then peaks which are within a distance of minPeakSpacing + of their nearest neighbor peak are found, and in each such pair the peak with the + lesser correlation intensities is removed. Finally, if the number of peaks remaining + exceeds maxNumPeaks, only the maxNumPeaks peaks with the highest correlation + intensity are retained. + + Args: + DP (ndarray): a diffraction pattern + probe (ndarray): the vacuum probe template + num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5 + Ideally, the more num_attempts the better (confident) the prediction will be + as the ML prediction utilizes Monte Carlo Dropout technique to estimate model + uncertainty using Bayesian approach. Note: increasing num_attempts will increase + the compute time significantly and it is advised to use GPU (CUDA) enabled environment + for fast prediction with num_attempts > 1 + int_window_radius (int): window radius (in pixels) for disk intensity integration over the + predicted atomic potentials array + predict (bool): Flag to determine if ML prediction is opted. + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The shape + of the returned DP must match the shape of the probe kernel (but does not + need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + peaks (PointList): For internal use. If peaks is None, the PointList of peak + positions is created here. If peaks is not None, it is the PointList that + detected peaks are added to, and must have the appropriate coords + ('qx','qy','intensity'). + model_path (str): filepath for the model weights (Tensorflow model) to load from. + By default, if the model_path is not provided, py4DSTEM will search for the + latest model stored on cloud using metadata json file. It is not recommeded to + keep track of the model path and advised to keep this argument unchanged (None) + to always search for the latest updated training model weights. + + Returns: + (PointList): the Bragg peak positions and correlation intensities + """ + try: + import crystal4D + except: + raise ImportError("Import Error: Please install crystal4D before proceeding") + try: + import tensorflow as tf + except: + raise ImportError( + "Please install tensorflow before proceeding - please check " + + "https://www.tensorflow.org/install" + + "for more information" + ) + + assert subpixel in [ + "none", + "poly", + "multicorr", + ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format( + subpixel + ) + + # Perform any prefiltering + if filter_function: + assert callable(filter_function), "filter_function must be callable" + DP = DP if filter_function is None else filter_function(DP) + + if predict: + assert ( + len(DP.shape) == 2 + ), "Dimension of single diffraction should be 2 (Qx, Qy)" + assert len(probe.shape) == 2, "Dimension of probe should be 2 (Qx, Qy)" + model = _get_latest_model(model_path=model_path) + DP = tf.expand_dims(tf.expand_dims(DP, axis=0), axis=-1) + probe = tf.expand_dims(tf.expand_dims(probe, axis=0), axis=-1) + prediction = np.zeros(shape=(1, DP.shape[1], DP.shape[2], 1)) + + for i in tqdmnd( + num_attempts, + desc="Neural network is predicting atomic potential", + unit="ATTEMPTS", + unit_scale=True, + ): + prediction += model.predict([DP, probe]) + print("Averaging over {} attempts \n".format(num_attempts)) + pred = prediction[0, :, :, 0] / num_attempts + else: + assert ( + len(DP.shape) == 2 + ), "Dimension of single diffraction should be 2 (Qx, Qy)" + pred = DP + + maxima = get_maxima_2D( + pred, + sigma=sigma, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + edgeBoundary=edgeBoundary, + relativeToPeak=relativeToPeak, + maxNumPeaks=maxNumPeaks, + minSpacing=minPeakSpacing, + subpixel=subpixel, + upsample_factor=upsample_factor, + ) + + # maxima_x, maxima_y, maxima_int = _integrate_disks(pred, maxima_x,maxima_y,maxima_int,int_window_radius=int_window_radius) + + # # Make peaks PointList + # if peaks is None: + # coords = [('qx',float),('qy',float),('intensity',float)] + # peaks = PointList(coordinates=coords) + # else: + # assert(isinstance(peaks,PointList)) + # peaks.add_tuple_of_nparrays((maxima_x,maxima_y,maxima_int)) + maxima = QPoints(maxima) + return maxima + + +def find_Bragg_disks_aiml_selected( + datacube, + probe, + Rx, + Ry, + num_attempts=5, + int_window_radius=1, + batch_size=1, + predict=True, + sigma=0, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + filter_function=None, + model_path=None, +): + """ + Finds the Bragg disks in the diffraction patterns of datacube at scan positions + (Rx,Ry) by AI/ML method. This method utilizes FCU-Net to predict Bragg + disks from diffraction images. + + Args: + datacube (datacube): a diffraction datacube + probe (ndarray): the vacuum probe template + num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5 + Ideally, the more num_attempts the better (confident) the prediction will be + as the ML prediction utilizes Monte Carlo Dropout technique to estimate model + uncertainty using Bayesian approach. Note: increasing num_attempts will increase + the compute time significantly and it is advised to use GPU (CUDA) enabled environment + for fast prediction with num_attempts > 1 + int_window_radius (int): window radius (in pixels) for disk intensity integration over the + predicted atomic potentials array + predict (bool): Flag to determine if ML prediction is opted. + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The shape + of the returned DP must match the shape of the probe kernel (but does not + need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + peaks (PointList): For internal use. If peaks is None, the PointList of peak + positions is created here. If peaks is not None, it is the PointList that + detected peaks are added to, and must have the appropriate coords + ('qx','qy','intensity'). + model_path (str): filepath for the model weights (Tensorflow model) to load from. + By default, if the model_path is not provided, py4DSTEM will search for the + latest model stored on cloud using metadata json file. It is not recommended to + keep track of the model path and advised to keep this argument unchanged (None) + to always search for the latest updated training model weights. + + Returns: + (n-tuple of PointLists, n=len(Rx)): the Bragg peak positions and + correlation intensities at each scan position (Rx,Ry). + """ + + try: + import crystal4D + except: + raise ImportError("Import Error: Please install crystal4D before proceeding") + + assert len(Rx) == len(Ry) + peaks = [] + + if predict: + model = _get_latest_model(model_path=model_path) + t0 = time() + probe = np.expand_dims( + np.repeat(np.expand_dims(probe, axis=0), len(Rx), axis=0), axis=-1 + ) + DP = np.expand_dims( + np.expand_dims(datacube.data[Rx[0], Ry[0], :, :], axis=0), axis=-1 + ) + total_DP = len(Rx) + for i in range(1, len(Rx)): + DP_ = np.expand_dims( + np.expand_dims(datacube.data[Rx[i], Ry[i], :, :], axis=0), axis=-1 + ) + DP = np.concatenate([DP, DP_], axis=0) + + prediction = np.zeros(shape=(total_DP, datacube.Q_Nx, datacube.Q_Ny, 1)) + + image_num = len(Rx) + batch_num = int(image_num // batch_size) + + for att in tqdmnd( + num_attempts, + desc="Neural network is predicting structure factors", + unit="ATTEMPTS", + unit_scale=True, + ): + for i in range(batch_num): + prediction[i * batch_size : (i + 1) * batch_size] += model.predict( + [ + DP[i * batch_size : (i + 1) * batch_size], + probe[i * batch_size : (i + 1) * batch_size], + ], + verbose=0, + ) + if (i + 1) * batch_size < image_num: + prediction[(i + 1) * batch_size :] += model.predict( + [DP[(i + 1) * batch_size :], probe[(i + 1) * batch_size :]], + verbose=0, + ) + + prediction = prediction / num_attempts + + # Loop over selected diffraction patterns + for Rx in tqdmnd( + image_num, desc="Finding Bragg Disks using AI/ML", unit="DP", unit_scale=True + ): + DP = prediction[Rx, :, :, 0] + _peaks = find_Bragg_disks_aiml_single_DP( + DP, + probe, + int_window_radius=int_window_radius, + predict=False, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + filter_function=filter_function, + model_path=model_path, + ) + peaks.append(_peaks) + t2 = time() - t0 + print( + "Analyzed {} diffraction patterns in {}h {}m {}s".format( + image_num, int(t2 / 3600), int(t2 / 60), int(t2 % 60) + ) + ) + + peaks = tuple(peaks) + return peaks + + +def find_Bragg_disks_aiml_serial( + datacube, + probe, + num_attempts=5, + int_window_radius=1, + predict=True, + batch_size=2, + sigma=0, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + global_threshold=False, + minGlobalIntensity=0.005, + metric="mean", + filter_function=None, + name="braggpeaks_raw", + model_path=None, +): + """ + Finds the Bragg disks in all diffraction patterns of datacube from AI/ML method. + When hist = True, returns histogram of intensities in the entire datacube. + + Args: + datacube (datacube): a diffraction datacube + probe (ndarray): the vacuum probe template + num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5. + Ideally, the more num_attempts the better (confident) the prediction will be + as the ML prediction utilizes Monte Carlo Dropout technique to estimate model + uncertainty using Bayesian approach. Note: increasing num_attempts will increase + the compute time significantly and it is advised to use GPU (CUDA) enabled environment + for fast prediction with num_attempts > 1 + int_window_radius (int): window radius (in pixels) for disk intensity integration over the + predicted atomic potentials array + predict (bool): Flag to determine if ML prediction is opted. + batch_size (int): batch size for Tensorflow model.predict() function, by default batch_size = 2, + Note: if you are using CPU for model.predict(), please use batch_size < 2. Future version + will implement Dask parrlelization implementation of the serial function to boost up the + performance of Tensorflow CPU predictions. Keep in mind that this funciton will take + significant amount of time to predict for all the DPs in a datacube. + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + global_threshold (bool): if True, applies global threshold based on + minGlobalIntensity and metric + minGlobalThreshold (float): the minimum allowed peak intensity, relative to the + selected metric (0-1), except in the case of 'manual' metric, in which the + threshold value based on the minimum intensity that you want thresholder + out should be set. + metric (string): the metric used to compare intensities. 'average' compares peak + intensity relative to the average of the maximum intensity in each + diffraction pattern. 'max' compares peak intensity relative to the maximum + intensity value out of all the diffraction patterns. 'median' compares peak + intensity relative to the median of the maximum intensity peaks in each + diffraction pattern. 'manual' Allows the user to threshold based on a + predetermined intensity value manually determined. In this case, + minIntensity should be an int. + name (str): name for the returned PointListArray + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The + shape of the returned DP must match the shape of the probe kernel (but does + not need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + model_path (str): filepath for the model weights (Tensorflow model) to load from. + By default, if the model_path is not provided, py4DSTEM will search for the + latest model stored on cloud using metadata json file. It is not recommended to + keep track of the model path and advised to keep this argument unchanged (None) + to always search for the latest updated training model weights. + + Returns: + (PointListArray): the Bragg peak positions and correlation intensities + """ + + try: + import crystal4D + except: + raise ImportError("Import Error: Please install crystal4D before proceeding") + + # Make the peaks PointListArray + dtype = [("qx", float), ("qy", float), ("intensity", float)] + # peaks = BraggVectors(datacube.Rshape, datacube.Qshape) + peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny)) + # check that the filtered DP is the right size for the probe kernel: + if filter_function: + assert callable(filter_function), "filter_function must be callable" + DP = ( + datacube.data[0, 0, :, :] + if filter_function is None + else filter_function(datacube.data[0, 0, :, :]) + ) + # assert np.all(DP.shape == probe.shape), 'Probe kernel shape must match filtered DP shape' + + if predict: + t0 = time() + model = _get_latest_model(model_path=model_path) + probe = np.expand_dims( + np.repeat(np.expand_dims(probe, axis=0), datacube.R_N, axis=0), axis=-1 + ) + DP = np.expand_dims( + np.reshape(datacube.data, (datacube.R_N, datacube.Q_Nx, datacube.Q_Ny)), + axis=-1, + ) + + prediction = np.zeros(shape=(datacube.R_N, datacube.Q_Nx, datacube.Q_Ny, 1)) + + image_num = datacube.R_N + batch_num = int(image_num // batch_size) + + for att in tqdmnd( + num_attempts, + desc="Neural network is predicting structure factors", + unit="ATTEMPTS", + unit_scale=True, + ): + for i in range(batch_num): + prediction[i * batch_size : (i + 1) * batch_size] += model.predict( + [ + DP[i * batch_size : (i + 1) * batch_size], + probe[i * batch_size : (i + 1) * batch_size], + ], + verbose=0, + ) + if (i + 1) * batch_size < image_num: + prediction[(i + 1) * batch_size :] += model.predict( + [DP[(i + 1) * batch_size :], probe[(i + 1) * batch_size :]], + verbose=0, + ) + + prediction = prediction / num_attempts + + prediction = np.reshape( + np.transpose(prediction, (0, 3, 1, 2)), + (datacube.R_Nx, datacube.R_Ny, datacube.Q_Nx, datacube.Q_Ny), + ) + + # Loop over all diffraction patterns + for Rx, Ry in tqdmnd( + datacube.R_Nx, + datacube.R_Ny, + desc="Finding Bragg Disks using AI/ML", + unit="DP", + unit_scale=True, + ): + DP_ = prediction[Rx, Ry, :, :] + find_Bragg_disks_aiml_single_DP( + DP_, + probe, + num_attempts=num_attempts, + int_window_radius=int_window_radius, + predict=False, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + filter_function=filter_function, + peaks=peaks.get_pointlist(Rx, Ry), + model_path=model_path, + ) + t2 = time() - t0 + print( + "Analyzed {} diffraction patterns in {}h {}m {}s".format( + datacube.R_N, int(t2 / 3600), int(t2 / 60), int(t2 % 60) + ) + ) + + if global_threshold is True: + from py4DSTEM.braggvectors import universal_threshold + + peaks = universal_threshold( + peaks, minGlobalIntensity, metric, minPeakSpacing, maxNumPeaks + ) + peaks.name = name + return peaks + + +def find_Bragg_disks_aiml( + datacube, + probe, + num_attempts=5, + int_window_radius=1, + predict=True, + batch_size=8, + sigma=0, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + name="braggpeaks_raw", + filter_function=None, + model_path=None, + distributed=None, + CUDA=True, + **kwargs, +): + """ + Finds the Bragg disks in all diffraction patterns of datacube by AI/ML method. This method + utilizes FCU-Net to predict Bragg disks from diffraction images. + + datacube (datacube): a diffraction datacube + probe (ndarray): the vacuum probe template + num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5. + Ideally, the more num_attempts the better (confident) the prediction will be + as the ML prediction utilizes Monte Carlo Dropout technique to estimate model + uncertainty using Bayesian approach. Note: increasing num_attempts will increase + the compute time significantly and it is advised to use GPU (CUDA) enabled environment + for fast prediction with num_attempts > 1 + int_window_radius (int): window radius (in pixels) for disk intensity integration over the + predicted atomic potentials array + predict (bool): Flag to determine if ML prediction is opted. + batch_size (int): batch size for Tensorflow model.predict() function, by default batch_size = 2, + Note: if you are using CPU for model.predict(), please use batch_size < 2. Future version + will implement Dask parrlelization implementation of the serial function to boost up the + performance of Tensorflow CPU predictions. Keep in mind that this funciton will take + significant amount of time to predict for all the DPs in a datacube. + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + global_threshold (bool): if True, applies global threshold based on + minGlobalIntensity and metric + minGlobalThreshold (float): the minimum allowed peak intensity, relative to the + selected metric (0-1), except in the case of 'manual' metric, in which the + threshold value based on the minimum intensity that you want thresholder + out should be set. + metric (string): the metric used to compare intensities. 'average' compares peak + intensity relative to the average of the maximum intensity in each + diffraction pattern. 'max' compares peak intensity relative to the maximum + intensity value out of all the diffraction patterns. 'median' compares peak + intensity relative to the median of the maximum intensity peaks in each + diffraction pattern. 'manual' Allows the user to threshold based on a + predetermined intensity value manually determined. In this case, + minIntensity should be an int. + name (str): name for the returned PointListArray + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The + shape of the returned DP must match the shape of the probe kernel (but does + not need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + model_path (str): filepath for the model weights (Tensorflow model) to load from. + By default, if the model_path is not provided, py4DSTEM will search for the + latest model stored on cloud using metadata json file. It is not recommended to + keep track of the model path and advised to keep this argument unchanged (None) + to always search for the latest updated training model weights. + distributed (dict): contains information for parallelprocessing using an + IPyParallel or Dask distributed cluster. Valid keys are: + * ipyparallel (dict): + * client_file (str): path to client json for connecting to your + existing IPyParallel cluster + * dask (dict): + client (object): a dask client that connects to your + existing Dask cluster + * data_file (str): the absolute path to your original data + file containing the datacube + * cluster_path (str): defaults to the working directory during processing + if distributed is None, which is the default, processing will be in serial + CUDA (bool): When True, py4DSTEM will use CUDA-enabled disk_detection_aiml function + + Returns: + (PointListArray): the Bragg peak positions and correlation intensities + """ + try: + import crystal4D + except: + raise ImportError("Please install crystal4D before proceeding") + + def _parse_distributed(distributed): + import os + + if "ipyparallel" in distributed: + if "client_file" in distributed["ipyparallel"]: + connect = distributed["ipyparallel"]["client_file"] + else: + raise KeyError( + 'Within distributed["ipyparallel"], missing key for "client_file"' + ) + + try: + import ipyparallel as ipp + + c = ipp.Client(url_file=connect, timeout=30) + + if len(c.ids) == 0: + raise RuntimeError("No IPyParallel engines attached to cluster!") + except ImportError: + raise ImportError("Unable to import module ipyparallel!") + elif "dask" in distributed: + if "client" in distributed["dask"]: + connect = distributed["dask"]["client"] + else: + raise KeyError('Within distributed["dask"], missing key for "client"') + else: + raise KeyError( + "Within distributed, you must specify 'ipyparallel' or 'dask'!" + ) + + if "data_file" not in distributed: + raise KeyError( + "Missing input data file path to distributed! Required key 'data_file'" + ) + + data_file = distributed["data_file"] + + if not isinstance(data_file, str): + raise TypeError( + "Expected string for distributed key 'data_file', received {}".format( + type(data_file) + ) + ) + if len(data_file.strip()) == 0: + raise ValueError("Empty data file path from distributed key 'data_file'") + elif not os.path.exists(data_file): + raise FileNotFoundError("File not found") + + if "cluster_path" in distributed: + cluster_path = distributed["cluster_path"] + + if not isinstance(cluster_path, str): + raise TypeError( + "distributed key 'cluster_path' must be of type str, received {}".format( + type(cluster_path) + ) + ) + + if len(cluster_path.strip()) == 0: + raise ValueError( + "distributed key 'cluster_path' cannot be an empty string!" + ) + elif not os.path.exists(cluster_path): + raise FileNotFoundError( + "distributed key 'cluster_path' does not exist: {}".format( + cluster_path + ) + ) + elif not os.path.isdir(cluster_path): + raise NotADirectoryError( + "distributed key 'cluster_path' is not a directory: {}".format( + cluster_path + ) + ) + else: + cluster_path = None + + return connect, data_file, cluster_path + + if distributed is None: + import warnings + + if not CUDA: + if _check_cuda_device_available(): + warnings.warn( + "WARNING: CUDA = False is selected but py4DSTEM found available CUDA device to speed up. Going ahead anyway with non-CUDA mode (CPU only). You may want to abort and switch to CUDA = True to speed things up... \n" + ) + if num_attempts > 1: + warnings.warn( + "WARNING: num_attempts > 1 will take significant amount of time with Non-CUDA mode ..." + ) + return find_Bragg_disks_aiml_serial( + datacube, + probe, + num_attempts=num_attempts, + int_window_radius=int_window_radius, + predict=predict, + batch_size=batch_size, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + model_path=model_path, + name=name, + filter_function=filter_function, + ) + elif _check_cuda_device_available(): + from py4DSTEM.braggvectors.diskdetection_aiml_cuda import ( + find_Bragg_disks_aiml_CUDA, + ) + + return find_Bragg_disks_aiml_CUDA( + datacube, + probe, + num_attempts=num_attempts, + int_window_radius=int_window_radius, + predict=predict, + batch_size=batch_size, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + model_path=model_path, + name=name, + filter_function=filter_function, + ) + else: + import warnings + + warnings.warn( + "WARNING: py4DSTEM attempted to speed up the process using GPUs but no CUDA enabled devices are found. Switching back to Non-CUDA (CPU only) mode (Note it will take significant amount of time to get AIML predictions for disk detection using CPUs!!!!) \n" + ) + if num_attempts > 1: + warnings.warn( + "WARNING: num_attempts > 1 will take significant amount of time with Non-CUDA mode ..." + ) + return find_Bragg_disks_aiml_serial( + datacube, + probe, + num_attempts=num_attempts, + int_window_radius=int_window_radius, + predict=predict, + batch_size=batch_size, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + model_path=model_path, + name=name, + filter_function=filter_function, + ) + + elif isinstance(distributed, dict): + raise Exception( + "{} is not yet implemented for aiml pipeline".format(type(distributed)) + ) + else: + raise Exception( + "Expected type dict or None for distributed, instead found : {}".format( + type(distributed) + ) + ) + + +def _integrate_disks(DP, maxima_x, maxima_y, maxima_int, int_window_radius=1): + """ + Integrate DP over the circular patch of pixel with radius + """ + disks = [] + img_size = DP.shape[0] + for x, y, i in zip(maxima_x, maxima_y, maxima_int): + r1, r2 = np.ogrid[-x : img_size - x, -y : img_size - y] + mask = r1**2 + r2**2 <= int_window_radius**2 + mask_arr = np.zeros((img_size, img_size)) + mask_arr[mask] = 1 + disk = DP * mask_arr + disks.append(np.average(disk)) + try: + disks = disks / max(disks) + except: + pass + return (maxima_x, maxima_y, disks) + + +def _check_cuda_device_available(): + """ + Check if GPU is available to use by python/tensorflow. + """ + + import tensorflow as tf + + tf_recog_gpus = tf.config.experimental.list_physical_devices("GPU") + + if len(tf_recog_gpus) > 0: + return True + else: + return False + + +def _get_latest_model(model_path=None): + """ + get the latest tensorflow model and model weights for disk detection + + Args: + model_path (filepath string): File path for the tensorflow models stored in local system, + if provided, disk detection will be performed loading the model provided by user. + By default, there is no need to provide any file path unless specifically required for + development/debug purpose. If None, _get_latest_model() will look up the latest model + from cloud and download and load them. + + Returns: + model: Trained tensorflow model for disk detection + """ + import crystal4D + + try: + import tensorflow as tf + except: + raise ImportError( + "Please install tensorflow before proceeding - please check " + + "https://www.tensorflow.org/install" + + "for more information" + ) + from py4DSTEM.io.google_drive_downloader import gdrive_download + + tf.keras.backend.clear_session() + + if model_path is None: + try: + os.mkdir("./tmp") + except: + pass + # download the json file with the meta data + gdrive_download( + "FCU-Net", + destination="./tmp/", + filename="model_metadata.json", + overwrite=True, + ) + with open("./tmp/model_metadata.json") as f: + metadata = json.load(f) + file_id = metadata["file_id"] + file_path = metadata["file_path"] + file_type = metadata["file_type"] + + try: + with open("./tmp/model_metadata_old.json") as f_old: + metaold = json.load(f_old) + file_id_old = metaold["file_id"] + except: + file_id_old = file_id + + if os.path.exists(file_path) and file_id == file_id_old: + print( + "Latest model weight is already available in the local system. Loading the model... \n" + ) + model_path = file_path + os.remove("./tmp/model_metadata_old.json") + os.rename("./tmp/model_metadata.json", "./tmp/model_metadata_old.json") + else: + print("Checking the latest model on the cloud... \n") + filename = file_path + file_type + filename = Path(filename) + gdrive_download(file_id, destination="./tmp", filename=filename.name) + try: + shutil.unpack_archive(filename, "./tmp", format="zip") + except: + pass + model_path = file_path + os.rename("./tmp/model_metadata.json", "./tmp/model_metadata_old.json") + print("Loading the model... \n") + + model = tf.keras.models.load_model( + model_path, + custom_objects={"lrScheduler": crystal4D.utils.utils.lrScheduler(128)}, + ) + else: + print("Loading the user provided model... \n") + model = tf.keras.models.load_model( + model_path, + custom_objects={"lrScheduler": crystal4D.utils.utils.lrScheduler(128)}, + ) + + return model diff --git a/py4DSTEM/datacube/diskdetection/diskdetection_aiml_cuda.py b/py4DSTEM/datacube/diskdetection/diskdetection_aiml_cuda.py new file mode 100644 index 000000000..bd2736719 --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/diskdetection_aiml_cuda.py @@ -0,0 +1,738 @@ +# Functions for finding Bragg disks using AI/ML pipeline (CUDA version) +""" +Functions for finding Braggdisks (AI/ML) using cupy and tensorflow-gpu +""" + +import numpy as np +from time import time + +from emdfile import tqdmnd +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from emdfile import PointList, PointListArray +from py4DSTEM.data import QPoints +from py4DSTEM.braggvectors.kernels import kernels +from py4DSTEM.braggvectors.diskdetection_aiml import _get_latest_model + +# from py4DSTEM.braggvectors.diskdetection import universal_threshold + +try: + import cupy as cp + from cupyx.scipy.ndimage import gaussian_filter +except (ModuleNotFoundError, ImportError) as e: + raise ImportError("AIML CUDA Requires cupy") from e + +try: + import tensorflow as tf +except: + raise ImportError( + "Please install tensorflow before proceeding - please check " + + "https://www.tensorflow.org/install" + + "for more information" + ) + + +def find_Bragg_disks_aiml_CUDA( + datacube, + probe, + num_attempts=5, + int_window_radius=1, + predict=True, + batch_size=8, + sigma=0, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + global_threshold=False, + minGlobalIntensity=0.005, + metric="mean", + filter_function=None, + name="braggpeaks_raw", + model_path=None, +): + """ + Finds the Bragg disks in all diffraction patterns of datacube by AI/ML method (CUDA version) + This method utilizes FCU-Net to predict Bragg disks from diffraction images. + + Args: + datacube (datacube): a diffraction datacube + probe (ndarray): the vacuum probe template + num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5. + Ideally, the more num_attempts the better (confident) the prediction will be + as the ML prediction utilizes Monte Carlo Dropout technique to estimate model + uncertainty using Bayesian approach. Note: increasing num_attempts will increase + the compute time significantly and it is advised to use GPU (CUDA) enabled environment + for fast prediction with num_attempts > 1 + int_window_radius (int): window radius (in pixels) for disk intensity integration over the + predicted atomic potentials array + predict (bool): Flag to determine if ML prediction is opted. + batch_size (int): batch size for Tensorflow model.predict() function, by default batch_size = 2, + Note: if you are using CPU for model.predict(), please use batch_size < 2. Future version + will implement Dask parrlelization implementation of the serial function to boost up the + performance of Tensorflow CPU predictions. Keep in mind that this funciton will take + significant amount of time to predict for all the DPs in a datacube. + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + global_threshold (bool): if True, applies global threshold based on + minGlobalIntensity and metric + minGlobalThreshold (float): the minimum allowed peak intensity, relative to the + selected metric (0-1), except in the case of 'manual' metric, in which the + threshold value based on the minimum intensity that you want thresholder + out should be set. + metric (string): the metric used to compare intensities. 'average' compares peak + intensity relative to the average of the maximum intensity in each + diffraction pattern. 'max' compares peak intensity relative to the maximum + intensity value out of all the diffraction patterns. 'median' compares peak + intensity relative to the median of the maximum intensity peaks in each + diffraction pattern. 'manual' Allows the user to threshold based on a + predetermined intensity value manually determined. In this case, + minIntensity should be an int. + name (str): name for the returned PointListArray + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The + shape of the returned DP must match the shape of the probe kernel (but does + not need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + model_path (str): filepath for the model weights (Tensorflow model) to load from. + By default, if the model_path is not provided, py4DSTEM will search for the + latest model stored on cloud using metadata json file. It is not recommended to + keep track of the model path and advised to keep this argument unchanged (None) + to always search for the latest updated training model weights. + + Returns: + (PointListArray): the Bragg peak positions and correlation intensities + """ + + # Make the peaks PointListArray + dtype = [("qx", float), ("qy", float), ("intensity", float)] + # peaks = BraggVectors(datacube.Rshape, datacube.Qshape) + peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny)) + + # check that the filtered DP is the right size for the probe kernel: + if filter_function: + assert callable(filter_function), "filter_function must be callable" + DP = ( + datacube.data[0, 0, :, :] + if filter_function is None + else filter_function(datacube.data[0, 0, :, :]) + ) + assert np.all( + DP.shape == probe.shape + ), "Probe kernel shape must match filtered DP shape" + + get_maximal_points = kernels["maximal_pts_float64"] + + if get_maximal_points.max_threads_per_block < DP.shape[1]: + blocks = ((np.prod(DP.shape) // get_maximal_points.max_threads_per_block + 1),) + threads = get_maximal_points.max_threads_per_block + else: + blocks = (DP.shape[0],) + threads = (DP.shape[1],) + + if predict: + t0 = time() + model = _get_latest_model(model_path=model_path) + prediction = np.zeros(shape=(datacube.R_N, datacube.Q_Nx, datacube.Q_Ny, 1)) + + image_num = datacube.R_N + batch_num = int(image_num // batch_size) + + datacube_flattened = datacube.data.view() + datacube_flattened = datacube_flattened.reshape( + datacube.R_N, datacube.Q_Nx, datacube.Q_Ny + ) + + for att in tqdmnd( + num_attempts, + desc="Neural network is predicting structure factors", + unit="ATTEMPTS", + unit_scale=True, + ): + for batch_idx in range(batch_num): + # the final batch may be smaller than the other ones: + probes_remaining = datacube.R_N - (batch_idx * batch_size) + this_batch_size = ( + probes_remaining if probes_remaining < batch_size else batch_size + ) + DP = tf.expand_dims( + datacube_flattened[ + batch_idx * batch_size : batch_idx * batch_size + + this_batch_size + ], + axis=-1, + ) + _probe = tf.expand_dims( + tf.repeat(tf.expand_dims(probe, axis=0), this_batch_size, axis=0), + axis=-1, + ) + prediction[ + batch_idx * batch_size : batch_idx * batch_size + this_batch_size + ] += model.predict([DP, _probe]) + + print("Averaging over {} attempts \n".format(num_attempts)) + prediction = prediction / num_attempts + + prediction = np.reshape( + np.transpose(prediction, (0, 3, 1, 2)), + (datacube.R_Nx, datacube.R_Ny, datacube.Q_Nx, datacube.Q_Ny), + ) + + # Loop over all diffraction patterns + for Rx, Ry in tqdmnd( + datacube.R_Nx, + datacube.R_Ny, + desc="Finding Bragg Disks using AI/ML CUDA", + unit="DP", + unit_scale=True, + ): + DP = prediction[Rx, Ry, :, :] + _find_Bragg_disks_aiml_single_DP_CUDA( + DP, + probe, + num_attempts=num_attempts, + int_window_radius=int_window_radius, + predict=False, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + filter_function=filter_function, + peaks=peaks.get_pointlist(Rx, Ry), + get_maximal_points=get_maximal_points, + blocks=blocks, + threads=threads, + ) + t2 = time() - t0 + print( + "Analyzed {} diffraction patterns in {}h {}m {}s".format( + datacube.R_N, int(t2 / 3600), int(t2 / 60), int(t2 % 60) + ) + ) + if global_threshold is True: + from py4DSTEM.braggvectors import universal_threshold + + peaks = universal_threshold( + peaks, minGlobalIntensity, metric, minPeakSpacing, maxNumPeaks + ) + peaks.name = name + return peaks + + +def _find_Bragg_disks_aiml_single_DP_CUDA( + DP, + probe, + num_attempts=5, + int_window_radius=1, + predict=True, + sigma=0, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + filter_function=None, + return_cc=False, + peaks=None, + get_maximal_points=None, + blocks=None, + threads=None, + model_path=None, + **kwargs, +): + """ + Finds the Bragg disks in single DP by AI/ML method. This method utilizes FCU-Net + to predict Bragg disks from diffraction images. + + The input DP and Probes need to be aligned before the prediction. Detected peaks within + edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks + with intensities less than minRelativeIntensity of the brightest peak in the + correlation are discarded. Then peaks which are within a distance of minPeakSpacing + of their nearest neighbor peak are found, and in each such pair the peak with the + lesser correlation intensities is removed. Finally, if the number of peaks remaining + exceeds maxNumPeaks, only the maxNumPeaks peaks with the highest correlation + intensity are retained. + + Args: + DP (ndarray): a diffraction pattern + probe (ndarray): the vacuum probe template + num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5 + Ideally, the more num_attempts the better (confident) the prediction will be + as the ML prediction utilizes Monte Carlo Dropout technique to estimate model + uncertainty using Bayesian approach. Note: increasing num_attempts will increase + the compute time significantly and it is advised to use GPU (CUDA) enabled environment + for fast prediction with num_attempts > 1 + int_window_radius (int): window radius (in pixels) for disk intensity integration over the + predicted atomic potentials array + predict (bool): Flag to determine if ML prediction is opted. + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The shape + of the returned DP must match the shape of the probe kernel (but does not + need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + peaks (PointList): For internal use. If peaks is None, the PointList of peak + positions is created here. If peaks is not None, it is the PointList that + detected peaks are added to, and must have the appropriate coords + ('qx','qy','intensity'). + model_path (str): filepath for the model weights (Tensorflow model) to load from. + By default, if the model_path is not provided, py4DSTEM will search for the + latest model stored on cloud using metadata json file. It is not recommeded to + keep track of the model path and advised to keep this argument unchanged (None) + to always search for the latest updated training model weights. + + Returns: + peaks (PointList) the Bragg peak positions and correlation intensities + """ + assert subpixel in [ + "none", + "poly", + "multicorr", + ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format( + subpixel + ) + + if predict: + assert ( + len(DP.shape) == 2 + ), "Dimension of single diffraction should be 2 (Qx, Qy)" + assert len(probe.shape) == 2, "Dimension of Probe should be 2 (Qx, Qy)" + + model = _get_latest_model(model_path=model_path) + DP = tf.expand_dims(tf.expand_dims(DP, axis=0), axis=-1) + probe = tf.expand_dims(tf.expand_dims(probe, axis=0), axis=-1) + prediction = np.zeros(shape=(1, DP.shape[1], DP.shape[2], 1)) + + for att in tqdmnd( + num_attempts, + desc="Neural network is predicting structure factors", + unit="ATTEMPTS", + unit_scale=True, + ): + print("attempt {} \n".format(att + 1)) + prediction += model.predict([DP, probe]) + print("Averaging over {} attempts \n".format(num_attempts)) + pred = cp.array(prediction[0, :, :, 0] / num_attempts, dtype="float64") + else: + assert ( + len(DP.shape) == 2 + ), "Dimension of single diffraction should be 2 (Qx, Qy)" + pred = cp.array( + DP if filter_function is None else filter_function(DP), dtype="float64" + ) + + # Find the maxima + maxima_x, maxima_y, maxima_int = get_maxima_2D_cp( + pred, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + get_maximal_points=get_maximal_points, + blocks=blocks, + threads=threads, + ) + + maxima_x, maxima_y, maxima_int = _integrate_disks_cp( + pred, maxima_x, maxima_y, maxima_int, int_window_radius=int_window_radius + ) + + # Make peaks PointList + if peaks is None: + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = PointList(coordinates=coords) + else: + assert isinstance(peaks, PointList) + peaks.add_data_by_field((maxima_x, maxima_y, maxima_int)) + + return peaks + + +def get_maxima_2D_cp( + ar, + sigma=0, + edgeBoundary=0, + minSpacing=0, + minRelativeIntensity=0, + minAbsoluteIntensity=0, + relativeToPeak=0, + maxNumPeaks=0, + subpixel="poly", + ar_FT=None, + upsample_factor=16, + get_maximal_points=None, + blocks=None, + threads=None, +): + """ + Finds the indices where the 2D array ar is a local maximum. + Optional parameters allow blurring of the array and filtering of the output; + setting each of these to 0 (default) turns off these functions. + + Accepts: + ar (ndarray) a 2D array + sigma (float) guassian blur std to apply to ar before finding the maxima + edgeBoundary (int) ignore maxima within edgeBoundary of the array edge + minSpacing (float) if two maxima are found within minSpacing, the dimmer one + is removed + minRelativeIntensity (float) maxima dimmer than minRelativeIntensity compared to the + relativeToPeak'th brightest maximum are removed + minAbsoluteIntensity (float) the minimum acceptable correlation peak intensity, + on an absolute scale + relativeToPeak (int) 0=brightest maximum. 1=next brightest, etc. + maxNumPeaks (int) return only the first maxNumPeaks maxima + subpixel (str) 'none': no subpixel fitting + (default) 'poly': polynomial interpolation of correlogram peaks + (fairly fast but not very accurate) + 'multicorr': uses the multicorr algorithm with + DFT upsampling + ar_FT (None or complex array) if subpixel=='multicorr' the + fourier transform of the image is required. It may be + passed here as a complex array. Otherwise, if ar_FT is None, + it is computed + upsample_factor (int) required iff subpixel=='multicorr' + + Returns + maxima_x (ndarray) x-coords of the local maximum, sorted by intensity. + maxima_y (ndarray) y-coords of the local maximum, sorted by intensity. + maxima_intensity (ndarray) intensity of the local maxima + """ + assert subpixel in [ + "none", + "poly", + "multicorr", + ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format( + subpixel + ) + + # Get maxima + ar = gaussian_filter(ar, sigma) + maxima_bool = cp.zeros_like(ar, dtype=bool) + sizex = ar.shape[0] + sizey = ar.shape[1] + N = sizex * sizey + get_maximal_points( + blocks, threads, (ar, maxima_bool, minAbsoluteIntensity, sizex, sizey, N) + ) + # get_maximal_points(blocks,threads,(ar,maxima_bool,sizex,sizey,N)) + + # Remove edges + if edgeBoundary > 0: + assert isinstance(edgeBoundary, (int, np.integer)) + maxima_bool[:edgeBoundary, :] = False + maxima_bool[-edgeBoundary:, :] = False + maxima_bool[:, :edgeBoundary] = False + maxima_bool[:, -edgeBoundary:] = False + elif subpixel is True: + maxima_bool[:1, :] = False + maxima_bool[-1:, :] = False + maxima_bool[:, :1] = False + maxima_bool[:, -1:] = False + + # Get indices, sorted by intensity + maxima_x, maxima_y = cp.nonzero(maxima_bool) + maxima_x = maxima_x.get() + maxima_y = maxima_y.get() + dtype = np.dtype([("x", float), ("y", float), ("intensity", float)]) + maxima = np.zeros(len(maxima_x), dtype=dtype) + maxima["x"] = maxima_x + maxima["y"] = maxima_y + + ar = ar.get() + maxima["intensity"] = ar[maxima_x, maxima_y] + maxima = np.sort(maxima, order="intensity")[::-1] + + if len(maxima) > 0: + # Remove maxima which are too close + if minSpacing > 0: + deletemask = np.zeros(len(maxima), dtype=bool) + for i in range(len(maxima)): + if deletemask[i] == False: # noqa: E712 + tooClose = ( + (maxima["x"] - maxima["x"][i]) ** 2 + + (maxima["y"] - maxima["y"][i]) ** 2 + ) < minSpacing**2 + tooClose[: i + 1] = False + deletemask[tooClose] = True + maxima = np.delete(maxima, np.nonzero(deletemask)[0]) + + # Remove maxima which are too dim + if (minRelativeIntensity > 0) & (len(maxima) > relativeToPeak): + assert isinstance(relativeToPeak, (int, np.integer)) + deletemask = ( + maxima["intensity"] / maxima["intensity"][relativeToPeak] + < minRelativeIntensity + ) + maxima = np.delete(maxima, np.nonzero(deletemask)[0]) + + # Remove maxima which are too dim, absolute scale + if minAbsoluteIntensity > 0: + deletemask = maxima["intensity"] < minAbsoluteIntensity + maxima = np.delete(maxima, np.nonzero(deletemask)[0]) + + # Remove maxima in excess of maxNumPeaks + if maxNumPeaks > 0: + assert isinstance(maxNumPeaks, (int, np.integer)) + if len(maxima) > maxNumPeaks: + maxima = maxima[:maxNumPeaks] + + # Subpixel fitting + # For all subpixel fitting, first fit 1D parabolas in x and y to 3 points (maximum, +/- 1 pixel) + if subpixel != "none": + for i in range(len(maxima)): + Ix1_ = ar[int(maxima["x"][i]) - 1, int(maxima["y"][i])] + Ix0 = ar[int(maxima["x"][i]), int(maxima["y"][i])] + Ix1 = ar[int(maxima["x"][i]) + 1, int(maxima["y"][i])] + Iy1_ = ar[int(maxima["x"][i]), int(maxima["y"][i]) - 1] + Iy0 = ar[int(maxima["x"][i]), int(maxima["y"][i])] + Iy1 = ar[int(maxima["x"][i]), int(maxima["y"][i]) + 1] + deltax = (Ix1 - Ix1_) / (4 * Ix0 - 2 * Ix1 - 2 * Ix1_) + deltay = (Iy1 - Iy1_) / (4 * Iy0 - 2 * Iy1 - 2 * Iy1_) + maxima["x"][i] += deltax + maxima["y"][i] += deltay + maxima["intensity"][i] = linear_interpolation_2D_cp( + ar, maxima["x"][i], maxima["y"][i] + ) + # Further refinement with fourier upsampling + if subpixel == "multicorr": + if ar_FT is None: + ar_FT = cp.conj(cp.fft.fft2(cp.array(ar))) + else: + ar_FT = cp.conj(ar_FT) + for ipeak in range(len(maxima["x"])): + xyShift = np.array((maxima["x"][ipeak], maxima["y"][ipeak])) + # we actually have to lose some precision and go down to half-pixel + # accuracy. this could also be done by a single upsampling at factor 2 + # instead of get_maxima_2D_cp. + xyShift[0] = np.round(xyShift[0] * 2) / 2 + xyShift[1] = np.round(xyShift[1] * 2) / 2 + + subShift = upsampled_correlation_cp(ar_FT, upsample_factor, xyShift) + maxima["x"][ipeak] = subShift[0] + maxima["y"][ipeak] = subShift[1] + + return maxima["x"], maxima["y"], maxima["intensity"] + + +def upsampled_correlation_cp(imageCorr, upsampleFactor, xyShift): + """ + Refine the correlation peak of imageCorr around xyShift by DFT upsampling using cupy. + + Args: + imageCorr (complex valued ndarray): + Complex product of the FFTs of the two images to be registered + i.e. m = np.fft.fft2(DP) * probe_kernel_FT; + imageCorr = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m)) + upsampleFactor (int): + Upsampling factor. Must be greater than 2. (To do upsampling + with factor 2, use upsampleFFT, which is faster.) + xyShift: + Location in original image coordinates around which to upsample the + FT. This should be given to exactly half-pixel precision to + replicate the initial FFT step that this implementation skips + + Returns: + (2-element np array): Refined location of the peak in image coordinates. + """ + + # ------------------------------------------------------------------------------------- + # There are two approaches to Fourier upsampling for subpixel refinement: (a) one + # can pad an (appropriately shifted) FFT with zeros and take the inverse transform, + # or (b) one can compute the DFT by matrix multiplication using modified + # transformation matrices. The former approach is straightforward but requires + # performing the FFT algorithm (which is fast) on very large data. The latter method + # trades one speedup for a slowdown elsewhere: the matrix multiply steps are expensive + # but we operate on smaller matrices. Since we are only interested in a very small + # region of the FT around a peak of interest, we use the latter method to get + # a substantial speedup and enormous decrease in memory requirement. This + # "DFT upsampling" approach computes the transformation matrices for the matrix- + # multiply DFT around a small 1.5px wide region in the original `imageCorr`. + + # Following the matrix multiply DFT we use parabolic subpixel fitting to + # get even more precision! (below 1/upsampleFactor pixels) + + # NOTE: previous versions of multiCorr operated in two steps: using the zero- + # padding upsample method for a first-pass factor-2 upsampling, followed by the + # DFT upsampling (at whatever user-specified factor). I have implemented it + # differently, to better support iterating over multiple peaks. **The DFT is always + # upsampled around xyShift, which MUST be specified to HALF-PIXEL precision + # (no more, no less) to replicate the behavior of the factor-2 step.** + # (It is possible to refactor this so that peak detection is done on a Fourier + # upsampled image rather than using the parabolic subpixel and rounding as now... + # I like keeping it this way because all of the parameters and logic will be identical + # to the other subpixel methods.) + # ------------------------------------------------------------------------------------- + + assert upsampleFactor > 2 + + xyShift[0] = np.round(xyShift[0] * upsampleFactor) / upsampleFactor + xyShift[1] = np.round(xyShift[1] * upsampleFactor) / upsampleFactor + + globalShift = np.fix(np.ceil(upsampleFactor * 1.5) / 2) + + upsampleCenter = globalShift - upsampleFactor * xyShift + + imageCorrUpsample = cp.conj( + dftUpsample_cp(imageCorr, upsampleFactor, upsampleCenter) + ).get() + + xySubShift = np.unravel_index(imageCorrUpsample.argmax(), imageCorrUpsample.shape) + + # add a subpixel shift via parabolic fitting + try: + icc = np.real( + imageCorrUpsample[ + xySubShift[0] - 1 : xySubShift[0] + 2, + xySubShift[1] - 1 : xySubShift[1] + 2, + ] + ) + dx = (icc[2, 1] - icc[0, 1]) / (4 * icc[1, 1] - 2 * icc[2, 1] - 2 * icc[0, 1]) + dy = (icc[1, 2] - icc[1, 0]) / (4 * icc[1, 1] - 2 * icc[1, 2] - 2 * icc[1, 0]) + except: + dx, dy = ( + 0, + 0, + ) # this is the case when the peak is near the edge and one of the above values does not exist + + xySubShift = xySubShift - globalShift + + xyShift = xyShift + (xySubShift + np.array([dx, dy])) / upsampleFactor + + return xyShift + + +def dftUpsample_cp(imageCorr, upsampleFactor, xyShift): + """ + This performs a matrix multiply DFT around a small neighboring region of the inital + correlation peak. By using the matrix multiply DFT to do the Fourier upsampling, the + efficiency is greatly improved. This is adapted from the subfuction dftups found in + the dftregistration function on the Matlab File Exchange. + + https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation + + The matrix multiplication DFT is from: + + Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel + image registration algorithms," Opt. Lett. 33, 156-158 (2008). + http://www.sciencedirect.com/science/article/pii/S0045790612000778 + + Args: + imageCorr (complex valued ndarray): + Correlation image between two images in Fourier space. + upsampleFactor (int): + Scalar integer of how much to upsample. + xyShift (list of 2 floats): + Coordinates in the UPSAMPLED GRID around which to upsample. + These must be single-pixel IN THE UPSAMPLED GRID + + Returns: + (ndarray): + Upsampled image from region around correlation peak. + """ + imageSize = imageCorr.shape + pixelRadius = 1.5 + numRow = np.ceil(pixelRadius * upsampleFactor) + numCol = numRow + + colKern = cp.exp( + (-1j * 2 * cp.pi / (imageSize[1] * upsampleFactor)) + * cp.outer( + (cp.fft.ifftshift((cp.arange(imageSize[1]))) - cp.floor(imageSize[1] / 2)), + (cp.arange(numCol) - xyShift[1]), + ) + ) + + rowKern = cp.exp( + (-1j * 2 * cp.pi / (imageSize[0] * upsampleFactor)) + * cp.outer( + (cp.arange(numRow) - xyShift[0]), + (cp.fft.ifftshift(cp.arange(imageSize[0])) - cp.floor(imageSize[0] / 2)), + ) + ) + + imageUpsample = cp.real(rowKern @ imageCorr @ colKern) + return imageUpsample + + +def linear_interpolation_2D_cp(ar, x, y): + """ + Calculates the 2D linear interpolation of array ar at position x,y using the four + nearest array elements. + """ + x0, x1 = int(np.floor(x)), int(np.ceil(x)) + y0, y1 = int(np.floor(y)), int(np.ceil(y)) + dx = x - x0 + dy = y - y0 + return ( + (1 - dx) * (1 - dy) * ar[x0, y0] + + (1 - dx) * dy * ar[x0, y1] + + dx * (1 - dy) * ar[x1, y0] + + dx * dy * ar[x1, y1] + ) + + +def _integrate_disks_cp(DP, maxima_x, maxima_y, maxima_int, int_window_radius=1): + disks = [] + DP = cp.asnumpy(DP) + img_size = DP.shape[0] + for x, y, i in zip(maxima_x, maxima_y, maxima_int): + r1, r2 = np.ogrid[-x : img_size - x, -y : img_size - y] + mask = r1**2 + r2**2 <= int_window_radius**2 + mask_arr = np.zeros((img_size, img_size)) + mask_arr[mask] = 1 + disk = DP * mask_arr + disks.append(np.average(disk)) + try: + disks = disks / max(disks) + except: + pass + return (maxima_x, maxima_y, disks) diff --git a/py4DSTEM/datacube/diskdetection/diskdetection_cuda.py b/py4DSTEM/datacube/diskdetection/diskdetection_cuda.py new file mode 100644 index 000000000..ddea4d9ad --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/diskdetection_cuda.py @@ -0,0 +1,718 @@ +""" +Functions for finding Braggdisks using cupy + +""" + +import numpy as np + +import cupy as cp +from cupyx.scipy.ndimage import gaussian_filter +import cupyx.scipy.fft as cufft +from time import time +import numba + +from emdfile import tqdmnd +from py4DSTEM import PointList, PointListArray +from py4DSTEM.braggvectors.kernels import kernels + + +def find_Bragg_disks_CUDA( + datacube, + probe, + corrPower=1, + sigma=2, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0.0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + filter_function=None, + name="braggpeaks_raw", + batching=True, +): + """ + Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or + phase correlation with probe. When hist = True, returns histogram of intensities in + the entire datacube. + + Args: + DP (ndarray): a diffraction pattern + probe (ndarray): the vacuum probe template, in real space. + corrPower (float between 0 and 1, inclusive): the cross correlation power. A + value of 1 corresponds to a cross correaltion, and 0 corresponds to a + phase correlation, with intermediate values giving various hybrids. + sigma (float): the standard deviation for the gaussian smoothing applied to + the cross correlation + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the brightest peak + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + global_threshold (bool): if True, applies global threshold based on + minGlobalIntensity and metric + minGlobalThreshold (float): the minimum allowed peak intensity, relative to the + selected metric (0-1), except in the case of 'manual' metric, in which the + threshold value based on the minimum intensity that you want thresholder + out should be set. + metric (string): the metric used to compare intensities. 'average' compares peak + intensity relative to the average of the maximum intensity in each + diffraction pattern. 'max' compares peak intensity relative to the maximum + intensity value out of all the diffraction patterns. 'median' compares peak + intensity relative to the median of the maximum intensity peaks in each + diffraction pattern. 'manual' Allows the user to threshold based on a + predetermined intensity value manually determined. In this case, + minIntensity should be an int. + name (str): name for the returned PointListArray + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The + shape of the returned DP must match the shape of the probe kernel (but does + not need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + batching (bool): Whether to batch the FFT cross correlation steps. + + Returns: + (PointListArray): the Bragg peak positions and correlation intensities + """ + + # Make the peaks PointListArray + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = PointListArray(dtype=coords, shape=(datacube.R_Nx, datacube.R_Ny)) + + # check that the filtered DP is the right size for the probe kernel: + if filter_function: + assert callable(filter_function), "filter_function must be callable" + DP = ( + datacube.data[0, 0, :, :] + if filter_function is None + else filter_function(datacube.data[0, 0, :, :]) + ) + assert np.all( + DP.shape == probe.shape + ), "Probe kernel shape must match filtered DP shape" + + # Get the probe kernel FT as a cupy array + probe_kernel_FT = cp.conj(cp.fft.fft2(cp.array(probe))).astype(cp.complex64) + bytes_per_pattern = probe_kernel_FT.nbytes + + # get the maximal array kernel + # if probe_kernel_FT.dtype == 'float64': + # get_maximal_points = kernels['maximal_pts_float64'] + # elif probe_kernel_FT.dtype == 'float32': + # get_maximal_points = kernels['maximal_pts_float32'] + # else: + # raise TypeError("Maximal kernel only valid for float32 and float64 types...") + get_maximal_points = kernels["maximal_pts_float32"] + + if get_maximal_points.max_threads_per_block < DP.shape[1]: + # naive blocks/threads will not work, figure out an OK distribution + blocks = ((np.prod(DP.shape) // get_maximal_points.max_threads_per_block + 1),) + threads = (get_maximal_points.max_threads_per_block,) + else: + blocks = (DP.shape[0],) + threads = (DP.shape[1],) + + t0 = time() + if batching: + # compute the batch size based on available VRAM: + max_num_bytes = cp.cuda.Device().mem_info[0] + # use a fudge factor to leave room for the fourier transformed data + # I have set this at 10, which results in underutilization of + # VRAM, because this yielded better performance in my testing + batch_size = max_num_bytes // (bytes_per_pattern * 10) + num_batches = datacube.R_N // batch_size + 1 + + print(f"Using {num_batches} batches of {batch_size} patterns each...") + + # allocate array for batch of DPs + batched_subcube = cp.zeros( + (batch_size, datacube.Q_Nx, datacube.Q_Ny), dtype=cp.float32 + ) + + for batch_idx in tqdmnd( + range(num_batches), desc="Finding Bragg disks in batches", unit="batch" + ): + # the final batch may be smaller than the other ones: + probes_remaining = datacube.R_N - (batch_idx * batch_size) + this_batch_size = ( + probes_remaining if probes_remaining < batch_size else batch_size + ) + + # fill in diffraction patterns, with filtering + for subbatch_idx in range(this_batch_size): + patt_idx = batch_idx * batch_size + subbatch_idx + rx, ry = np.unravel_index(patt_idx, (datacube.R_Nx, datacube.R_Ny)) + batched_subcube[subbatch_idx, :, :] = cp.array( + ( + datacube.data[rx, ry, :, :] + if filter_function is None + else filter_function(datacube.data[rx, ry, :, :]) + ), + dtype=cp.float32, + ) + + # Perform the FFT and multiplication by probe_kernel on the batched array + batched_crosscorr = ( + cufft.fft2(batched_subcube, overwrite_x=True) + * probe_kernel_FT[None, :, :] + ) + + # Iterate over the patterns in the batch and do the Bragg disk stuff + for subbatch_idx in range(this_batch_size): + patt_idx = batch_idx * batch_size + subbatch_idx + rx, ry = np.unravel_index(patt_idx, (datacube.R_Nx, datacube.R_Ny)) + + subFFT = batched_crosscorr[subbatch_idx] + ccc = cp.abs(subFFT) ** corrPower * cp.exp(1j * cp.angle(subFFT)) + cc = cp.maximum(cp.real(cp.fft.ifft2(ccc)), 0) + + _find_Bragg_disks_single_DP_FK_CUDA( + None, + None, + ccc=ccc, + cc=cc, + corrPower=corrPower, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + filter_function=filter_function, + peaks=peaks.get_pointlist(rx, ry), + get_maximal_points=get_maximal_points, + blocks=blocks, + threads=threads, + ) + + # clean up + del batched_subcube, batched_crosscorr, subFFT, cc, ccc + cp.get_default_memory_pool().free_all_blocks() + + else: + # Loop over all diffraction patterns + for Rx, Ry in tqdmnd( + datacube.R_Nx, + datacube.R_Ny, + desc="Finding Bragg Disks", + unit="DP", + unit_scale=True, + ): + DP = datacube.data[Rx, Ry, :, :] + _find_Bragg_disks_single_DP_FK_CUDA( + DP, + probe_kernel_FT, + corrPower=corrPower, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minPeakSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + upsample_factor=upsample_factor, + filter_function=filter_function, + peaks=peaks.get_pointlist(Rx, Ry), + get_maximal_points=get_maximal_points, + blocks=blocks, + threads=threads, + ) + t = time() - t0 + print( + f"Analyzed {datacube.R_N} diffraction patterns in {t//3600}h {t % 3600 // 60}m {t % 60:.2f}s\n(avg. speed {datacube.R_N/t:0.4f} patterns per second)".format() + ) + peaks.name = name + return peaks + + +def _find_Bragg_disks_single_DP_FK_CUDA( + DP, + probe_kernel_FT, + corrPower=1, + sigma=2, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0.0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + filter_function=None, + return_cc=False, + peaks=None, + get_maximal_points=None, + blocks=None, + threads=None, + ccc=None, + cc=None, +): + """ + Finds the Bragg disks in DP by cross, hybrid, or phase correlation with probe_kernel_FT. + + After taking the cross/hybrid/phase correlation, a gaussian smoothing is applied + with standard deviation sigma, and all local maxima are found. Detected peaks within + edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks with + intensities less than minRelativeIntensity of the brightest peak in the correaltion are + discarded. Then peaks which are within a distance of minPeakSpacing of their nearest neighbor + peak are found, and in each such pair the peak with the lesser correlation intensities is + removed. Finally, if the number of peaks remaining exceeds maxNumPeaks, only the maxNumPeaks + peaks with the highest correlation intensity are retained. + + IMPORTANT NOTE: the argument probe_kernel_FT is related to the probe kernels generated by + functions like get_probe_kernel() by: + + probe_kernel_FT = np.conj(np.fft.fft2(probe_kernel)) + + if this function is simply passed a probe kernel, the results will not be meaningful! To run + on a single DP while passing the real space probe kernel as an argument, use + find_Bragg_disks_single_DP(). + + Accepts: + DP (ndarray) a diffraction pattern + probe_kernel_FT (cparray) the vacuum probe template, in Fourier space. Related to the + real space probe kernel by probe_kernel_FT = F(probe_kernel)*, where F + indicates a Fourier Transform and * indicates complex conjugation. + corrPower (float between 0 and 1, inclusive) the cross correlation power. A + value of 1 corresponds to a cross correaltion, and 0 corresponds to a + phase correlation, with intermediate values giving various hybrids. + sigma (float) the standard deviation for the gaussian smoothing applied to + the cross correlation + edgeBoundary (int) minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float) the minimum acceptable correlation peak intensity, relative to + the intensity of the relativeToPeak'th peak + relativeToPeak (int) specifies the peak against which the minimum relative intensity + is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float) the minimum acceptable spacing between detected peaks + maxNumPeaks (int) the maximum number of peaks to return + subpixel (str) 'none': no subpixel fitting + (default) 'poly': polynomial interpolation of correlogram peaks + (fairly fast but not very accurate) + 'multicorr': uses the multicorr algorithm with + DFT upsampling + upsample_factor (int) upsampling factor for subpixel fitting (only used when subpixel='multicorr') + filter_function (callable) filtering function to apply to each diffraction pattern before peakfinding. + Must be a function of only one argument (the diffraction pattern) and return + the filtered diffraction pattern. + The shape of the returned DP must match the shape of the probe kernel (but does + not need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk detection, + the function must be able to be pickled with by dill. + return_cc (bool) if True, return the cross correlation + peaks (PointList) For internal use. + If peaks is None, the PointList of peak positions is created here. + If peaks is not None, it is the PointList that detected peaks are added + to, and must have the appropriate coords ('qx','qy','intensity'). + ccc and cc: Precomputed complex and real-IFFT cross correlations. Used when called + in batched mode only, causing local calculation of those to be skipped + + Returns: + peaks (PointList) the Bragg peak positions and correlation intensities + """ + + # if we are in batching mode, cc and ccc will be provided. else, compute it + if ccc is None: + # Perform any prefiltering + DP = cp.array( + DP if filter_function is None else filter_function(DP), dtype=cp.float32 + ) + + # Get the cross correlation + if subpixel in ("none", "poly"): + cc = get_cross_correlation_fk(DP, probe_kernel_FT, corrPower) + ccc = None + # for multicorr subpixel fitting, we need both the real and complex cross correlation + else: + ccc = get_cross_correlation_fk( + DP, probe_kernel_FT, corrPower, returnval="fourier" + ) + cc = cp.maximum(cp.real(cp.fft.ifft2(ccc)), 0) + + # Find the maxima + maxima_x, maxima_y, maxima_int = get_maxima_2D( + cc, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=subpixel, + ar_FT=ccc, + upsample_factor=upsample_factor, + get_maximal_points=get_maximal_points, + blocks=blocks, + threads=threads, + ) + + # Make peaks PointList + if peaks is None: + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = PointList(coordinates=coords) + peaks.add_data_by_field((maxima_x, maxima_y, maxima_int)) + + if return_cc: + return peaks, gaussian_filter(cc, sigma) + else: + return peaks + + +def get_cross_correlation_fk(ar, fourierkernel, corrPower=1, returnval="cc"): + """ + Calculates the cross correlation of ar with fourierkernel. + Here, fourierkernel = np.conj(np.fft.fft2(kernel)); speeds up computation when the same + kernel is to be used for multiple cross correlations. + corrPower specifies the correlation type, where 1 is a cross correlation, 0 is a phase + correlation, and values in between are hybrids. + + The return value depends on the argument `returnval`: + if return=='cc' (default), returns the real part of the cross correlation in real + space. + if return=='fourier', returns the output in Fourier space, before taking the + inverse transform. + """ + m = cp.fft.fft2(ar) * fourierkernel + ccc = cp.abs(m) ** (corrPower) * cp.exp(1j * cp.angle(m)) + if returnval == "fourier": + return ccc + else: + return cp.real(cp.fft.ifft2(ccc)) + + +def get_maxima_2D( + ar, + sigma=0, + edgeBoundary=0, + minSpacing=0, + minRelativeIntensity=0, + minAbsoluteIntensity=0, + relativeToPeak=0, + maxNumPeaks=0, + subpixel="poly", + ar_FT=None, + upsample_factor=16, + get_maximal_points=None, + blocks=None, + threads=None, +): + """ + Finds the indices where the 2D array ar is a local maximum. + Optional parameters allow blurring of the array and filtering of the output; + setting each of these to 0 (default) turns off these functions. + + Accepts: + ar (ndarray) a 2D array + sigma (float) guassian blur std to applyu to ar before finding the maxima + edgeBoundary (int) ignore maxima within edgeBoundary of the array edge + minSpacing (float) if two maxima are found within minSpacing, the dimmer one + is removed + minRelativeIntensity (float) maxima dimmer than minRelativeIntensity compared to the + relativeToPeak'th brightest maximum are removed + relativeToPeak (int) 0=brightest maximum. 1=next brightest, etc. + maxNumPeaks (int) return only the first maxNumPeaks maxima + subpixel (str) 'none': no subpixel fitting + (default) 'poly': polynomial interpolation of correlogram peaks + (fairly fast but not very accurate) + 'multicorr': uses the multicorr algorithm with + DFT upsampling + ar_FT (None or complex array) if subpixel=='multicorr' the + fourier transform of the image is required. It may be + passed here as a complex array. Otherwise, if ar_FT is None, + it is computed + upsample_factor (int) required iff subpixel=='multicorr' + + Returns + maxima_x (ndarray) x-coords of the local maximum, sorted by intensity. + maxima_y (ndarray) y-coords of the local maximum, sorted by intensity. + maxima_intensity (ndarray) intensity of the local maxima + """ + + # Get maxima + ar = gaussian_filter(ar, sigma) + maxima_bool = cp.zeros_like(ar, dtype=bool) + sizex = ar.shape[0] + sizey = ar.shape[1] + N = sizex * sizey + get_maximal_points( + blocks, threads, (ar, maxima_bool, minAbsoluteIntensity, sizex, sizey, N) + ) + + # Remove edges + if edgeBoundary > 0: + maxima_bool[:edgeBoundary, :] = False + maxima_bool[-edgeBoundary:, :] = False + maxima_bool[:, :edgeBoundary] = False + maxima_bool[:, -edgeBoundary:] = False + elif subpixel is True: + maxima_bool[:1, :] = False + maxima_bool[-1:, :] = False + maxima_bool[:, :1] = False + maxima_bool[:, -1:] = False + + # Get indices, sorted by intensity + maxima_x, maxima_y = cp.nonzero(maxima_bool) + maxima_x = maxima_x.get() + maxima_y = maxima_y.get() + dtype = np.dtype([("x", float), ("y", float), ("intensity", float)]) + maxima = np.zeros(len(maxima_x), dtype=dtype) + maxima["x"] = maxima_x + maxima["y"] = maxima_y + + ar = ar.get() + maxima["intensity"] = ar[maxima_x, maxima_y] + maxima = np.sort(maxima, order="intensity")[::-1] + + if len(maxima) > 0: + # Remove maxima which are too close + if minSpacing > 0: + deletemask = np.zeros(len(maxima), dtype=bool) + for i in range(len(maxima)): + if deletemask[i] == False: # noqa: E712 + tooClose = ( + (maxima["x"] - maxima["x"][i]) ** 2 + + (maxima["y"] - maxima["y"][i]) ** 2 + ) < minSpacing**2 + tooClose[: i + 1] = False + deletemask[tooClose] = True + maxima = np.delete(maxima, np.nonzero(deletemask)[0]) + + # Remove maxima which are too dim + if (minRelativeIntensity > 0) & (len(maxima) > relativeToPeak): + deletemask = ( + maxima["intensity"] / maxima["intensity"][relativeToPeak] + < minRelativeIntensity + ) + maxima = np.delete(maxima, np.nonzero(deletemask)[0]) + + # Remove maxima which are too dim, absolute scale + if minAbsoluteIntensity > 0: + deletemask = maxima["intensity"] < minAbsoluteIntensity + maxima = np.delete(maxima, np.nonzero(deletemask)[0]) + + # Remove maxima in excess of maxNumPeaks + if maxNumPeaks is not None and maxNumPeaks > 0: + if len(maxima) > maxNumPeaks: + maxima = maxima[:maxNumPeaks] + + # Subpixel fitting + # For all subpixel fitting, first fit 1D parabolas in x and y to 3 points (maximum, +/- 1 pixel) + if subpixel != "none": + for i in range(len(maxima)): + Ix1_ = ar[int(maxima["x"][i]) - 1, int(maxima["y"][i])] + Ix0 = ar[int(maxima["x"][i]), int(maxima["y"][i])] + Ix1 = ar[int(maxima["x"][i]) + 1, int(maxima["y"][i])] + Iy1_ = ar[int(maxima["x"][i]), int(maxima["y"][i]) - 1] + Iy0 = ar[int(maxima["x"][i]), int(maxima["y"][i])] + Iy1 = ar[int(maxima["x"][i]), int(maxima["y"][i]) + 1] + deltax = (Ix1 - Ix1_) / (4 * Ix0 - 2 * Ix1 - 2 * Ix1_) + deltay = (Iy1 - Iy1_) / (4 * Iy0 - 2 * Iy1 - 2 * Iy1_) + maxima["x"][i] += deltax if np.abs(deltax) <= 1.0 else 0.0 + maxima["y"][i] += deltay if np.abs(deltay) <= 1.0 else 0.0 + maxima["intensity"][i] = linear_interpolation_2D( + ar, maxima["x"][i], maxima["y"][i] + ) + # Further refinement with fourier upsampling + if subpixel == "multicorr": + ar_FT = cp.conj(ar_FT) + + xyShift = np.vstack((maxima["x"], maxima["y"])).T + # we actually have to lose some precision and go down to half-pixel + # accuracy. this could also be done by a single upsampling at factor 2 + # instead of get_maxima_2D. + xyShift = cp.array(np.round(xyShift * 2.0) / 2) + + subShift = upsampled_correlation(ar_FT, upsample_factor, xyShift).get() + maxima["x"] = subShift[:, 0] + maxima["y"] = subShift[:, 1] + + return maxima["x"], maxima["y"], maxima["intensity"] + + +def upsampled_correlation(imageCorr, upsampleFactor, xyShift): + """ + Refine the correlation peak of imageCorr around xyShift by DFT upsampling. + + There are two approaches to Fourier upsampling for subpixel refinement: (a) one + can pad an (appropriately shifted) FFT with zeros and take the inverse transform, + or (b) one can compute the DFT by matrix multiplication using modified + transformation matrices. The former approach is straightforward but requires + performing the FFT algorithm (which is fast) on very large data. The latter method + trades one speedup for a slowdown elsewhere: the matrix multiply steps are expensive + but we operate on smaller matrices. Since we are only interested in a very small + region of the FT around a peak of interest, we use the latter method to get + a substantial speedup and enormous decrease in memory requirement. This + "DFT upsampling" approach computes the transformation matrices for the matrix- + multiply DFT around a small 1.5px wide region in the original `imageCorr`. + + Following the matrix multiply DFT we use parabolic subpixel fitting to + get even more precision! (below 1/upsampleFactor pixels) + + NOTE: previous versions of multiCorr operated in two steps: using the zero- + padding upsample method for a first-pass factor-2 upsampling, followed by the + DFT upsampling (at whatever user-specified factor). I have implemented it + differently, to better support iterating over multiple peaks. **The DFT is always + upsampled around xyShift, which MUST be specified to HALF-PIXEL precision + (no more, no less) to replicate the behavior of the factor-2 step.** + (It is possible to refactor this so that peak detection is done on a Fourier + upsampled image rather than using the parabolic subpixel and rounding as now... + I like keeping it this way because all of the parameters and logic will be identical + to the other subpixel methods.) + + + Args: + imageCorr (complex valued ndarray): + Complex product of the FFTs of the two images to be registered + i.e. m = np.fft.fft2(DP) * probe_kernel_FT; + imageCorr = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m)) + upsampleFactor (int): + Upsampling factor. Must be greater than 2. (To do upsampling + with factor 2, use upsampleFFT, which is faster.) + xyShift: + Array of points around which to upsample, with shape [N-points, 2] + + Returns: + (N_points, 2) cupy ndarray: Refined locations of the peaks in image coordinates. + """ + + xyShift = (cp.round(xyShift * upsampleFactor) / upsampleFactor).astype(cp.float32) + + globalShift = np.fix(np.ceil(upsampleFactor * 1.5) / 2) + + upsampleCenter = globalShift - upsampleFactor * xyShift + + imageCorrUpsample = dftUpsample(imageCorr, upsampleFactor, upsampleCenter).get() + + xSubShift, ySubShift = np.unravel_index( + imageCorrUpsample.reshape(imageCorrUpsample.shape[0], -1).argmax(axis=1), + imageCorrUpsample.shape[1:3], + ) + + # add a subpixel shift via parabolic fitting, serially for each peak + for idx in range(xSubShift.shape[0]): + try: + icc = np.real( + imageCorrUpsample[ + idx, + xSubShift[idx] - 1 : xSubShift[idx] + 2, + ySubShift[idx] - 1 : ySubShift[idx] + 2, + ] + ) + dx = (icc[2, 1] - icc[0, 1]) / ( + 4 * icc[1, 1] - 2 * icc[2, 1] - 2 * icc[0, 1] + ) + dy = (icc[1, 2] - icc[1, 0]) / ( + 4 * icc[1, 1] - 2 * icc[1, 2] - 2 * icc[1, 0] + ) + except: + dx, dy = ( + 0, + 0, + ) # this is the case when the peak is near the edge and one of the above values does not exist + + xyShift[idx] = ( + xyShift[idx] + + (cp.array([xSubShift[idx] + dx, ySubShift[idx] + dy]) - globalShift) + / upsampleFactor + ) + + return xyShift + + +def dftUpsample(imageCorr, upsampleFactor, xyShift): + """ + This performs a matrix multiply DFT around a small neighboring region of the inital + correlation peak. By using the matrix multiply DFT to do the Fourier upsampling, the + efficiency is greatly improved. This is adapted from the subfuction dftups found in + the dftregistration function on the Matlab File Exchange. + + https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation + + The matrix multiplication DFT is from: + + Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel + image registration algorithms," Opt. Lett. 33, 156-158 (2008). + http://www.sciencedirect.com/science/article/pii/S0045790612000778 + + Args: + imageCorr (complex valued ndarray): + Correlation image between two images in Fourier space. + upsampleFactor (int): + Scalar integer of how much to upsample. + xyShift (N_points,2) cp.ndarray, locations to upsample around: + Coordinates in the UPSAMPLED GRID around which to upsample. + These must be single-pixel IN THE UPSAMPLED GRID + + Returns: + (ndarray): + Stack of upsampled images from region around correlation peak. + """ + N_pts = xyShift.shape[0] + imageSize = imageCorr.shape + pixelRadius = 1.5 + kernel_size = int(np.ceil(pixelRadius * upsampleFactor)) + + colKern = cp.zeros( + (N_pts, imageSize[1], kernel_size), dtype=cp.complex64 + ) # N_pts * image_size[1] * kernel_size + rowKern = cp.zeros( + (N_pts, kernel_size, imageSize[0]), dtype=cp.complex64 + ) # N_pts * kernel_size * image_size[0] + + # Fill in the DFT arrays using the CUDA kernels + multicorr_col_kernel = kernels["multicorr_col_kernel"] + blocks = ( + (np.prod(colKern.shape) // multicorr_col_kernel.max_threads_per_block + 1), + ) + threads = (multicorr_col_kernel.max_threads_per_block,) + multicorr_col_kernel( + blocks, threads, (colKern, xyShift, N_pts, *imageSize, upsampleFactor) + ) + + multicorr_row_kernel = kernels["multicorr_row_kernel"] + blocks = ( + (np.prod(rowKern.shape) // multicorr_row_kernel.max_threads_per_block + 1), + ) + threads = (multicorr_row_kernel.max_threads_per_block,) + multicorr_row_kernel( + blocks, threads, (rowKern, xyShift, N_pts, *imageSize, upsampleFactor) + ) + + # Apply the DFT arrays to the correlation image + imageUpsample = cp.real(rowKern @ imageCorr @ colKern) + return imageUpsample + + +@numba.jit(nopython=True) +def linear_interpolation_2D(ar, x, y): + """ + Calculates the 2D linear interpolation of array ar at position x,y using the four + nearest array elements. + """ + x0, x1 = int(np.floor(x)), int(np.ceil(x)) + y0, y1 = int(np.floor(y)), int(np.ceil(y)) + dx = x - x0 + dy = y - y0 + return ( + (1 - dx) * (1 - dy) * ar[x0, y0] + + (1 - dx) * dy * ar[x0, y1] + + dx * (1 - dy) * ar[x1, y0] + + dx * dy * ar[x1, y1] + ) diff --git a/py4DSTEM/datacube/diskdetection/diskdetection_parallel.py b/py4DSTEM/datacube/diskdetection/diskdetection_parallel.py new file mode 100644 index 000000000..32ff1c520 --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/diskdetection_parallel.py @@ -0,0 +1,577 @@ +# stdlib +import os +import tempfile +from time import time + +# 3rd party +import numpy as np +import dill + +# local +import py4DSTEM +from emdfile import PointListArray + + +def _find_Bragg_disks_single_DP_FK( + DP, + probe_kernel_FT, + corrPower=1, + sigma=2, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="multicorr", + upsample_factor=16, + filter_function=None, + return_cc=False, + peaks=None, +): + """ + Mirror of diskdetection.find_Bragg_disks_single_DP_FK with explicit imports for + remote execution. + + Finds the Bragg disks in DP by cross, hybrid, or phase correlation with + probe_kernel_FT. + + After taking the cross/hybrid/phase correlation, a gaussian smoothing is applied + with standard deviation sigma, and all local maxima are found. Detected peaks within + edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks + with intensities less than minRelativeIntensity of the brightest peak in the + correaltion are discarded. Then peaks which are within a distance of minPeakSpacing + of their nearest neighbor peak are found, and in each such pair the peak with the + lesser correlation intensities is removed. Finally, if the number of peaks remaining + exceeds maxNumPeaks, only the maxNumPeaks peaks with the highest correlation + intensity are retained. + + IMPORTANT NOTE: the argument probe_kernel_FT is related to the probe kernels + generated by functions like get_probe_kernel() by: + + >>> probe_kernel_FT = np.conj(np.fft.fft2(probe_kernel)) + + if this function is simply passed a probe kernel, the results will not be meaningful! + To run on a single DP while passing the real space probe kernel as an argument, use + find_Bragg_disks_single_DP(). + + Args: + DP (ndarray): a diffraction pattern + probe_kernel_FT (ndarray): the vacuum probe template, in Fourier space. Related + to the real space probe kernel by probe_kernel_FT = F(probe_kernel)*, where + F indicates a Fourier Transform and * indicates complex conjugation. + corrPower (float between 0 and 1, inclusive): the cross correlation power. A + value of 1 corresponds to a cross correaltion, and 0 corresponds to a + phase correlation, with intermediate values giving various hybrids. + sigma (float): the standard deviation for the gaussian smoothing applied to + the cross correlation + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the relativeToPeak'th peak + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The shape + of the returned DP must match the shape of the probe kernel (but does not + need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + return_cc (bool): if True, return the cross correlation + peaks (PointList): For internal use. If peaks is None, the PointList of peak + positions is created here. If peaks is not None, it is the PointList that + detected peaks are added to, and must have the appropriate coords + ('qx','qy','intensity'). + + Returns: + (PointList) the Bragg peak positions and correlation intensities + """ + assert subpixel in [ + "none", + "poly", + "multicorr", + ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format( + subpixel + ) + + import numpy + import scipy.ndimage.filters + import py4DSTEM.utils.multicorr + + # apply filter function: + DP = DP if filter_function is None else filter_function(DP) + + if subpixel == "none": + cc = py4DSTEM.utils.get_cross_correlation_fk( + DP, probe_kernel_FT, corrPower + ) + cc = numpy.maximum(cc, 0) + maxima_x, maxima_y, maxima_int = py4DSTEM.utils.get_maxima_2D( + cc, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=False, + ) + elif subpixel == "poly": + cc = py4DSTEM.utils.get_cross_correlation_fk( + DP, probe_kernel_FT, corrPower + ) + cc = numpy.maximum(cc, 0) + maxima_x, maxima_y, maxima_int = py4DSTEM.utils.get_maxima_2D( + cc, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=True, + ) + else: + # Multicorr subpixel: + m = numpy.fft.fft2(DP) * probe_kernel_FT + ccc = numpy.abs(m) ** corrPower * numpy.exp(1j * numpy.angle(m)) + + cc = numpy.maximum(numpy.real(numpy.fft.ifft2(ccc)), 0) + + maxima_x, maxima_y, maxima_int = py4DSTEM.utils.get_maxima_2D( + cc, + sigma=sigma, + edgeBoundary=edgeBoundary, + minRelativeIntensity=minRelativeIntensity, + minAbsoluteIntensity=minAbsoluteIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minPeakSpacing, + maxNumPeaks=maxNumPeaks, + subpixel=True, + ) + + # Use the DFT upsample to refine the detected peaks (but not the intensity) + for ipeak in range(len(maxima_x)): + xyShift = numpy.array((maxima_x[ipeak], maxima_y[ipeak])) + # we actually have to lose some precision and go down to half-pixel + # accuracy. this could also be done by a single upsampling at factor 2 + # instead of get_maxima_2D. + xyShift[0] = numpy.round(xyShift[0] * 2) / 2 + xyShift[1] = numpy.round(xyShift[1] * 2) / 2 + + subShift = py4DSTEM.utils.multicorr.upsampled_correlation( + ccc, upsample_factor, xyShift + ) + maxima_x[ipeak] = subShift[0] + maxima_y[ipeak] = subShift[1] + + # Make peaks PointList + if peaks is None: + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = py4DSTEM.PointList(coordinates=coords) + else: + assert isinstance(peaks, py4DSTEM.PointList) + peaks.add_tuple_of_nparrays((maxima_x, maxima_y, maxima_int)) + + if return_cc: + return peaks, scipy.ndimage.filters.gaussian_filter(cc, sigma) + else: + return peaks + + +def _process_chunk(_f, start, end, path_to_static, coords, path_to_data, cluster_path): + import os + import dill + + with open(path_to_static, "rb") as infile: + inputs = dill.load(infile) + + # Always try to memory map the data file, if possible + if path_to_data.rsplit(".", 1)[-1].startswith("dm"): + datacube = py4DSTEM.read(path_to_data, load="dmmmap") + elif path_to_data.rsplit(".", 1)[-1].startswith("gt"): + datacube = py4DSTEM.read(path_to_data, load="gatan_bin") + else: + datacube = py4DSTEM.read(path_to_data) + + results = [] + for x in coords: + results.append((x[0], x[1], _f(datacube.data[x[0], x[1], :, :], *inputs).data)) + + # release memory + datacube = None + + path_to_output = os.path.join(cluster_path, "{}_{}.data".format(start, end)) + with open(path_to_output, "wb") as data_file: + dill.dump(results, data_file) + + return path_to_output + + +def find_Bragg_disks_ipp( + DP, + probe, + corrPower=1, + sigma=2, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="poly", + upsample_factor=4, + filter_function=None, + ipyparallel_client_file=None, + data_file=None, + cluster_path=None, +): + """ + Distributed compute using IPyParallel. + + Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or + phase correlation with probe. + + Args: + DP (ndarray): a diffraction pattern + probe (ndarray): the vacuum probe template, in real space. + corrPower (float between 0 and 1, inclusive): the cross correlation power. A + value of 1 corresponds to a cross correaltion, and 0 corresponds to a + phase correlation, with intermediate values giving various hybrids. + sigma (float): the standard deviation for the gaussian smoothing applied to + the cross correlation + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the brightest peak + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The shape + of the returned DP must match the shape of the probe kernel (but does not + need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + ipyparallel_client_file (str): absolute path to ipyparallel client JSON file for + connecting to a cluster + data_file (str): absolute path to the data file containing the datacube for + processing remotely + cluster_path (str): working directory for cluster processing, defaults to current + directory + + Returns: + (PointListArray): the Bragg peak positions and correlation intensities + """ + import ipyparallel as ipp + + R_Nx = DP.R_Nx + R_Ny = DP.R_Ny + R_N = DP.R_N + DP = None + + # Make the peaks PointListArray + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = PointListArray(coordinates=coords, shape=(R_Nx, R_Ny)) + + # Get the probe kernel FT + probe_kernel_FT = np.conj(np.fft.fft2(probe)) + + if ipyparallel_client_file is None: + raise RuntimeError("ipyparallel_client_file is None, no IPyParallel cluster") + elif data_file is None: + raise RuntimeError("data_file is None, needs path to datacube") + + t0 = time() + c = ipp.Client(url_file=ipyparallel_client_file, timeout=30) + + inputs_list = [ + probe_kernel_FT, + corrPower, + sigma, + edgeBoundary, + minRelativeIntensity, + minAbsoluteIntensity, + relativeToPeak, + minPeakSpacing, + maxNumPeaks, + subpixel, + upsample_factor, + filter_function, + ] + + if cluster_path is None: + cluster_path = os.getcwd() + + tmpdir = tempfile.TemporaryDirectory(dir=cluster_path) + + t_00 = time() + # write out static inputs + path_to_inputs = os.path.join(tmpdir.name, "inputs") + with open(path_to_inputs, "wb") as inputs_file: + dill.dump(inputs_list, inputs_file) + t_inputs_save = time() - t_00 + print("Serialize input values : {}".format(t_inputs_save)) + + results = [] + t1 = time() + total = int(R_Nx * R_Ny) + chunkSize = int(total / len(c.ids)) + + while chunkSize * len(c.ids) < total: + chunkSize += 1 + + indices = [(Rx, Ry) for Rx in range(R_Nx) for Ry in range(R_Ny)] + + start = 0 + for engine in c.ids: + if start + chunkSize < total - 1: + end = start + chunkSize + else: + end = total + + results.append( + c[engine].apply( + _process_chunk, + _find_Bragg_disks_single_DP_FK, + start, + end, + path_to_inputs, + indices[start:end], + data_file, + tmpdir.name, + ) + ) + + if end == total: + break + else: + start = end + t_submit = time() - t1 + print("Submit phase : {}".format(t_submit)) + + t2 = time() + c.wait(jobs=results) + t_wait = time() - t2 + print("Gather phase : {}".format(t_wait)) + + t3 = time() + for i in range(len(results)): + with open(results[i].get(), "rb") as f: + data_chunk = dill.load(f) + + for Rx, Ry, data in data_chunk: + peaks.get_pointlist(Rx, Ry).add_dataarray(data) + t_copy = time() - t3 + print("Copy results : {}".format(t_copy)) + + # clean up temp files + try: + tmpdir.cleanup() + except OSError as e: + print("Error when cleaning up temporary files: {}".format(e)) + + t = time() - t0 + print( + "Analyzed {} diffraction patterns in {}h {}m {}s".format( + R_N, int(t / 3600), int(t / 60), int(t % 60) + ) + ) + + return peaks + + +def find_Bragg_disks_dask( + DP, + probe, + corrPower=1, + sigma=2, + edgeBoundary=20, + minRelativeIntensity=0.005, + minAbsoluteIntensity=0, + relativeToPeak=0, + minPeakSpacing=60, + maxNumPeaks=70, + subpixel="poly", + upsample_factor=4, + filter_function=None, + dask_client=None, + data_file=None, + cluster_path=None, +): + """ + Distributed compute using Dask. + + Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or + phase correlation with probe. + + Args: + DP (ndarray): a diffraction pattern + probe (darray): the vacuum probe template, in real space. + corrPower (float between 0 and 1, inclusive): the cross correlation power. A + value of 1 corresponds to a cross correaltion, and 0 corresponds to a + phase correlation, with intermediate values giving various hybrids. + sigma (float): the standard deviation for the gaussian smoothing applied to + the cross correlation + edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels + minRelativeIntensity (float): the minimum acceptable correlation peak intensity, + relative to the intensity of the brightest peak + relativeToPeak (int): specifies the peak against which the minimum relative + intensity is measured -- 0=brightest maximum. 1=next brightest, etc. + minPeakSpacing (float): the minimum acceptable spacing between detected peaks + maxNumPeaks (int): the maximum number of peaks to return + subpixel (str): Whether to use subpixel fitting, and which algorithm to use. + Must be in ('none','poly','multicorr'). + * 'none': performs no subpixel fitting + * 'poly': polynomial interpolation of correlogram peaks (default) + * 'multicorr': uses the multicorr algorithm with DFT upsampling + upsample_factor (int): upsampling factor for subpixel fitting (only used when + subpixel='multicorr') + filter_function (callable): filtering function to apply to each diffraction + pattern before peakfinding. Must be a function of only one argument (the + diffraction pattern) and return the filtered diffraction pattern. The shape + of the returned DP must match the shape of the probe kernel (but does not + need to match the shape of the input diffraction pattern, e.g. the filter + can be used to bin the diffraction pattern). If using distributed disk + detection, the function must be able to be pickled with by dill. + dask_client (obj): dask client for connecting to a cluster + data_file (str): absolute path to the data file containing the datacube for + processing remotely + cluster_path (str): working directory for cluster processing, defaults to current + directory + + Returns: + (PointListArray) the Bragg peak positions and correlation intensities + """ + import distributed + + R_Nx = DP.R_Nx + R_Ny = DP.R_Ny + R_N = DP.R_N + DP = None + + # Make the peaks PointListArray + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = PointListArray(coordinates=coords, shape=(R_Nx, R_Ny)) + + # Get the probe kernel FT + probe_kernel_FT = np.conj(np.fft.fft2(probe)) + + if dask_client is None: + raise RuntimeError("dask_client is None, no Dask cluster!") + elif data_file is None: + raise RuntimeError("data_file is None, needs path to datacube") + + t0 = time() + + inputs_list = [ + probe_kernel_FT, + corrPower, + sigma, + edgeBoundary, + minRelativeIntensity, + minAbsoluteIntensity, + relativeToPeak, + minPeakSpacing, + maxNumPeaks, + subpixel, + upsample_factor, + filter_function, + ] + + if cluster_path is None: + cluster_path = os.getcwd() + + tmpdir = tempfile.TemporaryDirectory(dir=cluster_path) + + # write out static inputs + path_to_inputs = os.path.join(tmpdir.name, "{}.inputs".format(dask_client.id)) + with open(path_to_inputs, "wb") as inputs_file: + dill.dump(inputs_list, inputs_file) + t_inputs_save = time() - t0 + print("Serialize input values : {}".format(t_inputs_save)) + + cores = len(dask_client.ncores()) + + submits = [] + t1 = time() + total = int(R_Nx * R_Ny) + chunkSize = int(total / cores) + + while (chunkSize * cores) < total: + chunkSize += 1 + + indices = [(Rx, Ry) for Rx in range(R_Nx) for Ry in range(R_Ny)] + + start = 0 + for engine in range(cores): + if start + chunkSize < total - 1: + end = start + chunkSize + else: + end = total + + submits.append( + dask_client.submit( + _process_chunk, + _find_Bragg_disks_single_DP_FK, + start, + end, + path_to_inputs, + indices[start:end], + data_file, + tmpdir.name, + ) + ) + + if end == total: + break + else: + start = end + t_submit = time() - t1 + print("Submit phase : {}".format(t_submit)) + + t2 = time() + # collect results + for batch in distributed.as_completed(submits, with_results=True).batches(): + for future, result in batch: + with open(result, "rb") as f: + data_chunk = dill.load(f) + + for Rx, Ry, data in data_chunk: + peaks.get_pointlist(Rx, Ry).add_dataarray(data) + t_copy = time() - t2 + print("Gather phase : {}".format(t_copy)) + + # clean up temp files + try: + tmpdir.cleanup() + except OSError as e: + print("Error when cleaning up temporary files: {}".format(e)) + + t = time() - t0 + print( + "Analyzed {} diffraction patterns in {}h {}m {}s".format( + R_N, int(t / 3600), int(t / 60), int(t % 60) + ) + ) + + return peaks diff --git a/py4DSTEM/datacube/diskdetection/diskdetection_parallel_new.py b/py4DSTEM/datacube/diskdetection/diskdetection_parallel_new.py new file mode 100644 index 000000000..dccc0dd4b --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/diskdetection_parallel_new.py @@ -0,0 +1,272 @@ +import numpy as np +import matplotlib.pyplot as plt +import h5py +import time +import dill + +import dask +import dask.array as da +import dask.config +from dask import delayed +from dask.distributed import Client, LocalCluster +from dask.diagnostics import ProgressBar + +# import dask.bag as db + +# import distributed +from distributed.protocol.serialize import register_serialization_family +import distributed + +import py4DSTEM +from emdfile import PointListArray, PointList +from py4DSTEM.braggvectors.diskdetection import _find_Bragg_disks_single_DP_FK + + +#### SERIALISERS #### +# Define Serialiser +# these are functions which allow the hdf5 objects to be passed. May not be required anymore + + +def dill_dumps(x): + header = {"serializer": "dill"} + frames = [dill.dumps(x)] + return header, frames + + +def dill_loads(header, frames): + if len(frames) > 1: + frame = "".join(frames) + else: + frame = frames[0] + + return dill.loads(frame) + + +# register the serialization method +# register_serialization_family('dill', dill_dumps, dill_loads) + + +def register_dill_serializer(): + """ + This function registers the dill serializer allowing dask to work on h5py objects. + Not sure if this needs to be run and how often this need to be run. Keeping this in for now. + Args: + None + Returns: + None + """ + register_serialization_family("dill", dill_dumps, dill_loads) + return None + + +#### END OF SERAILISERS #### + + +#### DASK WRAPPER FUNCTION #### + + +# Each delayed objected is passed a 4D array, currently implementing only on 2D slices. +# TODO add batching with fancy indexing - needs to run a for loop over the batch of arrays +# TODO add cuda accelerated version +# TODO add ML-AI version +def _find_Bragg_disks_single_DP_FK_dask_wrapper(arr, *args, **kwargs): + # THis is needed as _find_Bragg_disks_single_DP_FK takes 2D array these arrays have the wrong shape + return _find_Bragg_disks_single_DP_FK(arr[0, 0], *args, **kwargs) + + +#### END OF DASK WRAPPER FUNCTIONS #### + + +#### MAIN FUNCTION +# TODO add batching with fancy indexing - needs batch size, fancy indexing method +# TODO add cuda accelerated function - needs dask GPU cluster. + + +def beta_parallel_disk_detection( + dataset, + probe, + # rxmin=None, # these would allow selecting a sub section + # rxmax=None, + # rymin=None, + # rymax=None, + # qxmin=None, + # qxmax=None, + # qymin=None, + # qymax=None, + probe_type="FT", + dask_client=None, + dask_client_params: dict = None, + restart_dask_client=True, + close_dask_client=False, + return_dask_client=True, + *args, + **kwargs, +): + """ + This is not fully validated currently so may not work, please report bugs on the py4DSTEM github page. + + This parallellises the disk detetection for all probe posistions. This can operate on either in memory or out of memory datasets + + There is an asumption that unless specifying otherwise you are parallelising on a single Local Machine. + If this is not the case its probably best to pass the dask_client into the function, although you can just pass the required arguments to dask_client_params. + If no dask_client arguments are passed it will create a dask_client for a local machine + + Note: + Do not pass "peaks" argument as a kwarg, like you might in "_find_Bragg_disks_single_DP_FK", as the results will be unreliable and may cause the calculation to crash. + Args: + dataset (py4dSTEM datacube): 4DSTEM dataset + probe (ndarray): can be regular probe kernel or fourier transormed + probe_type (str): "FT" or None + dask_client (distributed.client.Client): dask client + dask_client_params (dict): parameters to pass to dask client or dask cluster + restart_dask_client (bool): if True, function will attempt to restart the dask_client. + close_dask_client (bool): if True, function will attempt to close the dask_client. + return_dask_client (bool): if True, function will return the dask_client. + *args,kwargs will be passed to "_find_Bragg_disks_single_DP_FK" e.g. corrPower, sigma, edgeboundary... + + Returns: + peaks (PointListArray): the Bragg peak positions and the correlenation intensities + dask_client(optional) (distributed.client.Client): dask_client for use later. + """ + # TODO add asserts abotu peaks not being passed + # Dask Client stuff + # TODO how to guess at default params for client, sqrt no.cores. Something to do with the size of the diffraction patterm + # write a function which can do this. + # TODO replace dask part with a with statement for easier clean up e.g. + # with LocalCluser(params) as cluster, Client(cluster) as client: + # ... dask stuff. + # TODO add assert statements and other checks. Think about reordering opperations + + if dask_client is None: + if dask_client_params is not None: + dask.config.set( + { + "distributed.worker.memory.spill": False, + "distributed.worker.memory.target": False, + } + ) + cluster = LocalCluster(**dask_client_params) + dask_client = Client(cluster, **dask_client_params) + else: + # AUTO MAGICALLY SET? + # LET DASK SET? + # HAVE A FUNCTION WHICH RUNS ON A SUBSET OF THE DATA TO PICK OPTIMIAL VALUE? + # psutil could be used to count cores. + dask.config.set( + { + "distributed.worker.memory.spill": False, # stops spilling to disk + "distributed.worker.memory.target": False, + } + ) # stops spilling to disk and erroring out + cluster = LocalCluster() + dask_client = Client(cluster) + + else: + assert type(dask_client) == distributed.client.Client + if restart_dask_client: + try: + dask_client.restart() + except Exception as e: + print( + 'Could not restart dask client. Try manually restarting outside or passing "restart_dask_client=False"' + ) # WARNING STATEMENT + return e + else: + pass + + # Probe stuff + assert ( + probe.shape == dataset.data.shape[2:] + ), "Probe and Diffraction Pattern Shapes are Mismatched" + if probe_type != "FT": + # TODO clean up and pull out redudant parts + # if probe.dtype != (np.complex128 or np.complex64 or np.complex256): + # DO FFT SHIFT THING + probe_kernel_FT = np.conj(np.fft.fft2(probe)) + dask_probe_array = da.from_array( + probe_kernel_FT, chunks=(dataset.Q_Nx, dataset.Q_Ny) + ) + dask_probe_delayed = dask_probe_array.to_delayed() + # delayed_probe_kernel_FT = delayed(probe_kernel_FT) + else: + probe_kernel_FT = probe + dask_probe_array = da.from_array( + probe_kernel_FT, chunks=(dataset.Q_Nx, dataset.Q_Ny) + ) + dask_probe_delayed = dask_probe_array.to_delayed() + + # GET DATA + # TODO add another elif if it is a dask array then pass + if type(dataset.data) == np.ndarray: + dask_data = da.from_array( + dataset.data, chunks=(1, 1, dataset.Q_Nx, dataset.Q_Ny) + ) + elif dataset.stack_pointer is not None: + dask_data = da.from_array( + dataset.stack_pointer, chunks=(1, 1, dataset.Q_Nx, dataset.Q_Ny) + ) + else: + print("Couldn't access the data") + return None + + # Convert the data to delayed + dataset_delayed = dask_data.to_delayed() + # TODO Trim data e.g. rx,ry,qx,qy + # I can pass the index values in here I should trim the probe and diffraction pattern first + + # Into the meat of the function + + # create an empty list to which we will append the dealyed functions to. + res = [] + # loop over the dataset_delayed and create a delayed function of + for x in np.ndindex(dataset_delayed.shape): + temp = delayed(_find_Bragg_disks_single_DP_FK_dask_wrapper)( + dataset_delayed[x], + probe_kernel_FT=dask_probe_delayed[0, 0], + # probe_kernel_FT=delayed_probe_kernel_FT, + *args, + **kwargs, + ) # passing through args from earlier or should I use + # corrPower=corrPower, + # sigma=sigma_gaussianFilter, + # edgeBoundary=edgeBoundary, + # minRelativeIntensity=minRelativeIntensity, + # minPeakSpacing=minPeakSpacing, + # maxNumPeaks=maxNumPeaks, + # subpixel='poly') + res.append(temp) + _temp_peaks = dask_client.compute( + res, optimize_graph=True + ) # creates futures and starts computing + + output = dask_client.gather(_temp_peaks) # gather the future objects + + coords = [("qx", float), ("qy", float), ("intensity", float)] + peaks = PointListArray(coordinates=coords, shape=dataset.data.shape[:-2]) + + # temp_peaks[0][0] + + # operating over a list so we need the size (0->count) and re-create the probe positions (0->rx,0->ry), + for count, (rx, ry) in zip( + [i for i in range(dataset.data[..., 0, 0].size)], + np.ndindex(dataset.data.shape[:-2]), + ): + # peaks.get_pointlist(rx, ry).add_pointlist(temp_peaks[0][count]) + # peaks.get_pointlist(rx, ry).add_pointlist(output[count][0]) + peaks.get_pointlist(rx, ry).add_pointlist(output[count]) + + # Clean up + dask_client.cancel(_temp_peaks) # removes from the dask workers + del _temp_peaks # deletes the object + if close_dask_client: + dask_client.close() + return peaks + elif close_dask_client is False and return_dask_client is True: + return peaks, dask_client + elif close_dask_client and return_dask_client is False: + return peaks + else: + print( + "Dask Client in unknown state, this may result in unpredicitable behaviour later" + ) + return peaks diff --git a/py4DSTEM/datacube/diskdetection/kernels.py b/py4DSTEM/datacube/diskdetection/kernels.py new file mode 100644 index 000000000..d36ae172b --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/kernels.py @@ -0,0 +1,97 @@ +import cupy as cp + +__all__ = ["kernels"] + +kernels = {} + +############################# multicorr kernels ################################# + +import os + +with open(os.path.join(os.path.dirname(__file__), "multicorr_row_kernel.cu"), "r") as f: + kernels["multicorr_row_kernel"] = cp.RawKernel(f.read(), "multicorr_row_kernel") + +with open(os.path.join(os.path.dirname(__file__), "multicorr_col_kernel.cu"), "r") as f: + kernels["multicorr_col_kernel"] = cp.RawKernel(f.read(), "multicorr_col_kernel") + + +############################# get_maximal_points ################################ + +""" +These kernels are approximately 50x faster than the np.roll approach used in the CPU version, +per my testing with 1024x1024 pixels and float64 on a Jetson Xavier NX. +The boundary conditions are slightly different in this version, in that pixels on the edge +of the frame are always false. This simplifies the indexing, and since in the Braggdisk +detection application an edgeBoundary is always applied in the case of subpixel detection, +this is not considered a problem. +""" + +maximal_pts_float32 = r""" +extern "C" __global__ +void maximal_pts(const float *ar, bool *out, const double minAbsoluteIntensity, const long long sizex, const long long sizey, const long long N){ + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int x = tid / sizey; + int y = tid % sizey; + bool res = false; + if (tid < N && x>0 && x<(sizex-1) && y>0 && y<(sizey-1)) { + float val = ar[tid]; + + out[tid] = ( val > ar[tid + sizey]) && + (val > ar[tid - sizey]) && + (val > ar[tid + 1]) && + (val > ar[tid - 1]) && + (val > ar[tid - sizey - 1]) && + (val > ar[tid - sizey + 1]) && + (val > ar[tid + sizey - 1]) && + (val > ar[tid+sizey + 1] && + (val >= minAbsoluteIntensity)); + } +} +""" + +kernels["maximal_pts_float32"] = cp.RawKernel(maximal_pts_float32, "maximal_pts") + +maximal_pts_float64 = r""" +extern "C" __global__ +void maximal_pts(const double *ar, bool *out, const double minAbsoluteIntensity, const long long sizex, const long long sizey, const long long N){ + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int x = tid / sizey; + int y = tid % sizey; + bool res = false; + if (tid < N && x>0 && x<(sizex-1) && y>0 && y<(sizey-1)) { + double val = ar[tid]; + + out[tid] = ( val > ar[tid + sizey]) && + (val > ar[tid - sizey]) && + (val > ar[tid + 1]) && + (val > ar[tid - 1]) && + (val > ar[tid - sizey - 1]) && + (val > ar[tid - sizey + 1]) && + (val > ar[tid + sizey - 1]) && + (val > ar[tid+sizey + 1] && + (val >= minAbsoluteIntensity)); + } +} +""" + +kernels["maximal_pts_float64"] = cp.RawKernel(maximal_pts_float64, "maximal_pts") + + +################################ edge_boundary ###################################### + +edge_boundary = r""" +extern "C" __global__ +void edge_boundary(bool *ar, const long long edgeBoundary, + const long long sizex, const long long sizey, const long long N){ + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int x = tid % sizex; + int y = tid / sizey; // Floor divide + if (tid < N) { + if (x(sizex-1-edgeBoundary) || y(sizey-1-edgeBoundary)){ + ar[tid] = false; + } + } +} +""" + +kernels["edge_boundary"] = cp.RawKernel(edge_boundary, "edge_boundary") diff --git a/py4DSTEM/datacube/diskdetection/multicorr_col_kernel.cu b/py4DSTEM/datacube/diskdetection/multicorr_col_kernel.cu new file mode 100644 index 000000000..d2c9cc4ca --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/multicorr_col_kernel.cu @@ -0,0 +1,61 @@ +#include +#define PI 3.14159265359 +extern "C" __global__ +void multicorr_col_kernel( + complex *ar, + const float *xyShifts, + const long long N_pts, + const long long image_size_x, + const long long image_size_y, + const long long upsample_factor) { + /* + Fill in the entries of the multicorr row kernel. + Inputs (C++ type/Python type): + ar (complex* / cp.complex64): Array of size N_pts x image_size[1] x kernel_size + to hold the row kernels + xyShifts (const float* / cp.float32): (N_pts x 2) array of center points to build kernels for + N_pts (const long long/int) number of center points we are + building kernels for + image_size_x (const long long/int): x size of the correlation image + image_size_y (const long long/int): y size of correlation image + upsample_factor (const long long/int): note, kernel_width = ceil(1.5*upsample_factor) + */ + int kernel_size = ceil(1.5 * upsample_factor); + + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + // Using strides to compute indices: + int stride_0 = image_size_y * kernel_size; // Stride along 0-th dimension of ar + int stride_1 = kernel_size; // Stride along 1-th dimension of ar + + // Which kernel in the stack (first index of ar) + int kernel_idx = tid / stride_0; + // Which row in the kernel (second index of ar) + int row_idx = (tid % stride_0) / stride_1; + // Which column in the kernel (last index of ar) + int col_idx = (tid % stride_0) % stride_1; + + complex prefactor = complex(0,-2.0 * PI) / float(image_size_y * upsample_factor); + + // Now do the actual calculation + if (tid < N_pts * image_size_y * kernel_size) { + // np.fft.ifftshift(np.arange(imageSize[1])) - np.floor(imageSize[1]/2) + // modresult is necessary to get the Pythonic behavior of mod of negative numbers + int modresult = int(row_idx - ceil((float)image_size_y / 2.)) % image_size_y; + modresult = modresult < 0 ? modresult + image_size_y : modresult; + float columnEntry = float(modresult) - floor((float)image_size_y/2.) ; + + + // np.arange(numColumns) - xyShift[idx,0] + float rowEntry = (float)col_idx - xyShifts[kernel_idx*2 + 1]; + + ar[tid] = exp(prefactor * columnEntry * rowEntry); + + // Use these for testing the indexing: + // ar[tid] = complex(0,(float)tid); + // ar[tid] = complex(0,(float)kernel_idx); + // ar[tid] = complex(0,(float)row_idx); + // ar[tid] = complex(0,(float)col_idx); + } + +} diff --git a/py4DSTEM/datacube/diskdetection/multicorr_row_kernel.cu b/py4DSTEM/datacube/diskdetection/multicorr_row_kernel.cu new file mode 100644 index 000000000..a5a0f352f --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/multicorr_row_kernel.cu @@ -0,0 +1,60 @@ +#include +#define PI 3.14159265359 +extern "C" __global__ +void multicorr_row_kernel( + complex *ar, + const float *xyShifts, + const long long N_pts, + const long long image_size_x, + const long long image_size_y, + const long long upsample_factor) { + /* + Fill in the entries of the multicorr row kernel. + Inputs (C++ type/Python type): + ar (complex* / cp.complex64): Array of size N_pts x kernel_size x image_size[0] + to hold the row kernels + xyShifts (const float* / cp.float32): (N_pts x 2) array of center points to build kernels for + N_pts (const long long/int) number of center points we are + building kernels for + image_size_x (const long long/int): x size of the correlation image + image_size_y (const long long/int): y size of correlation image + upsample_factor (const long long/int): note, kernel_width = ceil(1.5*upsample_factor) + */ + int kernel_size = ceil(1.5 * upsample_factor); + + int tid = blockDim.x * blockIdx.x + threadIdx.x; + + // Using strides to compute indices: + int stride_0 = image_size_x * kernel_size; // Stride along 0-th dimension of ar + int stride_1 = image_size_x; // Stride along 1-th dimension of ar + + // Which kernel in the stack (first index of ar) + int kernel_idx = tid / stride_0; + // Which row in the kernel (second index of ar) + int row_idx = (tid % stride_0) / stride_1; + // Which column in the kernel (last index of ar) + int col_idx = (tid % stride_0) % stride_1; + + complex prefactor = complex(0,-2.0 * PI) / float(image_size_x * upsample_factor); + + // Now do the actual calculation + if (tid < N_pts * image_size_x * kernel_size) { + // np.arange(numColumns) - xyShift[idx,0] + float columnEntry = (float)row_idx - xyShifts[kernel_idx*2]; + + // np.fft.ifftshift(np.arange(imageSize[0])) - np.floor(imageSize[0]/2) + // modresult is necessary to get the Pythonic behavior of mod of negative numbers + int modresult = int(col_idx - ceil((float)image_size_x / 2.)) % image_size_x; + modresult = modresult < 0 ? modresult + image_size_x : modresult; + float rowEntry = float(modresult) - floor((float)image_size_x/2.) ; + + ar[tid] = exp(prefactor * columnEntry * rowEntry); + + // Use these for testing the indexing: + //ar[tid] = complex(0,(float)tid); + //ar[tid] = complex(0,(float)kernel_idx); + //ar[tid] = complex(0,(float)row_idx); + //ar[tid] = complex(0,(float)col_idx); + } + +} \ No newline at end of file diff --git a/py4DSTEM/datacube/diskdetection/probe.py b/py4DSTEM/datacube/diskdetection/probe.py new file mode 100644 index 000000000..7d32ea1ca --- /dev/null +++ b/py4DSTEM/datacube/diskdetection/probe.py @@ -0,0 +1,767 @@ +# Defines the Probe class + +import numpy as np +from typing import Optional +from warnings import warn +from scipy.ndimage import binary_opening, binary_dilation, distance_transform_edt + +from py4DSTEM.utils import get_shifted_ar, get_shift, get_CoM +from py4DSTEM.data import DiffractionSlice, Data + + + + + +# DataCube methods + + +class ProbeMaker: + def __init__(self): + pass + + + def get_vacuum_probe( + self, + ROI=None, + align=True, + mask=None, + threshold=0.2, + expansion=12, + opening=3, + verbose=False, + returncalc=True, + ): + """ + Computes a vacuum probe. + + Which diffraction patterns are included in the calculation is specified + by the `ROI` parameter. Diffraction patterns are aligned before averaging + if `align` is True (default). A global mask is applied to each diffraction + pattern before aligning/averaging if `mask` is specified. After averaging, + a final masking step is applied according to the parameters `threshold`, + `expansion`, and `opening`. + + Parameters + ---------- + ROI : optional, boolean array or len 4 list/tuple + If unspecified, uses the whole datacube. If a boolean array is + passed must be real-space shaped, and True pixels are used. If a + 4-tuple is passed, uses the region inside the limits + (rx_min,rx_max,ry_min,ry_max) + align : optional, bool + if True, aligns the probes before averaging + mask : optional, array + mask applied to each diffraction pattern before alignment and + averaging + threshold : float + in the final masking step, values less than max(probe)*threshold + are considered outside the probe + expansion : int + number of pixels by which the final mask is expanded after + thresholding + opening : int + size of binary opening applied to the final mask to eliminate stray + bright pixels + verbose : bool + toggles verbose output + returncalc : bool + if True, returns the answer + + Returns + ------- + probe : Probe, optional + the vacuum probe + """ + from py4DSTEM.braggvectors import Probe + + # parse region to use + if ROI is None: + ROI = np.ones(self.Rshape, dtype=bool) + elif isinstance(ROI, tuple): + assert len(ROI) == 4, "if ROI is a tuple must be length 4" + _ROI = np.ones(self.Rshape, dtype=bool) + ROI = _ROI[ROI[0] : ROI[1], ROI[2] : ROI[3]] + else: + assert isinstance(ROI, np.ndarray) + assert ROI.shape == self.Rshape + xy = np.vstack(np.nonzero(ROI)) + length = xy.shape[1] + + # setup global mask + if mask is None: + mask = 1 + else: + assert mask.shape == self.Qshape + + # compute average probe + probe = self.data[xy[0, 0], xy[1, 0], :, :] + for n in tqdmnd(range(1, length)): + curr_DP = self.data[xy[0, n], xy[1, n], :, :] * mask + if align: + xshift, yshift = get_shift(probe, curr_DP) + curr_DP = get_shifted_ar(curr_DP, xshift, yshift) + probe = probe * (n - 1) / n + curr_DP / n + + # mask + mask = probe > np.max(probe) * threshold + mask = binary_opening(mask, iterations=opening) + mask = binary_dilation(mask, iterations=1) + mask = ( + np.cos( + (np.pi / 2) + * np.minimum( + distance_transform_edt(np.logical_not(mask)) / expansion, 1 + ) + ) + ** 2 + ) + probe *= mask + + # make a probe, add to tree, and return + probe = Probe(probe) + self.attach(probe) + if returncalc: + return probe + + def get_probe_size( + self, + dp=None, + thresh_lower=0.01, + thresh_upper=0.99, + N=100, + plot=False, + returncal=True, + write_to_cal=True, + **kwargs, + ): + """ + Gets the center and radius of the probe in the diffraction plane. + + The algorithm is as follows: + First, create a series of N binary masks, by thresholding the diffraction + pattern DP with a linspace of N thresholds from thresh_lower to + thresh_upper, measured relative to the maximum intensity in DP. + Using the area of each binary mask, calculate the radius r of a circular + probe. Because the central disk is typically very intense relative to + the rest of the DP, r should change very little over a wide range of + intermediate values of the threshold. The range in which r is trustworthy + is found by taking the derivative of r(thresh) and finding identifying + where it is small. The radius is taken to be the mean of these r values. + Using the threshold corresponding to this r, a mask is created and the + CoM of the DP times this mask it taken. This is taken to be the origin + x0,y0. + + Args: + dp (str or array): specifies the diffraction pattern in which to + find the central disk. A position averaged, or shift-corrected + and averaged, DP works best. If mode is None, the diffraction + pattern stored in the tree from 'get_dp_mean' is used. If mode + is a string it specifies the name of another virtual diffraction + pattern in the tree. If mode is an array, the array is used to + calculate probe size. + thresh_lower (float, 0 to 1): the lower limit of threshold values + thresh_upper (float, 0 to 1): the upper limit of threshold values + N (int): the number of thresholds / masks to use + plot (bool): if True plots results + plot_params(dict): dictionary to modify defaults in plot + return_calc (bool): if True returns 3-tuple described below + write_to_cal (bool): if True, looks for a Calibration instance + and writes the measured probe radius there + + Returns: + (3-tuple): A 3-tuple containing: + + * **r**: *(float)* the central disk radius, in pixels + * **x0**: *(float)* the x position of the central disk center + * **y0**: *(float)* the y position of the central disk center + """ + # perform computation + from py4DSTEM.process.calibration import get_probe_size + + if dp is None: + assert ( + "dp_mean" in self.treekeys + ), "calculate .get_dp_mean() or pass a `dp` arg" + DP = self.tree("dp_mean").data + elif type(dp) == str: + assert dp in self.treekeys, f"mode {dp} not found in the tree" + DP = self.tree(dp) + elif type(dp) == np.ndarray: + assert dp.shape == self.Qshape, "must be a diffraction space shape 2D array" + DP = dp + + x = get_probe_size( + DP, + thresh_lower=thresh_lower, + thresh_upper=thresh_upper, + N=N, + ) + + # try to add to calibration + if write_to_cal: + try: + self.calibration.set_probe_param(x) + except AttributeError: + raise Exception( + "writing to calibrations were requested, but could not be completed" + ) + + # plot results + if plot: + from py4DSTEM.visualize import show_circles + + show_circles(DP, (x[1], x[2]), x[0], vmin=0, vmax=1, **kwargs) + + # return + if returncal: + return x + + + + + + + +# Container class + +class Probe(DiffractionSlice, Data): + """ + Stores a vacuum probe. + + Both a vacuum probe and a kernel for cross-correlative template matching + derived from that probe are stored and can be accessed at + + >>> p.probe + >>> p.kernel + + respectively, for some Probe instance `p`. If a kernel has not been computed + the latter expression returns None. + + + """ + + def __init__(self, data: np.ndarray, name: Optional[str] = "probe"): + """ + Accepts: + data (2D or 3D np.ndarray): the vacuum probe, or + the vacuum probe + kernel + name (str): a name + + Returns: + (Probe) + """ + # if only the probe is passed, make space for the kernel + if data.ndim == 2: + data = np.stack([data, np.zeros_like(data)]) + + # initialize as a DiffractionSlice + DiffractionSlice.__init__( + self, name=name, data=data, slicelabels=["probe", "kernel"] + ) + + ## properties + + @property + def probe(self): + return self.get_slice("probe").data + + @probe.setter + def probe(self, x): + assert x.shape == (self.data.shape[1:]) + self.data[0, :, :] = x + + @property + def kernel(self): + return self.get_slice("kernel").data + + @kernel.setter + def kernel(self, x): + assert x.shape == (self.data.shape[1:]) + self.data[1, :, :] = x + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = DiffractionSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args + + # generation methods + + @classmethod + def from_vacuum_data(cls, data, mask=None, threshold=0.2, expansion=12, opening=3): + """ + Generates and returns a vacuum probe Probe instance from either a + 2D vacuum image or a 3D stack of vacuum diffraction patterns. + + The probe is multiplied by `mask`, if it's passed. An additional + masking step zeros values outside of a mask determined by `threshold`, + `expansion`, and `opening`, generated by first computing the binary image + probe < max(probe)*threshold, then applying a binary expansion and + then opening to this image. No alignment is performed - i.e. it is assumed + that the beam was stationary during acquisition of the stack. To align + the images, use the DataCube .get_vacuum_probe method. + + Parameters + ---------- + data : 2D or 3D array + the vacuum diffraction data. For 3D stacks, use shape (N,Q_Nx,Q_Ny) + mask : boolean array, optional + mask applied to the probe + threshold : float + threshold determining mask which zeros values outside of probe + expansion : int + number of pixels by which the zeroing mask is expanded to capture + the full probe + opening : int + size of binary opening used to eliminate stray bright pixels + + Returns + ------- + probe : Probe + the vacuum probe + """ + assert isinstance(data, np.ndarray) + if data.ndim == 3: + probe = np.average(data, axis=0) + elif data.ndim == 2: + probe = data + else: + raise Exception(f"data must be 2- or 3-D, not {data.ndim}-D") + + if mask is not None: + probe *= mask + + mask = probe > np.max(probe) * threshold + mask = binary_opening(mask, iterations=opening) + mask = binary_dilation(mask, iterations=1) + mask = ( + np.cos( + (np.pi / 2) + * np.minimum( + distance_transform_edt(np.logical_not(mask)) / expansion, 1 + ) + ) + ** 2 + ) + + probe = cls(probe * mask) + return probe + + @classmethod + def generate_synthetic_probe(cls, radius, width, Qshape): + """ + Makes a synthetic probe, with the functional form of a disk blurred by a + sigmoid (a logistic function). + + Parameters + ---------- + radius : float + the probe radius + width : float + the blurring of the probe edge. width represents the + full width of the blur, with x=-w/2 to x=+w/2 about the edge + spanning values of ~0.12 to 0.88 + Qshape : 2 tuple + the diffraction plane dimensions + + Returns + ------- + probe : Probe + the probe + """ + # Make coords + Q_Nx, Q_Ny = Qshape + qy, qx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx)) + qy, qx = qy - Q_Ny / 2.0, qx - Q_Nx / 2.0 + qr = np.sqrt(qx**2 + qy**2) + + # Shift zero to disk edge + qr = qr - radius + + # Calculate logistic function + probe = 1 / (1 + np.exp(4 * qr / width)) + + return cls(probe) + + # calibration methods + + def measure_disk( + self, + thresh_lower=0.01, + thresh_upper=0.99, + N=100, + returncalc=True, + data=None, + ): + """ + Finds the center and radius of an average probe image. + + A naive algorithm. Creates a series of N binary masks by thresholding + the probe image a linspace of N thresholds from thresh_lower to + thresh_upper, relative to the image max/min. For each mask, we find the + square root of the number of True valued pixels divided by pi to + estimate a radius. Because the central disk is intense relative to the + remainder of the image, the computed radii are expected to vary very + little over a wider range threshold values. A range of r values + considered trustworthy is estimated by taking the derivative + r(thresh)/dthresh identifying where it is small, and the mean of this + range is returned as the radius. A center is estimated using a binary + thresholded image in combination with the center of mass operator. + + Parameters + ---------- + thresh_lower : float, 0 to 1 + the lower limit of threshold values + thresh_upper : float, 0 to 1) + the upper limit of threshold values + N : int + the number of thresholds / masks to use + returncalc : True + toggles returning the answer + data : 2d array, optional + if passed, uses this 2D array in place of the probe image when + performing the computation. This also supresses storing the + results in the Probe's calibration metadata + + Returns + ------- + r, x0, y0 : (3-tuple) + the radius and origin + """ + # set the image + im = self.probe if data is None else data + + # define the thresholds + thresh_vals = np.linspace(thresh_lower, thresh_upper, N) + r_vals = np.zeros(N) + + # get binary images and compute a radius for each + immax = np.max(im) + for i, val in enumerate(thresh_vals): + mask = im > immax * val + r_vals[i] = np.sqrt(np.sum(mask) / np.pi) + + # Get derivative and determine trustworthy r-values + dr_dtheta = np.gradient(r_vals) + mask = (dr_dtheta <= 0) * (dr_dtheta >= 2 * np.median(dr_dtheta)) + r = np.mean(r_vals[mask]) + + # Get origin + thresh = np.mean(thresh_vals[mask]) + mask = im > immax * thresh + x0, y0 = get_CoM(im * mask) + + # Store metadata and return + ans = r, x0, y0 + if data is None: + try: + self.calibration.set_probe_param(ans) + except AttributeError: + warn( + f"Couldn't store the probe parameters in metadata as no calibration was found for this Probe instance, {self}" + ) + pass + if returncalc: + return ans + + # Kernel generation methods + + def get_kernel( + self, mode="flat", origin=None, data=None, returncalc=True, **kwargs + ): + """ + Creates a cross-correlation kernel from the vacuum probe. + + Specific behavior and valid keyword arguments depend on the `mode` + specified. In each case, the center of the probe is shifted to the + origin and the kernel normalized such that it sums to 1. This is the + only processing performed if mode is 'flat'. Otherwise, a centrosymmetric + region of negative intensity is added around the probe intended to promote + edge-filtering-like behavior during cross correlation, with the + functional form of the subtracted region defined by `mode` and the + relevant **kwargs. For normalization, flat probes integrate to 1, and the + remaining probes integrate to 1 before subtraction and 0 after. Required + keyword arguments are: + + - 'flat': No required arguments. This mode is recommended for bullseye + or other structured probes + - 'gaussian': Required arg `sigma` (number), the width (standard + deviation) of a centered gaussian to be subtracted. + - 'sigmoid': Required arg `radii` (2-tuple), the inner and outer radii + (ri,ro) of an annular region with a sine-squared sigmoidal radial + profile to be subtracted. + - 'sigmoid_log': Required arg `radii` (2-tuple), the inner and outer radii + (ri,ro) of an annular region with a logistic sigmoidal radial + profile to be subtracted. + + Parameters + ---------- + mode : str + must be in 'flat','gaussian','sigmoid','sigmoid_log' + origin : 2-tuple, optional + specify the origin. If not passed, looks for a value for the probe + origin in metadata. If not found there, calls .measure_disk. + data : 2d array, optional + if specified, uses this array instead of the probe image to compute + the kernel + **kwargs + see descriptions above + + Returns + ------- + kernel : 2D array + """ + + modes = ["flat", "gaussian", "sigmoid", "sigmoid_log"] + + # parse args + assert mode in modes, f"mode must be in {modes}. Received {mode}" + + # get function + function_dict = { + "flat": self.get_probe_kernel_flat, + "gaussian": self.get_probe_kernel_edge_gaussian, + "sigmoid": self._get_probe_kernel_edge_sigmoid_sine_squared, + "sigmoid_log": self._get_probe_kernel_edge_sigmoid_sine_squared, + } + fn = function_dict[mode] + + # check for the origin + if origin is None: + try: + x = self.calibration.get_probe_params() + except AttributeError: + x = None + finally: + if x is None: + origin = None + else: + r, x, y = x + origin = (x, y) + + # get the data + probe = data if data is not None else self.probe + + # compute + kern = fn(probe, origin=origin, **kwargs) + + # add to the Probe + self.kernel = kern + + # return + if returncalc: + return kern + + @staticmethod + def get_probe_kernel_flat(probe, origin=None, bilinear=False): + """ + Creates a cross-correlation kernel from the vacuum probe by normalizing + and shifting the center. + + Parameters + ---------- + probe : 2d array + the vacuum probe + origin : 2-tuple (optional) + the origin of diffraction space. If not specified, finds the origin + using get_probe_radius. + bilinear : bool (optional) + By default probe is shifted via a Fourier transform. Setting this to + True overrides it and uses bilinear shifting. Not recommended! + + Returns + ------- + kernel : ndarray + the cross-correlation kernel corresponding to the probe, in real + space + """ + Q_Nx, Q_Ny = probe.shape + + # Get CoM + if origin is None: + from py4DSTEM.process.calibration import get_probe_size + + _, xCoM, yCoM = get_probe_size(probe) + else: + xCoM, yCoM = origin + + # Normalize + probe = probe / np.sum(probe) + + # Shift center to corners of array + probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) + + # Return + return probe_kernel + + @staticmethod + def get_probe_kernel_edge_gaussian( + probe, + sigma, + origin=None, + bilinear=True, + ): + """ + Creates a cross-correlation kernel from the probe, subtracting a + gaussian from the normalized probe such that the kernel integrates to + zero, then shifting the center of the probe to the array corners. + + Parameters + ---------- + probe : ndarray + the diffraction pattern corresponding to the probe over vacuum + sigma : float + the width of the gaussian to subtract, relative to the standard + deviation of the probe + origin : 2-tuple (optional) + the origin of diffraction space. If not specified, finds the origin + using get_probe_radius. + bilinear : bool + By default probe is shifted via a Fourier transform. Setting this to + True overrides it and uses bilinear shifting. Not recommended! + + Returns + ------- + kernel : ndarray + the cross-correlation kernel + """ + Q_Nx, Q_Ny = probe.shape + + # Get CoM + if origin is None: + from py4DSTEM.process.calibration import get_probe_size + + _, xCoM, yCoM = get_probe_size(probe) + else: + xCoM, yCoM = origin + + # Shift probe to origin + probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) + + # Generate normalization kernel + # Coordinates + qy, qx = np.meshgrid( + np.mod(np.arange(Q_Ny) + Q_Ny // 2, Q_Ny) - Q_Ny // 2, + np.mod(np.arange(Q_Nx) + Q_Nx // 2, Q_Nx) - Q_Nx // 2, + ) + qr2 = qx**2 + qy**2 + # Calculate Gaussian normalization kernel + qstd2 = np.sum(qr2 * probe_kernel) / np.sum(probe_kernel) + kernel_norm = np.exp(-qr2 / (2 * qstd2 * sigma**2)) + + # Output normalized kernel + probe_kernel = probe_kernel / np.sum(probe_kernel) - kernel_norm / np.sum( + kernel_norm + ) + + return probe_kernel + + @staticmethod + def get_probe_kernel_edge_sigmoid( + probe, + radii, + origin=None, + type="sine_squared", + bilinear=True, + ): + """ + Creates a convolution kernel from an average probe, subtracting an annular + trench about the probe such that the kernel integrates to zero, then + shifting the center of the probe to the array corners. + + Parameters + ---------- + probe : ndarray + the diffraction pattern corresponding to the probe over vacuum + radii : 2-tuple + the sigmoid inner and outer radii + origin : 2-tuple (optional) + the origin of diffraction space. If not specified, finds the origin + using get_probe_radius. + type : string + must be 'logistic' or 'sine_squared' + bilinear : bool + By default probe is shifted via a Fourier transform. Setting this to + True overrides it and uses bilinear shifting. Not recommended! + + Returns + ------- + kernel : 2d array + the cross-correlation kernel + """ + # parse inputs + if isinstance(probe, Probe): + probe = probe.probe + + valid_types = ("logistic", "sine_squared") + assert type in valid_types, "type must be in {}".format(valid_types) + Q_Nx, Q_Ny = probe.shape + ri, ro = radii + + # Get CoM + if origin is None: + from py4DSTEM.process.calibration import get_probe_size + + _, xCoM, yCoM = get_probe_size(probe) + else: + xCoM, yCoM = origin + + # Shift probe to origin + probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) + + # Generate normalization kernel + # Coordinates + qy, qx = np.meshgrid( + np.mod(np.arange(Q_Ny) + Q_Ny // 2, Q_Ny) - Q_Ny // 2, + np.mod(np.arange(Q_Nx) + Q_Nx // 2, Q_Nx) - Q_Nx // 2, + ) + qr = np.sqrt(qx**2 + qy**2) + # Calculate sigmoid + if type == "logistic": + r0 = 0.5 * (ro + ri) + sigma = 0.25 * (ro - ri) + sigmoid = 1 / (1 + np.exp((qr - r0) / sigma)) + elif type == "sine_squared": + sigmoid = (qr - ri) / (ro - ri) + sigmoid = np.minimum(np.maximum(sigmoid, 0.0), 1.0) + sigmoid = np.cos((np.pi / 2) * sigmoid) ** 2 + else: + raise Exception("type must be in {}".format(valid_types)) + + # Output normalized kernel + probe_kernel = probe_kernel / np.sum(probe_kernel) - sigmoid / np.sum(sigmoid) + + return probe_kernel + + def _get_probe_kernel_edge_sigmoid_sine_squared( + self, + probe, + radii, + origin=None, + **kwargs, + ): + return self.get_probe_kernel_edge_sigmoid( + probe, + radii, + origin=origin, + type="sine_squared", + **kwargs, + ) + + def _get_probe_kernel_edge_sigmoid_logistic( + self, + probe, + radii, + origin=None, + **kwargs, + ): + return self.get_probe_kernel_edge_sigmoid( + probe, radii, origin=origin, type="logistic", **kwargs + ) diff --git a/py4DSTEM/datacube/preprocess/__init__.py b/py4DSTEM/datacube/preprocess/__init__.py new file mode 100644 index 000000000..2bb70086a --- /dev/null +++ b/py4DSTEM/datacube/preprocess/__init__.py @@ -0,0 +1,5 @@ +from py4DSTEM.datacube.preprocess.preprocess import * +from py4DSTEM.datacube.preprocess.darkreference import * +from py4DSTEM.datacube.preprocess.electroncount import * +from py4DSTEM.datacube.preprocess.radialbkgrd import * + diff --git a/py4DSTEM/datacube/preprocess/darkreference.py b/py4DSTEM/datacube/preprocess/darkreference.py new file mode 100644 index 000000000..a23d4271b --- /dev/null +++ b/py4DSTEM/datacube/preprocess/darkreference.py @@ -0,0 +1,200 @@ +# Functions for background fitting and subtraction. + +import numpy as np + +#### Subtrack darkreference from datacube frame at (Rx,Ry) #### + + +def get_bksbtr_DP(datacube, darkref, Rx, Ry): + """ + Returns a background subtracted diffraction pattern. + + Args: + datacube (DataCube): data to background subtract + darkref (ndarray): dark reference. must have shape (datacube.Q_Nx, datacube.Q_Ny) + Rx,Ry (int): the scan position of the diffraction pattern of interest + + Returns: + (ndarray) the background subtracted diffraction pattern + """ + assert darkref.shape == ( + datacube.Q_Nx, + datacube.Q_Ny, + ), "background must have shape (datacube.Q_Nx, datacube.Q_Ny)" + return datacube.data[Rx, Ry, :, :].astype(float) - darkref.astype(float) + + +#### Get dark reference #### + + +def get_darkreference( + datacube, N_frames, width_x=0, width_y=0, side_x="end", side_y="end" +): + """ + Gets a dark reference image. + + Select N_frames random frames (DPs) from datacube. Find streaking noise in the + horizontal and vertical directions, by finding the average values along a thin strip + of width_x/width_y pixels along the detector edges. Which edges are used is + controlled by side_x/side_y, which must be 'start' or 'end'. Streaks along only one + direction can be used by setting width_x or width_y to 0, which disables correcting + streaks in this direction. + + Note that the data is cast to float before computing the background, and should + similarly be cast to float before performing a subtraction. This avoids integer + clipping and wraparound errors. + + Args: + datacube (DataCube): data to background subtract + N_frames (int): number of random diffraction patterns to use + width_x (int): width of the ROI strip for finding streaking in x + width_y (int): see above + side_x (str): use a strip from the start or end of the array. Must be 'start' or + 'end', defaults to 'end' + side_y (str): see above + + Returns: + (ndarray): a 2D ndarray of shape (datacube.Q_Nx, datacube.Ny) giving the + background. + """ + if width_x == 0 and width_y == 0: + print( + "Warning: either width_x or width_y should be a positive integer. Returning an empty dark reference." + ) + return np.zeros((datacube.Q_Nx, datacube.Q_Ny)) + elif width_x == 0: + return get_background_streaks_y( + datacube=datacube, N_frames=N_frames, width=width_y, side=side_y + ) + elif width_y == 0: + return get_background_streaks_x( + datacube=datacube, N_frames=N_frames, width=width_x, side=side_x + ) + else: + darkref_x = get_background_streaks_x( + datacube=datacube, N_frames=N_frames, width=width_x, side=side_x + ) + darkref_y = get_background_streaks_y( + datacube=datacube, N_frames=N_frames, width=width_y, side=side_y + ) + return ( + darkref_x + + darkref_y + - (np.mean(darkref_x) * width_x + np.mean(darkref_y) * width_y) + / (width_x + width_y) + ) + # Mean has been added twice; subtract one off + + +def get_background_streaks(datacube, N_frames, width, side="end", direction="x"): + """ + Gets background streaking in either the x- or y-direction, by finding the average of + a strip of pixels along the edge of the detector over a random selection of + diffraction patterns, and returns a dark reference array. + + Note that the data is cast to float before computing the background, and should + similarly be cast to float before performing a subtraction. This avoids integer + clipping and wraparound errors. + + Args: + datacube (DataCube): data to background subtract + N_frames (int): number of random frames to use + width (int): width of the ROI strip for background identification + side (str, optional): use a strip from the start or end of the array. Must be + 'start' or 'end', defaults to 'end' + directions (str): the direction of background streaks to find. Must be either + 'x' or 'y' defaults to 'x' + + Returns: + (ndarray): a 2D ndarray of shape (datacube.Q_Nx,datacube.Q_Ny), giving the + the x- or y-direction background streaking. + """ + assert (direction == "x") or (direction == "y"), "direction must be 'x' or 'y'." + if direction == "x": + return get_background_streaks_x( + datacube=datacube, N_frames=N_frames, width=width, side=side + ) + else: + return get_background_streaks_y( + datacube=datacube, N_frames=N_frames, width=width, side=side + ) + + +def get_background_streaks_x(datacube, width, N_frames, side="start"): + """ + Gets background streaking, by finding the average of a strip of pixels along the + y-edge of the detector over a random selection of diffraction patterns. + + See docstring for get_background_streaks() for more info. + """ + assert ( + N_frames <= datacube.R_Nx * datacube.R_Ny + ), "N_frames must be less than or equal to the total number of diffraction patterns." + assert (side == "start") or (side == "end"), "side must be 'start' or 'end'." + + # Get random subset of DPs + indices = np.arange(datacube.R_Nx * datacube.R_Ny) + np.random.shuffle(indices) + indices = indices[:N_frames] + indices_x, indices_y = np.unravel_index(indices, (datacube.R_Nx, datacube.R_Ny)) + + # Make a reference strip array + refstrip = np.zeros((width, datacube.Q_Ny)) + if side == "start": + for i in range(N_frames): + refstrip += datacube.data[indices_x[i], indices_y[i], :width, :].astype( + float + ) + else: + for i in range(N_frames): + refstrip += datacube.data[indices_x[i], indices_y[i], -width:, :].astype( + float + ) + + # Calculate mean and return 1D array of streaks + bkgrnd_streaks = np.sum(refstrip, axis=0) // width // N_frames + + # Broadcast to 2D array + darkref = np.zeros((datacube.Q_Nx, datacube.Q_Ny)) + darkref += bkgrnd_streaks[np.newaxis, :] + return darkref + + +def get_background_streaks_y(datacube, N_frames, width, side="start"): + """ + Gets background streaking, by finding the average of a strip of pixels along the + x-edge of the detector over a random selection of diffraction patterns. + + See docstring for get_background_streaks_1D() for more info. + """ + assert ( + N_frames <= datacube.R_Nx * datacube.R_Ny + ), "N_frames must be less than or equal to the total number of diffraction patterns." + assert (side == "start") or (side == "end"), "side must be 'start' or 'end'." + + # Get random subset of DPs + indices = np.arange(datacube.R_Nx * datacube.R_Ny) + np.random.shuffle(indices) + indices = indices[:N_frames] + indices_x, indices_y = np.unravel_index(indices, (datacube.R_Nx, datacube.R_Ny)) + + # Make a reference strip array + refstrip = np.zeros((datacube.Q_Nx, width)) + if side == "start": + for i in range(N_frames): + refstrip += datacube.data[indices_x[i], indices_y[i], :, :width].astype( + float + ) + else: + for i in range(N_frames): + refstrip += datacube.data[indices_x[i], indices_y[i], :, -width:].astype( + float + ) + + # Calculate mean and return 1D array of streaks + bkgrnd_streaks = np.sum(refstrip, axis=1) // width // N_frames + + # Broadcast to 2D array + darkref = np.zeros((datacube.Q_Nx, datacube.Q_Ny)) + darkref += bkgrnd_streaks[:, np.newaxis] + return darkref diff --git a/py4DSTEM/datacube/preprocess/electroncount.py b/py4DSTEM/datacube/preprocess/electroncount.py new file mode 100644 index 000000000..ea505ebc5 --- /dev/null +++ b/py4DSTEM/datacube/preprocess/electroncount.py @@ -0,0 +1,460 @@ +# Electron counting +# +# Includes functions for electron counting on either the CPU (electron_count) +# or the GPU (electron_count_GPU). For GPU electron counting, pytorch is used +# to interface between numpy and the GPU, and the datacube is expected in +# numpy.memmap (memory mapped) form. + +import numpy as np +from scipy import optimize + +from emdfile import PointListArray +from py4DSTEM.utils import get_maxima_2D, bin2D + + +def electron_count( + datacube, + darkreference, + Nsamples=40, + thresh_bkgrnd_Nsigma=4, + thresh_xray_Nsigma=10, + binfactor=1, + sub_pixel=True, + output="pointlist", +): + """ + Performs electron counting. + + The algorithm is as follows: + From a random sampling of frames, calculate an x-ray and background + threshold value. In each frame, subtract the dark reference, then apply the + two thresholds. Find all local maxima with respect to the nearest neighbor + pixels. These are considered electron strike events. + + Thresholds are specified in units of standard deviations, either of a + gaussian fit to the histogram background noise (for thresh_bkgrnd) or of + the histogram itself (for thresh_xray). The background (lower) threshold is + more important; we will always be missing some real electron counts and + incorrectly counting some noise as electron strikes - this threshold + controls their relative balance. The x-ray threshold may be set fairly high. + + Args: + datacube: a 4D numpy.ndarray pointing to the datacube. Note: the R/Q axes are + flipped with respect to py4DSTEM DataCube objects + darkreference: a 2D numpy.ndarray with the dark reference + Nsamples: the number of frames to use in dark reference and threshold + calculation. + thresh_bkgrnd_Nsigma: the background threshold is + ``mean(guassian fit) + (this #)*std(gaussian fit)`` + where the gaussian fit is to the background noise. + thresh_xray_Nsigma: the X-ray threshold is + ``mean(hist) +/- (this #)*std(hist)`` + where hist is the histogram of all pixel values in the Nsamples random frames + binfactor: the binnning factor + sub_pixel (bool): controls whether subpixel refinement is performed + output (str): controls output format; must be 'datacube' or 'pointlist' + + Returns: + (variable) if output=='pointlist', returns a PointListArray of all electron + counts in each frame. If output=='datacube', returns a 4D array of bools, with + True indicating electron strikes + """ + assert isinstance(output, str), "output must be a str" + assert output in [ + "pointlist", + "datacube", + ], "output must be 'pointlist' or 'datacube'" + + # Get dimensions + R_Nx, R_Ny, Q_Nx, Q_Ny = np.shape(datacube) + + # Get threshholds + print("Calculating threshholds") + thresh_bkgrnd, thresh_xray = calculate_thresholds( + datacube, + darkreference, + Nsamples=Nsamples, + thresh_bkgrnd_Nsigma=thresh_bkgrnd_Nsigma, + thresh_xray_Nsigma=thresh_xray_Nsigma, + ) + + # Save to a new datacube + if output == "datacube": + counted = np.ones((R_Nx, R_Ny, Q_Nx // binfactor, Q_Ny // binfactor)) + # Loop through frames + for Rx in range(R_Nx): + for Ry in range(R_Ny): + frame = datacube[Rx, Ry, :, :].astype(np.int16) # Get frame from file + workingarray = frame - darkreference # Subtract dark ref from frame + events = workingarray > thresh_bkgrnd # Threshold electron events + events *= thresh_xray > workingarray + + ## Keep events which are greater than all NN pixels ## + events = get_maxima_2D(workingarray * events) + + if binfactor > 1: + # Perform binning + counted[Rx, Ry, :, :] = bin2D(events, factor=binfactor) + else: + counted[Rx, Ry, :, :] = events + return counted + + # Save to a PointListArray + else: + coordinates = [("qx", int), ("qy", int)] + pointlistarray = PointListArray(coordinates=coordinates, shape=(R_Nx, R_Ny)) + # Loop through frames + for Rx in range(R_Nx): + for Ry in range(R_Ny): + frame = datacube[Rx, Ry, :, :].astype(np.int16) # Get frame from file + workingarray = frame - darkreference # Subtract dark ref from frame + events = workingarray > thresh_bkgrnd # Threshold electron events + events *= thresh_xray > workingarray + + ## Keep events which are greater than all NN pixels ## + events = get_maxima_2D(workingarray * events) + + # Perform binning + if binfactor > 1: + events = bin2D(events, factor=binfactor) + + # Save to PointListArray + x, y = np.nonzero(events) + pointlist = pointlistarray.get_pointlist(Rx, Ry) + pointlist.add_tuple_of_nparrays((x, y)) + + return pointlistarray + + +def electron_count_GPU( + datacube, + darkreference, + Nsamples=40, + thresh_bkgrnd_Nsigma=4, + thresh_xray_Nsigma=10, + binfactor=1, + sub_pixel=True, + output="pointlist", +): + """ + Performs electron counting on the GPU. + + Uses pytorch to interface between numpy and cuda. Requires cuda and pytorch. + This function expects datacube to be a np.memmap object. + See electron_count() for additional documentation. + """ + import torch + import dm + + assert isinstance(output, str), "output must be a str" + assert output in [ + "pointlist", + "datacube", + ], "output must be 'pointlist' or 'datacube'" + + # Get dimensions + R_Nx, R_Ny, Q_Nx, Q_Ny = np.shape(datacube) + + # Get threshholds + print("Calculating threshholds") + thresh_bkgrnd, thresh_xray = calculate_thresholds( + datacube, + darkreference, + Nsamples=Nsamples, + thresh_bkgrnd_Nsigma=thresh_bkgrnd_Nsigma, + thresh_xray_Nsigma=thresh_xray_Nsigma, + ) + + # Make a torch device object, to interface numpy with the GPU + # Put a few arrays on it - dark reference, counted image + device = torch.device("cuda") + darkref = torch.from_numpy(darkreference.astype(np.int16)).to(device) + counted = torch.ones( + R_Nx, R_Ny, Q_Nx // binfactor, Q_Ny // binfactor, dtype=torch.short + ).to(device) + + # Loop through frames + for Rx in range(R_Nx): + for Ry in range(R_Ny): + frame = datacube[Rx, Ry, :, :].astype(np.int16) # Get frame from file + gframe = torch.from_numpy(frame).to(device) # Move frame to GPU + workingarray = gframe - darkref # Subtract dark ref from frame + events = workingarray > thresh_bkgrnd # Threshold electron events + events = thresh_xray > workingarray + + ## Keep events which are greater than all NN pixels ## + + # Check pixel is greater than all adjacent pixels + log = workingarray[1:-1, :] > workingarray[0:-2, :] + events[1:-1, :] = events[1:-1, :] & log + log = workingarray[0:-2, :] > workingarray[1:-1, :] + events[0:-2, :] = events[0:-2, :] & log + log = workingarray[:, 1:-1] > workingarray[:, 0:-2] + events[:, 1:-1] = events[:, 1:-1] & log + log = workingarray[:, 0:-2] > workingarray[:, 1:-1] + events[:, 0:-2] = events[:, 0:-2] & log + # Check pixel is greater than adjacent diagonal pixels + log = workingarray[1:-1, 1:-1] > workingarray[0:-2, 0:-2] + events[1:-1, 1:-1] = events[1:-1, 1:-1] & log + log = workingarray[0:-2, 1:-1] > workingarray[1:-1, 0:-2] + events[0:-2, 1:-1] = events[0:-2, 1:-1] & log + log = workingarray[1:-1, 0:-2] > workingarray[0:-2, 1:-1] + events[2:-1, 0:-2] = events[1:-1, 0:-2] & log + log = workingarray[0:-2, 0:-2] > workingarray[1:-1, 1:-1] + events[0:-2, 0:-2] = events[0:-2, 0:-2] & log + + if binfactor > 1: + # Perform binning on GPU in torch_bin function + counted[Rx, Ry, :, :] = ( + torch.transpose( + torch_bin( + events.type(torch.cuda.ShortTensor), + device, + factor=binfactor, + ), + 0, + 1, + ) + .flip(0) + .flip(1) + ) + else: + # I'm not sure I understand this - we're flipping coordinates to match what? + # TODO: check array flipping - may vary by camera + counted[Rx, Ry, :, :] = ( + torch.transpose(events.type(torch.cuda.ShortTensor), 0, 1) + .flip(0) + .flip(1) + ) + + if output == "datacube": + return counted.cpu().numpy() + else: + return counted_datacube_to_pointlistarray(counted) + + +####### Support functions ######## + + +def calculate_thresholds( + datacube, + darkreference, + Nsamples=20, + thresh_bkgrnd_Nsigma=4, + thresh_xray_Nsigma=10, + return_params=False, +): + """ + Calculate the upper and lower thresholds for thresholding what to register as + an electron count. + + Both thresholds are determined from the histogram of detector pixel values summed + over Nsamples frames. The thresholds are set to:: + + thresh_xray_Nsigma = mean(histogram) + thresh_upper * std(histogram) + thresh_bkgrnd_N_sigma = mean(guassian fit) + thresh_lower * std(gaussian fit) + + For more info, see the electron_count docstring. + + Args: + datacube: a 4D numpy.ndarrau pointing to the datacube + darkreference: a 2D numpy.ndarray with the dark reference + Nsamples: the number of frames to use in dark reference and threshold + calculation. + thresh_bkgrnd_Nsigma: the background threshold is + ``mean(guassian fit) + (this #)*std(gaussian fit)`` + where the gaussian fit is to the background noise. + thresh_xray_Nsigma: the X-ray threshold is + ``mean(hist) + (this #)*std(hist)`` + where hist is the histogram of all pixel values in the Nsamples random frames + return_params: bool, if True return n,hist of the histogram and popt of the + gaussian fit + + Returns: + (5-tuple): A 5-tuple containing: + + * **thresh_bkgrnd**: the background threshold + * **thresh_xray**: the X-ray threshold + * **n**: returned iff return_params==True. The histogram values + * **hist**: returned iff return_params==True. The histogram bin edges + * **popt**: returned iff return_params==True. The fit gaussian parameters, + (A, mu, sigma). + """ + R_Nx, R_Ny, Q_Nx, Q_Ny = datacube.shape + + # Select random set of frames + nframes = R_Nx * R_Ny + samples = np.arange(nframes) + np.random.shuffle(samples) + samples = samples[:Nsamples] + + # Get frames and subtract dark references + sample = np.zeros((Q_Nx, Q_Ny, Nsamples), dtype=np.int16) + for i in range(Nsamples): + sample[:, :, i] = datacube[samples[i] // R_Nx, samples[i] % R_Ny, :, :] + sample[:, :, i] -= darkreference + sample = np.ravel(sample) # Flatten array + + # Get upper (X-ray) threshold + mean = np.mean(sample) + stddev = np.std(sample) + thresh_xray = mean + thresh_xray_Nsigma * stddev + + # Make a histogram + binmax = min(int(np.ceil(np.amax(sample))), int(mean + thresh_xray * stddev)) + binmin = max(int(np.ceil(np.amin(sample))), int(mean - thresh_xray * stddev)) + step = max(1, (binmax - binmin) // 1000) + bins = np.arange(binmin, binmax, step=step, dtype=np.int16) + n, bins = np.histogram(sample, bins=bins) + + # Define Guassian to fit to, with parameters p: + # p[0] is amplitude + # p[1] is the mean + # p[2] is std deviation + fitfunc = lambda p, x: p[0] * np.exp(-0.5 * np.square((x - p[1]) / p[2])) + errfunc = lambda p, x, y: fitfunc(p, x) - y # Error for scipy's optimize routine + + # Get initial guess + p0 = [n.max(), (bins[n.argmax() + 1] - bins[n.argmax()]) / 2, np.std(sample)] + p1, success = optimize.leastsq( + errfunc, p0[:], args=(bins[:-1], n) + ) # Use the scipy optimize routine + p1[1] += 0.5 # Add a half to account for integer bin width + + # Set lower threshhold for electron counts to count + thresh_bkgrnd = p1[1] + p1[2] * thresh_bkgrnd_Nsigma + + if return_params: + return thresh_bkgrnd, thresh_xray, n, bins, p1 + else: + return thresh_bkgrnd, thresh_xray + + +def torch_bin(array, device, factor=2): + """ + Bin data on the GPU using torch. + + Args: + array: a 2D numpy array + device: a torch device class instance + factor (int): the binning factor + + Returns: + (array): the binned array + """ + + import torch + + x, y = array.shape + binx, biny = x // factor, y // factor + xx, yy = binx * factor, biny * factor + + # Make a binned array on the device + binned_ar = torch.zeros(biny, binx, device=device, dtype=array.dtype) + + # Collect pixel sums into new bins + for ix in range(factor): + for iy in range(factor): + binned_ar += array[0 + ix : xx + ix : factor, 0 + iy : yy + iy : factor] + return binned_ar + + +def counted_datacube_to_pointlistarray(counted_datacube, subpixel=False): + """ + Converts an electron counted datacube to PointListArray. + + Args: + counted_datacube: a 4D array of bools, with true indicating an electron strike. + subpixel (bool): controls if subpixel electron strike positions are expected + + Returns: + (PointListArray): a PointListArray of electron strike events + """ + # Get shape, initialize PointListArray + R_Nx, R_Ny, Q_Nx, Q_Ny = counted_datacube.shape + if subpixel: + coordinates = [("qx", float), ("qy", float)] + else: + coordinates = [("qx", int), ("qy", int)] + pointlistarray = PointListArray(coordinates=coordinates, shape=(R_Nx, R_Ny)) + + # Loop through frames, adding electron counts to the PointListArray for each. + for Rx in range(R_Nx): + for Ry in range(R_Ny): + frame = counted_datacube[Rx, Ry, :, :] + x, y = np.nonzero(frame) + pointlist = pointlistarray.get_pointlist(Rx, Ry) + pointlist.add_tuple_of_nparrays((x, y)) + + return pointlistarray + + +def counted_pointlistarray_to_datacube(counted_pointlistarray, shape, subpixel=False): + """ + Converts an electron counted PointListArray to a datacube. + + Args: + counted_pointlistarray (PointListArray): a PointListArray of electron strike + events + shape (4-tuple): a length 4 tuple of ints containing (R_Nx,R_Ny,Q_Nx,Q_Ny) + subpixel (bool): controls if subpixel electron strike positions are expected + + Returns: + (4D array of bools): a 4D array of bools, with true indicating an electron strike. + """ + assert len(shape) == 4 + assert subpixel is False, "subpixel mode not presently supported." + R_Nx, R_Ny, Q_Nx, Q_Ny = shape + counted_datacube = np.zeros((R_Nx, R_Nx, Q_Nx, Q_Ny), dtype=bool) + + # Loop through frames, adding electron counts to the datacube for each. + for Rx in range(R_Nx): + for Ry in range(R_Ny): + pointlist = counted_pointlistarray.get_pointlist(Rx, Ry) + counted_datacube[Rx, Ry, pointlist.data["qx"], pointlist.data["qy"]] = True + + return counted_datacube + + +if __name__ == "__main__": + from py4DSTEM.process.preprocess import get_darkreference + from py4DSTEM.io import DataCube, save + from ncempy.io import dm + + dm4_filepath = "Capture25.dm4" + + # Parameters for dark reference determination + drwidth = 100 + + # Parameters for electron counting + Nsamples = 40 + thresh_bkgrnd_Nsigma = 4 + thresh_xray_Nsigma = 30 + binfactor = 1 + subpixel = False + output = "pointlist" + + # Get memory mapped 4D datacube from dm file + datacube = dm.dmReader(dm4_filepath, dSetNum=0, verbose=False)["data"] + datacube = np.moveaxis(datacube, (0, 1), (2, 3)) + + # Get dark reference + darkreference = 1 # TODO: get_darkreference(datacube = ...! + + electron_counted_data = electron_count( + datacube, + darkreference, + Nsamples=Nsamples, + thresh_bkgrnd_Nsigma=thresh_bkgrnd_Nsigma, + thresh_xray_Nsigma=thresh_xray_Nsigma, + binfactor=binfactor, + sub_pixel=True, + output="pointlist", + ) + + # For outputting datacubes, wrap counted into a py4DSTEM DataCube + if output == "datacube": + electron_counted_data = DataCube(data=electron_counted_data) + + output_path = dm4_filepath.replace(".dm4", ".h5") + save(electron_counted_data, output_path) diff --git a/py4DSTEM/datacube/preprocess/preprocess.py b/py4DSTEM/datacube/preprocess/preprocess.py new file mode 100644 index 000000000..88894dd6a --- /dev/null +++ b/py4DSTEM/datacube/preprocess/preprocess.py @@ -0,0 +1,944 @@ +import numpy as np +from typing import Optional +from scipy.ndimage import median_filter +import warnings + +from emdfile import tqdmnd, Metadata +from py4DSTEM.utils import bin2D, get_shifted_ar, fourier_resample +from py4DSTEM.data import DiffractionSlice, Data + + + +# DataCube methods + +class Preprocessor: + def __init__(self): + pass + + + ### Set & manipulate dimensions, bin, crop + + def set_scan_shape(self, Rshape): + """ + Reshape the data given the real space scan shape. Accepts: Rshape (2-tuple) + """ + assert len(Rshape) == 2, "Rshape must have a length of 2" + try: + # reshape + self.data = self.data.reshape( + Rshape[0], Rshape[1], self.Q_Nx, self.Q_Ny) + + # TODO - restruct + # set dim vectors + Rpixsize = self.calibration.get_R_pixel_size() + Rpixunits = self.calibration.get_R_pixel_units() + self.set_dim(0, [0, Rpixsize], units=Rpixunits) + self.set_dim(1, [0, Rpixsize], units=Rpixunits) + + # return + return self + + except ValueError: + print(f"Can't reshape {self.R_N} scan positions into a {Rshape} shaped array. Returning") + return self + except AttributeError: + print(f"Can't reshape self.") + return self + + def swap_RQ(self): + """ + Swaps the first and last two dimensions of the 4D self. + """ + self.data = np.transpose(self.data, axes=(2, 3, 0, 1)) + + # TODO + # set dim vectors + Rpixsize = self.calibration.get_R_pixel_size() + Rpixunits = self.calibration.get_R_pixel_units() + Qpixsize = self.calibration.get_Q_pixel_size() + Qpixunits = self.calibration.get_Q_pixel_units() + self.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx") + self.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry") + self.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx") + self.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy") + + # return + return self + + def swap_Rxy(self): + """ + Swaps the real space x and y coordinates. + """ + # swap + self.data = np.moveaxis(self.data, 1, 0) + + # TODO + # set dim vectors + Rpixsize = self.calibration.get_R_pixel_size() + Rpixunits = self.calibration.get_R_pixel_units() + self.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx") + self.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry") + + # return + return self + + def swap_Qxy(self): + """ + Swaps the diffraction space x and y coordinates. + """ + self.data = np.moveaxis(self.data, 3, 2) + return self + + def crop_Q(self, ROI): + """ + Crops the data in diffraction space about the region specified by ROI. + + Accepts: + ROI (4-tuple): Specifies (Qx_min,Qx_max,Qy_min,Qy_max) + """ + assert len(ROI) == 4, "Crop region `ROI` must have length 4" + self.data = self.data[ :, :, ROI[0]:ROI[1], ROI[2]:ROI[3]] + + # TODO + # set dim vectors + Qpixsize = self.calibration.get_Q_pixel_size() + Qpixunits = self.calibration.get_Q_pixel_units() + self.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx") + self.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy") + + # return + return self + + def crop_R(self, ROI): + """ + Crops the data in real space about the region specified by ROI. + + Accepts: + ROI (4-tuple): Specifies (Rx_min,Rx_max,Ry_min,Ry_max) + """ + assert len(ROI) == 4, "Crop region `ROI` must have length 4" + self.data = self.data[ ROI[0]:ROI[1] , ROI[2]:ROI[3] ] + + # TODO + # set dim vectors + Rpixsize = self.calibration.get_R_pixel_size() + Rpixunits = self.calibration.get_R_pixel_units() + self.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx") + self.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry") + + # return + return self + + def bin_Q(self, N, dtype=None): + """ + Bins the data in diffraction space by bin factor N + + Parameters + ---------- + N : int + The binning factor + dtype : a datatype (optional) + Specify the datatype for the output. If not passed, the datatype + is left unchanged + + Returns + ------ + self : DataCube + """ + # validate inputs + assert type(bin_factor) is int, f"Error: binning factor {bin_factor} is not an int." + if bin_factor == 1: + return self + if dtype is None: + dtype = self.data.dtype + + # get shape + R_Nx, R_Ny, Q_Nx, Q_Ny = ( + self.R_Nx, + self.R_Ny, + self.Q_Nx, + self.Q_Ny, + ) + # crop edges if necessary + if (Q_Nx % bin_factor == 0) and (Q_Ny % bin_factor == 0): + pass + elif Q_Nx % bin_factor == 0: + self.data = self.data[:, :, :, : -(Q_Ny % bin_factor)] + elif Q_Ny % bin_factor == 0: + self.data = self.data[:, :, : -(Q_Nx % bin_factor), :] + else: + self.data = self.data[ + :, :, : -(Q_Nx % bin_factor), : -(Q_Ny % bin_factor) + ] + + # bin + self.data = ( + self.data.reshape( + R_Nx, + R_Ny, + int(Q_Nx / bin_factor), + bin_factor, + int(Q_Ny / bin_factor), + bin_factor, + ) + .sum(axis=(3, 5)) + .astype(dtype) + ) + + # TODO + # set dim vectors + Qpixsize = self.calibration.get_Q_pixel_size() * bin_factor + Qpixunits = self.calibration.get_Q_pixel_units() + self.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx") + self.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy") + # set calibration pixel size + self.calibration.set_Q_pixel_size(Qpixsize) + + # return + return self + + def pad_Q(self, N=None, output_size=None): + """ + Pads the data in diffraction space by pad factor N, or to match output_size. + + Accepts: + N (float, or Sequence[float]): the padding factor + output_size ((int,int)): the padded output size + """ + Qx, Qy = self.shape[-2:] + + if pad_factor is not None: + if output_size is not None: + raise ValueError( + "Only one of 'pad_factor' or 'output_size' can be specified." + ) + + pad_factor = np.array(pad_factor) + if pad_factor.shape == (): + pad_factor = np.tile(pad_factor, 2) + + if np.any(pad_factor < 1): + raise ValueError("'pad_factor' needs to be larger than 1.") + + pad_kx = np.round(Qx * (pad_factor[0] - 1) / 2).astype("int") + pad_kx = (pad_kx, pad_kx) + pad_ky = np.round(Qy * (pad_factor[1] - 1) / 2).astype("int") + pad_ky = (pad_ky, pad_ky) + + else: + if output_size is None: + raise ValueError( + "At-least one of 'pad_factor' or 'output_size' must be specified." + ) + + if len(output_size) != 2: + raise ValueError( + f"'output_size' must have length 2, not {len(output_size)}" + ) + + Sx, Sy = output_size + + if Sx < Qx or Sy < Qy: + raise ValueError(f"'output_size' must be at-least as large as {(Qx,Qy)}.") + + pad_kx = Sx - Qx + pad_kx = (pad_kx // 2, pad_kx // 2 + pad_kx % 2) + + pad_ky = Sy - Qy + pad_ky = (pad_ky // 2, pad_ky // 2 + pad_ky % 2) + + pad_width = ( + (0, 0), + (0, 0), + pad_kx, + pad_ky, + ) + + self.data = np.pad(self.data, pad_width=pad_width, mode="constant") + + # TODO + Qpixsize = self.calibration.get_Q_pixel_size() + Qpixunits = self.calibration.get_Q_pixel_units() + self.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx") + self.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy") + + self.calibrate() + + return self + + def resample_Q(self, N=None, output_size=None, method="bilinear"): + """ + Resamples the data in diffraction space by resampling factor N, or to match output_size, + using either 'fourier' or 'bilinear' interpolation. + + Accepts: + N (float, or Sequence[float]): the resampling factor + output_size ((int,int)): the resampled output size + method (str): 'fourier' or 'bilinear' (default) + """ + if method == "fourier": + if np.size(resampling_factor) != 1: + warnings.warn( + ( + "Fourier resampling currently only accepts a scalar resampling_factor. " + f"'resampling_factor' set to {resampling_factor[0]}." + ), + UserWarning, + ) + resampling_factor = resampling_factor[0] + + old_size = self.data.shape + + self.data = fourier_resample( + self.data, scale=resampling_factor, output_size=output_size + ) + + if not resampling_factor: + resampling_factor = output_size[0] / old_size[2] + if self.calibration.get_Q_pixel_size() is not None: + self.calibration.set_Q_pixel_size( + self.calibration.get_Q_pixel_size() / resampling_factor + ) + + elif method == "bilinear": + from scipy.ndimage import zoom + + if resampling_factor is not None: + if output_size is not None: + raise ValueError( + "Only one of 'resampling_factor' or 'output_size' can be specified." + ) + + resampling_factor = np.array(resampling_factor) + if resampling_factor.shape == (): + resampling_factor = np.tile(resampling_factor, 2) + + else: + if output_size is None: + raise ValueError( + "At-least one of 'resampling_factor' or 'output_size' must be specified." + ) + + if len(output_size) != 2: + raise ValueError( + f"'output_size' must have length 2, not {len(output_size)}" + ) + + resampling_factor = np.array(output_size) / np.array(self.shape[-2:]) + + resampling_factor = np.concatenate(((1, 1), resampling_factor)) + self.data = zoom( + self.data, resampling_factor, order=1, mode="grid-wrap", grid_mode=True + ) + self.calibration.set_Q_pixel_size( + self.calibration.get_Q_pixel_size() / resampling_factor[2] + ) + + else: + raise ValueError( + f"'method' needs to be one of 'bilinear' or 'fourier', not {method}." + ) + + return self + + def bin_Q_mmap(self, N, dtype=np.float32): + """ + Bins the data in diffraction space by bin factor N for memory mapped data + + Accepts: + N (int): the binning factor + dtype: the data type + """ + # validate inputs + assert type(bin_factor) is int, f"Error: binning factor {bin_factor} is not an int." + if bin_factor == 1: + return self + + # get shape + R_Nx, R_Ny, Q_Nx, Q_Ny = ( + self.R_Nx, + self.R_Ny, + self.Q_Nx, + self.Q_Ny, + ) + # allocate space + data = np.zeros( + ( + self.R_Nx, + self.R_Ny, + self.Q_Nx // bin_factor, + self.Q_Ny // bin_factor, + ), + dtype=dtype, + ) + # bin + for Rx, Ry in tqdmnd(self.R_Ny, self.R_Ny): + data[Rx, Ry, :, :] = bin2D(self.data[Rx, Ry, :, :], bin_factor, dtype=dtype) + self.data = data + + # TODO + # set dim vectors + Qpixsize = self.calibration.get_Q_pixel_size() * bin_factor + Qpixunits = self.calibration.get_Q_pixel_units() + self.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx") + self.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy") + # set calibration pixel size + self.calibration.set_Q_pixel_size(Qpixsize) + + # return + return self + + def bin_R(self, N): + """ + Bins the data in real space by bin factor N + + Accepts: + N (int): the binning factor + """ + # validate inputs + assert type(bin_factor) is int, f"Bin factor {bin_factor} is not an int." + if bin_factor <= 1: + return self + + # set shape + R_Nx, R_Ny, Q_Nx, Q_Ny = ( + self.R_Nx, + self.R_Ny, + self.Q_Nx, + self.Q_Ny, + ) + # crop edges if necessary + if (R_Nx % bin_factor == 0) and (R_Ny % bin_factor == 0): + pass + elif R_Nx % bin_factor == 0: + self.data = self.data[:, : -(R_Ny % bin_factor), :, :] + elif R_Ny % bin_factor == 0: + self.data = self.data[: -(R_Nx % bin_factor), :, :, :] + else: + self.data = self.data[ + : -(R_Nx % bin_factor), : -(R_Ny % bin_factor), :, : + ] + # bin + self.data = self.data.reshape( + int(R_Nx / bin_factor), + bin_factor, + int(R_Ny / bin_factor), + bin_factor, + Q_Nx, + Q_Ny, + ).sum(axis=(1, 3)) + + # TODO + # set dim vectors + Rpixsize = self.calibration.get_R_pixel_size() * bin_factor + Rpixunits = self.calibration.get_R_pixel_units() + self.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx") + self.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry") + # set calibration pixel size + self.calibration.set_R_pixel_size(Rpixsize) + + # return + return self + + def thin_R(self, N): + """ + Reduces the data in real space by skipping every N patterns in the x and y directions. + + Accepts: + N (int): the thinning factor + """ + # get shapes + Rshape0 = self.Rshape + Rshapef = tuple([x // thinning_factor for x in Rshape0]) + + # allocate memory + data = np.empty( + (Rshapef[0], Rshapef[1], self.Qshape[0], self.Qshape[1]), + dtype=self.data.dtype, + ) + + # populate data + for rx, ry in tqdmnd(Rshapef[0], Rshapef[1]): + rx0 = rx * thinning_factor + ry0 = ry * thinning_factor + data[rx, ry, :, :] = self[rx0, ry0, :, :] + + self.data = data + + # TODO + # set dim vectors + Rpixsize = self.calibration.get_R_pixel_size() * thinning_factor + Rpixunits = self.calibration.get_R_pixel_units() + self.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx") + self.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry") + # set calibration pixel size + self.calibration.set_R_pixel_size(Rpixsize) + + # return + return self + + + ### Denoising + + def filter_hot_pixels(self, thresh, ind_compare=1, return_mask=False): + """ + This function performs pixel filtering to remove hot / bright pixels. + A mean diffraction pattern is calculated, then a moving local ordering filter + is applied to it, finding and sorting the intensities of the 21 pixels nearest + each pixel (where 21 = (the pixel itself) + (nearest neighbors) + (next + nearest neighbors) = (1) + (8) + (12) = 21; the next nearest neighbors + exclude the corners of the NNN square of pixels). This filter then returns + a single value at each pixel given by the N'th highest value of these 21 + sorted values, where N is specified by `ind_compare`. ind_compare=0 + specifies the highest intensity, =1 is the second hightest, etc. Next, a mask + is generated which is True for all pixels which are least a value `thresh` + higher than the local ordering filter output. Thus for the default + `ind_compare` value of 1, the mask will be True wherever the mean diffraction + pattern is higher than the second brightest pixel in it's local window by + at least a value of `thresh`. Finally, we loop through all diffraction + images, and any pixels defined by mask are replaced by their 3x3 local + median. + + Parameters + ---------- + thresh : float + Threshold for replacing hot pixels, if pixel value minus local ordering + filter exceeds it. + ind_compare : int + Which median filter value to compare against. 0 = brightest pixel, + 1 = next brightest, etc. + return_mask : bool + If True, returns the filter mask + + Returns + ------- + self : Datacube + mask : bool + (optional) the bad pixel mask + """ + + # Mean image over all probe positions + diff_mean = np.mean(self.data, axis=(0, 1)) + shape = diff_mean.shape + + # Moving local ordered pixel values + diff_local_med = np.sort( + np.vstack( + [ + np.roll(diff_mean, (-1, -1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (0, -1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (1, -1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-1, 0), axis=(0, 1)).ravel(), + np.roll(diff_mean, (0, 0), axis=(0, 1)).ravel(), + np.roll(diff_mean, (1, 0), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-1, 1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (0, 1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (1, 1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-1, -2), axis=(0, 1)).ravel(), + np.roll(diff_mean, (0, -2), axis=(0, 1)).ravel(), + np.roll(diff_mean, (1, -2), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-1, 2), axis=(0, 1)).ravel(), + np.roll(diff_mean, (0, 2), axis=(0, 1)).ravel(), + np.roll(diff_mean, (1, 2), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-2, -1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-2, 0), axis=(0, 1)).ravel(), + np.roll(diff_mean, (-2, 1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (2, -1), axis=(0, 1)).ravel(), + np.roll(diff_mean, (2, 0), axis=(0, 1)).ravel(), + np.roll(diff_mean, (2, 1), axis=(0, 1)).ravel(), + ] + ), + axis=0, + ) + # arry of the ind_compare'th pixel intensity + diff_compare = np.reshape(diff_local_med[-ind_compare - 1, :], shape) + + # Generate mask + mask = diff_mean - diff_compare > thresh + + # If the mask is empty, return + if np.sum(mask) == 0: + print("No hot pixels detected") + if return_mask is True: + return self, mask + else: + return self + + # Otherwise, apply filtering + + # Get masked indices + x_ma, y_ma = np.nonzero(mask) + + # Get local windows for each masked pixel + xslices, yslices = [], [] + for xm, ym in zip(x_ma, y_ma): + xslice, yslice = slice(xm - 1, xm + 2), slice(ym - 1, ym + 2) + if xslice.start < 0: + xslice = slice(0, xslice.stop) + elif xslice.stop > shape[0]: + xslice = slice(xslice.start, shape[0]) + if yslice.start < 0: + yslice = slice(0, yslice.stop) + elif yslice.stop > shape[1]: + yslice = slice(yslice.start, shape[1]) + xslices.append(xslice) + yslices.append(yslice) + + # Loop and replace pixels + for ax, ay in tqdmnd( + *(self.R_Nx, self.R_Ny), desc="Cleaning pixels", unit=" images" + ): + for xm, ym, xs, ys in zip(x_ma, y_ma, xslices, yslices): + self.data[ax, ay, xm, ym] = np.median(self.data[ax, ay, xs, ys]) + + # Calculate local 3x3 median images + # im_med = median_filter(self.data[ax, ay, :, :], size=3, mode="nearest") + # self.data[ax, ay, :, :][mask] = im_med[mask] + + # Return + if return_mask is True: + return self, mask + else: + return self + + + ### Background subtraction + + def get_radial_bkgrnd(self, rx, ry, sigma=2): + """ + Computes and returns a background image for the diffraction + pattern at (rx,ry), populated by radial rings of constant intensity + about the origin, with the value of each ring given by the median + value of the diffraction pattern at that radial distance. + + Parameters + ---------- + rx : int + The x-coord of the beam position + ry : int + The y-coord of the beam position + sigma : number + If >0, applying a gaussian smoothing in the radial direction + before returning + + Returns + ------- + background : ndarray + The radial background + """ + # ensure a polar cube and origin exist + assert self.polar is not None, "No polar self found!" + assert self.calibration.get_origin() is not None, "No origin found!" + + # get the 1D median background + bkgrd_ma_1d = np.ma.median(self.polar.data[rx, ry], axis=0) + bkgrd_1d = bkgrd_ma_1d.data + bkgrd_1d[bkgrd_ma_1d.mask] = 0 + + # smooth + if sigma > 0: + bkgrd_1d = gaussian_filter1d(bkgrd_1d, sigma) + + # define the 2D cartesian coordinate system + origin = self.calibration.get_origin() + origin = origin[0][rx, ry], origin[1][rx, ry] + qxx, qyy = self.qxx_raw - origin[0], self.qyy_raw - origin[1] + + # get distance qr in polar-elliptical coords + ellipse = self.calibration.get_ellipse() + ellipse = (1, 1, 0) if ellipse is None else ellipse + a, b, theta = ellipse + + qrr = np.sqrt( + ((qxx * np.cos(theta)) + (qyy * np.sin(theta))) ** 2 + + ((qxx * np.sin(theta)) - (qyy * np.cos(theta))) ** 2 / (b / a) ** 2 + ) + + # make an interpolation function and get the 2D background + f = interp1d(self.polar.radial_bins, bkgrd_1d, fill_value="extrapolate") + background = f(qrr) + + # return + return background + + def get_radial_bksb_dp(self, rx, ry, sigma=2): + """ + Computes and returns the diffraction pattern at beam position (rx,ry) + with a radial background subtracted. See the docstring for + self.get_radial_background for more info. + + Parameters + ---------- + rx : int + The x-coord of the beam position + ry : int + The y-coord of the beam position + sigma : number + If >0, applying a gaussian smoothing in the radial direction + before returning + + Returns + ------- + data : ndarray + The radial background subtracted diffraction image + """ + # get 2D background + background = self.get_radial_bkgrnd(rx, ry, sigma) + + # subtract, zero negative values, return + ans = self.data[rx, ry] - background + ans[ans < 0] = 0 + return ans + + ### Local averaging + + def get_local_ave_dp( + self, + rx, + ry, + radial_bksb=False, + sigma=2, + braggmask=False, + braggvectors=None, + braggmask_radius=None, + ): + """ + Computes and returns the diffraction pattern at beam position (rx,ry) + after weighted local averaging with its nearest-neighbor patterns, + using a 3x3 gaussian kernel for the weightings. + + Parameters + ---------- + rx : int + The x-coord of the beam position + ry : int + The y-coord of the beam position + radial_bksb : bool + It True, apply a radial background subtraction to each pattern + before averaging + sigma : number + If radial_bksb is True, use this sigma for radial smoothing of + the background + braggmask : bool + If True, masks bragg scattering at each scan position before + averaging. `braggvectors` and `braggmask_radius` must be + specified. + braggvectors : BraggVectors + The Bragg vectors to use for masking + braggmask_radius : number + The radius about each Bragg point to mask + + Returns + ------- + data : ndarray + The radial background subtracted diffraction image + """ + # define the kernel + kernel = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 16.0 + + # get shape and check for valid inputs + nx, ny = self.data.shape[:2] + assert rx >= 0 and rx < nx, "rx outside of scan range" + assert ry >= 0 and ry < ny, "ry outside of scan range" + + # get the subcube, checking for edge patterns + # and modifying the kernel as needed + if rx != 0 and rx != (nx - 1) and ry != 0 and ry != (ny - 1): + subcube = self.data[rx - 1 : rx + 2, ry - 1 : ry + 2, :, :] + elif rx == 0 and ry == 0: + subcube = self.data[:2, :2, :, :] + kernel = kernel[1:, 1:] + elif rx == 0 and ry == (ny - 1): + subcube = self.data[:2, -2:, :, :] + kernel = kernel[1:, :-1] + elif rx == (nx - 1) and ry == 0: + subcube = self.data[-2:, :2, :, :] + kernel = kernel[:-1, 1:] + elif rx == (nx - 1) and ry == (ny - 1): + subcube = self.data[-2:, -2:, :, :] + kernel = kernel[:-1, :-1] + elif rx == 0: + subcube = self.data[:2, ry - 1 : ry + 2, :, :] + kernel = kernel[1:, :] + elif rx == (nx - 1): + subcube = self.data[-2:, ry - 1 : ry + 2, :, :] + kernel = kernel[:-1, :] + elif ry == 0: + subcube = self.data[rx - 1 : rx + 2, :2, :, :] + kernel = kernel[:, 1:] + elif ry == (ny - 1): + subcube = self.data[rx - 1 : rx + 2, -2:, :, :] + kernel = kernel[:, :-1] + else: + raise Exception(f"Invalid (rx,ry) = ({rx},{ry})...") + + # normalize the kernel + kernel /= np.sum(kernel) + + # compute... + + # ...in the simple case + if not (radial_bksb) and not (braggmask): + ans = np.tensordot(subcube, kernel, axes=((0, 1), (0, 1))) + + # ...with radial background subtration + elif radial_bksb and not (braggmask): + # get position of (rx,ry) relative to kernel + _xs = 1 if rx != 0 else 0 + _ys = 1 if ry != 0 else 0 + x0 = rx - _xs + y0 = ry - _ys + # compute + ans = np.zeros(self.Qshape) + for (i, j), w in np.ndenumerate(kernel): + x = x0 + i + y = y0 + j + ans += self.get_radial_bksb_dp(x, y, sigma) * w + + # ...with bragg masking + elif not (radial_bksb) and braggmask: + assert ( + braggvectors is not None + ), "`braggvectors` must be specified or `braggmask` must be turned off!" + assert ( + braggmask_radius is not None + ), "`braggmask_radius` must be specified or `braggmask` must be turned off!" + # get position of (rx,ry) relative to kernel + _xs = 1 if rx != 0 else 0 + _ys = 1 if ry != 0 else 0 + x0 = rx - _xs + y0 = ry - _ys + # compute + ans = np.zeros(self.Qshape) + weights = np.zeros(self.Qshape) + for (i, j), w in np.ndenumerate(kernel): + x = x0 + i + y = y0 + j + mask = self.get_braggmask(braggvectors, x, y, braggmask_radius) + weights_curr = mask * w + ans += self.data[x, y] * weights_curr + weights += weights_curr + # normalize + out = np.full_like(ans, np.nan) + ans_mask = weights > 0 + ans = np.divide(ans, weights, out=out, where=ans_mask) + # make masked array + ans = np.ma.array(data=ans, mask=np.logical_not(ans_mask)) + pass + + # ...with both radial background subtraction and bragg masking + else: + assert ( + braggvectors is not None + ), "`braggvectors` must be specified or `braggmask` must be turned off!" + assert ( + braggmask_radius is not None + ), "`braggmask_radius` must be specified or `braggmask` must be turned off!" + # get position of (rx,ry) relative to kernel + _xs = 1 if rx != 0 else 0 + _ys = 1 if ry != 0 else 0 + x0 = rx - _xs + y0 = ry - _ys + # compute + ans = np.zeros(self.Qshape) + weights = np.zeros(self.Qshape) + for (i, j), w in np.ndenumerate(kernel): + x = x0 + i + y = y0 + j + mask = self.get_braggmask(braggvectors, x, y, braggmask_radius) + weights_curr = mask * w + ans += self.get_radial_bksb_dp(x, y, sigma) * weights_curr + weights += weights_curr + # normalize + out = np.full_like(ans, np.nan) + ans_mask = weights > 0 + ans = np.divide(ans, weights, out=out, where=ans_mask) + # make masked array + ans = np.ma.array(data=ans, mask=np.logical_not(ans_mask)) + pass + + # return + return ans + + + ### Bragg masking + + def get_braggmask(self, braggvectors, rx, ry, radius): + """ + Returns a boolean mask which is False in a radius of `radius` around + each bragg scattering vector at scan position (rx,ry). + + Parameters + ---------- + braggvectors : BraggVectors + The bragg vectors + rx : int + The x-coord of the beam position + ry : int + The y-coord of the beam position + radius : number + mask pixels about each bragg vector to this radial distance + + Returns + ------- + mask : boolean ndarray + """ + # allocate space + mask = np.ones(self.Qshape, dtype=bool) + # get the vectors + vects = braggvectors.raw[rx, ry] + # loop + for idx in range(len(vects.data)): + qr = np.hypot(self.qxx_raw - vects.qx[idx], self.qyy_raw - vects.qy[idx]) + mask = np.logical_and(mask, qr > radius) + return mask + + ### Shift patterns + + def align_diffraction( + self, + xshifts, + yshifts, + periodic=True, + bilinear=False, + ): + """ + This function shifts each 2D diffraction image by the values defined by + (xshifts,yshifts). The shift values can be scalars (same shift for all + images) or arrays with the same dimensions as the probe positions in + self. + + Args: + self (DataCube): py4DSTEM DataCube + xshifts (float): Array or scalar value for the x dim shifts + yshifts (float): Array or scalar value for the y dim shifts + periodic (bool): Flag for periodic boundary conditions. If set to false, boundaries are assumed to be periodic. + bilinear (bool): Flag for bilinear image shifts. If set to False, Fourier shifting is used. + + Returns: + self (DataCube): py4DSTEM DataCube + """ + + # if the shift values are constant, expand to arrays + xshifts = np.array(xshifts) + yshifts = np.array(yshifts) + if xshifts.ndim == 0: + xshifts = xshifts * np.ones((self.R_Nx, self.R_Ny)) + if yshifts.ndim == 0: + yshifts = yshifts * np.ones((self.R_Nx, self.R_Ny)) + + # Loop over all images + for ax, ay in tqdmnd( + *(self.R_Nx, self.R_Ny), desc="Shifting images", unit=" images" + ): + self.data[ax, ay, :, :] = get_shifted_ar( + self.data[ax, ay, :, :], + xshifts[ax, ay], + yshifts[ax, ay], + periodic=periodic, + bilinear=bilinear, + ) + + return self + + diff --git a/py4DSTEM/datacube/preprocess/radialbkgrd.py b/py4DSTEM/datacube/preprocess/radialbkgrd.py new file mode 100644 index 000000000..ac207cb99 --- /dev/null +++ b/py4DSTEM/datacube/preprocess/radialbkgrd.py @@ -0,0 +1,174 @@ +""" +Functions for generating radially averaged backgrounds +""" + +import numpy as np +from scipy.interpolate import interp1d +from scipy.signal import savgol_filter + +from py4DSTEM.utils import cartesian_to_polarelliptical_transform + + +## Create look up table for background subtraction +def get_1D_polar_background( + data, + p_ellipse, + center=None, + maskUpdateIter=3, + min_relative_threshold=4, + smoothing=False, + smoothingWindowSize=3, + smoothingPolyOrder=4, + smoothing_log=True, + min_background_value=1e-3, + return_polararr=False, +): + """ + Gets the median polar background for a diffraction pattern + + Parameters + ---------- + data : ndarray + the data for which to find the polar eliptical background, + usually a diffraction pattern + p_ellipse : 5-tuple + the ellipse parameters (qx0,qy0,a,b,theta) + center : 2-tuple or None + if None, the center point from `p_ellipse` is used. Otherwise, + the center point in `p_ellipse` is ignored, and this argument + is used as (qx0,qy0) instead. + maskUpdate_iter : integer + min_relative_threshold : float + smoothing : bool + if true, applies a Savitzky-Golay smoothing filter + smoothingWindowSize : integer + size of the smoothing window, must be odd number + smoothingPolyOrder : number + order of the polynomial smoothing to be applied + smoothing_log : bool + if true log smoothing is performed + min_background_value : float + if log smoothing is true, a zero value will be replaced with a + small nonzero float + return_polar_arr : bool + if True the polar transform with the masked high intensity peaks + will be returned + + Returns + ------- + 2- or 3-tuple of ndarrays + * **background1D**: 1D polar elliptical background + * **r_bins**: the elliptically transformed radius associated with + background1D + * **polarData** (optional): the masked polar transform from which the + background is computed, returned iff `return_polar_arr==True` + """ + # assert data is proper form + assert isinstance(smoothing, bool), "Smoothing must be bool" + assert smoothingWindowSize % 2 == 1, "Smoothing window must be odd" + assert isinstance(return_polararr, bool), "return_polararr must be bool" + + # Prepare ellipse params + if center is not None: + p_ellipse = tuple[ + center[0], center[1], p_ellipse[2], p_ellipse[3], p_ellipse[4] + ] + + # Compute Polar Transform + polarData, rr, tt = cartesian_to_polarelliptical_transform(data, p_ellipse) + + # Crop polar data to maximum distance which contains information from original image + if (polarData.mask.sum(axis=(0)) == polarData.shape[0]).any(): + ii = polarData.data.shape[1] - 1 + while polarData.mask[:, ii].all() == True: # noqa: E712 + ii = ii - 1 + maximalDistance = ii + polarData = polarData[:, 0:maximalDistance] + r_bins = rr[0, 0:maximalDistance] + else: + r_bins = rr[0, :] + + # Iteratively mask off high intensity peaks + maskPolar = np.copy(polarData.mask) + background1D = np.ma.median(polarData, axis=0) + for ii in range(maskUpdateIter + 1): + if ii > 0: + maskUpdate = np.logical_or( + maskPolar, polarData / background1D > min_relative_threshold + ) + # Prevent entire columns from being masked off + colMaskMin = np.all(maskUpdate, axis=0) # Detect columns that are empty + maskUpdate[:, colMaskMin] = polarData.mask[ + :, colMaskMin + ] # reset empty columns to values of previous iterations + polarData.mask = maskUpdate # Update Mask + + background1D = np.maximum(background1D, min_background_value) + + if smoothing is True: + if smoothing_log is True: + background1D = np.log(background1D) + + background1D = savgol_filter( + background1D, smoothingWindowSize, smoothingPolyOrder + ) + if smoothing_log is True: + background1D = np.exp(background1D) + if return_polararr is True: + return (background1D, r_bins, polarData) + else: + return (background1D, r_bins) + + +# Create 2D Background +def get_2D_polar_background(data, background1D, r_bins, p_ellipse, center=None): + """ + Gets 2D polar elliptical background from linear 1D background + + Parameters + ---------- + data : ndarray + the data for which to find the polar eliptical background, + usually a diffraction pattern + background1D : ndarray + a vector representing the radial elliptical background + r_bins : ndarray + a vector of the elliptically transformed radius associated with + background1D + p_ellipse : 5-tuple + the ellipse parameters (qx0,qy0,a,b,theta) + center : 2-tuple or None + if None, the center point from `p_ellipse` is used. Otherwise, + the center point in `p_ellipse` is ignored, and this argument + is used as (qx0,qy0) instead. + + Returns + ------- + ndarray + 2D polar elliptical median background image + """ + assert ( + r_bins.shape == background1D.shape + ), "1D background and r_bins must be same length" + + # Prepare ellipse params + qx0, qy0, a, b, theta = p_ellipse + if center is not None: + qx0, qy0 = center + + # Define centered 2D cartesian coordinate system + yc, xc = np.meshgrid( + np.arange(0, data.shape[1]) - qy0, np.arange(0, data.shape[0]) - qx0 + ) + + # Calculate the semimajor axis distance for each point in the 2D array + r = np.sqrt( + ((xc * np.cos(theta) + yc * np.sin(theta)) ** 2) + + (((xc * np.sin(theta) - yc * np.cos(theta)) ** 2) / ((b / a) ** 2)) + ) + + # Create a 2D eliptical background using linear interpolation + f = interp1d(r_bins, background1D, fill_value="extrapolate") + background2D = f(r) + + return background2D diff --git a/py4DSTEM/datacube/virtualdiffraction.py b/py4DSTEM/datacube/virtualdiffraction.py index 23b151d58..b37fcc320 100644 --- a/py4DSTEM/datacube/virtualdiffraction.py +++ b/py4DSTEM/datacube/virtualdiffraction.py @@ -1,6 +1,6 @@ -# Virtual diffraction from a self. Includes: -# * VirtualDiffraction - a container for virtual diffraction data + metadata -# * DataCubeVirtualDiffraction - methods inherited by DataCube for virt diffraction +# Contains- +# 1) Datacube methods and +# 2) a container class import numpy as np from typing import Optional @@ -8,55 +8,14 @@ from emdfile import tqdmnd, Metadata from py4DSTEM.data import DiffractionSlice, Data -from py4DSTEM.preprocess import get_shifted_ar +from py4DSTEM.utils import get_shifted_ar -# Virtual diffraction container class -class VirtualDiffraction(DiffractionSlice, Data): - """ - Stores a diffraction-space shaped 2D image with metadata - indicating how this image was generated from a self. - """ +# DataCube methods - def __init__( - self, - data: np.ndarray, - name: Optional[str] = "virtualdiffraction", - ): - """ - Args: - data (np.ndarray) : the 2D data - name (str) : the name - Returns: - A new VirtualDiffraction instance - """ - # initialize as a DiffractionSlice - DiffractionSlice.__init__( - self, - data=data, - name=name, - ) - - # read - @classmethod - def _get_constructor_args(cls, group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = DiffractionSlice._get_constructor_args(group) - args = { - "data": ar_constr_args["data"], - "name": ar_constr_args["name"], - } - return args - - -# DataCube virtual diffraction methods - - -class DataCubeVirtualDiffraction: +class VirtualDiffractioner: def __init__(self): pass @@ -391,3 +350,50 @@ def get_dp_median( name="dp_median", returncalc=True, ) + + + +# Container class + + +class VirtualDiffraction(DiffractionSlice, Data): + """ + Stores a diffraction-space shaped 2D image with metadata + indicating how this image was generated from a self. + """ + + def __init__( + self, + data: np.ndarray, + name: Optional[str] = "virtualdiffraction", + ): + """ + Args: + data (np.ndarray) : the 2D data + name (str) : the name + + Returns: + A new VirtualDiffraction instance + """ + # initialize as a DiffractionSlice + DiffractionSlice.__init__( + self, + data=data, + name=name, + ) + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = DiffractionSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args + + + diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 627223d23..6c6114345 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -1,8 +1,7 @@ -# Virtual imaging from a datacube. Includes: -# * VirtualImage - a container for virtual image data + metadata -# * DataCubeVirtualImager - methods inherited by DataCube for virt imaging -# -# for bragg virtual imaging methods, goto diskdetection.virtualimage.py +# Contains- +# 1) Datacube methods, in VirtualImager, and +# 2) A container class, VirtualImage + import numpy as np import dask.array as da @@ -11,58 +10,15 @@ from emdfile import tqdmnd, Metadata from py4DSTEM.data import Calibration, RealSlice, Data, DiffractionSlice -from py4DSTEM.preprocess import get_shifted_ar +from py4DSTEM.utils import get_shifted_ar from py4DSTEM.visualize import show -# Virtual image container class +# DataCube methods -class VirtualImage(RealSlice, Data): - """ - A container for storing virtual image data and metadata, - including the real-space shaped 2D image and metadata - indicating how this image was generated from a datacube. - """ - def __init__( - self, - data: np.ndarray, - name: Optional[str] = "virtualimage", - ): - """ - Parameters - ---------- - data : np.ndarray - the 2D data - name : str - the name - """ - # initialize as a RealSlice - RealSlice.__init__( - self, - data=data, - name=name, - ) - - # read - @classmethod - def _get_constructor_args(cls, group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = RealSlice._get_constructor_args(group) - args = { - "data": ar_constr_args["data"], - "name": ar_constr_args["name"], - } - return args - - -# DataCube virtual imaging methods - - -class DataCubeVirtualImager: +class VirtualImager: def __init__(self): pass @@ -749,3 +705,55 @@ def make_bragg_mask( if return_sum: mask = np.sum(mask, axis=2) return mask + + + + + + +# Container class + + +class VirtualImage(RealSlice, Data): + """ + A container for storing virtual image data and metadata, + including the real-space shaped 2D image and metadata + indicating how this image was generated from a datacube. + """ + + def __init__( + self, + data: np.ndarray, + name: Optional[str] = "virtualimage", + ): + """ + Parameters + ---------- + data : np.ndarray + the 2D data + name : str + the name + """ + # initialize as a RealSlice + RealSlice.__init__( + self, + data=data, + name=name, + ) + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = RealSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args + + + + diff --git a/py4DSTEM/io/filereaders/empad.py b/py4DSTEM/io/filereaders/empad.py index 25c0a113b..686ee74cd 100644 --- a/py4DSTEM/io/filereaders/empad.py +++ b/py4DSTEM/io/filereaders/empad.py @@ -9,7 +9,7 @@ from pathlib import Path from emdfile import tqdmnd from py4DSTEM.datacube import DataCube -from py4DSTEM.preprocess.utils import bin2D +from py4DSTEM.utils import bin2D def read_empad(filename, mem="RAM", binfactor=1, metadata=False, **kwargs): diff --git a/py4DSTEM/io/filereaders/read_arina.py b/py4DSTEM/io/filereaders/read_arina.py index 832499d3f..c75679854 100644 --- a/py4DSTEM/io/filereaders/read_arina.py +++ b/py4DSTEM/io/filereaders/read_arina.py @@ -2,7 +2,7 @@ import hdf5plugin import numpy as np from py4DSTEM.datacube import DataCube -from py4DSTEM.preprocess.utils import bin2D +from py4DSTEM.utils import bin2D def read_arina( diff --git a/py4DSTEM/io/filereaders/read_dm.py b/py4DSTEM/io/filereaders/read_dm.py index 617529708..164a72d4c 100644 --- a/py4DSTEM/io/filereaders/read_dm.py +++ b/py4DSTEM/io/filereaders/read_dm.py @@ -6,7 +6,8 @@ from emdfile import tqdmnd, Array from py4DSTEM.datacube import DataCube -from py4DSTEM.preprocess.utils import bin2D +from py4DSTEM.utils import bin2D, electron_wavelength_angstrom + def read_dm(filepath, name="dm_dataset", mem="RAM", binfactor=1, **kwargs): @@ -79,8 +80,6 @@ def read_dm(filepath, name="dm_dataset", mem="RAM", binfactor=1, **kwargs): if "Microscope Info.Voltage" in t ] if len(voltage) >= 1: - from py4DSTEM.process.utils import electron_wavelength_angstrom - wavelength = electron_wavelength_angstrom(voltage[0]) Q_pixel_units = "A^-1" Q_pixel_size = ( diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 0509d181e..878f60da9 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -3,7 +3,7 @@ from py4DSTEM.process import phase from py4DSTEM.process import calibration -from py4DSTEM.process import utils from py4DSTEM.process import classification from py4DSTEM.process import diffraction from py4DSTEM.process import wholepatternfit + diff --git a/py4DSTEM/process/calibration/ellipse.py b/py4DSTEM/process/calibration/ellipse.py index 2954de377..d01392fc0 100644 --- a/py4DSTEM/process/calibration/ellipse.py +++ b/py4DSTEM/process/calibration/ellipse.py @@ -12,15 +12,15 @@ to the x-axis, in radians More details about the elliptical parameterization used can be found in -the module docstring for process/utils/elliptical_coords.py. +the module docstring for utils/elliptical_coords.py. """ import numpy as np from scipy.optimize import leastsq from scipy.ndimage import gaussian_filter -from py4DSTEM.process.utils import convert_ellipse_params, convert_ellipse_params_r -from py4DSTEM.process.utils import get_CoM, radial_integral +from py4DSTEM.utils import convert_ellipse_params, convert_ellipse_params_r +from py4DSTEM.utils import get_CoM, radial_integral ###### Fitting a 1d elliptical curve to a 2d array, e.g. a Bragg vector map ###### diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py index 7f0c07a81..957d1e40e 100644 --- a/py4DSTEM/process/calibration/origin.py +++ b/py4DSTEM/process/calibration/origin.py @@ -11,18 +11,7 @@ from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration.probe import get_probe_size from py4DSTEM.process.fit import plane, parabola, bezier_two, fit_2D -from py4DSTEM.process.utils import ( - get_CoM, - add_to_2D_array_from_floats, - get_maxima_2D, - upsampled_correlation, -) -from py4DSTEM.process.phase.utils import copy_to_device - -try: - import cupy as cp -except (ImportError, ModuleNotFoundError): - cp = np +from py4DSTEM.utils import get_CoM, add_to_2D_array_from_floats, get_maxima_2D # diff --git a/py4DSTEM/process/calibration/probe.py b/py4DSTEM/process/calibration/probe.py index dc0a38949..78f8816a6 100644 --- a/py4DSTEM/process/calibration/probe.py +++ b/py4DSTEM/process/calibration/probe.py @@ -1,6 +1,6 @@ import numpy as np -from py4DSTEM.process.utils import get_CoM +from py4DSTEM.utils import get_CoM def get_probe_size(DP, thresh_lower=0.01, thresh_upper=0.99, N=100): diff --git a/py4DSTEM/process/calibration/qpixelsize.py b/py4DSTEM/process/calibration/qpixelsize.py index d59d5a45c..f8aa6a3b5 100644 --- a/py4DSTEM/process/calibration/qpixelsize.py +++ b/py4DSTEM/process/calibration/qpixelsize.py @@ -5,7 +5,7 @@ from typing import Union, Optional from emdfile import tqdmnd -from py4DSTEM.process.utils import get_CoM +from py4DSTEM.utils import get_CoM def get_Q_pixel_size(q_meas, q_known, units="A"): diff --git a/py4DSTEM/process/classification/classutils.py b/py4DSTEM/process/classification/classutils.py index 51762a090..929bb37b0 100644 --- a/py4DSTEM/process/classification/classutils.py +++ b/py4DSTEM/process/classification/classutils.py @@ -4,7 +4,7 @@ from emdfile import tqdmnd, PointListArray from py4DSTEM.datacube import DataCube -from py4DSTEM.process.utils import get_shifted_ar +from py4DSTEM.utils import get_shifted_ar def get_class_DP( diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py index eb964de96..a816c770e 100644 --- a/py4DSTEM/process/diffraction/WK_scattering_factors.py +++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py @@ -3,7 +3,7 @@ # from functools import lru_cache -from py4DSTEM.process.utils import electron_wavelength_angstrom +from py4DSTEM.utils import electron_wavelength_angstrom """ Weickenmeier-Kohl absorptive scattering factors, adapted by SE Zeltmann from EMsoftLib/others.f90 diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index aa1eb8555..f114765ad 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -9,7 +9,7 @@ import warnings from emdfile import PointList -from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom +from py4DSTEM.utils import single_atom_scatter, electron_wavelength_angstrom from py4DSTEM.process.diffraction.utils import Orientation diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 83dff29e1..1eb1526fd 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -6,7 +6,7 @@ from emdfile import tqdmnd, PointList, PointListArray from py4DSTEM.data import RealSlice from py4DSTEM.process.diffraction.utils import Orientation, OrientationMap, axisEqual3D -from py4DSTEM.process.utils import electron_wavelength_angstrom +from py4DSTEM.utils import electron_wavelength_angstrom from warnings import warn diff --git a/py4DSTEM/process/diffraction/crystal_bloch.py b/py4DSTEM/process/diffraction/crystal_bloch.py index ce8bb8622..9c47889b0 100644 --- a/py4DSTEM/process/diffraction/crystal_bloch.py +++ b/py4DSTEM/process/diffraction/crystal_bloch.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from emdfile import PointList -from py4DSTEM.process.utils import electron_wavelength_angstrom, single_atom_scatter +from py4DSTEM.utils import electron_wavelength_angstrom, single_atom_scatter from py4DSTEM.process.diffraction.WK_scattering_factors import compute_WK_factor diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index 2ba71c8f1..1061ee95a 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -13,7 +13,7 @@ from matplotlib.ticker import PercentFormatter from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM import Calibration, DataCube -from py4DSTEM.preprocess.utils import get_shifted_ar +from py4DSTEM.utils import get_shifted_ar from py4DSTEM.process.phase.phase_base_class import PhaseReconstruction from py4DSTEM.process.phase.utils import ( AffineTransform, @@ -23,8 +23,8 @@ lanczos_kernel_density_estimate, pixel_rolling_kernel_density_estimate, ) -from py4DSTEM.process.utils.cross_correlate import align_images_fourier -from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.utils import align_images_fourier +from py4DSTEM.utils import electron_wavelength_angstrom from py4DSTEM.visualize import return_scaled_histogram_ordering, show from scipy.linalg import polar from scipy.ndimage import distance_transform_edt diff --git a/py4DSTEM/process/phase/phase_base_class.py b/py4DSTEM/process/phase/phase_base_class.py index a57568f48..014e39074 100644 --- a/py4DSTEM/process/phase/phase_base_class.py +++ b/py4DSTEM/process/phase/phase_base_class.py @@ -27,7 +27,7 @@ get_array_module, polar_aliases, ) -from py4DSTEM.process.utils import ( +from py4DSTEM.utils import ( electron_wavelength_angstrom, fourier_resample, get_shifted_ar, diff --git a/py4DSTEM/process/phase/ptychographic_constraints.py b/py4DSTEM/process/phase/ptychographic_constraints.py index 8a10b4df2..51508fec5 100644 --- a/py4DSTEM/process/phase/ptychographic_constraints.py +++ b/py4DSTEM/process/phase/ptychographic_constraints.py @@ -8,7 +8,7 @@ fit_aberration_surface, regularize_probe_amplitude, ) -from py4DSTEM.process.utils import get_CoM +from py4DSTEM.utils import get_CoM try: import cupy as cp diff --git a/py4DSTEM/process/phase/ptychographic_methods.py b/py4DSTEM/process/phase/ptychographic_methods.py index a1cef5c56..77bd4ee14 100644 --- a/py4DSTEM/process/phase/ptychographic_methods.py +++ b/py4DSTEM/process/phase/ptychographic_methods.py @@ -16,7 +16,7 @@ spatial_frequencies, vectorized_bilinear_resample, ) -from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from py4DSTEM.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import gaussian_filter, rotate diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index 98d8fdd49..6f2c57fad 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -19,9 +19,9 @@ def get_array_module(*args): return np -from py4DSTEM.process.utils import get_CoM -from py4DSTEM.process.utils.cross_correlate import align_and_shift_images -from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.utils import get_CoM +from py4DSTEM.utils import align_and_shift_images +from py4DSTEM.utils import electron_wavelength_angstrom from skimage.restoration import unwrap_phase # fmt: off diff --git a/py4DSTEM/process/preprocess/electron_counting.py b/py4DSTEM/process/preprocess/electron_counting.py new file mode 100644 index 000000000..ee4b5379e --- /dev/null +++ b/py4DSTEM/process/preprocess/electron_counting.py @@ -0,0 +1,220 @@ +import numpy as np, matplotlib.pyplot as plt,copy,sys,os +from scipy import optimize +from scipy.ndimage.measurements import label,center_of_mass + +def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, + length = 100, fill = '*'): + """ + Call in a loop to create terminal progress bar + @params: + iteration - Required : current iteration (Int) + total - Required : total iterations (Int) + prefix - Optional : prefix string (Str) + suffix - Optional : suffix string (Str) + decimals - Optional : positive number of decimals in percent complete (Int) + length - Optional : character length of bar (Int) + fill - Optional : bar fill character (Str) + """ + percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) + filledLength = int(length * iteration // total) + bar = fill * filledLength + '-' * (length - filledLength) + print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r') + # Print New Line on Complete + if iteration == total: + print() + + +def get_dark_reference(datacube,ndrsamples=20,upper_limit=5,drwidth=100): + '''For a 4D-STEM datacube subtract the dark reference''' + #Get dimensions + y,x,z,t = np.shape(datacube) + + #Sum a set of random frames from the data to create + #the background reference + rand = np.random.randint(0,min(z,t),(min(z*t,ndrsamples),2)) + refframe = np.zeros((y,x),dtype=np.int16) + mean = 0 + var=0 + for i in range(ndrsamples): + refframe+= datacube[:,:,rand[i,0],rand[i,1]] + mean += np.mean(datacube[:,:,rand[i,0],rand[i,1]]) + var += np.var(datacube[:,:,rand[i,0],rand[i,1]]) + mean/=ndrsamples + var/=ndrsamples + #TODO: make GUI to choose region to take dark reference from + #Create dark reference image by choosing a strip of the reference + #image near the boundary, projecting this horizontally (axis 1) + #and summing. + #Normalisation by the width of the reference region and the + #number of sample images necessary. + dr = np.sum(refframe[:,:drwidth],axis=1)//drwidth//ndrsamples + + return dr,mean,np.sqrt(var) + +def electron_list_to_array(list,array_size): + '''For a list of electron positions as a fractional + if the original pixel dimensions of the camera + create a 2D image of the diffraction pattern for + arbitrary array size array_size.''' + nelectrons = np.shape(list)[0] + Ny,Nx = [np.amax(list[:,2]),np.amax(list[:,3])] + array_out = np.zeros(array_size+(Ny,Nx),dtype=np.int) + sze = np.asarray(array_size) + for i in range(nelectrons): + y,x = [int(x) for x in np.floor(list[i,:2]*sze)] + z,t = [int(x) for x in list[i,2:4]] + array_out[y,x,z,t] += 1 + return array_out + +def calculate_counting_threshhold(datacube,darkreference,sigmathresh=4,nsamples=20, + upper_limit=10,plot_histogram=False): + """For a datacube calculate the threshhold for an electron count to be + registered.""" + y,x,z,t = np.shape(datacube) + samples = np.random.randint(0,z*t,size=nsamples) + #This flattens the array, it shouldn't make a copy of it + sample = np.zeros((y,x,nsamples),dtype=np.int16) + for i in range(nsamples): + sample[:,:,i] = np.asarray(datacube[:,:,samples[i]//t,samples[i]%t],dtype=np.int16) + sample[:,:,i] -= darkreference[:,np.newaxis] + sample = np.ravel(sample) + #Remove X-rays + mean = np.mean(sample) + stddev = np.std(sample) + #zero all X-ray counts + sample[sample>(mean+upper_limit*stddev)]=mean + #Since the data is 16 bit integer choosing bins is easy + bins = np.arange(np.floor(np.amin(sample)) + ,np.ceil(np.amax(sample)), + dtype=np.int) + + #Fit background + n, bins = np.histogram(sample, bins=bins) + + #Now fit Gaussian function to the histogram + #First define Gaussian function, list p contains parameters + #p[0] is amplitude of Gaussian + #p[1] is centre of Gaussian + #p[2] is std deviation of Gaussian + fitfunc = lambda p, x: p[0]*np.exp(-0.5*np.square((x-p[1])/p[2])) + #Define the error routine for scipy's optimize routine + errfunc = lambda p, x, y: fitfunc(p, x) - y + + #Initial guess of Gaussian function parameters + p0 = [n.max(),(bins[n.argmax()+1]-bins[n.argmax()])/2,np.std(sample)] + #Use the scipy optimize routine + p1,success = optimize.leastsq(errfunc,p0[:],args=(bins[:-1],n)) + #Add a half to account for integer bin width + p1[1] += 0.5 + + #Threshhold for an electron count to be considered an electron count + thresh = p1[1]+p1[2]*sigmathresh + + #If desired, plot histogram with line of best fit + if(plot_histogram): + n, bins, patches = plt.hist(x=sample, bins=bins, + color='#0504aa', + alpha=0.7, rwidth=0.85) + plt.grid(axis='y', alpha=0.75) + plt.xlabel('Value') + plt.ylabel('Frequency') + plt.title('Histogram of K2 intensity readout') + maxfreq = n.max() + stddev =np.std(sample) + (plt.xlim( xmin = bins[n.argmax()] - 5*stddev, + xmax = bins[n.argmax()] + 5*stddev)) + ymax = np.ceil(maxfreq / 10) * 10 + plt.ylim(ymax=ymax) + plt.plot(bins[:-1],fitfunc(p1,bins[:-1]),color='red') + plt.plot([thresh,thresh],[0.0,ymax],'k--') + plt.text(thresh-0.5,ymax/2,'e$^{-}$ count\nthreshhold',ha='right') + plt.savefig('Intensity_histogram.pdf') + plt.show() + + return thresh + +def count_datacube(datacube,counted_shape,sigmathresh=4,nsamples=40,upper_limit=10, + drwidth=100,sub_pixel=True, + plot_histogram=False,plot_electrons=False): + data = datacube.data4D + + #Get dimensions + y,x,z,t = [datacube.Q_Ny, datacube.Q_Nx,datacube.R_Ny, datacube.R_Nx] + + print('Getting dark current reference') + #Remove dark background + dr,mean,stddev = get_dark_reference(data,ndrsamples=nsamples,upper_limit=upper_limit, + drwidth=drwidth) + + print('Calculating threshhold') + #Get threshhold + thresh = calculate_counting_threshhold(data,dr,sigmathresh=sigmathresh + ,nsamples=nsamples,upper_limit=upper_limit + ,plot_histogram=plot_histogram) + + counted = np.zeros(counted_shape+(z,t),dtype=np.uint16) + + #S + total_electrons =0 + e_list = None + + for tt in range(t): + printProgressBar(tt, t-1, prefix = 'Counting:', suffix = 'Complete', length = 50) + for zz in range(z): + + #Background subtract and remove X-rays\hot pixels + workingarray = data[:,:,zz,tt]-dr[:,np.newaxis] + workingarray[workingarray>mean+upper_limit*stddev]=0 + + #Create a map of events (pixels with value greater than + #the threshhold value) + + events = np.greater(workingarray,thresh) + + #Now find local maxima by circular shift and comparison to neighbours + #if we want diagonal neighbours to be considered seperate electron + #counts then we would add a second for loop to do these extra + #comparisons + for i in range(4): + events = np.logical_and(np.greater(workingarray, + np.roll(workingarray,i%2*2-1,axis=i//2)),events) + events[ 0, :]=False + events[-1, :]=False + events[ :, 0]=False + events[ :,-1]=False + + electron_posn = np.asarray(np.argwhere(events),dtype=np.float) + num_electrons = np.shape(electron_posn)[0] + + if(sub_pixel): + #Now do center of mass in local region 3x3 region to refine position + # estimate + for i in range(num_electrons): + event = electron_posn[i,:] + electron_posn[i,:] += center_of_mass(workingarray[int(event[0]-1) + :int(event[0]+2),int(event[1]-1) + :int(event[1]+2)]) + electron_posn -= np.asarray([1,1])[np.newaxis,:] + + if(plot_electrons): + if(not os.path.exists(count_plots)):os.mkdir('count_plots') + figsize = max(np.shape(data[:,:,zz,tt])[:2])//200 + fig = plt.figure(figsize = (figsize,figsize)) + ax = fig.add_subplot(111) + ax.imshow(data[:,:,zz,tt],origin='lower',vmax=2*thresh) + ax.plot(electron_posn[:,1],electron_posn[:,0],'rx') + ax.set_title('Found {0} electrons'.format(num_electrons)) + plt.show() + fig.savefig('count_plots/Countplot_{0}_{1}.pdf'.format(tt,zz)) + #Update total number of electrons + total_electrons += num_electrons + #Put the electron_posn in fractional coordinates + #where the positions are fractions of the original array + electron_posn /= np.asarray([max(y,x),max(y,x)])[np.newaxis,:] + electron_posn = np.hstack((electron_posn + ,np.asarray([tt,zz],dtype=np.float32)[np.newaxis,:])) + if(e_list is None): e_list = electron_posn + else:e_list = np.vstack((e_list,electron_posn)) + return electron_posn + +if __name__=="__main__": pass diff --git a/py4DSTEM/process/rdf/amorph.py b/py4DSTEM/process/rdf/amorph.py index 3aaf63c45..fe3c6999e 100644 --- a/py4DSTEM/process/rdf/amorph.py +++ b/py4DSTEM/process/rdf/amorph.py @@ -1,6 +1,6 @@ import numpy as np import matplotlib.pyplot as plt -from py4DSTEM.process.utils.elliptical_coords import * ## What else is used here? These fns have +from py4DSTEM.utils.elliptical_coords import * ## What else is used here? These fns have ## moved around some. In general, specifying ## the fns is better practice. TODO: change diff --git a/py4DSTEM/process/rdf/rdf.py b/py4DSTEM/process/rdf/rdf.py index cee7eeee9..358985135 100644 --- a/py4DSTEM/process/rdf/rdf.py +++ b/py4DSTEM/process/rdf/rdf.py @@ -6,7 +6,7 @@ from scipy.special import erf from scipy.fftpack import dst, idst -from py4DSTEM.process.utils import single_atom_scatter +from py4DSTEM.utils import single_atom_scatter def get_radial_intensity(polar_img, polar_mask): diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 516820e04..49a8f5c2a 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -10,7 +10,7 @@ from py4DSTEM import PointList, PointListArray, tqdmnd from py4DSTEM.braggvectors import BraggVectors from py4DSTEM.data import Data, RealSlice -from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.utils import get_maxima_2D from py4DSTEM.process.strain.latticevectors import ( fit_lattice_vectors_all_DPs, get_reference_g1g2, @@ -185,6 +185,7 @@ def choose_basis_vectors( index_origin : int selected index for the origin subpixel : str in ('pixel','poly','multicorr') +<<<<<<< HEAD See the docstring for py4DSTEM.preprocess.get_maxima_2D upsample_factor : int See the py4DSTEM.preprocess.get_maxima_2D docstring @@ -202,6 +203,25 @@ def choose_basis_vectors( See the py4DSTEM.preprocess.get_maxima_2D docstring maxNumPeaks : int See the py4DSTEM.preprocess.get_maxima_2D docstring +======= + See the docstring for py4DSTEM.utils.get_maxima_2D + upsample_factor : int + See the py4DSTEM.utils.get_maxima_2D docstring + sigma : number + See the py4DSTEM.utils.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.utils.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.utils.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.utils.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.utils.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.utils.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.utils.get_maxima_2D docstring +>>>>>>> v15 figsize : 2-tuple the size of the figure c_indices : color diff --git a/py4DSTEM/utils/__init__.py b/py4DSTEM/utils/__init__.py index b0c484e80..a8b3adfa2 100644 --- a/py4DSTEM/utils/__init__.py +++ b/py4DSTEM/utils/__init__.py @@ -1 +1,48 @@ +from py4DSTEM.utils.bin2d import bin2D from py4DSTEM.utils.configuration_checker import check_config +from py4DSTEM.utils.cross_correlate import ( + get_cross_correlation, + get_cross_correlation_FT, + get_shift, + align_images_fourier, + align_and_shift_images) +from py4DSTEM.utils.electron_conversions import ( + electron_wavelength_angstrom, + electron_interaction_parameter) +from py4DSTEM.utils.elliptical_coords import ( + convert_ellipse_params, + convert_ellipse_params_r, + cartesian_to_polarelliptical_transform, + elliptical_resample_datacube, + elliptical_resample, + radial_elliptical_integral, + radial_integral) +from py4DSTEM.utils.ewpc import get_ewpc_filter_function +from py4DSTEM.utils.get_CoM import get_CoM +from py4DSTEM.utils.get_maxima import ( + get_maxima_1D, + get_maxima_2D, + filter_2D_maxima) +from py4DSTEM.utils.get_shifted_ar import get_shifted_ar +from py4DSTEM.utils.linear_interpolation import ( + linear_interpolation_1D, + linear_interpolation_2D, + add_to_2D_array_from_floats) +from py4DSTEM.utils.make_fourier_coords import ( + make_Fourier_coords2D, + get_qx_qy_1d) +from py4DSTEM.utils.masks import ( + make_circular_mask, + get_beamstop_mask, + sector_mask) +from py4DSTEM.utils.multicorr import ( + upsampled_correlation, + upsampleFFT, + dftUpsample) +from py4DSTEM.utils.radial_reduction import radial_reduction +from py4DSTEM.utils.resample import fourier_resample +from py4DSTEM.utils.single_atom_scatter import single_atom_scatter +from py4DSTEM.utils.voronoi import get_voronoi_vertices + + + diff --git a/py4DSTEM/utils/_depr_utils.py b/py4DSTEM/utils/_depr_utils.py new file mode 100644 index 000000000..efedeeb1c --- /dev/null +++ b/py4DSTEM/utils/_depr_utils.py @@ -0,0 +1,110 @@ +# Defines utility functions used by other functions in the /process/ directory. + +#import numpy as np +#from numpy.fft import fftfreq, fftshift +#from scipy.ndimage import gaussian_filter +#import math as ma +#import matplotlib.pyplot as plt +#from mpl_toolkits.axes_grid1 import make_axes_locatable +#from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar +#import matplotlib.font_manager as fm +# +#from emdfile import tqdmnd +#from py4DSTEM.process.utils.multicorr import upsampled_correlation +#from py4DSTEM.preprocess.utils import make_Fourier_coords2D +# +#try: +# from IPython.display import clear_output +#except ImportError: +# +# def clear_output(wait=True): +# pass +# +# +#try: +# import cupy as cp +#except (ModuleNotFoundError, ImportError): +# cp = np + + + + +# import matplotlib.pyplot as plt +# from mpl_toolkits.axes_grid1 import make_axes_locatable +# from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar +# import matplotlib.font_manager as fm +# +# +# try: +# from IPython.display import clear_output +# except ImportError: +# def clear_output(wait=True): +# pass +# +# def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, +# figsize=(10, 10), scale=None): +# fig, ax = plt.subplots(figsize=figsize) +# im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) +# divider = make_axes_locatable(ax) +# cax = divider.append_axes("right", size="5%", pad=0.05) +# plt.colorbar(im, cax=cax) +# ax.set_title(title) +# fontprops = fm.FontProperties(size=18) +# if scale is not None: +# scalebar = AnchoredSizeBar(ax.transData, +# scale[0], scale[1], 'lower right', +# pad=0.1, +# color='white', +# frameon=False, +# size_vertical=img.shape[0] / 40, +# fontproperties=fontprops) +# +# ax.add_artist(scalebar) +# ax.grid(False) +# if savePath is not None: +# fig.savefig(savePath + '.png', dpi=600) +# fig.savefig(savePath + '.eps', dpi=600) +# if show: +# plt.show() + + + +#def plot( +# img, +# title="Image", +# savePath=None, +# cmap="inferno", +# show=True, +# vmax=None, +# figsize=(10, 10), +# scale=None, +#): +# fig, ax = plt.subplots(figsize=figsize) +# im = ax.imshow(img, interpolation="nearest", cmap=plt.get_cmap(cmap), vmax=vmax) +# divider = make_axes_locatable(ax) +# cax = divider.append_axes("right", size="5%", pad=0.05) +# plt.colorbar(im, cax=cax) +# ax.set_title(title) +# fontprops = fm.FontProperties(size=18) +# if scale is not None: +# scalebar = AnchoredSizeBar( +# ax.transData, +# scale[0], +# scale[1], +# "lower right", +# pad=0.1, +# color="white", +# frameon=False, +# size_vertical=img.shape[0] / 40, +# fontproperties=fontprops, +# ) +# +# ax.add_artist(scalebar) +# ax.grid(False) +# if savePath is not None: +# fig.savefig(savePath + ".png", dpi=600) +# fig.savefig(savePath + ".eps", dpi=600) +# if show: +# plt.show() + + diff --git a/py4DSTEM/utils/bin2d.py b/py4DSTEM/utils/bin2d.py new file mode 100644 index 000000000..a74270d12 --- /dev/null +++ b/py4DSTEM/utils/bin2d.py @@ -0,0 +1,40 @@ +import numpy as np +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + + +def bin2D(array, factor, dtype=np.float64): + """ + Bin a 2D ndarray by binfactor. + + Parameters + ---------- + array : 2D numpy array + factor : int + the binning factor + dtype : numpy dtype + datatype for binned array. default is numpy default for np.zeros() + + Returns + ------- + the binned array + """ + x, y = array.shape + binx, biny = x // factor, y // factor + xx, yy = binx * factor, biny * factor + + # Make a binned array on the device + binned_ar = np.zeros((binx, biny), dtype=dtype) + array = array.astype(dtype) + + # Collect pixel sums into new bins + for ix in range(factor): + for iy in range(factor): + binned_ar += array[0 + ix : xx + ix : factor, 0 + iy : yy + iy : factor] + return binned_ar + + + diff --git a/py4DSTEM/utils/cross_correlate.py b/py4DSTEM/utils/cross_correlate.py new file mode 100644 index 000000000..bd2d2eb39 --- /dev/null +++ b/py4DSTEM/utils/cross_correlate.py @@ -0,0 +1,178 @@ +# Cross correlation function + +import numpy as np +from py4DSTEM.utils.get_shifted_ar import get_shifted_ar +from py4DSTEM.utils.multicorr import upsampled_correlation + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +def get_cross_correlation(ar, template, corrPower=1, _returnval="real"): + """ + Get the cross/phase/hybrid correlation of `ar` with `template`, where + the latter is in real space. + + If _returnval is 'real', returns the real-valued cross-correlation. + Otherwise, returns the complex valued result. + """ + assert _returnval in ("real", "fourier") + template_FT = np.conj(np.fft.fft2(template)) + return get_cross_correlation_FT( + ar, template_FT, corrPower=corrPower, _returnval=_returnval + ) + + +def get_cross_correlation_FT(ar, template_FT, corrPower=1, _returnval="real"): + """ + Get the cross/phase/hybrid correlation of `ar` with `template_FT`, where + the latter is already in Fourier space (i.e. `template_FT` is + `np.conj(np.fft.fft2(template))`. + + If _returnval is 'real', returns the real-valued cross-correlation. + Otherwise, returns the complex valued result. + """ + assert _returnval in ("real", "fourier") + m = np.fft.fft2(ar) * template_FT + if corrPower != 1: + cc = np.abs(m) ** (corrPower) * np.exp(1j * np.angle(m)) + else: + cc = m + if _returnval == "real": + cc = np.maximum(np.real(np.fft.ifft2(cc)), 0) + return cc + + +def get_shift(ar1, ar2, corrPower=1): + """ + Determine the relative shift between a pair of arrays giving the best overlap. + + Shift determination uses the brightest pixel in the cross correlation, and is + thus limited to pixel resolution. corrPower specifies the cross correlation + power, with 1 corresponding to a cross correlation and 0 a phase correlation. + + Args: + ar1,ar2 (2D ndarrays): + corrPower (float between 0 and 1, inclusive): 1=cross correlation, 0=phase + correlation + + Returns: + (2-tuple): (shiftx,shifty) - the relative image shift, in pixels + """ + cc = get_cross_correlation(ar1, ar2, corrPower) + xshift, yshift = np.unravel_index(np.argmax(cc), ar1.shape) + return xshift, yshift + + +def align_images_fourier( + G1, + G2, + upsample_factor, + device="cpu", +): + """ + Alignment of two images using DFT upsampling of cross correlation. + + Parameters + ------- + G1: ndarray + fourier transform of image 1 + G2: ndarray + fourier transform of image 2 + upsample_factor: float + upsampling for correlation. Must be greater than 2. + device: str, optional + calculation device will be perfomed on. Must be 'cpu' or 'gpu' + + Returns: + xy_shift [pixels] + """ + + if device == "cpu": + xp = np + elif device == "gpu": + xp = cp + + G1 = xp.asarray(G1) + G2 = xp.asarray(G2) + + # cross correlation + cc = G1 * xp.conj(G2) + cc_real = xp.real(xp.fft.ifft2(cc)) + + # local max + x0, y0 = xp.unravel_index(cc_real.argmax(), cc.shape) + + # half pixel shifts + x_inds = xp.mod(x0 + xp.arange(-1, 2), cc.shape[0]).astype("int") + y_inds = xp.mod(y0 + xp.arange(-1, 2), cc.shape[1]).astype("int") + + vx = cc_real[x_inds, y0] + vy = cc_real[x0, y_inds] + dx = (vx[2] - vx[0]) / (4 * vx[1] - 2 * vx[2] - 2 * vx[0]) + dy = (vy[2] - vy[0]) / (4 * vy[1] - 2 * vy[2] - 2 * vy[0]) + + x0 = xp.round((x0 + dx) * 2.0) / 2.0 + y0 = xp.round((y0 + dy) * 2.0) / 2.0 + + # subpixel shifts + xy_shift = upsampled_correlation( + cc, upsample_factor, xp.array([x0, y0]), device=device + ) + + return xy_shift + + +def align_and_shift_images( + image_1, + image_2, + upsample_factor, + device="cpu", +): + """ + Alignment of two images using DFT upsampling of cross correlation. + + Parameters + ------- + image_1: ndarray + image 1 + image_2: ndarray + image 2 + upsample_factor: float + upsampling for correlation. Must be greater than 2. + device: str, optional + calculation device will be perfomed on. Must be 'cpu' or 'gpu'. + + Returns: + shifted image [pixels] + """ + + if device == "cpu": + xp = np + + elif device == "gpu": + xp = cp + + image_1 = xp.asarray(image_1) + image_2 = xp.asarray(image_2) + + xy_shift = align_images_fourier( + xp.fft.fft2(image_1), + xp.fft.fft2(image_2), + upsample_factor=upsample_factor, + device=device, + ) + dx = ( + xp.mod(xy_shift[0] + image_1.shape[0] / 2, image_1.shape[0]) + - image_1.shape[0] / 2 + ) + dy = ( + xp.mod(xy_shift[1] + image_1.shape[1] / 2, image_1.shape[1]) + - image_1.shape[1] / 2 + ) + + image_2_shifted = get_shifted_ar(image_2, dx, dy, device=device) + + return image_2_shifted diff --git a/py4DSTEM/utils/electron_conversions.py b/py4DSTEM/utils/electron_conversions.py new file mode 100644 index 000000000..ef691357e --- /dev/null +++ b/py4DSTEM/utils/electron_conversions.py @@ -0,0 +1,35 @@ +import math as ma + + +def electron_wavelength_angstrom(E_eV): + m = 9.109383 * 10**-31 + e = 1.602177 * 10**-19 + c = 299792458 + h = 6.62607 * 10**-34 + + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) + return lam + + +def electron_interaction_parameter(E_eV): + m = 9.109383 * 10**-31 + e = 1.602177 * 10**-19 + c = 299792458 + h = 6.62607 * 10**-34 + lam = ( + h + / ma.sqrt(2 * m * e * E_eV) + / ma.sqrt(1 + e * E_eV / 2 / m / c**2) + * 10**10 + ) + sigma = ( + (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV) + ) + return sigma + + diff --git a/py4DSTEM/utils/elliptical_coords.py b/py4DSTEM/utils/elliptical_coords.py new file mode 100644 index 000000000..97291bc20 --- /dev/null +++ b/py4DSTEM/utils/elliptical_coords.py @@ -0,0 +1,434 @@ +""" +Contains functions relating to polar-elliptical calculations. + +This includes + - transforming data from cartesian to polar-elliptical coordinates + - converting between ellipse representations + - radial and polar-elliptical radial integration + +Functions for measuring/fitting elliptical distortions are found in +process/calibration/ellipse.py. Functions for computing radial and +polar-elliptical radial backgrounds are found in process/preprocess/ellipse.py. + +py4DSTEM uses 2 ellipse representations - one user-facing representation, and +one internal representation. The user-facing represenation is in terms of the +following 5 parameters: + + x0,y0 the center of the ellipse + a the semimajor axis length + b the semiminor axis length + theta the (positive, right handed) tilt of the a-axis + to the x-axis, in radians + +Internally, fits are performed using the canonical ellipse parameterization, +in terms of the parameters (x0,y0,A,B,C): + + A(x-x0)^2 + B(x-x0)(y-y0) C(y-y0)^2 = 1 + +It is possible to convert between (a,b,theta) <--> (A,B,C) using +the convert_ellipse_params() and convert_ellipse_params_r() methods. + +Transformation from cartesian to polar-elliptical space is done using + + x = x0 + a*r*cos(phi)*cos(theta) + b*r*sin(phi)*sin(theta) + y = y0 + a*r*cos(phi)*sin(theta) - b*r*sin(phi)*cos(theta) + +where (r,phi) are the polar-elliptical coordinates. All angular quantities are in +radians. +""" + +import numpy as np + +### Convert between representations + + +def convert_ellipse_params(A, B, C): + """ + Converts ellipse parameters from canonical form (A,B,C) into semi-axis lengths and + tilt (a,b,theta). + See module docstring for more info. + + Args: + A,B,C (floats): parameters of an ellipse in the form: + Ax^2 + Bxy + Cy^2 = 1 + + Returns: + (3-tuple): A 3-tuple consisting of: + + * **a**: (float) the semimajor axis length + * **b**: (float) the semiminor axis length + * **theta**: (float) the tilt of the ellipse semimajor axis with respect to + the x-axis, in radians + """ + val = np.sqrt((A - C) ** 2 + B**2) + b4a = B**2 - 4 * A * C + # Get theta + if B == 0: + if A < C: + theta = 0 + else: + theta = np.pi / 2.0 + else: + theta = np.arctan2((C - A - val), B) + # Get a,b + a = -np.sqrt(-2 * b4a * (A + C + val)) / b4a + b = -np.sqrt(-2 * b4a * (A + C - val)) / b4a + a, b = max(a, b), min(a, b) + return a, b, theta + + +def convert_ellipse_params_r(a, b, theta): + """ + Converts from ellipse parameters (a,b,theta) to (A,B,C). + See module docstring for more info. + + Args: + a,b,theta (floats): parameters of an ellipse, where `a`/`b` are the + semimajor/semiminor axis lengths, and theta is the tilt of the semimajor axis + with respect to the x-axis, in radians. + + Returns: + (3-tuple): A 3-tuple consisting of (A,B,C), the ellipse parameters in + canonical form. + """ + sin2, cos2 = np.sin(theta) ** 2, np.cos(theta) ** 2 + a2, b2 = a**2, b**2 + A = sin2 / b2 + cos2 / a2 + C = cos2 / b2 + sin2 / a2 + B = 2 * (b2 - a2) * np.sin(theta) * np.cos(theta) / (a2 * b2) + return A, B, C + + +### Polar elliptical transformation + + +def cartesian_to_polarelliptical_transform( + cartesianData, + p_ellipse, + dr=1, + dphi=np.radians(2), + r_range=None, + mask=None, + maskThresh=0.99, +): + """ + Transforms an array of data in cartesian coordinates into a data array in + polar-elliptical coordinates. + + Discussion of the elliptical parametrization used can be found in the docstring + for the process.utils.elliptical_coords module. + + Args: + cartesianData (2D float array): the data in cartesian coordinates + p_ellipse (5-tuple): specifies (qx0,qy0,a,b,theta), the parameters for the + transformation. These are the same 5 parameters which are outputs + of the elliptical fitting functions in the process.calibration + module, e.g. fit_ellipse_amorphous_ring and fit_ellipse_1D. For + more details, see the process.utils.elliptical_coords module docstring + dr (float): sampling of the (r,phi) coords: the width of the bins in r + dphi (float): sampling of the (r,phi) coords: the width of the bins in phi, + in radians + r_range (number or length 2 list/tuple or None): specifies the sampling of the + (r,theta) coords. Precise behavior which depends on the parameter type: + * if None, autoselects max r value + * if r_range is a number, specifies the maximum r value + * if r_range is a length 2 list/tuple, specifies the min/max r values + mask (2d array of bools): shape must match cartesianData; where mask==False, + ignore these datapoints in making the polarElliptical data array + maskThresh (float): the final data mask is calculated by converting mask (above) + from cartesian to polar elliptical coords. Due to interpolation, this + results in some non-boolean values - this is converted back to a boolean + array by taking polarEllipticalMask = polarTrans(mask) < maskThresh. Cells + where polarTrans is less than 1 (i.e. has at least one masked NN) should + generally be masked, hence the default value of 0.99. + + Returns: + (3-tuple): A 3-tuple, containing: + + * **polarEllipticalData**: *(2D masked array)* a masked array containing + the data and the data mask, in polarElliptical coordinates + * **rr**: *(2D array)* meshgrid of the r coordinates + * **pp**: *(2D array)* meshgrid of the phi coordinates + """ + if mask is None: + mask = np.ones_like(cartesianData.data, dtype=bool) + assert ( + cartesianData.shape == mask.shape + ), "Mask and cartesian data array shapes must match." + assert len(p_ellipse) == 5, "p_ellipse must have length 5" + + # Get params + qx0, qy0, a, b, theta = p_ellipse + Nx, Ny = cartesianData.shape + + # Define r_range: + if r_range is None: + # find corners of image + corners = np.array( + [ + [0, 0], + [0, cartesianData.shape[0]], + [0, cartesianData.shape[1]], + [cartesianData.shape[0], cartesianData.shape[1]], + ] + ) + # find maximum corner distance + r_min, r_max = 0, np.ceil( + np.max( + np.sqrt( + np.sum( + (corners - np.broadcast_to(np.array((qx0, qy0)), corners.shape)) + ** 2, + axis=1, + ) + ) + ) + ).astype(int) + else: + try: + r_min, r_max = r_range[0], r_range[1] + except TypeError: + r_min, r_max = 0, r_range + + # Define the r/phi coords + r_bins = np.arange(r_min + dr / 2.0, r_max + dr / 2.0, dr) # values are bin centers + p_bins = np.arange(-np.pi + dphi / 2.0, np.pi + dphi / 2.0, dphi) + rr, pp = np.meshgrid(r_bins, p_bins) + Nr, Np = rr.shape + + # Get (qx,qy) corresponding to each (r,phi) in the newly defined coords + xr = rr * np.cos(pp) + yr = rr * np.sin(pp) + qx = qx0 + xr * np.cos(theta) - yr * (b / a) * np.sin(theta) + qy = qy0 + xr * np.sin(theta) + yr * (b / a) * np.cos(theta) + + # qx,qy are now shape (Nr,Np) arrays, such that (qx[r,phi],qy[r,phi]) is the point + # in cartesian space corresponding to r,phi. We now get the values for the final + # polarEllipticalData array by interpolating values at these coords from the original + # cartesianData array. + + transform_mask = (qx > 0) * (qy > 0) * (qx < Nx - 1) * (qy < Ny - 1) + + # Bilinear interpolation + xF = np.floor(qx[transform_mask]) + yF = np.floor(qy[transform_mask]) + dx = qx[transform_mask] - xF + dy = qy[transform_mask] - yF + x_inds = np.vstack((xF, xF + 1, xF, xF + 1)).astype(int) + y_inds = np.vstack((yF, yF, yF + 1, yF + 1)).astype(int) + weights = np.vstack( + ((1 - dx) * (1 - dy), (dx) * (1 - dy), (1 - dx) * (dy), (dx) * (dy)) + ) + transform_mask = transform_mask.ravel() + polarEllipticalData = np.zeros(Nr * Np) + polarEllipticalData[transform_mask] = np.sum( + cartesianData[x_inds, y_inds] * weights, axis=0 + ) + polarEllipticalData = np.reshape(polarEllipticalData, (Nr, Np)) + + # Transform mask + polarEllipticalMask = np.zeros(Nr * Np) + polarEllipticalMask[transform_mask] = np.sum(mask[x_inds, y_inds] * weights, axis=0) + polarEllipticalMask = np.reshape(polarEllipticalMask, (Nr, Np)) + + polarEllipticalData = np.ma.array( + data=polarEllipticalData, mask=polarEllipticalMask < maskThresh + ) + return polarEllipticalData, rr, pp + + +### Cartesian elliptical transform + + +def elliptical_resample_datacube( + datacube, + p_ellipse, + mask=None, + maskThresh=0.99, +): + """ + Perform elliptic resamplig on each diffraction pattern in a DataCube + Detailed description of the args is found in ``elliptical_resample``. + + NOTE: Only use this function if you need to resample the raw data. + If you only need for Bragg disk positions to be corrected, use the + BraggVector calibration routines, as it is much faster to perform + this on the peak positions than the entire datacube. + """ + + from emdfile import tqdmnd + + for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny): + datacube.data[rx, ry] = elliptical_resample( + datacube.data[rx, ry], p_ellipse, mask, maskThresh + ) + + return datacube + + +def elliptical_resample( + data, + p_ellipse, + mask=None, + maskThresh=0.99, +): + """ + Resamples data with elliptic distortion to correct distortion of the + input pattern. + + Discussion of the elliptical parametrization used can be found in the docstring + for the process.utils.elliptical_coords module. + + Args: + data (2D float array): the data in cartesian coordinates + p_ellipse (5-tuple): specifies (qx0,qy0,a,b,theta), the parameters for the + transformation. These are the same 5 parameters which are outputs + of the elliptical fitting functions in the process.calibration + module, e.g. fit_ellipse_amorphous_ring and fit_ellipse_1D. For + more details, see the process.utils.elliptical_coords module docstring + dr (float): sampling of the (r,phi) coords: the width of the bins in r + dphi (float): sampling of the (r,phi) coords: the width of the bins in phi, + in radians + r_range (number or length 2 list/tuple or None): specifies the sampling of the + (r,theta) coords. Precise behavior which depends on the parameter type: + * if None, autoselects max r value + * if r_range is a number, specifies the maximum r value + * if r_range is a length 2 list/tuple, specifies the min/max r values + mask (2d array of bools): shape must match cartesianData; where mask==False, + ignore these datapoints in making the polarElliptical data array + maskThresh (float): the final data mask is calculated by converting mask (above) + from cartesian to polar elliptical coords. Due to interpolation, this + results in some non-boolean values - this is converted back to a boolean + array by taking polarEllipticalMask = polarTrans(mask) < maskThresh. Cells + where polarTrans is less than 1 (i.e. has at least one masked NN) should + generally be masked, hence the default value of 0.99. + + Returns: + (3-tuple): A 3-tuple, containing: + + * **resampled_data**: *(2D masked array)* a masked array containing + the data and the data mask, in polarElliptical coordinates + """ + if mask is None: + mask = np.ones_like(data, dtype=bool) + assert data.shape == mask.shape, "Mask and data array shapes must match." + assert len(p_ellipse) == 5, "p_ellipse must have length 5" + + # Expand params + qx0, qy0, a, b, theta = p_ellipse + Nx, Ny = data.shape + + # Get (qx,qy) corresponding to the coordinates distorted by the ellipse + xr, yr = np.mgrid[0:Nx, 0:Ny] + xr0 = xr.astype(np.float_) - qx0 + yr0 = yr.astype(np.float_) - qy0 + xr = xr0 * np.cos(-theta) - yr0 * np.sin(-theta) + yr = xr0 * np.sin(-theta) + yr0 * np.cos(-theta) + qx = qx0 + xr * np.cos(theta) - yr * (b / a) * np.sin(theta) + qy = qy0 + xr * np.sin(theta) + yr * (b / a) * np.cos(theta) + + # qx,qy are now shape (Nx,Ny) arrays, such that (qx[x,y],qy[x,y]) is the point + # in the distorted space corresponding to x,y. We now get the values for the final + # resampled_data array by interpolating values at these coords from the original + # data array. + + transform_mask = (qx > 0) * (qy > 0) * (qx < Nx - 1) * (qy < Ny - 1) + + # Bilinear interpolation + xF = np.floor(qx[transform_mask]) + yF = np.floor(qy[transform_mask]) + dx = qx[transform_mask] - xF + dy = qy[transform_mask] - yF + x_inds = np.vstack((xF, xF + 1, xF, xF + 1)).astype(int) + y_inds = np.vstack((yF, yF, yF + 1, yF + 1)).astype(int) + weights = np.vstack( + ((1 - dx) * (1 - dy), (dx) * (1 - dy), (1 - dx) * (dy), (dx) * (dy)) + ) + transform_mask = transform_mask.ravel() + resampled_data = np.zeros(Nx * Ny) + resampled_data[transform_mask] = np.sum(data[x_inds, y_inds] * weights, axis=0) + resampled_data = np.reshape(resampled_data, (Nx, Ny)) + + # Transform mask + data_mask = np.zeros(Nx * Ny) + data_mask[transform_mask] = np.sum(mask[x_inds, y_inds] * weights, axis=0) + data_mask = np.reshape(data_mask, (Nx, Ny)) + + resampled_data = np.ma.array(data=resampled_data, mask=data_mask < maskThresh) + return resampled_data + + +### Radial integration + + +def radial_elliptical_integral( + ar, + dr, + p_ellipse, + rmax=None, +): + """ + Computes the radial integral of array ar from center (x0,y0) with a step size in r of + dr. + + Args: + ar (2d array): the data + dr (number): the r sampling + p_ellipse (5-tuple): the parameters (x0,y0,a,b,theta) for the ellipse + r_max (float): maximum radial value + + Returns: + (2-tuple): A 2-tuple containing: + + * **rbin_centers**: *(1d array)* the bins centers of the radial integral + * **radial_integral**: *(1d array)* the radial integral + radial_integral (1d array) the radial integral + """ + x0, y0 = p_ellipse[0], p_ellipse[1] + if rmax is None: + rmax = int( + max( + ( + np.hypot(x0, y0), + np.hypot(x0, ar.shape[1] - y0), + np.hypot(ar.shape[0] - x0, y0), + np.hypot(ar.shape[0] - x0, ar.shape[1] - y0), + ) + ) + ) + + polarAr, rr, pp = cartesian_to_polarelliptical_transform( + ar, p_ellipse=p_ellipse, dr=dr, dphi=np.radians(2), r_range=rmax + ) + radial_integral = np.sum(polarAr, axis=0) + rbin_centers = rr[0, :] + return rbin_centers, radial_integral + + +def radial_integral(ar, x0=None, y0=None, dr=0.1, rmax=None): + """ + Computes the radial integral of array ar from center (x0,y0) with a step size in r of dr. + + Args: + ar (2d array): the data + x0,y0 (floats): the origin + dr (number): radial step size + rmax (float): maximum radial dimension + + Returns: + (2-tuple): A 2-tuple containing: + + * **rbin_centers**: *(1d array)* the bins centers of the radial integral + * **radial_integral**: *(1d array)* the radial integral + """ + + # Default values + if x0 is None: + x0 = ar.shape[0] / 2 + if y0 is None: + y0 = ar.shape[1] / 2 + + if rmax is None: + return radial_elliptical_integral(ar, dr, (x0, y0, 1, 1, 0)) + else: + return radial_elliptical_integral(ar, dr, (x0, y0, 1, 1, 0), rmax=rmax) diff --git a/py4DSTEM/utils/ewpc.py b/py4DSTEM/utils/ewpc.py new file mode 100644 index 000000000..d6d525516 --- /dev/null +++ b/py4DSTEM/utils/ewpc.py @@ -0,0 +1,18 @@ +import numpy as np + + +def get_ewpc_filter_function(Q_Nx, Q_Ny): + """ + Returns a function for computing the exit wave power cepstrum of a diffraction + pattern using a Hanning window. This can be passed as the filter_function in the + Bragg disk detection functions (with the probe an array of ones) to find the lattice + vectors by the EWPC method (but be careful as the lengths are now in realspace + units!) See https://arxiv.org/abs/1911.00984 + """ + h = np.hanning(Q_Nx)[:, np.newaxis] * np.hanning(Q_Ny)[np.newaxis, :] + return ( + lambda x: np.abs(np.fft.fftshift(np.fft.fft2(h * np.log(np.maximum(x, 0.01))))) + ** 2 + ) + + diff --git a/py4DSTEM/utils/get_CoM.py b/py4DSTEM/utils/get_CoM.py new file mode 100644 index 000000000..608e0c118 --- /dev/null +++ b/py4DSTEM/utils/get_CoM.py @@ -0,0 +1,31 @@ +import numpy as np +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +def get_CoM(ar, device="cpu", corner_centered=False): + """ + Finds and returns the center of mass of array ar. + If corner_centered is True, uses fftfreq for indices. + """ + if device == "cpu": + xp = np + elif device == "gpu": + xp = cp + + ar = xp.asarray(ar) + nx, ny = ar.shape + + if corner_centered: + ry, rx = xp.meshgrid(xp.fft.fftfreq(ny, 1 / ny), xp.fft.fftfreq(nx, 1 / nx)) + else: + ry, rx = xp.meshgrid(xp.arange(ny), xp.arange(nx)) + + tot_intens = xp.sum(ar) + xCoM = xp.sum(rx * ar) / tot_intens + yCoM = xp.sum(ry * ar) / tot_intens + return xCoM, yCoM + + diff --git a/py4DSTEM/utils/get_maxima.py b/py4DSTEM/utils/get_maxima.py new file mode 100644 index 000000000..5eaf29509 --- /dev/null +++ b/py4DSTEM/utils/get_maxima.py @@ -0,0 +1,329 @@ +import numpy as np +from scipy.ndimage import gaussian_filter +#from scipy.signal import medfilt +from scipy.ndimage import median_filter +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + +from py4DSTEM.utils.multicorr import upsampled_correlation +from py4DSTEM.utils.linear_interpolation import linear_interpolation_2D + + +def get_maxima_1D(ar, sigma=0, minSpacing=0, minRelativeIntensity=0, relativeToPeak=0): + """ + Finds the indices where 1D array ar is a local maximum. + Optional parameters allow blurring the array and filtering the output; + setting each to 0 (default) turns off these functions. + + Parameters + ---------- + ar : 1D array + sigma : number + gaussian blur std to apply to ar before finding maxima + minSpacing : number + if two maxima are found within minSpacing, the dimmer one + is removed + minRelativeIntensity : number + maxima dimmer than minRelativeIntensity compared + to the relativeToPeak'th brightest maximum are removed + relativeToPeak : int + 0=brightest maximum, 1=next brightest, etc. + + Returns + ------- + (array of ints): An array of indices where ar is a local maximum, sorted by intensity. + """ + assert len(ar.shape) == 1, "ar must be 1D" + assert isinstance( + relativeToPeak, (int, np.integer) + ), "relativeToPeak must be an int" + if sigma > 0: + ar = gaussian_filter(ar, sigma) + + # Get maxima and intensity arrays + maxima_bool = np.logical_and((ar > np.roll(ar, -1)), (ar >= np.roll(ar, +1))) + x = np.arange(len(ar))[maxima_bool] + intensity = ar[maxima_bool] + + # Sort by intensity + temp_ar = np.array( + [(x, inten) for inten, x in sorted(zip(intensity, x), reverse=True)] + ) + x, intensity = temp_ar[:, 0], temp_ar[:, 1] + + # Remove points which are too close + if minSpacing > 0: + deletemask = np.zeros(len(x), dtype=bool) + for i in range(len(x)): + if not deletemask[i]: + delete = np.abs(x[i] - x) < minSpacing + delete[: i + 1] = False + deletemask = deletemask | delete + x = np.delete(x, deletemask.nonzero()[0]) + intensity = np.delete(intensity, deletemask.nonzero()[0]) + + # Remove points which are too dim + if minRelativeIntensity > 0: + deletemask = intensity / intensity[relativeToPeak] < minRelativeIntensity + x = np.delete(x, deletemask.nonzero()[0]) + intensity = np.delete(intensity, deletemask.nonzero()[0]) + + return x.astype(int) + + + +def filter_2D_maxima( + maxima, + minAbsoluteIntensity=0, + minProminence=0, + prominenceKernelSize=3, + minRelativeIntensity=0, + relativeToPeak=0, + minSpacing=0, + edgeBoundary=1, + maxNumPeaks=1, + _ar = None, +): + """ + Parameters + ---------- + maxima : a numpy structured array with fields 'x', 'y', 'intensity' + minAbsoluteIntensity : number + delete counts with intensity below this value + minProminence : number + delete counts whose intensity above the local median is + below this value. _ar must be passed with this arg + prominenceKernelSize : odd number + kernel size for prominence local median comparison + minRelativeIntensity : number + delete counts with intensity below this value times + the intensity of the i'th peak, where i is given by `relativeToPeak` + relativeToPeak : int + 0=brightest peak, 1=second brightest peak, etc + minSpacing : number + if two peaks are within this euclidean distance from one + another, delete the less intense of the two + edgeBoundary : number + delete peaks within this distance of the image edge + maxNumPeaks : int + maximum number of peaks to return + _ar : array or None + if minProminence is passed, this must be the array + + Returns + ------- + a numpy structured array with fields 'x', 'y', 'intensity' + """ + + # Remove maxima which are too dim + if minAbsoluteIntensity > 0: + deletemask = maxima["intensity"] < minAbsoluteIntensity + maxima = maxima[~deletemask] + + # Remove maxima which are too dim relative to local median + if minProminence > 0: + assert(_ar is not None), "Array for median filter wasn't passed" + #med = medfilt(_ar,prominenceKernelSize) + assert(prominenceKernelSize%2==1), f"prominenceKernelSize must be odd, not {prominenceKernelSize}" + pks = prominenceKernelSize + footprint = np.ones((pks,pks), dtype=bool) + footprint[1:-1,1:-1] = 0 + med = median_filter( + _ar, + footprint = footprint + ) + compare = maxima["intensity"] - med[ + maxima['x'].astype(int), + maxima['y'].astype(int) + ] + deletemask = compare < minProminence + maxima = maxima[~deletemask] + + # Remove maxima which are too dim, compared to the n-th brightest + if (minRelativeIntensity > 0) & (len(maxima) > relativeToPeak): + assert isinstance(relativeToPeak, (int, np.integer)) + deletemask = ( + maxima["intensity"] / maxima["intensity"][relativeToPeak] + < minRelativeIntensity + ) + maxima = maxima[~deletemask] + + # Remove maxima which are too close + if minSpacing > 0: + deletemask = np.zeros(len(maxima), dtype=bool) + for i in range(len(maxima)): + if deletemask[i] == False: # noqa: E712 + tooClose = ( + (maxima["x"] - maxima["x"][i]) ** 2 + + (maxima["y"] - maxima["y"][i]) ** 2 + ) < minSpacing**2 + tooClose[: i + 1] = False + deletemask[tooClose] = True + maxima = maxima[~deletemask] + + # Remove maxima in excess of maxNumPeaks + if maxNumPeaks is not None: + if len(maxima) > maxNumPeaks: + maxima = maxima[:maxNumPeaks] + + return maxima + + + +def get_maxima_2D( + ar, + subpixel="poly", + upsample_factor=16, + sigma=0, + minAbsoluteIntensity=0, + minProminence=0, + prominenceKernelSize=3, + minRelativeIntensity=0, + relativeToPeak=0, + minSpacing=0, + edgeBoundary=1, + maxNumPeaks=1, + _ar_FT=None, +): + """ + Finds the maximal points of a 2D array. + + Parameters + ---------- + ar : array + the 2D array + subpixel : str + specifies the subpixel resolution algorithm to use. + must be in ('pixel','poly','multicorr'), which correspond + to pixel resolution, subpixel resolution by fitting a + parabola, and subpixel resultion by Fourier upsampling. + upsample_factor : int, power of 2 + the upsampling factor for the 'multicorr' algorithm + sigma : number + if >0, applies a gaussian filter + maxNumPeaks : int + the maximum number of maxima to return + minAbsoluteIntensity : number + minimum intensity threshold + minProminence : number + intensity threshold, in absolute units, above the local + median + prominenceKernelSize : odd number + kernel size for prominence local median comparison + minRelativeIntensity : number + intensity threshold relative to the N'th brightest peak, + where N is given by `relativeToPeak` + relativeToPeak : int + 0=brightest peaks, 1=second brightest peak, etc. + minSpacing : number + Minimum permissible spacing between peaks + edgeBoundary : number + Peaks within this distance of the image edge are deleted + maxNumPeaks: filtering applied + after maximum detection and before subpixel refinement + _ar_FT (complex array) if 'multicorr' is used and this is not + None, uses this argument as the Fourier transform of `ar`, + instead of recomputing it + + Returns + ------- + a structured array with fields 'x','y','intensity' + """ + subpixel_modes = ("pixel", "poly", "multicorr") + er = f"Unrecognized subpixel option {subpixel}. Must be in {subpixel_modes}" + assert subpixel in subpixel_modes, er + + # gaussian filtering + ar = ar if sigma <= 0 else gaussian_filter(ar, sigma) + + # local pixelwise maxima + maxima_bool = ( + (ar >= np.roll(ar, (-1, 0), axis=(0, 1))) + & (ar > np.roll(ar, (1, 0), axis=(0, 1))) + & (ar >= np.roll(ar, (0, -1), axis=(0, 1))) + & (ar > np.roll(ar, (0, 1), axis=(0, 1))) + & (ar >= np.roll(ar, (-1, -1), axis=(0, 1))) + & (ar > np.roll(ar, (-1, 1), axis=(0, 1))) + & (ar >= np.roll(ar, (1, -1), axis=(0, 1))) + & (ar > np.roll(ar, (1, 1), axis=(0, 1))) + ) + + # remove edges + assert isinstance(edgeBoundary, (int, np.integer)) + if edgeBoundary < 1: + edgeBoundary = 1 + maxima_bool[:edgeBoundary, :] = False + maxima_bool[-edgeBoundary:, :] = False + maxima_bool[:, :edgeBoundary] = False + maxima_bool[:, -edgeBoundary:] = False + + # get indices + # sort by intensity + maxima_x, maxima_y = np.nonzero(maxima_bool) + dtype = np.dtype([("x", float), ("y", float), ("intensity", float)]) + maxima = np.zeros(len(maxima_x), dtype=dtype) + maxima["x"] = maxima_x + maxima["y"] = maxima_y + maxima["intensity"] = ar[maxima_x, maxima_y] + maxima = np.sort(maxima, order="intensity")[::-1] + + if len(maxima) == 0: + return maxima + + # filter + maxima = filter_2D_maxima( + maxima, + minAbsoluteIntensity=minAbsoluteIntensity, + minProminence=minProminence, + prominenceKernelSize=prominenceKernelSize, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + _ar = ar + ) + + if subpixel == "pixel": + return maxima + + # Parabolic subpixel refinement + for i in range(len(maxima)): + Ix1_ = ar[int(maxima["x"][i]) - 1, int(maxima["y"][i])].astype(np.float64) + Ix0 = ar[int(maxima["x"][i]), int(maxima["y"][i])].astype(np.float64) + Ix1 = ar[int(maxima["x"][i]) + 1, int(maxima["y"][i])].astype(np.float64) + Iy1_ = ar[int(maxima["x"][i]), int(maxima["y"][i]) - 1].astype(np.float64) + Iy0 = ar[int(maxima["x"][i]), int(maxima["y"][i])].astype(np.float64) + Iy1 = ar[int(maxima["x"][i]), int(maxima["y"][i]) + 1].astype(np.float64) + deltax = (Ix1 - Ix1_) / (4 * Ix0 - 2 * Ix1 - 2 * Ix1_) + deltay = (Iy1 - Iy1_) / (4 * Iy0 - 2 * Iy1 - 2 * Iy1_) + maxima["x"][i] += deltax + maxima["y"][i] += deltay + maxima["intensity"][i] = linear_interpolation_2D( + ar, maxima["x"][i], maxima["y"][i] + ) + + if subpixel == "poly": + return maxima + + # Fourier upsampling + if _ar_FT is None: + _ar_FT = np.fft.fft2(ar) + for ipeak in range(len(maxima["x"])): + xyShift = np.array((maxima["x"][ipeak], maxima["y"][ipeak])) + # we actually have to lose some precision and go down to half-pixel + # accuracy for multicorr + xyShift[0] = np.round(xyShift[0] * 2) / 2 + xyShift[1] = np.round(xyShift[1] * 2) / 2 + + subShift = upsampled_correlation(_ar_FT, upsample_factor, xyShift) + maxima["x"][ipeak] = subShift[0] + maxima["y"][ipeak] = subShift[1] + + maxima = np.sort(maxima, order="intensity")[::-1] + return maxima + + + diff --git a/py4DSTEM/utils/get_shifted_ar.py b/py4DSTEM/utils/get_shifted_ar.py new file mode 100644 index 000000000..a0f0c44ba --- /dev/null +++ b/py4DSTEM/utils/get_shifted_ar.py @@ -0,0 +1,81 @@ +import numpy as np +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + + +def get_shifted_ar(ar, xshift, yshift, periodic=True, bilinear=False, device="cpu"): + """ + Shifts array ar by the shift vector (xshift,yshift), using the either + the Fourier shift theorem (i.e. with sinc interpolation), or bilinear + resampling. Boundary conditions can be periodic or not. + + Parameters + ---------- + ar : float + input array + xshift : float + shift along axis 0 (x) in pixels + yshift : float + shift along axis 1 (y) in pixels + periodic : bool + flag for periodic boundary conditions + bilinear : bool + flag for bilinear image shifts + device : str + calculation device will be perfomed on. Must be 'cpu' or 'gpu' + + Returns + ------- + the shifted array + """ + if device == "cpu": + xp = np + + elif device == "gpu": + xp = cp + + ar = xp.asarray(ar) + + # Apply image shift + if bilinear is False: + nx, ny = xp.shape(ar) + qx, qy = make_Fourier_coords2D(nx, ny, 1) + qx = xp.asarray(qx) + qy = xp.asarray(qy) + + w = xp.exp(-(2j * xp.pi) * ((yshift * qy) + (xshift * qx))) + shifted_ar = xp.real(xp.fft.ifft2((xp.fft.fft2(ar)) * w)) + + else: + xF = xp.floor(xshift).astype(int).item() + yF = xp.floor(yshift).astype(int).item() + wx = xshift - xF + wy = yshift - yF + + shifted_ar = ( + xp.roll(ar, (xF, yF), axis=(0, 1)) * ((1 - wx) * (1 - wy)) + + xp.roll(ar, (xF + 1, yF), axis=(0, 1)) * ((wx) * (1 - wy)) + + xp.roll(ar, (xF, yF + 1), axis=(0, 1)) * ((1 - wx) * (wy)) + + xp.roll(ar, (xF + 1, yF + 1), axis=(0, 1)) * ((wx) * (wy)) + ) + + if periodic is False: + # Rounded coordinates for boundaries + xR = (xp.round(xshift)).astype(int) + yR = (xp.round(yshift)).astype(int) + + if xR > 0: + shifted_ar[0:xR, :] = 0 + elif xR < 0: + shifted_ar[xR:, :] = 0 + if yR > 0: + shifted_ar[:, 0:yR] = 0 + elif yR < 0: + shifted_ar[:, yR:] = 0 + + return shifted_ar + + diff --git a/py4DSTEM/utils/linear_interpolation.py b/py4DSTEM/utils/linear_interpolation.py new file mode 100644 index 000000000..53c27b52c --- /dev/null +++ b/py4DSTEM/utils/linear_interpolation.py @@ -0,0 +1,58 @@ +import numpy as np +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + + +def linear_interpolation_1D(ar, x): + """ + Calculates the 1D linear interpolation of array ar at position x using the two + nearest elements. + """ + x0, x1 = int(np.floor(x)), int(np.ceil(x)) + dx = x - x0 + return (1 - dx) * ar[x0] + dx * ar[x1] + + +def linear_interpolation_2D(ar, x, y): + """ + Calculates the 2D linear interpolation of array ar at position x,y using the four + nearest array elements. + """ + x0, x1 = int(np.floor(x)), int(np.ceil(x)) + y0, y1 = int(np.floor(y)), int(np.ceil(y)) + dx = x - x0 + dy = y - y0 + return ( + (1 - dx) * (1 - dy) * ar[x0, y0] + + (1 - dx) * dy * ar[x0, y1] + + dx * (1 - dy) * ar[x1, y0] + + dx * dy * ar[x1, y1] + ) + + +def add_to_2D_array_from_floats(ar, x, y, I): + """ + Adds the values I to array ar, distributing the value between the four pixels nearest + (x,y) using linear interpolation. Inputs (x,y,I) may be floats or arrays of floats. + + Note that if the same [x,y] coordinate appears more than once in the input array, + only the *final* value of I at that coordinate will get added. + """ + Nx, Ny = ar.shape + x0, x1 = (np.floor(x)).astype(int), (np.ceil(x)).astype(int) + y0, y1 = (np.floor(y)).astype(int), (np.ceil(y)).astype(int) + mask = np.logical_and( + np.logical_and(np.logical_and((x0 >= 0), (y0 >= 0)), (x1 < Nx)), (y1 < Ny) + ) + dx = x - x0 + dy = y - y0 + ar[x0[mask], y0[mask]] += (1 - dx[mask]) * (1 - dy[mask]) * I[mask] + ar[x0[mask], y1[mask]] += (1 - dx[mask]) * (dy[mask]) * I[mask] + ar[x1[mask], y0[mask]] += (dx[mask]) * (1 - dy[mask]) * I[mask] + ar[x1[mask], y1[mask]] += (dx[mask]) * (dy[mask]) * I[mask] + return ar + + diff --git a/py4DSTEM/utils/make_fourier_coords.py b/py4DSTEM/utils/make_fourier_coords.py new file mode 100644 index 000000000..b18f34ef9 --- /dev/null +++ b/py4DSTEM/utils/make_fourier_coords.py @@ -0,0 +1,48 @@ +import numpy as np +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + + +def make_Fourier_coords2D(Nx, Ny, pixelSize=1): + """ + Generates Fourier coordinates for a (Nx,Ny)-shaped 2D array. + Specifying the pixelSize argument sets a unit size. + """ + if hasattr(pixelSize, "__len__"): + assert len(pixelSize) == 2, "pixelSize must either be a scalar or have length 2" + pixelSize_x = pixelSize[0] + pixelSize_y = pixelSize[1] + else: + pixelSize_x = pixelSize + pixelSize_y = pixelSize + + qx = np.fft.fftfreq(Nx, pixelSize_x) + qy = np.fft.fftfreq(Ny, pixelSize_y) + qy, qx = np.meshgrid(qy, qx) + return qx, qy + + + + +def get_qx_qy_1d(M, dx=[1, 1], fft_shifted=False): + """ + Generates 1D Fourier coordinates for a (Nx,Ny)-shaped 2D array. + Specifying the dx argument sets a unit size. + + Args: + M: (2,) shape of the returned array + dx: (2,) tuple, pixel size + fft_shifted: True if result should be fft_shifted to have the origin in the center of the array + """ + qxa = np.fft.fftfreq(M[0], dx[0]) + qya = np.fft.fftfreq(M[1], dx[1]) + if fft_shifted: + qxa = np.fft.fftshift(qxa) + qya = np.fft.fftshift(qya) + return qxa, qya + + + diff --git a/py4DSTEM/utils/masks.py b/py4DSTEM/utils/masks.py new file mode 100644 index 000000000..c6896516b --- /dev/null +++ b/py4DSTEM/utils/masks.py @@ -0,0 +1,120 @@ +# Functions for generating masks + +import numpy as np +from scipy.ndimage import binary_dilation + + + +def make_circular_mask(shape, qxy0, radius): + """ + Create a hard circular mask, for use in DPC integration or + or to use as a filter in diffraction or real space. + + Args: + shape (2-tuple of ints) image size, in pixels + qxy0 (2-tuple of floats) center coordinates, in pixels. Must be in (row, column) format. + radius (float) radius of mask, in pixels + + Returns: + mask (2D boolean array) the mask + + """ + # coordinates + qx = np.arange(shape[0]) - qxy0[0] + qy = np.arange(shape[1]) - qxy0[1] + [qya, qxa] = np.meshgrid(qy, qx) + + # return circular mask + return qxa**2 + qya**2 < radius**2 + + + +def get_beamstop_mask(dp, qx0, qy0, theta, dtheta=1, w=10, r=10): + """ + Generates a beamstop shaped mask. + + Args: + dp (2d array): a diffraction pattern + qx0,qy0 (numbers): the center position of the beamstop + theta (number): the orientation of the beamstop, in degrees + dtheta (number): angular span of the wedge representing the beamstop, in degrees + w (integer): half the width of the beamstop arm, in pixels + r (number): the radius of a circle at the end of the beamstop, in pixels + + Returns: + (2d boolean array): the mask + """ + # Handle inputs + theta = np.mod(np.radians(theta), 2 * np.pi) + dtheta = np.abs(np.radians(dtheta)) + + # Get a meshgrid + Q_Nx, Q_Ny = dp.shape + qyy, qxx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx)) + qyy, qxx = qyy - qy0, qxx - qx0 + + # wedge handles + if dtheta > 0: + qzz = qxx + qyy * 1j + phi = np.mod(np.angle(qzz), 2 * np.pi) + # Handle the branch cut in the complex plane + if theta - dtheta < 0: + phi, theta = np.mod(phi + dtheta, 2 * np.pi), theta + dtheta + elif theta + dtheta > 2 * np.pi: + phi, theta = np.mod(phi - dtheta, 2 * np.pi), theta - dtheta + mask1 = np.abs(phi - theta) < dtheta + if w > 0: + mask1 = binary_dilation(mask1, iterations=w) + + # straight handles + else: + pass + + # circle mask + qrr = np.hypot(qxx, qyy) + mask2 = qrr < r + + # combine masks + mask = np.logical_or(mask1, mask2) + + return mask + + +def sector_mask(shape, centre, radius, angle_range=(0, 360)): + """ + Return a boolean mask for a circular sector. The start/stop angles in + `angle_range` should be given in clockwise order. + + Args: + shape: 2D shape of the mask + centre: 2D center of the circular sector + radius: radius of the circular mask + angle_range: angular range of the circular mask + """ + x, y = np.ogrid[: shape[0], : shape[1]] + cx, cy = centre + tmin, tmax = np.deg2rad(angle_range) + + # ensure stop angle > start angle + if tmax < tmin: + tmax += 2 * np.pi + + # convert cartesian --> polar coordinates + r2 = (x - cx) * (x - cx) + (y - cy) * (y - cy) + theta = np.arctan2(x - cx, y - cy) - tmin + + # wrap angles between 0 and 2*pi + theta %= 2 * np.pi + + # circular mask + circmask = r2 <= radius * radius + + # print 'radius - ', radius + + # angular mask + anglemask = theta < (tmax - tmin) + + return circmask * anglemask + + + diff --git a/py4DSTEM/utils/multicorr.py b/py4DSTEM/utils/multicorr.py new file mode 100644 index 000000000..58c5fc051 --- /dev/null +++ b/py4DSTEM/utils/multicorr.py @@ -0,0 +1,200 @@ +""" +loosely based on multicorr.py found at: +https://github.com/ercius/openNCEM/blob/master/ncempy/algo/multicorr.py + +modified by SEZ, May 2019 to integrate with py4DSTEM utility functions + * rewrote upsampleFFT (previously did not work correctly) + * modified upsampled_correlation to accept xyShift, the point around which to + upsample the DFT + * eliminated the factor-2 FFT upsample step in favor of using parabolic + for first-pass subpixel (since parabolic is so fast) + * rewrote the matrix multiply DFT to be more pythonic +""" + +import numpy as np + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +def upsampled_correlation(imageCorr, upsampleFactor, xyShift, device="cpu"): + """ + Refine the correlation peak of imageCorr around xyShift by DFT upsampling. + + There are two approaches to Fourier upsampling for subpixel refinement: (a) one + can pad an (appropriately shifted) FFT with zeros and take the inverse transform, + or (b) one can compute the DFT by matrix multiplication using modified + transformation matrices. The former approach is straightforward but requires + performing the FFT algorithm (which is fast) on very large data. The latter method + trades one speedup for a slowdown elsewhere: the matrix multiply steps are expensive + but we operate on smaller matrices. Since we are only interested in a very small + region of the FT around a peak of interest, we use the latter method to get + a substantial speedup and enormous decrease in memory requirement. This + "DFT upsampling" approach computes the transformation matrices for the matrix- + multiply DFT around a small 1.5px wide region in the original `imageCorr`. + + Following the matrix multiply DFT we use parabolic subpixel fitting to + get even more precision! (below 1/upsampleFactor pixels) + + NOTE: previous versions of multiCorr operated in two steps: using the zero- + padding upsample method for a first-pass factor-2 upsampling, followed by the + DFT upsampling (at whatever user-specified factor). I have implemented it + differently, to better support iterating over multiple peaks. **The DFT is always + upsampled around xyShift, which MUST be specified to HALF-PIXEL precision + (no more, no less) to replicate the behavior of the factor-2 step.** + (It is possible to refactor this so that peak detection is done on a Fourier + upsampled image rather than using the parabolic subpixel and rounding as now... + I like keeping it this way because all of the parameters and logic will be identical + to the other subpixel methods.) + + + Args: + imageCorr (complex valued ndarray): + Complex product of the FFTs of the two images to be registered + i.e. m = np.fft.fft2(DP) * probe_kernel_FT; + imageCorr = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m)) + upsampleFactor (int): + Upsampling factor. Must be greater than 2. (To do upsampling + with factor 2, use upsampleFFT, which is faster.) + xyShift: + Location in original image coordinates around which to upsample the + FT. This should be given to exactly half-pixel precision to + replicate the initial FFT step that this implementation skips + + Returns: + (2-element np array): Refined location of the peak in image coordinates. + """ + + if device == "cpu": + xp = np + elif device == "gpu": + xp = cp + + assert upsampleFactor > 2 + + xyShift[0] = xp.round(xyShift[0] * upsampleFactor) / upsampleFactor + xyShift[1] = xp.round(xyShift[1] * upsampleFactor) / upsampleFactor + + globalShift = xp.fix(xp.ceil(upsampleFactor * 1.5) / 2) + + upsampleCenter = xp.asarray(globalShift - upsampleFactor * xyShift) + + imageCorrUpsample = xp.conj( + dftUpsample(xp.conj(imageCorr), upsampleFactor, upsampleCenter, device=device) + ) + + xySubShift = xp.asarray( + xp.unravel_index(imageCorrUpsample.argmax(), imageCorrUpsample.shape) + ) + + # add a subpixel shift via parabolic fitting + try: + icc = xp.real( + imageCorrUpsample[ + xySubShift[0] - 1 : xySubShift[0] + 2, + xySubShift[1] - 1 : xySubShift[1] + 2, + ] + ) + dx = (icc[2, 1] - icc[0, 1]) / (4 * icc[1, 1] - 2 * icc[2, 1] - 2 * icc[0, 1]) + dy = (icc[1, 2] - icc[1, 0]) / (4 * icc[1, 1] - 2 * icc[1, 2] - 2 * icc[1, 0]) + except: + dx, dy = ( + 0, + 0, + ) # this is the case when the peak is near the edge and one of the above values does not exist + + xySubShift = xySubShift - globalShift + + xyShift = xyShift + (xySubShift + xp.array([dx, dy])) / upsampleFactor + + return xyShift + + +def upsampleFFT(cc, device="cpu"): + """ + Zero-padding FFT upsampling. Returns the real IFFT of the input with 2x + upsampling. This may have an error for matrices with an odd size. Takes + a complex np array as input. + """ + if device == "cpu": + xp = np + elif device == "gpu": + xp = cp + + sz = cc.shape + ups = xp.zeros((sz[0] * 2, sz[1] * 2), dtype=complex) + + ups[: int(np.ceil(sz[0] / 2)), : int(np.ceil(sz[1] / 2))] = cc[ + : int(np.ceil(sz[0] / 2)), : int(np.ceil(sz[1] / 2)) + ] + ups[-int(np.ceil(sz[0] / 2)) :, : int(np.ceil(sz[1] / 2))] = cc[ + -int(np.ceil(sz[0] / 2)) :, : int(np.ceil(sz[1] / 2)) + ] + ups[: int(np.ceil(sz[0] / 2)), -int(np.ceil(sz[1] / 2)) :] = cc[ + : int(np.ceil(sz[0] / 2)), -int(np.ceil(sz[1] / 2)) : + ] + ups[-int(np.ceil(sz[0] / 2)) :, -int(np.ceil(sz[1] / 2)) :] = cc[ + -int(np.ceil(sz[0] / 2)) :, -int(np.ceil(sz[1] / 2)) : + ] + + return xp.real(xp.fft.ifft2(ups)) + + +def dftUpsample(imageCorr, upsampleFactor, xyShift, device="cpu"): + """ + This performs a matrix multiply DFT around a small neighboring region of the inital + correlation peak. By using the matrix multiply DFT to do the Fourier upsampling, the + efficiency is greatly improved. This is adapted from the subfuction dftups found in + the dftregistration function on the Matlab File Exchange. + + https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation + + The matrix multiplication DFT is from: + + Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel + image registration algorithms," Opt. Lett. 33, 156-158 (2008). + http://www.sciencedirect.com/science/article/pii/S0045790612000778 + + Args: + imageCorr (complex valued ndarray): + Correlation image between two images in Fourier space. + upsampleFactor (int): + Scalar integer of how much to upsample. + xyShift (list of 2 floats): + Coordinates in the UPSAMPLED GRID around which to upsample. + These must be single-pixel IN THE UPSAMPLED GRID + + Returns: + (ndarray): + Upsampled image from region around correlation peak. + """ + if device == "cpu": + xp = np + elif device == "gpu": + xp = cp + + imageSize = imageCorr.shape + pixelRadius = 1.5 + numRow = np.ceil(pixelRadius * upsampleFactor) + numCol = numRow + + colKern = xp.exp( + (-1j * 2 * np.pi / (imageSize[1] * upsampleFactor)) + * xp.outer( + (xp.fft.ifftshift((xp.arange(imageSize[1]))) - xp.floor(imageSize[1] / 2)), + (xp.arange(numCol) - xyShift[1]), + ) + ) + + rowKern = xp.exp( + (-1j * 2 * np.pi / (imageSize[0] * upsampleFactor)) + * xp.outer( + (xp.arange(numRow) - xyShift[0]), + (xp.fft.ifftshift(xp.arange(imageSize[0])) - xp.floor(imageSize[0] / 2)), + ) + ) + + imageUpsample = xp.real(rowKern @ imageCorr @ colKern) + return imageUpsample diff --git a/py4DSTEM/utils/radial_reduction.py b/py4DSTEM/utils/radial_reduction.py new file mode 100644 index 000000000..34b48c43e --- /dev/null +++ b/py4DSTEM/utils/radial_reduction.py @@ -0,0 +1,35 @@ +import numpy as np + + + +def radial_reduction(ar, x0, y0, binsize=1, fn=np.mean, coords=None): + """ + Evaluate a reduction function on pixels within annular rings centered on (x0,y0), + with a ring width of binsize. + + By default, returns the mean value of pixels within each annulus. + Some other useful reductions include: np.sum, np.std, np.count, np.median, ... + + When running in a loop, pre-compute the pixel coordinates and pass them in + for improved performance, like so: + coords = np.mgrid[0:ar.shape[0],0:ar.shape[1]] + radial_sums = radial_reduction(ar, x0,y0, coords=coords) + """ + qx, qy = coords if coords else np.mgrid[0 : ar.shape[0], 0 : ar.shape[1]] + + r = ( + np.floor(np.hypot(qx - x0, qy - y0).ravel() / binsize).astype(np.int64) + * binsize + ) + edges = np.cumsum(np.bincount(r)[::binsize]) + slices = [slice(0, edges[0])] + [ + slice(edges[i], edges[i + 1]) for i in range(len(edges) - 1) + ] + rargsort = np.argsort(r) + sorted_ar = ar.ravel()[rargsort] + reductions = np.array([fn(sorted_ar[s]) for s in slices]) + + return reductions + + + diff --git a/py4DSTEM/utils/resample.py b/py4DSTEM/utils/resample.py new file mode 100644 index 000000000..55ecdcc4e --- /dev/null +++ b/py4DSTEM/utils/resample.py @@ -0,0 +1,244 @@ +import numpy as np +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + + +def fourier_resample( + array, + scale=None, + output_size=None, + force_nonnegative=False, + bandlimit_nyquist=None, + bandlimit_power=2, + dtype=np.float32, +): + """ + Resize a 2D array along any dimension, using Fourier interpolation / extrapolation. + For 4D input arrays, only the final two axes can be resized. + + The scaling of the array can be specified by passing either `scale`, which sets + the scaling factor along both axes to be scaled; or by passing `output_size`, + which specifies the final dimensions of the scaled axes (and allows for different + scaling along the x,y or kx,ky axes.) + + Parameters + ---------- + array : 2D/4D numpy array + Input array, or 4D stack of arrays, to be resized. + scale : float + scalar value giving the scaling factor for all dimensions + output_size : 2-tuple of ints + two values giving either the (x,y) output size for 2D, or (kx,ky) for 4D + force_nonnegative : bool + Force all outputs to be nonnegative, after filtering + bandlimit_nyquist : float + Gaussian filter information limit in Nyquist units (0.5 max in both directions) + bandlimit_power : float + Gaussian filter power law scaling (higher is sharper) + dtype : numpy dtype + datatype for binned array. default is single precision float + + Returns + ------- + the resized array (2D/4D numpy array) + """ + + # Verify input is 2D or 4D + if np.size(array.shape) != 2 and np.size(array.shape) != 4: + raise Exception( + "Function does not support arrays with " + + str(np.size(array.shape)) + + " dimensions" + ) + + # Get input size from last 2 dimensions + input__size = array.shape[-2:] + + if scale is not None: + assert ( + output_size is None + ), "Cannot specify both a scaling factor and output size" + assert np.size(scale) == 1, "scale should be a single value" + scale = np.asarray(scale) + output_size = (input__size * scale).astype("intp") + else: + assert scale is None, "Cannot specify both a scaling factor and output size" + assert np.size(output_size) == 2, "output_size must contain two values" + output_size = np.asarray(output_size) + + scale_output = np.prod(output_size) / np.prod(input__size) + + if bandlimit_nyquist is not None: + kx = np.fft.fftfreq(output_size[0]) + ky = np.fft.fftfreq(output_size[1]) + k2 = kx[:, None] ** 2 + ky[None, :] ** 2 + # Gaussian filter + k_filt = np.exp( + (k2 ** (bandlimit_power / 2)) / (-2 * bandlimit_nyquist**bandlimit_power) + ) + + # generate slices + # named as {dimension}_{corner}_{in_/out}, + # where corner is ul, ur, ll, lr for {upper/lower}{left/right} + + # x slices + if output_size[0] > input__size[0]: + # x dimension increases + x0 = int((input__size[0] + 1) // 2) + x1 = int(input__size[0] // 2) + + x_ul_out = slice(0, x0) + x_ul_in_ = slice(0, x0) + + x_ll_out = slice(0 - x1 + output_size[0], output_size[0]) + x_ll_in_ = slice(0 - x1 + input__size[0], input__size[0]) + + x_ur_out = slice(0, x0) + x_ur_in_ = slice(0, x0) + + x_lr_out = slice(0 - x1 + output_size[0], output_size[0]) + x_lr_in_ = slice(0 - x1 + input__size[0], input__size[0]) + + elif output_size[0] < input__size[0]: + # x dimension decreases + x0 = int((output_size[0] + 1) // 2) + x1 = int(output_size[0] // 2) + + x_ul_out = slice(0, x0) + x_ul_in_ = slice(0, x0) + + x_ll_out = slice(0 - x1 + output_size[0], output_size[0]) + x_ll_in_ = slice(0 - x1 + input__size[0], input__size[0]) + + x_ur_out = slice(0, x0) + x_ur_in_ = slice(0, x0) + + x_lr_out = slice(0 - x1 + output_size[0], output_size[0]) + x_lr_in_ = slice(0 - x1 + input__size[0], input__size[0]) + + else: + # x dimension does not change + x_ul_out = slice(None) + x_ul_in_ = slice(None) + + x_ll_out = slice(None) + x_ll_in_ = slice(None) + + x_ur_out = slice(None) + x_ur_in_ = slice(None) + + x_lr_out = slice(None) + x_lr_in_ = slice(None) + + # y slices + if output_size[1] > input__size[1]: + # y increases + y0 = int((input__size[1] + 1) // 2) + y1 = int(input__size[1] // 2) + + y_ul_out = slice(0, y0) + y_ul_in_ = slice(0, y0) + + y_ll_out = slice(0, y0) + y_ll_in_ = slice(0, y0) + + y_ur_out = slice(0 - y1 + output_size[1], output_size[1]) + y_ur_in_ = slice(0 - y1 + input__size[1], input__size[1]) + + y_lr_out = slice(0 - y1 + output_size[1], output_size[1]) + y_lr_in_ = slice(0 - y1 + input__size[1], input__size[1]) + + elif output_size[1] < input__size[1]: + # y decreases + y0 = int((output_size[1] + 1) // 2) + y1 = int(output_size[1] // 2) + + y_ul_out = slice(0, y0) + y_ul_in_ = slice(0, y0) + + y_ll_out = slice(0, y0) + y_ll_in_ = slice(0, y0) + + y_ur_out = slice(0 - y1 + output_size[1], output_size[1]) + y_ur_in_ = slice(0 - y1 + input__size[1], input__size[1]) + + y_lr_out = slice(0 - y1 + output_size[1], output_size[1]) + y_lr_in_ = slice(0 - y1 + input__size[1], input__size[1]) + + else: + # y dimension does not change + y_ul_out = slice(None) + y_ul_in_ = slice(None) + + y_ll_out = slice(None) + y_ll_in_ = slice(None) + + y_ur_out = slice(None) + y_ur_in_ = slice(None) + + y_lr_out = slice(None) + y_lr_in_ = slice(None) + + if len(array.shape) == 2: + # image array + array_resize = np.zeros(output_size, dtype=np.complex64) + array_fft = np.fft.fft2(array) + + # copy each quadrant into the resize array + array_resize[x_ul_out, y_ul_out] = array_fft[x_ul_in_, y_ul_in_] + array_resize[x_ll_out, y_ll_out] = array_fft[x_ll_in_, y_ll_in_] + array_resize[x_ur_out, y_ur_out] = array_fft[x_ur_in_, y_ur_in_] + array_resize[x_lr_out, y_lr_out] = array_fft[x_lr_in_, y_lr_in_] + + # Band limit if needed + if bandlimit_nyquist is not None: + array_resize *= k_filt + + # Back to real space + array_resize = np.real(np.fft.ifft2(array_resize)).astype(dtype) + + elif len(array.shape) == 4: + # This case is the same as the 2D case, but loops over the probe index arrays + + # init arrays + array_resize = np.zeros((*array.shape[:2], *output_size), dtype) + array_fft = np.zeros(input__size, dtype=np.complex64) + array_output = np.zeros(output_size, dtype=np.complex64) + + for Rx, Ry in tqdmnd( + array.shape[0], + array.shape[1], + desc="Resampling 4D datacube", + unit="DP", + unit_scale=True, + ): + array_fft[:, :] = np.fft.fft2(array[Rx, Ry, :, :]) + array_output[:, :] = 0 + + # copy each quadrant into the resize array + array_output[x_ul_out, y_ul_out] = array_fft[x_ul_in_, y_ul_in_] + array_output[x_ll_out, y_ll_out] = array_fft[x_ll_in_, y_ll_in_] + array_output[x_ur_out, y_ur_out] = array_fft[x_ur_in_, y_ur_in_] + array_output[x_lr_out, y_lr_out] = array_fft[x_lr_in_, y_lr_in_] + + # Band limit if needed + if bandlimit_nyquist is not None: + array_output *= k_filt + + # Back to real space + array_resize[Rx, Ry, :, :] = np.real(np.fft.ifft2(array_output)).astype( + dtype + ) + + # Enforce positivity if needed, after filtering + if force_nonnegative: + array_resize = np.maximum(array_resize, 0) + + # Normalization + array_resize = array_resize * scale_output + + return array_resize + + diff --git a/py4DSTEM/utils/scattering_factors.txt b/py4DSTEM/utils/scattering_factors.txt new file mode 100644 index 000000000..1e33df156 --- /dev/null +++ b/py4DSTEM/utils/scattering_factors.txt @@ -0,0 +1,103 @@ +6.47384848835291790e-03 2.78519885379148890e+00 -4.90192576780229040e-01 2.77620428330644750e+00 5.73284160390876480e-01 2.77538591050625130e+00 -3.79403301483990480e-01 2.76759302867258810e+00 5.54426474774079140e-01 2.76511897642927540e+00 +3.05745116099835460e+00 1.08967248726078810e+00 -6.20044779127325260e+01 9.39838798143121100e-01 6.40055537084614490e+01 9.25289034386265530e-01 -5.00132578542780590e+00 8.22947498708650580e-01 1.51798828700526440e-01 5.77393110675402220e-01 +3.92622272886147930e+00 8.14276013517280360e+00 -4.54861962639998030e+00 4.98941077007855770e+00 2.19335312878658510e+00 4.14428999239410880e+00 6.99451265033965710e-02 4.01922315065680210e-01 2.09864224851937560e-03 1.56479034719823580e-01 +3.39824970557054050e+00 4.44270178622409520e+00 -1.90866886095696660e+00 3.32451542526422950e+00 3.90702117539227370e-02 1.89772880348214850e-01 -1.11631010210714520e-02 8.71918614644603550e-02 9.46204465357523140e-03 8.27809060041340560e-02 +1.47279248639329290e+00 3.74974048281819130e+00 -4.01933042199387140e-01 5.88066536139673750e-01 3.05998956982689360e-01 5.15639613103010990e-01 1.96144217173168000e-02 1.21377570080603680e-01 9.77177106088202540e-04 6.80982412160313910e-02 +1.24466088621343300e+02 2.42120849256005630e+00 -2.20352857078963780e+02 2.30537943752425800e+00 1.95235352280479130e+02 2.04851932106564230e+00 -9.81079361269799650e+01 1.93352552917547380e+00 1.42023041213623150e-02 7.68976818478339650e-02 +5.81327150702556140e+01 1.70044856413471050e+00 -1.47542409087812730e+02 1.55903852601740440e+00 1.30143065649639450e+02 1.41576827473146880e+00 -3.96195674084154280e+01 1.27841818205455790e+00 1.05957763331480850e-02 5.65587798474805540e-02 +2.99474045242362440e+01 1.30283987880010680e+00 -7.76101266255278260e+01 1.15794105258309530e+00 9.98817764623144200e+01 1.00988549338025120e+00 -5.12127005505673050e+01 9.43327971433265970e-01 8.19618954446032010e-03 4.33197611321825550e-02 +9.48984894503524750e-01 1.45882933198645910e+00 -3.01333923043554930e+01 6.88779993187680020e-01 5.27965078127338640e+01 6.54239869346695650e-01 -2.27062703795272430e+01 6.14836130811994290e-01 6.56997664531440950e-03 3.42837419495011160e-02 +5.82741192220907370e-01 1.28118573143877200e+00 3.70676561841054910e-01 4.44520897170477600e-01 -5.46744967350809240e-01 1.98650875510481020e-01 4.14052682480208050e-01 1.85477246656276460e-01 5.19903080863993140e-03 2.75738382033885800e-02 +2.36700603946792610e+01 8.45148773514603140e+00 -2.18531786159742470e+01 8.04096600474298210e+00 5.92499448108946390e-01 6.24996000526314990e-01 -2.44652290310244000e-02 1.32450394947296380e-01 4.83950221706535770e-03 2.33994362049878630e-02 +4.85501047687149520e+00 5.94639273842456450e+00 -2.66220906476843670e+00 4.17130312520697900e+00 4.78001236085108030e-01 3.98269808150374270e-01 -7.02307064692064830e-02 1.61886185837474190e-01 3.98905828104019820e-03 1.95345056363103450e-02 +2.83409561607507500e+00 6.66235023980533200e+00 -4.28004133378261020e+00 5.51294722224021430e-01 4.42191680548311440e+00 5.09328963445973780e-01 -3.45774471896400580e-02 1.11784837425331210e-01 3.52385941406079890e-03 1.67602351805257390e-02 +2.87189142611612350e+00 5.08487103642989610e+00 -2.06173501195173530e+00 4.29178185305126190e-01 2.17114024204478720e+00 3.66485434192162230e-01 -6.63073633058801900e-02 1.19710611296903400e-01 3.01070709670513740e-03 1.43994536128397540e-02 +2.79151840023151410e+00 3.90065961865466140e+00 -4.36506837823822110e+00 3.29825968377150000e-01 4.43558455516699010e+00 3.06089956505888550e-01 -8.09635773399473850e-02 1.08083232545972810e-01 2.67900017966440440e-03 1.25894495331186820e-02 +2.67971415610199460e+00 3.06889121199971140e+00 -4.74252822230755160e-01 3.78216702185809000e-01 5.14835948989687650e-01 1.88721811902548830e-01 -9.58360024990722450e-02 9.23370590031494100e-02 2.48871963818935240e-03 1.11920877241144090e-02 +2.56624839980020260e+00 2.41594920365612390e+00 -3.38876350828591740e-01 4.21414239310215990e-01 1.14584558755515010e+00 1.09592404975830310e-01 -9.23109316547079620e-01 9.90955458226752960e-02 2.29168002041042150e-03 9.99665948927521040e-03 +2.45981746414068510e+00 1.94004631988856560e+00 -3.64198177076995090e-01 3.99241067884398780e-01 2.50584477222474460e-01 1.17472406274412200e-01 -5.77437029544345810e-02 5.67803726023621840e-02 2.30143866828583670e-03 9.15579832920707970e-03 +5.81107878601454790e+00 1.26691483399036820e+01 -5.02537096539422590e+01 3.95641039698166490e+00 4.88609412059842470e+01 3.68385059577154510e+00 7.40628592048382940e-02 1.07458517569562820e-01 7.27802738626221870e-04 6.65576789391501140e-03 +2.11781161524159530e+01 6.39608619431736170e+00 -3.39043824317468880e+02 3.74024713891749450e+00 3.22756958523296650e+02 3.64888449922605270e+00 6.50077673896996690e-02 9.45090634514673540e-02 6.55874366578593520e-04 5.98520619883758600e-03 +1.26035186572148630e+01 6.15625615363852940e+00 -2.76875382053702590e+02 3.08873554266679310e+00 2.68871603907342830e+02 3.02727663298548320e+00 5.56824178897458030e-02 8.18874748375215680e-02 5.77071255145399560e-04 5.38289832050797570e-03 +8.57595775238129930e+00 6.00780668875581810e+00 -2.10331563465304870e+02 2.60285856745213010e+00 2.06097172601541360e+02 2.55352345051105620e+00 4.77773948977261930e-02 7.11429484024153900e-02 5.05716484480320570e-04 4.85628439383726120e-03 +6.52768433234788950e+00 5.83552479352448120e+00 -2.00430576829172650e+02 2.23255952381127630e+00 1.98015053889994020e+02 2.19786018594228990e+00 4.13911518064005580e-02 6.23973876828547300e-02 4.47455024329856520e-04 4.40383649079958020e-03 +3.02831784843691310e+00 8.35911504314631770e+00 -9.55393933081432320e+01 1.80263790264103240e+00 9.61761562352198070e+01 1.77509488957047210e+00 3.59777315957987700e-02 5.48144412143113730e-02 3.91492890718422320e-04 3.99828968916034810e-03 +4.37417550633122690e+00 5.51031705504928220e+00 -1.60925510918779510e+02 1.68798216402433910e+00 1.60273308060322340e+02 1.66614047777702390e+00 3.12303861049162410e-02 4.83370390321277020e-02 3.46966021020938010e-04 3.64746960053068160e-03 +3.79810090836859620e+00 5.31712645899433060e+00 -9.16893549381687190e+01 1.49713094884748130e+00 9.14454252155429830e+01 1.46809241804410510e+00 2.72754344027604720e-02 4.27247850128954910e-02 3.03379854383771030e-04 3.32791855231874110e-03 +3.33037874467544270e+00 5.18135964579543180e+00 -7.70017596472967090e+01 1.32915122265139420e+00 7.70725221790516880e+01 1.30284928963228920e+00 2.39904669040569940e-02 3.80645486826370960e-02 2.68256665523942030e-04 3.05010007991570780e-03 +2.96908078725286370e+00 5.04180949109380010e+00 -7.57477069129040220e+01 1.18275507921629310e+00 7.60398287625307320e+01 1.16216545846629930e+00 2.10162113692966970e-02 3.37479087883035340e-02 2.31151751135153040e-04 2.78680862070740520e-03 +1.75207145212145600e+00 6.18750497986187130e+00 -4.30410523492124500e+01 1.00266263628976620e+00 4.40705915543573850e+01 9.85384311353030280e-01 1.86876154088144730e-02 3.02984703916117610e-02 2.01727324791624140e-04 2.55855598748879050e-03 +2.46637110499459130e+00 4.91028078493815910e+00 -6.14678541332537450e+01 9.67898520322992170e-01 6.20176945237481260e+01 9.51283834775355500e-01 1.64160173931416210e-02 2.69600967667566260e-02 1.72487117595380870e-04 2.34109611046253710e-03 +2.76010203108428340e+00 6.10128224537662690e+00 -3.44452614207467890e+01 7.65143313553464880e-01 3.52262267244016340e+01 7.51328623382859660e-01 1.32067196999419880e-02 2.24879634317251170e-02 1.25945560928245720e-04 2.06737374278712340e-03 +3.18241635260000020e+00 5.01719040860914680e+00 -5.24514037811166670e+01 7.12395764437798170e-01 5.29690827162218480e+01 7.02280192528193960e-01 1.14096168591864820e-02 1.96747295694015350e-02 9.50954358145057910e-05 1.84146614494011450e-03 +3.45642969119603950e+00 4.01358016032945210e+00 -3.33176044431721220e+01 6.62355778050629060e-01 3.35712193855332190e+01 6.45771941056073830e-01 9.79002295640342070e-03 1.70919353231011460e-02 6.53434862265172480e-05 1.60301602839420070e-03 +3.64905047801926410e+00 3.25043267112593390e+00 -4.36851662221238610e+01 6.09666201650288950e-01 4.36920288601154780e+01 5.96971300802123240e-01 8.44902284199144410e-03 1.48554512740362250e-02 3.78611474110038160e-05 1.33625587356563020e-03 +3.83846312224289490e+00 2.61189470573232270e+00 -5.22723471011233870e+01 5.66195062874718210e-01 5.19861279495659620e+01 5.55279326699873900e-01 7.33955989338044760e-03 1.29746469432440690e-02 1.64694207951339920e-05 1.02986536855875730e-03 +4.02541030310827440e+00 2.13648381443739990e+00 -4.63042332078929770e+01 5.26591166454131180e-01 4.57213681904101180e+01 5.14136784464577560e-01 6.35337959625315720e-03 1.12807242242006810e-02 1.33477845255742950e-06 4.88089857940970360e-04 +4.77092509299825980e+00 1.33668881330412750e+01 1.47597850155283420e+00 1.33738379577196680e+00 3.04451355544188670e-01 1.77532369437782830e-01 3.59474981916728340e-03 7.79105033001826780e-03 3.00085549228278960e-07 2.82255139648537230e-04 +3.38975351595393980e+00 2.05744814368171130e+01 2.14348348679172410e+00 1.91079945218516370e+00 3.54322603510978160e-01 1.97410589396603890e-01 3.74009340085664200e-03 8.13459465343988920e-03 3.00342502342229660e-07 2.92685791056951060e-04 +4.60721019875234020e+00 1.08686905557151850e+01 1.42801851039840270e+00 1.31137455873124110e+00 2.95581045577753830e-01 1.68022870784711150e-01 3.38997840852895120e-03 7.35964545350478290e-03 2.66862974969022400e-07 2.62343281021387260e-04 +4.31175453406871600e+00 9.45896580517490190e+00 1.49331578039339450e+00 1.33063622845245620e+00 2.81236050128884140e-01 1.56506871444453550e-01 3.09337738455747140e-03 6.82410523811735180e-03 2.58024448512530600e-07 2.49818879110702270e-04 +3.11179039113495380e+00 1.06903141343023660e+01 2.20259060945803050e+00 1.65316356158933140e+00 2.70330774910020890e-01 1.45115185718969860e-01 2.68799194751098100e-03 6.13956351582850470e-03 2.32549483347779190e-07 2.32240235937672540e-04 +2.83105968432053560e+00 1.04357195758950730e+01 2.34858137489674460e+00 1.60482868674597220e+00 2.45105888429696440e-01 1.31696934774606920e-01 2.35282108218515860e-03 5.54977901383345900e-03 2.31270838343859220e-07 2.22747463764897690e-04 +2.57179859323352570e+00 1.01643117131377530e+01 2.45633741957203040e+00 1.53441919244550480e+00 2.20658440881371460e-01 1.19138619754110400e-01 2.05531418012438150e-03 5.01853240882988840e-03 2.32132947793692180e-07 2.14590379120861160e-04 +2.33230030353946560e+00 9.92167460159959450e+00 2.53578025489049620e+00 1.45585668808788140e+00 1.98208042136026000e-01 1.07682187873349280e-01 1.76120130460053530e-03 4.47243545654967340e-03 1.98129411546024490e-07 1.96574716232499390e-04 +2.11352534859470030e+00 9.65913725859762270e+00 2.58636316155013770e+00 1.37106656934329620e+00 1.77063913863686530e-01 9.70353027584648780e-02 1.49737051003056920e-03 3.97128509088792110e-03 2.05481445555622420e-07 1.91355190737798320e-04 +6.42159796182687260e-01 5.97479750263406120e+00 2.97914814426328920e+00 1.43359432541277740e+00 1.68154426004270800e-01 9.09868401172676950e-02 1.33744213878407760e-03 3.62410137116162160e-03 1.91410968246861570e-07 1.80688914416045000e-04 +1.55317216680389560e+00 8.15620235758956550e+00 2.63930364699988780e+00 1.21600887481801510e+00 1.42015486956788170e-01 7.90098864915873420e-02 1.00850460099760920e-03 2.96547901363948970e-03 1.94638429973024380e-07 1.74593095919146200e-04 +6.15307851992860150e+01 3.11468102533247300e+00 -7.86016741201582080e+01 2.76016983388574480e+00 2.15501292602770460e+01 1.93551312324722380e+00 1.37685015664191560e-01 7.22468347260293160e-02 3.24644930946521010e-04 1.17001629624684460e-03 +4.22232177901524610e+00 6.07265510403227450e+00 -2.64121318353202560e+01 1.64550178959387900e+00 2.72852852708746920e+01 1.52257074939506750e+00 1.21617901884137960e-01 6.56507914695225600e-02 3.06883546371525800e-04 1.12093851473035640e-03 +5.14222074642053690e+00 5.27272636474484460e+00 -2.54945413762576720e+01 1.53194959209148300e+00 2.57414487548601870e+01 1.40257525040087170e+00 1.11778227592199600e-01 6.11685387263086880e-02 2.93647384737742710e-04 1.07560815394522340e-03 +6.24164031831883650e+00 4.26984108082687540e+00 -9.33868724419552190e+01 1.39407746140250150e+00 9.26332875829884300e+01 1.35566854910434480e+00 1.03412692221298720e-01 5.72667949652199830e-02 2.81848426894390730e-04 1.03307954282067370e-03 +7.37743301813403020e+00 3.46917757783897770e+00 -1.26025106931628140e+02 1.29759810886736600e+00 1.24128405004095550e+02 1.26771067646066030e+00 9.59978371818569760e-02 5.37771750179423110e-02 2.71072216398882020e-04 9.93084577029176810e-04 +9.64400666272123170e+00 2.72645545366544080e+00 -1.22924435350112520e+02 1.23723425850866020e+00 1.18682564816567250e+02 1.20062036872249790e+00 8.95025947025538950e-02 5.06678682285199920e-02 2.61276121490475030e-04 9.55437883383497270e-04 +1.55451749674860600e+01 2.10637340865492680e+00 -1.18241027856744580e+02 1.20860376129512240e+00 1.08009524963196970e+02 1.15395270567214010e+00 8.36259341992221660e-02 4.78189391274824250e-02 2.51991862429073870e-04 9.19886257592613110e-04 +4.28708739181692260e+00 2.26587870798541500e+01 3.23250665422964990e+00 2.23797386470064240e+00 6.74029533561706250e-01 3.68995568664316260e-01 6.18908083387697610e-02 4.02266575298484350e-02 2.35612052952190030e-04 8.83761890878035130e-04 +6.24475187390461530e+00 1.51431354190985630e+01 2.35172271416388460e+00 1.45379000604814030e+00 4.74279373222252000e-01 3.20835646387801770e-01 6.38113874286849870e-02 4.04354532094699790e-02 2.34651280562792400e-04 8.54081131280032660e-04 +6.09788179599509660e+00 1.24288544315854800e+01 2.19495164736675010e+00 1.50535992352023460e+00 5.48172791960022550e-01 3.33738039717124290e-01 6.16669573222662570e-02 3.87744553499998830e-02 2.26807355864714200e-04 8.24049884064866980e-04 +5.79526879647240540e+00 1.42801055058256010e+01 2.37022664107843270e+00 1.35969015719134510e+00 4.71398756901114880e-01 3.02017349663514120e-01 5.74368260587866800e-02 3.66436798147007710e-02 2.18979489259659270e-04 7.95434526518371140e-04 +5.60406255377525750e+00 1.39517490230574650e+01 2.35796259561812920e+00 1.31239754917439220e+00 4.76001098572882360e-01 2.94933701192956530e-01 5.44123374315247100e-02 3.48601542462324140e-02 2.11414602205149750e-04 7.68261682661809690e-04 +5.42908391969770320e+00 1.36503649427659970e+01 2.33687325360880300e+00 1.26759841390319910e+00 4.83373554104881860e-01 2.88610681576929930e-01 5.14654964557478970e-02 3.31291807446609310e-02 2.03776132863761090e-04 7.42348483801266290e-04 +5.26774445089444580e+00 1.33602096827394550e+01 2.30855813326304960e+00 1.22585856586941720e+00 4.93265479026012750e-01 2.82919632127866860e-01 4.86356279741696550e-02 3.14735397127137930e-02 1.96308842320675220e-04 7.17658025069615150e-04 +5.12680428517077580e+00 1.31581501026129040e+01 2.26925534008380580e+00 1.18129508243337060e+00 5.04209300453302480e-01 2.77107038219062910e-01 4.56923992416221960e-02 2.97957604549544570e-02 1.88675050493886220e-04 6.94011099199098280e-04 +4.97962359749809200e+00 1.28392663078033800e+01 2.24183087455669480e+00 1.14705446420486030e+00 5.12933961402893710e-01 2.70387160935243730e-01 4.29801848498334440e-02 2.82418714952601960e-02 1.81381692485787560e-04 6.71486603173992420e-04 +5.07835830045611390e+00 1.05132725469047320e+01 1.95744027181658950e+00 1.11764941281524610e+00 5.92825983221375140e-01 2.84386741833675240e-01 4.19502034143925600e-02 2.72663327641347110e-02 1.75241091528316820e-04 6.50310860707303690e-04 +4.71161636657300060e+00 1.23109415948539130e+01 2.17261950786578910e+00 1.08743796767492460e+00 5.39717601342865730e-01 2.59804965509275510e-01 3.79796004547811110e-02 2.53289923895189670e-02 1.66923763564841650e-04 6.29278566961911160e-04 +4.59075504485146620e+00 1.20656740623369880e+01 2.13572373036567380e+00 1.05811737115864360e+00 5.51355560094152540e-01 2.53994438918661800e-01 3.56058406718513800e-02 2.39508739649649150e-02 1.59824016856789790e-04 6.09482748433288060e-04 +4.48400110626710810e+00 1.18749250691570630e+01 2.08904347078539800e+00 1.02828493708789590e+00 5.65932618316104640e-01 2.48907807908075960e-01 3.32703590211063040e-02 2.25844065532147840e-02 1.52445610282892090e-04 5.90374445159424470e-04 +4.37665140786963210e+00 1.16653080717139480e+01 2.04645150383634090e+00 1.00314769675495310e+00 5.80662860591587670e-01 2.44153292686312390e-01 3.12388528327766720e-02 2.13618580000731020e-02 1.45374869662535650e-04 5.72110338454766620e-04 +4.28308318247419530e+00 1.14961905956683400e+01 1.99538001453730710e+00 9.77039102018820600e-01 5.97053571332504250e-01 2.39429132975829070e-01 2.90951668869559870e-02 2.00912233327255780e-02 1.38064769037521620e-04 5.54377138492266260e-04 +4.19563840737123250e+00 1.14107750456652330e+01 1.94333285978645610e+00 9.49011924454877360e-01 6.12446617285465790e-01 2.34950326535304650e-01 2.71515207694932420e-02 1.89051576280975480e-02 1.30594787351929680e-04 5.37233649209711080e-04 +4.35692593296384260e+00 9.29434514718568930e+00 1.69589204777315850e+00 9.10500045957419960e-01 6.63904520068470450e-01 2.38746595911376420e-01 2.63020018760281370e-02 1.82098542546462290e-02 1.25497318499823790e-04 5.21559371963562160e-04 +4.33138405664923450e+00 7.87684433810288360e+00 1.52764864728680740e+00 9.42515642627748780e-01 7.35795922891237960e-01 2.41699480039806010e-01 2.49526323201402640e-02 1.72899894448508980e-02 1.18740852581057760e-04 5.05834631313525580e-04 +4.19726001319641990e+00 6.93674024894457200e+00 1.46854806774708810e+00 1.01725276618348250e+00 7.83931248211088170e-01 2.38948939715070610e-01 2.32494094864395090e-02 1.62332433075381040e-02 1.11261358607285950e-04 4.90239004021149030e-04 +3.97629715819077400e+00 6.29685779228774930e+00 1.52292684257341810e+00 1.11289951157691090e+00 7.97706246390561310e-01 2.31057040678473440e-01 2.13666741558685220e-02 1.51013545567204450e-02 1.03078689385717270e-04 4.74682477105025770e-04 +3.75144381439811880e+00 5.79754636082953830e+00 1.62768802977632740e+00 1.18223631077710680e+00 7.81756755454271810e-01 2.19913586024770210e-01 1.95170803912291560e-02 1.39810455481456640e-02 9.43199804293427490e-05 4.59127125829838610e-04 +3.48401517338711560e+00 5.43998846080953680e+00 1.79377920416965450e+00 1.22792134140128130e+00 7.44878356691920260e-01 2.06208856992909220e-01 1.75427785162434460e-02 1.27899401721146020e-02 8.44872351568859950e-05 4.43032226787478600e-04 +1.59956578198844190e+00 5.79244447385567530e+00 2.97534452194120510e+00 1.55300982973225970e+00 6.95092678366882160e-01 1.88626335936648240e-01 1.48779627458687970e-02 1.11763470662453270e-02 6.90549579101583330e-05 4.22772361655551400e-04 +2.04021563975728260e+00 6.65819429609627460e+00 2.89922634624825460e+00 1.41337923778932420e+00 6.36344083815797660e-01 1.74000104502178950e-01 1.32071919595408320e-02 1.00687804530210510e-02 5.67382192500131220e-05 4.03077105692223100e-04 +1.67593467064870840e+00 5.52231093211402510e+00 3.00486602969729290e+00 1.38007223007196280e+00 5.95340013161635540e-01 1.62229237655945410e-01 1.17163186623094810e-02 9.01814890416575630e-03 4.29678297639817100e-05 3.79277667477667070e-04 +2.23522850443105270e+00 5.02030988960240340e+00 2.68276638651994890e+00 1.23077590583777650e+00 5.55194926212433270e-01 1.52248122992863640e-01 1.07273354366258700e-02 8.28399116906210380e-03 3.28473992046262910e-05 3.56241938931758870e-04 +2.80342737412509990e+00 6.55876872804534190e+00 2.71882787966020030e+00 1.16972422516667860e+00 5.22475915391999560e-01 1.43556691800178100e-01 9.84537836971042040e-03 7.61976526230039810e-03 2.34524526508316210e-05 3.29627673902985420e-04 +3.60861020997771350e+00 6.58162594621962780e+00 2.45056774737198340e+00 1.02772852660588470e+00 4.78639500107256980e-01 1.33533680634720870e-01 8.87214235169215240e-03 6.84861238948429770e-03 1.04001932909467290e-05 2.76388875456666010e-04 +4.24209901057315890e+00 5.75280115536079960e+00 2.09994342055827850e+00 8.73901489195461510e-01 4.32836631379922900e-01 1.23599956180524660e-01 8.02021570381059210e-03 6.17600333058570230e-03 7.21785030851340250e-07 1.41495177617230900e-04 +4.63620095668962320e+00 4.88825299780982900e+00 1.78063311426987080e+00 7.56310526880074720e-01 4.03730443952738490e-01 1.17302364452249240e-01 7.68539554668673950e-03 5.93034394242894160e-03 8.95415902807643600e-08 7.66668663873028710e-05 +4.96592250595070530e+00 4.09129378787496290e+00 1.43815561507033650e+00 6.29228996602551270e-01 3.71299200579298050e-01 1.11096678695535160e-01 7.47256428450201060e-03 5.77235010347761780e-03 1.14115284152585070e-07 8.08511893879382500e-05 +5.30615614475033230e+00 3.48800735481655670e+00 1.11733160236072120e+00 4.81190741149771170e-01 3.15858723110854330e-01 1.01874467945753100e-01 7.20342242310210040e-03 5.59245374417078650e-03 1.07355330545863460e-07 7.79506673540736650e-05 +4.52053399042335970e+00 1.94482234232948810e+01 4.10695397909124620e+00 1.89824673155996850e+00 7.13946878503728510e-01 1.69553563595341790e-01 1.69294027687453900e-02 1.14819567548934220e-02 8.57492129191724410e-05 3.46122038257982700e-04 +6.52401072001873320e+00 1.40092554298974860e+01 3.20787080745661290e+00 1.32635035961665260e+00 5.40478774349637650e-01 1.31408568386036120e-01 8.78278806898853160e-03 6.28647434520409560e-03 6.91010602949131460e-06 2.27839981458950950e-04 +6.89602853559619080e+00 1.10763825670398680e+01 2.83514154536523220e+00 1.17132616290503000e+00 5.03506881888815090e-01 1.23494651391799910e-01 8.02294510966706890e-03 5.73515595790497110e-03 9.20400954758739720e-08 7.19707541802523240e-05 +7.09374900162630160e+00 9.09473795165944670e+00 2.52912373903193990e+00 1.06391667033161210e+00 4.82107819888889010e-01 1.18619446621877890e-01 7.71936674145110980e-03 5.53945708895034650e-03 7.27114199425041410e-08 6.58494730528462700e-05 +6.43401324797284160e+00 1.02596851307297850e+01 2.97099970535788760e+00 1.13177452306163300e+00 4.60796651787848010e-01 1.12775935362382300e-01 7.24033125775455300e-03 5.27459583353132120e-03 6.36236664309127860e-08 6.19263824088759710e-05 +6.21070826784019920e+00 1.00214105859162840e+01 3.03934453825623450e+00 1.10399880155883360e+00 4.37339984439162780e-01 1.07218973162453880e-01 6.80714764147316580e-03 5.02432130850682420e-03 6.18229327742851270e-08 6.00506994197195790e-05 +6.00431598308986470e+00 9.81793046719106140e+00 3.09431654531417970e+00 1.06978705068029440e+00 4.12808487716417040e-01 1.01635291448645790e-01 6.40891657880066860e-03 4.79524846888725040e-03 6.73007376216075430e-08 6.02806598726462900e-05 +5.20061716810304200e+00 1.12038310194406670e+01 3.49849440367144030e+00 1.12846369546928190e+00 4.05414931131320270e-01 9.89333472686131640e-02 6.22343700826529910e-03 4.65917431977950700e-03 6.00859329543440850e-08 5.72590638962334120e-05 +5.02533860036029090e+00 1.09772190697628730e+01 3.51843988285154060e+00 1.08477214832630780e+00 3.81950349446272210e-01 9.36746874853835870e-02 5.82110208096211020e-03 4.42682346873485640e-03 6.52609338829711920e-08 5.73856837050511520e-05 +5.34656100260619650e+00 9.23118379752449410e+00 3.22468466560910820e+00 9.72836229193835170e-01 3.49461752625445210e-01 8.70596220966562280e-02 5.29251756748815110e-03 4.13011964465032340e-03 6.15917630236237360e-08 5.49463894449675660e-05 +5.22582350334004710e+00 9.07135706618718270e+00 3.22818873995295300e+00 9.31124277456634840e-01 3.27098878823851360e-01 8.20720602059845510e-02 4.88882575592239010e-03 3.89029655993416560e-03 5.21272269419363250e-08 5.10367976636641620e-05 +4.58641247753547890e+00 1.03186102092335170e+01 3.51869595665101230e+00 9.57278837829232840e-01 3.19261714201205420e-01 7.96807431443880690e-02 4.72979647985625770e-03 3.77193199996848920e-03 5.51324480806204080e-08 5.09750860376011080e-05 +4.45799480675412600e+00 1.00890693383401260e+01 3.50867212637662320e+00 9.19446447361622620e-01 3.01375380528315420e-01 7.56419177786329250e-02 4.40763746214481880e-03 3.57026465124809290e-03 4.88787903228591640e-08 4.80495174395811320e-05 +4.33897576401170860e+00 9.96330937223836170e+00 3.49185098896403410e+00 8.78083956982974370e-01 2.81471099016286750e-01 7.11637115768777580e-02 4.00209271499046060e-03 3.32152404215077700e-03 5.52929808195514040e-08 4.85103179460372770e-05 +4.22729470403101180e+00 9.72400602040031360e+00 3.47249227510661560e+00 8.42873775960210400e-01 2.64822229496898600e-01 6.73534743974851110e-02 3.69072877833192920e-03 3.12364606263320840e-03 6.25871426200247310e-08 4.91217097017619240e-05 +4.10951702443020390e+00 9.67735994510120180e+00 3.45799132522750740e+00 8.06940042470817190e-01 2.47087351222386680e-01 6.32815436954167060e-02 3.30423920940995170e-03 2.87544639641209180e-03 5.99104926231931580e-08 4.70679153697598100e-05 +4.52147421198378830e+00 8.28309906861142050e+00 3.20212985587804380e+00 7.31918958125393870e-01 2.23028726956451650e-01 5.80942773018654980e-02 2.81716453892029730e-03 2.56168016047444900e-03 4.06427954038620530e-08 4.03816515529006540e-05 \ No newline at end of file diff --git a/py4DSTEM/utils/single_atom_scatter.py b/py4DSTEM/utils/single_atom_scatter.py new file mode 100644 index 000000000..8d6e2a891 --- /dev/null +++ b/py4DSTEM/utils/single_atom_scatter.py @@ -0,0 +1,90 @@ +import numpy as np +import os + + +class single_atom_scatter(object): + """ + This class calculates the composition averaged single atom scattering factor for a + material. The parameterization is based upon Lobato, Acta Cryst. (2014). A70, + 636–649. + + Elements is an 1D array of atomic numbers. + Composition is a 1D array, same length as elements, describing the average atomic + composition of the sample. If the Q_coords is a 1D array of Fourier coordinates, + given in inverse Angstroms. Units is a string of 'VA' or 'A', which returns the + scattering factor in volt angtroms or in angstroms. + """ + + def __init__(self, elements=None, composition=None, q_coords=None, units=None): + self.elements = elements + self.composition = composition + self.q_coords = q_coords + self.units = units + path = os.path.join(os.path.dirname(__file__), "scattering_factors.txt") + self.e_scattering_factors = np.loadtxt(path, dtype=np.float64) + + return + + def electron_scattering_factor(self, Z, gsq, units="A"): + ai = self.e_scattering_factors[Z - 1, 0:10:2] + bi = self.e_scattering_factors[Z - 1, 1:10:2] + + # Planck's constant in Js + h = 6.62607004e-34 + # Electron rest mass in kg + me = 9.10938356e-31 + # Electron charge in Coulomb + qe = 1.60217662e-19 + + fe = np.zeros_like(gsq) + for i in range(5): + fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2 + + # Result can be returned in units of Volt Angstrom³ ('VA') or Angstrom ('A') + if units == "VA": + return h**2 / (2 * np.pi * me * qe) * 1e18 * fe + elif units == "A": + return fe + + def get_scattering_factor( + self, elements=None, composition=None, q_coords=None, units=None + ): + if elements is None: + assert ( + not self.elements is None + ), "Must pass a list of atomic numbers in either class initialization or in call to get_scattering_factor()" + elements = self.elements + + if composition is None: + assert ( + not self.elements is None + ), "Must pass composition fractions in either class initialization or in call to get_scattering_factor()" + composition = self.composition + + if q_coords is None: + assert ( + not self.elements is None + ), "Must pass a q_space array in either class initialization or in call to get_scattering_factor()" + q_coords = self.q_coords + + if units is None: + units = self.units + if self.units is None: + print("Setting output units to Angstroms") + units = "A" + + assert len(elements) == len( + composition + ), "Each element must have an associated composition." + + if np.sum(composition) > 1: + # normalize composition if passed as stoichiometry instead of atomic fractions + composition /= np.sum(composition) + + fe = np.zeros_like(q_coords) + for i in range(len(elements)): + fe += composition[i] * self.electron_scattering_factor( + elements[i], np.square(q_coords), units + ) + + self.fe = fe diff --git a/py4DSTEM/utils/voronoi.py b/py4DSTEM/utils/voronoi.py new file mode 100644 index 000000000..7d4b5e9eb --- /dev/null +++ b/py4DSTEM/utils/voronoi.py @@ -0,0 +1,122 @@ +import numpy as np +from scipy.spatial import Voronoi + + +def get_voronoi_vertices(voronoi, nx, ny, dist=10): + """ + From a scipy.spatial.Voronoi instance, return a list of ndarrays, where each array + is shape (N,2) and contains the (x,y) positions of the vertices of a voronoi region. + + The problem this function solves is that in a Voronoi instance, some vertices outside + the field of view of the tesselated region are left unspecified; only the existence + of a point beyond the field is referenced (which may or may not be 'at infinity'). + This function specifies all points, such that the vertices and edges of the + tesselation may be directly laid over data. + + Args: + voronoi (scipy.spatial.Voronoi): the voronoi tesselation + nx (int): the x field-of-view of the tesselated region + ny (int): the y field-of-view of the tesselated region + dist (float, optional): place new vertices by extending new voronoi edges outside + the frame by a distance of this factor times the distance of its known vertex + from the frame edge + + Returns: + (list of ndarrays of shape (N,2)): the (x,y) coords of the vertices of each + voronoi region + """ + assert isinstance( + voronoi, Voronoi + ), "voronoi must be a scipy.spatial.Voronoi instance" + + vertex_list = [] + + # Get info about ridges containing an unknown vertex. Include: + # -the index of its known vertex, in voronoi.vertices, and + # -the indices of its regions, in voronoi.point_region + edgeridge_vertices_and_points = [] + for i in range(len(voronoi.ridge_vertices)): + ridge = voronoi.ridge_vertices[i] + if -1 in ridge: + edgeridge_vertices_and_points.append( + [max(ridge), voronoi.ridge_points[i, 0], voronoi.ridge_points[i, 1]] + ) + edgeridge_vertices_and_points = np.array(edgeridge_vertices_and_points) + + # Loop over all regions + for index in range(len(voronoi.regions)): + # Get the vertex indices + vertex_indices = voronoi.regions[index] + vertices = np.array([0, 0]) + # Loop over all vertices + for i in range(len(vertex_indices)): + index_current = vertex_indices[i] + if index_current != -1: + # For known vertices, just add to a running list + vertices = np.vstack((vertices, voronoi.vertices[index_current])) + else: + # For unknown vertices, get the first vertex it connects to, + # and the two voronoi points that this ridge divides + index_prev = vertex_indices[(i - 1) % len(vertex_indices)] + edgeridge_index = int( + np.argwhere(edgeridge_vertices_and_points[:, 0] == index_prev) + ) + index_vert, region0, region1 = edgeridge_vertices_and_points[ + edgeridge_index, : + ] + x, y = voronoi.vertices[index_vert] + # Only add new points for unknown vertices if the known index it connects to + # is inside the frame. Add points by finding the line segment starting at + # the known point which is perpendicular to the segment connecting the two + # voronoi points, and extending that line segment outside the frame. + if (x > 0) and (x < nx) and (y > 0) and (y < ny): + x_r0, y_r0 = voronoi.points[region0] + x_r1, y_r1 = voronoi.points[region1] + m = -(x_r1 - x_r0) / (y_r1 - y_r0) + # Choose the direction to extend the ridge + ts = np.array([-x, -y / m, nx - x, (ny - y) / m]) + x_t = lambda t: x + t + y_t = lambda t: y + m * t + t = ts[np.argmin(np.hypot(x - x_t(ts), y - y_t(ts)))] + x_new, y_new = x_t(dist * t), y_t(dist * t) + vertices = np.vstack((vertices, np.array([x_new, y_new]))) + else: + # If handling unknown points connecting to points outside the frame is + # desired, add here + pass + + # Repeat for the second vertec the unknown vertex connects to + index_next = vertex_indices[(i + 1) % len(vertex_indices)] + edgeridge_index = int( + np.argwhere(edgeridge_vertices_and_points[:, 0] == index_next) + ) + index_vert, region0, region1 = edgeridge_vertices_and_points[ + edgeridge_index, : + ] + x, y = voronoi.vertices[index_vert] + if (x > 0) and (x < nx) and (y > 0) and (y < ny): + x_r0, y_r0 = voronoi.points[region0] + x_r1, y_r1 = voronoi.points[region1] + m = -(x_r1 - x_r0) / (y_r1 - y_r0) + # Choose the direction to extend the ridge + ts = np.array([-x, -y / m, nx - x, (ny - y) / m]) + x_t = lambda t: x + t + y_t = lambda t: y + m * t + t = ts[np.argmin(np.hypot(x - x_t(ts), y - y_t(ts)))] + x_new, y_new = x_t(dist * t), y_t(dist * t) + vertices = np.vstack((vertices, np.array([x_new, y_new]))) + else: + pass + + # Remove regions with insufficiently many vertices + if len(vertices) < 4: + vertices = np.array([]) + # Remove initial dummy point + else: + vertices = vertices[1:, :] + # Update vertex list with this region's vertices + vertex_list.append(vertices) + + return vertex_list + + diff --git a/py4DSTEM/utils_config/__init__.py b/py4DSTEM/utils_config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 2db48e371..2a611da45 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -3,6 +3,7 @@ from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi +from colorspacious import cspace_convert from emdfile import PointList from py4DSTEM.visualize import show @@ -17,6 +18,8 @@ from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR from colorspacious import cspace_convert +from py4DSTEM.utils import convert_ellipse_params +from py4DSTEM.utils import get_voronoi_vertices def show_elliptical_fit( @@ -120,7 +123,6 @@ def show_amorphous_ring_fit( returnfig (bool): if True, returns the figure """ from py4DSTEM.process.calibration import double_sided_gaussian - from py4DSTEM.process.utils import convert_ellipse_params assert len(p_dsg) == 11 assert isinstance(N, (int, np.integer)) @@ -307,8 +309,6 @@ def show_voronoi( """ words """ - from py4DSTEM.process.utils import get_voronoi_vertices - Nx, Ny = ar.shape points = np.vstack((x, y)).T voronoi = Voronoi(points)