diff --git a/.github/scripts/update_version.py b/.github/scripts/update_version.py index 2aaaa07af..635cf8268 100644 --- a/.github/scripts/update_version.py +++ b/.github/scripts/update_version.py @@ -8,7 +8,7 @@ lines = f.readlines() line_split = lines[0].split(".") -patch_number = line_split[2].split("'")[0] +patch_number = line_split[2].split("'")[0].split('"')[0] # Increment patch number patch_number = str(int(patch_number) + 1) + "'" diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml index 6e44a6334..82701d50d 100644 --- a/.github/workflows/check_install_dev.yml +++ b/.github/workflows/check_install_dev.yml @@ -17,10 +17,10 @@ jobs: runs-on: [ubuntu-latest] architecture: [x86_64] python-version: ["3.9", "3.10", "3.11",] - include: - - python-version: "3.12.0-beta.4" - runs-on: ubuntu-latest - allow_failure: true + # include: + # - python-version: "3.12.0-beta.4" + # runs-on: ubuntu-latest + # allow_failure: true # Currently no public runners available for this but this or arm64 should work next time # include: # - python-version: "3.10" diff --git a/.gitignore b/.gitignore index 6c008b0ff..24587a3b3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.swp *.ipynb_checkpoints* .vscode/ +pyrightconfig.json # Folders # .idea/ diff --git a/README.md b/README.md index 0561f098a..aa102542a 100644 --- a/README.md +++ b/README.md @@ -46,42 +46,50 @@ First, download and install Anaconda: www.anaconda.com/download. If you prefer a more lightweight conda client, you can instead install Miniconda: https://docs.conda.io/en/latest/miniconda.html. Then open a conda terminal and run one of the following sets of commands to ensure everything is up-to-date and create a new environment for your py4DSTEM installation: - ``` conda update conda conda create -n py4dstem conda activate py4dstem +conda install -c conda-forge py4dstem pymatgen jupyterlab ``` -Next, install py4DSTEM. To simultaneously install py4DSTEM with `pymatgen` (used in some crystal structure workflows) and `jupyterlab` (providing an interface for running Python notebooks like those provided in the [py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials)) run: +In order, these commands +- ensure your installation of anaconda is up-to-date +- make a virtual environment (see below) +- enter the environment +- install py4DSTEM, as well as pymatgen (used for crystal structure calculations) and JupyterLab (an interface for running Python notebooks like those in the [py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials)) + + +We've had some recent reports install of `conda` getting stuck trying to solve the environment using the above installation. If you run into this problem, you can install py4DSTEM using `pip` instead of `conda` by running: ``` -conda install -c conda-forge py4dstem pymatgen jupyterlab +conda update conda +conda create -n py4dstem python=3.10 +conda activate py4dstem +pip install py4dstem pymatgen ``` -Or if you would prefer to install only the base modules of **py4DSTEM**, you can instead run: +Both `conda` and `pip` are programs which manage package installations, i.e. make sure different codes you're installing which depend on one another are using mutually compatible versions. Each has advantages and disadvantages; `pip` is a little more bare-bones, and we've seen this install work when `conda` doesn't. If you also want to use Jupyterlab you can then use either `pip install jupyterlab` or `conda install jupyterlab`. + +If you would prefer to install only the base modules of **py4DSTEM**, and skip pymategen and Jupterlab, you can instead run: ``` conda install -c conda-forge py4dstem ``` -In Windows you should then also run: +Finally, regardless of which of the above approaches you used, in Windows you should then also run: ``` conda install pywin32 ``` -In order, these commands -- ensure your installation of anaconda is up-to-date -- make a virtual environment (see below) -- enter the environment -- install py4DSTEM, and optionally also pymatgen and JupyterLab -- on Windows, enable python to talk to the windows API +which enables Python to talk to the Windows API. Please note that virtual environments are used in the instructions above in order to make sure packages that have different dependencies don't conflict with one another. Because these directions install py4DSTEM to its own virtual environment, each time you want to use py4DSTEM you'll need to activate this environment. You can do this in the command line by running `conda activate py4dstem`, or, if you're using the Anaconda Navigator, by clicking on the Environments tab and then clicking on `py4dstem`. +Last - as of the version 0.14.4 update, we've had a few reports of problems upgrading to the newest version. We're not sure what's causing the issue yet, but have found the new version can be installed successfully in these cases using a fresh Anaconda installation. diff --git a/docs/requirements.txt b/docs/requirements.txt index 43dbc0817..03ecc7e26 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ emdfile -# py4dstem \ No newline at end of file +sphinx_rtd_theme +# py4dstem diff --git a/docs/source/conf.py b/docs/source/conf.py index 30ee084fe..6da66611e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -36,7 +36,12 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.intersphinx"] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx_rtd_theme", +] # Other useful extensions # sphinx_copybutton diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index dcb6a861d..adf757d1b 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -53,7 +53,7 @@ ) # strain -from py4DSTEM.process import StrainMap +from py4DSTEM.process.strain.strain import StrainMap # TODO - crystal # TODO - ptycho diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index f4a9d96e1..02692e455 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -1,12 +1,14 @@ # BraggVectors methods -import numpy as np -from scipy.ndimage import gaussian_filter -from warnings import warn import inspect +from warnings import warn -from emdfile import Array, Metadata, tqdmnd, _read_metadata +import matplotlib.pyplot as plt +import numpy as np +from emdfile import Array, Metadata, _read_metadata, tqdmnd +from py4DSTEM import show from py4DSTEM.datacube import VirtualImage +from scipy.ndimage import gaussian_filter class BraggVectorMethods: @@ -517,6 +519,7 @@ def fit_origin( robust_thresh=2, plot=True, plot_range=None, + cmap="RdBu_r", returncalc=True, **kwargs, ): @@ -541,6 +544,8 @@ def fit_origin( Plot the results plot_range : float Min and max color range for plot (pixels) + cmap : colormap + plotting colormap returncalc : bool Toggles returning the answer @@ -571,75 +576,107 @@ def fit_origin( tuple(q_meas) ) - # try to add to calibration + qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin( + tuple(q_meas), + mask=mask, + fitfunction=fitfunction, + robust=robust, + robust_steps=robust_steps, + robust_thresh=robust_thresh, + ) + + # try to add update calibration metadata try: - self.calibration.set_origin([qx0_fit, qy0_fit]) + self.calibration.set_origin((qx0_fit, qy0_fit)) + self.setcal() except AttributeError: warn( "No calibration found on this datacube - fit values are not being stored" ) pass - if plot: - from py4DSTEM.visualize import show_image_grid - if mask is None: - qx0_meas, qy0_meas = q_meas - qx0_res_plot = qx0_residuals - qy0_res_plot = qy0_residuals - else: - qx0_meas = np.ma.masked_array(q_meas[0], mask=np.logical_not(mask)) - qy0_meas = np.ma.masked_array(q_meas[1], mask=np.logical_not(mask)) - qx0_res_plot = np.ma.masked_array( - qx0_residuals, mask=np.logical_not(mask) - ) - qy0_res_plot = np.ma.masked_array( - qy0_residuals, mask=np.logical_not(mask) - ) - qx0_mean = np.mean(qx0_fit) - qy0_mean = np.mean(qy0_fit) - - if plot_range is None: - plot_range = 2 * np.max(qx0_fit - qx0_mean) - - cmap = kwargs.get("cmap", "RdBu_r") - kwargs.pop("cmap", None) - axsize = kwargs.get("axsize", (6, 2)) - kwargs.pop("axsize", None) - - show_image_grid( - lambda i: [ - qx0_meas - qx0_mean, - qx0_fit - qx0_mean, - qx0_res_plot, - qy0_meas - qy0_mean, - qy0_fit - qy0_mean, - qy0_res_plot, - ][i], - H=2, - W=3, + # show + if plot: + self.show_origin_fit( + q_meas[0], + q_meas[1], + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask=mask, + plot_range=plot_range, cmap=cmap, - axsize=axsize, - title=[ - "measured origin, x", - "fitorigin, x", - "residuals, x", - "measured origin, y", - "fitorigin, y", - "residuals, y", - ], - vmin=-1 * plot_range, - vmax=1 * plot_range, - intensity_range="absolute", **kwargs, ) - # update calibration metadata - self.calibration.set_origin((qx0_fit, qy0_fit)) - self.setcal() - + # return if returncalc: return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals + def show_origin_fit( + self, + qx0_meas, + qy0_meas, + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask=None, + plot_range=None, + cmap="RdBu_r", + **kwargs, + ): + # apply mask + if mask is not None: + qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask)) + qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask)) + qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask)) + qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask)) + qx0_mean = np.mean(qx0_fit) + qy0_mean = np.mean(qy0_fit) + + # set range + if plot_range is None: + plot_range = max( + ( + 1.5 * np.max(np.abs(qx0_fit - qx0_mean)), + 1.5 * np.max(np.abs(qy0_fit - qy0_mean)), + ) + ) + + # set figsize + imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0]) + axsize = (3 * imsize_ratio, 3 / imsize_ratio) + axsize = kwargs.pop("axsize", axsize) + + # plot + fig, ax = show( + [ + [qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals], + [qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals], + ], + cmap=cmap, + axsize=axsize, + title=[ + "measured origin, x", + "fitorigin, x", + "residuals, x", + "measured origin, y", + "fitorigin, y", + "residuals, y", + ], + vmin=-1 * plot_range, + vmax=1 * plot_range, + intensity_range="absolute", + show_cbar=True, + returnfig=True, + **kwargs, + ) + plt.tight_layout() + + return + def fit_p_ellipse( self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs ): @@ -775,6 +812,21 @@ def mask_in_R(self, mask, update_inplace=False, returncalc=True): else: return + def to_strainmap(self, name: str = None): + """ + Generate a StrainMap object from the BraggVectors + equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors) + + Args: + name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'. + + Returns: + py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors + """ + from py4DSTEM.process.strain import StrainMap + + return StrainMap(self, name) if name else StrainMap(self) + ######### END BraggVectorMethods CLASS ######## diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 28570ca3a..d293259c9 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -200,7 +200,7 @@ def setcal( if pixel is None: pixel = False if c.get_Q_pixel_size() == 1 else True if rotate is None: - rotate = False if c.get_QR_rotflip() is None else True + rotate = False if c.get_QR_rotation() is None else True # validate requested state if center: @@ -210,7 +210,7 @@ def setcal( if pixel: assert c.get_Q_pixel_size() is not None, "Requested calibration not found" if rotate: - assert c.get_QR_rotflip() is not None, "Requested calibration not found" + assert c.get_QR_rotation() is not None, "Requested calibration not found" # set the calibrations self._calstate = { @@ -272,6 +272,7 @@ def copy(self, name=None): braggvector_copy.set_raw_vectors(self._v_uncal.copy()) for k in self.metadata.keys(): braggvector_copy.metadata = self.metadata[k].copy() + braggvector_copy.setcal() return braggvector_copy # write @@ -479,14 +480,16 @@ def _transform( # Q/R rotation if rotate: flip = cal.get_QR_flip() - theta = np.radians(cal.get_QR_rotation_degrees()) + theta = cal.get_QR_rotation() assert flip is not None, "Requested calibration was not found!" assert theta is not None, "Requested calibration was not found!" + flip = cal.get_QR_flip() + flip = False if flip is None else flip # rotation matrix R = np.array( [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]] ) - # apply + # rotate and flip if flip: positions = R @ np.vstack((ans["qy"], ans["qx"])) else: diff --git a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py index d0f550dcc..c5f89b9fd 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py @@ -17,8 +17,8 @@ try: import cupy as cp -except: - raise ImportError("Import Error: Please install cupy before proceeding") +except ModuleNotFoundError: + raise ImportError("AIML CUDA Requires cupy") try: import tensorflow as tf diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py index f9b963ad4..bfd6a1b76 100644 --- a/py4DSTEM/data/calibration.py +++ b/py4DSTEM/data/calibration.py @@ -205,6 +205,7 @@ def __init__( self["R_pixel_size"] = 1 self["Q_pixel_units"] = "pixels" self["R_pixel_units"] = "pixels" + self["QR_flip"] = False # EMD root property @property @@ -682,8 +683,17 @@ def ellipse(self, x): # Q/R-space rotation and flip + @call_calibrate + def set_QR_rotation(self, x): + self._params["QR_rotation"] = x + self._params["QR_rotation_degrees"] = np.degrees(x) + + def get_QR_rotation(self): + return self._get_value("QR_rotation") + @call_calibrate def set_QR_rotation_degrees(self, x): + self._params["QR_rotation"] = np.radians(x) self._params["QR_rotation_degrees"] = x def get_QR_rotation_degrees(self): @@ -705,10 +715,31 @@ def set_QR_rotflip(self, rot_flip): flip (bool): True indicates a Q/R axes flip """ rot, flip = rot_flip + self._params["QR_rotation"] = rot + self._params["QR_rotation_degrees"] = np.degrees(rot) + self._params["QR_flip"] = flip + + @call_calibrate + def set_QR_rotflip_degrees(self, rot_flip): + """ + Args: + rot_flip (tuple), (rot, flip) where: + rot (number): rotation in degrees + flip (bool): True indicates a Q/R axes flip + """ + rot, flip = rot_flip + self._params["QR_rotation"] = np.radians(rot) self._params["QR_rotation_degrees"] = rot self._params["QR_flip"] = flip def get_QR_rotflip(self): + rot = self.get_QR_rotation() + flip = self.get_QR_flip() + if rot is None or flip is None: + return None + return (rot, flip) + + def get_QR_rotflip_degrees(self): rot = self.get_QR_rotation_degrees() flip = self.get_QR_flip() if rot is None or flip is None: diff --git a/py4DSTEM/preprocess/electroncount.py b/py4DSTEM/preprocess/electroncount.py index 7a498a061..d3c2edd9a 100644 --- a/py4DSTEM/preprocess/electroncount.py +++ b/py4DSTEM/preprocess/electroncount.py @@ -1,8 +1,9 @@ # Electron counting # -# Includes functions for electron counting on either the CPU (electron_count) or the GPU -# (electron_count_GPU). For GPU electron counting, pytorch is used to interface between -# numpy and the GPU, and the datacube is expected in numpy.memmap (memory mapped) form. +# Includes functions for electron counting on either the CPU (electron_count) +# or the GPU (electron_count_GPU). For GPU electron counting, pytorch is used +# to interface between numpy and the GPU, and the datacube is expected in +# numpy.memmap (memory mapped) form. import numpy as np from scipy import optimize @@ -25,17 +26,17 @@ def electron_count( Performs electron counting. The algorithm is as follows: - From a random sampling of frames, calculate an x-ray and background threshold value. - In each frame, subtract the dark reference, then apply the two thresholds. Find all - local maxima with respect to the nearest neighbor pixels. These are considered - electron strike events. - - Thresholds are specified in units of standard deviations, either of a gaussian fit to - the histogram background noise (for thresh_bkgrnd) or of the histogram itself (for - thresh_xray). The background (lower) threshold is more important; we will always be - missing some real electron counts and incorrectly counting some noise as electron - strikes - this threshold controls their relative balance. The x-ray threshold may be - set fairly high. + From a random sampling of frames, calculate an x-ray and background + threshold value. In each frame, subtract the dark reference, then apply the + two thresholds. Find all local maxima with respect to the nearest neighbor + pixels. These are considered electron strike events. + + Thresholds are specified in units of standard deviations, either of a + gaussian fit to the histogram background noise (for thresh_bkgrnd) or of + the histogram itself (for thresh_xray). The background (lower) threshold is + more important; we will always be missing some real electron counts and + incorrectly counting some noise as electron strikes - this threshold + controls their relative balance. The x-ray threshold may be set fairly high. Args: datacube: a 4D numpy.ndarray pointing to the datacube. Note: the R/Q axes are diff --git a/py4DSTEM/preprocess/utils.py b/py4DSTEM/preprocess/utils.py index 0c76f35a7..752e2f81c 100644 --- a/py4DSTEM/preprocess/utils.py +++ b/py4DSTEM/preprocess/utils.py @@ -5,8 +5,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def bin2D(array, factor, dtype=np.float64): diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index e711e907d..0509d181e 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,11 +1,9 @@ from py4DSTEM.process.polar import PolarDatacube -from py4DSTEM.process.strain import StrainMap +from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process import latticevectors from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils from py4DSTEM.process import classification -from py4DSTEM.process import latticevectors from py4DSTEM.process import diffraction from py4DSTEM.process import wholepatternfit diff --git a/py4DSTEM/process/calibration/rotation.py b/py4DSTEM/process/calibration/rotation.py index 2c3e7bb43..aaf8a49ce 100644 --- a/py4DSTEM/process/calibration/rotation.py +++ b/py4DSTEM/process/calibration/rotation.py @@ -2,6 +2,179 @@ import numpy as np from typing import Optional +import matplotlib.pyplot as plt +from py4DSTEM import show + + +def compare_QR_rotation( + im_R, + im_Q, + QR_rotation, + R_rotation=0, + R_position=None, + Q_position=None, + R_pos_anchor="center", + Q_pos_anchor="center", + R_length=0.33, + Q_length=0.33, + R_width=0.001, + Q_width=0.001, + R_head_length_adjust=1, + Q_head_length_adjust=1, + R_head_width_adjust=1, + Q_head_width_adjust=1, + R_color="r", + Q_color="r", + figsize=(10, 5), + returnfig=False, +): + """ + Visualize a rotational offset between an image in real space, e.g. a STEM + virtual image, and an image in diffraction space, e.g. a defocused CBED + shadow image of the same region, by displaying an arrow overlaid over each + of these two images with the specified QR rotation applied. The QR rotation + is defined as the counter-clockwise rotation from real space to diffraction + space, in degrees. + + Parameters + ---------- + im_R : numpy array or other 2D image-like object (e.g. a VirtualImage) + A real space image, e.g. a STEM virtual image + im_Q : numpy array or other 2D image-like object + A diffraction space image, e.g. a defocused CBED image + QR_rotation : number + The counterclockwise rotation from real space to diffraction space, + in degrees + R_rotation : number + The orientation of the arrow drawn in real space, in degrees + R_position : None or 2-tuple + The position of the anchor point for the R-space arrow. If None, defaults + to the center of the image + Q_position : None or 2-tuple + The position of the anchor point for the Q-space arrow. If None, defaults + to the center of the image + R_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the R-space arrow, i.e. the point being specified by + the `R_position` parameter + Q_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the Q-space arrow, i.e. the point being specified by + the `Q_position` parameter + R_length : number or None + The length of the R-space arrow, as a fraction of the mean size of the + image + Q_length : number or None + The length of the Q-space arrow, as a fraction of the mean size of the + image + R_width : number + The width of the R-space arrow + Q_width : number + The width of the R-space arrow + R_head_length_adjust : number + Scaling factor for the R-space arrow head length + Q_head_length_adjust : number + Scaling factor for the Q-space arrow head length + R_head_width_adjust : number + Scaling factor for the R-space arrow head width + Q_head_width_adjust : number + Scaling factor for the Q-space arrow head width + R_color : color + Color of the R-space arrow + Q_color : color + Color of the Q-space arrow + figsize : 2-tuple + The figure size + returnfig : bool + Toggles returning the figure and axes + """ + # parse inputs + if R_position is None: + R_position = ( + im_R.shape[0] / 2, + im_R.shape[1] / 2, + ) + if Q_position is None: + Q_position = ( + im_Q.shape[0] / 2, + im_Q.shape[1] / 2, + ) + R_length = np.mean(im_R.shape) * R_length + Q_length = np.mean(im_Q.shape) * Q_length + assert R_pos_anchor in ("center", "tail", "head") + assert Q_pos_anchor in ("center", "tail", "head") + + # compute positions + rpos_x, rpos_y = R_position + qpos_x, qpos_y = Q_position + R_rot_rad = np.radians(R_rotation) + Q_rot_rad = np.radians(R_rotation + QR_rotation) + rvecx = np.cos(R_rot_rad) + rvecy = np.sin(R_rot_rad) + qvecx = np.cos(Q_rot_rad) + qvecy = np.sin(Q_rot_rad) + if R_pos_anchor == "center": + x0_r = rpos_x - rvecx * R_length / 2 + y0_r = rpos_y - rvecy * R_length / 2 + x1_r = rpos_x + rvecx * R_length / 2 + y1_r = rpos_y + rvecy * R_length / 2 + elif R_pos_anchor == "tail": + x0_r = rpos_x + y0_r = rpos_y + x1_r = rpos_x + rvecx * R_length + y1_r = rpos_y + rvecy * R_length + elif R_pos_anchor == "head": + x0_r = rpos_x - rvecx * R_length + y0_r = rpos_y - rvecy * R_length + x1_r = rpos_x + y1_r = rpos_y + else: + raise Exception(f"Invalid value for R_pos_anchor {R_pos_anchor}") + if Q_pos_anchor == "center": + x0_q = qpos_x - qvecx * Q_length / 2 + y0_q = qpos_y - qvecy * Q_length / 2 + x1_q = qpos_x + qvecx * Q_length / 2 + y1_q = qpos_y + qvecy * Q_length / 2 + elif Q_pos_anchor == "tail": + x0_q = qpos_x + y0_q = qpos_y + x1_q = qpos_x + qvecx * Q_length + y1_q = qpos_y + qvecy * Q_length + elif Q_pos_anchor == "head": + x0_q = qpos_x - qvecx * Q_length + y0_q = qpos_y - qvecy * Q_length + x1_q = qpos_x + y1_q = qpos_y + else: + raise Exception(f"Invalid value for Q_pos_anchor {Q_pos_anchor}") + + # make the figure + axsize = (figsize[0] / 2, figsize[1]) + fig, axs = show([im_R, im_Q], returnfig=True, axsize=axsize) + axs[0, 0].arrow( + x=y0_r, + y=x0_r, + dx=y1_r - y0_r, + dy=x1_r - x0_r, + color=R_color, + length_includes_head=True, + width=R_width, + head_width=R_length * R_head_width_adjust * 0.072, + head_length=R_length * R_head_length_adjust * 0.1, + ) + axs[0, 1].arrow( + x=y0_q, + y=x0_q, + dx=y1_q - y0_q, + dy=x1_q - x0_q, + color=Q_color, + length_includes_head=True, + width=Q_width, + head_width=Q_length * Q_head_width_adjust * 0.072, + head_length=Q_length * Q_head_length_adjust * 0.1, + ) + if returnfig: + return fig, axs + else: + plt.show() def get_Qvector_from_Rvector(vx, vy, QR_rotation): diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index fa735438b..b508d589e 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -2,23 +2,15 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional -from scipy.optimize import curve_fit import sys -from emdfile import tqdmnd, PointList, PointListArray +from emdfile import PointList from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom -from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -from py4DSTEM.process.diffraction.crystal_viz import plot_ring_pattern -from py4DSTEM.process.diffraction.utils import Orientation, calc_1D_profile - -try: - from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - from pymatgen.core.structure import Structure -except ImportError: - pass +from py4DSTEM.process.diffraction.utils import Orientation class Crystal: @@ -36,6 +28,8 @@ class Crystal: orientation_plan, match_orientations, match_single_pattern, + cluster_grains, + cluster_orientation_map, calculate_strain, save_ang_file, symmetry_reduce_directions, @@ -51,6 +45,8 @@ class Crystal: plot_orientation_plan, plot_orientation_maps, plot_fiber_orientation_maps, + plot_clusters, + plot_cluster_size, ) from py4DSTEM.process.diffraction.crystal_calibrate import ( @@ -1074,3 +1070,517 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp + + +def generate_moire_diffraction_pattern( + bragg_peaks_0, + bragg_peaks_1, + thresh_0=0.0002, + thresh_1=0.0002, + exx_1=0.0, + eyy_1=0.0, + exy_1=0.0, + phi_1=0.0, + power=2.0, +): + """ + Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated + and strained with respect to the original lattice. Note that this strain is applied in real space, + and so the inverse of the calculated infinitestimal strain tensor is applied. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + thresh_0: float + Intensity threshold for structure factors from lattice 0. + thresh_1: float + Intensity threshold for structure factors from lattice 1. + exx_1: float + Strain of lattice 1 in x direction (vertical) in real space. + eyy_1: float + Strain of lattice 1 in y direction (horizontal) in real space. + exy_1: float + Shear strain of lattice 1 in (x,y) direction (diagonal) in real space. + phi_1: float + Rotation of lattice 1 in real space. + power: float + Plotting power law (default is amplitude**2.0, i.e. intensity). + + Returns + -------- + parent_peaks_0, parent_peaks_1, moire_peaks: BraggVectors + Bragg vectors for the rotated & strained parent lattices + and the moire lattice + + """ + + # get intenties of all peaks + int0 = bragg_peaks_0["intensity"] ** (power / 2.0) + int1 = bragg_peaks_1["intensity"] ** (power / 2.0) + + # peaks above threshold + sub0 = int0 >= thresh_0 + sub1 = int1 >= thresh_1 + + # Remove origin (assuming brightest peak) + ind0_or = np.argmax(bragg_peaks_0["intensity"]) + ind1_or = np.argmax(bragg_peaks_1["intensity"]) + sub0[ind0_or] = False + sub1[ind1_or] = False + int0_sub = int0[sub0] + int1_sub = int1[sub1] + + # Get peaks + qx0 = bragg_peaks_0["qx"][sub0] + qy0 = bragg_peaks_0["qy"][sub0] + qx1_init = bragg_peaks_1["qx"][sub1] + qy1_init = bragg_peaks_1["qy"][sub1] + + # peak labels + h0 = bragg_peaks_0["h"][sub0] + k0 = bragg_peaks_0["k"][sub0] + l0 = bragg_peaks_0["l"][sub0] + h1 = bragg_peaks_1["h"][sub1] + k1 = bragg_peaks_1["k"][sub1] + l1 = bragg_peaks_1["l"][sub1] + + # apply strain tensor to lattice 1 + m = np.array( + [ + [np.cos(phi_1), -np.sin(phi_1)], + [np.sin(phi_1), np.cos(phi_1)], + ] + ) @ np.linalg.inv( + np.array( + [ + [1 + exx_1, exy_1 * 0.5], + [exy_1 * 0.5, 1 + eyy_1], + ] + ) + ) + qx1 = m[0, 0] * qx1_init + m[0, 1] * qy1_init + qy1 = m[1, 0] * qx1_init + m[1, 1] * qy1_init + + # Generate moire lattice + ind0, ind1 = np.meshgrid( + np.arange(np.sum(sub0)), + np.arange(np.sum(sub1)), + indexing="ij", + ) + qx = qx0[ind0] + qx1[ind1] + qy = qy0[ind0] + qy1[ind1] + int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 + + # moire labels + m_h0 = h0[ind0] + m_k0 = k0[ind0] + m_l0 = l0[ind0] + m_h1 = h1[ind1] + m_k1 = k1[ind1] + m_l1 = l1[ind1] + + # Convert thresholded and moire peaks to BraggVector class + + pl_dtype_parent = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h", "int"), + ("k", "int"), + ("l", "int"), + ] + ) + + bragg_parent_0 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_0.add_data_by_field( + [ + qx0.ravel(), + qy0.ravel(), + int0_sub.ravel(), + h0.ravel(), + k0.ravel(), + l0.ravel(), + ] + ) + + bragg_parent_1 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_1.add_data_by_field( + [ + qx1.ravel(), + qy1.ravel(), + int1_sub.ravel(), + h1.ravel(), + k1.ravel(), + l1.ravel(), + ] + ) + + pl_dtype = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ] + ) + bragg_moire = PointList(np.array([], dtype=pl_dtype)) + bragg_moire.add_data_by_field( + [ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(), + m_k0.ravel(), + m_l0.ravel(), + m_h1.ravel(), + m_k1.ravel(), + m_l1.ravel(), + ] + ) + + return bragg_parent_0, bragg_parent_1, bragg_moire + + +def plot_moire_diffraction_pattern( + bragg_parent_0, + bragg_parent_1, + bragg_moire, + int_range=(0, 5e-3), + k_max=1.0, + plot_subpixel=True, + labels=None, + marker_size_parent=16, + marker_size_moire=4, + text_size_parent=10, + text_size_moire=6, + add_labels_parent=False, + add_labels_moire=False, + dist_labels=0.03, + dist_check=0.06, + sep_labels=0.03, + figsize=(8, 6), + returnfig=False, +): + """ + Plot Moire lattice and parent lattices. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + bragg_moire: BraggVector + Bragg vectors for moire lattice. + int_range: (float, float) + Plotting intensity range for the Moire peaks. + k_max: float + Max k value of the plotted Moire lattice. + plot_subpixel: bool + Apply subpixel corrections to the Bragg spot positions. + Matplotlib default scatter plot rounds to the nearest pixel. + labels: list + List of text labels for parent lattices + marker_size_parent: float + Size of plot markers for the two parent lattices. + marker_size_moire: float + Size of plot markers for the Moire lattice. + text_size_parent: float + Label text size for parent lattice. + text_size_moire: float + Label text size for Moire lattice. + add_labels_parent: bool + Plot the parent lattice index labels. + add_labels_moire: bool + Plot the parent lattice index labels for the Moire spots. + dist_labels: float + Distance to move the labels off the spots. + dist_check: float + Set to some distance to "push" the labels away from each other if they are within this distance. + sep_labels: float + Separation distance for labels which are "pushed" apart. + figsize: (float,float) + Size of output figure. + returnfig: bool + Return the (fix,ax) handles of the plot. + + Returns + -------- + fig, ax: matplotlib handles (optional) + Figure and axes handles for the moire plot. + """ + + # peak labels + + if labels is None: + labels = ("crystal 0", "crystal 1") + + def overline(x): + return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") + + # parent 1 + qx0 = bragg_parent_0["qx"] + qy0 = bragg_parent_0["qy"] + h0 = bragg_parent_0["h"] + k0 = bragg_parent_0["k"] + l0 = bragg_parent_0["l"] + + # parent 2 + qx1 = bragg_parent_1["qx"] + qy1 = bragg_parent_1["qy"] + h1 = bragg_parent_1["h"] + k1 = bragg_parent_1["k"] + l1 = bragg_parent_1["l"] + + # moire + qx = bragg_moire["qx"] + qy = bragg_moire["qy"] + m_h0 = bragg_moire["h0"] + m_k0 = bragg_moire["k0"] + m_l0 = bragg_moire["l0"] + m_h1 = bragg_moire["h1"] + m_k1 = bragg_moire["k1"] + m_l1 = bragg_moire["l1"] + int_moire = bragg_moire["intensity"] + + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) + ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) + + text_params_parent = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_parent, + } + text_params_moire = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_moire, + } + + if plot_subpixel is False: + # moire + ax.scatter( + qy, + qx, + # color = (0,0,0,1), + c=int_moire, + s=marker_size_moire, + cmap="gray_r", + vmin=int_range[0], + vmax=int_range[1], + antialiased=True, + ) + + # parent lattices + ax.scatter( + qy0, + qx0, + color=(1, 0, 0, 1), + s=marker_size_parent, + antialiased=True, + ) + ax.scatter( + qy1, + qx1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + antialiased=True, + ) + + # origin + ax.scatter( + 0, + 0, + color=(0, 0, 0, 1), + s=marker_size_parent, + antialiased=True, + ) + + else: + # moire peaks + int_all = np.clip( + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1 + ) + keep = np.logical_and.reduce( + (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max) + ) + for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_moire) / 800.0, + color=(1 - int_marker, 1 - int_marker, 1 - int_marker), + ) + ) + if add_labels_moire: + for a0 in range(qx.size): + if keep.ravel()[a0]: + x0 = qx.ravel()[a0] + y0 = qy.ravel()[a0] + d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2 + sub = d2 < dist_check**2 + xc = np.mean(qx.ravel()[sub]) + yc = np.mean(qy.ravel()[sub]) + xp = x0 - xc + yp = y0 - yc + if xp == 0 and yp == 0.0: + xp = x0 - dist_labels + yp = y0 + else: + leng = np.linalg.norm((xp, yp)) + xp = x0 + xp * dist_labels / leng + yp = y0 + yp * dist_labels / leng + + ax.text( + yp, + xp - sep_labels, + "$" + + overline(m_h0.ravel()[a0]) + + overline(m_k0.ravel()[a0]) + + overline(m_l0.ravel()[a0]) + + "$", + c="r", + **text_params_moire, + ) + ax.text( + yp, + xp, + "$" + + overline(m_h1.ravel()[a0]) + + overline(m_k1.ravel()[a0]) + + overline(m_l1.ravel()[a0]) + + "$", + c=(0, 0.7, 1.0), + **text_params_moire, + ) + + keep = np.logical_and.reduce( + (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max) + ) + for x, y in zip(qx0[keep], qy0[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(1, 0, 0), + ) + ) + if add_labels_parent: + for a0 in range(qx0.size): + if keep.ravel()[a0]: + xp = qx0.ravel()[a0] - dist_labels + yp = qy0.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h0.ravel()[a0]) + + overline(k0.ravel()[a0]) + + overline(l0.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + keep = np.logical_and.reduce( + (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max) + ) + for x, y in zip(qx1[keep], qy1[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0.7, 1), + ) + ) + if add_labels_parent: + for a0 in range(qx1.size): + if keep.ravel()[a0]: + xp = qx1.ravel()[a0] - dist_labels + yp = qy1.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h1.ravel()[a0]) + + overline(k1.ravel()[a0]) + + overline(l1.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + # origin + ax.add_artist( + Circle( + xy=(0, 0), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0, 0), + ) + ) + + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + ax.set_ylabel("$q_x$ (1/A)") + ax.set_xlabel("$q_y$ (1/A)") + ax.invert_yaxis() + + # labels + ax_labels.scatter( + 0, + 0, + color=(1, 0, 0, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -2, + color=(0, 0, 0, 1), + s=marker_size_moire, + ) + ax_labels.text( + 0.4, + -0.2, + labels[0], + fontsize=14, + ) + ax_labels.text( + 0.4, + -1.2, + labels[1], + fontsize=14, + ) + ax_labels.text( + 0.4, + -2.2, + "Moiré lattice", + fontsize=14, + ) + + ax_labels.set_xlim((-1, 4)) + ax_labels.set_ylim((-21, 1)) + + ax_labels.axis("off") + + if returnfig: + return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index 523260e61..4926bd445 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1,7 +1,7 @@ import numpy as np import matplotlib.pyplot as plt -import os from typing import Union, Optional +from tqdm import tqdm from emdfile import tqdmnd, PointList, PointListArray from py4DSTEM.data import RealSlice @@ -14,7 +14,7 @@ try: import cupy as cp -except: +except ModuleNotFoundError: cp = None @@ -29,6 +29,7 @@ def orientation_plan( corr_kernel_size: float = 0.08, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling + calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, fiber_axis=None, @@ -61,6 +62,8 @@ def orientation_plan( corr_kernel_size (float): Correlation kernel size length in Angstroms radial_power (float): Power for scaling the correlation intensity as a function of the peak radius intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity + calculate_correlation_array (bool): Set to false to skip calculating the correlation array. + This is useful when we only want the angular range / rotation matrices. tol_peak_delete (float): Distance to delete peaks for multiple matches. Default is kernel_size * 0.5 tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] @@ -598,21 +601,6 @@ def orientation_plan( # init storage arrays self.orientation_rotation_angles = np.zeros((self.orientation_num_zones, 3)) self.orientation_rotation_matrices = np.zeros((self.orientation_num_zones, 3, 3)) - self.orientation_ref = np.zeros( - ( - self.orientation_num_zones, - np.size(self.orientation_shell_radii), - self.orientation_in_plane_steps, - ), - dtype="float", - ) - # self.orientation_ref_1D = np.zeros( - # ( - # self.orientation_num_zones, - # np.size(self.orientation_shell_radii), - # ), - # dtype="float", - # ) # If possible, Get symmetry operations for this spacegroup, store in matrix form if self.pymatgen_available: @@ -697,79 +685,110 @@ def orientation_plan( k0 = np.array([0.0, 0.0, -1.0 / self.wavelength]) n = np.array([0.0, 0.0, -1.0]) - for a0 in tqdmnd( - np.arange(self.orientation_num_zones), - desc="Orientation plan", - unit=" zone axes", - disable=not progress_bar, - ): - # reciprocal lattice spots and excitation errors - g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - sg = self.excitation_errors(g) - - # Keep only points that will contribute to this orientation plan slice - keep = np.abs(sg) < self.orientation_kernel_size - - # in-plane rotation angle - phi = np.arctan2(g[1, :], g[0, :]) - - # Loop over all peaks - for a1 in np.arange(self.g_vec_all.shape[1]): - ind_radial = self.orientation_shell_index[a1] + if calculate_correlation_array: + # initialize empty correlation array + self.orientation_ref = np.zeros( + ( + self.orientation_num_zones, + np.size(self.orientation_shell_radii), + self.orientation_in_plane_steps, + ), + dtype="float", + ) - if keep[a1] and ind_radial >= 0: - # 2D orientation plan - self.orientation_ref[a0, ind_radial, :] += ( - np.power(self.orientation_shell_radii[ind_radial], radial_power) - * np.power(self.struct_factors_int[a1], intensity_power) - * np.maximum( - 1 - - np.sqrt( - sg[a1] ** 2 - + ( - ( - np.mod( - self.orientation_gamma - phi[a1] + np.pi, - 2 * np.pi, + for a0 in tqdmnd( + np.arange(self.orientation_num_zones), + desc="Orientation plan", + unit=" zone axes", + disable=not progress_bar, + ): + # reciprocal lattice spots and excitation errors + g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all + sg = self.excitation_errors(g) + + # Keep only points that will contribute to this orientation plan slice + keep = np.abs(sg) < self.orientation_kernel_size + + # in-plane rotation angle + phi = np.arctan2(g[1, :], g[0, :]) + + # Loop over all peaks + for a1 in np.arange(self.g_vec_all.shape[1]): + ind_radial = self.orientation_shell_index[a1] + + if keep[a1] and ind_radial >= 0: + # 2D orientation plan + self.orientation_ref[a0, ind_radial, :] += ( + np.power(self.orientation_shell_radii[ind_radial], radial_power) + * np.power(self.struct_factors_int[a1], intensity_power) + * np.maximum( + 1 + - np.sqrt( + sg[a1] ** 2 + + ( + ( + np.mod( + self.orientation_gamma - phi[a1] + np.pi, + 2 * np.pi, + ) + - np.pi ) - - np.pi + * self.orientation_shell_radii[ind_radial] ) - * self.orientation_shell_radii[ind_radial] + ** 2 ) - ** 2 + / self.orientation_kernel_size, + 0, ) - / self.orientation_kernel_size, - 0, ) - ) - orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) - if orientation_ref_norm > 0: - self.orientation_ref[a0, :, :] /= orientation_ref_norm + orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) + if orientation_ref_norm > 0: + self.orientation_ref[a0, :, :] /= orientation_ref_norm - # Maximum value - self.orientation_ref_max = np.max(np.real(self.orientation_ref)) + # Maximum value + self.orientation_ref_max = np.max(np.real(self.orientation_ref)) - # Fourier domain along angular axis - if self.CUDA: - self.orientation_ref = cp.asarray(self.orientation_ref) - self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref)) - else: - self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref)) + # Fourier domain along angular axis + if self.CUDA: + self.orientation_ref = cp.asarray(self.orientation_ref) + self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref)) + else: + self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref)) def match_orientations( self, bragg_peaks_array: PointListArray, num_matches_return: int = 1, - min_number_peaks=3, - inversion_symmetry=True, - multiple_corr_reset=True, - progress_bar: bool = True, + min_angle_between_matches_deg=None, + min_number_peaks: int = 3, + inversion_symmetry: bool = True, + multiple_corr_reset: bool = True, return_orientation: bool = True, + progress_bar: bool = True, ): """ - This function computes the orientation of any number of PointLists stored in a PointListArray, and returns an OrienationMap. + Parameters + -------- + bragg_peaks_array: PointListArray + PointListArray containing the Bragg peaks and intensities, with calibrations applied + num_matches_return: int + return these many matches as 3th dim of orient (matrix) + min_angle_between_matches_deg: int + Minimum angle between zone axis of multiple matches, in degrees. + Note that I haven't thought how to handle in-plane rotations, since multiple matches are possible. + min_number_peaks: int + Minimum number of peaks required to perform ACOM matching + inversion_symmetry: bool + check for inversion symmetry in the matches + multiple_corr_reset: bool + keep original correlation score for multiple matches + return_orientation: bool + Return orientation map from function for inspection. + The map is always stored in the Crystal object. + progress_bar: bool + Show or hide the progress bar """ orientation_map = OrientationMap( @@ -808,6 +827,7 @@ def match_orientations( orientation = self.match_single_pattern( bragg_peaks=vectors, num_matches_return=num_matches_return, + min_angle_between_matches_deg=min_angle_between_matches_deg, min_number_peaks=min_number_peaks, inversion_symmetry=inversion_symmetry, multiple_corr_reset=multiple_corr_reset, @@ -816,6 +836,8 @@ def match_orientations( ) orientation_map.set_orientation(orientation, rx, ry) + + # assign and return self.orientation_map = orientation_map if return_orientation: @@ -828,6 +850,7 @@ def match_single_pattern( self, bragg_peaks: PointList, num_matches_return: int = 1, + min_angle_between_matches_deg=None, min_number_peaks=3, inversion_symmetry=True, multiple_corr_reset=True, @@ -842,26 +865,51 @@ def match_single_pattern( """ Solve for the best fit orientation of a single diffraction pattern. - Args: - bragg_peaks (PointList): numpy array containing the Bragg positions and intensities ('qx', 'qy', 'intensity') - num_matches_return (int): return these many matches as 3th dim of orient (matrix) - min_number_peaks (int): Minimum number of peaks required to perform ACOM matching - inversion_symmetry (bool): check for inversion symmetry in the matches - multiple_corr_reset (bool): keep original correlation score for multiple matches - subpixel_tilt (bool): set to false for faster matching, returning the nearest corr point - plot_polar (bool): set to true to plot the polar transform of the diffraction pattern - plot_corr (bool): set to true to plot the resulting correlogram - returnfig (bool): Return figure handles - figsize (list): size of figure - verbose (bool): Print the fitted zone axes, correlation scores - CUDA (bool): Enable CUDA for the FFT steps - - Returns: - orientation (Orientation): Orientation class containing all outputs - fig, ax (handles): Figure handles for the plotting output + Parameters + -------- + bragg_peaks: PointList + numpy array containing the Bragg positions and intensities ('qx', 'qy', 'intensity') + num_matches_return: int + return these many matches as 3th dim of orient (matrix) + min_angle_between_matches_deg: int + Minimum angle between zone axis of multiple matches, in degrees. + Note that I haven't thought how to handle in-plane rotations, since multiple matches are possible. + min_number_peaks: int + Minimum number of peaks required to perform ACOM matching + inversion_symmetry bool + check for inversion symmetry in the matches + multiple_corr_reset bool + keep original correlation score for multiple matches + subpixel_tilt: bool + set to false for faster matching, returning the nearest corr point + plot_polar: bool + set to true to plot the polar transform of the diffraction pattern + plot_corr: bool + set to true to plot the resulting correlogram + returnfig: bool + return figure handles + figsize: list + size of figure + verbose: bool + Print the fitted zone axes, correlation scores + CUDA: bool + Enable CUDA for the FFT steps + + Returns + -------- + orientation: Orientation + Orientation class containing all outputs + fig, ax: handles + Figure handles for the plotting output """ - # init orientation output + # adding assert statement for checking self.orientation_ref is present + # adding assert statement for checking self.orientation_ref is present + if not hasattr(self, "orientation_ref"): + raise ValueError( + "orientation_plan must be run with 'calculate_correlation_array=True'" + ) + orientation = Orientation(num_matches=num_matches_return) if bragg_peaks.data.shape[0] < min_number_peaks: return orientation @@ -1029,6 +1077,25 @@ def match_single_pattern( 0, ) + # If minimum angle is specified and we're on a match later than the first, + # we zero correlation values within the given range. + if min_angle_between_matches_deg is not None: + if match_ind > 0: + inds_previous = orientation.inds[:match_ind, 0] + for a0 in range(inds_previous.size): + mask_zero = np.arccos( + np.clip( + np.sum( + self.orientation_vecs + * self.orientation_vecs[inds_previous[a0], :], + axis=1, + ), + -1, + 1, + ) + ) < np.deg2rad(min_angle_between_matches_deg) + corr_full[mask_zero, :] = 0.0 + # Get maximum (non inverted) correlation value ind_phi = np.argmax(corr_full, axis=1) @@ -1096,6 +1163,26 @@ def match_single_pattern( ), 0, ) + + # If minimum angle is specified and we're on a match later than the first, + # we zero correlation values within the given range. + if min_angle_between_matches_deg is not None: + if match_ind > 0: + inds_previous = orientation.inds[:match_ind, 0] + for a0 in range(inds_previous.size): + mask_zero = np.arccos( + np.clip( + np.sum( + self.orientation_vecs + * self.orientation_vecs[inds_previous[a0], :], + axis=1, + ), + -1, + 1, + ) + ) < np.deg2rad(min_angle_between_matches_deg) + corr_full_inv[mask_zero, :] = 0.0 + ind_phi_inv = np.argmax(corr_full_inv, axis=1) corr_inv = np.zeros(self.orientation_num_zones, dtype="bool") @@ -1686,6 +1773,250 @@ def match_single_pattern( return orientation +def cluster_grains( + self, + threshold_add=1.0, + threshold_grow=0.1, + angle_tolerance_deg=5.0, + progress_bar=True, +): + """ + Cluster grains using rotation criterion, and correlation values. + + Parameters + -------- + threshold_add: float + Minimum signal required for a probe position to initialize a cluster. + threshold_grow: float + Minimum signal required for a probe position to be added to a cluster. + angle_tolerance_deg: float + Rotation rolerance for clustering grains. + progress_bar: bool + Turns on the progress bar for the polar transformation + + """ + + # symmetry operators + sym = self.symmetry_operators + + # Get data + # Correlation data = signal to cluster with + sig = self.orientation_map.corr.copy() + sig_init = sig.copy() + mark = sig >= threshold_grow + sig[np.logical_not(mark)] = 0 + # orientation matrix used for angle tolerance + matrix = self.orientation_map.matrix.copy() + + # init + self.cluster_sizes = np.array((), dtype="int") + self.cluster_sig = np.array(()) + self.cluster_inds = [] + self.cluster_orientation = [] + inds_all = np.zeros_like(sig, dtype="int") + inds_all.ravel()[:] = np.arange(inds_all.size) + + # Tolerance + tol = np.deg2rad(angle_tolerance_deg) + + # Main loop + search = True + comp = 0.0 + mark_total = np.sum(np.max(mark, axis=2)) + pbar = tqdm(total=mark_total, disable=not progress_bar) + while search is True: + inds_grain = np.argmax(sig) + + val = sig.ravel()[inds_grain] + + if val < threshold_add: + search = False + + else: + # Start cluster + x, y, z = np.unravel_index(inds_grain, sig.shape) + mark[x, y, z] = False + sig[x, y, z] = 0 + matrix_cluster = matrix[x, y, z] + orientation_cluster = self.orientation_map.get_orientation_single(x, y, z) + + # Neighbors to search + xr = np.clip(x + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1) + yr = np.clip(y + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1) + inds_cand = inds_all[xr[:, None], yr[None], :].ravel() + inds_cand = np.delete(inds_cand, mark.ravel()[inds_cand] == False) + + if inds_cand.size == 0: + grow = False + else: + grow = True + + # grow the cluster + while grow is True: + inds_new = np.array((), dtype="int") + + keep = np.zeros(inds_cand.size, dtype="bool") + for a0 in range(inds_cand.size): + xc, yc, zc = np.unravel_index(inds_cand[a0], sig.shape) + + # Angle test between orientation matrices + dphi = np.min( + np.arccos( + np.clip( + ( + np.trace( + self.symmetry_operators + @ matrix[xc, yc, zc] + @ np.transpose(matrix_cluster), + axis1=1, + axis2=2, + ) + - 1 + ) + / 2, + -1, + 1, + ) + ) + ) + + if np.abs(dphi) < tol: + keep[a0] = True + + sig[xc, yc, zc] = 0 + mark[xc, yc, zc] = False + + xr = np.clip( + xc + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1 + ) + yr = np.clip( + yc + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1 + ) + inds_add = inds_all[xr[:, None], yr[None], :].ravel() + inds_new = np.append(inds_new, inds_add) + + inds_grain = np.append(inds_grain, inds_cand[keep]) + inds_cand = np.unique( + np.delete(inds_new, mark.ravel()[inds_new] == False) + ) + + if inds_cand.size == 0: + grow = False + + # convert grain to x,y coordinates, add = list + xg, yg, zg = np.unravel_index(inds_grain, sig.shape) + xyg = np.unique(np.vstack((xg, yg)), axis=1) + sig_mean = np.mean(sig_init.ravel()[inds_grain]) + self.cluster_sizes = np.append(self.cluster_sizes, xyg.shape[1]) + self.cluster_sig = np.append(self.cluster_sig, sig_mean) + self.cluster_orientation.append(orientation_cluster) + self.cluster_inds.append(xyg) + + # update progressbar + new_marks = mark_total - np.sum(np.max(mark, axis=2)) + pbar.update(new_marks) + mark_total -= new_marks + + pbar.close() + + +def cluster_orientation_map( + self, + stripe_width=(2, 2), + area_min=2, +): + """ + Produce a new orientation map from the clustered grains. + Use a stripe pattern for the overlapping grains. + + Parameters + -------- + stripe_width: (int,int) + Width of stripes for plotting maps with overlapping grains + area_min: (int) + Minimum size of grains to include + + Returns + -------- + + orientation_map + The clustered orientation map + + """ + + # init + orientation_map = OrientationMap( + num_x=self.orientation_map.num_x, + num_y=self.orientation_map.num_y, + num_matches=1, + ) + im_grain = np.zeros( + (self.orientation_map.num_x, self.orientation_map.num_y), dtype="bool" + ) + im_count = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y)) + im_mark = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y)) + + # Loop over grains to determine number in each pixel + for a0 in range(self.cluster_sizes.shape[0]): + if self.cluster_sizes[a0] >= area_min: + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + im_count += im_grain + im_stripe = im_count >= 2 + im_single = np.logical_not(im_stripe) + + # prefactor for stripes + if stripe_width[0] == 0: + dx = 0 + else: + dx = 1 / stripe_width[0] + if stripe_width[1] == 0: + dy = 0 + else: + dy = 1 / stripe_width[1] + + # loop over grains + for a0 in range(self.cluster_sizes.shape[0]): + if self.cluster_sizes[a0] >= area_min: + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + + # non-overlapping grains + sub = np.logical_and(im_grain, im_single) + x, y = np.unravel_index(np.where(sub.ravel()), im_grain.shape) + x = np.atleast_1d(np.squeeze(x)) + y = np.atleast_1d(np.squeeze(y)) + for a1 in range(x.size): + orientation_map.set_orientation( + self.cluster_orientation[a0], x[a1], y[a1] + ) + + # overlapping grains + sub = np.logical_and(im_grain, im_stripe) + x, y = np.unravel_index(np.where(sub.ravel()), im_grain.shape) + x = np.atleast_1d(np.squeeze(x)) + y = np.atleast_1d(np.squeeze(y)) + for a1 in range(x.size): + d = np.mod( + x[a1] * dx + y[a1] * dy + im_mark[x[a1], y[a1]] + +0.5, + im_count[x[a1], y[a1]], + ) + + if d < 1.0: + orientation_map.set_orientation( + self.cluster_orientation[a0], x[a1], y[a1] + ) + im_mark[x[a1], y[a1]] += 1 + + return orientation_map + + def calculate_strain( self, bragg_peaks_array: PointListArray, diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 8e87b322f..a86454b58 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -5,6 +5,8 @@ from mpl_toolkits.mplot3d import Axes3D, art3d from scipy.signal import medfilt from scipy.ndimage import gaussian_filter +from scipy.ndimage.morphology import distance_transform_edt +from skimage.morphology import dilation, erosion import warnings import numpy as np @@ -147,7 +149,7 @@ def plot_structure( zs=xyz[sub, 2], # + d[2], s=size_marker, linewidth=2, - color=atomic_colors(ID_plot), + facecolors=atomic_colors(ID_plot), edgecolor=[0, 0, 0], ) @@ -1005,7 +1007,7 @@ def overline(x): def plot_orientation_maps( self, - orientation_map, + orientation_map=None, orientation_ind: int = 0, dir_in_plane_degrees: float = 0.0, corr_range: np.ndarray = np.array([0, 5]), @@ -1026,6 +1028,7 @@ def plot_orientation_maps( Args: orientation_map (OrientationMap): Class containing orientation matrices, correlation values, etc. + Optional - can reference internally stored OrientationMap. orientation_ind (int): Which orientation match to plot if num_matches > 1 dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down. corr_range (np.ndarray): Correlation intensity range for the plot @@ -1053,6 +1056,9 @@ def plot_orientation_maps( """ # Inputs + if orientation_map is None: + orientation_map = self.orientation_map + # Legend size leg_size = np.array([300, 300], dtype="int") @@ -1736,6 +1742,205 @@ def plot_fiber_orientation_maps( return images_orientation +def plot_clusters( + self, + area_min=2, + outline_grains=True, + outline_thickness=1, + fill_grains=0.25, + smooth_grains=1.0, + cmap="viridis", + figsize=(8, 8), + returnfig=False, +): + """ + Plot the clusters as an image. + + Parameters + -------- + area_min: int (optional) + Min cluster size to include, in units of probe positions. + outline_grains: bool (optional) + Set to True to draw grains with outlines + outline_thickness: int (optional) + Thickenss of the grain outline + fill_grains: float (optional) + Outlined grains are filled with this value in pixels. + smooth_grains: float (optional) + Grain boundaries are smoothed by this value in pixels. + figsize: tuple + Size of the figure panel + returnfig: bool + Setting this to true returns the figure and axis handles + + Returns + -------- + fig, ax (optional) + Figure and axes handles + + """ + + # init + im_plot = np.zeros( + ( + self.orientation_map.num_x, + self.orientation_map.num_y, + ) + ) + im_grain = np.zeros( + ( + self.orientation_map.num_x, + self.orientation_map.num_y, + ), + dtype="bool", + ) + + # make plotting image + + for a0 in range(self.cluster_sizes.shape[0]): + if self.cluster_sizes[a0] >= area_min: + if outline_grains: + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + + im_dist = distance_transform_edt( + erosion( + np.invert(im_grain), footprint=np.ones((3, 3), dtype="bool") + ) + ) - distance_transform_edt(im_grain) + im_dist = gaussian_filter(im_dist, sigma=smooth_grains, mode="nearest") + im_add = np.exp(im_dist**2 / (-0.5 * outline_thickness**2)) + + if fill_grains > 0: + im_dist = distance_transform_edt( + erosion( + np.invert(im_grain), footprint=np.ones((3, 3), dtype="bool") + ) + ) + im_dist = gaussian_filter( + im_dist, sigma=smooth_grains, mode="nearest" + ) + im_add += fill_grains * np.exp( + im_dist**2 / (-0.5 * outline_thickness**2) + ) + + # im_add = 1 - np.exp( + # distance_transform_edt(im_grain)**2 \ + # / (-2*outline_thickness**2)) + im_plot += im_add + # im_plot = np.minimum(im_plot, im_add) + else: + # xg,yg = np.unravel_index(self.cluster_inds[a0], im_plot.shape) + im_grain[:] = False + im_grain[ + self.cluster_inds[a0][0, :], + self.cluster_inds[a0][1, :], + ] = True + im_plot += gaussian_filter( + im_grain.astype("float"), sigma=smooth_grains, mode="nearest" + ) + + # im_plot[ + # self.cluster_inds[a0][0,:], + # self.cluster_inds[a0][1,:], + # ] += 1 + + if outline_grains: + im_plot = np.clip(im_plot, 0, 2) + + # plotting + fig, ax = plt.subplots(figsize=figsize) + ax.imshow( + im_plot, + # vmin = -3, + # vmax = 3, + cmap=cmap, + ) + + +def plot_cluster_size( + self, + area_min=None, + area_max=None, + area_step=1, + weight_intensity=False, + pixel_area=1.0, + pixel_area_units="px^2", + figsize=(8, 6), + returnfig=False, +): + """ + Plot the cluster sizes + + Parameters + -------- + area_min: int (optional) + Min area to include in pixels^2 + area_max: int (optional) + Max area bin in pixels^2 + area_step: int (optional) + Step size of the histogram bin in pixels^2 + weight_intensity: bool + Weight histogram by the peak intensity. + pixel_area: float + Size of pixel area unit square + pixel_area_units: string + Units of the pixel area + figsize: tuple + Size of the figure panel + returnfig: bool + Setting this to true returns the figure and axis handles + + Returns + -------- + fig, ax (optional) + Figure and axes handles + + """ + + if area_max is None: + area_max = np.max(self.cluster_sizes) + area = np.arange(0, area_max, area_step) + if area_min is None: + sub = self.cluster_sizes.astype("int") < area_max + else: + sub = np.logical_and( + self.cluster_sizes.astype("int") >= area_min, + self.cluster_sizes.astype("int") < area_max, + ) + if weight_intensity: + hist = np.bincount( + self.cluster_sizes[sub] // area_step, + weights=self.cluster_sig[sub], + minlength=area.shape[0], + ) + else: + hist = np.bincount( + self.cluster_sizes[sub] // area_step, + minlength=area.shape[0], + ) + + # plotting + fig, ax = plt.subplots(figsize=figsize) + ax.bar( + area * pixel_area, + hist, + width=0.8 * pixel_area * area_step, + ) + ax.set_xlim((0, area_max * pixel_area)) + ax.set_xlabel("Grain Area [" + pixel_area_units + "]") + if weight_intensity: + ax.set_ylabel("Total Signal [arb. units]") + else: + ax.set_ylabel("Number of Grains") + + if returnfig: + return fig, ax + + def axisEqual3D(ax): extents = np.array([getattr(ax, "get_{}lim".format(dim))() for dim in "xyz"]) sz = extents[:, 1] - extents[:, 0] diff --git a/py4DSTEM/process/diffraction/utils.py b/py4DSTEM/process/diffraction/utils.py index 09bd09f7c..cfb11f044 100644 --- a/py4DSTEM/process/diffraction/utils.py +++ b/py4DSTEM/process/diffraction/utils.py @@ -67,6 +67,16 @@ def get_orientation(self, ind_x, ind_y): orientation.angles = self.angles[ind_x, ind_y] return orientation + def get_orientation_single(self, ind_x, ind_y, ind_match): + orientation = Orientation(num_matches=1) + orientation.matrix = self.matrix[ind_x, ind_y, ind_match] + orientation.family = self.family[ind_x, ind_y, ind_match] + orientation.corr = self.corr[ind_x, ind_y, ind_match] + orientation.inds = self.inds[ind_x, ind_y, ind_match] + orientation.mirror = self.mirror[ind_x, ind_y, ind_match] + orientation.angles = self.angles[ind_x, ind_y, ind_match] + return orientation + # def __copy__(self): # return OrientationMap(self.name) # def __deepcopy__(self, memo): diff --git a/py4DSTEM/process/latticevectors/__init__.py b/py4DSTEM/process/latticevectors/__init__.py deleted file mode 100644 index 560a3b7e6..000000000 --- a/py4DSTEM/process/latticevectors/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from py4DSTEM.process.latticevectors.initialguess import * -from py4DSTEM.process.latticevectors.index import * -from py4DSTEM.process.latticevectors.fit import * -from py4DSTEM.process.latticevectors.strain import * diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py deleted file mode 100644 index 659bc8940..000000000 --- a/py4DSTEM/process/latticevectors/fit.py +++ /dev/null @@ -1,200 +0,0 @@ -# Functions for fitting lattice vectors to measured Bragg peak positions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.data import RealSlice - - -def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (7-tuple) A 7-tuple containing: - - * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. - * **y0**: *(float)* the y-coord of the origin - * **g1x**: *(float)* x-coord of the first lattice vector - * **g1y**: *(float)* y-coord of the first lattice vector - * **g2x**: *(float)* x-coord of the second lattice vector - * **g2y**: *(float)* y-coord of the second lattice vector - * **error**: *(float)* the fit error - """ - assert isinstance(braggpeaks, PointList) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - braggpeaks = braggpeaks.copy() - - # Remove unindexed peaks - if "index_mask" in braggpeaks.dtype.names: - deletemask = braggpeaks.data["index_mask"] == False - braggpeaks.remove(deletemask) - - # Check to ensure enough peaks are present - if braggpeaks.length < minNumPeaks: - return None, None, None, None, None, None, None - - # Get M, the matrix of (h,k) indices - h, k = braggpeaks.data["h"], braggpeaks.data["k"] - M = np.vstack((np.ones_like(h, dtype=int), h, k)).T - - # Get alpha, the matrix of measured Bragg peak positions - alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T - - # Get weighted matrices - weights = braggpeaks.data["intensity"] - weighted_M = M * weights[:, np.newaxis] - weighted_alpha = alpha * weights[:, np.newaxis] - - # Solve for lattice vectors - beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] - x0, y0 = beta[0, 0], beta[0, 1] - g1x, g1y = beta[1, 0], beta[1, 1] - g2x, g2y = beta[2, 0], beta[2, 1] - - # Calculate the error - alpha_calculated = np.matmul(M, beta) - error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) - error = np.sum(error * weights) / np.sum(weights) - - return x0, y0, g1x, g1y, g2x, g2y, error - - -def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - - return g1g2_map - - -def fit_lattice_vectors_masked(braggpeaks, mask, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks corresponding - to a scan position for which mask==True. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - mask (boolean array): real space shaped (R_Nx,R_Ny); fit lattice vectors where - mask is True - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((braggpeaks.shape[0], braggpeaks.shape[1], 8)), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - if mask[Rx, Ry]: - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - return g1g2_map diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py deleted file mode 100644 index 4ac7939e7..000000000 --- a/py4DSTEM/process/latticevectors/index.py +++ /dev/null @@ -1,280 +0,0 @@ -# Functions for indexing the Bragg directions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray - - -def get_selected_lattice_vectors(gx, gy, i0, i1, i2): - """ - From a set of reciprocal lattice points (gx,gy), and indices in those arrays which - specify the center beam, the first basis lattice vector, and the second basis lattice - vector, computes and returns the lattice vectors g1 and g2. - - Args: - gx (1d array): the reciprocal lattice points x-coords - gy (1d array): the reciprocal lattice points y-coords - i0 (int): index in the (gx,gy) arrays specifying the center beam - i1 (int): index in the (gx,gy) arrays specifying the first basis lattice vector - i2 (int): index in the (gx,gy) arrays specifying the second basis lattice vector - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing - - * **g1**: *(2-tuple)* the first lattice vector, (g1x,g1y) - * **g2**: *(2-tuple)* the second lattice vector, (g2x,g2y) - """ - for i in (i0, i1, i2): - assert isinstance(i, (int, np.integer)) - g1x = gx[i1] - gx[i0] - g1y = gy[i1] - gy[i0] - g2x = gx[i2] - gx[i0] - g2y = gy[i2] - gy[i0] - return (g1x, g1y), (g2x, g2y) - - -def index_bragg_directions(x0, y0, gx, gy, g1, g2): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - The approach is to solve the matrix equation - ``alpha = beta * M`` - where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, - beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the - h,k indices. - - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - gx (1d array): x-coord of the reciprocal lattice vectors - gy (1d array): y-coord of the reciprocal lattice vectors - g1 (2-tuple of floats): g1x,g1y - g2 (2-tuple of floats): g2x,g2y - - Returns: - (3-tuple) A 3-tuple containing: - - * **h**: *(ndarray of ints)* first index of the bragg directions - * **k**: *(ndarray of ints)* second index of the bragg directions - * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the - indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y - coords 'h' and 'k' contain h and k. - """ - # Get beta, the matrix of lattice vectors - beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) - - # Get alpha, the matrix of measured bragg angles - alpha = np.vstack([gx - x0, gy - y0]) - - # Calculate M, the matrix of peak positions - M = lstsq(beta, alpha, rcond=None)[0].T - M = np.round(M).astype(int) - - # Get h,k - h = M[:, 0] - k = M[:, 1] - - # Store in a PointList - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - temp_array = np.zeros([], dtype=coords) - bragg_directions = PointList(data=temp_array) - bragg_directions.add_data_by_field((gx, gy, h, k)) - mask = np.zeros(bragg_directions["qx"].shape[0]) - mask[0] = 1 - bragg_directions.remove(mask) - - return h, k, bragg_directions - - -def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None): - """ - Returns a full reciprocal lattice stretching to the limits of the diffraction pattern - by making linear combinations of the lattice vectors up to (±h_max,±k_max). - - This can be useful when there are false peaks or missing peaks in the braggvectormap, - which can cause errors in the strain finding routines that rely on those peaks for - indexing. This allows us to create a reference lattice that has all combinations of - the lattice vectors all the way out to the edges of the frame, and excluding any - erroneous intermediate peaks. - - Args: - ux (float): x-coord of the u lattice vector - uy (float): y-coord of the u lattice vector - vx (float): x-coord of the v lattice vector - vy (float): y-coord of the v lattice vector - x0 (float): x-coord of the lattice origin - y0 (float): y-coord of the lattice origin - Q_Nx (int): diffraction pattern size in the x-direction - Q_Ny (int): diffraction pattern size in the y-direction - h_max, k_max (int): maximal indices for generating the lattice (the lattive is - always trimmed to fit inside the pattern so you can overestimate these, or - leave unspecified and they will be automatically found) - - Returns: - (PointList): A 4-coordinate PointList, ('qx','qy','h','k'), containing points - corresponding to linear combinations of the u and v vectors, with associated - indices - """ - - # Matrix of lattice vectors - beta = np.array([[ux, uy], [vx, vy]]) - - # If no max index is specified, (over)estimate based on image size - if (h_max is None) or (k_max is None): - (y, x) = np.mgrid[0:Q_Ny, 0:Q_Nx] - x = x - x0 - y = y - y0 - h_max = np.max(np.ceil(np.abs((x / ux, y / uy)))) - k_max = np.max(np.ceil(np.abs((x / vx, y / vy)))) - - (hlist, klist) = np.meshgrid( - np.arange(-h_max, h_max + 1), np.arange(-k_max, k_max + 1) - ) - - M_ideal = np.vstack((hlist.ravel(), klist.ravel())).T - ideal_peaks = np.matmul(M_ideal, beta) - - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - - ideal_data = np.zeros(len(ideal_peaks[:, 0]), dtype=coords) - ideal_data["qx"] = ideal_peaks[:, 0] - ideal_data["qy"] = ideal_peaks[:, 1] - ideal_data["h"] = M_ideal[:, 0] - ideal_data["k"] = M_ideal[:, 1] - - ideal_lattice = PointList(data=ideal_data) - - # shift to the DP center - ideal_lattice.data["qx"] += x0 - ideal_lattice.data["qy"] += y0 - - # trim peaks outside the image - deletePeaks = ( - (ideal_lattice.data["qx"] > Q_Nx) - | (ideal_lattice.data["qx"] < 0) - | (ideal_lattice.data["qy"] > Q_Ny) - | (ideal_lattice.data["qy"] < 0) - ) - ideal_lattice.remove(deletePeaks) - - return ideal_lattice - - -def add_indices_to_braggvectors( - braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None -): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - braggpeaks (PointListArray): the braggpeaks to index. Must contain - the coordinates 'qx', 'qy', and 'intensity' - lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. - Must contain the coordinates 'qx', 'qy', 'h', and 'k' - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - - Returns: - (PointListArray): The original braggpeaks pointlistarray, with new coordinates - 'h', 'k', containing the indices of each indexable peak. - """ - - # assert isinstance(braggpeaks,BraggVectors) - # assert isinstance(lattice, PointList) - # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) - - if mask is None: - mask = np.ones(braggpeaks.Rshape, dtype=bool) - - assert ( - mask.shape == braggpeaks.Rshape - ), "mask must have same shape as pointlistarray" - assert mask.dtype == bool, "mask must be boolean" - - coords = [ - ("qx", float), - ("qy", float), - ("intensity", float), - ("h", int), - ("k", int), - ] - - indexed_braggpeaks = PointListArray( - dtype=coords, - shape=braggpeaks.Rshape, - ) - - # loop over all the scan positions - for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): - if mask[Rx, Ry]: - pl = braggpeaks.cal[Rx, Ry] - for i in range(pl.data.shape[0]): - r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( - pl.data["qy"][i] - lattice.data["qy"] + qy_shift - ) ** 2 - ind = np.argmin(r2) - if r2[ind] <= maxPeakSpacing**2: - indexed_braggpeaks[Rx, Ry].add_data_by_field( - ( - pl.data["qx"][i], - pl.data["qy"][i], - pl.data["intensity"][i], - lattice.data["h"][ind], - lattice.data["k"][ind], - ) - ) - - return indexed_braggpeaks - - -def bragg_vector_intensity_map_by_index(braggpeaks, h, k, symmetric=False): - """ - Returns a correlation intensity map for an indexed (h,k) Bragg vector - Used to obtain a darkfield image corresponding to the (h,k) reflection - or a bightfield image when h=k=0 - - Args: - braggpeaks (PointListArray): must contain the coordinates 'h','k', and - 'intensity' - h, k (int): indices for the reflection to generate an intensity map from - symmetric (bool): if set to true, returns sum of intensity of (h,k), (-h,k), - (h,-k), (-h,-k) - - Returns: - (numpy array): a map of the intensity of the (h,k) Bragg vector correlation. - Same shape as the pointlistarray. - """ - assert isinstance(braggpeaks, PointListArray), "braggpeaks must be a PointListArray" - assert np.all([name in braggpeaks.dtype.names for name in ("h", "k", "intensity")]) - intensity_map = np.zeros(braggpeaks.shape, dtype=float) - - for Rx in range(braggpeaks.shape[0]): - for Ry in range(braggpeaks.shape[1]): - pl = braggpeaks.get_pointlist(Rx, Ry) - if pl.length > 0: - if symmetric: - matches = np.logical_and( - np.abs(pl.data["h"]) == np.abs(h), - np.abs(pl.data["k"]) == np.abs(k), - ) - else: - matches = np.logical_and(pl.data["h"] == h, pl.data["k"] == k) - - if len(matches) > 0: - intensity_map[Rx, Ry] = np.sum(pl.data["intensity"][matches]) - - return intensity_map diff --git a/py4DSTEM/process/latticevectors/initialguess.py b/py4DSTEM/process/latticevectors/initialguess.py deleted file mode 100644 index d8054143f..000000000 --- a/py4DSTEM/process/latticevectors/initialguess.py +++ /dev/null @@ -1,229 +0,0 @@ -# Obtain an initial guess at the lattice vectors - -import numpy as np -from scipy.ndimage import gaussian_filter -from skimage.transform import radon - -from py4DSTEM.process.utils import get_maxima_1D - - -def get_radon_scores( - braggvectormap, - mask=None, - N_angles=200, - sigma=2, - minSpacing=2, - minRelativeIntensity=0.05, -): - """ - Calculates a score function, score(angle), representing the likelihood that angle is - a principle lattice direction of the lattice in braggvectormap. - - The procedure is as follows: - If mask is not None, ignore any data in braggvectormap where mask is False. Useful - for removing the unscattered beam, which can dominate the results. - Take the Radon transform of the (masked) Bragg vector map. - For each angle, get the corresponding slice of the sinogram, and calculate its score. - If we let R_theta(r) be the sinogram slice at angle theta, and where r is the - sinogram position coordinate, then the score of the slice is given by - score(theta) = sum_i(R_theta(r_i)) / N_i - Here, r_i are the positions r of all local maxima in R_theta(r), and N_i is the - number of such maxima. Thus the score is large when there are few maxima which are - high intensity. - - Args: - braggvectormap (ndarray): the Bragg vector map - mask (ndarray of bools): ignore data in braggvectormap wherever mask==False - N_angles (int): the number of angles at which to calculate the score - sigma (float): smoothing parameter for local maximum identification - minSpacing (float): if two maxima are found in a radon slice closer than - minSpacing, the dimmer of the two is removed - minRelativeIntensity (float): maxima in each radon slice dimmer than - minRelativeIntensity compared to the most intense maximum are removed - - Returns: - (3-tuple) A 3-tuple containing: - - * **scores**: *(ndarray, len N_angles, floats)* the scores for each angle - * **thetas**: *(ndarray, len N_angles, floats)* the angles, in radians - * **sinogram**: *(ndarray)* the radon transform of braggvectormap*mask - """ - # Get sinogram - thetas = np.linspace(0, 180, N_angles) - if mask is not None: - sinogram = radon(braggvectormap * mask, theta=thetas, circle=False) - else: - sinogram = radon(braggvectormap, theta=thetas, circle=False) - - # Get scores - N_maxima = np.empty_like(thetas) - total_intensity = np.empty_like(thetas) - for i in range(len(thetas)): - theta = thetas[i] - - # Get radon transform slice - ind = np.argmin(np.abs(thetas - theta)) - sinogram_theta = sinogram[:, ind] - sinogram_theta = gaussian_filter(sinogram_theta, 2) - - # Get maxima - maxima = get_maxima_1D(sinogram_theta, sigma, minSpacing, minRelativeIntensity) - - # Calculate metrics - N_maxima[i] = len(maxima) - total_intensity[i] = np.sum(sinogram_theta[maxima]) - scores = total_intensity / N_maxima - - return scores, np.radians(thetas), sinogram - - -def get_lattice_directions_from_scores( - thetas, scores, sigma=2, minSpacing=2, minRelativeIntensity=0.05, index1=0, index2=0 -): - """ - Get the lattice directions from the scores of the radon transform slices. - - Args: - thetas (ndarray): the angles, in radians - scores (ndarray): the scores - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - index1 (int): specifies which local maximum to use for the first lattice - direction, in order of maximum intensity - index2 (int): specifies the local maximum for the second lattice direction - - Returns: - (2-tuple) A 2-tuple containing: - - * **theta1**: *(float)* the first lattice direction, in radians - * **theta2**: *(float)* the second lattice direction, in radians - """ - assert len(thetas) == len(scores), "Size of thetas and scores must match" - - # Get first lattice direction - maxima1 = get_maxima_1D( - scores, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max1 = thetas[maxima1] - scores_max1 = scores[maxima1] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max1), dtype=dtype) - ar_structured["thetas"] = thetas_max1 - ar_structured["scores"] = scores_max1 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta1 = ar_structured["thetas"][index1] # Get direction 1 - - # Apply sin**2 damping - scores_damped = scores * np.sin(thetas - theta1) ** 2 - - # Get second lattice direction - maxima2 = get_maxima_1D( - scores_damped, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max2 = thetas[maxima2] - scores_max2 = scores[maxima2] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max2), dtype=dtype) - ar_structured["thetas"] = thetas_max2 - ar_structured["scores"] = scores_max2 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta2 = ar_structured["thetas"][index2] # Get direction 2 - - return theta1, theta2 - - -def get_lattice_vector_lengths( - u_theta, - v_theta, - thetas, - sinogram, - spacing_thresh=1.5, - sigma=1, - minSpacing=2, - minRelativeIntensity=0.1, -): - """ - Gets the lengths of the two lattice vectors from their angles and the sinogram. - - First, finds the spacing between peaks in the sinogram slices projected down the u- - and v- directions, u_proj and v_proj. Then, finds the lengths by taking:: - - |u| = v_proj/sin(u_theta-v_theta) - |v| = u_proj/sin(u_theta-v_theta) - - The most important thresholds for this function are spacing_thresh, which discards - any detected spacing between adjacent radon projection peaks which deviate from the - median spacing by more than this fraction, and minRelativeIntensity, which discards - detected maxima (from which spacings are then calculated) below this threshold - relative to the brightest maximum. - - Args: - u_theta (float): the angle of u, in radians - v_theta (float): the angle of v, in radians - thetas (ndarray): the angles corresponding to the sinogram - sinogram (ndarray): the sinogram - spacing_thresh (float): ignores spacings which are greater than spacing_thresh - times the median spacing - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - - Returns: - (2-tuple) A 2-tuple containing: - - * **u_length**: *(float)* the length of u, in pixels - * **v_length**: *(float)* the length of v, in pixels - """ - assert ( - len(thetas) == sinogram.shape[1] - ), "thetas must corresponding to the number of sinogram projection directions." - - # Get u projected spacing - ind = np.argmin(np.abs(thetas - u_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - u_projected_spacing = np.mean(spacings) - - # Get v projected spacing - ind = np.argmin(np.abs(thetas - v_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - v_projected_spacing = np.mean(spacings) - - # Get u and v lengths - sin_uv = np.sin(np.abs(u_theta - v_theta)) - u_length = v_projected_spacing / sin_uv - v_length = u_projected_spacing / sin_uv - - return u_length, v_length diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py deleted file mode 100644 index 6f4000449..000000000 --- a/py4DSTEM/process/latticevectors/strain.py +++ /dev/null @@ -1,231 +0,0 @@ -# Functions for calculating strain from lattice vector maps - -import numpy as np -from numpy.linalg import lstsq - -from py4DSTEM.data import RealSlice - - -def get_reference_g1g2(g1g2_map, mask): - """ - Gets a pair of reference lattice vectors from a region of real space specified by - mask. Takes the median of the lattice vectors in g1g2_map within the specified - region. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever - mask==True - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing: - - * **g1**: *(2-tuple)* first reference lattice vector (x,y) - * **g2**: *(2-tuple)* second reference lattice vector (x,y) - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] - ) - assert mask.dtype == bool - g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) - g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) - g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) - g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) - return (g1x, g1y), (g2x, g2y) - - -def get_strain_from_reference_g1g2(g1g2_map, g1, g2): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real and - diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - g1 (2-tuple): first reference lattice vector (x,y) - g2 (2-tuple): second reference lattice vector (x,y) - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - - # Get RealSlice for output storage - R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape - strain_map = RealSlice( - data=np.zeros((5, R_Nx, R_Ny)), - slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), - name="strain_map", - ) - - # Get reference lattice matrix - g1x, g1y = g1 - g2x, g2y = g2 - M = np.array([[g1x, g1y], [g2x, g2y]]) - - for Rx in range(R_Nx): - for Ry in range(R_Ny): - # Get lattice vectors for DP at Rx,Ry - alpha = np.array( - [ - [ - g1g2_map.get_slice("g1x").data[Rx, Ry], - g1g2_map.get_slice("g1y").data[Rx, Ry], - ], - [ - g1g2_map.get_slice("g2x").data[Rx, Ry], - g1g2_map.get_slice("g2y").data[Rx, Ry], - ], - ] - ) - # Get transformation matrix - beta = lstsq(M, alpha, rcond=None)[0].T - - # Get the infinitesimal strain matrix - strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] - strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] - strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 - strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 - strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ - Rx, Ry - ] - return strain_map - - -def get_strain_from_reference_region(g1g2_map, mask): - """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real - and diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - assert mask.dtype == bool - - g1, g2 = get_reference_g1g2(g1g2_map, mask) - strain_map = get_strain_from_reference_g1g2(g1g2_map, g1, g2) - return strain_map - - -def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector - along the new x-axis - unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the - infinitessimal strain matrix elements, stored at - * unrotated_strain_map.get_slice('e_xx') - * unrotated_strain_map.get_slice('e_xy') - * unrotated_strain_map.get_slice('e_yy') - * unrotated_strain_map.get_slice('theta') - - Returns: - (RealSlice) the rotated counterpart to unrotated_strain_map, with the - rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate - system - """ - assert isinstance(unrotated_strain_map, RealSlice) - assert np.all( - [ - key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] - for key in unrotated_strain_map.slicelabels - ] - ) - theta = -np.arctan2(xaxis_y, xaxis_x) - cost = np.cos(theta) - sint = np.sin(theta) - cost2 = cost**2 - sint2 = sint**2 - - Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape - rotated_strain_map = RealSlice( - data=np.zeros((5, Rx, Ry)), - slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], - name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), - ) - - rotated_strain_map.data[0, :, :] = ( - cost2 * unrotated_strain_map.get_slice("e_xx").data - - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + sint2 * unrotated_strain_map.get_slice("e_yy").data - ) - rotated_strain_map.data[1, :, :] = ( - cost - * sint - * ( - unrotated_strain_map.get_slice("e_xx").data - - unrotated_strain_map.get_slice("e_yy").data - ) - + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data - ) - rotated_strain_map.data[2, :, :] = ( - sint2 * unrotated_strain_map.get_slice("e_xx").data - + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + cost2 * unrotated_strain_map.get_slice("e_yy").data - ) - if flip_theta == True: - rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data - else: - rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data - rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data - return rotated_strain_map diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 178079349..1005a619d 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,28 +3,14 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import ( - MixedstatePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_multislice_ptychography import ( - MultislicePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import ( - OverlapMagneticTomographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_tomography import ( - OverlapTomographicReconstruction, -) +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import ( - SimultaneousPtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( - SingleslicePtychographicReconstruction, -) -from py4DSTEM.process.phase.parameter_optimize import ( - OptimizationParameter, - PtychographyOptimizer, -) +from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index ae4c92d4b..767789df2 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -8,13 +8,13 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import show, show_complex +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import rotate try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration @@ -23,7 +23,11 @@ from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( PtychographicConstraints, ) -from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases +from py4DSTEM.process.phase.utils import ( + AffineTransform, + generate_batches, + polar_aliases, +) from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -56,6 +60,53 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self + def reinitialize_parameters(self, device: str = None, verbose: bool = None): + """ + Reinitializes common parameters. This is useful when loading a previously-saved + reconstruction (which set device='cpu' and verbose=True for compatibility) , + using different initialization parameters. + + Parameters + ---------- + device: str, optional + If not None, imports and assigns appropriate device modules + verbose: bool, optional + If not None, sets the verbosity to verbose + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device is not None: + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device + + if verbose is not None: + self._verbose = verbose + + return self + def set_save_defaults( self, save_datacube: bool = False, @@ -278,7 +329,9 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - intensities = datacube.data + xp = self._xp + + intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -295,13 +348,14 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) self._scan_sampling = (1.0, 1.0) self._scan_units = ("pixels",) * 2 @@ -359,13 +413,14 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) self._angular_sampling = (1.0, 1.0) self._angular_units = ("pixels",) * 2 @@ -448,8 +503,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - intensities = xp.asarray(intensities, dtype=xp.float32) - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -484,9 +537,14 @@ def _calculate_intensities_center_of_mass( ) if com_shifts is None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + com_shifts = fit_origin( - (asnumpy(com_measured_x), asnumpy(com_measured_y)), + (com_measured_x_np, com_measured_y_np), fitfunction=fit_function, + mask=finite_mask, ) # Fit function to center of mass @@ -494,12 +552,12 @@ def _calculate_intensities_center_of_mass( com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units - com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[ - 0 - ] - com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[ - 1 - ] + com_normalized_x = ( + xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + ) + com_normalized_y = ( + xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + ) return ( com_measured_x, @@ -1077,6 +1135,8 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, + crop_patterns, + positions_mask, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1089,6 +1149,11 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns + when centering + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1101,16 +1166,59 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros_like(diffraction_intensities) - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities = self._asnumpy(diffraction_intensities) + if positions_mask is not None: + number_of_patterns = np.count_nonzero(self._positions_mask.ravel()) + else: + number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + + if crop_patterns: + crop_x = int( + np.minimum( + diffraction_intensities.shape[2] - com_fitted_x.max(), + com_fitted_x.min(), + ) + ) + crop_y = int( + np.minimum( + diffraction_intensities.shape[3] - com_fitted_y.max(), + com_fitted_y.min(), + ) + ) + + crop_w = np.minimum(crop_y, crop_x) + region_of_interest_shape = (crop_w * 2, crop_w * 2) + amplitudes = np.zeros( + ( + number_of_patterns, + crop_w * 2, + crop_w * 2, + ), + dtype=np.float32, + ) + + crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask[:crop_w, :crop_w] = True + crop_mask[-crop_w:, :crop_w] = True + crop_mask[:crop_w:, -crop_w:] = True + crop_mask[-crop_w:, -crop_w:] = True + self._crop_mask = crop_mask + + else: + region_of_interest_shape = diffraction_intensities.shape[-2:] + amplitudes = np.zeros( + (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + ) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) - diffraction_intensities = self._asnumpy(diffraction_intensities) - amplitudes = self._asnumpy(amplitudes) + counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): + if positions_mask is not None: + if not self._positions_mask[rx, ry]: + continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], -com_fitted_x[rx, ry], @@ -1119,16 +1227,71 @@ def _normalize_diffraction_intensities( device="cpu", ) - mean_intensity += np.sum(intensities) - amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) + if crop_patterns: + intensities = intensities[crop_mask].reshape( + region_of_interest_shape + ) - amplitudes = xp.asarray(amplitudes, dtype=xp.float32) + mean_intensity += np.sum(intensities) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) + amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity + def show_complex_CoM( + self, + com=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot complex-valued CoM image + + Parameters + ---------- + + com = (CoM_x, CoM_y) tuple + If None is specified, uses (self.com_x, self.com_y) instead + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A + pixelsize: float, optional + default is scan sampling + """ + + if com is None: + com = (self.com_x, self.com_y) + + if pixelsize is None: + pixelsize = self._scan_sampling[0] + if pixelunits is None: + pixelunits = self._scan_units[0] + + figsize = kwargs.pop("figsize", (6, 6)) + fig, ax = plt.subplots(figsize=figsize) + + complex_com = com[0] + 1j * com[1] + + show_complex( + complex_com, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): """ @@ -1196,6 +1359,7 @@ def to_h5(self, group): "num_diffraction_patterns": self._num_diffraction_patterns, "sampling": self.sampling, "angular_sampling": self.angular_sampling, + "positions_mask": self._positions_mask, }, ) @@ -1309,10 +1473,10 @@ def _get_constructor_args(cls, group): "object_type": instance_md["object_type"], "semiangle_cutoff": instance_md["semiangle_cutoff"], "rolloff": instance_md["rolloff"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], "polar_parameters": polar_params, + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } class_specific_kwargs = {} @@ -1338,6 +1502,7 @@ def _populate_instance(self, group): self._angular_sampling = preprocess_md["angular_sampling"] self._region_of_interest_shape = preprocess_md["region_of_interest_shape"] self._num_diffraction_patterns = preprocess_md["num_diffraction_patterns"] + self._positions_mask = preprocess_md["positions_mask"] # Reconstruction metadata reconstruction_md = _read_metadata(group, "reconstruction_metadata") @@ -1389,7 +1554,9 @@ def _set_polar_parameters(self, parameters: dict): else: raise ValueError("{} not a recognized parameter".format(symbol)) - def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): + def _calculate_scan_positions_in_pixels( + self, positions: np.ndarray, positions_mask + ): """ Method to compute the initial guess of scan positions in pixels. @@ -1398,6 +1565,8 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions: (J,2) np.ndarray or None Input probe positions in Å. If None, a raster scan using experimental parameters is constructed. + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1429,7 +1598,9 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): x = (x - np.ptp(x) / 2) / self.sampling[0] y = (y - np.ptp(y) / 2) / self.sampling[1] x, y = np.meshgrid(x, y, indexing="ij") - + if positions_mask is not None: + x = x[positions_mask] + y = y[positions_mask] else: positions -= np.mean(positions, axis=0) x = positions[:, 0] / self.sampling[1] @@ -2071,6 +2242,243 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(asnumpy(obj)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped) + else: + projected_cropped_potential = self.object_cropped + + return projected_cropped_potential + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(xp.asarray(errors), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") + pix_output /= pix_count + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5.25) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized Squared Error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], + 0, + ] + + ax.imshow( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + def show_fourier_probe( self, probe=None, @@ -2109,6 +2517,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 2) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2119,6 +2528,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) @@ -2138,22 +2548,16 @@ def show_object_fft(self, obj=None, **kwargs): figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -2218,6 +2622,6 @@ def positions(self): @property def object_cropped(self): - """cropped and rotated object""" + """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 4c80ed177..b390ce46d 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -13,8 +13,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from py4DSTEM.data import Calibration @@ -195,9 +195,9 @@ def _get_constructor_args(cls, group): "datacube": dc, "initial_object_guess": np.asarray(obj), "energy": instance_md["energy"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs @@ -718,24 +718,26 @@ def reconstruct( xp = self._xp asnumpy = self._asnumpy - if reset is None and hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - self.error_iterations = [] if reset: self.error = np.inf + self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] self.error = getattr(self, "error", np.inf) @@ -770,7 +772,8 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - print(f"Iteration {a0}, step reduced to {self._step_size}") + if self._verbose: + print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -796,6 +799,7 @@ def reconstruct( anti_gridding=anti_gridding, ) + self.error_iterations.append(self.error.item()) if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -804,13 +808,13 @@ def reconstruct( ].copy() ) ) - self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: - warnings.warn( - f"Step-size has decreased below stopping criterion {stopping_criterion}.", - UserWarning, - ) + if self._verbose: + warnings.warn( + f"Step-size has decreased below stopping criterion {stopping_criterion}.", + UserWarning, + ) # crop result self._object_phase = self._padded_object_phase[ @@ -840,7 +844,7 @@ def _visualize_last_iteration( If true, the NMSE error plot is displayed """ - figsize = kwargs.pop("figsize", (8, 8)) + figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") if plot_convergence: @@ -862,7 +866,7 @@ def _visualize_last_iteration( im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC Phase Reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") if cbar: divider = make_axes_locatable(ax1) @@ -870,11 +874,11 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "_error_iterations"): - errors = self._error_iterations + if plot_convergence: + errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -979,7 +983,7 @@ def _visualize_all_iterations( if plot_convergence: ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -990,7 +994,7 @@ def visualize( fig=None, iterations_grid: Tuple[int, int] = None, plot_convergence: bool = True, - cbar: bool = False, + cbar: bool = True, **kwargs, ): """ diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..f4c10cb13 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -0,0 +1,3654 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pylops +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex + +try: + import cupy as cp +except ImportError: + cp = None + +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate + +warnings.simplefilter(action="always", category=UserWarning) + + +class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._positions_mask = positions_mask + self._object_padding_px = object_padding_px + self._verbose = verbose + self._device = device + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + probe_roi_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + probe_roi_shape, (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._probe_roi_shape = probe_roi_shape + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + self._intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + self.com_x, + self.com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + ( + self._amplitudes, + self._mean_diffraction_intensity, + ) = self._normalize_diffraction_intensities( + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, + ) + + # explicitly delete namespace + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + del self._intensities + + self._positions_px = self._calculate_scan_positions_in_pixels( + self._scan_positions, self._positions_mask + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + if self._object is None: + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) + p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( + "int" + ) + if self._object_type == "potential": + self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) + else: + if self._object_type == "potential": + self._object = xp.asarray(self._object, dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.asarray(self._object, dtype=xp.complex64) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # Vectorized Patches + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Probe Initialization + if self._probe is None or isinstance(self._probe, ComplexProbe): + if self._probe is None: + if self._vacuum_probe_intensity is not None: + self._semiangle_cutoff = np.inf + self._vacuum_probe_intensity = xp.asarray( + self._vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + self._vacuum_probe_intensity, + device=self._device, + ) + self._vacuum_probe_intensity = get_shifted_ar( + self._vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=self._device, + ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + _probe = ( + ComplexProbe( + gpts=self._region_of_interest_shape, + sampling=self.sampling, + energy=self._energy, + semiangle_cutoff=self._semiangle_cutoff, + rolloff=self._rolloff, + vacuum_probe_intensity=self._vacuum_probe_intensity, + parameters=self._polar_parameters, + device=self._device, + ) + .build() + ._array + ) + + else: + if self._probe._gpts != self._region_of_interest_shape: + raise ValueError() + if hasattr(self._probe, "_array"): + _probe = self._probe._array + else: + self._probe._xp = xp + _probe = self._probe.build()._array + + self._probe = xp.zeros( + (self._num_probes,) + tuple(self._region_of_interest_shape), + dtype=xp.complex64, + ) + sx, sy = self._region_of_interest_shape + self._probe[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, self._num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + self._probe[i_probe] = ( + self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) + self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) + + else: + self._probe = xp.asarray(self._probe, dtype=xp.complex64) + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + self._theta_x, + self._theta_y, + ) + + # overlaps + shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) + probe_intensities = xp.abs(shifted_probes) ** 2 + probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + + if object_fov_mask is None: + self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=2, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=2, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe[0] intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax2, chroma_boost=chroma_boost) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe[0] intensity") + + ax3.imshow( + asnumpy(probe_overlap), + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + propagated_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) + propagated_probes[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes = ( + xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes[s + 1] = self._propagate_array( + transmitted_probes, self._propagator_arrays[s] + ) + + return propagated_probes, object_patches, transmitted_probes + + def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + + Returns + -------- + exit_waves:np.ndarray + Exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves + modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_exit_wave - transmitted_probes + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = transmitted_probes.copy() + + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = ( + projection_c * transmitted_probes + projection_y * exit_waves + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * transmitted_probes + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + ( + propagated_probes, + object_patches, + transmitted_probes, + ) = self._overlap_projection(current_object, current_probe) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, transmitted_probes + ) + + return propagated_probes, object_patches, transmitted_probes, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + current_probe, + transmitted_probes, + amplitudes, + current_positions, + positions_step_size, + constrain_position_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe:np.ndarray + fractionally-shifted probes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + constrain_position_distance: float + Distance to constrain position correction within original + field of view in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # Intensity gradient + exit_waves_fft = xp.fft.fft2(transmitted_probes) + exit_waves_fft_conj = xp.conj(exit_waves_fft) + estimated_intensity = xp.abs(exit_waves_fft) ** 2 + measured_intensity = amplitudes**2 + + flat_shape = (transmitted_probes.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # Computing perturbed exit waves one at a time to save on memory + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + # dx + obj_rolled_patches = complex_object[ + :, + (self._vectorized_patch_indices_row + 1) % self._object_shape[0], + self._vectorized_patch_indices_col, + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + # dy + obj_rolled_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + (self._vectorized_patch_indices_col + 1) % self._object_shape[1], + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + partial_intensity_dx = 2 * xp.real( + exit_waves_dx_fft * exit_waves_fft_conj + ).reshape(flat_shape) + partial_intensity_dy = 2 * xp.real( + exit_waves_dy_fft * exit_waves_fft_conj + ).reshape(flat_shape) + + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + + # positions_update = xp.einsum( + # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity + # ) + + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + if constrain_position_distance is not None: + constrain_position_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + x1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 0 + ] + y1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 1 + ] + x0 = self._positions_px_initial[:, 0] + y0 = self._positions_px_initial[:, 1] + if self._rotation_best_transpose: + x0, y0 = xp.array([y0, x0]) + x1, y1 = xp.array([y1, x1]) + + if self._rotation_best_rad is not None: + rotation_angle = self._rotation_best_rad + x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( + -rotation_angle + ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) + x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( + -rotation_angle + ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) + + outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( + x1 < (xp.min(x0) - constrain_position_distance) + ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( + y1 < (xp.min(y0) - constrain_position_distance) + ) > 0 + + positions_update[..., 0][outlier_ind] = 0 + + current_positions -= positions_step_size * positions_update[..., 0] + + return current_positions + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + 2D Butterworth filter + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + current_object = xp.pad( + current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" + ) + + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[1:] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + + def _constraints( + self, + current_object, + current_probe, + current_positions, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, + fix_positions, + global_affine_transformation, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + q_lowpass, + q_highpass, + butterworth_order, + kz_regularization_filter, + kz_regularization_gamma, + identical_slices, + object_positivity, + shrinkage_rad, + object_mask, + pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + orthogonalize_probe, + ): + """ + Ptychographic constraints operator. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + current_positions: np.ndarray + Current positions estimate + fix_com: bool + If True, probe CoM is fixed to the center + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool + If True, probe amplitude is constrained by top hat function + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude + fix_positions: bool + If True, positions are not updated + gaussian_filter: bool + If True, applies real-space gaussian filter in A + gaussian_filter_sigma: float + Standard deviation of gaussian kernel + butterworth_filter: bool + If True, applies fourier-space butterworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool + If True, applies fourier-space arctan regularization filter + kz_regularization_gamma: float + Slice regularization strength + identical_slices: bool + If True, forces all object slices to be identical + object_positivity: bool + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + pure_phase_object: bool + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True, performs TV denoising along z + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + orthogonalize_probe: bool + If True, probe will be orthogonalized + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + constrained_probe: np.ndarray + Constrained probe estimate + constrained_positions: np.ndarray + Constrained positions estimate + """ + + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, kz_regularization_gamma + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + pad_object=tv_denoise_pad_chambolle, + ) + + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # These constraints don't _really_ make sense for mixed-state + if fix_probe_aperture: + raise NotImplementedError() + elif constrain_probe_fourier_amplitude: + raise NotImplementedError() + if fit_probe_aberrations: + raise NotImplementedError() + if constrain_probe_amplitude: + raise NotImplementedError() + + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + max_iter: int = 64, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe_iter: int = 0, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions_iter: int = np.inf, + constrain_position_distance: float = None, + global_affine_transformation: bool = True, + gaussian_filter_sigma: float = None, + gaussian_filter_iter: int = np.inf, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + butterworth_filter_iter: int = np.inf, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter_iter: int = np.inf, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices_iter: int = 0, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + switch_object_iter: int = np.inf, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + max_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe_iter: int, optional + Number of iterations to run with a fixed probe before updating probe estimate + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions_iter: int, optional + Number of iterations to run with fixed positions before updating positions estimate + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter_iter: int, optional + Number of iterations to run using object smoothness constraint + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + butterworth_filter_iter: int, optional + Number of iterations to run using high-pass butteworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter_iter: int, optional + Number of iterations to run using kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices_iter: int, optional + Number of iterations to run using identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object_iter: int, optional + Number of iterations where object amplitude is set to unity + tv_denoise_iter_chambolle: bool + Number of iterations with TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + switch_object_iter: int, optional + Iteration to switch object type between 'complex' and 'potential' or between + 'potential' and 'complex' + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + asnumpy = self._asnumpy + xp = self._xp + + # Reconstruction method + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + if self._verbose: + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + if reconstruction_parameter is not None: + if np.array(reconstruction_parameter).shape == (3,): + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + else: + if step_size is not None: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + if max_batch_size is not None: + xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + if reset: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + self._exit_waves = None + self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] + self._exit_waves = None + + # main loop + for a0 in tqdmnd( + max_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if a0 == switch_object_iter: + if self._object_type == "potential": + self._object_type = "complex" + self._object = xp.exp(1j * self._object) + elif self._object_type == "complex": + self._object_type = "potential" + self._object = xp.angle(self._object) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( + self._num_diffraction_patterns + ) + positions_px = self._positions_px.copy()[shuffled_indices] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[shuffled_indices[start:end]] + + # forward operator + ( + propagated_probes, + object_patches, + self._transmitted_probes, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + self._probe, + amplitudes, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + propagated_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + self._probe[0], + self._transmitted_probes[:, 0], + amplitudes, + self._positions_px, + positions_step_size, + constrain_position_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._positions_px = positions_px.copy()[unshuffled_indices] + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=a0 < kz_regularization_filter_iter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=a0 < identical_slices_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _visualize_last_iteration_figax( + self, + fig, + object_ax, + convergence_ax, + cbar: bool, + padding: int = 0, + **kwargs, + ): + """ + Displays last reconstructed object on a given fig/ax. + + Parameters + -------- + fig: Figure + Matplotlib figure object_ax lives in + object_ax: Axes + Matplotlib axes to plot reconstructed object in + convergence_ax: Axes, optional + Matplotlib axes to plot convergence plot in + cbar: bool, optional + If true, displays a colorbar + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + cmap = kwargs.pop("cmap", "magma") + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + im = object_ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(object_ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if convergence_ax is not None and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = self.error_iterations + + convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + padding: int, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + """ + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + probe_array = Complex2RGB( + self.probe_fourier[0], chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + self.probe[0], power=2, chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed probe[0] intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = np.array(self.error_iterations) + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + iterations_grid: Tuple[int, int], + padding: int, + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + if iterations_grid == "auto": + num_iter = len(self.error_iterations) + + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + + errors = np.array(self.error_iterations) + + objects = [] + object_type = [] + + for obj in self.object_iterations: + if np.iscomplexobj(obj): + obj = np.angle(obj) + object_type.append("phase") + else: + object_type.append("potential") + objects.append( + self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) + ) + + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + probes = self.probe_iterations + else: + total_grids = np.prod(iterations_grid) + max_iter = len(objects) - 1 + grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[grid_range[n]], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = Complex2RGB( + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0] + ) + ), + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + cbar: bool = True, + padding: int = 0, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + return self + + def show_fourier_probe( + self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list(self.probe_fourier) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + chroma_boost = kwargs.pop("chroma_boost", 2) + + show_complex( + probe if len(probe) > 1 else probe[0], + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + chroma_boost=chroma_boost, + **kwargs, + ) + + def show_transmitted_probe( + self, + plot_fourier_probe: bool = False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + transmitted_probe_intensities = xp.sum( + xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) + ) + min_intensity_transmitted = self._transmitted_probes[ + xp.argmin(transmitted_probe_intensities), 0 + ] + max_intensity_transmitted = self._transmitted_probes[ + xp.argmax(transmitted_probe_intensities), 0 + ] + mean_transmitted = self._transmitted_probes[:, 0].mean(0) + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy(self._return_fourier_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + show_complex( + probes, + title=title, + **kwargs, + ) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + show_fft: bool = False, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def tune_num_slices_and_thicknesses( + self, + num_slices_guess=None, + thicknesses_guess=None, + num_slices_step_size=1, + thicknesses_step_size=20, + num_slices_values=3, + num_thicknesses_values=3, + update_defocus=False, + max_iter=5, + plot_reconstructions=True, + plot_convergence=True, + return_values=False, + **kwargs, + ): + """ + Run reconstructions over a parameters space of number of slices + and slice thicknesses. Should be run after the preprocess step. + + Parameters + ---------- + num_slices_guess: float, optional + initial starting guess for number of slices, rounds to nearest integer + if None, uses current initialized values + thicknesses_guess: float (A), optional + initial starting guess for thicknesses of slices assuming same + thickness for each slice + if None, uses current initialized values + num_slices_step_size: float, optional + size of change of number of slices for each step in parameter space + thicknesses_step_size: float (A), optional + size of change of slice thicknesses for each step in parameter space + num_slices_values: int, optional + number of number of slice values to test, must be >= 1 + num_thicknesses_values: int,optional + number of thicknesses values to test, must be >= 1 + update_defocus: bool, optional + if True, updates defocus based on estimated total thickness + max_iter: int, optional + number of iterations to run in ptychographic reconstruction + plot_reconstructions: bool, optional + if True, plot phase of reconstructed objects + plot_convergence: bool, optional + if True, plots error for each iteration for each reconstruction + return_values: bool, optional + if True, returns objects, convergence + + Returns + ------- + objects: list + reconstructed objects + convergence: np.ndarray + array of convergence values from reconstructions + """ + + # calculate number of slices and thicknesses values to test + if num_slices_guess is None: + num_slices_guess = self._num_slices + if thicknesses_guess is None: + thicknesses_guess = np.mean(self._slice_thicknesses) + + if num_slices_values == 1: + num_slices_step_size = 0 + + if num_thicknesses_values == 1: + thicknesses_step_size = 0 + + num_slices = np.linspace( + num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_values, + ) + + thicknesses = np.linspace( + thicknesses_guess + - thicknesses_step_size * (num_thicknesses_values - 1) / 2, + thicknesses_guess + + thicknesses_step_size * (num_thicknesses_values - 1) / 2, + num_thicknesses_values, + ) + + if return_values: + convergence = [] + objects = [] + + # current initialized values + current_verbose = self._verbose + current_num_slices = self._num_slices + current_thicknesses = self._slice_thicknesses + current_rotation_deg = self._rotation_best_rad * 180 / np.pi + current_transpose = self._rotation_best_transpose + current_defocus = -self._polar_parameters["C10"] + + # Gridspec to plot on + if plot_reconstructions: + if plot_convergence: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values * 2, + height_ratios=[1, 1 / 4] * num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) + ) + else: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) + ) + + fig = plt.figure(figsize=figsize) + + progress_bar = kwargs.pop("progress_bar", False) + # run loop and plot along the way + self._verbose = False + for flat_index, (slices, thickness) in enumerate( + tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") + ): + slices = int(slices) + self._num_slices = slices + self._slice_thicknesses = np.tile(thickness, slices - 1) + self._probe = None + self._object = None + if update_defocus: + defocus = current_defocus + slices / 2 * thickness + self._polar_parameters["C10"] = -defocus + + self.preprocess( + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + ) + self.reconstruct( + reset=True, + store_iterations=True if plot_convergence else False, + max_iter=max_iter, + progress_bar=progress_bar, + **kwargs, + ) + + if plot_reconstructions: + row_index, col_index = np.unravel_index( + flat_index, (num_slices_values, num_thicknesses_values) + ) + + if plot_convergence: + object_ax = fig.add_subplot(spec[row_index * 2, col_index]) + convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=convergence_ax, + cbar=True, + ) + convergence_ax.yaxis.tick_right() + else: + object_ax = fig.add_subplot(spec[row_index, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=None, + cbar=True, + ) + + object_ax.set_title( + f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" + ) + object_ax.set_xticks([]) + object_ax.set_yticks([]) + + if return_values: + objects.append(self.object) + convergence.append(self.error_iterations.copy()) + + # initialize back to pre-tuning values + self._probe = None + self._object = None + self._num_slices = current_num_slices + self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) + self._polar_parameters["C10"] = -current_defocus + self.preprocess( + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + ) + self._verbose = current_verbose + + if plot_reconstructions: + spec.tight_layout(fig) + + if return_values: + return objects, convergence + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + obj = asnumpy(obj) + if np.iscomplexobj(obj): + obj = np.angle(obj) + + obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index 56fec1004..d68291143 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -74,6 +74,8 @@ class MixedstatePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -102,6 +104,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "mixed-state_ptychographic_reconstruction", @@ -161,6 +164,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -178,6 +187,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -204,6 +214,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -261,6 +272,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -349,6 +362,8 @@ def preprocess( self._intensities, self._com_fitted_x, self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -357,7 +372,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -429,6 +444,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( @@ -505,19 +524,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -540,23 +553,19 @@ def preprocess( axs[i].imshow( complex_probe_rgb[i], extent=probe_extent, - **kwargs, ) axs[i].set_ylabel("x [A]") axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial Probe[{i}]") + axs[i].set_title(f"Initial probe[{i}] intensity") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) axs[-1].scatter( self.positions[:, 1], @@ -568,7 +577,7 @@ def preprocess( axs[-1].set_xlabel("y [A]") axs[-1].set_xlim((extent[0], extent[1])) axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object Field of View") + axs[-1].set_title("Object field of view") fig.tight_layout() @@ -1125,6 +1134,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, orthogonalize_probe, object_positivity, shrinkage_rad, @@ -1183,6 +1195,12 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter orthogonalize_probe: bool If True, probe will be orthogonalized + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1213,6 +1231,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1281,6 +1304,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, + constrain_position_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -1290,6 +1314,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1353,6 +1380,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1373,6 +1402,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1575,6 +1610,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1667,6 +1704,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, + constrain_position_distance, ) error += batch_error @@ -1707,6 +1745,9 @@ def reconstruct( q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1821,8 +1862,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1915,29 +1959,31 @@ def _visualize_last_iteration( if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert + self.probe_fourier[0], + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert + self.probe[0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe[0]") + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1970,10 +2016,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2045,8 +2091,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2149,29 +2198,30 @@ def _visualize_all_iterations( probes[grid_range[n]][0] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2183,7 +2233,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2280,11 +2330,61 @@ def show_fourier_probe( if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 2) + show_complex( probe if len(probe) > 1 else probe[0], scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a352502d0..93e32b079 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -8,14 +8,15 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -29,6 +30,7 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -78,9 +80,17 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -109,7 +119,11 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -142,6 +156,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: @@ -165,6 +198,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -181,6 +220,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -189,6 +229,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -196,6 +238,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -211,6 +255,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) Returns ------- @@ -230,6 +278,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -237,6 +289,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -279,6 +337,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -336,6 +395,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -424,6 +485,8 @@ def preprocess( self._intensities, self._com_fitted_x, self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -432,7 +495,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -503,6 +566,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -559,6 +626,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -575,19 +644,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -599,10 +662,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -624,38 +685,34 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -667,7 +724,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1449,6 +1506,111 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1481,9 +1643,12 @@ def _constraints( shrinkage_rad, object_mask, pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, tv_denoise, - tv_denoise_weight, - tv_denoise_pad, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1548,12 +1713,19 @@ def _constraints( If not None, used to calculate additional shrinkage using masked-mean of object pure_phase_object: bool If True, object amplitude is set to unity - tv_denoise: bool + tv_denoise_chambolle: bool If True, performs TV denoising along z - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1585,13 +1757,17 @@ def _constraints( current_object, kz_regularization_gamma ) elif tv_denoise: - if self._object_type == "complex": - raise NotImplementedError() + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, - tv_denoise_weight, + tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad, + pad_object=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1690,9 +1866,12 @@ def reconstruct( shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, - tv_denoise_weight=None, - tv_denoise_pad=True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1751,6 +1930,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1785,12 +1966,19 @@ def reconstruct( If true, the potential mean outside the FOV is forced to zero at each iteration pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - tv_denoise_iter: bool + tv_denoise_iter_chambolle: bool Number of iterations with TV denoisining - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -1987,6 +2175,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2123,7 +2313,7 @@ def reconstruct( and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None - and type(kz_regularization_gamma) == np.ndarray + and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, @@ -2133,9 +2323,13 @@ def reconstruct( else None, pure_phase_object=a0 < pure_phase_object_iter and self._object_type == "complex", - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_pad=tv_denoise_pad, + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2250,8 +2444,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2347,29 +2544,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2402,10 +2599,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2477,8 +2674,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2583,29 +2783,28 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], power=2, chroma_boost=chroma_boost ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2617,7 +2816,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2756,6 +2955,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -2771,12 +2971,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -2794,8 +3002,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) @@ -2841,6 +3062,143 @@ def show_slices( spec.tight_layout(fig) + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + -int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + def tune_num_slices_and_thicknesses( self, num_slices_guess=None, @@ -3068,3 +3426,14 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 8691a121d..c49a1faac 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import make_axes_locatable from py4DSTEM.visualize import show @@ -16,8 +17,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -92,6 +93,8 @@ class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -114,6 +117,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -162,6 +166,13 @@ def __init__( if object_type != "potential": raise NotImplementedError() + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -178,6 +189,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -430,6 +442,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -474,6 +487,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -594,6 +609,8 @@ def preprocess( intensities, com_fitted_x, com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -613,7 +630,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels @@ -684,6 +701,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -807,19 +828,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -831,10 +846,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -856,38 +869,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -899,7 +911,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1679,6 +1691,111 @@ def _divergence_free_constraint(self, vector_field): return vector_field + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1710,6 +1827,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1771,6 +1891,15 @@ def _constraints( If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1822,6 +1951,31 @@ def _constraints( butterworth_order, ) + elif tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[1] = self._object_denoise_tv_pylops( + current_object[1], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[2] = self._object_denoise_tv_pylops( + current_object[2], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[3] = self._object_denoise_tv_pylops( + current_object[3], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object[0] = self._object_shrinkage_constraint( current_object[0], @@ -1913,6 +2067,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1998,6 +2155,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2171,12 +2337,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2477,6 +2644,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2530,6 +2701,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2929,7 +3103,7 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[-1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() spec.tight_layout(fig) @@ -3129,22 +3303,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -3168,3 +3336,29 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index d6bee12fd..ddd13ac58 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np +import pylops from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable from py4DSTEM.visualize import show @@ -16,8 +17,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -55,8 +56,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): The electron energy of the wave functions in eV num_slices: int Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of tilt angles in degrees, + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -87,6 +88,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction name: str, optional Class name kwargs: @@ -94,13 +97,14 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[float], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -109,6 +113,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -122,22 +127,29 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from scipy.special import erf self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import ( + affine_transform, + gaussian_filter, + rotate, + zoom, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from cupyx.scipy.special import erf self._erf = erf @@ -156,13 +168,20 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts if object_type != "potential": raise NotImplementedError() + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -179,13 +198,14 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts def _precompute_propagator_arrays( @@ -323,6 +343,29 @@ def _expand_sliced_object(self, array: np.ndarray, output_z): normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + ): + """ """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=3) + + return volume + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -340,6 +383,7 @@ def preprocess( force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -384,6 +428,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -503,6 +549,8 @@ def preprocess( intensities, com_fitted_x, com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -522,7 +570,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels @@ -593,6 +641,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -663,15 +715,14 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - current_angle_deg = self._tilt_angles_deg[tilt_index] - probe_overlap_3D = self._rotate( + rot_matrix = self._tilt_orientation_matrices[tilt_index] + + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, + rot_matrix @ old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -691,14 +742,12 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix - probe_overlap_3D = self._rotate( - probe_overlap_3D, - -current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, - ) + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( @@ -719,19 +768,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -743,10 +786,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -768,38 +809,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -811,7 +851,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1527,6 +1567,111 @@ def _object_butterworth_constraint( current_object += current_object_mean return xp.real(current_object) + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1555,6 +1700,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1611,6 +1759,13 @@ def _constraints( Phase shift in radians to be subtracted from the potential at each iteration object_mask: np.ndarray (boolean) If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1634,6 +1789,12 @@ def _constraints( q_highpass, butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( @@ -1723,6 +1884,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1806,6 +1970,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -1981,12 +2154,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2018,17 +2192,17 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index tilt_error = 0.0 - self._object = self._rotate( + rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + self._object = self._rotate_zxy_volume( self._object, - self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix @ old_rot_matrix.T, ) object_sliced = self._project_sliced_object( @@ -2132,23 +2306,13 @@ def reconstruct( ) if collective_tilt_updates: - collective_object += self._rotate( - object_update, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + collective_object += self._rotate_zxy_volume( + object_update, rot_matrix.T ) else: self._object += object_update - self._object = self._rotate( - self._object, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, - ) + old_rot_matrix = rot_matrix # Normalize Error tilt_error /= ( @@ -2203,8 +2367,14 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) + # Normalize Error Over Tilts error /= self._num_tilts @@ -2251,6 +2421,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2431,8 +2604,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) asnumpy = self._asnumpy @@ -2534,16 +2710,19 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2556,7 +2735,10 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg( + ax_cb, + chroma_boost=chroma_boost, + ) else: ax = fig.add_subplot(spec[0]) im = ax.imshow( @@ -2585,10 +2767,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2672,8 +2854,11 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2788,29 +2973,30 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2822,7 +3008,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2997,22 +3183,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -3036,3 +3216,29 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 80cdd8cd8..716e1d782 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -8,22 +8,44 @@ import matplotlib.pyplot as plt import numpy as np -from emdfile import Custom, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec -from py4DSTEM import DataCube +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable +from py4DSTEM import Calibration, DataCube +from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.visualize import show from scipy.linalg import polar +from scipy.optimize import minimize from scipy.special import comb try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np warnings.simplefilter(action="always", category=UserWarning) +_aberration_names = { + (1, 0): "C1 ", + (1, 2): "stig ", + (2, 1): "coma ", + (2, 3): "trefoil ", + (3, 0): "C3 ", + (3, 2): "stig2 ", + (3, 4): "quadfoil ", + (4, 1): "coma2 ", + (4, 3): "trefoil2 ", + (4, 5): "pentafoil ", + (5, 0): "C5 ", + (5, 2): "stig3 ", + (5, 4): "quadfoil2 ", + (5, 6): "hexafoil ", +} + class ParallaxReconstruction(PhaseReconstruction): """ @@ -35,9 +57,6 @@ class ParallaxReconstruction(PhaseReconstruction): Input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - dp_mean: ndarray, optional - Mean diffraction pattern - If None, get_dp_mean() is used verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -73,6 +92,8 @@ def __init__( else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_save_defaults() + # Data self._datacube = datacube @@ -86,9 +107,78 @@ def __init__( def to_h5(self, group): """ Wraps datasets and metadata to write in emdfile classes, - notably ... + notably the (subpixel-)aligned BF. """ - raise NotImplementedError() + # instantiation metadata + self.metadata = Metadata( + name="instantiation_metadata", + data={ + "energy": self._energy, + "verbose": self._verbose, + "device": self._device, + "object_padding_px": self._object_padding_px, + "name": self.name, + }, + ) + + # preprocessing metadata + self.metadata = Metadata( + name="preprocess_metadata", + data={ + "scan_sampling": self._scan_sampling, + "wavelength": self._wavelength, + }, + ) + + # reconstruction metadata + recon_metadata = {"reconstruction_error": float(self._recon_error)} + + if hasattr(self, "aberration_C1"): + recon_metadata |= { + "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_transpose": self.transpose, + "aberration_C1": self.aberration_C1, + "aberration_A1x": self.aberration_A1x, + "aberration_A1y": self.aberration_A1y, + } + + if hasattr(self, "_kde_upsample_factor"): + recon_metadata |= { + "kde_upsample_factor": self._kde_upsample_factor, + } + self._subpixel_aligned_BF_emd = Array( + name="subpixel_aligned_BF", + data=self._asnumpy(self._recon_BF_subpixel_aligned), + ) + + if hasattr(self, "aberration_dict"): + self.metadata = Metadata( + name="aberrations_metadata", + data={ + v["aberration name"]: v["value [Ang]"] + for k, v in self.aberration_dict.items() + }, + ) + + self.metadata = Metadata( + name="reconstruction_metadata", + data=recon_metadata, + ) + + self._aligned_BF_emd = Array( + name="aligned_BF", + data=self._asnumpy(self._recon_BF), + ) + + # datacube + if self._save_datacube: + self.metadata = self._datacube.calibration + Custom.to_h5(self, group) + else: + dc = self._datacube + self._datacube = None + Custom.to_h5(self, group) + self._datacube = dc @classmethod def _get_constructor_args(cls, group): @@ -96,14 +186,68 @@ def _get_constructor_args(cls, group): Returns a dictionary of arguments/values to pass to the class' __init__ function """ - raise NotImplementedError() + # Get data + dict_data = cls._get_emd_attr_data(cls, group) + + # Get metadata dictionaries + instance_md = _read_metadata(group, "instantiation_metadata") + + # Fix calibrations bug + if "_datacube" in dict_data: + calibrations_dict = _read_metadata(group, "calibration")._params + cal = Calibration() + cal._params.update(calibrations_dict) + dc = dict_data["_datacube"] + dc.calibration = cal + else: + dc = None + + # Populate args and return + kwargs = { + "datacube": dc, + "energy": instance_md["energy"], + "object_padding_px": instance_md["object_padding_px"], + "name": instance_md["name"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility + } + + return kwargs def _populate_instance(self, group): """ Sets post-initialization properties, notably some preprocessing meta optional; during read, this method is run after object instantiation. """ - raise NotImplementedError() + + xp = self._xp + + # Preprocess metadata + preprocess_md = _read_metadata(group, "preprocess_metadata") + self._scan_sampling = preprocess_md["scan_sampling"] + self._wavelength = preprocess_md["wavelength"] + + # Reconstruction metadata + reconstruction_md = _read_metadata(group, "reconstruction_metadata") + self._recon_error = reconstruction_md["reconstruction_error"] + + # Data + dict_data = Custom._get_emd_attr_data(Custom, group) + + if "aberration_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.transpose = reconstruction_md["aberration_transpose"] + self.aberration_C1 = reconstruction_md["aberration_C1"] + self.aberration_A1x = reconstruction_md["aberration_A1x"] + self.aberration_A1y = reconstruction_md["aberration_A1y"] + + if "kde_upsample_factor" in reconstruction_md.keys: + self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] + self._recon_BF_subpixel_aligned = xp.asarray( + dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32 + ) + + self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) def preprocess( self, @@ -111,6 +255,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, + descan_correct: bool = True, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, @@ -133,6 +278,8 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + descan_correct: float, optional + If True, aligns bright field stack based on measured descan rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 @@ -171,7 +318,10 @@ def preprocess( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities, dtype=xp.float32) + + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) + self._scan_shape = np.array(self._intensities.shape[:2]) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -180,6 +330,45 @@ def preprocess( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct + if descan_correct: + ( + _, + _, + com_fitted_x, + com_fitted_y, + _, + _, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=None, + fit_function="plane", + com_shifts=None, + com_measured=None, + ) + + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + intensities = asnumpy(self._intensities) + intensities_shifted = np.zeros_like(intensities) + + center_x, center_y = self._region_of_interest_shape / 2 + + for rx in range(intensities_shifted.shape[0]): + for ry in range(intensities_shifted.shape[1]): + intensity_shifted = get_shifted_ar( + intensities[rx, ry], + -com_fitted_x[rx, ry] + center_x, + -com_fitted_y[rx, ry] + center_y, + bilinear=True, + device="cpu", + ) + + intensities_shifted[rx, ry] = intensity_shifted + + self._intensities = xp.asarray(intensities_shifted, xp.float32) + self._dp_mean = self._intensities.mean((0, 1)) + # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) self._num_bf_images = int(xp.count_nonzero(self._dp_mask)) @@ -187,14 +376,16 @@ def preprocess( # diffraction space coordinates self._xy_inds = np.argwhere(self._dp_mask) - self._kxy = (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) * xp.array( - self._reciprocal_sampling - )[None] + self._kxy = xp.asarray( + (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) + * xp.array(self._reciprocal_sampling)[None], + dtype=xp.float32, + ) self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) # Window function - x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1)[1:] + x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( xp.sin( @@ -205,7 +396,7 @@ def preprocess( ) ** 2 ) - y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1)[1:] + y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:] y -= (y[1] - y[0]) / 2 wy = ( xp.sin( @@ -222,7 +413,8 @@ def preprocess( ( self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], - ) + ), + dtype=xp.float32, ) self._window_pad[ self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -245,7 +437,8 @@ def preprocess( self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape) + self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -259,13 +452,21 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + elif normalize_order == 1: - x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) - y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones(xa.size), + xp.ones_like(xa), xa.ravel(), ya.ravel(), ) @@ -285,9 +486,18 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) + self._stack_BF_no_window[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + else: all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -299,9 +509,21 @@ def preprocess( + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.asarray(qx, dtype=xp.float32) + qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.asarray(qy, dtype=xp.float32) + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") self._qx_shift = -2j * xp.pi * qxa self._qy_shift = -2j * xp.pi * qya @@ -336,7 +558,7 @@ def preprocess( del Gs else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2)) + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) self._stack_mean = xp.mean(self._stack_BF) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images @@ -365,16 +587,27 @@ def preprocess( self.recon_BF = asnumpy(self._recon_BF) if plot_average_bf: - figsize = kwargs.pop("figsize", (6, 6)) + figsize = kwargs.pop("figsize", (6, 12)) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(1, 2, figsize=figsize) - self._visualize_figax(fig, ax, **kwargs) + self._visualize_figax(fig, ax[0], **kwargs) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Average Bright Field Image") + ax[0].set_ylabel("x [A]") + ax[0].set_xlabel("y [A]") + ax[0].set_title("Average Bright Field Image") + reciprocal_extent = [ + -0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + ] + ax[1].imshow(self._dp_mask, extent=reciprocal_extent, cmap="gray") + ax[1].set_title("DP mask") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + plt.tight_layout() self._preprocessed = True if self._device == "gpu": @@ -506,8 +739,6 @@ def tune_angle_and_defocus( convergence.append(asnumpy(self._recon_error[0])) if plot_convergence: - from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - fig, ax = plt.subplots() ax.set_title("convergence") im = ax.imshow( @@ -533,9 +764,9 @@ def tune_angle_and_defocus( divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) + fig.colorbar(im, cax=cax) - plt.tight_layout() + fig.tight_layout() if return_values: convergence = np.array(convergence).reshape( @@ -548,7 +779,7 @@ def reconstruct( max_alignment_bin: int = None, min_alignment_bin: int = 1, max_iter_at_min_bin: int = 2, - upsample_factor: int = 8, + cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, running_average: bool = True, @@ -570,7 +801,7 @@ def reconstruct( Minimum bin size for bright field alignment max_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size - upsample_factor: int, optional + cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional Bernstein basis degree used for regularizing shifts @@ -623,7 +854,8 @@ def reconstruct( ( self._num_bf_images, (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1), - ) + ), + dtype=xp.float32, ) for ii in np.arange(regularizer_matrix_size[0] + 1): Bi = ( @@ -708,7 +940,7 @@ def reconstruct( # Sort by radial order, from center to outer edge inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1)) - shifts_update = xp.zeros((self._num_bf_images, 2)) + shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) for a1 in tqdmnd( xy_vals.shape[0], @@ -730,7 +962,7 @@ def reconstruct( xy_shift = align_images_fourier( G_ref, G, - upsample_factor=upsample_factor, + upsample_factor=cross_correlation_upsample_factor, device=self._device, ) @@ -777,11 +1009,19 @@ def reconstruct( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) + self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) + self._stack_BF = xp.asarray( + self._stack_BF, dtype=xp.float32 + ) # numpy fft upcasts? + self._stack_mask = xp.asarray( + self._stack_mask, dtype=xp.float32 + ) # numpy fft upcasts? + del Gs # Center the shifts @@ -837,31 +1077,309 @@ def reconstruct( return self + def subpixel_alignment( + self, + kde_upsample_factor=None, + kde_sigma=0.125, + plot_upsampled_BF_comparison: bool = True, + plot_upsampled_FFT_comparison: bool = False, + **kwargs, + ): + """ + Upsample and subpixel-align BFs using the measured image shifts. + Uses kernel density estimation (KDE) to align upsampled BFs. + + Parameters + ---------- + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma: float, optional + KDE gaussian kernel bandwidth + plot_upsampled_BF_comparison: bool, optional + If True, the pre/post alignment BF images are plotted for comparison + plot_upsampled_FFT_comparison: bool, optional + If True, the pre/post alignment BF FFTs are plotted for comparison + + """ + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + xy_shifts = self._xy_shifts + BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + + self._DF_upsample_limit = np.max( + 2 * self._region_of_interest_shape / self._scan_shape + ) + self._BF_upsample_limit = ( + 4 * self._kr.max() / self._reciprocal_sampling[0] + ) / self._scan_shape.max() + if self._device == "gpu": + self._BF_upsample_limit = self._BF_upsample_limit.item() + + if kde_upsample_factor is None: + if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit: + kde_upsample_factor = self._DF_upsample_limit + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (the " + f"dark-field upsampling limit)." + ), + UserWarning, + ) + + elif self._BF_upsample_limit * 3 / 2 > 1: + kde_upsample_factor = self._BF_upsample_limit * 3 / 2 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + else: + kde_upsample_factor = self._DF_upsample_limit * 2 / 3 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (2/3 times the " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f})." + ), + UserWarning, + ) + + if kde_upsample_factor < 1: + raise ValueError("kde_upsample_factor must be larger than 1") + + if kde_upsample_factor > self._DF_upsample_limit: + warnings.warn( + ( + "Requested upsampling factor exceeds " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}." + ), + UserWarning, + ) + + self._kde_upsample_factor = kde_upsample_factor + pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") + pixel_size = pixel_output.prod() + + # shifted coordinates + x = xp.arange(BF_size[0]) + y = xp.arange(BF_size[1]) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + sigma = kde_sigma * self._kde_upsample_factor + pix_count = gaussian_filter(pix_count, sigma) + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, sigma) + pix_output /= pix_count + + self._recon_BF_subpixel_aligned = pix_output + self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + + # plotting + if plot_upsampled_BF_comparison: + if plot_upsampled_FFT_comparison: + figsize = kwargs.pop("figsize", (8, 8)) + fig, axs = plt.subplots(2, 2, figsize=figsize) + else: + figsize = kwargs.pop("figsize", (8, 4)) + fig, axs = plt.subplots(1, 2, figsize=figsize) + + axs = axs.flat + cmap = kwargs.pop("cmap", "magma") + + cropped_object = self._crop_padded_object(self._recon_BF) + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + for ax in axs[:2]: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if plot_upsampled_FFT_comparison: + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = np.round( + BF_size[0] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_y = np.round( + BF_size[1] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_recon_fft = asnumpy( + xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + show( + pad_recon_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Aligned Bright Field FFT", + **kwargs, + ) + + show( + upsampled_fft, + figax=(fig, axs[3]), + extent=reciprocal_extent, + cmap="gray", + title="Upsampled Bright Field FFT", + **kwargs, + ) + + for ax in axs[2:]: + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.xaxis.set_ticks_position("bottom") + + fig.tight_layout() + def aberration_fit( self, - plot_CTF_compare: bool = False, - plot_dk: float = 0.005, - plot_k_sigma: float = 0.02, + fit_BF_shifts: bool = False, + fit_CTF_FFT: bool = False, + fit_aberrations_max_radial_order: int = 3, + fit_aberrations_max_angular_order: int = 4, + fit_aberrations_min_radial_order: int = 2, + fit_aberrations_min_angular_order: int = 0, + fit_max_thon_rings: int = 6, + fit_power_alpha: float = 2.0, + plot_CTF_comparison: bool = None, + plot_BF_shifts_comparison: bool = None, + upsampled: bool = True, + force_transpose: bool = False, ): """ Fit aberrations to the measured image shifts. Parameters ---------- - plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies - plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT - plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by + fit_BF_shifts: bool + Set to True to fit aberrations to the measured BF shifts directly. + fit_CTF_FFT: bool + Set to True to fit aberrations in the FFT of the (upsampled) BF + image. Note that this method relies on visible zero crossings in the FFT. + fit_aberrations_max_radial_order: int + Max radial order for fitting of aberrations. + fit_aberrations_max_angular_order: int + Max angular order for fitting of aberrations. + fit_aberrations_min_radial_order: int + Min radial order for fitting of aberrations. + fit_aberrations_min_angular_order: int + Min angular order for fitting of aberrations. + fit_max_thon_rings: int + Max number of Thon rings to search for during CTF FFT fitting. + fit_power_alpha: int + Power to raise FFT alpha weighting during CTF FFT fitting. + plot_CTF_comparison: bool, optional + If True, the fitted CTF is plotted against the reconstructed frequencies. + plot_BF_shifts_comparison: bool, optional + If True, the measured vs fitted BF shifts are plotted. + upsampled: bool + If True, and upsampled BF is available, uses that for CTF FFT fitting. + force_transpose: bool + If True, and fit_BF_shifts is True, flips the measured x and y shifts """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + + ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + self.transpose = force_transpose # Solve affine transformation m = asnumpy( @@ -878,18 +1396,417 @@ def aberration_fit( np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi ) m_aberration = -1.0 * m_aberration + self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = ( - m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 # factor /2 for A1 astigmatism? /4? - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + if self.transpose: + self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + else: + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + + ### Second pass + + # Aberration coefs + mn = [] + + for m in range( + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + self._aberrations_mn = np.array(mn) + self._aberrations_mn = self._aberrations_mn[ + np.argsort(self._aberrations_mn[:, 1]), : + ] + + sub = self._aberrations_mn[:, 1] > 0 + self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ + np.argsort(self._aberrations_mn[sub, 0]), : + ] + self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][ + np.argsort(self._aberrations_mn[~sub, 0]), : + ] + self._aberrations_num = self._aberrations_mn.shape[0] + + if plot_CTF_comparison is None: + if fit_CTF_FFT: + plot_CTF_comparison = True + + if plot_BF_shifts_comparison is None: + if fit_BF_shifts: + plot_BF_shifts_comparison = True + + # Thon Rings Fitting + if fit_CTF_FFT or plot_CTF_comparison: + if upsampled and hasattr(self, "_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + else: + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + upsampled = False + + reciprocal_extent = [ + -0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[0], + -0.5 / self._scan_sampling[0], + ] + + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + self._aberrations_basis_FFT = xp.zeros( + (alpha_FFT.size, self._aberrations_num) + ) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) / (m + 1) + ).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) + ).ravel() + else: + # sin coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape_FFT = alpha_FFT.shape + plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # CTF function + def calculate_CTF_FFT(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] + return xp.reshape(chi, alpha_shape) + + # Direct Shifts Fitting + if fit_BF_shifts: + # FFT coordinates + sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) + sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) + qx, qy = np.meshgrid(qx, qy, indexing="ij") + + # passive rotation basis by -theta + rotation_angle = -self.rotation_Q_to_R_rads + qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( + rotation_angle + ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) + + qr2 = qx**2 + qy**2 + u = qx * self._wavelength + v = qy * self._wavelength + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy, qx) + + # Aberration basis + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape = alpha.shape + + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) + + # initial coefficients and plotting intensity range mask + self._aberrations_coefs = np.zeros(self._aberrations_num) + + aberrations_mn_list = self._aberrations_mn.tolist() + if [1, 0, 0] in aberrations_mn_list: + ind_C1 = aberrations_mn_list.index([1, 0, 0]) + self._aberrations_coefs[ind_C1] = self.aberration_C1 + + if [1, 2, 0] in aberrations_mn_list: + ind_A1x = aberrations_mn_list.index([1, 2, 0]) + ind_A1y = aberrations_mn_list.index([1, 2, 1]) + self._aberrations_coefs[ind_A1x] = self.aberration_A1x + self._aberrations_coefs[ind_A1y] = self.aberration_A1y + + # Refinement using CTF fitting / Thon rings + if fit_CTF_FFT: + # scoring function to minimize - mean value of zero crossing regions of FFT + def score_CTF(coefs): + im_CTF = xp.abs( + calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) + ) + mask = xp.logical_and( + im_CTF > 0.5 * np.pi, + im_CTF < (max_num_rings + 0.5) * np.pi, + ) + if np.any(mask): + weights = xp.cos(im_CTF[mask]) ** 4 + return asnumpy( + xp.sum( + weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha + ) + / xp.sum(weights) + ) + else: + return np.inf + + for max_num_rings in range(1, fit_max_thon_rings + 1): + # minimization + res = minimize( + score_CTF, + self._aberrations_coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method="BFGS", + tol=1e-8, + ) + self._aberrations_coefs = res.x + + # Refinement using CTF fitting / Thon rings + elif fit_BF_shifts: + # Gradient basis + corner_indices = self._xy_inds - xp.asarray( + self._region_of_interest_shape // 2 + ) + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.vstack( + ( + self._aberrations_basis_du[raveled_indices, :], + self._aberrations_basis_dv[raveled_indices, :], + ) + ) + + # (Relative) untransposed fit + raveled_shifts = self._xy_shifts_Ang.T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, raveled_shifts, rcond=None + )[:2] + + self._aberrations_coefs = asnumpy(aberrations_coefs) + + if self.transpose: + aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( + self._aberrations_mn[:, 2] == 0 + ) + self._aberrations_coefs[aberrations_to_flip] *= -1 + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] + + fitted_shifts = ( + xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) + .reshape((2, -1)) + .T + ) + + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] + + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] + + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] + ) + ) + + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], + [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Fitted Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Horizontal Shifts", + ], + ) + + # Plot the CTF comparison between experiment and fit + if plot_CTF_comparison: + # Generate FFT plotting image + im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) + int_vals = np.sort(im_scale.ravel()) + int_range = ( + int_vals[np.round(0.02 * im_scale.size).astype("int")], + int_vals[np.round(0.98 * im_scale.size).astype("int")], + ) + int_range = ( + int_range[0], + (int_range[1] - int_range[0]) * 1.0 + int_range[0], + ) + im_scale = np.clip( + (np.fft.fftshift(im_scale) - int_range[0]) + / (int_range[1] - int_range[0]), + 0, + 1, + ) + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) + + # Add CTF zero crossings + im_CTF = calculate_CTF_FFT( + self._aberrations_surface_shape_FFT, *self._aberrations_coefs + ) + im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 + im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 + im_CTF[xp.logical_not(plot_mask)] = 0 + + im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) + im_plot[:, :, 0] += im_CTF + im_plot[:, :, 1] -= im_CTF + im_plot[:, :, 2] -= im_CTF + im_plot = np.clip(im_plot, 0, 1) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) + ax1.imshow( + im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent + ) + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap="gray", + extent=reciprocal_extent, + ) + + for ax in (ax1, ax2): + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + + ax1.set_title("Aligned Bright Field FFT") + ax2.set_title("Fitted CTF Zero-Crossings") + + fig.tight_layout() + + self.aberration_dict = { + tuple(self._aberrations_mn[a0]): { + "aberration name": _aberration_names.get( + tuple(self._aberrations_mn[a0, :2]), "-" + ).strip(), + "value [Ang]": self._aberrations_coefs[a0], + } + for a0 in range(self._aberrations_num) + } # Print results if self._verbose: + if fit_CTF_FFT or fit_BF_shifts: + print("Initial Aberration coefficients") + print("-------------------------------") print( ( "Rotation of Q w.r.t. R = " @@ -903,99 +1820,105 @@ def aberration_fit( f"{self.aberration_A1y:.0f}) Ang" ) ) - if self.aberration_C1 > 0: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - else: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - - # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: - # Get polar mean from FFT of BF reconstruction - im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - k_max = xp.max(kra) / np.sqrt(2.0) - k_num_bins = int(xp.ceil(k_max / plot_dk)) - k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # histogram - k_ind = kra / plot_dk - kf = np.floor(k_ind).astype("int") - dk = k_ind - kf - sub = kf <= k_num_bins - hist_exp = xp.bincount( - kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins - ) - hist_norm = xp.bincount( - kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins - ) - sub = kf <= k_num_bins - 1 - - hist_exp += xp.bincount( - kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins - ) - hist_norm += xp.bincount( - kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins - ) + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Transpose = {self.transpose}") + + if fit_CTF_FFT or fit_BF_shifts: + print() + print("Refined Aberration coefficients") + print("-------------------------------") + print("aberration radial angular dir. coefs") + print("name order order Ang ") + print("---------- ------- ------- ---- -----") + + for a0 in range(self._aberrations_mn.shape[0]): + m, n, a = self._aberrations_mn[a0] + name = _aberration_names.get((m, n), " -- ") + if n == 0: + print( + name + + " " + + str(m + 1) + + " 0 - " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + elif a == 0: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " x " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + else: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " y " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) - # KDE and normalizing - k_sigma = plot_dk / plot_k_sigma - hist_exp[0] = 0.0 - hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - hist_exp /= hist_norm + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() - # CTF comparison - CTF_fit = xp.sin( - (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 - ) + def _calculate_CTF(self, alpha_shape, sampling, *coefs): + xp = self._xp - # plotting input - log scale - min_hist_val = xp.max(hist_exp) * 1e-3 - hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - hist_plot -= xp.min(hist_plot) - hist_plot /= xp.max(hist_plot) + # FFT coordinates + sx, sy = sampling + qx = xp.fft.fftfreq(alpha_shape[0], sx) + qy = xp.fft.fftfreq(alpha_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - hist_plot = asnumpy(hist_plot) - k_bins = asnumpy(k_bins) - CTF_fit = asnumpy(CTF_fit) + # global scaling + aberrations_basis *= 2 * np.pi / self._wavelength - fig, ax = plt.subplots(figsize=(8, 4)) + chi = xp.zeros_like(aberrations_basis[:, 0]) - ax.fill_between( - k_bins, - hist_plot, - color=(0.7, 0.7, 0.7, 1), - ) + for a0 in range(len(coefs)): + chi += coefs[a0] * aberrations_basis[:, a0] - ax.plot( - k_bins, - np.clip(CTF_fit, 0.0, np.inf), - color=(1, 0, 0, 1), - linewidth=2, - ) - ax.plot( - k_bins, - np.clip(-CTF_fit, 0.0, np.inf), - color=(0, 0.5, 1, 1), - linewidth=2, - ) - ax.set_xlim([0, k_bins[-1]]) - ax.set_ylim([0, 1.05]) + return xp.reshape(chi, alpha_shape) def aberration_correct( self, + use_CTF_fit=None, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, + Wiener_signal_noise_ratio: float = 1.0, + Wiener_filter_low_only: bool = False, + upsampled: bool = True, **kwargs, ): """ @@ -1003,6 +1926,9 @@ def aberration_correct( Parameters ---------- + use_FFT_fit: bool + Use the CTF fitted to the zero crossings of the FFT. + Default is True plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional @@ -1028,46 +1954,79 @@ def aberration_correct( ) ) + if upsampled and hasattr(self, "_kde_upsample_factor"): + im = self._recon_BF_subpixel_aligned + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + else: + upsampled = False + im = self._recon_BF + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + kx = xp.fft.fftfreq(im.shape[0], sx) + ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + if use_CTF_fit is None: + if hasattr(self, "_aberrations_surface_shape"): + use_CTF_fit = True - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio + if use_CTF_fit: + sin_chi = np.sin( + self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr - - else: - # CTF without tilt correction (beyond the parallax operator) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr # if needed, add low pass filter output image if k_info_limit is not None: im_fft_corr /= 1 + (kra2**k_info_power) / ( (k_info_limit) ** (2 * k_info_power) ) + else: + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + + (kra2**k_info_power) + / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) @@ -1084,12 +2043,14 @@ def aberration_correct( fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_phase_corrected) + cropped_object = self._crop_padded_object( + self._recon_phase_corrected, upsampled=upsampled + ) extent = [ 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], 0, ] @@ -1246,6 +2207,7 @@ def _crop_padded_object( self, padded_object: np.ndarray, remaining_padding: int = 0, + upsampled: bool = False, ): """ Utility function to crop padded object @@ -1266,8 +2228,19 @@ def _crop_padded_object( asnumpy = self._asnumpy - pad_x = self._object_padding_px[0] // 2 - remaining_padding - pad_y = self._object_padding_px[1] // 2 - remaining_padding + if upsampled: + pad_x = np.round( + self._object_padding_px[0] / 2 * self._kde_upsample_factor + ).astype("int") + pad_y = np.round( + self._object_padding_px[1] / 2 * self._kde_upsample_factor + ).astype("int") + else: + pad_x = self._object_padding_px[0] // 2 + pad_y = self._object_padding_px[1] // 2 + + pad_x -= remaining_padding + pad_y -= remaining_padding return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) @@ -1276,6 +2249,7 @@ def _visualize_figax( fig, ax, remaining_padding: int = 0, + upsampled: bool = False, **kwargs, ): """ @@ -1294,14 +2268,31 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + if upsampled: + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, remaining_padding, upsampled + ) - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + extent = [ + 0, + self._scan_sampling[1] + * cropped_object.shape[1] + / self._kde_upsample_factor, + self._scan_sampling[0] + * cropped_object.shape[0] + / self._kde_upsample_factor, + 0, + ] + + else: + cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] ax.imshow( cropped_object, @@ -1310,10 +2301,11 @@ def _visualize_figax( **kwargs, ) - def _visualize_shifts( + def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, + plot_rotated_shifts=True, **kwargs, ): """ @@ -1330,10 +2322,22 @@ def _visualize_shifts( xp = self._xp asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (6, 6)) color = kwargs.pop("color", (1, 0, 0, 1)) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + scaling_factor = ( + xp.array(self._reciprocal_sampling) + / xp.array(self._scan_sampling) + * scale_arrows + ) + rotated_shifts = self._xy_shifts_Ang * scaling_factor - fig, ax = plt.subplots(figsize=figsize) + else: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + + shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -1343,29 +2347,68 @@ def _visualize_shifts( masked_ind = xp.logical_and(freq_mask, self._dp_mask) plot_ind = masked_ind[dp_mask_ind] - ax.quiver( - asnumpy(self._kxy[plot_ind, 1]), - asnumpy(self._kxy[plot_ind, 0]), - asnumpy( - self._xy_shifts[plot_ind, 1] - * scale_arrows - * self._reciprocal_sampling[0] - ), - asnumpy( - self._xy_shifts[plot_ind, 0] - * scale_arrows - * self._reciprocal_sampling[1] - ), - color=color, - angles="xy", - scale_units="xy", - scale=1, - **kwargs, - ) - kr_max = xp.max(self._kr) - ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) - ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + ax[0].quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[0].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_title("Measured Bright Field Shifts") + ax[0].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[0].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[0].set_aspect("equal") + + # passive coordinate rotation + tf_T = AffineTransform(angle=-self.rotation_Q_to_R_rads) + rotated_kxy = tf_T(self._kxy[plot_ind], xp=xp) + ax[1].quiver( + asnumpy(rotated_kxy[:, 1]), + asnumpy(rotated_kxy[:, 0]), + asnumpy(rotated_shifts[plot_ind, 1]), + asnumpy(rotated_shifts[plot_ind, 0]), + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[1].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_title("Rotated Bright Field Shifts") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[1].set_aspect("equal") + else: + ax.quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_title("Measured BF Shifts") + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.set_aspect("equal") + + fig.tight_layout() def visualize( self, @@ -1391,3 +2434,21 @@ def visualize( ax.set_title("Reconstructed Bright Field Image") return self + + @property + def object_cropped(self): + """cropped object""" + if hasattr(self, "_recon_phase_corrected"): + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_phase_corrected, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_phase_corrected) + else: + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_BF) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 67dba6115..d29aa1747 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,4 +1,7 @@ +import warnings + import numpy as np +import pylops from py4DSTEM.process.phase.utils import ( array_slice, estimate_global_transformation_ransac, @@ -183,6 +186,63 @@ def _object_butterworth_constraint( return current_object + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + def _object_denoise_tv_chambolle( self, current_object, @@ -229,90 +289,100 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) + if xp.iscomplexobj(current_object): + updated_object = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" + current_object_sum = xp.sum(current_object) + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if pad_object: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (1, 1) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() - p = xp.zeros( - (current_object.ndim,) + current_object.shape, dtype=current_object.dtype - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ slice(None), ] * (current_object.ndim + 1) for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E E_previous = E - i += 1 + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] + if pad_object: + for ax in range(len(ndim)): + slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + updated_object = updated_object[slices] + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) - return updated_object / xp.sum(updated_object) * current_object_sum + return updated_object def _probe_center_of_mass_constraint(self, current_probe): """ @@ -364,7 +434,7 @@ def _probe_amplitude_constraint( erf = self._erf probe_intensity = xp.abs(current_probe) ** 2 - # current_probe_sum = xp.sum(probe_intensity) + current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -374,10 +444,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_fourier_amplitude_constraint( self, @@ -406,7 +476,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -419,10 +489,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aperture_constraint( self, @@ -444,16 +514,16 @@ def _probe_aperture_constraint( """ xp = self._xp - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aberration_fitting_constraint( self, @@ -485,16 +555,18 @@ def _probe_aberration_fitting_constraint( fourier_probe = xp.fft.fft2(current_probe) fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling + energy = self._energy fitted_angle, _ = fit_aberration_surface( fourier_probe, sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - fourier_probe = fourier_probe_abs * xp.exp(1.0j * fitted_angle) + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) current_probe = xp.fft.ifft2(fourier_probe) return current_probe diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index 8881d021c..233d34e45 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -66,6 +66,8 @@ class SimultaneousPtychographicReconstruction(PtychographicReconstruction): object_padding_px: Tuple[int,int], optional Pixel dimensions to pad objects with If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction initial_object_guess: np.ndarray, optional Initial guess for complex-valued object of dimensions (Px,Py) If None, initialized to 1.0j @@ -102,6 +104,7 @@ def __init__( vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, @@ -150,6 +153,12 @@ def __init__( raise ValueError( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") self.set_save_defaults() @@ -167,6 +176,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -192,6 +202,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -246,6 +257,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -338,6 +351,9 @@ def preprocess( ) ) + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + # 1st measurement sets rotation angle and transposition ( measurement_0, @@ -404,6 +420,8 @@ def preprocess( intensities_0, com_fitted_x_0, com_fitted_y_0, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -487,6 +505,8 @@ def preprocess( intensities_1, com_fitted_x_1, com_fitted_y_1, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -571,6 +591,8 @@ def preprocess( intensities_2, com_fitted_x_2, com_fitted_y_2, + crop_patterns, + self._positions_mask, ) # explicitly delete namescapes @@ -610,7 +632,7 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -683,6 +705,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -746,19 +772,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -780,23 +800,22 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax2.scatter( self.positions[:, 1], @@ -808,7 +827,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -2232,6 +2251,9 @@ def _constraints( q_highpass_e, q_highpass_m, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, warmup_iteration, object_positivity, shrinkage_rad, @@ -2300,6 +2322,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising warmup_iteration: bool If True, constraints electrostatic object only object_positivity: bool @@ -2349,6 +2377,15 @@ def _constraints( if self._object_type == "complex": magnetic_obj = magnetic_obj.real + if tv_denoise: + electrostatic_obj = self._object_denoise_tv_pylops( + electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) + + if not warmup_iteration: + magnetic_obj = self._object_denoise_tv_pylops( + magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) if shrinkage_rad > 0.0 or object_mask is not None: electrostatic_obj = self._object_shrinkage_constraint( @@ -2446,6 +2483,9 @@ def reconstruct( q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -2538,6 +2578,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -2748,6 +2794,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = (None,) * self._num_sim_measurements self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2899,6 +2947,9 @@ def reconstruct( q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -3029,8 +3080,6 @@ def _visualize_last_iteration( figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj_e = np.angle(self.object[0]) @@ -3052,6 +3101,11 @@ def _visualize_last_iteration( vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) + extent = [ 0, self.sampling[1] * rotated_shape[1], @@ -3156,29 +3210,29 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: # Electrostatic Object @@ -3229,10 +3283,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -3328,3 +3382,98 @@ def visualize( ) return self + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[0][start:end] + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped[0]) + else: + projected_cropped_potential = self.object_cropped[0] + + return projected_cropped_potential + + @property + def object_cropped(self): + """Cropped and rotated object""" + + obj_e, obj_m = self._object + obj_e = self._crop_rotate_object_fov(obj_e) + obj_m = self._crop_rotate_object_fov(obj_m) + return (obj_e, obj_m) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 0480bae8a..350d0a3cb 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -14,8 +14,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np from emdfile import Custom, tqdmnd from py4DSTEM.datacube import DataCube @@ -79,6 +79,8 @@ class SingleslicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -102,6 +104,7 @@ def __init__( initial_scan_positions: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "ptychographic_reconstruction", @@ -147,6 +150,13 @@ def __init__( f"object_type must be either 'potential' or 'complex', not {object_type}" ) + if positions_mask is not None and positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + positions_mask = np.asarray(positions_mask, dtype="bool") + self.set_save_defaults() # Data @@ -163,6 +173,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -188,6 +199,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -245,6 +257,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -333,6 +347,8 @@ def preprocess( self._intensities, self._com_fitted_x, self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -341,7 +357,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -412,6 +428,11 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, @@ -474,19 +495,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -508,23 +523,19 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="gray", ) ax2.scatter( self.positions[:, 1], @@ -536,7 +547,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -1023,6 +1034,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, @@ -1078,6 +1092,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1108,6 +1128,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1198,6 +1223,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1284,6 +1312,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1486,6 +1520,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1618,6 +1654,9 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1734,8 +1773,11 @@ def _visualize_last_iteration( """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1828,29 +1870,31 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + self.probe_fourier, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1883,10 +1927,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1957,9 +2001,12 @@ def _visualize_all_iterations( else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "inferno") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + cmap = kwargs.pop("cmap", "magma") + + if plot_fourier_probe: + chroma_boost = kwargs.pop("chroma_boost", 2) + else: + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2063,8 +2110,7 @@ def _visualize_all_iterations( probes[grid_range[n]] ) ), - hue_start=hue_start, - invert=invert, + chroma_boost=chroma_boost, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") @@ -2072,21 +2118,23 @@ def _visualize_all_iterations( else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2098,7 +2146,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c2e1d3b77..a1eb54c80 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1543,46 +1543,84 @@ def step_model(radius, sig_0, rad_0, width): def aberrations_basis_function( probe_size, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ + + # Add constant phase shift in basis + mn = [[-1, 0, 0]] + + for m in range(1, max_radial_order): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(0, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + aberrations_mn = np.array(mn) + aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :] + + sub = aberrations_mn[:, 1] > 0 + aberrations_mn[sub, :] = aberrations_mn[sub, :][ + np.argsort(aberrations_mn[sub, 0]), : + ] + aberrations_mn[~sub, :] = aberrations_mn[~sub, :][ + np.argsort(aberrations_mn[~sub, 0]), : + ] + aberrations_num = aberrations_mn.shape[0] + sx, sy = probe_size dx, dy = probe_sampling + wavelength = electron_wavelength_angstrom(energy) + qx = xp.fft.fftfreq(sx, dx) qy = xp.fft.fftfreq(sy, dy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.ones((alpha.size, aberrations_num)) + + # Skip constant to avoid dividing by zero in normalization + for a0 in range(1, aberrations_num): + m, n, a = aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - qxa, qya = xp.meshgrid(qx, qy, indexing="ij") - q2 = qxa**2 + qya**2 - theta = xp.arctan2(qya, qxa) - - basis = [] - index = [] - - for n in range(max_angular_order + 1): - for m in range((max_radial_order - n) // 2 + 1): - basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) - index.append((m, n, 0)) - if n > 0: - basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) - index.append((m, n, 1)) - - basis = xp.array(basis) + # global scaling + aberrations_basis *= 2 * np.pi / wavelength - return basis, index + return aberrations_basis, aberrations_mn def fit_aberration_surface( complex_probe, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ probe_amp = xp.abs(complex_probe) - probe_angle = xp.angle(complex_probe) + probe_angle = -xp.angle(complex_probe) if xp is np: probe_angle = probe_angle.astype(np.float64) @@ -1592,21 +1630,47 @@ def fit_aberration_surface( unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) - basis, _ = aberrations_basis_function( + raveled_basis, _ = aberrations_basis_function( complex_probe.shape, probe_sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - raveled_basis = basis.reshape((basis.shape[0], -1)) raveled_weights = probe_amp.ravel() - Aw = raveled_basis.T * raveled_weights[:, None] + Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] - fitted_angle = xp.tensordot(coeff, basis, axes=1) + fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) return fitted_angle, coeff + + +def rotate_point(origin, point, angle): + """ + Rotate a point (x1, y1) counterclockwise by a given angle around + a given origin (x0, y0). + + Parameters + -------- + origin: 2-tuple of floats + (x0, y0) + point: 2-tuple of floats + (x1, y1) + angle: float (radians) + + Returns + -------- + rotated points (2-tuple) + + """ + ox, oy = origin + px, py = point + + qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) + qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) + return qx, qy diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 3ed00ec13..4f053dcfa 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -2,49 +2,72 @@ import numpy as np import matplotlib.pyplot as plt +from scipy.optimize import curve_fit +from scipy.ndimage import gaussian_filter from emdfile import tqdmnd -def calculate_FEM_global( +def calculate_radial_statistics( self, - use_median_local=False, - use_median_global=False, - plot_results=False, + plot_results_mean=False, + plot_results_var=False, figsize=(8, 4), returnval=False, returnfig=False, progress_bar=True, ): """ - Calculate fluctuation electron microscopy (FEM) statistics, including radial mean, - variance, and normalized variance. This function uses the original FEM definitions, - where the signal is computed pattern-by-pattern. + Calculate the radial statistics used in fluctuation electron microscopy (FEM) + and as an initial step in radial distribution function (RDF) calculation. + The computed quantities are the radial mean, variance, and normalized variance. + + There are several ways the means and variances can be computed. Here we first + compute the mean and standard deviation pattern by pattern, i.e. for + diffraction signal d(x,y; q,theta) we take + + d_mean_all(x,y; q) = \int_{0}^{2\pi} d(x,y; q,\theta) d\theta + d_var_all(x,y; q) = \int_{0}^{2\pi} + \( d(x,y; q,\theta) - d_mean_all(x,y; q,\theta) \)^2 d\theta + + Then we find the mean and variance profiles by taking the means of these + quantities over all scan positions: + + d_mean(q) = \sum_{x,y} d_mean_all(x,y; q) + d_var(q) = \sum_{x,y} d_var_all(x,y; q) + + and the normalized variance is d_var/d_mean. + + This follows the methods described in [@cophus TODO ADD CITATION]. - TODO - finish docstrings, add median statistics. Parameters -------- - self: PolarDatacube - Polar datacube used for measuring FEM properties. + plot_results_mean: bool + Toggles plotting the computed radial means + plot_results_var: bool + Toggles plotting the computed radial variances + figsize: 2-tuple + Size of output figures + returnval: bool + Toggles returning the answer. Answers are always stored internally. + returnfig: bool + Toggles returning figures that have been plotted. Only figures for + which `plot_results_*` is True are returned. Returns -------- radial_avg: np.array - Average radial intensity + Optional - returned iff returnval is True. The average radial intensity. radial_var: np.array - Variance in the radial dimension - - + Optional - returned iff returnval is True. The radial variance. + fig_means: 2-tuple (fig,ax) + Optional - returned iff returnfig is True. Plot of the radial means. + fig_var: 2-tuple (fig,ax) + Optional - returned iff returnfig is True. Plot of the radial variances. """ - # Get the dimensioned radial bins - self.scattering_vector = ( - self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() - ) - self.scattering_vector_units = self.calibration.get_Q_pixel_units() - - # init radial data array + # init radial data arrays self.radial_all = np.zeros( ( self._datacube.shape[0], @@ -52,72 +75,510 @@ def calculate_FEM_global( self.polar_shape[1], ) ) + self.radial_all_std = np.zeros( + ( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + ) + ) - # Compute the radial mean for each probe position + # Compute the radial mean and standard deviation for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], - desc="Global FEM", + desc="Radial statistics", unit=" probe positions", disable=not progress_bar, ): self.radial_all[rx, ry] = np.mean(self.data[rx, ry], axis=0) + self.radial_all_std[rx, ry] = np.sqrt( + np.mean((self.data[rx, ry] - self.radial_all[rx, ry][None]) ** 2, axis=0) + ) - self.radial_avg = np.mean(self.radial_all, axis=(0, 1)) + self.radial_mean = np.mean(self.radial_all, axis=(0, 1)) self.radial_var = np.mean( - (self.radial_all - self.radial_avg[None, None]) ** 2, axis=(0, 1) + (self.radial_all - self.radial_mean[None, None]) ** 2, axis=(0, 1) ) - self.radial_var_norm = self.radial_var / self.radial_avg**2 + + self.radial_var_norm = np.copy(self.radial_var) + sub = self.radial_mean > 0.0 + self.radial_var_norm[sub] /= self.radial_mean[sub] ** 2 + + # prepare answer + statistics = self.radial_mean, self.radial_var, self.radial_var_norm + if returnval: + ans = statistics if not returnfig else [statistics] + else: + ans = None if not returnfig else [] # plot results - if plot_results: + if plot_results_mean: + fig, ax = plot_radial_mean( + self, + figsize=figsize, + returnfig=True, + ) if returnfig: - fig, ax = plot_FEM_global( - self, - figsize=figsize, - returnfig=True, - ) - else: - plot_FEM_global( - self, - figsize=figsize, - ) - - # Return values - if returnval: + ans.append((fig, ax)) + if plot_results_var: + fig, ax = plot_radial_var_norm( + self, + figsize=figsize, + returnfig=True, + ) if returnfig: - return self.radial_avg, self.radial_var, fig, ax - else: - return self.radial_avg, self.radial_var + ans.append((fig, ax)) + + # return + return ans + + +def plot_radial_mean( + self, + log_x=False, + log_y=False, + figsize=(8, 4), + returnfig=False, +): + """ + Plot the radial means. + + Parameters + ---------- + log_x : bool + Toggle log scaling of the x-axis + log_y : bool + Toggle log scaling of the y-axis + figsize : 2-tuple + Size of the output figure + returnfig : bool + Toggle returning the figure + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.radial_mean, + ) + + if log_x: + ax.set_xscale("log") + if log_y: + ax.set_yscale("log") + + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") + ax.set_ylabel("Radial Mean") + if log_x and self.qq[0] == 0.0: + ax.set_xlim((self.qq[1], self.qq[-1])) else: - if returnfig: - return fig, ax - else: - pass + ax.set_xlim((self.qq[0], self.qq[-1])) + + if returnfig: + return fig, ax -def plot_FEM_global( +def plot_radial_var_norm( self, figsize=(8, 4), returnfig=False, ): """ - Plotting function for the global FEM. + Plot the radial variances. + + Parameters + ---------- + figsize : 2-tuple + Size of the output figure + returnfig : bool + Toggle returning the figure + """ fig, ax = plt.subplots(figsize=figsize) ax.plot( - self.scattering_vector, + self.qq, self.radial_var_norm, ) - ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") ax.set_ylabel("Normalized Variance") - ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) + ax.set_xlim((self.qq[0], self.qq[-1])) if returnfig: return fig, ax +def calculate_pair_dist_function( + self, + k_min=0.05, + k_max=None, + k_width=0.25, + k_lowpass=None, + k_highpass=None, + r_min=0.0, + r_max=20.0, + r_step=0.02, + damp_origin_fluctuations=True, + density=None, + plot_background_fits=False, + plot_sf_estimate=False, + plot_reduced_pdf=True, + plot_pdf=False, + figsize=(8, 4), + maxfev=None, + returnval=False, + returnfig=False, +): + """ + Calculate the pair distribution function (PDF). + + First a background is calculated using primarily the signal at the highest + scattering vectors available, given by a sum of two exponentials ~exp(-k^2) + and ~exp(-k^4) modelling the single atom scattering factor plus a constant + offset. Next, the structure factor is computed as + + S(k) = (I(k) - bg(k)) * k / f(k) + + where k is the magnitude of the scattering vector, I(k) is the mean radial + signal, f(k) is the single atom scattering factor, and bg(k) is the total + background signal (i.e. f(k) plus a constant offset). S(k) is masked outside + of the selected fitting region of k-values [k_min,k_max] and low/high pass + filters are optionally applied. The structure factor is then inverted into + the reduced pair distribution function g(r) using + + g(r) = \frac{2}{\pi) \int sin( 2\pi r k ) S(k) dk + + The value of the integral is (optionally) damped to zero at the origin to + match the physical requirement that this condition holds. Finally, the + full PDF G(r) is computed if a known density is provided, using + + G(r) = 1 + [ \frac{2}{\pi} * g(r) / ( 4\pi * D * r dr ) ] + + This follows the methods described in [@cophus TODO ADD CITATION]. + + + Parameters + ---------- + k_min : number + Minimum scattering vector to include in the calculation + k_max : number or None + Maximum scattering vector to include in the calculation. Note that + this cutoff is used when calculating the structure factor - however it + is *not* used when estimating the background / single atom scattering + factor, which is best estimated from high scattering lengths. + k_width : number + The fitting window for the structure factor calculation [k_min,k_max] + includes a damped region at its edges, i.e. the signal is smoothly dampled + to zero in the regions [k_min, k_min+k_width] and [k_max-k_width,k_max] + k_lowpass : number or None + Lowpass filter, in units the scattering vector stepsize (i.e. self.qstep) + k_highpass : number or None + Highpass filter, in units the scattering vector stepsize (i.e. self.qstep) + r_min,r_max,r_step : numbers + Define the real space coordinates r that the PDF g(r) will be computed in. + The coordinates will be np.arange(r_min,r_max,r_step), given in units + inverse to the scattering vector units. + damp_origin_fluctuations : bool + The value of the PDF approaching the origin should be zero, however numerical + instability may result in non-physical finite values there. This flag toggles + damping the value of the PDF to zero near the origin. + density : number or None + The density of the sample, if known. If this is not provided, only the + reduced PDF is calculated. If this value is provided, the PDF is also + calculated. + plot_background_fits : bool + plot_sf_estimate : bool + plot_reduced_pdf=True : bool + plot_pdf : bool + figsize : 2-tuple + maxfev : integer or None + Max number of iterations to use when fitting the background + returnval: bool + Toggles returning the answer. Answers are always stored internally. + returnfig: bool + Toggles returning figures that have been plotted. Only figures for + which `plot_*` is True are returned. + """ + + # set up coordinates and scaling + k = self.qq + dk = k[1] - k[0] + k2 = k**2 + Ik = self.radial_mean + int_mean = np.mean(Ik) + sub_fit = k >= k_min + + # initial guesses for background coefs + const_bg = np.min(self.radial_mean) / int_mean + int0 = np.median(self.radial_mean) / int_mean - const_bg + sigma0 = np.mean(k) + coefs = [const_bg, int0, sigma0, int0, sigma0] + lb = [0, 0, 0, 0, 0] + ub = [np.inf, np.inf, np.inf, np.inf, np.inf] + # Weight the fit towards high k values + noise_est = k[-1] - k + dk + + # Estimate the mean atomic form factor + background + if maxfev is None: + coefs = curve_fit( + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma=noise_est[sub_fit], + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + )[0] + else: + coefs = curve_fit( + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma=noise_est[sub_fit], + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, + )[0] + + coefs[0] *= int_mean + coefs[1] *= int_mean + coefs[3] *= int_mean + + # Calculate the mean atomic form factor without a constant offset + # coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) + # fk = scattering_model(k2, coefs_fk) + bg = scattering_model(k2, coefs) + fk = bg - coefs[0] + + # mask for structure factor estimate + if k_max is None: + k_max = np.max(k) + mask = np.clip( + np.minimum( + (k - 0.0) / k_width, + (k_max - k) / k_width, + ), + 0, + 1, + ) + mask = np.sin(mask * (np.pi / 2)) + + # Estimate the reduced structure factor S(k) + Sk = (Ik - bg) * k / fk + + # Masking edges of S(k) + mask_sum = np.sum(mask) + Sk = (Sk - np.sum(Sk * mask) / mask_sum) * mask + + # Filtering of S(k) + if k_lowpass is not None and k_lowpass > 0.0: + Sk = gaussian_filter(Sk, sigma=k_lowpass / dk, mode="nearest") + if k_highpass is not None and k_highpass > 0.0: + Sk_lowpass = gaussian_filter(Sk, sigma=k_highpass / dk, mode="nearest") + Sk -= Sk_lowpass + + # Calculate the PDF + r = np.arange(r_min, r_max, r_step) + ra, ka = np.meshgrid(r, k) + pdf_reduced = ( + (2 / np.pi) + * dk + * np.sum( + np.sin(2 * np.pi * ra * ka) * Sk[:, None], + axis=0, + ) + ) + + # Damp the unphysical fluctuations at the PDF origin + if damp_origin_fluctuations: + ind_max = np.argmax(pdf_reduced) + r_ind_max = r[ind_max] + r_mask = np.minimum(r / r_ind_max, 1.0) + r_mask = np.sin(r_mask * np.pi / 2) ** 2 + pdf_reduced *= r_mask + + # Store results + self.pdf_r = r + self.pdf_reduced = pdf_reduced + + self.Sk = Sk + self.fk = fk + self.bg = bg + self.offset = coefs[0] + self.Sk_mask = mask + + # if density is provided, we can estimate the full PDF + if density is not None: + pdf = pdf_reduced.copy() + pdf[1:] /= 4 * np.pi * density * r[1:] * (r[1] - r[0]) + pdf *= 2 / np.pi + pdf += 1 + + # damp and clip values below zero + if damp_origin_fluctuations: + pdf *= r_mask + pdf = np.maximum(pdf, 0.0) + + # store results + self.pdf = pdf + + # prepare answer + if density is None: + return_values = self.pdf_r, self.pdf_reduced + else: + return_values = self.pdf_r, self.pdf_reduced, self.pdf + if returnval: + ans = return_values if not returnfig else [return_values] + else: + ans = None if not returnfig else [] + + # Plots + if plot_background_fits: + fig, ax = self.plot_background_fits(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + if plot_sf_estimate: + fig, ax = self.plot_sf_estimate(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + if plot_reduced_pdf: + fig, ax = self.plot_reduced_pdf(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + if plot_pdf: + fig, ax = self.plot_pdf(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + # return + return ans + + +def plot_background_fits( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.radial_mean, + color="k", + ) + ax.plot( + self.qq, + self.bg, + color="r", + ) + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") + ax.set_ylabel("Radial Mean") + ax.set_xlim((self.qq[0], self.qq[-1])) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("I(k) and Background Fit Estimates") + ax.set_ylim( + ( + np.min(self.radial_mean[self.radial_mean > 0]) * 0.8, + np.max(self.radial_mean * self.Sk_mask) * 1.25, + ) + ) + ax.set_yscale("log") + if returnfig: + return fig, ax + plt.show() + + +def plot_sf_estimate( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.Sk, + color="r", + ) + yr = (np.min(self.Sk), np.max(self.Sk)) + ax.set_ylim( + ( + yr[0] - 0.05 * (yr[1] - yr[0]), + yr[1] + 0.05 * (yr[1] - yr[0]), + ) + ) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("Reduced Structure Factor") + if returnfig: + return fig, ax + plt.show() + + +def plot_reduced_pdf( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.pdf_r, + self.pdf_reduced, + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Reduced Pair Distribution Function") + if returnfig: + return fig, ax + plt.show() + + +def plot_pdf( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.pdf_r, + self.pdf, + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Pair Distribution Function") + if returnfig: + return fig, ax + plt.show() + + # functions for inverting from reduced PDF back to S(k) + + # # invert + # ind_max = np.argmax(pdf_reduced* np.sqrt(r)) + # r_ind_max = r[ind_max-1] + # r_mask = np.minimum(r / (r_ind_max), 1.0) + # r_mask = np.sin(r_mask*np.pi/2)**2 + + # Sk_back_proj = (0.5*r_step)*np.sum( + # np.sin( + # 2*np.pi*ra*ka + # ) * pdf_corr[None,:],# * r_mask[None,:], + # # ) * pdf_corr[None,:],# * r_mask[None,:], + # axis=1, + # ) + + def calculate_FEM_local( self, figsize=(8, 6), @@ -143,4 +604,41 @@ def calculate_FEM_local( """ - 1 + 1 + pass + + +def scattering_model(k2, *coefs): + """ + The scattering model used to fit the PDF background. The fit + function is a constant plus two exponentials - one in k^2 and one + in k^4: + + f(k; c,i0,s0,i1,s1) = + c + i0*exp(k^2/-2*s0^2) + i1*exp(k^4/-2*s1^4) + + Parameters + ---------- + k2 : 1d array + the scattering vector squared + coefs : 5-tuple + Initial guesses at the parameters (c,i0,s0,i1,s1) + """ + coefs = np.squeeze(np.array(coefs)) + + const_bg = coefs[0] + int0 = coefs[1] + sigma0 = coefs[2] + int1 = coefs[3] + sigma1 = coefs[4] + + int_model = ( + const_bg + + int0 * np.exp(k2 / (-2 * sigma0**2)) + + int1 * np.exp(k2**2 / (-2 * sigma1**4)) + ) + + # (int1*sigma1)/(k2 + sigma1**2) + # int1*np.exp(k2/(-2*sigma1**2)) + # int1*np.exp(k2/(-2*sigma1**2)) + + return int_model diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index 2bfb205fe..56071c534 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -94,9 +94,15 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( - calculate_FEM_global, - plot_FEM_global, + calculate_radial_statistics, + calculate_pair_dist_function, calculate_FEM_local, + plot_radial_mean, + plot_radial_var_norm, + plot_background_fits, + plot_sf_estimate, + plot_reduced_pdf, + plot_pdf, ) from py4DSTEM.process.polar.polar_peaks import ( find_peaks_single_pattern, diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain.py deleted file mode 100644 index db252f75b..000000000 --- a/py4DSTEM/process/strain.py +++ /dev/null @@ -1,601 +0,0 @@ -# Defines the Strain class - -from typing import Optional - -import matplotlib.pyplot as plt -import numpy as np -from py4DSTEM import PointList -from py4DSTEM.braggvectors import BraggVectors -from py4DSTEM.data import Data, RealSlice -from py4DSTEM.preprocess.utils import get_maxima_2D -from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show - - -class StrainMap(RealSlice, Data): - """ - Stores strain map. - - TODO add docs - - """ - - def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): - """ - TODO - """ - assert isinstance( - braggvectors, BraggVectors - ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" - - # initialize as a RealSlice - RealSlice.__init__( - self, - name=name, - data=np.empty( - ( - 6, - braggvectors.Rshape[0], - braggvectors.Rshape[1], - ) - ), - slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], - ) - - # set up braggvectors - # this assigns the bvs, ensures the origin is calibrated, - # and adds the strainmap to the bvs' tree - self.braggvectors = braggvectors - - # initialize as Data - Data.__init__(self) - - # set calstate - # this property is used only to check to make sure that - # the braggvectors being used throughout a workflow are - # the same. The state of calibration of the vectors is noted - # here, and then checked each time the vectors are used - - # if they differ, an error message and instructions for - # re-calibration are issued - self.calstate = self.braggvectors.calstate - assert self.calstate["center"], "braggvectors must be centered" - # get the BVM - # a new BVM using the current calstate is computed - self.bvm = self.braggvectors.histogram(mode="cal") - - # braggvector properties - - @property - def braggvectors(self): - return self._braggvectors - - @braggvectors.setter - def braggvectors(self, x): - assert isinstance( - x, BraggVectors - ), f".braggvectors must be BraggVectors, not type {type(x)}" - assert ( - x.calibration.origin is not None - ), f"braggvectors must have a calibrated origin" - self._braggvectors = x - self._braggvectors.tree(self, force=True) - - def reset_calstate(self): - """ - Resets the calibration state. This recomputes the BVM, and removes any computations - this StrainMap instance has stored, which will need to be recomputed. - """ - for attr in ( - "g0", - "g1", - "g2", - ): - if hasattr(self, attr): - delattr(self, attr) - self.calstate = self.braggvectors.calstate - pass - - # Class methods - - def choose_lattice_vectors( - self, - index_g0, - index_g1, - index_g2, - subpixel="multicorr", - upsample_factor=16, - sigma=0, - minAbsoluteIntensity=0, - minRelativeIntensity=0, - relativeToPeak=0, - minSpacing=0, - edgeBoundary=1, - maxNumPeaks=10, - figsize=(12, 6), - c_indices="lightblue", - c0="g", - c1="r", - c2="r", - c_vectors="r", - c_vectorlabels="w", - size_indices=20, - width_vectors=1, - size_vectorlabels=20, - vis_params={}, - returncalc=False, - returnfig=False, - ): - """ - Choose which lattice vectors to use for strain mapping. - - Overlays the bvm with the points detected via local 2D - maxima detection, plus an index for each point. User selects - 3 points using the overlaid indices, which are identified as - the origin and the termini of the lattice vectors g1 and g2. - - Parameters - ---------- - index_g0 : int - selected index for the origin - index_g1 : int - selected index for g1 - index_g2 :int - selected index for g2 - subpixel : str in ('pixel','poly','multicorr') - See the docstring for py4DSTEM.preprocess.get_maxima_2D - upsample_factor : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - sigma : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - minAbsoluteIntensity : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - minRelativeIntensity : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - relativeToPeak : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - minSpacing : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - edgeBoundary : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - maxNumPeaks : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - figsize : 2-tuple - the size of the figure - c_indices : color - color of the maxima - c0 : color - color of the origin - c1 : color - color of g1 point - c2 : color - color of g2 point - c_vectors : color - color of the g1/g2 vectors - c_vectorlabels : color - color of the vector labels - size_indices : number - size of the indices - width_vectors : number - width of the vectors - size_vectorlabels : number - size of the vector labels - vis_params : dict - additional visualization parameters passed to `show` - returncalc : bool - toggles returning the answer - returnfig : bool - toggles returning the figure - - Returns - ------- - (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter - """ - # validate inputs - for i in (index_g0, index_g1, index_g2): - assert isinstance(i, (int, np.integer)), "indices must be integers!" - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - # find the maxima - g = get_maxima_2D( - self.bvm.data, - subpixel=subpixel, - upsample_factor=upsample_factor, - sigma=sigma, - minAbsoluteIntensity=minAbsoluteIntensity, - minRelativeIntensity=minRelativeIntensity, - relativeToPeak=relativeToPeak, - minSpacing=minSpacing, - edgeBoundary=edgeBoundary, - maxNumPeaks=maxNumPeaks, - ) - - # get the lattice vectors - gx, gy = g["x"], g["y"] - g0 = gx[index_g0], gy[index_g0] - g1x = gx[index_g1] - g0[0] - g1y = gy[index_g1] - g0[1] - g2x = gx[index_g2] - g0[0] - g2y = gy[index_g2] - g0[1] - g1, g2 = (g1x, g1y), (g2x, g2y) - - # make the figure - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - show(self.bvm.data, figax=(fig, ax1), **vis_params) - show(self.bvm.data, figax=(fig, ax2), **vis_params) - - # Add indices to left panel - d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} - d0 = { - "x": gx[index_g0], - "y": gy[index_g0], - "size": size_indices, - "color": c0, - "fontweight": "bold", - "labels": [str(index_g0)], - } - d1 = { - "x": gx[index_g1], - "y": gy[index_g1], - "size": size_indices, - "color": c1, - "fontweight": "bold", - "labels": [str(index_g1)], - } - d2 = { - "x": gx[index_g2], - "y": gy[index_g2], - "size": size_indices, - "color": c2, - "fontweight": "bold", - "labels": [str(index_g2)], - } - add_pointlabels(ax1, d) - add_pointlabels(ax1, d0) - add_pointlabels(ax1, d1) - add_pointlabels(ax1, d2) - - # Add vectors to right panel - dg1 = { - "x0": gx[index_g0], - "y0": gy[index_g0], - "vx": g1[0], - "vy": g1[1], - "width": width_vectors, - "color": c_vectors, - "label": r"$g_1$", - "labelsize": size_vectorlabels, - "labelcolor": c_vectorlabels, - } - dg2 = { - "x0": gx[index_g0], - "y0": gy[index_g0], - "vx": g2[0], - "vy": g2[1], - "width": width_vectors, - "color": c_vectors, - "label": r"$g_2$", - "labelsize": size_vectorlabels, - "labelcolor": c_vectorlabels, - } - add_vector(ax2, dg1) - add_vector(ax2, dg2) - - # store vectors - self.g = g - self.g0 = g0 - self.g1 = g1 - self.g2 = g2 - - # return - if returncalc and returnfig: - return (g0, g1, g2), (fig, (ax1, ax2)) - elif returncalc: - return (g0, g1, g2) - elif returnfig: - return (fig, (ax1, ax2)) - else: - return - - def fit_lattice_vectors( - self, - x0=None, - y0=None, - max_peak_spacing=2, - mask=None, - plot=True, - vis_params={}, - returncalc=False, - ): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - x0 : floagt - x-coord of origin - y0 : float - y-coord of origin - max_peak_spacing: float - Maximum distance from the ideal lattice points - to include a peak for indexing - mask: bool - Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - plot:bool - plot results if tru - vis_params : dict - additional visualization parameters passed to `show` - returncalc : bool - if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map - """ - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - if x0 is None: - x0 = self.braggvectors.Qshape[0] / 2 - if y0 is None: - y0 = self.braggvectors.Qshape[0] / 2 - - # index braggvectors - from py4DSTEM.process.latticevectors import index_bragg_directions - - _, _, braggdirections = index_bragg_directions( - x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 - ) - - self.braggdirections = braggdirections - - if plot: - self.show_bragg_indexing( - self.bvm, - bragg_directions=braggdirections, - points=True, - **vis_params, - ) - - # add indicies to braggvectors - from py4DSTEM.process.latticevectors import add_indices_to_braggvectors - - bragg_vectors_indexed = add_indices_to_braggvectors( - self.braggvectors, - self.braggdirections, - maxPeakSpacing=max_peak_spacing, - qx_shift=self.braggvectors.Qshape[0] / 2, - qy_shift=self.braggvectors.Qshape[1] / 2, - mask=mask, - ) - - self.bragg_vectors_indexed = bragg_vectors_indexed - - # fit bragg vectors - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) - self.g1g2_map = g1g2_map - - if returncalc: - braggdirections, bragg_vectors_indexed, g1g2_map - - def get_strain( - self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs - ): - """ - mask: nd.array (bool) - Use lattice vectors from g1g2_map scan positions - wherever mask==True. If mask is None gets median strain - map from entire field of view. If mask is not None, gets - reference g1 and g2 from region and then calculates strain. - g_reference: nd.array of form [x,y] - G_reference (tupe): reference coordinate system for - xaxis_x and xaxis_y - flip_theta: bool - If True, flips rotation coordinate system - returncal: bool - It True, returns rotated map - """ - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - if mask is None: - mask = np.ones(self.g1g2_map.shape, dtype="bool") - - from py4DSTEM.process.latticevectors import get_strain_from_reference_region - - strainmap_g1g2 = get_strain_from_reference_region( - self.g1g2_map, - mask=mask, - ) - else: - from py4DSTEM.process.latticevectors import get_reference_g1g2 - - g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) - - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - - strainmap_g1g2 = get_strain_from_reference_g1g2( - self.g1g2_map, g1_ref, g2_ref - ) - - self.strainmap_g1g2 = strainmap_g1g2 - - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) - - from py4DSTEM.process.latticevectors import get_rotated_strain_map - - strainmap_rotated = get_rotated_strain_map( - self.strainmap_g1g2, - xaxis_x=g_reference[0], - xaxis_y=g_reference[1], - flip_theta=flip_theta, - ) - - self.strainmap_rotated = strainmap_rotated - - from py4DSTEM.visualize import show_strain - - figsize = kwargs.pop("figsize", (14, 4)) - vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) - vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) - ticknumber = kwargs.pop("ticknumber", 3) - bkgrd = kwargs.pop("bkgrd", False) - axes_plots = kwargs.pop("axes_plots", ()) - - fig, ax = show_strain( - self.strainmap_rotated, - vrange_exx=vrange_exx, - vrange_theta=vrange_theta, - ticknumber=ticknumber, - axes_plots=axes_plots, - bkgrd=bkgrd, - figsize=figsize, - **kwargs, - returnfig=True, - ) - - if not np.all(mask == True): - ax[0][0].imshow(mask, alpha=0.2, cmap="binary") - ax[0][1].imshow(mask, alpha=0.2, cmap="binary") - ax[1][0].imshow(mask, alpha=0.2, cmap="binary") - ax[1][1].imshow(mask, alpha=0.2, cmap="binary") - - if returncalc: - return self.strainmap_rotated - - def show_lattice_vectors( - ar, - x0, - y0, - g1, - g2, - color="r", - width=1, - labelsize=20, - labelcolor="w", - returnfig=False, - **kwargs, - ): - """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" - fig, ax = show(ar, returnfig=True, **kwargs) - - # Add vectors - dg1 = { - "x0": x0, - "y0": y0, - "vx": g1[0], - "vy": g1[1], - "width": width, - "color": color, - "label": r"$g_1$", - "labelsize": labelsize, - "labelcolor": labelcolor, - } - dg2 = { - "x0": x0, - "y0": y0, - "vx": g2[0], - "vy": g2[1], - "width": width, - "color": color, - "label": r"$g_2$", - "labelsize": labelsize, - "labelcolor": labelcolor, - } - add_vector(ax, dg1) - add_vector(ax, dg2) - - if returnfig: - return fig, ax - else: - plt.show() - return - - def show_bragg_indexing( - self, - ar, - bragg_directions, - voffset=5, - hoffset=0, - color="w", - size=20, - points=True, - pointcolor="r", - pointsize=50, - returnfig=False, - **kwargs, - ): - """ - Shows an array with an overlay describing the Bragg directions - - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. - """ - assert isinstance(bragg_directions, PointList) - for k in ("qx", "qy", "h", "k"): - assert k in bragg_directions.data.dtype.fields - - fig, ax = show(ar, returnfig=True, **kwargs) - d = { - "bragg_directions": bragg_directions, - "voffset": voffset, - "hoffset": hoffset, - "color": color, - "size": size, - "points": points, - "pointsize": pointsize, - "pointcolor": pointcolor, - } - add_bragg_index_labels(ax, d) - - if returnfig: - return fig, ax - else: - plt.show() - return - - def copy(self, name=None): - name = name if name is not None else self.name + "_copy" - strainmap_copy = StrainMap(self.braggvectors) - for attr in ( - "g", - "g0", - "g1", - "g2", - "calstate", - "bragg_directions", - "bragg_vectors_indexed", - "g1g2_map", - "strainmap_g1g2", - "strainmap_rotated", - ): - if hasattr(self, attr): - setattr(strainmap_copy, attr, getattr(self, attr)) - - for k in self.metadata.keys(): - strainmap_copy.metadata = self.metadata[k].copy() - return strainmap_copy - - # IO methods - - # read - @classmethod - def _get_constructor_args(cls, group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = RealSlice._get_constructor_args(group) - args = { - "data": ar_constr_args["data"], - "name": ar_constr_args["name"], - } - return args diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py new file mode 100644 index 000000000..b487c916b --- /dev/null +++ b/py4DSTEM/process/strain/__init__.py @@ -0,0 +1,10 @@ +from py4DSTEM.process.strain.strain import StrainMap +from py4DSTEM.process.strain.latticevectors import ( + index_bragg_directions, + add_indices_to_braggvectors, + fit_lattice_vectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_strain_from_reference_g1g2, + get_rotated_strain_map, +) diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py new file mode 100644 index 000000000..ba9bb4fcf --- /dev/null +++ b/py4DSTEM/process/strain/latticevectors.py @@ -0,0 +1,469 @@ +# Functions for indexing the Bragg directions + +import numpy as np +from emdfile import PointList, PointListArray, tqdmnd +from numpy.linalg import lstsq +from py4DSTEM.data import RealSlice + + +def index_bragg_directions(x0, y0, gx, gy, g1, g2): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + The approach is to solve the matrix equation + ``alpha = beta * M`` + where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, + beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the + h,k indices. + + Args: + x0 (float): x-coord of origin + y0 (float): y-coord of origin + gx (1d array): x-coord of the reciprocal lattice vectors + gy (1d array): y-coord of the reciprocal lattice vectors + g1 (2-tuple of floats): g1x,g1y + g2 (2-tuple of floats): g2x,g2y + + Returns: + (3-tuple) A 3-tuple containing: + + * **h**: *(ndarray of ints)* first index of the bragg directions + * **k**: *(ndarray of ints)* second index of the bragg directions + * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the + indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y + coords 'h' and 'k' contain h and k. + """ + # Get beta, the matrix of lattice vectors + beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) + + # Get alpha, the matrix of measured bragg angles + alpha = np.vstack([gx - x0, gy - y0]) + + # Calculate M, the matrix of peak positions + M = lstsq(beta, alpha, rcond=None)[0].T + M = np.round(M).astype(int) + + # Get h,k + h = M[:, 0] + k = M[:, 1] + + # Store in a PointList + coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] + temp_array = np.zeros([], dtype=coords) + bragg_directions = PointList(data=temp_array) + bragg_directions.add_data_by_field((gx, gy, h, k)) + mask = np.zeros(bragg_directions["qx"].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) + + return h, k, bragg_directions + + +def add_indices_to_braggvectors( + braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None +): + """ + Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, + identify the indices for each peak in the PointListArray braggpeaks. + Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus + three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak + indices with the ints (h,k) and indicating whether the peak was successfully indexed + or not with the bool index_mask. If `mask` is specified, only the locations where + mask is True are indexed. + + Args: + braggpeaks (PointListArray): the braggpeaks to index. Must contain + the coordinates 'qx', 'qy', and 'intensity' + lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. + Must contain the coordinates 'qx', 'qy', 'h', and 'k' + maxPeakSpacing (float): Maximum distance from the ideal lattice points + to include a peak for indexing + qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList + relative to the `braggpeaks` PointListArray + mask (bool): Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + + Returns: + (PointListArray): The original braggpeaks pointlistarray, with new coordinates + 'h', 'k', containing the indices of each indexable peak. + """ + + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + + if mask is None: + mask = np.ones(braggpeaks.Rshape, dtype=bool) + + assert ( + mask.shape == braggpeaks.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + + coords = [ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ] + + indexed_braggpeaks = PointListArray( + dtype=coords, + shape=braggpeaks.Rshape, + ) + + calstate = braggpeaks.calstate + + # loop over all the scan positions + for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + if mask[Rx, Ry]: + pl = braggpeaks.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( + pl.data["qy"][i] - lattice.data["qy"] + qy_shift + ) ** 2 + ind = np.argmin(r2) + if r2[ind] <= maxPeakSpacing**2: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + lattice.data["h"][ind], + lattice.data["k"][ind], + ) + ) + + return indexed_braggpeaks + + +def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (7-tuple) A 7-tuple containing: + + * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. + * **y0**: *(float)* the y-coord of the origin + * **g1x**: *(float)* x-coord of the first lattice vector + * **g1y**: *(float)* y-coord of the first lattice vector + * **g2x**: *(float)* x-coord of the second lattice vector + * **g2y**: *(float)* y-coord of the second lattice vector + * **error**: *(float)* the fit error + """ + assert isinstance(braggpeaks, PointList) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + braggpeaks = braggpeaks.copy() + + # Remove unindexed peaks + if "index_mask" in braggpeaks.dtype.names: + deletemask = braggpeaks.data["index_mask"] == False + braggpeaks.remove(deletemask) + + # Check to ensure enough peaks are present + if braggpeaks.length < minNumPeaks: + return None, None, None, None, None, None, None + + # Get M, the matrix of (h,k) indices + h, k = braggpeaks.data["h"], braggpeaks.data["k"] + M = np.vstack((np.ones_like(h, dtype=int), h, k)).T + + # Get alpha, the matrix of measured Bragg peak positions + alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T + + # Get weighted matrices + weights = braggpeaks.data["intensity"] + weighted_M = M * weights[:, np.newaxis] + weighted_alpha = alpha * weights[:, np.newaxis] + + # Solve for lattice vectors + beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] + x0, y0 = beta[0, 0], beta[0, 1] + g1x, g1y = beta[1, 0], beta[1, 1] + g2x, g2y = beta[2, 0], beta[2, 1] + + # Calculate the error + alpha_calculated = np.matmul(M, beta) + error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) + error = np.sum(error * weights) / np.sum(weights) + + return x0, y0, g1x, g1y, g2x, g2y, error + + +def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some + known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: + + * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice + * ``g1g2_map.get_slice('y0')`` y-coord of the origin + * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector + * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector + * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector + * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector + * ``g1g2_map.get_slice('error')`` the fit error + * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful + fits + """ + assert isinstance(braggpeaks, PointListArray) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + + # Make RealSlice to contain outputs + slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") + g1g2_map = RealSlice( + data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), + slicelabels=slicelabels, + name="g1g2_map", + ) + + # Fit lattice vectors + for Rx, Ry in tqdmnd( + braggpeaks.shape[0], + braggpeaks.shape[1], + desc="Fitting lattice vectors", + unit="DP", + unit_scale=True, + ): + braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) + qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( + braggpeaks_curr, x0, y0, minNumPeaks + ) + # Store data + if g1x is not None: + g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x + g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y + g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x + g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y + g1g2_map.get_slice("error").data[Rx, Ry] = error + g1g2_map.get_slice("mask").data[Rx, Ry] = 1 + + return g1g2_map + + +def get_reference_g1g2(g1g2_map, mask): + """ + Gets a pair of reference lattice vectors from a region of real space specified by + mask. Takes the median of the lattice vectors in g1g2_map within the specified + region. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever + mask==True + + Returns: + (2-tuple of 2-tuples) A 2-tuple containing: + + * **g1**: *(2-tuple)* first reference lattice vector (x,y) + * **g2**: *(2-tuple)* second reference lattice vector (x,y) + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] + ) + assert mask.dtype == bool + g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) + g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) + g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) + g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) + return (g1x, g1y), (g2x, g2y) + + +def get_strain_from_reference_g1g2(g1g2_map, g1, g2): + """ + Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map + g1g2_map. + + Note that this function will return the strain map oriented with respect to the x/y + axes of diffraction space - to rotate the coordinate system, use + get_rotated_strain_map(). Calibration of the rotational misalignment between real and + diffraction space may also be necessary. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + g1 (2-tuple): first reference lattice vector (x,y) + g2 (2-tuple): second reference lattice vector (x,y) + + Returns: + (RealSlice) the strain map; contains the elements of the infinitessimal strain + matrix, in the following 5 arrays: + + * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect + to x + * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect + to y + * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect + to y + * ``strain_map.get_slice('theta')``: rotation of lattice with respect to + reference + * ``strain_map.get_slice('mask')``: 0/False indicates unknown values + + Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] + ) + + # Get RealSlice for output storage + R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), + name="strain_map", + ) + + # Get reference lattice matrix + g1x, g1y = g1 + g2x, g2y = g2 + M = np.array([[g1x, g1y], [g2x, g2y]]) + + for Rx, Ry in tqdmnd( + R_Nx, + R_Ny, + desc="Calculating strain", + unit="DP", + unit_scale=True, + ): + # Get lattice vectors for DP at Rx,Ry + alpha = np.array( + [ + [ + g1g2_map.get_slice("g1x").data[Rx, Ry], + g1g2_map.get_slice("g1y").data[Rx, Ry], + ], + [ + g1g2_map.get_slice("g2x").data[Rx, Ry], + g1g2_map.get_slice("g2y").data[Rx, Ry], + ], + ] + ) + # Get transformation matrix + beta = lstsq(M, alpha, rcond=None)[0].T + + # Get the infinitesimal strain matrix + strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] + strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] + strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 + strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 + strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ + Rx, Ry + ] + return strain_map + + +def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): + """ + Starting from a strain map defined with respect to the xy coordinate system of + diffraction space, i.e. where exx and eyy are the compression/tension along the Qx + and Qy directions, respectively, get a strain map defined with respect to some other + right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, + xaxis_y). + + Args: + xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector + along the new x-axis + unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the + infinitessimal strain matrix elements, stored at + * unrotated_strain_map.get_slice('e_xx') + * unrotated_strain_map.get_slice('e_xy') + * unrotated_strain_map.get_slice('e_yy') + * unrotated_strain_map.get_slice('theta') + + Returns: + (RealSlice) the rotated counterpart to unrotated_strain_map, with the + rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate + system + """ + assert isinstance(unrotated_strain_map, RealSlice) + assert np.all( + [ + key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] + for key in unrotated_strain_map.slicelabels + ] + ) + theta = -np.arctan2(xaxis_y, xaxis_x) + cost = np.cos(theta) + sint = np.sin(theta) + cost2 = cost**2 + sint2 = sint**2 + + Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx, Ry)), + slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], + name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), + ) + + rotated_strain_map.data[0, :, :] = ( + cost2 * unrotated_strain_map.get_slice("e_xx").data + - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + sint2 * unrotated_strain_map.get_slice("e_yy").data + ) + rotated_strain_map.data[1, :, :] = ( + cost + * sint + * ( + unrotated_strain_map.get_slice("e_xx").data + - unrotated_strain_map.get_slice("e_yy").data + ) + + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data + ) + rotated_strain_map.data[2, :, :] = ( + sint2 * unrotated_strain_map.get_slice("e_xx").data + + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + cost2 * unrotated_strain_map.get_slice("e_yy").data + ) + if flip_theta == True: + rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data + else: + rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data + rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data + return rotated_strain_map diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py new file mode 100644 index 000000000..ab8a46a9a --- /dev/null +++ b/py4DSTEM/process/strain/strain.py @@ -0,0 +1,1579 @@ +# Defines the Strain class + +import warnings +from typing import Optional + +import matplotlib.pyplot as plt +from matplotlib.patches import Circle +from matplotlib.collections import PatchCollection +from mpl_toolkits.axes_grid1 import make_axes_locatable +import numpy as np +from py4DSTEM import PointList, PointListArray, tqdmnd +from py4DSTEM.braggvectors import BraggVectors +from py4DSTEM.data import Data, RealSlice +from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.process.strain.latticevectors import ( + add_indices_to_braggvectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_rotated_strain_map, + get_strain_from_reference_g1g2, + index_bragg_directions, +) +from py4DSTEM.visualize import ( + show, + add_bragg_index_labels, + add_pointlabels, + add_vector, + ax_addaxes, + ax_addaxes_QtoR, +) + + +class StrainMap(RealSlice, Data): + """ + Storage and processing methods for 4D-STEM datasets. + + """ + + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): + """ + Parameters + ---------- + braggvectors : BraggVectors + The Bragg vectors + name : str + The name of the strainmap + + Returns + ------- + A new StrainMap instance. + """ + assert isinstance( + braggvectors, BraggVectors + ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" + + # initialize as a RealSlice + RealSlice.__init__( + self, + name=name, + data=np.empty( + ( + 6, + braggvectors.Rshape[0], + braggvectors.Rshape[1], + ) + ), + slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], + ) + + # set up braggvectors + # this assigns the bvs, ensures the origin is calibrated, + # and adds the strainmap to the bvs' tree + self.braggvectors = braggvectors + + # initialize as Data + Data.__init__(self) + + # set calstate + # this property is used only to check to make sure that + # the braggvectors being used throughout a workflow are + # the same. The state of calibration of the vectors is noted + # here, and then checked each time the vectors are used - + # if they differ, an error message and instructions for + # re-calibration are issued + self.calstate = self.braggvectors.calstate + assert self.calstate["center"], "braggvectors must be centered" + if self.calstate["rotate"] == False: + warnings.warn( + ("Real to reciprocal space rotation not calibrated"), + UserWarning, + ) + + # get the BVM + # a new BVM using the current calstate is computed + self.bvm = self.braggvectors.histogram(mode="cal") + + # braggvector properties + + @property + def braggvectors(self): + return self._braggvectors + + @braggvectors.setter + def braggvectors(self, x): + assert isinstance( + x, BraggVectors + ), f".braggvectors must be BraggVectors, not type {type(x)}" + assert ( + x.calibration.origin is not None + ), "braggvectors must have a calibrated origin" + self._braggvectors = x + self._braggvectors.tree(self, force=True) + + @property + def rshape(self): + return self._braggvectors.Rshape + + @property + def qshape(self): + return self._braggvectors.Qshape + + @property + def origin(self): + return self.calibration.get_origin_mean() + + def reset_calstate(self): + """ + Resets the calibration state. This recomputes the BVM, and removes any computations + this StrainMap instance has stored, which will need to be recomputed. + """ + for attr in ( + "g0", + "g1", + "g2", + ): + if hasattr(self, attr): + delattr(self, attr) + self.calstate = self.braggvectors.calstate + pass + + # Class methods + + def choose_basis_vectors( + self, + index_g1=None, + index_g2=None, + index_origin=None, + subpixel="multicorr", + upsample_factor=16, + sigma=0, + minAbsoluteIntensity=0, + minRelativeIntensity=0, + relativeToPeak=0, + minSpacing=0, + edgeBoundary=1, + maxNumPeaks=10, + x0=None, + y0=None, + figsize=(14, 9), + c_indices="lightblue", + c0="g", + c1="r", + c2="r", + c_vectors="r", + c_vectorlabels="w", + size_indices=15, + width_vectors=1, + size_vectorlabels=15, + vis_params={}, + returncalc=False, + returnfig=False, + ): + """ + Choose basis lattice vectors g1 and g2 for strain mapping. + + Overlays the bvm with the points detected via local 2D + maxima detection, plus an index for each point. Three points + are selected which correspond to the origin, and the basis + reciprocal lattice vectors g1 and g2. By default these are + automatically located; the user can override and select these + manually using the `index_*` arguments. + + Parameters + ---------- + index_g1 : int + selected index for g1 + index_g2 :int + selected index for g2 + index_origin : int + selected index for the origin + subpixel : str in ('pixel','poly','multicorr') + See the docstring for py4DSTEM.preprocess.get_maxima_2D + upsample_factor : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + sigma : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + figsize : 2-tuple + the size of the figure + c_indices : color + color of the maxima + c0 : color + color of the origin + c1 : color + color of g1 point + c2 : color + color of g2 point + c_vectors : color + color of the g1/g2 vectors + c_vectorlabels : color + color of the vector labels + size_indices : number + size of the indices + width_vectors : number + width of the vectors + size_vectorlabels : number + size of the vector labels + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + toggles returning the answer + returnfig : bool + toggles returning the figure + + Returns + ------- + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or the latter two + """ + # validate inputs + for i in (index_origin, index_g1, index_g2): + assert isinstance(i, (int, np.integer)) or ( + i is None + ), "indices must be integers!" + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # find the maxima + + g = get_maxima_2D( + self.bvm.data, + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + ) + + # guess the origin and g1 g2 vectors if indices aren't provided + if np.any([x is None for x in (index_g1, index_g2, index_origin)]): + # get distances and angles from calibrated origin + g_dists = np.hypot(g["x"] - self.origin[0], g["y"] - self.origin[1]) + g_angles = np.angle( + g["x"] - self.origin[0] + 1j * (g["y"] - self.origin[1]) + ) + + # guess the origin + if index_origin is None: + index_origin = np.argmin(g_dists) + g_dists[index_origin] = 2 * np.max(g_dists) + + # guess g1 + if index_g1 is None: + index_g1 = np.argmin(g_dists) + g_dists[index_g1] = 2 * np.max(g_dists) + + # guess g2 + if index_g2 is None: + angle_scaling = np.cos(g_angles - g_angles[index_g1]) ** 2 + index_g2 = np.argmin(g_dists * (angle_scaling + 0.1)) + + # get the lattice vectors + gx, gy = g["x"], g["y"] + g0 = gx[index_origin], gy[index_origin] + g1x = gx[index_g1] - g0[0] + g1y = gy[index_g1] - g0[1] + g2x = gx[index_g2] - g0[0] + g2y = gy[index_g2] - g0[1] + g1, g2 = (g1x, g1y), (g2x, g2y) + + # index the lattice vectors + _, _, braggdirections = index_bragg_directions( + g0[0], g0[1], g["x"], g["y"], g1, g2 + ) + + # make the figure + fig, ax = plt.subplots(1, 3, figsize=figsize) + show(self.bvm.data, figax=(fig, ax[0]), **vis_params) + show(self.bvm.data, figax=(fig, ax[1]), **vis_params) + self.show_bragg_indexing( + self.bvm.data, + bragg_directions=braggdirections, + points=True, + figax=(fig, ax[2]), + size=size_indices, + **vis_params, + ) + + # Add indices to left panel + d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} + d0 = { + "x": gx[index_origin], + "y": gy[index_origin], + "size": size_indices, + "color": c0, + "fontweight": "bold", + "labels": [str(index_origin)], + } + d1 = { + "x": gx[index_g1], + "y": gy[index_g1], + "size": size_indices, + "color": c1, + "fontweight": "bold", + "labels": [str(index_g1)], + } + d2 = { + "x": gx[index_g2], + "y": gy[index_g2], + "size": size_indices, + "color": c2, + "fontweight": "bold", + "labels": [str(index_g2)], + } + add_pointlabels(ax[0], d) + add_pointlabels(ax[0], d0) + add_pointlabels(ax[0], d1) + add_pointlabels(ax[0], d2) + + # Add vectors to right panel + dg1 = { + "x0": gx[index_origin], + "y0": gy[index_origin], + "vx": g1[0], + "vy": g1[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_1$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + dg2 = { + "x0": gx[index_origin], + "y0": gy[index_origin], + "vx": g2[0], + "vy": g2[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_2$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + add_vector(ax[1], dg1) + add_vector(ax[1], dg2) + + # store vectors + self.g = g + self.g0 = g0 + self.g1 = g1 + self.g2 = g2 + + # center the bragg directions and store + braggdirections.data["qx"] -= self.origin[0] + braggdirections.data["qy"] -= self.origin[1] + self.braggdirections = braggdirections + + # return + if returncalc and returnfig: + return (self.g0, self.g1, self.g2, self.braggdirections), (fig, ax) + elif returncalc: + return (self.g0, self.g1, self.g2, self.braggdirections) + elif returnfig: + return (fig, ax) + else: + return + + def set_max_peak_spacing( + self, + max_peak_spacing, + returnfig=False, + **vis_params, + ): + """ + Set the size of the regions of diffraction space in which detected Bragg + peaks will be indexed and included in subsequent fitting of basis + vectors, and visualize those regions. + + Parameters + ---------- + max_peak_spacing : number + The maximum allowable distance in pixels between a detected Bragg peak and + the indexed maxima found in `choose_basis_vectors` for the detected + peak to be indexed + returnfig : bool + Toggles returning the figure + vis_params : dict + Any additional arguments are passed to the `show` function when + visualization the BVM + """ + # set the max peak spacing + self.max_peak_spacing = max_peak_spacing + + # make the figure + fig, ax = show( + self.bvm.data, + returnfig=True, + **vis_params, + ) + + # make the circle patch collection + patches = [] + qx = self.braggdirections["qx"] + qy = self.braggdirections["qy"] + origin = self.origin + for idx in range(len(qx)): + c = Circle( + xy=(qy[idx] + origin[1], qx[idx] + origin[0]), + radius=self.max_peak_spacing, + edgecolor="r", + fill=False, + ) + patches.append(c) + pc = PatchCollection(patches, match_original=True) + + # draw the circles + ax.add_collection(pc) + + # return + if returnfig: + return fig, ax + else: + plt.show() + + def fit_basis_vectors( + self, mask=None, max_peak_spacing=None, vis_params={}, returncalc=False + ): + """ + Fit the basis lattice vectors to the detected Bragg peaks at each + scan position. + + First, the lattice vectors at each scan position are indexed using the + basis vectors g1 and g2 specified previously with `choose_basis_vectors` + Detected Bragg peaks which are farther from the set of lattice vectors + found in `choose_basis vectors` than the maximum peak spacing are + ignored; the maximum peak spacing can be set previously by calling + `set_max_peak_spacing` or by specifying the `max_peak_spacing` argument + here. A fit is then performed to refine the values of g1 and g2 at each + scan position, fitting the basis vectors to all detected and indexed + peaks, weighting the peaks according to their intensity. + + Parameters + ---------- + mask : 2d boolean array + A real space shaped Boolean mask indicating scan positions at which + to fit the lattice vectors. + max_peak_spacing : float + Maximum distance from the ideal lattice points to include a peak + for indexing + vis_params : dict + Visualization parameters for showing the max peak spacing; ignored + if `max_peak_spacing` is not set + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # handle the max peak spacing + if max_peak_spacing is not None: + self.set_max_peak_spacing(max_peak_spacing, **vis_params) + assert hasattr(self, "max_peak_spacing"), "Set the maximum peak spacing!" + + # index the bragg vectors + + # handle the mask + if mask is None: + mask = np.ones(self.braggvectors.Rshape, dtype=bool) + assert ( + mask.shape == self.braggvectors.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + self.mask = mask + + # set up new braggpeaks PLA + indexed_braggpeaks = PointListArray( + dtype=[ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ], + shape=self.braggvectors.Rshape, + ) + + # loop over all the scan positions + # and perform indexing, excluding peaks outside of max_peak_spacing + calstate = self.braggvectors.calstate + for Rx, Ry in tqdmnd( + mask.shape[0], + mask.shape[1], + desc="Indexing Bragg scattering", + unit="DP", + unit_scale=True, + ): + if mask[Rx, Ry]: + pl = self.braggvectors.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r = np.hypot( + pl.data["qx"][i] - self.braggdirections.data["qx"], + pl.data["qy"][i] - self.braggdirections.data["qy"], + ) + ind = np.argmin(r) + if r[ind] <= self.max_peak_spacing: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + self.braggdirections.data["h"][ind], + self.braggdirections.data["k"][ind], + ) + ) + self.bragg_vectors_indexed = indexed_braggpeaks + + # fit bragg vectors + g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) + self.g1g2_map = g1g2_map + + # update the mask + g1g2_mask = self.g1g2_map["mask"].data.astype("bool") + self.mask = np.logical_and(self.mask, g1g2_mask) + + # return + if returncalc: + return self.bragg_vectors_indexed, self.g1g2_map + + def get_strain( + self, gvects=None, coordinate_rotation=0, returncalc=False, **kwargs + ): + """ + Compute the strain as the deviation of the basis reciprocal lattice + vectors which have been fit at each scan position with respect to a + pair of reference lattice vectors, determined by the argument `gvects`. + + Parameters + ---------- + gvects : None or 2d-array or tuple + Specifies how to select the reference lattice vectors. If None, + use the median of the fit lattice vectors over the whole dataset. + If a 2d array is passed, it should be real space shaped and boolean. + In this case, uses the median of the fit lattice vectors in all scan + positions where this array is True. Otherwise, should be a length 2 + tuple of length 2 array/list/tuples, which are used directly as + g1 and g2. + coordinate_rotation : number + Rotate the reference coordinate system counterclockwise by this + amount, in degrees + returncal : bool + It True, returns rotated map + """ + # confirm that the calstate hasn't changed + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # get the reference g-vectors + if gvects is None: + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, self.mask) + elif isinstance(gvects, np.ndarray): + assert gvects.shape == self.rshape + assert gvects.dtype == bool + g1_ref, g2_ref = get_reference_g1g2( + self.g1g2_map, np.logical_and(gvects, self.mask) + ) + else: + g1_ref = np.array(gvects[0]) + g2_ref = np.array(gvects[1]) + + # find the strain + strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) + self.strainmap_g1g2 = strainmap_g1g2 + + # get the reference coordinate system + theta = np.radians(coordinate_rotation) + xaxis_x = np.cos(theta) + xaxis_y = np.sin(theta) + self.coordinate_rotation_degrees = coordinate_rotation + self.coordinate_rotation_radians = theta + + # get the strain in the reference coordinates + strainmap_rotated = get_rotated_strain_map( + self.strainmap_g1g2, + xaxis_x=xaxis_x, + xaxis_y=xaxis_y, + flip_theta=False, + ) + + # store the data + self.data[0] = strainmap_rotated["e_xx"].data + self.data[1] = strainmap_rotated["e_yy"].data + self.data[2] = strainmap_rotated["e_xy"].data + self.data[3] = strainmap_rotated["theta"].data + self.data[4] = strainmap_rotated["mask"].data + + # plot the results + fig, ax = self.show_strain( + **kwargs, + returnfig=True, + ) + + # return + if returncalc: + return self.strainmap + + def get_reference_g1g2(self, ROI): + """ + Get reference g1,g2 vectors by taking the median fit vectors + in the specified ROI. + + Parameters + ---------- + ROI : real space shaped 2d boolean ndarray + Use scan positions where ROI is True + + Returns + ------- + g1_ref,g2_ref : 2 tuple of length 2 ndarrays + """ + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, ROI) + return g1_ref, g2_ref + + def show_strain( + self, + vrange=[-3, 3], + vrange_theta=[-3, 3], + vrange_exx=None, + vrange_exy=None, + vrange_eyy=None, + bkgrd=True, + show_cbars=None, + bordercolor="k", + borderwidth=1, + titlesize=18, + ticklabelsize=10, + ticknumber=5, + unitlabelsize=16, + cmap="RdBu_r", + cmap_theta="PRGn", + mask_color="k", + color_axes="k", + show_gvects=True, + color_gvects="r", + legend_camera_length=1.6, + scale_gvects=0.6, + layout="square", + figsize=None, + returnfig=False, + ): + """ + Display a strain map, showing the 4 strain components + (e_xx,e_yy,e_xy,theta), and masking each image with + strainmap.get_slice('mask') + + Parameters + ---------- + vrange : length 2 list or tuple + The colorbar intensity range for exx,eyy, and exy. + vrange_theta : length 2 list or tuple + The colorbar intensity range for theta. + vrange_exx : length 2 list or tuple + The colorbar intensity range for exx; overrides `vrange` + for exx + vrange_exy : length 2 list or tuple + The colorbar intensity range for exy; overrides `vrange` + for exy + vrange_eyy : length 2 list or tuple + The colorbar intensity range for eyy; overrides `vrange` + for eyy + bkgrd : bool + Overlay a mask over background pixels + show_cbars : None or a tuple of strings + Show colorbars for the specified axes. Valid strings are + 'exx', 'eyy', 'exy', and 'theta'. + bordercolor : color + Color for the image borders + borderwidth : number + Width of the image borders + titlesize : number + Size of the image titles + ticklabelsize : number + Size of the colorbar ticks + ticknumber : number + Number of ticks on colorbars + unitlabelsize : number + Size of the units label on the colorbars + cmap : colormap + Colormap for exx, exy, and eyy + cmap_theta : colormap + Colormap for theta + mask_color : color + Color for the background mask + color_axes : color + Color for the legend coordinate axes + show_gvects : bool + Toggles displaying the g-vectors in the legend + color_gvects : color + Color for the legend g-vectors + legend_camera_length : number + The distance the legend is viewed from; a smaller number yields + a larger legend + scale_gvects : number + Scaling for the legend g-vectors relative to the coordinate axes + layout : int + Determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize : length 2 tuple of numbers + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Lookup table for different layouts + assert layout in ("square", "horizontal", "vertical") + layout_lookup = { + "square": ["left", "right", "left", "right"], + "horizontal": ["bottom", "bottom", "bottom", "bottom"], + "vertical": ["right", "right", "right", "right"], + } + + layout_p = layout_lookup[layout] + + # Set which colorbars to display + if show_cbars is None: + if np.all( + [ + v is None + for v in ( + vrange_exx, + vrange_eyy, + vrange_exy, + ) + ] + ): + show_cbars = ("eyy", "theta") + else: + show_cbars = ("exx", "eyy", "exy", "theta") + else: + assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars]) + + # Contrast limits + if vrange_exx is None: + vrange_exx = vrange + if vrange_exy is None: + vrange_exy = vrange + if vrange_eyy is None: + vrange_eyy = vrange + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 + # theta is plotted in units of degrees + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) + + # Get images + e_xx = np.ma.array( + self.get_slice("exx").data, mask=self.get_slice("mask").data == False + ) + e_yy = np.ma.array( + self.get_slice("eyy").data, mask=self.get_slice("mask").data == False + ) + e_xy = np.ma.array( + self.get_slice("exy").data, mask=self.get_slice("mask").data == False + ) + theta = np.ma.array( + self.get_slice("theta").data, + mask=self.get_slice("mask").data == False, + ) + + ## Plot + + # if figsize hasn't been set, set it based on the + # chosen layout and the image shape + if figsize is None: + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) + if layout == "square": + figsize = (13 * ratio, 8 / ratio) + elif layout == "horizontal": + figsize = (10 * ratio, 4 / ratio) + else: + figsize = (4 * ratio, 10 / ratio) + + # set up layout + if layout == "square": + fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( + 2, 3, figsize=figsize + ) + elif layout == "horizontal": + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 1, 5, figsize=figsize + ) + else: + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 5, 1, figsize=figsize + ) + + # display images, returning cbar axis references + cax11 = show( + e_xx, + figax=(fig, ax11), + vmin=vmin_exx, + vmax=vmax_exx, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + cax12 = show( + e_yy, + figax=(fig, ax12), + vmin=vmin_eyy, + vmax=vmax_eyy, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + cax21 = show( + e_xy, + figax=(fig, ax21), + vmin=vmin_exy, + vmax=vmax_exy, + intensity_range="absolute", + cmap=cmap, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + cax22 = show( + theta, + figax=(fig, ax22), + vmin=vmin_theta, + vmax=vmax_theta, + intensity_range="absolute", + cmap=cmap_theta, + mask=self.mask, + mask_color=mask_color, + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) + + # Add black background + if bkgrd: + mask = np.ma.masked_where( + self.get_slice("mask").data.astype(bool), + np.zeros_like(self.get_slice("mask").data), + ) + ax11.matshow(mask, cmap="gray") + ax12.matshow(mask, cmap="gray") + ax21.matshow(mask, cmap="gray") + ax22.matshow(mask, cmap="gray") + + # add colorbars + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) + if np.any(show_cbars): + divider11 = make_axes_locatable(ax11) + divider12 = make_axes_locatable(ax12) + divider21 = make_axes_locatable(ax21) + divider22 = make_axes_locatable(ax22) + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( + range(4), + show_cbars, + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): + if show_cbar: + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) + if ind < 3: + ticklabels = np.round( + np.linspace( + 100 * vmin, 100 * vmax, ticknumber, endpoint=True + ), + decimals=2, + ).astype(str) + else: + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) + cbax.yaxis.set_ticks_position(tickside) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) + cbax.yaxis.set_label_position(tickside) + else: + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) + cbax.xaxis.set_ticks_position(tickside) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) + cbax.xaxis.set_label_position(tickside) + else: + cbax.axis("off") + + # Add borders + if bordercolor is not None: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: + ax.spines[s].set_color(bordercolor) + ax.spines[s].set_linewidth(borderwidth) + ax.set_xticks([]) + ax.set_yticks([]) + + # Legend + + # for layout "square", combine vertical plots on the right end + if layout == "square": + # get gridspec object + gs = ax_legend1.get_gridspec() + # remove last two axes + ax_legend1.remove() + ax_legend2.remove() + # make new axis + ax_legend = fig.add_subplot(gs[:, -1]) + + # get the coordinate axes' directions + rotation = self.coordinate_rotation_radians + xaxis_vectx = np.cos(rotation) + xaxis_vecty = np.sin(rotation) + yaxis_vectx = np.cos(rotation + np.pi / 2) + yaxis_vecty = np.sin(rotation + np.pi / 2) + + # make the coordinate axes + ax_legend.arrow( + x=0, + y=0, + dx=xaxis_vecty, + dy=xaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.arrow( + x=0, + y=0, + dx=yaxis_vecty, + dy=yaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.text( + x=xaxis_vecty * 1.16, + y=xaxis_vectx * 1.16, + s="x", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=yaxis_vecty * 1.16, + y=yaxis_vectx * 1.16, + s="y", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # make the g-vectors + if show_gvects: + # get the g-vectors directions + g1q = np.array(self.g1) + g2q = np.array(self.g2) + g1norm = np.linalg.norm(g1q) + g2norm = np.linalg.norm(g2q) + g1q /= g1norm + g2q /= g2norm + # set the lengths + g_ratio = g2norm / g1norm + if g_ratio > 1: + g1q /= g_ratio + else: + g2q *= g_ratio + g1_x, g1_y = g1q + g2_x, g2_y = g2q + + # draw the g vectors + ax_legend.arrow( + x=0, + y=0, + dx=g1_y * scale_gvects, + dy=g1_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.arrow( + x=0, + y=0, + dx=g2_y * scale_gvects, + dy=g2_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.text( + x=g1_y * scale_gvects * 1.2, + y=g1_x * scale_gvects * 1.2, + s=r"$g_1$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=g2_y * scale_gvects * 1.2, + y=g2_x * scale_gvects * 1.2, + s=r"$g_2$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # find center and extent + xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx]) + xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx]) + ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty]) + ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty]) + if show_gvects: + xmin = np.min([xmin, g1_x, g2_x]) + xmax = np.max([xmax, g1_x, g2_x]) + ymin = np.min([ymin, g1_y, g2_y]) + ymax = np.max([ymax, g1_y, g2_y]) + x0 = np.mean([xmin, xmax]) + y0 = np.mean([ymin, ymax]) + xL = (xmax - x0) * legend_camera_length + yL = (ymax - y0) * legend_camera_length + + # set the extent and aspect + ax_legend.set_xlim([y0 - yL, y0 + yL]) + ax_legend.set_ylim([x0 - xL, x0 + xL]) + ax_legend.invert_yaxis() + ax_legend.set_aspect("equal") + ax_legend.axis("off") + + # show/return + if not returnfig: + plt.show() + return + else: + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs + + def show_reference_directions( + self, + im_uncal=None, + im_cal=None, + color_axes="linen", + color_gvects="r", + origin_uncal=None, + origin_cal=None, + camera_length=1.8, + visp_uncal={"scaling": "log"}, + visp_cal={"scaling": "log"}, + layout="horizontal", + titlesize=16, + size_labels=14, + figsize=None, + returnfig=False, + ): + """ + Show the reference coordinate system used to compute the strain + overlaid over calibrated and uncalibrated diffraction space images. + + The diffraction images used can be specificied with the `im_uncal` + and `im_cal` arguments, and default to the uncalibrated and calibrated + Bragg vector maps. The `rotate_cal` argument causes the `im_cal` array + to be rotated by -QR rotation from the calibration metadata, so that an + uncalibrated image (like a raw diffraction image or mean or max + diffraction pattern) can be passed to the `im_cal` argument. + + Parameters + ---------- + im_uncal : 2d array or None + Uncalibrated diffraction space image to dispay; defaults to + the maximal diffraction image. + im_cal : 2d array or None + Calibrated diffraction space image to display; defaults to + the calibrated Bragg vector map. + color_axes : color + The color of the overlaid coordinate axes + color_gvects : color + The color of the g-vectors + origin_uncal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the uncalibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + origin_cal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the calibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + camera_length : number + Determines the length of the overlaid coordinate axes; a smaller + number yields larger axes. + visp_uncal : dict + Visualization parameters for the uncalibrated diffraction image. + visp_cal : dict + Visualization parameters for the calibrated diffraction image. + layout : str; either "horizontal" or "vertical" + Determines the layout of the visualization. + titlesize : number + The size of the plot titles + size_labels : number + The size of the axis labels + figsize : length 2 tuple of numbers or None + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Set up the figure + assert layout in ("horizontal", "vertical") + + # Set the figsize + if figsize is None: + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) + if layout == "horizontal": + figsize = (10 * ratio, 8 / ratio) + else: + figsize = (8 * ratio, 12 / ratio) + + # Create the figure + if layout == "horizontal": + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + else: + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) + + # prepare images + if im_uncal is None: + im_uncal = self.braggvectors.histogram(mode="raw") + if im_cal is None: + im_cal = self.braggvectors.histogram(mode="cal") + + # display images + show(im_cal, figax=(fig, ax1), **visp_cal) + show(im_uncal, figax=(fig, ax2), **visp_uncal) + ax1.set_title("Calibrated", size=titlesize) + ax2.set_title("Uncalibrated", size=titlesize) + + # Get the coordinate axes + + # get the directions + + # calibrated + rotation = self.coordinate_rotation_radians + xaxis_cal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_cal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + + # uncalibrated + QRrot = self.calibration.get_QR_rotation() + rotation = np.sum([self.coordinate_rotation_radians, -QRrot]) + xaxis_uncal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_uncal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + # inversion + if self.calibration.get_QR_flip(): + xaxis_uncal = np.array([xaxis_uncal[1], xaxis_uncal[0]]) + yaxis_uncal = np.array([yaxis_uncal[1], yaxis_uncal[0]]) + + # set the lengths + Lmean = np.mean([im_cal.shape[0], im_cal.shape[1]]) / 2 + xaxis_cal *= Lmean / camera_length + yaxis_cal *= Lmean / camera_length + xaxis_uncal *= Lmean / camera_length + yaxis_uncal *= Lmean / camera_length + + # Get the g-vectors + + # calibrated + g1_cal = np.array(self.g1) + g2_cal = np.array(self.g2) + + # uncalibrated + R = np.array([[np.cos(QRrot), -np.sin(QRrot)], [np.sin(QRrot), np.cos(QRrot)]]) + g1_uncal = np.matmul(g1_cal, R) + g2_uncal = np.matmul(g2_cal, R) + # inversion + if self.calibration.get_QR_flip(): + g1_uncal = np.array([g1_uncal[1], g1_uncal[0]]) + g2_uncal = np.array([g2_uncal[1], g2_uncal[0]]) + + # Set origin positions + if origin_uncal is None: + origin_uncal = self.calibration.get_origin_mean() + if origin_cal is None: + origin_cal = self.calibration.get_origin_mean() + + # Draw calibrated coordinate axes + coordax_width = Lmean * 2 / 100 + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=xaxis_cal[1], + dy=xaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=yaxis_cal[1], + dy=yaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax1.text( + x=origin_cal[1] + xaxis_cal[1] * 1.16, + y=origin_cal[0] + xaxis_cal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax1.text( + x=origin_cal[1] + yaxis_cal[1] * 1.16, + y=origin_cal[0] + yaxis_cal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw uncalibrated coordinate axes + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=xaxis_uncal[1], + dy=xaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=yaxis_uncal[1], + dy=yaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax2.text( + x=origin_uncal[1] + xaxis_uncal[1] * 1.16, + y=origin_uncal[0] + xaxis_uncal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax2.text( + x=origin_uncal[1] + yaxis_uncal[1] * 1.16, + y=origin_uncal[0] + yaxis_uncal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw the calibrated g-vectors + + # draw the g vectors + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=g1_cal[1], + dy=g1_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=g2_cal[1], + dy=g2_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax1.text( + x=origin_cal[1] + g1_cal[1] * 1.16, + y=origin_cal[0] + g1_cal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax1.text( + x=origin_cal[1] + g2_cal[1] * 1.16, + y=origin_cal[0] + g2_cal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw the uncalibrated g-vectors + + # draw the g vectors + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=g1_uncal[1], + dy=g1_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=g2_uncal[1], + dy=g2_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax2.text( + x=origin_uncal[1] + g1_uncal[1] * 1.16, + y=origin_uncal[0] + g1_uncal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax2.text( + x=origin_uncal[1] + g2_uncal[1] * 1.16, + y=origin_uncal[0] + g2_uncal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # show/return + if not returnfig: + plt.show() + return + else: + return fig, (ax1, ax2) + + def show_lattice_vectors( + ar, + x0, + y0, + g1, + g2, + color="r", + width=1, + labelsize=20, + labelcolor="w", + returnfig=False, + **kwargs, + ): + """ + Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). + g1 and g2 are 2-tuples (gx,gy). + """ + fig, ax = show(ar, returnfig=True, **kwargs) + + # Add vectors + dg1 = { + "x0": x0, + "y0": y0, + "vx": g1[0], + "vy": g1[1], + "width": width, + "color": color, + "label": r"$g_1$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + dg2 = { + "x0": x0, + "y0": y0, + "vx": g2[0], + "vy": g2[1], + "width": width, + "color": color, + "label": r"$g_2$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + add_vector(ax, dg1) + add_vector(ax, dg2) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def show_bragg_indexing( + self, + ar, + bragg_directions, + voffset=5, + hoffset=0, + color="w", + size=20, + points=True, + pointcolor="r", + pointsize=50, + figax=None, + returnfig=False, + **kwargs, + ): + """ + Shows an array with an overlay describing the Bragg directions + + Parameters + ---------- + ar : np.ndarray + The display image + bragg_directions : PointList + The Bragg scattering directions. Must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. + """ + assert isinstance(bragg_directions, PointList) + for k in ("qx", "qy", "h", "k"): + assert k in bragg_directions.data.dtype.fields + + if figax is None: + fig, ax = show(ar, returnfig=True, **kwargs) + else: + fig = figax[0] + ax = figax[1] + show(ar, figax=figax, **kwargs) + + d = { + "bragg_directions": bragg_directions, + "voffset": voffset, + "hoffset": hoffset, + "color": color, + "size": size, + "points": points, + "pointsize": pointsize, + "pointcolor": pointcolor, + } + add_bragg_index_labels(ax, d) + + if returnfig: + return fig, ax + else: + return + + def copy(self, name=None): + name = name if name is not None else self.name + "_copy" + strainmap_copy = StrainMap(self.braggvectors) + for attr in ( + "g", + "g0", + "g1", + "g2", + "calstate", + "bragg_directions", + "bragg_vectors_indexed", + "g1g2_map", + "strainmap_g1g2", + "strainmap_rotated", + "mask", + ): + if hasattr(self, attr): + setattr(strainmap_copy, attr, getattr(self, attr)) + + for k in self.metadata.keys(): + strainmap_copy.metadata = self.metadata[k].copy() + return strainmap_copy + + # TODO IO methods + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = RealSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args diff --git a/py4DSTEM/process/utils/cross_correlate.py b/py4DSTEM/process/utils/cross_correlate.py index f9aac1312..50de91e33 100644 --- a/py4DSTEM/process/utils/cross_correlate.py +++ b/py4DSTEM/process/utils/cross_correlate.py @@ -6,8 +6,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def get_cross_correlation(ar, template, corrPower=1, _returnval="real"): diff --git a/py4DSTEM/process/utils/multicorr.py b/py4DSTEM/process/utils/multicorr.py index 8523c8e62..bc07390bb 100644 --- a/py4DSTEM/process/utils/multicorr.py +++ b/py4DSTEM/process/utils/multicorr.py @@ -15,8 +15,8 @@ try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def upsampled_correlation(imageCorr, upsampleFactor, xyShift, device="cpu"): diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 03d3d07a0..4ef2e1d8a 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -24,8 +24,8 @@ def clear_output(wait=True): try: import cupy as cp -except ImportError: - cp = None +except ModuleNotFoundError: + cp = np def radial_reduction(ar, x0, y0, binsize=1, fn=np.mean, coords=None): diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 224f1fb74..103751b29 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.4" +__version__ = "0.14.9" diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 996bb89b3..32baff443 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -408,7 +408,7 @@ def add_ellipses(ax, d): (cent[1], cent[0]), 2 * _b, 2 * _a, - -np.degrees(_theta), + angle=-np.degrees(_theta), color=col, fill=f, alpha=_alpha, @@ -832,7 +832,14 @@ def add_scalebar(ax, d): labelpos_y = y0 # Add line - ax.plot((yi, yf), (xi, xf), lw=width, color=color, alpha=alpha) + ax.plot( + (yi, yf), + (xi, xf), + color=color, + alpha=alpha, + lw=width, + solid_capstyle="butt", + ) # Add label if label: diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index fb99de5ae..00309ec36 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -8,6 +8,7 @@ from matplotlib.axes import Axes from matplotlib.colors import is_color_like from matplotlib.figure import Figure +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM.data import Calibration, DiffractionSlice, RealSlice from py4DSTEM.visualize.overlay import ( add_annuli, @@ -25,7 +26,7 @@ def show( ar, - figsize=(8, 8), + figsize=(5, 5), cmap="gray", scaling="none", intensity_range="ordered", @@ -75,6 +76,7 @@ def show( theta=None, title=None, show_fft=False, + show_cbar=False, **kwargs ): """ @@ -302,7 +304,8 @@ def show( does not add a scalebar. If a dict is passed, it is propagated to the add_scalebar function which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. - show_fft (Bool): if True, plots 2D-fft of array + show_fft (bool): if True, plots 2D-fft of array + show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() Returns: @@ -366,7 +369,9 @@ def show( from py4DSTEM.visualize import show if show_fft: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) for a0 in range(num_images): im = show( ar[a0], @@ -605,6 +610,10 @@ def show( ax.matshow( mask_display, cmap=cmap, alpha=mask_alpha, vmin=vmin, vmax=vmax ) + if show_cbar: + ax_divider = make_axes_locatable(ax) + c_axis = ax_divider.append_axes("right", size="7%") + fig.colorbar(cax, cax=c_axis) # ...or, plot its histogram else: hist, bin_edges = np.histogram( diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 423146892..acacb6184 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -1,6 +1,5 @@ from matplotlib import cm, colors as mcolors, pyplot as plt import numpy as np -from matplotlib.colors import hsv_to_rgb from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi @@ -17,6 +16,7 @@ ) from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR +from colorspacious import cspace_convert def show_elliptical_fit( @@ -404,308 +404,6 @@ def show_class_BPs_grid( return fig, axs -def show_strain( - strainmap, - vrange_exx, - vrange_theta, - vrange_exy=None, - vrange_eyy=None, - flip_theta=False, - bkgrd=True, - show_cbars=("exx", "eyy", "exy", "theta"), - bordercolor="k", - borderwidth=1, - titlesize=24, - ticklabelsize=16, - ticknumber=5, - unitlabelsize=24, - show_axes=True, - axes_x0=0, - axes_y0=0, - xaxis_x=1, - xaxis_y=0, - axes_length=10, - axes_width=1, - axes_color="r", - xaxis_space="Q", - labelaxes=True, - QR_rotation=0, - axes_labelsize=12, - axes_labelcolor="r", - axes_plots=("exx"), - cmap="RdBu_r", - layout=0, - figsize=(12, 12), - returnfig=False, -): - """ - Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') - - Args: - strainmap (RealSlice): - vrange_exx (length 2 list or tuple): - vrange_theta (length 2 list or tuple): - vrange_exy (length 2 list or tuple): - vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle - bkgrd (bool): - show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a - tuple containing any, all, or none of ('exx','eyy','exy','theta'). - bordercolor (color): - borderwidth (number): - titlesize (number): - ticklabelsize (number): - ticknumber (number): number of ticks on colorbars - unitlabelsize (number): - show_axes (bool): - axes_x0 (number): - axes_y0 (number): - xaxis_x (number): - xaxis_y (number): - axes_length (number): - axes_width (number): - axes_color (color): - xaxis_space (string): must be 'Q' or 'R' - labelaxes (bool): - QR_rotation (number): - axes_labelsize (number): - axes_labelcolor (color): - axes_plots (tuple of strings): controls if coordinate axes showing the - orientation of the strain matrices are overlaid over any of the plots. - Must be a tuple of strings containing any, all, or none of - ('exx','eyy','exy','theta'). - cmap (colormap): - layout=0 (int): determines the layout of the grid which the strain components - will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). - figsize (length 2 tuple of numbers): - returnfig (bool): - """ - # Lookup table for different layouts - assert layout in (0, 1, 2) - layout_lookup = { - 0: ["left", "right", "left", "right"], - 1: ["bottom", "bottom", "bottom", "bottom"], - 2: ["right", "right", "right", "right"], - } - layout_p = layout_lookup[layout] - - # Contrast limits - if vrange_exy is None: - vrange_exy = vrange_exx - if vrange_eyy is None: - vrange_eyy = vrange_exx - for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): - assert len(vrange) == 2, "vranges must have length 2" - vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 - vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 - vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 - # theta is plotted in units of degrees - vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( - 180.0 / np.pi - ) - - # Get images - e_xx = np.ma.array( - strainmap.get_slice("e_xx").data, mask=strainmap.get_slice("mask").data == False - ) - e_yy = np.ma.array( - strainmap.get_slice("e_yy").data, mask=strainmap.get_slice("mask").data == False - ) - e_xy = np.ma.array( - strainmap.get_slice("e_xy").data, mask=strainmap.get_slice("mask").data == False - ) - theta = np.ma.array( - strainmap.get_slice("theta").data, - mask=strainmap.get_slice("mask").data == False, - ) - if flip_theta == True: - theta = -theta - - # Plot - if layout == 0: - fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) - elif layout == 1: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) - else: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) - cax11 = show( - e_xx, - figax=(fig, ax11), - vmin=vmin_exx, - vmax=vmax_exx, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax12 = show( - e_yy, - figax=(fig, ax12), - vmin=vmin_eyy, - vmax=vmax_eyy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax21 = show( - e_xy, - figax=(fig, ax21), - vmin=vmin_exy, - vmax=vmax_exy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax22 = show( - theta, - figax=(fig, ax22), - vmin=vmin_theta, - vmax=vmax_theta, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) - ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) - ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) - ax22.set_title(r"$\theta$", size=titlesize) - - # Add black background - if bkgrd: - mask = np.ma.masked_where( - strainmap.get_slice("mask").data.astype(bool), - np.zeros_like(strainmap.get_slice("mask").data), - ) - ax11.matshow(mask, cmap="gray") - ax12.matshow(mask, cmap="gray") - ax21.matshow(mask, cmap="gray") - ax22.matshow(mask, cmap="gray") - - # Colorbars - show_cbars = np.array( - [ - "exx" in show_cbars, - "eyy" in show_cbars, - "exy" in show_cbars, - "theta" in show_cbars, - ] - ) - if np.any(show_cbars): - divider11 = make_axes_locatable(ax11) - divider12 = make_axes_locatable(ax12) - divider21 = make_axes_locatable(ax21) - divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) - cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) - cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) - cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) - for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( - range(4), - show_cbars, - (cax11, cax12, cax21, cax22), - (cbax11, cbax12, cbax21, cbax22), - (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), - (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), - (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), - ("% ", " %", "% ", r" $^\circ$"), - ): - if show_cbar: - ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) - if ind < 3: - ticklabels = np.round( - np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), - decimals=2, - ).astype(str) - else: - ticklabels = np.round( - np.linspace( - (180 / np.pi) * vmin, - (180 / np.pi) * vmax, - ticknumber, - endpoint=True, - ), - decimals=2, - ).astype(str) - - if tickside in ("left", "right"): - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="vertical" - ) - cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) - cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) - cbax.yaxis.set_label_position(tickside) - else: - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="horizontal" - ) - cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) - cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) - cbax.xaxis.set_label_position(tickside) - else: - cbax.axis("off") - - # Add coordinate axes - if show_axes: - assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array( - [ - "exx" in axes_plots, - "eyy" in axes_plots, - "exy" in axes_plots, - "theta" in axes_plots, - ] - ) - for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): - if _show: - if xaxis_space == "R": - ax_addaxes( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - else: - ax_addaxes_QtoR( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - QR_rotation, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - - # Add borders - if bordercolor is not None: - for ax in (ax11, ax12, ax21, ax22): - for s in ["bottom", "top", "left", "right"]: - ax.spines[s].set_color(bordercolor) - ax.spines[s].set_linewidth(borderwidth) - ax.set_xticks([]) - ax.set_yticks([]) - - if not returnfig: - plt.show() - return - else: - axs = ((ax11, ax12), (ax21, ax22)) - return fig, axs - - def show_pointlabels( ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs ): @@ -937,15 +635,20 @@ def show_selected_dps( ) -def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float) : power to raise amplitude to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.5) """ amp = np.abs(complex_data) + phase = np.angle(complex_data) + + if power is not None: + amp = amp**power + if np.isclose(np.max(amp), np.min(amp)): if vmin is None: vmin = 0 @@ -966,36 +669,40 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) + amp = ((amp - vmin) / vmax).clip(1e-16, 1) + + J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff + C = np.minimum(chroma_boost * 98 * J / 123, 110) + h = np.rad2deg(phase) + 180 - phase = np.angle(complex_data) + np.deg2rad(hue_start) - amp /= np.max(amp) - rgb = np.zeros(phase.shape + (3,)) - rgb[..., 0] = 0.5 * (np.sin(phase) + 1) * amp - rgb[..., 1] = 0.5 * (np.sin(phase + np.pi / 2) + 1) * amp - rgb[..., 2] = 0.5 * (-np.sin(phase) + 1) * amp + JCh = np.stack((J, C, h), axis=-1) + rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) - return 1 - rgb if invert else rgb + return rgb -def add_colorbar_arg(cax, vmin=None, vmax=None, hue_start=0, invert=False): +def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5): """ - cax : axis to add cbar too - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + cax : axis to add cbar to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.25) + c (float) : constant chroma value + j (float) : constant luminance value """ - z = np.exp(1j * np.linspace(-np.pi, np.pi, 200)) - rgb_vals = Complex2RGB(z, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert) + + h = np.linspace(0, 360, 256, endpoint=False) + J = np.full_like(h, j) + C = np.full_like(h, np.minimum(c * chroma_boost, 110)) + JCh = np.stack((J, C, h), axis=-1) + rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) - cb1 = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) - cb1.set_label("arg", rotation=0, ha="center", va="bottom") - cb1.ax.yaxis.set_label_coords(0.5, 1.01) - cb1.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) - cb1.set_ticklabels( + cb.set_label("arg", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) + cb.set_ticklabels( [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) @@ -1004,13 +711,13 @@ def show_complex( ar_complex, vmin=None, vmax=None, + power=None, + chroma_boost=1, cbar=True, scalebar=False, pixelunits="pixels", pixelsize=1, returnfig=False, - hue_start=0, - invert=False, **kwargs ): """ @@ -1023,13 +730,13 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels - cbar (bool, optional) : if True, include color wheel + power (float,optional) : power to raise amplitude to + chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25) + cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - hue_start (float, optional) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -1044,7 +751,7 @@ def show_complex( if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for sublist in ar_complex for ar in sublist ] @@ -1053,7 +760,7 @@ def show_complex( else: rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for ar in ar_complex ] if len(rgb[0].shape) == 4: @@ -1064,7 +771,9 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, hue_start=hue_start, invert=invert) + rgb = Complex2RGB( + ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost + ) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -1115,37 +824,74 @@ def show_complex( add_scalebar(ax, scalebar) # add color bar - if cbar == True: - ax0 = fig.add_axes([1, 0.35, 0.3, 0.3]) - - # create wheel - AA = 1000 - kx = np.fft.fftshift(np.fft.fftfreq(AA)) - ky = np.fft.fftshift(np.fft.fftfreq(AA)) - kya, kxa = np.meshgrid(ky, kx) - kra = (kya**2 + kxa**2) ** 0.5 - ktheta = np.arctan2(-kxa, kya) - ktheta = kra * np.exp(1j * ktheta) - - # convert to hsv - rgb = Complex2RGB(ktheta, 0, 0.4, hue_start=hue_start, invert=invert) - ind = kra > 0.4 - rgb[ind] = [1, 1, 1] - - # plot - ax0.imshow(rgb) - - # add axes - ax0.axhline(AA / 2, 0, AA, color="k") - ax0.axvline(AA / 2, 0, AA, color="k") - ax0.axis("off") - - label_size = 16 - - ax0.text(AA, AA / 2, 1, fontsize=label_size) - ax0.text(AA / 2, 0, "i", fontsize=label_size) - ax0.text(AA / 2, AA, "-i", fontsize=label_size) - ax0.text(0, AA / 2, -1, fontsize=label_size) - - if returnfig == True: + if cbar: + if is_grid: + for ax_flat in ax.flatten(): + divider = make_axes_locatable(ax_flat) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + else: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + fig.tight_layout() + + if returnfig: return fig, ax + + +def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + """ + Utility function for calculating min and max values for plotting array + based on distribution of pixel values + + Parameters + ---------- + array: np.array + array to be plotted + vmin: float + lower fraction cut off of pixel values + vmax: float + upper fraction cut off of pixel values + normalize: bool + if True, rescales from 0 to 1 + + Returns + ---------- + scaled_array: np.array + array clipped outside vmin and vmax + vmin: float + lower value to be plotted + vmax: float + upper value to be plotted + """ + + if vmin is None: + vmin = 0.02 + if vmax is None: + vmax = 0.98 + + vals = np.sort(array.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + + scaled_array = array.copy() + scaled_array = np.where(scaled_array < vmin, vmin, scaled_array) + scaled_array = np.where(scaled_array > vmax, vmax, scaled_array) + + if normalize: + scaled_array -= scaled_array.min() + scaled_array /= scaled_array.max() + vmin = 0 + vmax = 1 + + return scaled_array, vmin, vmax diff --git a/setup.py b/setup.py index 069bf1600..fad5913ea 100644 --- a/setup.py +++ b/setup.py @@ -37,9 +37,11 @@ "gdown >= 4.7.1", "dask >= 2.3.0", "distributed >= 2.3.0", - "emdfile >= 0.0.13", + "emdfile >= 0.0.14", "mpire >= 2.7.1", "threadpoolctl >= 3.1.0", + "pylops >= 2.1.0", + "colorspacious >= 1.1.2", ], extras_require={ "ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"], @@ -57,8 +59,8 @@ package_data={ "py4DSTEM": [ "process/utils/scattering_factors.txt", - "process/diskdetection/multicorr_row_kernel.cu", - "process/diskdetection/multicorr_col_kernel.cu", + "braggvectors/multicorr_row_kernel.cu", + "braggvectors/multicorr_col_kernel.cu", ] }, )