diff --git a/.github/scripts/update_version.py b/.github/scripts/update_version.py index 2aaaa07af..635cf8268 100644 --- a/.github/scripts/update_version.py +++ b/.github/scripts/update_version.py @@ -8,7 +8,7 @@ lines = f.readlines() line_split = lines[0].split(".") -patch_number = line_split[2].split("'")[0] +patch_number = line_split[2].split("'")[0].split('"')[0] # Increment patch number patch_number = str(int(patch_number) + 1) + "'" diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml index 82701d50d..4e9d16f77 100644 --- a/.github/workflows/check_install_dev.yml +++ b/.github/workflows/check_install_dev.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11",] + python-version: ["3.9", "3.10", "3.11", "3.12"] # include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest diff --git a/.github/workflows/check_install_main.yml b/.github/workflows/check_install_main.yml index 2d1c8ed2a..a276cab17 100644 --- a/.github/workflows/check_install_main.yml +++ b/.github/workflows/check_install_main.yml @@ -16,7 +16,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest, windows-latest, macos-latest] architecture: [x86_64] - python-version: ["3.9", "3.10", "3.11",] + python-version: ["3.9", "3.10", "3.11", "3.12"] #include: # - python-version: "3.12.0-beta.4" # runs-on: ubuntu-latest diff --git a/.github/workflows/check_install_quick.yml b/.github/workflows/check_install_quick.yml index a36db34da..f83ee0b73 100644 --- a/.github/workflows/check_install_quick.yml +++ b/.github/workflows/check_install_quick.yml @@ -20,7 +20,7 @@ jobs: allow_failure: [false] runs-on: [ubuntu-latest] architecture: [x86_64] - python-version: ["3.10"] + python-version: ["3.9", "3.12"] # Currently no public runners available for this but this or arm64 should work next time # include: # - python-version: "3.10" diff --git a/.gitignore b/.gitignore index 6c008b0ff..24587a3b3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.swp *.ipynb_checkpoints* .vscode/ +pyrightconfig.json # Folders # .idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..a90d54b54 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: + # Using this mirror lets us use mypyc-compiled black, which is about 2x faster + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.11.0 + hooks: + - id: black + language_version: python3.9 + - repo: https://github.com/pycqa/flake8 + rev: '6.1.0' + hooks: + - id: flake8 \ No newline at end of file diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index dcb6a861d..d5df63f5e 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -52,17 +52,25 @@ BraggVectorMap, ) -# strain -from py4DSTEM.process import StrainMap +from py4DSTEM.process import classification -# TODO - crystal -# TODO - ptycho -# TODO - others -# TODO - where -from py4DSTEM.process import ( - PolarDatacube, -) +# diffraction +from py4DSTEM.process.diffraction import Crystal, Orientation + + +# ptycho +from py4DSTEM.process import phase + + +# polar +from py4DSTEM.process.polar import PolarDatacube + + +# strain +from py4DSTEM.process.strain.strain import StrainMap + +from py4DSTEM.process import wholepatternfit ### more submodules diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index 267f81e5f..70a36dec1 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -1,12 +1,14 @@ # BraggVectors methods -import numpy as np -from scipy.ndimage import gaussian_filter -from warnings import warn import inspect +from warnings import warn -from emdfile import Array, Metadata, tqdmnd, _read_metadata +import matplotlib.pyplot as plt +import numpy as np +from emdfile import Array, Metadata, _read_metadata, tqdmnd +from py4DSTEM import show from py4DSTEM.datacube import VirtualImage +from scipy.ndimage import gaussian_filter class BraggVectorMethods: @@ -97,12 +99,12 @@ def histogram( # then scale by the sampling factor else: # get pixel calibration - if self.calstate["pixel"] == True: + if self.calstate["pixel"] is True: qpix = self.calibration.get_Q_pixel_size() qx /= qpix qy /= qpix # origin calibration - if self.calstate["center"] == True: + if self.calstate["center"] is True: origin = self.calibration.get_origin_mean() qx += origin[0] qy += origin[1] @@ -151,12 +153,12 @@ def histogram( ).reshape(Q_Nx, Q_Ny) # determine the resampled grid center and pixel size - if mode == "cal" and self.calstate["center"] == True: + if mode == "cal" and self.calstate["center"] is True: x0 = sampling * origin[0] y0 = sampling * origin[1] else: x0, y0 = 0, 0 - if mode == "cal" and self.calstate["pixel"] == True: + if mode == "cal" and self.calstate["pixel"] is True: pixelsize = qpix / sampling else: pixelsize = 1 / sampling @@ -518,6 +520,7 @@ def fit_origin( mask_check_data=True, plot=True, plot_range=None, + cmap="RdBu_r", returncalc=True, **kwargs, ): @@ -537,6 +540,7 @@ def fit_origin( mask_check_data (bool): Get mask from origin measurements equal to zero. (TODO - replace) plot (bool, optional): plot results plot_range (float): min and max color range for plot (pixels) + cmap (colormap): plotting colormap Returns: (variable): Return value depends on returnfitp. If ``returnfitp==False`` @@ -552,84 +556,113 @@ def fit_origin( from py4DSTEM.process.calibration import fit_origin if mask_check_data is True: - # TODO - replace this bad hack for the mask for the origin fit - mask = np.logical_not(q_meas[0] == 0) - qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin( - tuple(q_meas), - mask=mask, - ) - else: - qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(tuple(q_meas)) + data_mask = np.logical_not(q_meas[0] == 0) + if mask is None: + mask = data_mask + else: + mask = np.logical_and(mask, data_mask) + + qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin( + tuple(q_meas), + mask=mask, + fitfunction=fitfunction, + robust=robust, + robust_steps=robust_steps, + robust_thresh=robust_thresh, + ) - # try to add to calibration + # try to add update calibration metadata try: - self.calibration.set_origin([qx0_fit, qy0_fit]) + self.calibration.set_origin((qx0_fit, qy0_fit)) + self.setcal() except AttributeError: warn( "No calibration found on this datacube - fit values are not being stored" ) pass - if plot: - from py4DSTEM.visualize import show_image_grid - if mask is None: - qx0_meas, qy0_meas = q_meas - qx0_res_plot = qx0_residuals - qy0_res_plot = qy0_residuals - else: - qx0_meas = np.ma.masked_array(q_meas[0], mask=np.logical_not(mask)) - qy0_meas = np.ma.masked_array(q_meas[1], mask=np.logical_not(mask)) - qx0_res_plot = np.ma.masked_array( - qx0_residuals, mask=np.logical_not(mask) - ) - qy0_res_plot = np.ma.masked_array( - qy0_residuals, mask=np.logical_not(mask) - ) - qx0_mean = np.mean(qx0_fit) - qy0_mean = np.mean(qy0_fit) - - if plot_range is None: - plot_range = 2 * np.max(qx0_fit - qx0_mean) - - cmap = kwargs.get("cmap", "RdBu_r") - kwargs.pop("cmap", None) - axsize = kwargs.get("axsize", (6, 2)) - kwargs.pop("axsize", None) - - show_image_grid( - lambda i: [ - qx0_meas - qx0_mean, - qx0_fit - qx0_mean, - qx0_res_plot, - qy0_meas - qy0_mean, - qy0_fit - qy0_mean, - qy0_res_plot, - ][i], - H=2, - W=3, + # show + if plot: + self.show_origin_fit( + q_meas[0], + q_meas[1], + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask=mask, + plot_range=plot_range, cmap=cmap, - axsize=axsize, - title=[ - "measured origin, x", - "fitorigin, x", - "residuals, x", - "measured origin, y", - "fitorigin, y", - "residuals, y", - ], - vmin=-1 * plot_range, - vmax=1 * plot_range, - intensity_range="absolute", **kwargs, ) - # update calibration metadata - self.calibration.set_origin((qx0_fit, qy0_fit)) - self.setcal() - + # return if returncalc: return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals + def show_origin_fit( + self, + qx0_meas, + qy0_meas, + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask=None, + plot_range=None, + cmap="RdBu_r", + **kwargs, + ): + # apply mask + if mask is not None: + qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask)) + qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask)) + qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask)) + qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask)) + qx0_mean = np.mean(qx0_fit) + qy0_mean = np.mean(qy0_fit) + + # set range + if plot_range is None: + plot_range = max( + ( + 1.5 * np.max(np.abs(qx0_fit - qx0_mean)), + 1.5 * np.max(np.abs(qy0_fit - qy0_mean)), + ) + ) + + # set figsize + imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0]) + axsize = (3 * imsize_ratio, 3 / imsize_ratio) + axsize = kwargs.pop("axsize", axsize) + + # plot + fig, ax = show( + [ + [qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals], + [qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals], + ], + cmap=cmap, + axsize=axsize, + title=[ + "measured origin, x", + "fitorigin, x", + "residuals, x", + "measured origin, y", + "fitorigin, y", + "residuals, y", + ], + vmin=-1 * plot_range, + vmax=1 * plot_range, + intensity_range="absolute", + show_cbar=True, + returnfig=True, + **kwargs, + ) + plt.tight_layout() + + return + def fit_p_ellipse( self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs ): @@ -765,6 +798,21 @@ def mask_in_R(self, mask, update_inplace=False, returncalc=True): else: return + def to_strainmap(self, name: str = None): + """ + Generate a StrainMap object from the BraggVectors + equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors) + + Args: + name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'. + + Returns: + py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors + """ + from py4DSTEM.process.strain import StrainMap + + return StrainMap(self, name) if name else StrainMap(self) + ######### END BraggVectorMethods CLASS ######## diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 45c08b9c9..e81eeb62f 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -200,7 +200,7 @@ def setcal( if pixel is None: pixel = False if c.get_Q_pixel_size() == 1 else True if rotate is None: - rotate = False if c.get_QR_rotflip() is None else True + rotate = False if c.get_QR_rotation() is None else True # validate requested state if center: @@ -210,7 +210,7 @@ def setcal( if pixel: assert c.get_Q_pixel_size() is not None, "Requested calibration not found" if rotate: - assert c.get_QR_rotflip() is not None, "Requested calibration not found" + assert c.get_QR_rotation() is not None, "Requested calibration not found" # set the calibrations self._calstate = { @@ -272,6 +272,7 @@ def copy(self, name=None): braggvector_copy.set_raw_vectors(self._v_uncal.copy()) for k in self.metadata.keys(): braggvector_copy.metadata = self.metadata[k].copy() + braggvector_copy.setcal() return braggvector_copy # write @@ -384,7 +385,7 @@ def __getitem__(self, pos): def __repr__(self): space = " " * len(self.__class__.__name__) + " " string = f"{self.__class__.__name__}( " - string += f"Retrieves raw bragg vectors. Get vectors for scan position x,y with [x,y]. )" + string += "Retrieves raw bragg vectors. Get vectors for scan position x,y with [x,y]. )" return string @@ -478,15 +479,15 @@ def _transform( # Q/R rotation if rotate: - flip = cal.get_QR_flip() - theta = cal.get_QR_rotation_degrees() - assert flip is not None, "Requested calibration was not found!" + theta = cal.get_QR_rotation() assert theta is not None, "Requested calibration was not found!" + flip = cal.get_QR_flip() + flip = False if flip is None else flip # rotation matrix R = np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ) - # apply + # rotate and flip if flip: positions = R @ np.vstack((ans["qy"], ans["qx"])) else: diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py index e23b10a15..99818b75e 100644 --- a/py4DSTEM/braggvectors/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -231,10 +231,10 @@ def find_Bragg_disks( mode = "dc_ml" elif mode == "datacube": - if distributed is None and CUDA == False: + if distributed is None and CUDA is False: mode = "dc_CPU" - elif distributed is None and CUDA == True: - if CUDA_batched == False: + elif distributed is None and CUDA is True: + if CUDA_batched is False: mode = "dc_GPU" else: mode = "dc_GPU_batched" @@ -271,7 +271,7 @@ def find_Bragg_disks( kws["data_file"] = data_file kws["cluster_path"] = cluster_path # ML arguments - if ML == True: + if ML is True: kws["CUDA"] = CUDA kws["model_path"] = ml_model_path kws["num_attempts"] = ml_num_attempts @@ -759,7 +759,7 @@ def _parse_distributed(distributed): data_file = distributed["data_file"] if not isinstance(data_file, str): - er = f"Expected string for distributed key 'data_file', " + er = "Expected string for distributed key 'data_file', " er += f"received {type(data_file)}" raise TypeError(er) if len(data_file.strip()) == 0: @@ -773,7 +773,7 @@ def _parse_distributed(distributed): cluster_path = distributed["cluster_path"] if not isinstance(cluster_path, str): - er = f"distributed key 'cluster_path' must be of type str, " + er = "distributed key 'cluster_path' must be of type str, " er += f"received {type(cluster_path)}" raise TypeError(er) @@ -784,7 +784,7 @@ def _parse_distributed(distributed): er = f"distributed key 'cluster_path' does not exist: {cluster_path}" raise FileNotFoundError(er) elif not os.path.isdir(cluster_path): - er = f"distributed key 'cluster_path' is not a directory: " + er = "distributed key 'cluster_path' is not a directory: " er += f"{cluster_path}" raise NotADirectoryError(er) else: diff --git a/py4DSTEM/braggvectors/diskdetection_aiml.py b/py4DSTEM/braggvectors/diskdetection_aiml.py index 67df18074..4d23ebf6c 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml.py @@ -528,7 +528,7 @@ def find_Bragg_disks_aiml_serial( ) ) - if global_threshold == True: + if global_threshold is True: from py4DSTEM.braggvectors import universal_threshold peaks = universal_threshold( @@ -559,7 +559,7 @@ def find_Bragg_disks_aiml( model_path=None, distributed=None, CUDA=True, - **kwargs + **kwargs, ): """ Finds the Bragg disks in all diffraction patterns of datacube by AI/ML method. This method diff --git a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py index c5f89b9fd..ffa6e891b 100644 --- a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py @@ -17,8 +17,9 @@ try: import cupy as cp -except ModuleNotFoundError: - raise ImportError("AIML CUDA Requires cupy") + from cupyx.scipy.ndimage import gaussian_filter +except (ModuleNotFoundError, ImportError) as e: + raise ImportError("AIML CUDA Requires cupy") from e try: import tensorflow as tf @@ -29,8 +30,6 @@ + "for more information" ) -from cupyx.scipy.ndimage import gaussian_filter - def find_Bragg_disks_aiml_CUDA( datacube, @@ -233,7 +232,7 @@ def find_Bragg_disks_aiml_CUDA( datacube.R_N, int(t2 / 3600), int(t2 / 60), int(t2 % 60) ) ) - if global_threshold == True: + if global_threshold is True: from py4DSTEM.braggvectors import universal_threshold peaks = universal_threshold( @@ -265,7 +264,7 @@ def _find_Bragg_disks_aiml_single_DP_CUDA( blocks=None, threads=None, model_path=None, - **kwargs + **kwargs, ): """ Finds the Bragg disks in single DP by AI/ML method. This method utilizes FCU-Net @@ -496,7 +495,7 @@ def get_maxima_2D_cp( if minSpacing > 0: deletemask = np.zeros(len(maxima), dtype=bool) for i in range(len(maxima)): - if deletemask[i] == False: + if deletemask[i] == False: # noqa: E712 tooClose = ( (maxima["x"] - maxima["x"][i]) ** 2 + (maxima["y"] - maxima["y"][i]) ** 2 diff --git a/py4DSTEM/braggvectors/diskdetection_cuda.py b/py4DSTEM/braggvectors/diskdetection_cuda.py index 4bbb7f488..870b8303d 100644 --- a/py4DSTEM/braggvectors/diskdetection_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_cuda.py @@ -4,6 +4,7 @@ """ import numpy as np + import cupy as cp from cupyx.scipy.ndimage import gaussian_filter import cupyx.scipy.fft as cufft @@ -482,7 +483,7 @@ def get_maxima_2D( if minSpacing > 0: deletemask = np.zeros(len(maxima), dtype=bool) for i in range(len(maxima)): - if deletemask[i] == False: + if deletemask[i] == False: # noqa: E712 tooClose = ( (maxima["x"] - maxima["x"][i]) ** 2 + (maxima["y"] - maxima["y"][i]) ** 2 diff --git a/py4DSTEM/braggvectors/diskdetection_parallel_new.py b/py4DSTEM/braggvectors/diskdetection_parallel_new.py index c15e41732..dccc0dd4b 100644 --- a/py4DSTEM/braggvectors/diskdetection_parallel_new.py +++ b/py4DSTEM/braggvectors/diskdetection_parallel_new.py @@ -100,7 +100,7 @@ def beta_parallel_disk_detection( close_dask_client=False, return_dask_client=True, *args, - **kwargs + **kwargs, ): """ This is not fully validated currently so may not work, please report bugs on the py4DSTEM github page. @@ -137,8 +137,8 @@ def beta_parallel_disk_detection( # ... dask stuff. # TODO add assert statements and other checks. Think about reordering opperations - if dask_client == None: - if dask_client_params != None: + if dask_client is None: + if dask_client_params is not None: dask.config.set( { "distributed.worker.memory.spill": False, @@ -201,7 +201,7 @@ def beta_parallel_disk_detection( dask_data = da.from_array( dataset.data, chunks=(1, 1, dataset.Q_Nx, dataset.Q_Ny) ) - elif dataset.stack_pointer != None: + elif dataset.stack_pointer is not None: dask_data = da.from_array( dataset.stack_pointer, chunks=(1, 1, dataset.Q_Nx, dataset.Q_Ny) ) @@ -225,7 +225,7 @@ def beta_parallel_disk_detection( probe_kernel_FT=dask_probe_delayed[0, 0], # probe_kernel_FT=delayed_probe_kernel_FT, *args, - **kwargs + **kwargs, ) # passing through args from earlier or should I use # corrPower=corrPower, # sigma=sigma_gaussianFilter, @@ -261,9 +261,9 @@ def beta_parallel_disk_detection( if close_dask_client: dask_client.close() return peaks - elif close_dask_client == False and return_dask_client == True: + elif close_dask_client is False and return_dask_client is True: return peaks, dask_client - elif close_dask_client and return_dask_client == False: + elif close_dask_client and return_dask_client is False: return peaks else: print( diff --git a/py4DSTEM/braggvectors/threshold.py b/py4DSTEM/braggvectors/threshold.py index c13b0a665..7e19404b1 100644 --- a/py4DSTEM/braggvectors/threshold.py +++ b/py4DSTEM/braggvectors/threshold.py @@ -1,7 +1,6 @@ # Bragg peaks thresholding fns import numpy as np - from emdfile import tqdmnd, PointListArray @@ -52,7 +51,7 @@ def threshold_Braggpeaks( r2 = minPeakSpacing**2 deletemask = np.zeros(pointlist.length, dtype=bool) for i in range(pointlist.length): - if deletemask[i] == False: + if deletemask[i] == False: # noqa: E712 tooClose = ( (pointlist.data["qx"] - pointlist.data["qx"][i]) ** 2 + (pointlist.data["qy"] - pointlist.data["qy"][i]) ** 2 @@ -160,7 +159,7 @@ def universal_threshold( r2 = minPeakSpacing**2 deletemask = np.zeros(pointlist.length, dtype=bool) for i in range(pointlist.length): - if deletemask[i] == False: + if deletemask[i] == False: # noqa: E712 tooClose = ( (pointlist.data["qx"] - pointlist.data["qx"][i]) ** 2 + (pointlist.data["qy"] - pointlist.data["qy"][i]) ** 2 diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py index a31f098d4..408f977cc 100644 --- a/py4DSTEM/data/calibration.py +++ b/py4DSTEM/data/calibration.py @@ -205,6 +205,7 @@ def __init__( self["R_pixel_size"] = 1 self["Q_pixel_units"] = "pixels" self["R_pixel_units"] = "pixels" + self["QR_flip"] = False # EMD root property @property @@ -233,7 +234,7 @@ def attach(self, data): """ from py4DSTEM.data import Data - assert isinstance(data, Data), f"data must be a Data instance" + assert isinstance(data, Data), "data must be a Data instance" self.root.attach(data) # Register for auto-calibration @@ -315,7 +316,7 @@ def set_Q_pixel_units(self, x): "pixels", "A^-1", "mrad", - ), f"Q pixel units must be 'A^-1', 'mrad' or 'pixels'." + ), "Q pixel units must be 'A^-1', 'mrad' or 'pixels'." self._params["Q_pixel_units"] = x def get_Q_pixel_units(self): @@ -666,8 +667,17 @@ def ellipse(self, x): # Q/R-space rotation and flip + @call_calibrate + def set_QR_rotation(self, x): + self._params["QR_rotation"] = x + self._params["QR_rotation_degrees"] = np.degrees(x) + + def get_QR_rotation(self): + return self._get_value("QR_rotation") + @call_calibrate def set_QR_rotation_degrees(self, x): + self._params["QR_rotation"] = np.radians(x) self._params["QR_rotation_degrees"] = x def get_QR_rotation_degrees(self): @@ -689,10 +699,31 @@ def set_QR_rotflip(self, rot_flip): flip (bool): True indicates a Q/R axes flip """ rot, flip = rot_flip + self._params["QR_rotation"] = rot + self._params["QR_rotation_degrees"] = np.degrees(rot) + self._params["QR_flip"] = flip + + @call_calibrate + def set_QR_rotflip_degrees(self, rot_flip): + """ + Args: + rot_flip (tuple), (rot, flip) where: + rot (number): rotation in degrees + flip (bool): True indicates a Q/R axes flip + """ + rot, flip = rot_flip + self._params["QR_rotation"] = np.radians(rot) self._params["QR_rotation_degrees"] = rot self._params["QR_flip"] = flip def get_QR_rotflip(self): + rot = self.get_QR_rotation() + flip = self.get_QR_flip() + if rot is None or flip is None: + return None + return (rot, flip) + + def get_QR_rotflip_degrees(self): rot = self.get_QR_rotation_degrees() flip = self.get_QR_flip() if rot is None or flip is None: diff --git a/py4DSTEM/datacube/virtualdiffraction.py b/py4DSTEM/datacube/virtualdiffraction.py index 65665728d..23b151d58 100644 --- a/py4DSTEM/datacube/virtualdiffraction.py +++ b/py4DSTEM/datacube/virtualdiffraction.py @@ -114,7 +114,7 @@ def get_virtual_diffraction( # Calculate # ...with no center shifting - if shift_center == False: + if shift_center is False: # ...for the whole pattern if mask is None: if method == "mean": diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py index 50a297914..87aeae8b1 100644 --- a/py4DSTEM/datacube/virtualimage.py +++ b/py4DSTEM/datacube/virtualimage.py @@ -197,13 +197,13 @@ def get_virtual_image( # Get mask mask = self.make_detector(self.Qshape, mode, g) # if return_mask is True, skip computation - if return_mask == True and shift_center == False: + if return_mask is True and shift_center is False: return mask # Calculate virtual image # no center shifting - if shift_center == False: + if shift_center is False: # single CPU if not dask: # allocate space @@ -220,7 +220,7 @@ def get_virtual_image( virtual_image[rx, ry] = np.sum(self.data[rx, ry] * mask) # dask - if dask == True: + if dask is True: # set up a generalized universal function for dask distribution def _apply_mask_dask(self, mask): virtual_image = np.sum( @@ -444,7 +444,7 @@ def position_detector( # shift center if shift_center is None: shift_center = False - elif shift_center == True: + elif shift_center is True: assert isinstance( data, tuple ), "If shift_center is set to True, `data` should be a 2-tuple (rx,ry). \ @@ -552,7 +552,7 @@ def get_calibrated_detector_geometry( # Convert units into detector pixels # Shift center - if centered == True: + if centered is True: if mode == "point": g = (g[0] + x0_mean, g[1] + y0_mean) if mode in ("circle", "circular", "annulus", "annular"): @@ -561,7 +561,7 @@ def get_calibrated_detector_geometry( g = (g[0] + x0_mean, g[1] + x0_mean, g[2] + y0_mean, g[3] + y0_mean) # Scale by the detector pixel size - if calibrated == True: + if calibrated is True: if mode == "point": g = (g[0] / unit_conversion, g[1] / unit_conversion) if mode in ("circle", "circular"): diff --git a/py4DSTEM/io/filereaders/read_K2.py b/py4DSTEM/io/filereaders/read_K2.py index 61405a437..5df91f2dc 100644 --- a/py4DSTEM/io/filereaders/read_K2.py +++ b/py4DSTEM/io/filereaders/read_K2.py @@ -336,9 +336,9 @@ def _find_offsets(self): for i in range(8): sync = False frame = 0 - while sync == False: + while sync == False: # noqa: E712 sync = self._bin_files[i][frame]["block"] == block_id - if sync == False: + if sync == False: # noqa: E712 frame += 1 self._shutter_offsets[i] += frame print("Offsets are currently ", self._shutter_offsets) @@ -358,7 +358,7 @@ def _find_offsets(self): sync = False next_frame = stripe[j]["frame"] - if sync == False: + if sync == False: # noqa: E712 # the first frame is incomplete, so we need to seek the next one print( f"First frame ({first_frame}) incomplete, seeking frame {next_frame}..." @@ -366,12 +366,12 @@ def _find_offsets(self): for i in range(8): sync = False frame = 0 - while sync == False: + while sync == False: # noqa: E712 sync = ( self._bin_files[i][self._shutter_offsets[i] + frame]["frame"] == next_frame ) - if sync == False: + if sync == False: # noqa: E712 frame += 1 self._shutter_offsets[i] += frame print("Offsets are now ", self._shutter_offsets) @@ -387,7 +387,7 @@ def _find_offsets(self): ] if np.any(stripe[:]["frame"] != first_frame): sync = False - if sync == True: + if sync == True: # noqa: E712 print("New frame is complete!") else: print("Next frame also incomplete!!!! Data may be corrupt?") @@ -397,7 +397,7 @@ def _find_offsets(self): for i in range(8): shutter = False frame = 0 - while shutter == False: + while shutter == False: # noqa: E712 offset = self._shutter_offsets[i] + (frame * 32) stripe = self._bin_files[i][offset : offset + 32] shutter = stripe[0]["shutter"] diff --git a/py4DSTEM/io/filereaders/read_mib.py b/py4DSTEM/io/filereaders/read_mib.py index 079c9d1bd..7456bd594 100644 --- a/py4DSTEM/io/filereaders/read_mib.py +++ b/py4DSTEM/io/filereaders/read_mib.py @@ -14,7 +14,7 @@ def load_mib( reshape=True, flip=True, scan=(256, 256), - **kwargs + **kwargs, ): """ Read a MIB file and return as py4DSTEM DataCube. diff --git a/py4DSTEM/preprocess/electroncount.py b/py4DSTEM/preprocess/electroncount.py index 7a498a061..e3fc68e05 100644 --- a/py4DSTEM/preprocess/electroncount.py +++ b/py4DSTEM/preprocess/electroncount.py @@ -1,8 +1,9 @@ # Electron counting # -# Includes functions for electron counting on either the CPU (electron_count) or the GPU -# (electron_count_GPU). For GPU electron counting, pytorch is used to interface between -# numpy and the GPU, and the datacube is expected in numpy.memmap (memory mapped) form. +# Includes functions for electron counting on either the CPU (electron_count) +# or the GPU (electron_count_GPU). For GPU electron counting, pytorch is used +# to interface between numpy and the GPU, and the datacube is expected in +# numpy.memmap (memory mapped) form. import numpy as np from scipy import optimize @@ -25,17 +26,17 @@ def electron_count( Performs electron counting. The algorithm is as follows: - From a random sampling of frames, calculate an x-ray and background threshold value. - In each frame, subtract the dark reference, then apply the two thresholds. Find all - local maxima with respect to the nearest neighbor pixels. These are considered - electron strike events. - - Thresholds are specified in units of standard deviations, either of a gaussian fit to - the histogram background noise (for thresh_bkgrnd) or of the histogram itself (for - thresh_xray). The background (lower) threshold is more important; we will always be - missing some real electron counts and incorrectly counting some noise as electron - strikes - this threshold controls their relative balance. The x-ray threshold may be - set fairly high. + From a random sampling of frames, calculate an x-ray and background + threshold value. In each frame, subtract the dark reference, then apply the + two thresholds. Find all local maxima with respect to the nearest neighbor + pixels. These are considered electron strike events. + + Thresholds are specified in units of standard deviations, either of a + gaussian fit to the histogram background noise (for thresh_bkgrnd) or of + the histogram itself (for thresh_xray). The background (lower) threshold is + more important; we will always be missing some real electron counts and + incorrectly counting some noise as electron strikes - this threshold + controls their relative balance. The x-ray threshold may be set fairly high. Args: datacube: a 4D numpy.ndarray pointing to the datacube. Note: the R/Q axes are @@ -402,7 +403,7 @@ def counted_pointlistarray_to_datacube(counted_pointlistarray, shape, subpixel=F (4D array of bools): a 4D array of bools, with true indicating an electron strike. """ assert len(shape) == 4 - assert subpixel == False, "subpixel mode not presently supported." + assert subpixel is False, "subpixel mode not presently supported." R_Nx, R_Ny, Q_Nx, Q_Ny = shape counted_datacube = np.zeros((R_Nx, R_Nx, Q_Nx, Q_Ny), dtype=bool) diff --git a/py4DSTEM/preprocess/radialbkgrd.py b/py4DSTEM/preprocess/radialbkgrd.py index e0d402fe3..1fcb859cd 100644 --- a/py4DSTEM/preprocess/radialbkgrd.py +++ b/py4DSTEM/preprocess/radialbkgrd.py @@ -80,7 +80,7 @@ def get_1D_polar_background( # Crop polar data to maximum distance which contains information from original image if (polarData.mask.sum(axis=(0)) == polarData.shape[0]).any(): ii = polarData.data.shape[1] - 1 - while polarData.mask[:, ii].all() == True: + while polarData.mask[:, ii].all() == True: # noqa: E712 ii = ii - 1 maximalDistance = ii polarData = polarData[:, 0:maximalDistance] @@ -105,16 +105,16 @@ def get_1D_polar_background( background1D = np.maximum(background1D, min_background_value) - if smoothing == True: - if smoothing_log == True: + if smoothing is True: + if smoothing_log is True: background1D = np.log(background1D) background1D = savgol_filter( background1D, smoothingWindowSize, smoothingPolyOrder ) - if smoothing_log == True: + if smoothing_log is True: background1D = np.exp(background1D) - if return_polararr == True: + if return_polararr is True: return (background1D, r_bins, polarData) else: return (background1D, r_bins) diff --git a/py4DSTEM/preprocess/utils.py b/py4DSTEM/preprocess/utils.py index 752e2f81c..0165a4753 100644 --- a/py4DSTEM/preprocess/utils.py +++ b/py4DSTEM/preprocess/utils.py @@ -5,7 +5,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np @@ -293,7 +293,7 @@ def filter_2D_maxima( if minSpacing > 0: deletemask = np.zeros(len(maxima), dtype=bool) for i in range(len(maxima)): - if deletemask[i] == False: + if deletemask[i] == False: # noqa: E712 tooClose = ( (maxima["x"] - maxima["x"][i]) ** 2 + (maxima["y"] - maxima["y"][i]) ** 2 diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index e711e907d..0509d181e 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,11 +1,9 @@ from py4DSTEM.process.polar import PolarDatacube -from py4DSTEM.process.strain import StrainMap +from py4DSTEM.process.strain.strain import StrainMap -from py4DSTEM.process import latticevectors from py4DSTEM.process import phase from py4DSTEM.process import calibration from py4DSTEM.process import utils from py4DSTEM.process import classification -from py4DSTEM.process import latticevectors from py4DSTEM.process import diffraction from py4DSTEM.process import wholepatternfit diff --git a/py4DSTEM/process/calibration/ellipse.py b/py4DSTEM/process/calibration/ellipse.py index 8835aa95b..2954de377 100644 --- a/py4DSTEM/process/calibration/ellipse.py +++ b/py4DSTEM/process/calibration/ellipse.py @@ -63,7 +63,7 @@ def fit_ellipse_1D(ar, center=None, fitradii=None, mask=None): rr = np.sqrt((xx - x0) ** 2 + (yy - y0) ** 2) _mask = (rr > ri) * (rr <= ro) if mask is not None: - _mask *= mask == False + _mask *= mask == False # noqa: E712 xs, ys = np.nonzero(_mask) vals = ar[_mask] diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py index 78a90fbef..a0717e321 100644 --- a/py4DSTEM/process/calibration/origin.py +++ b/py4DSTEM/process/calibration/origin.py @@ -154,7 +154,7 @@ def fit_origin( robust=robust, robust_steps=robust_steps, robust_thresh=robust_thresh, - data_mask=mask == True, + data_mask=mask == True, # noqa E712 ) popt_y, pcov_y, qy0_fit, _ = fit_2D( f, @@ -162,7 +162,7 @@ def fit_origin( robust=robust, robust_steps=robust_steps, robust_thresh=robust_thresh, - data_mask=mask == True, + data_mask=mask == True, # noqa E712 ) # Compute residuals diff --git a/py4DSTEM/process/calibration/qpixelsize.py b/py4DSTEM/process/calibration/qpixelsize.py index 2abefd54c..d59d5a45c 100644 --- a/py4DSTEM/process/calibration/qpixelsize.py +++ b/py4DSTEM/process/calibration/qpixelsize.py @@ -60,6 +60,6 @@ def get_dq_from_indexed_peaks(qs, hkl, a): # Get pixel size dq = 1 / (c * a) qs_fit = d_inv[mask] / a - hkl_fit = [hkl[i] for i in range(len(hkl)) if mask[i] == True] + hkl_fit = [hkl[i] for i in range(len(hkl)) if mask[i] == True] # noqa: E712 return dq, qs_fit, hkl_fit diff --git a/py4DSTEM/process/calibration/rotation.py b/py4DSTEM/process/calibration/rotation.py index 2c3e7bb43..aaf8a49ce 100644 --- a/py4DSTEM/process/calibration/rotation.py +++ b/py4DSTEM/process/calibration/rotation.py @@ -2,6 +2,179 @@ import numpy as np from typing import Optional +import matplotlib.pyplot as plt +from py4DSTEM import show + + +def compare_QR_rotation( + im_R, + im_Q, + QR_rotation, + R_rotation=0, + R_position=None, + Q_position=None, + R_pos_anchor="center", + Q_pos_anchor="center", + R_length=0.33, + Q_length=0.33, + R_width=0.001, + Q_width=0.001, + R_head_length_adjust=1, + Q_head_length_adjust=1, + R_head_width_adjust=1, + Q_head_width_adjust=1, + R_color="r", + Q_color="r", + figsize=(10, 5), + returnfig=False, +): + """ + Visualize a rotational offset between an image in real space, e.g. a STEM + virtual image, and an image in diffraction space, e.g. a defocused CBED + shadow image of the same region, by displaying an arrow overlaid over each + of these two images with the specified QR rotation applied. The QR rotation + is defined as the counter-clockwise rotation from real space to diffraction + space, in degrees. + + Parameters + ---------- + im_R : numpy array or other 2D image-like object (e.g. a VirtualImage) + A real space image, e.g. a STEM virtual image + im_Q : numpy array or other 2D image-like object + A diffraction space image, e.g. a defocused CBED image + QR_rotation : number + The counterclockwise rotation from real space to diffraction space, + in degrees + R_rotation : number + The orientation of the arrow drawn in real space, in degrees + R_position : None or 2-tuple + The position of the anchor point for the R-space arrow. If None, defaults + to the center of the image + Q_position : None or 2-tuple + The position of the anchor point for the Q-space arrow. If None, defaults + to the center of the image + R_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the R-space arrow, i.e. the point being specified by + the `R_position` parameter + Q_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the Q-space arrow, i.e. the point being specified by + the `Q_position` parameter + R_length : number or None + The length of the R-space arrow, as a fraction of the mean size of the + image + Q_length : number or None + The length of the Q-space arrow, as a fraction of the mean size of the + image + R_width : number + The width of the R-space arrow + Q_width : number + The width of the R-space arrow + R_head_length_adjust : number + Scaling factor for the R-space arrow head length + Q_head_length_adjust : number + Scaling factor for the Q-space arrow head length + R_head_width_adjust : number + Scaling factor for the R-space arrow head width + Q_head_width_adjust : number + Scaling factor for the Q-space arrow head width + R_color : color + Color of the R-space arrow + Q_color : color + Color of the Q-space arrow + figsize : 2-tuple + The figure size + returnfig : bool + Toggles returning the figure and axes + """ + # parse inputs + if R_position is None: + R_position = ( + im_R.shape[0] / 2, + im_R.shape[1] / 2, + ) + if Q_position is None: + Q_position = ( + im_Q.shape[0] / 2, + im_Q.shape[1] / 2, + ) + R_length = np.mean(im_R.shape) * R_length + Q_length = np.mean(im_Q.shape) * Q_length + assert R_pos_anchor in ("center", "tail", "head") + assert Q_pos_anchor in ("center", "tail", "head") + + # compute positions + rpos_x, rpos_y = R_position + qpos_x, qpos_y = Q_position + R_rot_rad = np.radians(R_rotation) + Q_rot_rad = np.radians(R_rotation + QR_rotation) + rvecx = np.cos(R_rot_rad) + rvecy = np.sin(R_rot_rad) + qvecx = np.cos(Q_rot_rad) + qvecy = np.sin(Q_rot_rad) + if R_pos_anchor == "center": + x0_r = rpos_x - rvecx * R_length / 2 + y0_r = rpos_y - rvecy * R_length / 2 + x1_r = rpos_x + rvecx * R_length / 2 + y1_r = rpos_y + rvecy * R_length / 2 + elif R_pos_anchor == "tail": + x0_r = rpos_x + y0_r = rpos_y + x1_r = rpos_x + rvecx * R_length + y1_r = rpos_y + rvecy * R_length + elif R_pos_anchor == "head": + x0_r = rpos_x - rvecx * R_length + y0_r = rpos_y - rvecy * R_length + x1_r = rpos_x + y1_r = rpos_y + else: + raise Exception(f"Invalid value for R_pos_anchor {R_pos_anchor}") + if Q_pos_anchor == "center": + x0_q = qpos_x - qvecx * Q_length / 2 + y0_q = qpos_y - qvecy * Q_length / 2 + x1_q = qpos_x + qvecx * Q_length / 2 + y1_q = qpos_y + qvecy * Q_length / 2 + elif Q_pos_anchor == "tail": + x0_q = qpos_x + y0_q = qpos_y + x1_q = qpos_x + qvecx * Q_length + y1_q = qpos_y + qvecy * Q_length + elif Q_pos_anchor == "head": + x0_q = qpos_x - qvecx * Q_length + y0_q = qpos_y - qvecy * Q_length + x1_q = qpos_x + y1_q = qpos_y + else: + raise Exception(f"Invalid value for Q_pos_anchor {Q_pos_anchor}") + + # make the figure + axsize = (figsize[0] / 2, figsize[1]) + fig, axs = show([im_R, im_Q], returnfig=True, axsize=axsize) + axs[0, 0].arrow( + x=y0_r, + y=x0_r, + dx=y1_r - y0_r, + dy=x1_r - x0_r, + color=R_color, + length_includes_head=True, + width=R_width, + head_width=R_length * R_head_width_adjust * 0.072, + head_length=R_length * R_head_length_adjust * 0.1, + ) + axs[0, 1].arrow( + x=y0_q, + y=x0_q, + dx=y1_q - y0_q, + dy=x1_q - x0_q, + color=Q_color, + length_includes_head=True, + width=Q_width, + head_width=Q_length * Q_head_width_adjust * 0.072, + head_length=Q_length * Q_head_length_adjust * 0.1, + ) + if returnfig: + return fig, axs + else: + plt.show() def get_Qvector_from_Rvector(vx, vy, QR_rotation): diff --git a/py4DSTEM/process/classification/braggvectorclassification.py b/py4DSTEM/process/classification/braggvectorclassification.py index d5c2ac0fc..c3e36273f 100644 --- a/py4DSTEM/process/classification/braggvectorclassification.py +++ b/py4DSTEM/process/classification/braggvectorclassification.py @@ -594,8 +594,8 @@ def merge_iterative(self, threshBPs=0.1, threshScanPosition=0.1): W_merge = W_merge[:, 1:] H_merge = H_merge[1:, :] - W_ = np.hstack((W_[:, merged == False], W_merge)) - H_ = np.vstack((H_[merged == False, :], H_merge)) + W_ = np.hstack((W_[:, merged == False], W_merge)) # noqa: E712 + H_ = np.vstack((H_[merged == False, :], H_merge)) # noqa: E712 Nc_ = W_.shape[1] if len(merge_candidates) == 0: diff --git a/py4DSTEM/process/classification/featurization.py b/py4DSTEM/process/classification/featurization.py index 38b4e1412..126583413 100644 --- a/py4DSTEM/process/classification/featurization.py +++ b/py4DSTEM/process/classification/featurization.py @@ -182,7 +182,7 @@ def from_braggvectors( (pointlist.data["qy"] / q_pixel_size) + Q_Ny / 2 ).astype(int), ] - == False + == False # noqa: E712 ), True, False, @@ -330,7 +330,7 @@ def MinMaxScaler(self, return_scaled=True): """ mms = MinMaxScaler() self.features = mms.fit_transform(self.features) - if return_scaled == True: + if return_scaled is True: return self.features else: return @@ -345,7 +345,7 @@ def RobustScaler(self, return_scaled=True): """ rs = RobustScaler() self.features = rs.fit_transform(self.features) - if return_scaled == True: + if return_scaled is True: return self.features else: return @@ -358,7 +358,7 @@ def shift_positive(self, return_scaled=True): return_scaled (bool): returns the scaled array """ self.features += np.abs(self.features.min()) - if return_scaled == True: + if return_scaled is True: return self.features else: return @@ -372,7 +372,7 @@ def PCA(self, components, return_results=False): """ pca = PCA(n_components=components) self.pca = pca.fit_transform(self.features) - if return_results == True: + if return_results is True: return self.pca return @@ -385,7 +385,7 @@ def ICA(self, components, return_results=True): """ ica = FastICA(n_components=components) self.ica = ica.fit_transform(self.features) - if return_results == True: + if return_results is True: return self.ica return @@ -434,7 +434,7 @@ def NMF( random_seed=random_seed, save_all_models=save_all_models, ) - if return_results == True: + if return_results is True: return self.W return @@ -455,7 +455,7 @@ def GMM(self, cv, components, num_models, random_seed=None, return_results=False num_models=num_models, random_seed=random_seed, ) - if return_results == True: + if return_results is True: return self.gmm return @@ -680,7 +680,7 @@ def spatial_separation(self, size, threshold=0, method=None, clean=True): ) else: large_labelled_image = labelled_image - elif method == None: + elif method is None: labelled_image = label(image) if np.sum(labelled_image) > size: large_labelled_image = remove_small_objects( @@ -707,7 +707,7 @@ def spatial_separation(self, size, threshold=0, method=None, clean=True): ) if len(separated_temp) > 0: - if clean == True: + if clean is True: data_ndarray = np.dstack(separated_temp) data_hard = ( data_ndarray.max(axis=2, keepdims=1) == data_ndarray @@ -875,16 +875,16 @@ def _nmf_single( """ # Prepare error, random seed err = np.inf - if random_seed == None: + if random_seed is None: rng = np.random.RandomState(seed=42) else: seed = random_seed - if save_all_models == True: + if save_all_models is True: W = [] # Big loop through all models for i in range(num_models): - if random_seed == None: + if random_seed is None: seed = rng.randint(5000) n_comps = max_components recon_error, counter = 0, 0 @@ -936,7 +936,7 @@ def _nmf_single( if n_comps <= 2: break - if save_all_models == True: + if save_all_models is True: W.append(nmf_temp) elif (recon_error / counter) < err: @@ -963,18 +963,18 @@ def _gmm_single(x, cv, components, num_models, random_seed=None, return_all=True gmm_labels OR best_gmm_labels: Label list for all models or labels for best model gmm_proba OR best_gmm_proba: Probability list of class belonging or probability for best model """ - if return_all == True: + if return_all is True: gmm_list = [] gmm_labels = [] gmm_proba = [] lowest_bic = np.infty bic_temp = 0 - if random_seed == None: + if random_seed is None: rng = np.random.RandomState(seed=42) else: seed = random_seed for n in range(num_models): - if random_seed == None: + if random_seed is None: seed = rng.randint(5000) for j in range(len(components)): for cv_type in cv: @@ -986,18 +986,18 @@ def _gmm_single(x, cv, components, num_models, random_seed=None, return_all=True labels = gmm.fit_predict(x) bic_temp = gmm.bic(x) - if return_all == True: + if return_all is True: gmm_list.append(gmm) gmm_labels.append(labels) gmm_proba.append(gmm.predict_proba(x)) - elif return_all == False: + elif return_all is False: if bic_temp < lowest_bic: lowest_bic = bic_temp best_gmm = gmm best_gmm_labels = labels best_gmm_proba = gmm.predict_proba(x) - if return_all == True: + if return_all is True: return gmm_list, gmm_labels, gmm_proba return best_gmm, best_gmm_labels, best_gmm_proba diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index f3e3b7739..a2bf36a02 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -2,24 +2,16 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib.patches import Circle from fractions import Fraction from typing import Union, Optional -from scipy.optimize import curve_fit import sys import warnings -from emdfile import tqdmnd, PointList, PointListArray +from emdfile import PointList from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom -from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern -from py4DSTEM.process.diffraction.crystal_viz import plot_ring_pattern -from py4DSTEM.process.diffraction.utils import Orientation, calc_1D_profile - -try: - from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - from pymatgen.core.structure import Structure -except ImportError: - pass +from py4DSTEM.process.diffraction.utils import Orientation class Crystal: @@ -963,12 +955,12 @@ def generate_ring_pattern( ) intensity_unique = np.bincount(inv, weights=intensity) - if plot_rings == True: + if plot_rings is True: from py4DSTEM.process.diffraction.crystal_viz import plot_ring_pattern plot_ring_pattern(radii_unique, intensity_unique, **plot_params) - if return_calc == True: + if return_calc is True: return radii_unique, intensity_unique # Vector conversions and other utilities for Crystal classes @@ -1165,3 +1157,517 @@ def calculate_bragg_peak_histogram( int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power) int_exp /= np.max(int_exp) return k, int_exp + + +def generate_moire_diffraction_pattern( + bragg_peaks_0, + bragg_peaks_1, + thresh_0=0.0002, + thresh_1=0.0002, + exx_1=0.0, + eyy_1=0.0, + exy_1=0.0, + phi_1=0.0, + power=2.0, +): + """ + Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated + and strained with respect to the original lattice. Note that this strain is applied in real space, + and so the inverse of the calculated infinitestimal strain tensor is applied. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + thresh_0: float + Intensity threshold for structure factors from lattice 0. + thresh_1: float + Intensity threshold for structure factors from lattice 1. + exx_1: float + Strain of lattice 1 in x direction (vertical) in real space. + eyy_1: float + Strain of lattice 1 in y direction (horizontal) in real space. + exy_1: float + Shear strain of lattice 1 in (x,y) direction (diagonal) in real space. + phi_1: float + Rotation of lattice 1 in real space. + power: float + Plotting power law (default is amplitude**2.0, i.e. intensity). + + Returns + -------- + parent_peaks_0, parent_peaks_1, moire_peaks: BraggVectors + Bragg vectors for the rotated & strained parent lattices + and the moire lattice + + """ + + # get intenties of all peaks + int0 = bragg_peaks_0["intensity"] ** (power / 2.0) + int1 = bragg_peaks_1["intensity"] ** (power / 2.0) + + # peaks above threshold + sub0 = int0 >= thresh_0 + sub1 = int1 >= thresh_1 + + # Remove origin (assuming brightest peak) + ind0_or = np.argmax(bragg_peaks_0["intensity"]) + ind1_or = np.argmax(bragg_peaks_1["intensity"]) + sub0[ind0_or] = False + sub1[ind1_or] = False + int0_sub = int0[sub0] + int1_sub = int1[sub1] + + # Get peaks + qx0 = bragg_peaks_0["qx"][sub0] + qy0 = bragg_peaks_0["qy"][sub0] + qx1_init = bragg_peaks_1["qx"][sub1] + qy1_init = bragg_peaks_1["qy"][sub1] + + # peak labels + h0 = bragg_peaks_0["h"][sub0] + k0 = bragg_peaks_0["k"][sub0] + l0 = bragg_peaks_0["l"][sub0] + h1 = bragg_peaks_1["h"][sub1] + k1 = bragg_peaks_1["k"][sub1] + l1 = bragg_peaks_1["l"][sub1] + + # apply strain tensor to lattice 1 + m = np.array( + [ + [np.cos(phi_1), -np.sin(phi_1)], + [np.sin(phi_1), np.cos(phi_1)], + ] + ) @ np.linalg.inv( + np.array( + [ + [1 + exx_1, exy_1 * 0.5], + [exy_1 * 0.5, 1 + eyy_1], + ] + ) + ) + qx1 = m[0, 0] * qx1_init + m[0, 1] * qy1_init + qy1 = m[1, 0] * qx1_init + m[1, 1] * qy1_init + + # Generate moire lattice + ind0, ind1 = np.meshgrid( + np.arange(np.sum(sub0)), + np.arange(np.sum(sub1)), + indexing="ij", + ) + qx = qx0[ind0] + qx1[ind1] + qy = qy0[ind0] + qy1[ind1] + int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5 + + # moire labels + m_h0 = h0[ind0] + m_k0 = k0[ind0] + m_l0 = l0[ind0] + m_h1 = h1[ind1] + m_k1 = k1[ind1] + m_l1 = l1[ind1] + + # Convert thresholded and moire peaks to BraggVector class + + pl_dtype_parent = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h", "int"), + ("k", "int"), + ("l", "int"), + ] + ) + + bragg_parent_0 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_0.add_data_by_field( + [ + qx0.ravel(), + qy0.ravel(), + int0_sub.ravel(), + h0.ravel(), + k0.ravel(), + l0.ravel(), + ] + ) + + bragg_parent_1 = PointList(np.array([], dtype=pl_dtype_parent)) + bragg_parent_1.add_data_by_field( + [ + qx1.ravel(), + qy1.ravel(), + int1_sub.ravel(), + h1.ravel(), + k1.ravel(), + l1.ravel(), + ] + ) + + pl_dtype = np.dtype( + [ + ("qx", "float"), + ("qy", "float"), + ("intensity", "float"), + ("h0", "int"), + ("k0", "int"), + ("l0", "int"), + ("h1", "int"), + ("k1", "int"), + ("l1", "int"), + ] + ) + bragg_moire = PointList(np.array([], dtype=pl_dtype)) + bragg_moire.add_data_by_field( + [ + qx.ravel(), + qy.ravel(), + int_moire.ravel(), + m_h0.ravel(), + m_k0.ravel(), + m_l0.ravel(), + m_h1.ravel(), + m_k1.ravel(), + m_l1.ravel(), + ] + ) + + return bragg_parent_0, bragg_parent_1, bragg_moire + + +def plot_moire_diffraction_pattern( + bragg_parent_0, + bragg_parent_1, + bragg_moire, + int_range=(0, 5e-3), + k_max=1.0, + plot_subpixel=True, + labels=None, + marker_size_parent=16, + marker_size_moire=4, + text_size_parent=10, + text_size_moire=6, + add_labels_parent=False, + add_labels_moire=False, + dist_labels=0.03, + dist_check=0.06, + sep_labels=0.03, + figsize=(8, 6), + returnfig=False, +): + """ + Plot Moire lattice and parent lattices. + + Parameters + -------- + bragg_peaks_0: BraggVector + Bragg vectors for parent lattice 0. + bragg_peaks_1: BraggVector + Bragg vectors for parent lattice 1. + bragg_moire: BraggVector + Bragg vectors for moire lattice. + int_range: (float, float) + Plotting intensity range for the Moire peaks. + k_max: float + Max k value of the plotted Moire lattice. + plot_subpixel: bool + Apply subpixel corrections to the Bragg spot positions. + Matplotlib default scatter plot rounds to the nearest pixel. + labels: list + List of text labels for parent lattices + marker_size_parent: float + Size of plot markers for the two parent lattices. + marker_size_moire: float + Size of plot markers for the Moire lattice. + text_size_parent: float + Label text size for parent lattice. + text_size_moire: float + Label text size for Moire lattice. + add_labels_parent: bool + Plot the parent lattice index labels. + add_labels_moire: bool + Plot the parent lattice index labels for the Moire spots. + dist_labels: float + Distance to move the labels off the spots. + dist_check: float + Set to some distance to "push" the labels away from each other if they are within this distance. + sep_labels: float + Separation distance for labels which are "pushed" apart. + figsize: (float,float) + Size of output figure. + returnfig: bool + Return the (fix,ax) handles of the plot. + + Returns + -------- + fig, ax: matplotlib handles (optional) + Figure and axes handles for the moire plot. + """ + + # peak labels + + if labels is None: + labels = ("crystal 0", "crystal 1") + + def overline(x): + return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}") + + # parent 1 + qx0 = bragg_parent_0["qx"] + qy0 = bragg_parent_0["qy"] + h0 = bragg_parent_0["h"] + k0 = bragg_parent_0["k"] + l0 = bragg_parent_0["l"] + + # parent 2 + qx1 = bragg_parent_1["qx"] + qy1 = bragg_parent_1["qy"] + h1 = bragg_parent_1["h"] + k1 = bragg_parent_1["k"] + l1 = bragg_parent_1["l"] + + # moire + qx = bragg_moire["qx"] + qy = bragg_moire["qy"] + m_h0 = bragg_moire["h0"] + m_k0 = bragg_moire["k0"] + m_l0 = bragg_moire["l0"] + m_h1 = bragg_moire["h1"] + m_k1 = bragg_moire["k1"] + m_l1 = bragg_moire["l1"] + int_moire = bragg_moire["intensity"] + + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0.09, 0.09, 0.65, 0.9]) + ax_labels = fig.add_axes([0.75, 0, 0.25, 1]) + + text_params_parent = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_parent, + } + text_params_moire = { + "ha": "center", + "va": "center", + "family": "sans-serif", + "fontweight": "normal", + "size": text_size_moire, + } + + if plot_subpixel is False: + # moire + ax.scatter( + qy, + qx, + # color = (0,0,0,1), + c=int_moire, + s=marker_size_moire, + cmap="gray_r", + vmin=int_range[0], + vmax=int_range[1], + antialiased=True, + ) + + # parent lattices + ax.scatter( + qy0, + qx0, + color=(1, 0, 0, 1), + s=marker_size_parent, + antialiased=True, + ) + ax.scatter( + qy1, + qx1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + antialiased=True, + ) + + # origin + ax.scatter( + 0, + 0, + color=(0, 0, 0, 1), + s=marker_size_parent, + antialiased=True, + ) + + else: + # moire peaks + int_all = np.clip( + (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1 + ) + keep = np.logical_and.reduce( + (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max) + ) + for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_moire) / 800.0, + color=(1 - int_marker, 1 - int_marker, 1 - int_marker), + ) + ) + if add_labels_moire: + for a0 in range(qx.size): + if keep.ravel()[a0]: + x0 = qx.ravel()[a0] + y0 = qy.ravel()[a0] + d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2 + sub = d2 < dist_check**2 + xc = np.mean(qx.ravel()[sub]) + yc = np.mean(qy.ravel()[sub]) + xp = x0 - xc + yp = y0 - yc + if xp == 0 and yp == 0.0: + xp = x0 - dist_labels + yp = y0 + else: + leng = np.linalg.norm((xp, yp)) + xp = x0 + xp * dist_labels / leng + yp = y0 + yp * dist_labels / leng + + ax.text( + yp, + xp - sep_labels, + "$" + + overline(m_h0.ravel()[a0]) + + overline(m_k0.ravel()[a0]) + + overline(m_l0.ravel()[a0]) + + "$", + c="r", + **text_params_moire, + ) + ax.text( + yp, + xp, + "$" + + overline(m_h1.ravel()[a0]) + + overline(m_k1.ravel()[a0]) + + overline(m_l1.ravel()[a0]) + + "$", + c=(0, 0.7, 1.0), + **text_params_moire, + ) + + keep = np.logical_and.reduce( + (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max) + ) + for x, y in zip(qx0[keep], qy0[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(1, 0, 0), + ) + ) + if add_labels_parent: + for a0 in range(qx0.size): + if keep.ravel()[a0]: + xp = qx0.ravel()[a0] - dist_labels + yp = qy0.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h0.ravel()[a0]) + + overline(k0.ravel()[a0]) + + overline(l0.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + keep = np.logical_and.reduce( + (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max) + ) + for x, y in zip(qx1[keep], qy1[keep]): + ax.add_artist( + Circle( + xy=(y, x), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0.7, 1), + ) + ) + if add_labels_parent: + for a0 in range(qx1.size): + if keep.ravel()[a0]: + xp = qx1.ravel()[a0] - dist_labels + yp = qy1.ravel()[a0] + ax.text( + yp, + xp, + "$" + + overline(h1.ravel()[a0]) + + overline(k1.ravel()[a0]) + + overline(l1.ravel()[a0]) + + "$", + c="k", + **text_params_parent, + ) + + # origin + ax.add_artist( + Circle( + xy=(0, 0), + radius=np.sqrt(marker_size_parent) / 800.0, + color=(0, 0, 0), + ) + ) + + ax.set_xlim((-k_max, k_max)) + ax.set_ylim((-k_max, k_max)) + ax.set_ylabel("$q_x$ (1/A)") + ax.set_xlabel("$q_y$ (1/A)") + ax.invert_yaxis() + + # labels + ax_labels.scatter( + 0, + 0, + color=(1, 0, 0, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -1, + color=(0, 0.7, 1, 1), + s=marker_size_parent, + ) + ax_labels.scatter( + 0, + -2, + color=(0, 0, 0, 1), + s=marker_size_moire, + ) + ax_labels.text( + 0.4, + -0.2, + labels[0], + fontsize=14, + ) + ax_labels.text( + 0.4, + -1.2, + labels[1], + fontsize=14, + ) + ax_labels.text( + 0.4, + -2.2, + "Moiré lattice", + fontsize=14, + ) + + ax_labels.set_xlim((-1, 4)) + ax_labels.set_ylim((-21, 1)) + + ax_labels.axis("off") + + if returnfig: + return fig, ax diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index da553456f..748837c61 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -1,8 +1,6 @@ import numpy as np import matplotlib.pyplot as plt -import os from typing import Union, Optional -import time, sys from tqdm import tqdm from emdfile import tqdmnd, PointList, PointListArray @@ -16,7 +14,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = None @@ -31,6 +29,7 @@ def orientation_plan( corr_kernel_size: float = 0.08, radial_power: float = 1.0, intensity_power: float = 0.25, # New default intensity power scaling + calculate_correlation_array=True, tol_peak_delete=None, tol_distance: float = 0.01, fiber_axis=None, @@ -63,6 +62,8 @@ def orientation_plan( corr_kernel_size (float): Correlation kernel size length in Angstroms radial_power (float): Power for scaling the correlation intensity as a function of the peak radius intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity + calculate_correlation_array (bool): Set to false to skip calculating the correlation array. + This is useful when we only want the angular range / rotation matrices. tol_peak_delete (float): Distance to delete peaks for multiple matches. Default is kernel_size * 0.5 tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms] @@ -600,21 +601,6 @@ def orientation_plan( # init storage arrays self.orientation_rotation_angles = np.zeros((self.orientation_num_zones, 3)) self.orientation_rotation_matrices = np.zeros((self.orientation_num_zones, 3, 3)) - self.orientation_ref = np.zeros( - ( - self.orientation_num_zones, - np.size(self.orientation_shell_radii), - self.orientation_in_plane_steps, - ), - dtype="float", - ) - # self.orientation_ref_1D = np.zeros( - # ( - # self.orientation_num_zones, - # np.size(self.orientation_shell_radii), - # ), - # dtype="float", - # ) # If possible, Get symmetry operations for this spacegroup, store in matrix form if self.pymatgen_available: @@ -699,65 +685,76 @@ def orientation_plan( k0 = np.array([0.0, 0.0, -1.0 / self.wavelength]) n = np.array([0.0, 0.0, -1.0]) - for a0 in tqdmnd( - np.arange(self.orientation_num_zones), - desc="Orientation plan", - unit=" zone axes", - disable=not progress_bar, - ): - # reciprocal lattice spots and excitation errors - g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all - sg = self.excitation_errors(g) - - # Keep only points that will contribute to this orientation plan slice - keep = np.abs(sg) < self.orientation_kernel_size - - # in-plane rotation angle - phi = np.arctan2(g[1, :], g[0, :]) - - # Loop over all peaks - for a1 in np.arange(self.g_vec_all.shape[1]): - ind_radial = self.orientation_shell_index[a1] + if calculate_correlation_array: + # initialize empty correlation array + self.orientation_ref = np.zeros( + ( + self.orientation_num_zones, + np.size(self.orientation_shell_radii), + self.orientation_in_plane_steps, + ), + dtype="float", + ) - if keep[a1] and ind_radial >= 0: - # 2D orientation plan - self.orientation_ref[a0, ind_radial, :] += ( - np.power(self.orientation_shell_radii[ind_radial], radial_power) - * np.power(self.struct_factors_int[a1], intensity_power) - * np.maximum( - 1 - - np.sqrt( - sg[a1] ** 2 - + ( - ( - np.mod( - self.orientation_gamma - phi[a1] + np.pi, - 2 * np.pi, + for a0 in tqdmnd( + np.arange(self.orientation_num_zones), + desc="Orientation plan", + unit=" zone axes", + disable=not progress_bar, + ): + # reciprocal lattice spots and excitation errors + g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all + sg = self.excitation_errors(g) + + # Keep only points that will contribute to this orientation plan slice + keep = np.abs(sg) < self.orientation_kernel_size + + # in-plane rotation angle + phi = np.arctan2(g[1, :], g[0, :]) + + # Loop over all peaks + for a1 in np.arange(self.g_vec_all.shape[1]): + ind_radial = self.orientation_shell_index[a1] + + if keep[a1] and ind_radial >= 0: + # 2D orientation plan + self.orientation_ref[a0, ind_radial, :] += ( + np.power(self.orientation_shell_radii[ind_radial], radial_power) + * np.power(self.struct_factors_int[a1], intensity_power) + * np.maximum( + 1 + - np.sqrt( + sg[a1] ** 2 + + ( + ( + np.mod( + self.orientation_gamma - phi[a1] + np.pi, + 2 * np.pi, + ) + - np.pi ) - - np.pi + * self.orientation_shell_radii[ind_radial] ) - * self.orientation_shell_radii[ind_radial] + ** 2 ) - ** 2 + / self.orientation_kernel_size, + 0, ) - / self.orientation_kernel_size, - 0, ) - ) - orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) - if orientation_ref_norm > 0: - self.orientation_ref[a0, :, :] /= orientation_ref_norm + orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2)) + if orientation_ref_norm > 0: + self.orientation_ref[a0, :, :] /= orientation_ref_norm - # Maximum value - self.orientation_ref_max = np.max(np.real(self.orientation_ref)) + # Maximum value + self.orientation_ref_max = np.max(np.real(self.orientation_ref)) - # Fourier domain along angular axis - if self.CUDA: - self.orientation_ref = cp.asarray(self.orientation_ref) - self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref)) - else: - self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref)) + # Fourier domain along angular axis + if self.CUDA: + self.orientation_ref = cp.asarray(self.orientation_ref) + self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref)) + else: + self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref)) def match_orientations( @@ -801,12 +798,12 @@ def match_orientations( ) # check cal state - if bragg_peaks_array.calstate["ellipse"] == False: + if bragg_peaks_array.calstate["ellipse"] is False: ellipse = False warn("Warning: bragg peaks not elliptically calibrated") else: ellipse = True - if bragg_peaks_array.calstate["rotate"] == False: + if bragg_peaks_array.calstate["rotate"] is False: rotate = False warn("bragg peaks not rotationally calibrated") else: @@ -905,7 +902,13 @@ def match_single_pattern( Figure handles for the plotting output """ - # init orientation output + # adding assert statement for checking self.orientation_ref is present + # adding assert statement for checking self.orientation_ref is present + if not hasattr(self, "orientation_ref"): + raise ValueError( + "orientation_plan must be run with 'calculate_correlation_array=True'" + ) + orientation = Orientation(num_matches=num_matches_return) if bragg_peaks.data.shape[0] < min_number_peaks: return orientation @@ -1837,7 +1840,9 @@ def cluster_grains( xr = np.clip(x + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1) yr = np.clip(y + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1) inds_cand = inds_all[xr[:, None], yr[None], :].ravel() - inds_cand = np.delete(inds_cand, mark.ravel()[inds_cand] == False) + inds_cand = np.delete( + inds_cand, mark.ravel()[inds_cand] == False # noqa: E712 + ) if inds_cand.size == 0: grow = False @@ -1890,7 +1895,7 @@ def cluster_grains( inds_grain = np.append(inds_grain, inds_cand[keep]) inds_cand = np.unique( - np.delete(inds_new, mark.ravel()[inds_new] == False) + np.delete(inds_new, mark.ravel()[inds_new] == False) # noqa: E712 ) if inds_cand.size == 0: @@ -2080,12 +2085,12 @@ def calculate_strain( radius_max_2 = corr_kernel_size**2 # check cal state - if bragg_peaks_array.calstate["ellipse"] == False: + if bragg_peaks_array.calstate["ellipse"] is False: ellipse = False warn("bragg peaks not elliptically calibrated") else: ellipse = True - if bragg_peaks_array.calstate["rotate"] == False: + if bragg_peaks_array.calstate["rotate"] is False: rotate = False warn("bragg peaks not rotationally calibrated") else: diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py index 84824fe63..d28616aa9 100644 --- a/py4DSTEM/process/diffraction/crystal_phase.py +++ b/py4DSTEM/process/diffraction/crystal_phase.py @@ -52,7 +52,7 @@ def plot_all_phase_maps(self, map_scale_values=None, index=0): map_scale_values (float): Value to scale correlations by """ phase_maps = [] - if map_scale_values == None: + if map_scale_values is None: map_scale_values = [1] * len(self.orientation_maps) corr_sum = np.sum( [ @@ -75,7 +75,7 @@ def plot_phase_map(self, index=0, cmap=None): for p in range(len(self.orientation_maps)) ] - if cmap == None: + if cmap is None: cm = plt.get_cmap("rainbow") cmap = [ cm(1.0 * i / len(self.orientation_maps)) @@ -276,7 +276,7 @@ def quantify_phase_pointlist( if len(pointlist["qx"]) > 0: if mask_peaks is not None: for i in range(len(mask_peaks)): - if mask_peaks[i] == None: + if mask_peaks[i] == None: # noqa: E711 continue inds_mask = np.where( pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0 diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 642c3223b..da016b3ed 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -7,7 +7,7 @@ from scipy.signal import medfilt from scipy.ndimage import gaussian_filter -from scipy.ndimage.morphology import distance_transform_edt +from scipy.ndimage import distance_transform_edt from skimage.morphology import dilation, erosion import warnings @@ -621,7 +621,7 @@ def plot_orientation_zones( # x = r * np.sin(theta) # y = r * np.cos(theta) - warnings.filterwarnings("ignore", module="matplotlib\..*") + warnings.filterwarnings("ignore", module=r"matplotlib\..*") line_params = {"linewidth": 2, "alpha": 0.1, "c": "k"} for phi in np.arange(0, 180, 5): ax.plot3D( @@ -1806,11 +1806,11 @@ def plot_fiber_orientation_maps( np.round(leg_size * 1.0), ] labels = [ - str(np.round(self.orientation_fiber_angles[0] * 0.00)) + "$\degree$", - str(np.round(self.orientation_fiber_angles[0] * 0.25)) + "$\degree$", - str(np.round(self.orientation_fiber_angles[0] * 0.50)) + "$\degree$", - str(np.round(self.orientation_fiber_angles[0] * 0.75)) + "$\degree$", - str(np.round(self.orientation_fiber_angles[0] * 1.00)) + "$\degree$", + str(np.round(self.orientation_fiber_angles[0] * 0.00)) + "$\\degree$", + str(np.round(self.orientation_fiber_angles[0] * 0.25)) + "$\\degree$", + str(np.round(self.orientation_fiber_angles[0] * 0.50)) + "$\\degree$", + str(np.round(self.orientation_fiber_angles[0] * 0.75)) + "$\\degree$", + str(np.round(self.orientation_fiber_angles[0] * 1.00)) + "$\\degree$", ] ax_op_l.set_xticks(ticks) ax_op_l.set_xticklabels(labels) @@ -2195,7 +2195,7 @@ def plot_ring_pattern( figsize=(10, 10), returnfig=False, input_fig_handle=None, - **kwargs + **kwargs, ): """ 2D plot of diffraction rings @@ -2222,7 +2222,7 @@ def plot_ring_pattern( ax = ax_parent[0] for a1 in range(radii.shape[0]): - if intensity_constant == True: + if intensity_constant is True: ax.plot( radii[a1] * np.sin(theta), radii[a1] * np.cos(theta), diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index 349d88530..9973ff79f 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -86,7 +86,7 @@ def fit_2D( xy = np.vstack((rx_1D, ry_1D)) # if robust fitting is turned off, set number of robust iterations to 0 - if robust == False: + if robust is False: robust_steps = 0 # least squares fitting @@ -107,7 +107,7 @@ def fit_2D( fit_mean_square_error > np.mean(fit_mean_square_error) * robust_thresh**2 ) - mask[_mask] == False + mask[_mask] = False # perform fitting popt, pcov = curve_fit( diff --git a/py4DSTEM/process/latticevectors/__init__.py b/py4DSTEM/process/latticevectors/__init__.py deleted file mode 100644 index 560a3b7e6..000000000 --- a/py4DSTEM/process/latticevectors/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from py4DSTEM.process.latticevectors.initialguess import * -from py4DSTEM.process.latticevectors.index import * -from py4DSTEM.process.latticevectors.fit import * -from py4DSTEM.process.latticevectors.strain import * diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py deleted file mode 100644 index 659bc8940..000000000 --- a/py4DSTEM/process/latticevectors/fit.py +++ /dev/null @@ -1,200 +0,0 @@ -# Functions for fitting lattice vectors to measured Bragg peak positions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.data import RealSlice - - -def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (7-tuple) A 7-tuple containing: - - * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. - * **y0**: *(float)* the y-coord of the origin - * **g1x**: *(float)* x-coord of the first lattice vector - * **g1y**: *(float)* y-coord of the first lattice vector - * **g2x**: *(float)* x-coord of the second lattice vector - * **g2y**: *(float)* y-coord of the second lattice vector - * **error**: *(float)* the fit error - """ - assert isinstance(braggpeaks, PointList) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - braggpeaks = braggpeaks.copy() - - # Remove unindexed peaks - if "index_mask" in braggpeaks.dtype.names: - deletemask = braggpeaks.data["index_mask"] == False - braggpeaks.remove(deletemask) - - # Check to ensure enough peaks are present - if braggpeaks.length < minNumPeaks: - return None, None, None, None, None, None, None - - # Get M, the matrix of (h,k) indices - h, k = braggpeaks.data["h"], braggpeaks.data["k"] - M = np.vstack((np.ones_like(h, dtype=int), h, k)).T - - # Get alpha, the matrix of measured Bragg peak positions - alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T - - # Get weighted matrices - weights = braggpeaks.data["intensity"] - weighted_M = M * weights[:, np.newaxis] - weighted_alpha = alpha * weights[:, np.newaxis] - - # Solve for lattice vectors - beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] - x0, y0 = beta[0, 0], beta[0, 1] - g1x, g1y = beta[1, 0], beta[1, 1] - g2x, g2y = beta[2, 0], beta[2, 1] - - # Calculate the error - alpha_calculated = np.matmul(M, beta) - error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) - error = np.sum(error * weights) / np.sum(weights) - - return x0, y0, g1x, g1y, g2x, g2y, error - - -def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - - return g1g2_map - - -def fit_lattice_vectors_masked(braggpeaks, mask, x0=0, y0=0, minNumPeaks=5): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks corresponding - to a scan position for which mask==True. - - Args: - braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. - Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a - weighting factor when fitting), 'h','k' (indexing). May optionally also - contain 'index_mask' (bool), indicating which peaks have been successfully - indixed and should be used. - mask (boolean array): real space shaped (R_Nx,R_Ny); fit lattice vectors where - mask is True - x0 (float): x-coord of the origin - y0 (float): y-coord of the origin - minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks - which can be indexed, return None for all return parameters - - Returns: - (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: - - * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice - * ``g1g2_map.get_slice('y0')`` y-coord of the origin - * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector - * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector - * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector - * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector - * ``g1g2_map.get_slice('error')`` the fit error - * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful - fits - """ - assert isinstance(braggpeaks, PointListArray) - assert np.all( - [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity")] - ) - - # Make RealSlice to contain outputs - slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") - g1g2_map = RealSlice( - data=np.zeros((braggpeaks.shape[0], braggpeaks.shape[1], 8)), - slicelabels=slicelabels, - name="g1g2_map", - ) - - # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): - if mask[Rx, Ry]: - braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) - qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( - braggpeaks_curr, x0, y0, minNumPeaks - ) - # Store data - if g1x is not None: - g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 - g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x - g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y - g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x - g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y - g1g2_map.get_slice("error").data[Rx, Ry] = error - g1g2_map.get_slice("mask").data[Rx, Ry] = 1 - return g1g2_map diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py deleted file mode 100644 index 4ac7939e7..000000000 --- a/py4DSTEM/process/latticevectors/index.py +++ /dev/null @@ -1,280 +0,0 @@ -# Functions for indexing the Bragg directions - -import numpy as np -from numpy.linalg import lstsq - -from emdfile import tqdmnd, PointList, PointListArray - - -def get_selected_lattice_vectors(gx, gy, i0, i1, i2): - """ - From a set of reciprocal lattice points (gx,gy), and indices in those arrays which - specify the center beam, the first basis lattice vector, and the second basis lattice - vector, computes and returns the lattice vectors g1 and g2. - - Args: - gx (1d array): the reciprocal lattice points x-coords - gy (1d array): the reciprocal lattice points y-coords - i0 (int): index in the (gx,gy) arrays specifying the center beam - i1 (int): index in the (gx,gy) arrays specifying the first basis lattice vector - i2 (int): index in the (gx,gy) arrays specifying the second basis lattice vector - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing - - * **g1**: *(2-tuple)* the first lattice vector, (g1x,g1y) - * **g2**: *(2-tuple)* the second lattice vector, (g2x,g2y) - """ - for i in (i0, i1, i2): - assert isinstance(i, (int, np.integer)) - g1x = gx[i1] - gx[i0] - g1y = gy[i1] - gy[i0] - g2x = gx[i2] - gx[i0] - g2y = gy[i2] - gy[i0] - return (g1x, g1y), (g2x, g2y) - - -def index_bragg_directions(x0, y0, gx, gy, g1, g2): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - The approach is to solve the matrix equation - ``alpha = beta * M`` - where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, - beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the - h,k indices. - - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - gx (1d array): x-coord of the reciprocal lattice vectors - gy (1d array): y-coord of the reciprocal lattice vectors - g1 (2-tuple of floats): g1x,g1y - g2 (2-tuple of floats): g2x,g2y - - Returns: - (3-tuple) A 3-tuple containing: - - * **h**: *(ndarray of ints)* first index of the bragg directions - * **k**: *(ndarray of ints)* second index of the bragg directions - * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the - indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y - coords 'h' and 'k' contain h and k. - """ - # Get beta, the matrix of lattice vectors - beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) - - # Get alpha, the matrix of measured bragg angles - alpha = np.vstack([gx - x0, gy - y0]) - - # Calculate M, the matrix of peak positions - M = lstsq(beta, alpha, rcond=None)[0].T - M = np.round(M).astype(int) - - # Get h,k - h = M[:, 0] - k = M[:, 1] - - # Store in a PointList - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - temp_array = np.zeros([], dtype=coords) - bragg_directions = PointList(data=temp_array) - bragg_directions.add_data_by_field((gx, gy, h, k)) - mask = np.zeros(bragg_directions["qx"].shape[0]) - mask[0] = 1 - bragg_directions.remove(mask) - - return h, k, bragg_directions - - -def generate_lattice(ux, uy, vx, vy, x0, y0, Q_Nx, Q_Ny, h_max=None, k_max=None): - """ - Returns a full reciprocal lattice stretching to the limits of the diffraction pattern - by making linear combinations of the lattice vectors up to (±h_max,±k_max). - - This can be useful when there are false peaks or missing peaks in the braggvectormap, - which can cause errors in the strain finding routines that rely on those peaks for - indexing. This allows us to create a reference lattice that has all combinations of - the lattice vectors all the way out to the edges of the frame, and excluding any - erroneous intermediate peaks. - - Args: - ux (float): x-coord of the u lattice vector - uy (float): y-coord of the u lattice vector - vx (float): x-coord of the v lattice vector - vy (float): y-coord of the v lattice vector - x0 (float): x-coord of the lattice origin - y0 (float): y-coord of the lattice origin - Q_Nx (int): diffraction pattern size in the x-direction - Q_Ny (int): diffraction pattern size in the y-direction - h_max, k_max (int): maximal indices for generating the lattice (the lattive is - always trimmed to fit inside the pattern so you can overestimate these, or - leave unspecified and they will be automatically found) - - Returns: - (PointList): A 4-coordinate PointList, ('qx','qy','h','k'), containing points - corresponding to linear combinations of the u and v vectors, with associated - indices - """ - - # Matrix of lattice vectors - beta = np.array([[ux, uy], [vx, vy]]) - - # If no max index is specified, (over)estimate based on image size - if (h_max is None) or (k_max is None): - (y, x) = np.mgrid[0:Q_Ny, 0:Q_Nx] - x = x - x0 - y = y - y0 - h_max = np.max(np.ceil(np.abs((x / ux, y / uy)))) - k_max = np.max(np.ceil(np.abs((x / vx, y / vy)))) - - (hlist, klist) = np.meshgrid( - np.arange(-h_max, h_max + 1), np.arange(-k_max, k_max + 1) - ) - - M_ideal = np.vstack((hlist.ravel(), klist.ravel())).T - ideal_peaks = np.matmul(M_ideal, beta) - - coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] - - ideal_data = np.zeros(len(ideal_peaks[:, 0]), dtype=coords) - ideal_data["qx"] = ideal_peaks[:, 0] - ideal_data["qy"] = ideal_peaks[:, 1] - ideal_data["h"] = M_ideal[:, 0] - ideal_data["k"] = M_ideal[:, 1] - - ideal_lattice = PointList(data=ideal_data) - - # shift to the DP center - ideal_lattice.data["qx"] += x0 - ideal_lattice.data["qy"] += y0 - - # trim peaks outside the image - deletePeaks = ( - (ideal_lattice.data["qx"] > Q_Nx) - | (ideal_lattice.data["qx"] < 0) - | (ideal_lattice.data["qy"] > Q_Ny) - | (ideal_lattice.data["qy"] < 0) - ) - ideal_lattice.remove(deletePeaks) - - return ideal_lattice - - -def add_indices_to_braggvectors( - braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None -): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - braggpeaks (PointListArray): the braggpeaks to index. Must contain - the coordinates 'qx', 'qy', and 'intensity' - lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. - Must contain the coordinates 'qx', 'qy', 'h', and 'k' - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - - Returns: - (PointListArray): The original braggpeaks pointlistarray, with new coordinates - 'h', 'k', containing the indices of each indexable peak. - """ - - # assert isinstance(braggpeaks,BraggVectors) - # assert isinstance(lattice, PointList) - # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) - - if mask is None: - mask = np.ones(braggpeaks.Rshape, dtype=bool) - - assert ( - mask.shape == braggpeaks.Rshape - ), "mask must have same shape as pointlistarray" - assert mask.dtype == bool, "mask must be boolean" - - coords = [ - ("qx", float), - ("qy", float), - ("intensity", float), - ("h", int), - ("k", int), - ] - - indexed_braggpeaks = PointListArray( - dtype=coords, - shape=braggpeaks.Rshape, - ) - - # loop over all the scan positions - for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): - if mask[Rx, Ry]: - pl = braggpeaks.cal[Rx, Ry] - for i in range(pl.data.shape[0]): - r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( - pl.data["qy"][i] - lattice.data["qy"] + qy_shift - ) ** 2 - ind = np.argmin(r2) - if r2[ind] <= maxPeakSpacing**2: - indexed_braggpeaks[Rx, Ry].add_data_by_field( - ( - pl.data["qx"][i], - pl.data["qy"][i], - pl.data["intensity"][i], - lattice.data["h"][ind], - lattice.data["k"][ind], - ) - ) - - return indexed_braggpeaks - - -def bragg_vector_intensity_map_by_index(braggpeaks, h, k, symmetric=False): - """ - Returns a correlation intensity map for an indexed (h,k) Bragg vector - Used to obtain a darkfield image corresponding to the (h,k) reflection - or a bightfield image when h=k=0 - - Args: - braggpeaks (PointListArray): must contain the coordinates 'h','k', and - 'intensity' - h, k (int): indices for the reflection to generate an intensity map from - symmetric (bool): if set to true, returns sum of intensity of (h,k), (-h,k), - (h,-k), (-h,-k) - - Returns: - (numpy array): a map of the intensity of the (h,k) Bragg vector correlation. - Same shape as the pointlistarray. - """ - assert isinstance(braggpeaks, PointListArray), "braggpeaks must be a PointListArray" - assert np.all([name in braggpeaks.dtype.names for name in ("h", "k", "intensity")]) - intensity_map = np.zeros(braggpeaks.shape, dtype=float) - - for Rx in range(braggpeaks.shape[0]): - for Ry in range(braggpeaks.shape[1]): - pl = braggpeaks.get_pointlist(Rx, Ry) - if pl.length > 0: - if symmetric: - matches = np.logical_and( - np.abs(pl.data["h"]) == np.abs(h), - np.abs(pl.data["k"]) == np.abs(k), - ) - else: - matches = np.logical_and(pl.data["h"] == h, pl.data["k"] == k) - - if len(matches) > 0: - intensity_map[Rx, Ry] = np.sum(pl.data["intensity"][matches]) - - return intensity_map diff --git a/py4DSTEM/process/latticevectors/initialguess.py b/py4DSTEM/process/latticevectors/initialguess.py deleted file mode 100644 index d8054143f..000000000 --- a/py4DSTEM/process/latticevectors/initialguess.py +++ /dev/null @@ -1,229 +0,0 @@ -# Obtain an initial guess at the lattice vectors - -import numpy as np -from scipy.ndimage import gaussian_filter -from skimage.transform import radon - -from py4DSTEM.process.utils import get_maxima_1D - - -def get_radon_scores( - braggvectormap, - mask=None, - N_angles=200, - sigma=2, - minSpacing=2, - minRelativeIntensity=0.05, -): - """ - Calculates a score function, score(angle), representing the likelihood that angle is - a principle lattice direction of the lattice in braggvectormap. - - The procedure is as follows: - If mask is not None, ignore any data in braggvectormap where mask is False. Useful - for removing the unscattered beam, which can dominate the results. - Take the Radon transform of the (masked) Bragg vector map. - For each angle, get the corresponding slice of the sinogram, and calculate its score. - If we let R_theta(r) be the sinogram slice at angle theta, and where r is the - sinogram position coordinate, then the score of the slice is given by - score(theta) = sum_i(R_theta(r_i)) / N_i - Here, r_i are the positions r of all local maxima in R_theta(r), and N_i is the - number of such maxima. Thus the score is large when there are few maxima which are - high intensity. - - Args: - braggvectormap (ndarray): the Bragg vector map - mask (ndarray of bools): ignore data in braggvectormap wherever mask==False - N_angles (int): the number of angles at which to calculate the score - sigma (float): smoothing parameter for local maximum identification - minSpacing (float): if two maxima are found in a radon slice closer than - minSpacing, the dimmer of the two is removed - minRelativeIntensity (float): maxima in each radon slice dimmer than - minRelativeIntensity compared to the most intense maximum are removed - - Returns: - (3-tuple) A 3-tuple containing: - - * **scores**: *(ndarray, len N_angles, floats)* the scores for each angle - * **thetas**: *(ndarray, len N_angles, floats)* the angles, in radians - * **sinogram**: *(ndarray)* the radon transform of braggvectormap*mask - """ - # Get sinogram - thetas = np.linspace(0, 180, N_angles) - if mask is not None: - sinogram = radon(braggvectormap * mask, theta=thetas, circle=False) - else: - sinogram = radon(braggvectormap, theta=thetas, circle=False) - - # Get scores - N_maxima = np.empty_like(thetas) - total_intensity = np.empty_like(thetas) - for i in range(len(thetas)): - theta = thetas[i] - - # Get radon transform slice - ind = np.argmin(np.abs(thetas - theta)) - sinogram_theta = sinogram[:, ind] - sinogram_theta = gaussian_filter(sinogram_theta, 2) - - # Get maxima - maxima = get_maxima_1D(sinogram_theta, sigma, minSpacing, minRelativeIntensity) - - # Calculate metrics - N_maxima[i] = len(maxima) - total_intensity[i] = np.sum(sinogram_theta[maxima]) - scores = total_intensity / N_maxima - - return scores, np.radians(thetas), sinogram - - -def get_lattice_directions_from_scores( - thetas, scores, sigma=2, minSpacing=2, minRelativeIntensity=0.05, index1=0, index2=0 -): - """ - Get the lattice directions from the scores of the radon transform slices. - - Args: - thetas (ndarray): the angles, in radians - scores (ndarray): the scores - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - index1 (int): specifies which local maximum to use for the first lattice - direction, in order of maximum intensity - index2 (int): specifies the local maximum for the second lattice direction - - Returns: - (2-tuple) A 2-tuple containing: - - * **theta1**: *(float)* the first lattice direction, in radians - * **theta2**: *(float)* the second lattice direction, in radians - """ - assert len(thetas) == len(scores), "Size of thetas and scores must match" - - # Get first lattice direction - maxima1 = get_maxima_1D( - scores, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max1 = thetas[maxima1] - scores_max1 = scores[maxima1] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max1), dtype=dtype) - ar_structured["thetas"] = thetas_max1 - ar_structured["scores"] = scores_max1 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta1 = ar_structured["thetas"][index1] # Get direction 1 - - # Apply sin**2 damping - scores_damped = scores * np.sin(thetas - theta1) ** 2 - - # Get second lattice direction - maxima2 = get_maxima_1D( - scores_damped, sigma, minSpacing, minRelativeIntensity - ) # Get maxima - thetas_max2 = thetas[maxima2] - scores_max2 = scores[maxima2] - dtype = np.dtype( - [("thetas", thetas.dtype), ("scores", scores.dtype)] - ) # Sort by intensity - ar_structured = np.empty(len(thetas_max2), dtype=dtype) - ar_structured["thetas"] = thetas_max2 - ar_structured["scores"] = scores_max2 - ar_structured = np.sort(ar_structured, order="scores")[::-1] - theta2 = ar_structured["thetas"][index2] # Get direction 2 - - return theta1, theta2 - - -def get_lattice_vector_lengths( - u_theta, - v_theta, - thetas, - sinogram, - spacing_thresh=1.5, - sigma=1, - minSpacing=2, - minRelativeIntensity=0.1, -): - """ - Gets the lengths of the two lattice vectors from their angles and the sinogram. - - First, finds the spacing between peaks in the sinogram slices projected down the u- - and v- directions, u_proj and v_proj. Then, finds the lengths by taking:: - - |u| = v_proj/sin(u_theta-v_theta) - |v| = u_proj/sin(u_theta-v_theta) - - The most important thresholds for this function are spacing_thresh, which discards - any detected spacing between adjacent radon projection peaks which deviate from the - median spacing by more than this fraction, and minRelativeIntensity, which discards - detected maxima (from which spacings are then calculated) below this threshold - relative to the brightest maximum. - - Args: - u_theta (float): the angle of u, in radians - v_theta (float): the angle of v, in radians - thetas (ndarray): the angles corresponding to the sinogram - sinogram (ndarray): the sinogram - spacing_thresh (float): ignores spacings which are greater than spacing_thresh - times the median spacing - sigma (float): gaussian blur for local maxima identification - minSpacing (float): minimum spacing for local maxima identification - minRelativeIntensity (float): minumum intensity, relative to the brightest - maximum, for local maxima identification - - Returns: - (2-tuple) A 2-tuple containing: - - * **u_length**: *(float)* the length of u, in pixels - * **v_length**: *(float)* the length of v, in pixels - """ - assert ( - len(thetas) == sinogram.shape[1] - ), "thetas must corresponding to the number of sinogram projection directions." - - # Get u projected spacing - ind = np.argmin(np.abs(thetas - u_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - u_projected_spacing = np.mean(spacings) - - # Get v projected spacing - ind = np.argmin(np.abs(thetas - v_theta)) - sinogram_slice = sinogram[:, ind] - maxima = get_maxima_1D(sinogram_slice, sigma, minSpacing, minRelativeIntensity) - spacings = np.sort(np.arange(sinogram_slice.shape[0])[maxima]) - spacings = spacings[1:] - spacings[:-1] - mask = ( - np.array( - [ - max(i, np.median(spacings)) / min(i, np.median(spacings)) - for i in spacings - ] - ) - < spacing_thresh - ) - spacings = spacings[mask] - v_projected_spacing = np.mean(spacings) - - # Get u and v lengths - sin_uv = np.sin(np.abs(u_theta - v_theta)) - u_length = v_projected_spacing / sin_uv - v_length = u_projected_spacing / sin_uv - - return u_length, v_length diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py deleted file mode 100644 index 6f4000449..000000000 --- a/py4DSTEM/process/latticevectors/strain.py +++ /dev/null @@ -1,231 +0,0 @@ -# Functions for calculating strain from lattice vector maps - -import numpy as np -from numpy.linalg import lstsq - -from py4DSTEM.data import RealSlice - - -def get_reference_g1g2(g1g2_map, mask): - """ - Gets a pair of reference lattice vectors from a region of real space specified by - mask. Takes the median of the lattice vectors in g1g2_map within the specified - region. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever - mask==True - - Returns: - (2-tuple of 2-tuples) A 2-tuple containing: - - * **g1**: *(2-tuple)* first reference lattice vector (x,y) - * **g2**: *(2-tuple)* second reference lattice vector (x,y) - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] - ) - assert mask.dtype == bool - g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) - g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) - g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) - g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) - return (g1x, g1y), (g2x, g2y) - - -def get_strain_from_reference_g1g2(g1g2_map, g1, g2): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real and - diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - g1 (2-tuple): first reference lattice vector (x,y) - g2 (2-tuple): second reference lattice vector (x,y) - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - - # Get RealSlice for output storage - R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape - strain_map = RealSlice( - data=np.zeros((5, R_Nx, R_Ny)), - slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), - name="strain_map", - ) - - # Get reference lattice matrix - g1x, g1y = g1 - g2x, g2y = g2 - M = np.array([[g1x, g1y], [g2x, g2y]]) - - for Rx in range(R_Nx): - for Ry in range(R_Ny): - # Get lattice vectors for DP at Rx,Ry - alpha = np.array( - [ - [ - g1g2_map.get_slice("g1x").data[Rx, Ry], - g1g2_map.get_slice("g1y").data[Rx, Ry], - ], - [ - g1g2_map.get_slice("g2x").data[Rx, Ry], - g1g2_map.get_slice("g2y").data[Rx, Ry], - ], - ] - ) - # Get transformation matrix - beta = lstsq(M, alpha, rcond=None)[0].T - - # Get the infinitesimal strain matrix - strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] - strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] - strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 - strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 - strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ - Rx, Ry - ] - return strain_map - - -def get_strain_from_reference_region(g1g2_map, mask): - """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Note that this function will return the strain map oriented with respect to the x/y - axes of diffraction space - to rotate the coordinate system, use - get_rotated_strain_map(). Calibration of the rotational misalignment between real - and diffraction space may also be necessary. - - Args: - g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data - under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for - fit_lattice_vectors_all_DPs() for more information. - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - Returns: - (RealSlice) the strain map; contains the elements of the infinitessimal strain - matrix, in the following 5 arrays: - - * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect - to x - * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect - to y - * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect - to y - * ``strain_map.get_slice('theta')``: rotation of lattice with respect to - reference - * ``strain_map.get_slice('mask')``: 0/False indicates unknown values - - Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical - """ - assert isinstance(g1g2_map, RealSlice) - assert np.all( - [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] - ) - assert mask.dtype == bool - - g1, g2 = get_reference_g1g2(g1g2_map, mask) - strain_map = get_strain_from_reference_g1g2(g1g2_map, g1, g2) - return strain_map - - -def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector - along the new x-axis - unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the - infinitessimal strain matrix elements, stored at - * unrotated_strain_map.get_slice('e_xx') - * unrotated_strain_map.get_slice('e_xy') - * unrotated_strain_map.get_slice('e_yy') - * unrotated_strain_map.get_slice('theta') - - Returns: - (RealSlice) the rotated counterpart to unrotated_strain_map, with the - rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate - system - """ - assert isinstance(unrotated_strain_map, RealSlice) - assert np.all( - [ - key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] - for key in unrotated_strain_map.slicelabels - ] - ) - theta = -np.arctan2(xaxis_y, xaxis_x) - cost = np.cos(theta) - sint = np.sin(theta) - cost2 = cost**2 - sint2 = sint**2 - - Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape - rotated_strain_map = RealSlice( - data=np.zeros((5, Rx, Ry)), - slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], - name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), - ) - - rotated_strain_map.data[0, :, :] = ( - cost2 * unrotated_strain_map.get_slice("e_xx").data - - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + sint2 * unrotated_strain_map.get_slice("e_yy").data - ) - rotated_strain_map.data[1, :, :] = ( - cost - * sint - * ( - unrotated_strain_map.get_slice("e_xx").data - - unrotated_strain_map.get_slice("e_yy").data - ) - + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data - ) - rotated_strain_map.data[2, :, :] = ( - sint2 * unrotated_strain_map.get_slice("e_xx").data - + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data - + cost2 * unrotated_strain_map.get_slice("e_yy").data - ) - if flip_theta == True: - rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data - else: - rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data - rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data - return rotated_strain_map diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index 178079349..1005a619d 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -3,28 +3,14 @@ _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction -from py4DSTEM.process.phase.iterative_mixedstate_ptychography import ( - MixedstatePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_multislice_ptychography import ( - MultislicePtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import ( - OverlapMagneticTomographicReconstruction, -) -from py4DSTEM.process.phase.iterative_overlap_tomography import ( - OverlapTomographicReconstruction, -) +from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction +from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction +from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction -from py4DSTEM.process.phase.iterative_simultaneous_ptychography import ( - SimultaneousPtychographicReconstruction, -) -from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( - SingleslicePtychographicReconstruction, -) -from py4DSTEM.process.phase.parameter_optimize import ( - OptimizationParameter, - PtychographyOptimizer, -) +from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction +from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction +from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer # fmt: on diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index 6d7967550..a0ed485ba 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -8,12 +8,12 @@ import numpy as np from matplotlib.gridspec import GridSpec from mpl_toolkits.axes_grid1 import ImageGrid -from py4DSTEM.visualize import show, show_complex +from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex from scipy.ndimage import rotate try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd @@ -23,7 +23,11 @@ from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( PtychographicConstraints, ) -from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases +from py4DSTEM.process.phase.utils import ( + AffineTransform, + generate_batches, + polar_aliases, +) from py4DSTEM.process.utils import ( electron_wavelength_angstrom, fourier_resample, @@ -56,6 +60,53 @@ def attach_datacube(self, datacube: DataCube): self._datacube = datacube return self + def reinitialize_parameters(self, device: str = None, verbose: bool = None): + """ + Reinitializes common parameters. This is useful when loading a previously-saved + reconstruction (which set device='cpu' and verbose=True for compatibility) , + using different initialization parameters. + + Parameters + ---------- + device: str, optional + If not None, imports and assigns appropriate device modules + verbose: bool, optional + If not None, sets the verbosity to verbose + + Returns + -------- + self: PhaseReconstruction + Self to enable chaining + """ + + if device is not None: + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self._device = device + + if verbose is not None: + self._verbose = verbose + + return self + def set_save_defaults( self, save_datacube: bool = False, @@ -278,7 +329,9 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - intensities = datacube.data + xp = self._xp + + intensities = xp.asarray(datacube.data, dtype=xp.float32) self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -295,13 +348,14 @@ def _extract_intensities_and_calibrations_from_datacube( if require_calibrations: raise ValueError("Real-space calibrations must be given in 'A'") - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) self._scan_sampling = (1.0, 1.0) self._scan_units = ("pixels",) * 2 @@ -359,13 +413,14 @@ def _extract_intensities_and_calibrations_from_datacube( "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + if self._verbose: + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) self._angular_sampling = (1.0, 1.0) self._angular_units = ("pixels",) * 2 @@ -448,8 +503,6 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - intensities = xp.asarray(intensities, dtype=xp.float32) - # for ptycho if com_measured: com_measured_x, com_measured_y = com_measured @@ -484,9 +537,14 @@ def _calculate_intensities_center_of_mass( ) if com_shifts is None: + com_measured_x_np = asnumpy(com_measured_x) + com_measured_y_np = asnumpy(com_measured_y) + finite_mask = np.isfinite(com_measured_x_np) + com_shifts = fit_origin( - (asnumpy(com_measured_x), asnumpy(com_measured_y)), + (com_measured_x_np, com_measured_y_np), fitfunction=fit_function, + mask=finite_mask, ) # Fit function to center of mass @@ -494,12 +552,12 @@ def _calculate_intensities_center_of_mass( com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) # fix CoM units - com_normalized_x = (com_measured_x - com_fitted_x) * self._reciprocal_sampling[ - 0 - ] - com_normalized_y = (com_measured_y - com_fitted_y) * self._reciprocal_sampling[ - 1 - ] + com_normalized_x = ( + xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0] + ) + com_normalized_y = ( + xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1] + ) return ( com_measured_x, @@ -1077,6 +1135,8 @@ def _normalize_diffraction_intensities( diffraction_intensities, com_fitted_x, com_fitted_y, + crop_patterns, + positions_mask, ): """ Fix diffraction intensities CoM, shift to origin, and take square root @@ -1089,6 +1149,11 @@ def _normalize_diffraction_intensities( Best fit horizontal center of mass gradient com_fitted_y: (Rx,Ry) xp.ndarray Best fit vertical center of mass gradient + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns + when centering + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1101,16 +1166,59 @@ def _normalize_diffraction_intensities( xp = self._xp mean_intensity = 0 - amplitudes = xp.zeros_like(diffraction_intensities) - region_of_interest_shape = diffraction_intensities.shape[-2:] + diffraction_intensities = self._asnumpy(diffraction_intensities) + if positions_mask is not None: + number_of_patterns = np.count_nonzero(positions_mask.ravel()) + else: + number_of_patterns = np.prod(diffraction_intensities.shape[:2]) + + if crop_patterns: + crop_x = int( + np.minimum( + diffraction_intensities.shape[2] - com_fitted_x.max(), + com_fitted_x.min(), + ) + ) + crop_y = int( + np.minimum( + diffraction_intensities.shape[3] - com_fitted_y.max(), + com_fitted_y.min(), + ) + ) + + crop_w = np.minimum(crop_y, crop_x) + region_of_interest_shape = (crop_w * 2, crop_w * 2) + amplitudes = np.zeros( + ( + number_of_patterns, + crop_w * 2, + crop_w * 2, + ), + dtype=np.float32, + ) + + crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_) + crop_mask[:crop_w, :crop_w] = True + crop_mask[-crop_w:, :crop_w] = True + crop_mask[:crop_w:, -crop_w:] = True + crop_mask[-crop_w:, -crop_w:] = True + self._crop_mask = crop_mask + + else: + region_of_interest_shape = diffraction_intensities.shape[-2:] + amplitudes = np.zeros( + (number_of_patterns,) + region_of_interest_shape, dtype=np.float32 + ) com_fitted_x = self._asnumpy(com_fitted_x) com_fitted_y = self._asnumpy(com_fitted_y) - diffraction_intensities = self._asnumpy(diffraction_intensities) - amplitudes = self._asnumpy(amplitudes) + counter = 0 for rx in range(diffraction_intensities.shape[0]): for ry in range(diffraction_intensities.shape[1]): + if positions_mask is not None: + if not positions_mask[rx, ry]: + continue intensities = get_shifted_ar( diffraction_intensities[rx, ry], -com_fitted_x[rx, ry], @@ -1119,16 +1227,71 @@ def _normalize_diffraction_intensities( device="cpu", ) - mean_intensity += np.sum(intensities) - amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0)) + if crop_patterns: + intensities = intensities[crop_mask].reshape( + region_of_interest_shape + ) - amplitudes = xp.asarray(amplitudes, dtype=xp.float32) + mean_intensity += np.sum(intensities) + amplitudes[counter] = np.sqrt(np.maximum(intensities, 0)) + counter += 1 - amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape) + amplitudes = xp.asarray(amplitudes) mean_intensity /= amplitudes.shape[0] return amplitudes, mean_intensity + def show_complex_CoM( + self, + com=None, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot complex-valued CoM image + + Parameters + ---------- + + com = (CoM_x, CoM_y) tuple + If None is specified, uses (self.com_x, self.com_y) instead + cbar: bool, optional + if True, adds colorbar + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A + pixelsize: float, optional + default is scan sampling + """ + + if com is None: + com = (self.com_x, self.com_y) + + if pixelsize is None: + pixelsize = self._scan_sampling[0] + if pixelunits is None: + pixelunits = self._scan_units[0] + + figsize = kwargs.pop("figsize", (6, 6)) + fig, ax = plt.subplots(figsize=figsize) + + complex_com = com[0] + 1j * com[1] + + show_complex( + complex_com, + cbar=cbar, + figax=(fig, ax), + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + **kwargs, + ) + class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints): """ @@ -1185,6 +1348,14 @@ def to_h5(self, group): data=metadata, ) + # saving multiple None positions_mask fix + if self._positions_mask is None: + positions_mask = None + elif self._positions_mask[0] is None: + positions_mask = None + else: + positions_mask = self._positions_mask + # preprocessing metadata self.metadata = Metadata( name="preprocess_metadata", @@ -1196,6 +1367,7 @@ def to_h5(self, group): "num_diffraction_patterns": self._num_diffraction_patterns, "sampling": self.sampling, "angular_sampling": self.angular_sampling, + "positions_mask": positions_mask, }, ) @@ -1309,10 +1481,10 @@ def _get_constructor_args(cls, group): "object_type": instance_md["object_type"], "semiangle_cutoff": instance_md["semiangle_cutoff"], "rolloff": instance_md["rolloff"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], "polar_parameters": polar_params, + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } class_specific_kwargs = {} @@ -1338,6 +1510,7 @@ def _populate_instance(self, group): self._angular_sampling = preprocess_md["angular_sampling"] self._region_of_interest_shape = preprocess_md["region_of_interest_shape"] self._num_diffraction_patterns = preprocess_md["num_diffraction_patterns"] + self._positions_mask = preprocess_md["positions_mask"] # Reconstruction metadata reconstruction_md = _read_metadata(group, "reconstruction_metadata") @@ -1389,7 +1562,9 @@ def _set_polar_parameters(self, parameters: dict): else: raise ValueError("{} not a recognized parameter".format(symbol)) - def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): + def _calculate_scan_positions_in_pixels( + self, positions: np.ndarray, positions_mask + ): """ Method to compute the initial guess of scan positions in pixels. @@ -1398,6 +1573,8 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions: (J,2) np.ndarray or None Input probe positions in Å. If None, a raster scan using experimental parameters is constructed. + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction Returns ------- @@ -1429,7 +1606,9 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): x = (x - np.ptp(x) / 2) / self.sampling[0] y = (y - np.ptp(y) / 2) / self.sampling[1] x, y = np.meshgrid(x, y, indexing="ij") - + if positions_mask is not None: + x = x[positions_mask] + y = y[positions_mask] else: positions -= np.mean(positions, axis=0) x = positions[:, 0] / self.sampling[1] @@ -1975,6 +2154,7 @@ def plot_position_correction( def _return_fourier_probe( self, probe=None, + remove_initial_probe_aberrations=False, ): """ Returns complex fourier probe shifted to center of array from @@ -1984,6 +2164,8 @@ def _return_fourier_probe( ---------- probe: complex array, optional if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe Returns ------- @@ -1997,11 +2179,17 @@ def _return_fourier_probe( else: probe = xp.asarray(probe, dtype=xp.complex64) - return xp.fft.fftshift(xp.fft.fft2(probe), axes=(-2, -1)) + fourier_probe = xp.fft.fft2(probe) + + if remove_initial_probe_aberrations: + fourier_probe *= xp.conjugate(self._known_aberrations_array) + + return xp.fft.fftshift(fourier_probe, axes=(-2, -1)) def _return_fourier_probe_from_centered_probe( self, probe=None, + remove_initial_probe_aberrations=False, ): """ Returns complex fourier probe shifted to center of array from @@ -2011,6 +2199,8 @@ def _return_fourier_probe_from_centered_probe( ---------- probe: complex array, optional if None is specified, uses self._probe + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe Returns ------- @@ -2018,7 +2208,10 @@ def _return_fourier_probe_from_centered_probe( Fourier-transformed and center-shifted probe. """ xp = self._xp - return self._return_fourier_probe(xp.fft.ifftshift(probe, axes=(-2, -1))) + return self._return_fourier_probe( + xp.fft.ifftshift(probe, axes=(-2, -1)), + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) def _return_centered_probe( self, @@ -2071,9 +2264,247 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(asnumpy(obj)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped) + else: + projected_cropped_potential = self.object_cropped + + return projected_cropped_potential + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + + if errors is None: + errors = self._return_self_consistency_errors(max_batch_size=max_batch_size) + + if projected_cropped_potential is None: + projected_cropped_potential = self._return_projected_cropped_potential() + + if kde_sigma is None: + kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0] + + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + ## Kernel Density Estimation + + # rotated basis + angle = ( + self._rotation_best_rad + if self._rotation_best_transpose + else -self._rotation_best_rad + ) + + tf = AffineTransform(angle=angle) + rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp) + + padding = xp.min(rotated_points, axis=0).astype("int") + + # bilinear sampling + pixel_output = np.array(projected_cropped_potential.shape) + asnumpy( + 2 * padding + ) + pixel_size = pixel_output.prod() + + xa = rotated_points[:, 0] + ya = rotated_points[:, 1] + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(xp.asarray(errors), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap") + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap") + pix_output /= pix_count + pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]] + pix_output, _, _ = return_scaled_histogram_ordering( + pix_output.get(), normalize=True + ) + + ## Visualization + if plot_histogram: + spec = GridSpec( + ncols=1, + nrows=2, + height_ratios=[1, 4], + hspace=0.15, + ) + auto_figsize = (4, 5.25) + else: + spec = GridSpec( + ncols=1, + nrows=1, + ) + auto_figsize = (4, 4) + + figsize = kwargs.pop("figsize", auto_figsize) + + fig = plt.figure(figsize=figsize) + + if plot_histogram: + ax_hist = fig.add_subplot(spec[0]) + + counts, bins = np.histogram(errors, bins=50) + ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5) + ax_hist.set_ylabel("Counts") + ax_hist.set_xlabel("Normalized Squared Error") + + ax = fig.add_subplot(spec[-1]) + + cmap = kwargs.pop("cmap", "magma") + vmin = kwargs.pop("vmin", None) + vmax = kwargs.pop("vmax", None) + + projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + ) + + extent = [ + 0, + self.sampling[1] * projected_cropped_potential.shape[1], + self.sampling[0] * projected_cropped_potential.shape[0], + 0, + ] + + ax.imshow( + projected_cropped_potential, + vmin=vmin, + vmax=vmax, + extent=extent, + alpha=1 - pix_output, + cmap=cmap, + **kwargs, + ) + + if plot_contours: + aligned_points = asnumpy(rotated_points - padding) + aligned_points[:, 0] *= self.sampling[0] + aligned_points[:, 1] *= self.sampling[1] + + ax.tricontour( + aligned_points[:, 1], + aligned_points[:, 0], + errors, + colors="grey", + levels=5, + # linestyles='dashed', + linewidths=0.5, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + ax.set_xlim((extent[0], extent[1])) + ax.set_ylim((extent[2], extent[3])) + ax.xaxis.set_ticks_position("bottom") + + spec.tight_layout(fig) + def show_fourier_probe( self, probe=None, + remove_initial_probe_aberrations=False, cbar=True, scalebar=True, pixelsize=None, @@ -2087,6 +2518,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe cbar: bool, optional if True, adds colorbar scalebar: bool, optional @@ -2098,10 +2531,11 @@ def show_fourier_probe( """ asnumpy = self._asnumpy - if probe is None: - probe = self.probe_fourier - else: - probe = asnumpy(self._return_fourier_probe(probe)) + probe = asnumpy( + self._return_fourier_probe( + probe, remove_initial_probe_aberrations=remove_initial_probe_aberrations + ) + ) if pixelsize is None: pixelsize = self._reciprocal_sampling[1] @@ -2109,6 +2543,7 @@ def show_fourier_probe( pixelunits = r"$\AA^{-1}$" figsize = kwargs.pop("figsize", (6, 6)) + chroma_boost = kwargs.pop("chroma_boost", 1) fig, ax = plt.subplots(figsize=figsize) show_complex( @@ -2119,6 +2554,7 @@ def show_fourier_probe( pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) @@ -2138,22 +2574,16 @@ def show_object_fft(self, obj=None, **kwargs): figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -2166,6 +2596,19 @@ def probe_fourier(self): asnumpy = self._asnumpy return asnumpy(self._return_fourier_probe(self._probe)) + @property + def probe_fourier_residual(self): + """Current probe estimate in Fourier space""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy( + self._return_fourier_probe( + self._probe, remove_initial_probe_aberrations=True + ) + ) + @property def probe_centered(self): """Current probe estimate shifted to the center""" @@ -2218,6 +2661,6 @@ def positions(self): @property def object_cropped(self): - """cropped and rotated object""" + """Cropped and rotated object""" return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index 02138d738..11adc0c70 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -13,7 +13,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd @@ -195,9 +195,9 @@ def _get_constructor_args(cls, group): "datacube": dc, "initial_object_guess": np.asarray(obj), "energy": instance_md["energy"], - "verbose": instance_md["verbose"], "name": instance_md["name"], - "device": instance_md["device"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility } return kwargs @@ -718,24 +718,26 @@ def reconstruct( xp = self._xp asnumpy = self._asnumpy - if reset is None and hasattr(self, "error"): - warnings.warn( - ( - "Continuing reconstruction from previous result. " - "Use reset=True for a fresh start." - ), - UserWarning, - ) - # Restart if store_iterations and (not hasattr(self, "object_phase_iterations") or reset): self.object_phase_iterations = [] - self.error_iterations = [] if reset: self.error = np.inf + self.error_iterations = [] self._step_size = step_size if step_size is not None else 0.5 self._padded_object_phase = self._padded_object_phase_initial.copy() + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] self.error = getattr(self, "error", np.inf) @@ -770,7 +772,8 @@ def reconstruct( if (new_error > self.error) and backtrack: self._padded_object_phase = previous_iteration self._step_size /= 2 - print(f"Iteration {a0}, step reduced to {self._step_size}") + if self._verbose: + print(f"Iteration {a0}, step reduced to {self._step_size}") continue self.error = new_error @@ -796,6 +799,7 @@ def reconstruct( anti_gridding=anti_gridding, ) + self.error_iterations.append(self.error.item()) if store_iterations: self.object_phase_iterations.append( asnumpy( @@ -804,13 +808,13 @@ def reconstruct( ].copy() ) ) - self.error_iterations.append(self.error.item()) if self._step_size < stopping_criterion: - warnings.warn( - f"Step-size has decreased below stopping criterion {stopping_criterion}.", - UserWarning, - ) + if self._verbose: + warnings.warn( + f"Step-size has decreased below stopping criterion {stopping_criterion}.", + UserWarning, + ) # crop result self._object_phase = self._padded_object_phase[ @@ -840,7 +844,7 @@ def _visualize_last_iteration( If true, the NMSE error plot is displayed """ - figsize = kwargs.pop("figsize", (8, 8)) + figsize = kwargs.pop("figsize", (5, 6)) cmap = kwargs.pop("cmap", "magma") if plot_convergence: @@ -862,7 +866,7 @@ def _visualize_last_iteration( im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs) ax1.set_ylabel(f"x [{self._scan_units[0]}]") ax1.set_xlabel(f"y [{self._scan_units[1]}]") - ax1.set_title(f"DPC Phase Reconstruction - NMSE error: {self.error:.3e}") + ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}") if cbar: divider = make_axes_locatable(ax1) @@ -870,11 +874,11 @@ def _visualize_last_iteration( fig.add_axes(ax_cb) fig.colorbar(im, cax=ax_cb) - if plot_convergence and hasattr(self, "_error_iterations"): - errors = self._error_iterations + if plot_convergence: + errors = self.error_iterations ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -979,7 +983,7 @@ def _visualize_all_iterations( if plot_convergence: ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(len(errors)), errors, **kwargs) - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.set_ylabel("Log NMSE error") ax2.yaxis.tick_right() @@ -990,7 +994,7 @@ def visualize( fig=None, iterations_grid: Tuple[int, int] = None, plot_convergence: bool = True, - cbar: bool = False, + cbar: bool = True, **kwargs, ): """ diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py new file mode 100644 index 000000000..10dc40e00 --- /dev/null +++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py @@ -0,0 +1,3705 @@ +""" +Module for reconstructing phase objects from 4DSTEM datasets using iterative methods, +namely multislice ptychography. +""" + +import warnings +from typing import Mapping, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable +from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex + +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = None + import os + + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception +from emdfile import Custom, tqdmnd +from py4DSTEM import DataCube +from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction +from py4DSTEM.process.phase.utils import ( + ComplexProbe, + fft_shift, + generate_batches, + polar_aliases, + polar_symbols, + spatial_frequencies, +) +from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate + +warnings.simplefilter(action="always", category=UserWarning) + + +class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction): + """ + Mixed-State Multislice Ptychographic Reconstruction Class. + + Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) + Reconstructed probe dimensions : (N,Sx,Sy) + Reconstructed object dimensions : (T,Px,Py) + + such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes + and (Px,Py) is the padded-object size we position our ROI around in + each of the T slices. + + Parameters + ---------- + energy: float + The electron energy of the wave functions in eV + num_probes: int, optional + Number of mixed-state probes + num_slices: int + Number of slices to use in the forward model + slice_thicknesses: float or Sequence[float] + Slice thicknesses in angstroms. If float, all slices are assigned the same thickness + datacube: DataCube, optional + Input 4D diffraction pattern intensities + semiangle_cutoff: float, optional + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels + rolloff: float, optional + Semiangle rolloff for the initial probe guess + vacuum_probe_intensity: np.ndarray, optional + Vacuum probe to use as intensity aperture for initial probe guess + polar_parameters: dict, optional + Mapping from aberration symbols to their corresponding values. All aberration + magnitudes should be given in Å and angles should be given in radians. + object_padding_px: Tuple[int,int], optional + Pixel dimensions to pad object with + If None, the padding is set to half the probe ROI dimensions + initial_object_guess: np.ndarray, optional + Initial guess for complex-valued object of dimensions (Px,Py) + If None, initialized to 1.0j + initial_probe_guess: np.ndarray, optional + Initial guess for complex-valued probe of dimensions (Sx,Sy). If None, + initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations + initial_scan_positions: np.ndarray, optional + Probe positions in Å for each diffraction intensity + If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus + object_type: str, optional + The object can be reconstructed as a real potential ('potential') or a complex + object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction + verbose: bool, optional + If True, class methods will inherit this and print additional information + device: str, optional + Calculation device will be perfomed on. Must be 'cpu' or 'gpu' + name: str, optional + Class name + kwargs: + Provide the aberration coefficients as keyword arguments. + """ + + # Class-specific Metadata + _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses") + + def __init__( + self, + energy: float, + num_slices: int, + slice_thicknesses: Union[float, Sequence[float]], + num_probes: int = None, + datacube: DataCube = None, + semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, + rolloff: float = 2.0, + vacuum_probe_intensity: np.ndarray = None, + polar_parameters: Mapping[str, float] = None, + object_padding_px: Tuple[int, int] = None, + initial_object_guess: np.ndarray = None, + initial_probe_guess: np.ndarray = None, + initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, + object_type: str = "complex", + positions_mask: np.ndarray = None, + verbose: bool = True, + device: str = "cpu", + name: str = "multi-slice_ptychographic_reconstruction", + **kwargs, + ): + Custom.__init__(self, name=name) + + if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe): + if num_probes is None: + raise ValueError( + ( + "If initial_probe_guess is None, or a ComplexProbe object, " + "num_probes must be specified." + ) + ) + else: + if len(initial_probe_guess.shape) != 3: + raise ValueError( + "Specified initial_probe_guess must have dimensions (N,Sx,Sy)." + ) + num_probes = initial_probe_guess.shape[0] + + if device == "cpu": + self._xp = np + self._asnumpy = np.asarray + from scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from scipy.special import erf + + self._erf = erf + elif device == "gpu": + self._xp = cp + self._asnumpy = cp.asnumpy + from cupyx.scipy.ndimage import gaussian_filter + + self._gaussian_filter = gaussian_filter + from cupyx.scipy.special import erf + + self._erf = erf + else: + raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + + for key in kwargs.keys(): + if (key not in polar_symbols) and (key not in polar_aliases.keys()): + raise ValueError("{} not a recognized parameter".format(key)) + + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) + + if polar_parameters is None: + polar_parameters = {} + + polar_parameters.update(kwargs) + self._set_polar_parameters(polar_parameters) + + slice_thicknesses = np.array(slice_thicknesses) + if slice_thicknesses.shape == (): + slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1) + elif slice_thicknesses.shape[0] != (num_slices - 1): + raise ValueError( + ( + f"slice_thicknesses must have length {num_slices - 1}, " + f"not {slice_thicknesses.shape[0]}." + ) + ) + + if object_type != "potential" and object_type != "complex": + raise ValueError( + f"object_type must be either 'potential' or 'complex', not {object_type}" + ) + + self.set_save_defaults() + + # Data + self._datacube = datacube + self._object = initial_object_guess + self._probe = initial_probe_guess + + # Common Metadata + self._vacuum_probe_intensity = vacuum_probe_intensity + self._scan_positions = initial_scan_positions + self._energy = energy + self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels + self._rolloff = rolloff + self._object_type = object_type + self._positions_mask = positions_mask + self._object_padding_px = object_padding_px + self._verbose = verbose + self._device = device + self._preprocessed = False + + # Class-specific Metadata + self._num_probes = num_probes + self._num_slices = num_slices + self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y + + def _precompute_propagator_arrays( + self, + gpts: Tuple[int, int], + sampling: Tuple[float, float], + energy: float, + slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, + ): + """ + Precomputes propagator arrays complex wave-function will be convolved by, + for all slice thicknesses. + + Parameters + ---------- + gpts: Tuple[int,int] + Wavefunction pixel dimensions + sampling: Tuple[float,float] + Wavefunction sampling in A + energy: float + The electron energy of the wave functions in eV + slice_thicknesses: Sequence[float] + Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + + Returns + ------- + propagator_arrays: np.ndarray + (T,Sx,Sy) shape array storing propagator arrays + """ + xp = self._xp + + # Frequencies + kx, ky = spatial_frequencies(gpts, sampling) + kx = xp.asarray(kx, dtype=xp.float32) + ky = xp.asarray(ky, dtype=xp.float32) + + # Propagators + wavelength = electron_wavelength_angstrom(energy) + num_slices = slice_thicknesses.shape[0] + propagators = xp.empty( + (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 + ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + + for i, dz in enumerate(slice_thicknesses): + propagators[i] = xp.exp( + 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) + + return propagators + + def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): + """ + Propagates array by Fourier convolving array with propagator_array. + + Parameters + ---------- + array: np.ndarray + Wavefunction array to be convolved + propagator_array: np.ndarray + Propagator array to convolve array with + + Returns + ------- + propagated_array: np.ndarray + Fourier-convolved array + """ + xp = self._xp + + return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) + + def preprocess( + self, + diffraction_intensities_shape: Tuple[int, int] = None, + reshaping_method: str = "fourier", + probe_roi_shape: Tuple[int, int] = None, + dp_mask: np.ndarray = None, + fit_function: str = "plane", + plot_center_of_mass: str = "default", + plot_rotation: bool = True, + maximize_divergence: bool = False, + rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0), + plot_probe_overlaps: bool = True, + force_com_rotation: float = None, + force_com_transpose: float = None, + force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, + object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, + **kwargs, + ): + """ + Ptychographic preprocessing step. + Calls the base class methods: + + _extract_intensities_and_calibrations_from_datacube, + _compute_center_of_mass(), + _solve_CoM_rotation(), + _normalize_diffraction_intensities() + _calculate_scan_positions_in_px() + + Additionally, it initializes an (T,Px,Py) array of 1.0j + and a complex probe using the specified polar parameters. + + Parameters + ---------- + diffraction_intensities_shape: Tuple[int,int], optional + Pixel dimensions (Qx',Qy') of the resampled diffraction intensities + If None, no resampling of diffraction intenstities is performed + reshaping_method: str, optional + Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default) + probe_roi_shape, (int,int), optional + Padded diffraction intensities shape. + If None, no padding is performed + dp_mask: ndarray, optional + Mask for datacube intensities (Qx,Qy) + fit_function: str, optional + 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' + plot_center_of_mass: str, optional + If 'default', the corrected CoM arrays will be displayed + If 'all', the computed and fitted CoM arrays will be displayed + plot_rotation: bool, optional + If True, the CoM curl minimization search result will be displayed + maximize_divergence: bool, optional + If True, the divergence of the CoM gradient vector field is maximized + rotation_angles_deg: np.darray, optional + Array of angles in degrees to perform curl minimization over + plot_probe_overlaps: bool, optional + If True, initial probe overlaps scanned over the object will be displayed + force_com_rotation: float (degrees), optional + Force relative rotation angle between real and reciprocal space + force_com_transpose: bool, optional + Force whether diffraction intensities need to be transposed. + force_com_shifts: tuple of ndarrays (CoMx, CoMy) + Amplitudes come from diffraction patterns shifted with + the CoM in the upper left corner for each probe unless + shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 + object_fov_mask: np.ndarray (boolean) + Boolean mask of FOV. Used to calculate additional shrinkage of object + If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering + + Returns + -------- + self: MixedstateMultislicePtychographicReconstruction + Self to accommodate chaining + """ + xp = self._xp + asnumpy = self._asnumpy + + # set additional metadata + self._diffraction_intensities_shape = diffraction_intensities_shape + self._reshaping_method = reshaping_method + self._probe_roi_shape = probe_roi_shape + self._dp_mask = dp_mask + + if self._datacube is None: + raise ValueError( + ( + "The preprocess() method requires a DataCube. " + "Please run ptycho.attach_datacube(DataCube) first." + ) + ) + + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + + ( + self._datacube, + self._vacuum_probe_intensity, + self._dp_mask, + force_com_shifts, + ) = self._preprocess_datacube_and_vacuum_probe( + self._datacube, + diffraction_intensities_shape=self._diffraction_intensities_shape, + reshaping_method=self._reshaping_method, + probe_roi_shape=self._probe_roi_shape, + vacuum_probe_intensity=self._vacuum_probe_intensity, + dp_mask=self._dp_mask, + com_shifts=force_com_shifts, + ) + + self._intensities = self._extract_intensities_and_calibrations_from_datacube( + self._datacube, + require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, + ) + + ( + self._com_measured_x, + self._com_measured_y, + self._com_fitted_x, + self._com_fitted_y, + self._com_normalized_x, + self._com_normalized_y, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=self._dp_mask, + fit_function=fit_function, + com_shifts=force_com_shifts, + ) + + ( + self._rotation_best_rad, + self._rotation_best_transpose, + self._com_x, + self._com_y, + self.com_x, + self.com_y, + ) = self._solve_for_center_of_mass_relative_rotation( + self._com_measured_x, + self._com_measured_y, + self._com_normalized_x, + self._com_normalized_y, + rotation_angles_deg=rotation_angles_deg, + plot_rotation=plot_rotation, + plot_center_of_mass=plot_center_of_mass, + maximize_divergence=maximize_divergence, + force_com_rotation=force_com_rotation, + force_com_transpose=force_com_transpose, + **kwargs, + ) + + ( + self._amplitudes, + self._mean_diffraction_intensity, + ) = self._normalize_diffraction_intensities( + self._intensities, + self._com_fitted_x, + self._com_fitted_y, + crop_patterns, + self._positions_mask, + ) + + # explicitly delete namespace + self._num_diffraction_patterns = self._amplitudes.shape[0] + self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:]) + del self._intensities + + self._positions_px = self._calculate_scan_positions_in_pixels( + self._scan_positions, self._positions_mask + ) + + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + + # Object Initialization + if self._object is None: + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) + p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( + "int" + ) + q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( + "int" + ) + if self._object_type == "potential": + self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64) + else: + if self._object_type == "potential": + self._object = xp.asarray(self._object, dtype=xp.float32) + elif self._object_type == "complex": + self._object = xp.asarray(self._object, dtype=xp.complex64) + + self._object_initial = self._object.copy() + self._object_type_initial = self._object_type + self._object_shape = self._object.shape[-2:] + + self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32) + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2 + self._positions_px_com = xp.mean(self._positions_px, axis=0) + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + self._positions_px_initial = self._positions_px.copy() + self._positions_initial = self._positions_px_initial.copy() + self._positions_initial[:, 0] *= self.sampling[0] + self._positions_initial[:, 1] *= self.sampling[1] + + # Vectorized Patches + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Probe Initialization + if self._probe is None or isinstance(self._probe, ComplexProbe): + if self._probe is None: + if self._vacuum_probe_intensity is not None: + self._semiangle_cutoff = np.inf + self._vacuum_probe_intensity = xp.asarray( + self._vacuum_probe_intensity, dtype=xp.float32 + ) + probe_x0, probe_y0 = get_CoM( + self._vacuum_probe_intensity, + device=self._device, + ) + self._vacuum_probe_intensity = get_shifted_ar( + self._vacuum_probe_intensity, + -probe_x0, + -probe_y0, + bilinear=True, + device=self._device, + ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + _probe = ( + ComplexProbe( + gpts=self._region_of_interest_shape, + sampling=self.sampling, + energy=self._energy, + semiangle_cutoff=self._semiangle_cutoff, + rolloff=self._rolloff, + vacuum_probe_intensity=self._vacuum_probe_intensity, + parameters=self._polar_parameters, + device=self._device, + ) + .build() + ._array + ) + + else: + if self._probe._gpts != self._region_of_interest_shape: + raise ValueError() + if hasattr(self._probe, "_array"): + _probe = self._probe._array + else: + self._probe._xp = xp + _probe = self._probe.build()._array + + self._probe = xp.zeros( + (self._num_probes,) + tuple(self._region_of_interest_shape), + dtype=xp.complex64, + ) + sx, sy = self._region_of_interest_shape + self._probe[0] = _probe + + # Randomly shift phase of other probes + for i_probe in range(1, self._num_probes): + shift_x = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) + ) + shift_y = xp.exp( + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) + ) + self._probe[i_probe] = ( + self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] + ) + + # Normalize probe to match mean diffraction intensity + probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2) + self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity) + + else: + self._probe = xp.asarray(self._probe, dtype=xp.complex64) + + self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state + + self._known_aberrations_array = ComplexProbe( + energy=self._energy, + gpts=self._region_of_interest_shape, + sampling=self.sampling, + parameters=self._polar_parameters, + device=self._device, + )._evaluate_ctf() + + # Precomputed propagator arrays + self._propagator_arrays = self._precompute_propagator_arrays( + self._region_of_interest_shape, + self.sampling, + self._energy, + self._slice_thicknesses, + self._theta_x, + self._theta_y, + ) + + # overlaps + shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) + probe_intensities = xp.abs(shifted_probes) ** 2 + probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities) + probe_overlap = self._gaussian_filter(probe_overlap, 1.0) + + if object_fov_mask is None: + self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max()) + else: + self._object_fov_mask = np.asarray(object_fov_mask) + self._object_fov_mask_inverse = np.invert(self._object_fov_mask) + + if plot_probe_overlaps: + figsize = kwargs.pop("figsize", (13, 4)) + chroma_boost = kwargs.pop("chroma_boost", 1) + + # initial probe + complex_probe_rgb = Complex2RGB( + self.probe_centered[0], + power=2, + chroma_boost=chroma_boost, + ) + + # propagated + propagated_probe = self._probe[0].copy() + + for s in range(self._num_slices - 1): + propagated_probe = self._propagate_array( + propagated_probe, self._propagator_arrays[s] + ) + complex_propagated_rgb = Complex2RGB( + asnumpy(self._return_centered_probe(propagated_probe)), + power=2, + chroma_boost=chroma_boost, + ) + + extent = [ + 0, + self.sampling[1] * self._object_shape[1], + self.sampling[0] * self._object_shape[0], + 0, + ] + + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize) + + ax1.imshow( + complex_probe_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax1) + cax1 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax1, + chroma_boost=chroma_boost, + ) + ax1.set_ylabel("x [A]") + ax1.set_xlabel("y [A]") + ax1.set_title("Initial probe[0] intensity") + + ax2.imshow( + complex_propagated_rgb, + extent=probe_extent, + ) + + divider = make_axes_locatable(ax2) + cax2 = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(cax2, chroma_boost=chroma_boost) + ax2.set_ylabel("x [A]") + ax2.set_xlabel("y [A]") + ax2.set_title("Propagated probe[0] intensity") + + ax3.imshow( + asnumpy(probe_overlap), + extent=extent, + cmap="Greys_r", + ) + ax3.scatter( + self.positions[:, 1], + self.positions[:, 0], + s=2.5, + color=(1, 0, 0, 1), + ) + ax3.set_ylabel("x [A]") + ax3.set_xlabel("y [A]") + ax3.set_xlim((extent[0], extent[1])) + ax3.set_ylim((extent[2], extent[3])) + ax3.set_title("Object field of view") + + fig.tight_layout() + + self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _overlap_projection(self, current_object, current_probe): + """ + Ptychographic overlap projection method. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + """ + + xp = self._xp + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + object_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ] + + num_probe_positions = object_patches.shape[1] + + propagated_shape = ( + self._num_slices, + num_probe_positions, + self._num_probes, + self._region_of_interest_shape[0], + self._region_of_interest_shape[1], + ) + propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype) + propagated_probes[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes = ( + xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes[s + 1] = self._propagate_array( + transmitted_probes, self._propagator_arrays[s] + ) + + return propagated_probes, object_patches, transmitted_probes + + def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes): + """ + Ptychographic fourier projection method for GD method. + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + + Returns + -------- + exit_waves:np.ndarray + Exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + intensity_norm[intensity_norm == 0.0] = np.inf + amplitude_modification = amplitudes / intensity_norm + + fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves + modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap) + + exit_waves = modified_exit_wave - transmitted_probes + + return exit_waves, error + + def _projection_sets_fourier_projection( + self, + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic fourier projection method for DM_AP and RAAR methods. + Generalized projection using three parameters: a,b,c + + DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha + DM: DM_AP(1.0), AP: DM_AP(0.0) + + RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2 + DM : RAAR(1.0) + + RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2 + DM: RRR(1.0) + + SUPERFLIP : a = 0, b = 1, c = 2 + + Parameters + -------- + amplitudes: np.ndarray + Normalized measured amplitudes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves: np.ndarray + previously estimated exit waves + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + exit_waves:np.ndarray + Updated exit wave difference + error: float + Reconstruction error + """ + + xp = self._xp + projection_x = 1 - projection_a - projection_b + projection_y = 1 - projection_c + + if exit_waves is None: + exit_waves = transmitted_probes.copy() + + fourier_exit_waves = xp.fft.fft2(transmitted_probes) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1)) + error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2) + + factor_to_be_projected = ( + projection_c * transmitted_probes + projection_y * exit_waves + ) + fourier_projected_factor = xp.fft.fft2(factor_to_be_projected) + + intensity_norm_projected = xp.sqrt( + xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1) + ) + intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf + + amplitude_modification = amplitudes / intensity_norm_projected + fourier_projected_factor *= amplitude_modification[:, None] + + projected_factor = xp.fft.ifft2(fourier_projected_factor) + + exit_waves = ( + projection_x * exit_waves + + projection_a * transmitted_probes + + projection_b * projected_factor + ) + + return exit_waves, error + + def _forward( + self, + current_object, + current_probe, + amplitudes, + exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ): + """ + Ptychographic forward operator. + Calls _overlap_projection() and the appropriate _fourier_projection(). + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + amplitudes: np.ndarray + Normalized measured amplitudes + exit_waves: np.ndarray + previously estimated exit waves + use_projection_scheme: bool, + If True, use generalized projection update + projection_a: float + projection_b: float + projection_c: float + + Returns + -------- + propagated_probes: np.ndarray + Shifted probes at each layer + object_patches: np.ndarray + Patched object view + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + exit_waves:np.ndarray + Updated exit_waves + error: float + Reconstruction error + """ + + ( + propagated_probes, + object_patches, + transmitted_probes, + ) = self._overlap_projection(current_object, current_probe) + + if use_projection_scheme: + exit_waves, error = self._projection_sets_fourier_projection( + amplitudes, + transmitted_probes, + exit_waves, + projection_a, + projection_b, + projection_c, + ) + + else: + exit_waves, error = self._gradient_descent_fourier_projection( + amplitudes, transmitted_probes + ) + + return propagated_probes, object_patches, transmitted_probes, exit_waves, error + + def _gradient_descent_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves[:, i_probe] + ) + ) + ) + else: + object_update += ( + step_size + * self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe] + ) + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] += object_update * probe_normalization + + # back-transmit + exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves = self._propagate_array( + exit_waves, xp.conj(self._propagator_arrays[s - 1]) + ) + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe += ( + step_size + * xp.sum( + exit_waves, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _projection_sets_adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ): + """ + Ptychographic adjoint operator for DM_AP and RAAR methods. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + xp = self._xp + + # careful not to modify exit_waves in-place for projection set methods + exit_waves_copy = exit_waves.copy() + for s in reversed(range(self._num_slices)): + probe = propagated_probes[s] + obj = object_patches[s] + + # object-update + probe_normalization = xp.zeros_like(current_object[s]) + object_update = xp.zeros_like(current_object[s]) + + for i_probe in range(self._num_probes): + probe_normalization += self._sum_overlapping_patches_bincounts( + xp.abs(probe[:, i_probe]) ** 2 + ) + + if self._object_type == "potential": + object_update += self._sum_overlapping_patches_bincounts( + xp.real( + -1j + * xp.conj(obj) + * xp.conj(probe[:, i_probe]) + * exit_waves_copy[:, i_probe] + ) + ) + else: + object_update += self._sum_overlapping_patches_bincounts( + xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe] + ) + + probe_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * probe_normalization) ** 2 + + (normalization_min * xp.max(probe_normalization)) ** 2 + ) + + current_object[s] = object_update * probe_normalization + + # back-transmit + exit_waves_copy *= xp.expand_dims( + xp.conj(obj), axis=1 + ) # / xp.abs(obj) ** 2 + + if s > 0: + # back-propagate + exit_waves_copy = self._propagate_array( + exit_waves_copy, xp.conj(self._propagator_arrays[s - 1]) + ) + + elif not fix_probe: + # probe-update + object_normalization = xp.sum( + (xp.abs(obj) ** 2), + axis=0, + ) + object_normalization = 1 / xp.sqrt( + 1e-16 + + ((1 - normalization_min) * object_normalization) ** 2 + + (normalization_min * xp.max(object_normalization)) ** 2 + ) + + current_probe = ( + xp.sum( + exit_waves_copy, + axis=0, + ) + * object_normalization[None] + ) + + return current_object, current_probe + + def _adjoint( + self, + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + use_projection_scheme: bool, + step_size: float, + normalization_min: float, + fix_probe: bool, + ): + """ + Ptychographic adjoint operator for GD method. + Computes object and probe update steps. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + object_patches: np.ndarray + Patched object view + propagated_probes: np.ndarray + Shifted probes at each layer + exit_waves:np.ndarray + Updated exit_waves + use_projection_scheme: bool, + If True, use generalized projection update + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + fix_probe: bool, optional + If True, probe will not be updated + + Returns + -------- + updated_object: np.ndarray + Updated object estimate + updated_probe: np.ndarray + Updated probe estimate + """ + + if use_projection_scheme: + current_object, current_probe = self._projection_sets_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + normalization_min, + fix_probe, + ) + else: + current_object, current_probe = self._gradient_descent_adjoint( + current_object, + current_probe, + object_patches, + propagated_probes, + exit_waves, + step_size, + normalization_min, + fix_probe, + ) + + return current_object, current_probe + + def _position_correction( + self, + current_object, + current_probe, + transmitted_probes, + amplitudes, + current_positions, + positions_step_size, + constrain_position_distance, + ): + """ + Position correction using estimated intensity gradient. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe:np.ndarray + fractionally-shifted probes + transmitted_probes: np.ndarray + Transmitted probes after N-1 propagations and N transmissions + amplitudes: np.ndarray + Measured amplitudes + current_positions: np.ndarray + Current positions estimate + positions_step_size: float + Positions step size + constrain_position_distance: float + Distance to constrain position correction within original + field of view in A + + Returns + -------- + updated_positions: np.ndarray + Updated positions estimate + """ + + xp = self._xp + + # Intensity gradient + exit_waves_fft = xp.fft.fft2(transmitted_probes) + exit_waves_fft_conj = xp.conj(exit_waves_fft) + estimated_intensity = xp.abs(exit_waves_fft) ** 2 + measured_intensity = amplitudes**2 + + flat_shape = (transmitted_probes.shape[0], -1) + difference_intensity = (measured_intensity - estimated_intensity).reshape( + flat_shape + ) + + # Computing perturbed exit waves one at a time to save on memory + + if self._object_type == "potential": + complex_object = xp.exp(1j * current_object) + else: + complex_object = current_object + + # dx + obj_rolled_patches = complex_object[ + :, + (self._vectorized_patch_indices_row + 1) % self._object_shape[0], + self._vectorized_patch_indices_col, + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + # dy + obj_rolled_patches = complex_object[ + :, + self._vectorized_patch_indices_row, + (self._vectorized_patch_indices_col + 1) % self._object_shape[1], + ] + + propagated_probes_perturbed = xp.empty_like(obj_rolled_patches) + propagated_probes_perturbed[0] = fft_shift( + current_probe, self._positions_px_fractional, xp + ) + + for s in range(self._num_slices): + # transmit + transmitted_probes_perturbed = ( + obj_rolled_patches[s] * propagated_probes_perturbed[s] + ) + + # propagate + if s + 1 < self._num_slices: + propagated_probes_perturbed[s + 1] = self._propagate_array( + transmitted_probes_perturbed, self._propagator_arrays[s] + ) + + exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed) + + partial_intensity_dx = 2 * xp.real( + exit_waves_dx_fft * exit_waves_fft_conj + ).reshape(flat_shape) + partial_intensity_dy = 2 * xp.real( + exit_waves_dy_fft * exit_waves_fft_conj + ).reshape(flat_shape) + + coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy)) + + # positions_update = xp.einsum( + # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity + # ) + + coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2) + positions_update = ( + xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix) + @ coefficients_matrix_T + @ difference_intensity[..., None] + ) + + if constrain_position_distance is not None: + constrain_position_distance /= xp.sqrt( + self.sampling[0] ** 2 + self.sampling[1] ** 2 + ) + x1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 0 + ] + y1 = (current_positions - positions_step_size * positions_update[..., 0])[ + :, 1 + ] + x0 = self._positions_px_initial[:, 0] + y0 = self._positions_px_initial[:, 1] + if self._rotation_best_transpose: + x0, y0 = xp.array([y0, x0]) + x1, y1 = xp.array([y1, x1]) + + if self._rotation_best_rad is not None: + rotation_angle = self._rotation_best_rad + x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin( + -rotation_angle + ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle) + x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin( + -rotation_angle + ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle) + + outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + ( + x1 < (xp.min(x0) - constrain_position_distance) + ) + (y1 > (xp.max(y0) + constrain_position_distance)) + ( + y1 < (xp.min(y0) - constrain_position_distance) + ) > 0 + + positions_update[..., 0][outlier_ind] = 0 + + current_positions -= positions_step_size * positions_update[..., 0] + + return current_positions + + def _probe_center_of_mass_constraint(self, current_probe): + """ + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + probe_intensity = xp.abs(current_probe[0]) ** 2 + + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True + ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) + + return shifted_probe + + def _probe_orthogonalization_constraint(self, current_probe): + """ + Ptychographic probe-orthogonalization constraint. + Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 + + Parameters + -------- + current_probe: np.ndarray + Current probe estimate + + Returns + -------- + constrained_probe: np.ndarray + Orthogonalized probe estimate + """ + xp = self._xp + n_probes = self._num_probes + + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] + + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): + """ + 2D Butterworth filter + Used for low/high-pass filtering object. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qya, qxa = xp.meshgrid(qy, qx) + qra = xp.sqrt(qxa**2 + qya**2) + + env = xp.ones_like(qra) + if q_highpass: + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) + if q_lowpass: + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) + + current_object_mean = xp.mean(current_object) + current_object -= current_object_mean + current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None]) + current_object += current_object_mean + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_kz_regularization_constraint( + self, current_object, kz_regularization_gamma + ): + """ + Arctan regularization filter + + Parameters + -------- + current_object: np.ndarray + Current object estimate + kz_regularization_gamma: float + Slice regularization strength + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + current_object = xp.pad( + current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant" + ) + + qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0]) + qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1]) + qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0]) + + kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0] + + qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij") + qz2 = qza**2 * kz_regularization_gamma**2 + qr2 = qxa**2 + qya**2 + + w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2) + + current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w) + current_object = current_object[1:] + + if self._object_type == "potential": + current_object = xp.real(current_object) + + return current_object + + def _object_identical_slices_constraint(self, current_object): + """ + Strong regularization forcing all slices to be identical + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + object_mean = current_object.mean(0, keepdims=True) + current_object[:] = object_mean + + return current_object + + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + + def _constraints( + self, + current_object, + current_probe, + current_positions, + fix_com, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, + fix_positions, + global_affine_transformation, + gaussian_filter, + gaussian_filter_sigma, + butterworth_filter, + q_lowpass, + q_highpass, + butterworth_order, + kz_regularization_filter, + kz_regularization_gamma, + identical_slices, + object_positivity, + shrinkage_rad, + object_mask, + pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, + orthogonalize_probe, + ): + """ + Ptychographic constraints operator. + + Parameters + -------- + current_object: np.ndarray + Current object estimate + current_probe: np.ndarray + Current probe estimate + current_positions: np.ndarray + Current positions estimate + fix_com: bool + If True, probe CoM is fixed to the center + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool + If True, probe amplitude is constrained by top hat function + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude + fix_positions: bool + If True, positions are not updated + gaussian_filter: bool + If True, applies real-space gaussian filter in A + gaussian_filter_sigma: float + Standard deviation of gaussian kernel + butterworth_filter: bool + If True, applies fourier-space butterworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter: bool + If True, applies fourier-space arctan regularization filter + kz_regularization_gamma: float + Slice regularization strength + identical_slices: bool + If True, forces all object slices to be identical + object_positivity: bool + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + pure_phase_object: bool + If True, object amplitude is set to unity + tv_denoise_chambolle: bool + If True, performs TV denoising along z + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + orthogonalize_probe: bool + If True, probe will be orthogonalized + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + constrained_probe: np.ndarray + Constrained probe estimate + constrained_positions: np.ndarray + Constrained positions estimate + """ + + if gaussian_filter: + current_object = self._object_gaussian_constraint( + current_object, gaussian_filter_sigma, pure_phase_object + ) + + if butterworth_filter: + current_object = self._object_butterworth_constraint( + current_object, + q_lowpass, + q_highpass, + butterworth_order, + ) + + if identical_slices: + current_object = self._object_identical_slices_constraint(current_object) + elif kz_regularization_filter: + current_object = self._object_kz_regularization_constraint( + current_object, kz_regularization_gamma + ) + elif tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: + current_object = self._object_denoise_tv_chambolle( + current_object, + tv_denoise_weight_chambolle, + axis=0, + pad_object=tv_denoise_pad_chambolle, + ) + + if shrinkage_rad > 0.0 or object_mask is not None: + current_object = self._object_shrinkage_constraint( + current_object, + shrinkage_rad, + object_mask, + ) + + if self._object_type == "complex": + current_object = self._object_threshold_constraint( + current_object, pure_phase_object + ) + elif object_positivity: + current_object = self._object_positivity_constraint(current_object) + + if fix_com: + current_probe = self._probe_center_of_mass_constraint(current_probe) + + # These constraints don't _really_ make sense for mixed-state + if fix_probe_aperture: + raise NotImplementedError() + elif constrain_probe_fourier_amplitude: + raise NotImplementedError() + if fit_probe_aberrations: + raise NotImplementedError() + if constrain_probe_amplitude: + raise NotImplementedError() + + if orthogonalize_probe: + current_probe = self._probe_orthogonalization_constraint(current_probe) + + if not fix_positions: + current_positions = self._positions_center_of_mass_constraint( + current_positions + ) + + if global_affine_transformation: + current_positions = self._positions_affine_transformation_constraint( + self._positions_px_initial, current_positions + ) + + return current_object, current_probe, current_positions + + def reconstruct( + self, + max_iter: int = 64, + reconstruction_method: str = "gradient-descent", + reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, + max_batch_size: int = None, + seed_random: int = None, + step_size: float = 0.5, + normalization_min: float = 1, + positions_step_size: float = 0.9, + fix_com: bool = True, + orthogonalize_probe: bool = True, + fix_probe_iter: int = 0, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, + fix_positions_iter: int = np.inf, + constrain_position_distance: float = None, + global_affine_transformation: bool = True, + gaussian_filter_sigma: float = None, + gaussian_filter_iter: int = np.inf, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, + butterworth_filter_iter: int = np.inf, + q_lowpass: float = None, + q_highpass: float = None, + butterworth_order: float = 2, + kz_regularization_filter_iter: int = np.inf, + kz_regularization_gamma: Union[float, np.ndarray] = None, + identical_slices_iter: int = 0, + object_positivity: bool = True, + shrinkage_rad: float = 0.0, + fix_potential_baseline: bool = True, + pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, + switch_object_iter: int = np.inf, + store_iterations: bool = False, + progress_bar: bool = True, + reset: bool = None, + ): + """ + Ptychographic reconstruction main method. + + Parameters + -------- + max_iter: int, optional + Maximum number of iterations to run + reconstruction_method: str, optional + Specifies which reconstruction algorithm to use, one of: + "generalized-projections", + "DM_AP" (or "difference-map_alternating-projections"), + "RAAR" (or "relaxed-averaged-alternating-reflections"), + "RRR" (or "relax-reflect-reflect"), + "SUPERFLIP" (or "charge-flipping"), or + "GD" (or "gradient_descent") + reconstruction_parameter: float, optional + Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. + max_batch_size: int, optional + Max number of probes to update at once + seed_random: int, optional + Seeds the random number generator, only applicable when max_batch_size is not None + step_size: float, optional + Update step size + normalization_min: float, optional + Probe normalization minimum as a fraction of the maximum overlap intensity + positions_step_size: float, optional + Positions update step size + fix_com: bool, optional + If True, fixes center of mass of probe + fix_probe_iter: int, optional + Number of iterations to run with a fixed probe before updating probe estimate + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float + Relative location of top-hat inflection point, between 0 and 0.5 + constrain_probe_amplitude_relative_width: float + Relative width of top-hat sigmoid, between 0 and 0.5 + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_positions_iter: int, optional + Number of iterations to run with fixed positions before updating positions estimate + global_affine_transformation: bool, optional + If True, positions are assumed to be a global affine transform from initial scan + gaussian_filter_sigma: float, optional + Standard deviation of gaussian kernel in A + gaussian_filter_iter: int, optional + Number of iterations to run using object smoothness constraint + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + butterworth_filter_iter: int, optional + Number of iterations to run using high-pass butteworth filter + q_lowpass: float + Cut-off frequency in A^-1 for low-pass butterworth filter + q_highpass: float + Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + kz_regularization_filter_iter: int, optional + Number of iterations to run using kz regularization filter + kz_regularization_gamma, float, optional + kz regularization strength + identical_slices_iter: int, optional + Number of iterations to run using identical slices + object_positivity: bool, optional + If True, forces object to be positive + shrinkage_rad: float + Phase shift in radians to be subtracted from the potential at each iteration + fix_potential_baseline: bool + If true, the potential mean outside the FOV is forced to zero at each iteration + pure_phase_object_iter: int, optional + Number of iterations where object amplitude is set to unity + tv_denoise_iter_chambolle: bool + Number of iterations with TV denoisining + tv_denoise_weight_chambolle: float + weight of tv denoising constraint + tv_denoise_pad_chambolle: bool + if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + switch_object_iter: int, optional + Iteration to switch object type between 'complex' and 'potential' or between + 'potential' and 'complex' + store_iterations: bool, optional + If True, reconstructed objects and probes are stored at each iteration + progress_bar: bool, optional + If True, reconstruction progress is displayed + reset: bool, optional + If True, previous reconstructions are ignored + + Returns + -------- + self: MultislicePtychographicReconstruction + Self to accommodate chaining + """ + asnumpy = self._asnumpy + xp = self._xp + + # Reconstruction method + + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): + raise ValueError( + ( + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." + ) + ) + + use_projection_scheme = True + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c + step_size = None + elif ( + reconstruction_method == "DM_AP" + or reconstruction_method == "difference-map_alternating-projections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = 1 + projection_c = 1 + reconstruction_parameter + step_size = None + elif ( + reconstruction_method == "RAAR" + or reconstruction_method == "relaxed-averaged-alternating-reflections" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0: + raise ValueError("reconstruction_parameter must be between 0-1.") + + use_projection_scheme = True + projection_a = 1 - 2 * reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "RRR" + or reconstruction_method == "relax-reflect-reflect" + ): + if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0: + raise ValueError("reconstruction_parameter must be between 0-2.") + + use_projection_scheme = True + projection_a = -reconstruction_parameter + projection_b = reconstruction_parameter + projection_c = 2 + step_size = None + elif ( + reconstruction_method == "SUPERFLIP" + or reconstruction_method == "charge-flipping" + ): + use_projection_scheme = True + projection_a = 0 + projection_b = 1 + projection_c = 2 + reconstruction_parameter = None + step_size = None + elif ( + reconstruction_method == "GD" or reconstruction_method == "gradient-descent" + ): + use_projection_scheme = False + projection_a = None + projection_b = None + projection_c = None + reconstruction_parameter = None + else: + raise ValueError( + ( + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " + "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " + "'RRR' (or 'relax-reflect-reflect'), " + "'SUPERFLIP' (or 'charge-flipping'), " + f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}." + ) + ) + + if self._verbose: + if switch_object_iter > max_iter: + first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, " + else: + switch_object_type = ( + "complex" if self._object_type == "potential" else "potential" + ) + first_line = ( + f"Performing {switch_object_iter} iterations using a {self._object_type} object type and " + f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, " + ) + if max_batch_size is not None: + if use_projection_scheme: + raise ValueError( + ( + "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. " + "Use reconstruction_method='GD' or set max_batch_size=None." + ) + ) + else: + print( + ( + first_line + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}, " + f"in batches of max {max_batch_size} measurements." + ) + ) + + else: + if reconstruction_parameter is not None: + if np.array(reconstruction_parameter).shape == (3,): + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}." + ) + ) + else: + if step_size is not None: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min}." + ) + ) + else: + print( + ( + first_line + + f"with the {reconstruction_method} algorithm, " + f"with normalization_min: {normalization_min} and step _size: {step_size}." + ) + ) + + # Batching + shuffled_indices = np.arange(self._num_diffraction_patterns) + unshuffled_indices = np.zeros_like(shuffled_indices) + + if max_batch_size is not None: + xp.random.seed(seed_random) + else: + max_batch_size = self._num_diffraction_patterns + + # initialization + if store_iterations and (not hasattr(self, "object_iterations") or reset): + self.object_iterations = [] + self.probe_iterations = [] + + if reset: + self.error_iterations = [] + self._object = self._object_initial.copy() + self._probe = self._probe_initial.copy() + self._positions_px = self._positions_px_initial.copy() + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + self._exit_waves = None + self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf + elif reset is None: + if hasattr(self, "error"): + warnings.warn( + ( + "Continuing reconstruction from previous result. " + "Use reset=True for a fresh start." + ), + UserWarning, + ) + else: + self.error_iterations = [] + self._exit_waves = None + + # main loop + for a0 in tqdmnd( + max_iter, + desc="Reconstructing object and probe", + unit=" iter", + disable=not progress_bar, + ): + error = 0.0 + + if a0 == switch_object_iter: + if self._object_type == "potential": + self._object_type = "complex" + self._object = xp.exp(1j * self._object) + elif self._object_type == "complex": + self._object_type = "potential" + self._object = xp.angle(self._object) + + # randomize + if not use_projection_scheme: + np.random.shuffle(shuffled_indices) + unshuffled_indices[shuffled_indices] = np.arange( + self._num_diffraction_patterns + ) + positions_px = self._positions_px.copy()[shuffled_indices] + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[shuffled_indices[start:end]] + + # forward operator + ( + propagated_probes, + object_patches, + self._transmitted_probes, + self._exit_waves, + batch_error, + ) = self._forward( + self._object, + self._probe, + amplitudes, + self._exit_waves, + use_projection_scheme, + projection_a, + projection_b, + projection_c, + ) + + # adjoint operator + self._object, self._probe = self._adjoint( + self._object, + self._probe, + object_patches, + propagated_probes, + self._exit_waves, + use_projection_scheme=use_projection_scheme, + step_size=step_size, + normalization_min=normalization_min, + fix_probe=a0 < fix_probe_iter, + ) + + # position correction + if a0 >= fix_positions_iter: + positions_px[start:end] = self._position_correction( + self._object, + self._probe[0], + self._transmitted_probes[:, 0], + amplitudes, + self._positions_px, + positions_step_size, + constrain_position_distance, + ) + + error += batch_error + + # Normalize Error + error /= self._mean_diffraction_intensity * self._num_diffraction_patterns + + # constraints + self._positions_px = positions_px.copy()[unshuffled_indices] + self._object, self._probe, self._positions_px = self._constraints( + self._object, + self._probe, + self._positions_px, + fix_com=fix_com and a0 >= fix_probe_iter, + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, + fix_positions=a0 < fix_positions_iter, + global_affine_transformation=global_affine_transformation, + gaussian_filter=a0 < gaussian_filter_iter + and gaussian_filter_sigma is not None, + gaussian_filter_sigma=gaussian_filter_sigma, + butterworth_filter=a0 < butterworth_filter_iter + and (q_lowpass is not None or q_highpass is not None), + q_lowpass=q_lowpass, + q_highpass=q_highpass, + butterworth_order=butterworth_order, + kz_regularization_filter=a0 < kz_regularization_filter_iter + and kz_regularization_gamma is not None, + kz_regularization_gamma=kz_regularization_gamma[a0] + if kz_regularization_gamma is not None + and isinstance(kz_regularization_gamma, np.ndarray) + else kz_regularization_gamma, + identical_slices=a0 < identical_slices_iter, + object_positivity=object_positivity, + shrinkage_rad=shrinkage_rad, + object_mask=self._object_fov_mask_inverse + if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 + else None, + pure_phase_object=a0 < pure_phase_object_iter + and self._object_type == "complex", + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, + orthogonalize_probe=orthogonalize_probe, + ) + + self.error_iterations.append(error.item()) + if store_iterations: + self.object_iterations.append(asnumpy(self._object.copy())) + self.probe_iterations.append(self.probe_centered) + + # store result + self.object = asnumpy(self._object) + self.probe = self.probe_centered + self.error = error.item() + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + + return self + + def _visualize_last_iteration_figax( + self, + fig, + object_ax, + convergence_ax, + cbar: bool, + padding: int = 0, + **kwargs, + ): + """ + Displays last reconstructed object on a given fig/ax. + + Parameters + -------- + fig: Figure + Matplotlib figure object_ax lives in + object_ax: Axes + Matplotlib axes to plot reconstructed object in + convergence_ax: Axes, optional + Matplotlib axes to plot convergence plot in + cbar: bool, optional + If true, displays a colorbar + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + cmap = kwargs.pop("cmap", "magma") + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + im = object_ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + if cbar: + divider = make_axes_locatable(object_ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if convergence_ax is not None and hasattr(self, "error_iterations"): + errors = np.array(self.error_iterations) + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = self.error_iterations + + convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs) + + def _visualize_last_iteration( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + padding: int, + **kwargs, + ): + """ + Displays last reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + padding : int, optional + Pixels to pad by post rotating-cropping object + + """ + figsize = kwargs.pop("figsize", (8, 5)) + cmap = kwargs.pop("cmap", "magma") + + chroma_boost = kwargs.pop("chroma_boost", 1) + + if self._object_type == "complex": + obj = np.angle(self.object) + else: + obj = self.object + + rotated_object = self._crop_rotate_object_fov( + np.sum(obj, axis=0), padding=padding + ) + rotated_shape = rotated_object.shape + + extent = [ + 0, + self.sampling[1] * rotated_shape[1], + self.sampling[0] * rotated_shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=2, + height_ratios=[4, 1], + hspace=0.15, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec( + ncols=2, + nrows=1, + width_ratios=[ + (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]), + 1, + ], + wspace=0.35, + ) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + if plot_probe or plot_fourier_probe: + # Object + ax = fig.add_subplot(spec[0, 0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + # Probe + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + + ax = fig.add_subplot(spec[0, 1]) + if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual[0] + else: + probe_array = self.probe_fourier[0] + + probe_array = Complex2RGB( + probe_array, + chroma_boost=chroma_boost, + ) + + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + self.probe[0], power=2, chroma_boost=chroma_boost + ) + ax.set_title("Reconstructed probe[0] intensity") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + else: + ax = fig.add_subplot(spec[0]) + im = ax.imshow( + rotated_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if self._object_type == "potential": + ax.set_title("Reconstructed object potential") + elif self._object_type == "complex": + ax.set_title("Reconstructed object phase") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if plot_convergence and hasattr(self, "error_iterations"): + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + errors = np.array(self.error_iterations) + if plot_probe: + ax = fig.add_subplot(spec[1, :]) + else: + ax = fig.add_subplot(spec[1]) + ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax.set_ylabel("NMSE") + ax.set_xlabel("Iteration number") + ax.yaxis.tick_right() + + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") + spec.tight_layout(fig) + + def _visualize_all_iterations( + self, + fig, + cbar: bool, + plot_convergence: bool, + plot_probe: bool, + plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, + iterations_grid: Tuple[int, int], + padding: int, + **kwargs, + ): + """ + Displays all reconstructed object and probe iterations. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + padding : int, optional + Pixels to pad by post rotating-cropping object + """ + asnumpy = self._asnumpy + + if not hasattr(self, "object_iterations"): + raise ValueError( + ( + "Object and probe iterations were not saved during reconstruction. " + "Please re-run using store_iterations=True." + ) + ) + + if iterations_grid == "auto": + num_iter = len(self.error_iterations) + + if num_iter == 1: + return self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + cbar=cbar, + padding=padding, + **kwargs, + ) + elif plot_probe or plot_fourier_probe: + iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter) + else: + iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2) + else: + if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2: + raise ValueError() + + auto_figsize = ( + (3 * iterations_grid[1], 3 * iterations_grid[0] + 1) + if plot_convergence + else (3 * iterations_grid[1], 3 * iterations_grid[0]) + ) + figsize = kwargs.pop("figsize", auto_figsize) + cmap = kwargs.pop("cmap", "magma") + + chroma_boost = kwargs.pop("chroma_boost", 1) + + errors = np.array(self.error_iterations) + + objects = [] + object_type = [] + + for obj in self.object_iterations: + if np.iscomplexobj(obj): + obj = np.angle(obj) + object_type.append("phase") + else: + object_type.append("potential") + objects.append( + self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding) + ) + + if plot_probe or plot_fourier_probe: + total_grids = (np.prod(iterations_grid) / 2).astype("int") + probes = self.probe_iterations + else: + total_grids = np.prod(iterations_grid) + max_iter = len(objects) - 1 + grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1)) + + extent = [ + 0, + self.sampling[1] * objects[0].shape[1], + self.sampling[0] * objects[0].shape[0], + 0, + ] + + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] + + if plot_convergence: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0) + else: + spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0) + else: + if plot_probe or plot_fourier_probe: + spec = GridSpec(ncols=1, nrows=2) + else: + spec = GridSpec(ncols=1, nrows=1) + + if fig is None: + fig = plt.figure(figsize=figsize) + + grid = ImageGrid( + fig, + spec[0], + nrows_ncols=(1, iterations_grid[1]) + if (plot_probe or plot_fourier_probe) + else iterations_grid, + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + im = ax.imshow( + objects[grid_range[n]], + extent=extent, + cmap=cmap, + **kwargs, + ) + ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + if cbar: + grid.cbar_axes[n].colorbar(im) + + if plot_probe or plot_fourier_probe: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + grid = ImageGrid( + fig, + spec[1], + nrows_ncols=(1, iterations_grid[1]), + axes_pad=(0.75, 0.5) if cbar else 0.5, + cbar_mode="each" if cbar else None, + cbar_pad="2.5%" if cbar else None, + ) + + for n, ax in enumerate(grid): + if plot_fourier_probe: + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: + probe_array = Complex2RGB( + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, + ) + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + im = ax.imshow( + probe_array, + extent=probe_extent, + ) + + if cbar: + add_colorbar_arg( + grid.cbar_axes[n], + chroma_boost=chroma_boost, + ) + + if plot_convergence: + kwargs.pop("vmin", None) + kwargs.pop("vmax", None) + if plot_probe: + ax2 = fig.add_subplot(spec[2]) + else: + ax2 = fig.add_subplot(spec[1]) + ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) + ax2.set_ylabel("NMSE") + ax2.set_xlabel("Iteration number") + ax2.yaxis.tick_right() + + spec.tight_layout(fig) + + def visualize( + self, + fig=None, + iterations_grid: Tuple[int, int] = None, + plot_convergence: bool = True, + plot_probe: bool = True, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, + cbar: bool = True, + padding: int = 0, + **kwargs, + ): + """ + Displays reconstructed object and probe. + + Parameters + -------- + fig: Figure + Matplotlib figure to place Gridspec in + plot_convergence: bool, optional + If true, the normalized mean squared error (NMSE) plot is displayed + iterations_grid: Tuple[int,int] + Grid dimensions to plot reconstruction iterations + cbar: bool, optional + If true, displays a colorbar + plot_probe: bool + If true, the reconstructed probe intensity is also displayed + plot_fourier_probe: bool, optional + If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes + padding : int, optional + Pixels to pad by post rotating-cropping object + + Returns + -------- + self: PtychographicReconstruction + Self to accommodate chaining + """ + + if iterations_grid is None: + self._visualize_last_iteration( + fig=fig, + plot_convergence=plot_convergence, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + padding=padding, + **kwargs, + ) + else: + self._visualize_all_iterations( + fig=fig, + plot_convergence=plot_convergence, + iterations_grid=iterations_grid, + plot_probe=plot_probe, + plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + cbar=cbar, + padding=padding, + **kwargs, + ) + return self + + def show_fourier_probe( + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, + ): + """ + Plot probe in fourier space + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe + scalebar: bool, optional + if True, adds scalebar to probe + pixelunits: str, optional + units for scalebar, default is A^-1 + pixelsize: float, optional + default is probe reciprocal sampling + """ + asnumpy = self._asnumpy + + if probe is None: + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) + else: + if isinstance(probe, np.ndarray) and probe.ndim == 2: + probe = [probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] + + if pixelsize is None: + pixelsize = self._reciprocal_sampling[1] + if pixelunits is None: + pixelunits = r"$\AA^{-1}$" + + chroma_boost = kwargs.pop("chroma_boost", 1) + + show_complex( + probe if len(probe) > 1 else probe[0], + cbar=cbar, + scalebar=scalebar, + pixelsize=pixelsize, + pixelunits=pixelunits, + ticks=False, + chroma_boost=chroma_boost, + **kwargs, + ) + + def show_transmitted_probe( + self, + plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, + **kwargs, + ): + """ + Plots the min, max, and mean transmitted probe after propagation and transmission. + + Parameters + ---------- + plot_fourier_probe: boolean, optional + If True, the transmitted probes are also plotted in Fourier space + kwargs: + Passed to show_complex + """ + + xp = self._xp + asnumpy = self._asnumpy + + transmitted_probe_intensities = xp.sum( + xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1) + ) + min_intensity_transmitted = self._transmitted_probes[ + xp.argmin(transmitted_probe_intensities), 0 + ] + max_intensity_transmitted = self._transmitted_probes[ + xp.argmax(transmitted_probe_intensities), 0 + ] + mean_transmitted = self._transmitted_probes[:, 0].mean(0) + probes = [ + asnumpy(self._return_centered_probe(probe)) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + title = [ + "Mean Transmitted Probe", + "Min Intensity Transmitted Probe", + "Max Intensity Transmitted Probe", + ] + + if plot_fourier_probe: + bottom_row = [ + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for probe in [ + mean_transmitted, + min_intensity_transmitted, + max_intensity_transmitted, + ] + ] + probes = [probes, bottom_row] + + title += [ + "Mean Transmitted Fourier Probe", + "Min Intensity Transmitted Fourier Probe", + "Max Intensity Transmitted Fourier Probe", + ] + + title = kwargs.get("title", title) + show_complex( + probes, + title=title, + **kwargs, + ) + + def show_slices( + self, + ms_object=None, + cbar: bool = True, + common_color_scale: bool = True, + padding: int = 0, + num_cols: int = 3, + show_fft: bool = False, + **kwargs, + ): + """ + Displays reconstructed slices of object + + Parameters + -------- + ms_object: nd.array, optional + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + padding: int, optional + Padding to leave uncropped + num_cols: int, optional + Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices + """ + + if ms_object is None: + ms_object = self._object + + rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) + rotated_shape = rotated_object.shape + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + + extent = [ + 0, + self.sampling[1] * rotated_shape[2], + self.sampling[0] * rotated_shape[1], + 0, + ] + + num_rows = np.ceil(self._num_slices / num_cols).astype("int") + wspace = 0.35 if cbar else 0.15 + + axsize = kwargs.pop("axsize", (3, 3)) + cmap = kwargs.pop("cmap", "magma") + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None + vmin = kwargs.pop("vmin", vmin) + vmax = kwargs.pop("vmax", vmax) + + spec = GridSpec( + ncols=num_cols, + nrows=num_rows, + hspace=0.15, + wspace=wspace, + ) + + figsize = (axsize[0] * num_cols, axsize[1] * num_rows) + fig = plt.figure(figsize=figsize) + + for flat_index, obj_slice in enumerate(rotated_object): + row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols)) + ax = fig.add_subplot(spec[row_index, col_index]) + im = ax.imshow( + obj_slice, + cmap=cmap, + vmin=vmin, + vmax=vmax, + extent=extent, + **kwargs, + ) + + ax.set_title(f"Slice index: {flat_index}") + + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + + if row_index < num_rows - 1: + ax.set_xticks([]) + else: + ax.set_xlabel("y [A]") + + if col_index > 0: + ax.set_yticks([]) + else: + ax.set_ylabel("x [A]") + + spec.tight_layout(fig) + + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[ + :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) + ] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + + def tune_num_slices_and_thicknesses( + self, + num_slices_guess=None, + thicknesses_guess=None, + num_slices_step_size=1, + thicknesses_step_size=20, + num_slices_values=3, + num_thicknesses_values=3, + update_defocus=False, + max_iter=5, + plot_reconstructions=True, + plot_convergence=True, + return_values=False, + **kwargs, + ): + """ + Run reconstructions over a parameters space of number of slices + and slice thicknesses. Should be run after the preprocess step. + + Parameters + ---------- + num_slices_guess: float, optional + initial starting guess for number of slices, rounds to nearest integer + if None, uses current initialized values + thicknesses_guess: float (A), optional + initial starting guess for thicknesses of slices assuming same + thickness for each slice + if None, uses current initialized values + num_slices_step_size: float, optional + size of change of number of slices for each step in parameter space + thicknesses_step_size: float (A), optional + size of change of slice thicknesses for each step in parameter space + num_slices_values: int, optional + number of number of slice values to test, must be >= 1 + num_thicknesses_values: int,optional + number of thicknesses values to test, must be >= 1 + update_defocus: bool, optional + if True, updates defocus based on estimated total thickness + max_iter: int, optional + number of iterations to run in ptychographic reconstruction + plot_reconstructions: bool, optional + if True, plot phase of reconstructed objects + plot_convergence: bool, optional + if True, plots error for each iteration for each reconstruction + return_values: bool, optional + if True, returns objects, convergence + + Returns + ------- + objects: list + reconstructed objects + convergence: np.ndarray + array of convergence values from reconstructions + """ + + # calculate number of slices and thicknesses values to test + if num_slices_guess is None: + num_slices_guess = self._num_slices + if thicknesses_guess is None: + thicknesses_guess = np.mean(self._slice_thicknesses) + + if num_slices_values == 1: + num_slices_step_size = 0 + + if num_thicknesses_values == 1: + thicknesses_step_size = 0 + + num_slices = np.linspace( + num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2, + num_slices_values, + ) + + thicknesses = np.linspace( + thicknesses_guess + - thicknesses_step_size * (num_thicknesses_values - 1) / 2, + thicknesses_guess + + thicknesses_step_size * (num_thicknesses_values - 1) / 2, + num_thicknesses_values, + ) + + if return_values: + convergence = [] + objects = [] + + # current initialized values + current_verbose = self._verbose + current_num_slices = self._num_slices + current_thicknesses = self._slice_thicknesses + current_rotation_deg = self._rotation_best_rad * 180 / np.pi + current_transpose = self._rotation_best_transpose + current_defocus = -self._polar_parameters["C10"] + + # Gridspec to plot on + if plot_reconstructions: + if plot_convergence: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values * 2, + height_ratios=[1, 1 / 4] * num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 5 * num_slices_values) + ) + else: + spec = GridSpec( + ncols=num_thicknesses_values, + nrows=num_slices_values, + hspace=0.15, + wspace=0.35, + ) + figsize = kwargs.get( + "figsize", (4 * num_thicknesses_values, 4 * num_slices_values) + ) + + fig = plt.figure(figsize=figsize) + + progress_bar = kwargs.pop("progress_bar", False) + # run loop and plot along the way + self._verbose = False + for flat_index, (slices, thickness) in enumerate( + tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus") + ): + slices = int(slices) + self._num_slices = slices + self._slice_thicknesses = np.tile(thickness, slices - 1) + self._probe = None + self._object = None + if update_defocus: + defocus = current_defocus + slices / 2 * thickness + self._polar_parameters["C10"] = -defocus + + self.preprocess( + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + ) + self.reconstruct( + reset=True, + store_iterations=True if plot_convergence else False, + max_iter=max_iter, + progress_bar=progress_bar, + **kwargs, + ) + + if plot_reconstructions: + row_index, col_index = np.unravel_index( + flat_index, (num_slices_values, num_thicknesses_values) + ) + + if plot_convergence: + object_ax = fig.add_subplot(spec[row_index * 2, col_index]) + convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=convergence_ax, + cbar=True, + ) + convergence_ax.yaxis.tick_right() + else: + object_ax = fig.add_subplot(spec[row_index, col_index]) + self._visualize_last_iteration_figax( + fig, + object_ax=object_ax, + convergence_ax=None, + cbar=True, + ) + + object_ax.set_title( + f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}" + ) + object_ax.set_xticks([]) + object_ax.set_yticks([]) + + if return_values: + objects.append(self.object) + convergence.append(self.error_iterations.copy()) + + # initialize back to pre-tuning values + self._probe = None + self._object = None + self._num_slices = current_num_slices + self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1) + self._polar_parameters["C10"] = -current_defocus + self.preprocess( + force_com_rotation=current_rotation_deg, + force_com_transpose=current_transpose, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + ) + self._verbose = current_verbose + + if plot_reconstructions: + spec.tight_layout(fig) + + if return_values: + return objects, convergence + + def _return_object_fft( + self, + obj=None, + ): + """ + Returns obj fft shifted to center of array + + Parameters + ---------- + obj: array, optional + if None is specified, uses self._object + """ + asnumpy = self._asnumpy + + if obj is None: + obj = self._object + + obj = asnumpy(obj) + if np.iscomplexobj(obj): + obj = np.angle(obj) + + obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index ceae66cd8..880858f30 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -14,7 +14,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np from emdfile import Custom, tqdmnd @@ -74,6 +74,8 @@ class MixedstatePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -102,6 +104,7 @@ def __init__( initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "mixed-state_ptychographic_reconstruction", @@ -178,6 +181,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -204,6 +208,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -261,6 +266,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -284,6 +291,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -349,6 +363,8 @@ def preprocess( self._intensities, self._com_fitted_x, self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -357,7 +373,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -429,6 +445,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) _probe = ( ComplexProbe( @@ -505,19 +525,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -540,23 +554,19 @@ def preprocess( axs[i].imshow( complex_probe_rgb[i], extent=probe_extent, - **kwargs, ) axs[i].set_ylabel("x [A]") axs[i].set_xlabel("y [A]") - axs[i].set_title(f"Initial Probe[{i}]") + axs[i].set_title(f"Initial probe[{i}] intensity") divider = make_axes_locatable(axs[i]) cax = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax, chroma_boost=chroma_boost) axs[-1].imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) axs[-1].scatter( self.positions[:, 1], @@ -568,7 +578,7 @@ def preprocess( axs[-1].set_xlabel("y [A]") axs[-1].set_xlim((extent[0], extent[1])) axs[-1].set_ylim((extent[2], extent[3])) - axs[-1].set_title("Object Field of View") + axs[-1].set_title("Object field of view") fig.tight_layout() @@ -1125,6 +1135,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, orthogonalize_probe, object_positivity, shrinkage_rad, @@ -1183,6 +1196,12 @@ def _constraints( Butterworth filter order. Smaller gives a smoother filter orthogonalize_probe: bool If True, probe will be orthogonalized + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1213,6 +1232,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1281,6 +1305,7 @@ def reconstruct( constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, + constrain_position_distance: float = None, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, fit_probe_aberrations_iter: int = 0, @@ -1290,6 +1315,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1353,6 +1381,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1373,6 +1403,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1575,6 +1611,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1667,6 +1705,7 @@ def reconstruct( amplitudes, self._positions_px, positions_step_size, + constrain_position_distance, ) error += batch_error @@ -1707,6 +1746,9 @@ def reconstruct( q_highpass=q_highpass, butterworth_order=butterworth_order, orthogonalize_probe=orthogonalize_probe, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1800,6 +1842,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -1816,13 +1859,15 @@ def _visualize_last_iteration( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe padding : int, optional Pixels to pad by post rotating-cropping object """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1914,30 +1959,38 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual[0] + else: + probe_array = self.probe_fourier[0] + probe_array = Complex2RGB( - self.probe_fourier[0], hue_start=hue_start, invert=invert + probe_array, + chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe[0], hue_start=hue_start, invert=invert + self.probe[0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe[0]") + ax.set_title("Reconstructed probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1970,10 +2023,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1983,6 +2036,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -2004,6 +2058,9 @@ def _visualize_all_iterations( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -2045,8 +2102,8 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2143,35 +2200,37 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]][0] - ) - ), - hue_start=hue_start, - invert=invert, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]][0], hue_start=hue_start, invert=invert + probes[grid_range[n]][0], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2183,7 +2242,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2195,6 +2254,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2216,6 +2276,9 @@ def visualize( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2231,6 +2294,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2242,6 +2306,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2250,7 +2315,14 @@ def visualize( return self def show_fourier_probe( - self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs + self, + probe=None, + remove_initial_probe_aberrations=False, + cbar=True, + scalebar=True, + pixelsize=None, + pixelunits=None, + **kwargs, ): """ Plot probe in fourier space @@ -2259,6 +2331,8 @@ def show_fourier_probe( ---------- probe: complex array, optional if None is specified, uses the `probe_fourier` property + remove_initial_probe_aberrations: bool, optional + If True, removes initial probe aberrations from Fourier probe scalebar: bool, optional if True, adds scalebar to probe pixelunits: str, optional @@ -2269,22 +2343,88 @@ def show_fourier_probe( asnumpy = self._asnumpy if probe is None: - probe = list(self.probe_fourier) + probe = list( + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + ) else: if isinstance(probe, np.ndarray) and probe.ndim == 2: probe = [probe] - probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe] + probe = [ + asnumpy( + self._return_fourier_probe( + pr, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) + for pr in probe + ] if pixelsize is None: pixelsize = self._reciprocal_sampling[1] if pixelunits is None: pixelunits = r"$\AA^{-1}$" + chroma_boost = kwargs.pop("chroma_boost", 1) + show_complex( probe if len(probe) > 1 else probe[0], + cbar=cbar, scalebar=scalebar, pixelsize=pixelsize, pixelunits=pixelunits, ticks=False, + chroma_boost=chroma_boost, **kwargs, ) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[start:end] + + # Overlaps + _, _, overlap = self._overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap) + intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1)) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index aee383675..39cb62fdd 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -14,9 +14,13 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np + import os + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction @@ -29,6 +33,7 @@ spatial_frequencies, ) from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar +from scipy.ndimage import rotate warnings.simplefilter(action="always", category=UserWarning) @@ -78,9 +83,17 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): initial_scan_positions: np.ndarray, optional Probe positions in Å for each diffraction intensity If None, initialized to a grid scan + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) + middle_focus: bool + if True, adds half the sample thickness to the defocus object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -109,7 +122,11 @@ def __init__( initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, + theta_x: float = 0, + theta_y: float = 0, + middle_focus: bool = False, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "multi-slice_ptychographic_reconstruction", @@ -142,6 +159,25 @@ def __init__( if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized parameter".format(key)) + if np.isscalar(slice_thicknesses): + mean_slice_thickness = slice_thicknesses + else: + mean_slice_thickness = np.mean(slice_thicknesses) + + if middle_focus: + if "defocus" in kwargs: + kwargs["defocus"] += mean_slice_thickness * num_slices / 2 + elif "C10" in kwargs: + kwargs["C10"] -= mean_slice_thickness * num_slices / 2 + elif polar_parameters is not None and "defocus" in polar_parameters: + polar_parameters["defocus"] = ( + polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2 + ) + elif polar_parameters is not None and "C10" in polar_parameters: + polar_parameters["C10"] = ( + polar_parameters["C10"] - mean_slice_thickness * num_slices / 2 + ) + self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols))) if polar_parameters is None: @@ -181,6 +217,7 @@ def __init__( self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type + self._positions_mask = positions_mask self._object_padding_px = object_padding_px self._verbose = verbose self._device = device @@ -189,6 +226,8 @@ def __init__( # Class-specific Metadata self._num_slices = num_slices self._slice_thicknesses = slice_thicknesses + self._theta_x = theta_x + self._theta_y = theta_y def _precompute_propagator_arrays( self, @@ -196,6 +235,8 @@ def _precompute_propagator_arrays( sampling: Tuple[float, float], energy: float, slice_thicknesses: Sequence[float], + theta_x: float, + theta_y: float, ): """ Precomputes propagator arrays complex wave-function will be convolved by, @@ -211,6 +252,10 @@ def _precompute_propagator_arrays( The electron energy of the wave functions in eV slice_thicknesses: Sequence[float] Array of slice thicknesses in A + theta_x: float + x tilt of propagator (in degrees) + theta_y: float + y tilt of propagator (in degrees) Returns ------- @@ -230,6 +275,10 @@ def _precompute_propagator_arrays( propagators = xp.empty( (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64 ) + + theta_x = np.deg2rad(theta_x) + theta_y = np.deg2rad(theta_y) + for i, dz in enumerate(slice_thicknesses): propagators[i] = xp.exp( 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz) @@ -237,6 +286,12 @@ def _precompute_propagator_arrays( propagators[i] *= xp.exp( 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz) ) + propagators[i] *= xp.exp( + 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x)) + ) + propagators[i] *= xp.exp( + 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y)) + ) return propagators @@ -279,6 +334,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -336,6 +392,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -359,6 +417,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -424,6 +489,8 @@ def preprocess( self._intensities, self._com_fitted_x, self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -432,7 +499,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -503,6 +570,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -559,6 +630,8 @@ def preprocess( self.sampling, self._energy, self._slice_thicknesses, + self._theta_x, + self._theta_y, ) # overlaps @@ -575,19 +648,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -599,10 +666,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -624,38 +689,34 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[:, 1], @@ -667,7 +728,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1449,6 +1510,111 @@ def _object_identical_slices_constraint(self, current_object): return current_object + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1481,9 +1647,12 @@ def _constraints( shrinkage_rad, object_mask, pure_phase_object, + tv_denoise_chambolle, + tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle, tv_denoise, - tv_denoise_weight, - tv_denoise_pad, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1548,12 +1717,19 @@ def _constraints( If not None, used to calculate additional shrinkage using masked-mean of object pure_phase_object: bool If True, object amplitude is set to unity - tv_denoise: bool + tv_denoise_chambolle: bool If True, performs TV denoising along z - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1585,13 +1761,17 @@ def _constraints( current_object, kz_regularization_gamma ) elif tv_denoise: - if self._object_type == "complex": - raise NotImplementedError() + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) + elif tv_denoise_chambolle: current_object = self._object_denoise_tv_chambolle( current_object, - tv_denoise_weight, + tv_denoise_weight_chambolle, axis=0, - pad_object=tv_denoise_pad, + pad_object=tv_denoise_pad_chambolle, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1690,9 +1870,12 @@ def reconstruct( shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, pure_phase_object_iter: int = 0, + tv_denoise_iter_chambolle=np.inf, + tv_denoise_weight_chambolle=None, + tv_denoise_pad_chambolle=True, tv_denoise_iter=np.inf, - tv_denoise_weight=None, - tv_denoise_pad=True, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, switch_object_iter: int = np.inf, store_iterations: bool = False, progress_bar: bool = True, @@ -1751,6 +1934,8 @@ def reconstruct( If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate + constrain_position_distance: float + Distance to constrain position correction within original field of view in A global_affine_transformation: bool, optional If True, positions are assumed to be a global affine transform from initial scan gaussian_filter_sigma: float, optional @@ -1785,12 +1970,19 @@ def reconstruct( If true, the potential mean outside the FOV is forced to zero at each iteration pure_phase_object_iter: int, optional Number of iterations where object amplitude is set to unity - tv_denoise_iter: bool + tv_denoise_iter_chambolle: bool Number of iterations with TV denoisining - tv_denoise_weight: float + tv_denoise_weight_chambolle: float weight of tv denoising constraint - tv_denoise_pad: bool + tv_denoise_pad_chambolle: bool if True, pads object at top and bottom with zeros before applying denoising + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising switch_object_iter: int, optional Iteration to switch object type between 'complex' and 'potential' or between 'potential' and 'complex' @@ -1987,6 +2179,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2123,7 +2317,7 @@ def reconstruct( and kz_regularization_gamma is not None, kz_regularization_gamma=kz_regularization_gamma[a0] if kz_regularization_gamma is not None - and type(kz_regularization_gamma) == np.ndarray + and isinstance(kz_regularization_gamma, np.ndarray) else kz_regularization_gamma, identical_slices=a0 < identical_slices_iter, object_positivity=object_positivity, @@ -2133,9 +2327,13 @@ def reconstruct( else None, pure_phase_object=a0 < pure_phase_object_iter and self._object_type == "complex", - tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, - tv_denoise_weight=tv_denoise_weight, - tv_denoise_pad=tv_denoise_pad, + tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle + and tv_denoise_weight_chambolle is not None, + tv_denoise_weight_chambolle=tv_denoise_weight_chambolle, + tv_denoise_pad_chambolle=tv_denoise_pad_chambolle, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2226,6 +2424,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -2244,14 +2443,16 @@ def _visualize_last_iteration( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe padding : int, optional Pixels to pad by post rotating-cropping object """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -2346,30 +2547,36 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + probe_array, + chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -2402,10 +2609,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2415,6 +2622,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -2436,6 +2644,9 @@ def _visualize_all_iterations( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -2477,8 +2688,8 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2577,35 +2788,35 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] - ) - ), - hue_start=hue_start, - invert=invert, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], power=2, chroma_boost=chroma_boost ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2617,7 +2828,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2629,6 +2840,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2650,6 +2862,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2665,6 +2880,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2676,6 +2892,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -2685,6 +2902,7 @@ def visualize( def show_transmitted_probe( self, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations=False, **kwargs, ): """ @@ -2727,7 +2945,12 @@ def show_transmitted_probe( if plot_fourier_probe: bottom_row = [ - asnumpy(self._return_fourier_probe(probe)) + asnumpy( + self._return_fourier_probe( + probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) + ) for probe in [ mean_transmitted, min_intensity_transmitted, @@ -2756,6 +2979,7 @@ def show_slices( common_color_scale: bool = True, padding: int = 0, num_cols: int = 3, + show_fft: bool = False, **kwargs, ): """ @@ -2771,12 +2995,20 @@ def show_slices( Padding to leave uncropped num_cols: int, optional Number of GridSpec columns + show_fft: bool, optional + if True, plots fft of object slices """ if ms_object is None: ms_object = self._object rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding) + if show_fft: + rotated_object = np.abs( + np.fft.fftshift( + np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1) + ) + ) rotated_shape = rotated_object.shape if np.iscomplexobj(rotated_object): @@ -2794,8 +3026,21 @@ def show_slices( axsize = kwargs.pop("axsize", (3, 3)) cmap = kwargs.pop("cmap", "magma") - vmin = np.min(rotated_object) if common_color_scale else None - vmax = np.max(rotated_object) if common_color_scale else None + + if common_color_scale: + vals = np.sort(rotated_object.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + else: + vmax = None + vmin = None vmin = kwargs.pop("vmin", vmin) vmax = kwargs.pop("vmax", vmax) @@ -2841,6 +3086,145 @@ def show_slices( spec.tight_layout(fig) + def show_depth( + self, + x1: float, + x2: float, + y1: float, + y2: float, + specify_calibrated: bool = False, + gaussian_filter_sigma: float = None, + ms_object=None, + cbar: bool = False, + aspect: float = None, + plot_line_profile: bool = False, + **kwargs, + ): + """ + Displays line profile depth section + + Parameters + -------- + x1, x2, y1, y2: floats (pixels) + Line profile for depth section runs from (x1,y1) to (x2,y2) + Specified in pixels unless specify_calibrated is True + specify_calibrated: bool (optional) + If True, specify x1, x2, y1, y2 in A values instead of pixels + gaussian_filter_sigma: float (optional) + Standard deviation of gaussian kernel in A + ms_object: np.array + Object to plot slices of. If None, uses current object + cbar: bool, optional + If True, displays a colorbar + aspect: float, optional + aspect ratio for depth profile plot + plot_line_profile: bool + If True, also plots line profile showing where depth profile is taken + """ + if ms_object is not None: + ms_obj = ms_object + else: + ms_obj = self.object_cropped + + if specify_calibrated: + x1 /= self.sampling[0] + x2 /= self.sampling[0] + y1 /= self.sampling[1] + y2 /= self.sampling[1] + + if x2 == x1: + angle = 0 + elif y2 == y1: + angle = np.pi / 2 + else: + angle = np.arctan((x2 - x1) / (y2 - y1)) + + x0 = ms_obj.shape[1] / 2 + y0 = ms_obj.shape[2] / 2 + + if ( + x1 > ms_obj.shape[1] + or x2 > ms_obj.shape[1] + or y1 > ms_obj.shape[2] + or y2 > ms_obj.shape[2] + ): + raise ValueError("depth section must be in field of view of object") + + from py4DSTEM.process.phase.utils import rotate_point + + x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle) + x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle) + + rotated_object = np.roll( + rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)), + -int(x1_0), + axis=1, + ) + + if np.iscomplexobj(rotated_object): + rotated_object = np.angle(rotated_object) + if gaussian_filter_sigma is not None: + from scipy.ndimage import gaussian_filter + + gaussian_filter_sigma /= self.sampling[0] + rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma) + + plot_im = rotated_object[ + :, 0, np.max((0, int(y1_0))) : np.min((int(y2_0), rotated_object.shape[2])) + ] + + extent = [ + 0, + self.sampling[1] * plot_im.shape[1], + self._slice_thicknesses[0] * plot_im.shape[0], + 0, + ] + figsize = kwargs.pop("figsize", (6, 6)) + if not plot_line_profile: + fig, ax = plt.subplots(figsize=figsize) + im = ax.imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax.set_aspect(aspect) + ax.set_xlabel("r [A]") + ax.set_ylabel("z [A]") + ax.set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + else: + extent2 = [ + 0, + self.sampling[1] * ms_obj.shape[2], + self.sampling[0] * ms_obj.shape[1], + 0, + ] + + fig, ax = plt.subplots(2, 1, figsize=figsize) + ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2) + ax[0].plot( + [y1 * self.sampling[0], y2 * self.sampling[1]], + [x1 * self.sampling[0], x2 * self.sampling[1]], + color="red", + ) + ax[0].set_xlabel("y [A]") + ax[0].set_ylabel("x [A]") + ax[0].set_title("Multislice depth profile location") + + im = ax[1].imshow(plot_im, cmap="magma", extent=extent) + if aspect is not None: + ax[1].set_aspect(aspect) + ax[1].set_xlabel("r [A]") + ax[1].set_ylabel("z [A]") + ax[1].set_title("Multislice depth profile") + if cbar: + divider = make_axes_locatable(ax[1]) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + fig.add_axes(ax_cb) + fig.colorbar(im, cax=ax_cb) + plt.tight_layout() + def tune_num_slices_and_thicknesses( self, num_slices_guess=None, @@ -3068,3 +3452,14 @@ def _return_object_fft( obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped).sum(0) + else: + projected_cropped_potential = self.object_cropped.sum(0) + + return projected_cropped_potential diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index b09d18ca7..670ea5e40 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -16,8 +16,13 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np + import os + + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -92,6 +97,8 @@ class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -114,6 +121,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -178,6 +186,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -430,6 +439,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -474,6 +484,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -497,15 +509,43 @@ def preprocess( ) ) + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask) + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_tilts, 1, 1) + ) + + if self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array."), + UserWarning, + ) + self._positions_mask = self._positions_mask.astype("bool") + else: + self._positions_mask = [None] * self._num_tilts + # Prepopulate various arrays - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) + if self._positions_mask[0] is None: + num_probes_per_tilt = [0] + for dc in self._datacube: + rx, ry = dc.Rshape + num_probes_per_tilt.append(rx * ry) + + num_probes_per_tilt = np.array(num_probes_per_tilt) + else: + num_probes_per_tilt = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) - self._num_diffraction_patterns = sum(num_probes_per_tilt) - self._cum_probes_per_tilt = np.cumsum(np.array(num_probes_per_tilt)) + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) self._mean_diffraction_intensity = [] self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) @@ -594,6 +634,8 @@ def preprocess( intensities, com_fitted_x, com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -613,7 +655,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels @@ -684,6 +726,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -807,19 +853,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -831,10 +871,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -856,38 +894,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -899,7 +936,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1679,6 +1716,111 @@ def _divergence_free_constraint(self, vector_field): return vector_field + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1710,6 +1852,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1771,6 +1916,15 @@ def _constraints( If True, forces object to be positive shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration + object_mask: np.ndarray (boolean) + If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1822,6 +1976,31 @@ def _constraints( butterworth_order, ) + elif tv_denoise: + current_object[0] = self._object_denoise_tv_pylops( + current_object[0], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[1] = self._object_denoise_tv_pylops( + current_object[1], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[2] = self._object_denoise_tv_pylops( + current_object[2], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + + current_object[3] = self._object_denoise_tv_pylops( + current_object[3], + tv_denoise_weights, + tv_denoise_inner_iter, + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object[0] = self._object_shrinkage_constraint( current_object[0], @@ -1913,6 +2092,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1998,6 +2180,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -2171,12 +2362,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2477,6 +2669,10 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) # Normalize Error Over Tilts @@ -2530,6 +2726,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2929,7 +3128,7 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[-1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() spec.tight_layout(fig) @@ -3129,22 +3328,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -3168,3 +3361,29 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 1f6be1c38..749028b83 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -16,8 +16,13 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np + import os + + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception from emdfile import Custom, tqdmnd from py4DSTEM import DataCube @@ -55,8 +60,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): The electron energy of the wave functions in eV num_slices: int Number of slices to use in the forward model - tilt_angles_deg: Sequence[float] - List of tilt angles in degrees, + tilt_orientation_matrices: Sequence[np.ndarray] + List of orientation matrices for each tilt semiangle_cutoff: float, optional Semiangle cutoff for the initial probe guess in mrad semiangle_cutoff_pixels: float, optional @@ -87,6 +92,8 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions to ignore in reconstruction name: str, optional Class name kwargs: @@ -94,13 +101,14 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): """ # Class-specific Metadata - _class_specific_metadata = ("_num_slices", "_tilt_angles_deg") + _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices") + _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]) def __init__( self, energy: float, num_slices: int, - tilt_angles_deg: Sequence[float], + tilt_orientation_matrices: Sequence[np.ndarray], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, semiangle_cutoff_pixels: float = None, @@ -109,6 +117,7 @@ def __init__( polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, object_type: str = "potential", + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: Sequence[np.ndarray] = None, @@ -122,22 +131,29 @@ def __init__( if device == "cpu": self._xp = np self._asnumpy = np.asarray - from scipy.ndimage import gaussian_filter, rotate, zoom + from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from scipy.special import erf self._erf = erf elif device == "gpu": self._xp = cp self._asnumpy = cp.asnumpy - from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom + from cupyx.scipy.ndimage import ( + affine_transform, + gaussian_filter, + rotate, + zoom, + ) self._gaussian_filter = gaussian_filter self._zoom = zoom self._rotate = rotate + self._affine_transform = affine_transform from cupyx.scipy.special import erf self._erf = erf @@ -156,7 +172,7 @@ def __init__( polar_parameters.update(kwargs) self._set_polar_parameters(polar_parameters) - num_tilts = len(tilt_angles_deg) + num_tilts = len(tilt_orientation_matrices) if initial_scan_positions is None: initial_scan_positions = [None] * num_tilts @@ -179,13 +195,14 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False # Class-specific Metadata self._num_slices = num_slices - self._tilt_angles_deg = tuple(tilt_angles_deg) + self._tilt_orientation_matrices = tuple(tilt_orientation_matrices) self._num_tilts = num_tilts def _precompute_propagator_arrays( @@ -323,6 +340,29 @@ def _expand_sliced_object(self, array: np.ndarray, output_z): normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None] return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z] + def _rotate_zxy_volume( + self, + volume_array, + rot_matrix, + ): + """ """ + + xp = self._xp + affine_transform = self._affine_transform + swap_zxy_to_xyz = self._swap_zxy_to_xyz + + volume = volume_array.copy() + volume_shape = xp.asarray(volume.shape) + tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz) + + in_center = (volume_shape - 1) / 2 + out_center = tf @ in_center + offset = in_center - out_center + + volume = affine_transform(volume, tf, offset=offset, order=3) + + return volume + def preprocess( self, diffraction_intensities_shape: Tuple[int, int] = None, @@ -340,6 +380,7 @@ def preprocess( force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -384,6 +425,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -407,14 +450,43 @@ def preprocess( ) ) + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask) + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_tilts, 1, 1) + ) + + if self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array."), + UserWarning, + ) + self._positions_mask = self._positions_mask.astype("bool") + else: + self._positions_mask = [None] * self._num_tilts + # Prepopulate various arrays - num_probes_per_tilt = [0] - for dc in self._datacube: - rx, ry = dc.Rshape - num_probes_per_tilt.append(rx * ry) - self._num_diffraction_patterns = sum(num_probes_per_tilt) - self._cum_probes_per_tilt = np.cumsum(np.array(num_probes_per_tilt)) + if self._positions_mask[0] is None: + num_probes_per_tilt = [0] + for dc in self._datacube: + rx, ry = dc.Rshape + num_probes_per_tilt.append(rx * ry) + + num_probes_per_tilt = np.array(num_probes_per_tilt) + else: + num_probes_per_tilt = np.insert( + self._positions_mask.sum(axis=(-2, -1)), 0, 0 + ) + + self._num_diffraction_patterns = num_probes_per_tilt.sum() + self._cum_probes_per_tilt = np.cumsum(num_probes_per_tilt) self._mean_diffraction_intensity = [] self._positions_px_all = np.empty((self._num_diffraction_patterns, 2)) @@ -503,6 +575,8 @@ def preprocess( intensities, com_fitted_x, com_fitted_y, + crop_patterns, + self._positions_mask[tilt_index], ) self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp) @@ -522,7 +596,7 @@ def preprocess( tilt_index + 1 ] ] = self._calculate_scan_positions_in_pixels( - self._scan_positions[tilt_index] + self._scan_positions[tilt_index], self._positions_mask[tilt_index] ) # handle semiangle specified in pixels @@ -593,6 +667,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -663,15 +741,14 @@ def preprocess( # overlaps if object_fov_mask is None: probe_overlap_3D = xp.zeros_like(self._object) + old_rot_matrix = np.eye(3) # identity for tilt_index in np.arange(self._num_tilts): - current_angle_deg = self._tilt_angles_deg[tilt_index] - probe_overlap_3D = self._rotate( + rot_matrix = self._tilt_orientation_matrices[tilt_index] + + probe_overlap_3D = self._rotate_zxy_volume( probe_overlap_3D, - current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, + rot_matrix @ old_rot_matrix.T, ) self._positions_px = self._positions_px_all[ @@ -691,14 +768,12 @@ def preprocess( ) probe_overlap_3D += probe_overlap[None] + old_rot_matrix = rot_matrix - probe_overlap_3D = self._rotate( - probe_overlap_3D, - -current_angle_deg, - axes=(0, 2), - reshape=False, - order=2, - ) + probe_overlap_3D = self._rotate_zxy_volume( + probe_overlap_3D, + old_rot_matrix.T, + ) probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0) self._object_fov_mask = asnumpy( @@ -719,19 +794,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (13, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) # propagated @@ -743,10 +812,8 @@ def preprocess( ) complex_propagated_rgb = Complex2RGB( asnumpy(self._return_centered_probe(propagated_probe)), - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -768,38 +835,37 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( complex_propagated_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax2) cax2 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax2, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax2, + chroma_boost=chroma_boost, ) ax2.set_ylabel("x [A]") ax2.set_xlabel("y [A]") - ax2.set_title("Propagated Probe") + ax2.set_title("Propagated probe intensity") ax3.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax3.scatter( self.positions[0, :, 1], @@ -811,7 +877,7 @@ def preprocess( ax3.set_xlabel("y [A]") ax3.set_xlim((extent[0], extent[1])) ax3.set_ylim((extent[2], extent[3])) - ax3.set_title("Object Field of View") + ax3.set_title("Object field of view") fig.tight_layout() @@ -1527,6 +1593,111 @@ def _object_butterworth_constraint( current_object += current_object_mean return xp.real(current_object) + def _object_denoise_tv_pylops(self, current_object, weights, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weights : [float, float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + # zero pad at top and bottom slice + pad_width = ((1, 1), (0, 0), (0, 0)) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + # run tv denoising + nz, nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny * nz) + + if weights[0] == 0: + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[1]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + elif weights[1] == 0: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + l1_regs = [z_gradient] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weights[0]], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + else: + z_gradient = pylops.FirstDerivative( + (nz, nx, ny), axis=0, edge=False, kind="backward" + ) + xy_laplacian = pylops.Laplacian( + (nz, nx, ny), axes=(1, 2), edge=False, kind="backward" + ) + l1_regs = [z_gradient, xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=weights, + tol=1e-4, + tau=1.0, + show=False, + )[0] + + # remove padding + current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1] + + return current_object_tv + def _constraints( self, current_object, @@ -1555,6 +1726,9 @@ def _constraints( object_positivity, shrinkage_rad, object_mask, + tv_denoise, + tv_denoise_weights, + tv_denoise_inner_iter, ): """ Ptychographic constraints operator. @@ -1611,6 +1785,13 @@ def _constraints( Phase shift in radians to be subtracted from the potential at each iteration object_mask: np.ndarray (boolean) If not None, used to calculate additional shrinkage using masked-mean of object + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising Returns -------- @@ -1634,6 +1815,12 @@ def _constraints( q_highpass, butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, + tv_denoise_weights, + tv_denoise_inner_iter, + ) if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( @@ -1723,6 +1910,9 @@ def reconstruct( object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, + tv_denoise_iter=np.inf, + tv_denoise_weights=None, + tv_denoise_inner_iter=40, collective_tilt_updates: bool = False, store_iterations: bool = False, progress_bar: bool = True, @@ -1806,6 +1996,15 @@ def reconstruct( Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weights: [float,float] + Denoising weights[z weight, r weight]. The greater `weight`, + the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising + collective_tilt_updates: bool + if True perform collective tilt updates shrinkage_rad: float Phase shift in radians to be subtracted from the potential at each iteration store_iterations: bool, optional @@ -1981,12 +2180,13 @@ def reconstruct( self.error_iterations = [] self._probe = self._probe_initial.copy() self._positions_px_all = self._positions_px_initial_all.copy() + if hasattr(self, "_tf"): + del self._tf if use_projection_scheme: self._exit_waves = [None] * self._num_tilts else: self._exit_waves = None - elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2018,17 +2218,17 @@ def reconstruct( tilt_indices = np.arange(self._num_tilts) np.random.shuffle(tilt_indices) + old_rot_matrix = np.eye(3) # identity + for tilt_index in tilt_indices: self._active_tilt_index = tilt_index tilt_error = 0.0 - self._object = self._rotate( + rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index] + self._object = self._rotate_zxy_volume( self._object, - self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + rot_matrix @ old_rot_matrix.T, ) object_sliced = self._project_sliced_object( @@ -2132,23 +2332,13 @@ def reconstruct( ) if collective_tilt_updates: - collective_object += self._rotate( - object_update, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, + collective_object += self._rotate_zxy_volume( + object_update, rot_matrix.T ) else: self._object += object_update - self._object = self._rotate( - self._object, - -self._tilt_angles_deg[self._active_tilt_index], - axes=(0, 2), - reshape=False, - order=3, - ) + old_rot_matrix = rot_matrix # Normalize Error tilt_error /= ( @@ -2203,8 +2393,14 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter + and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) + self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T) + # Normalize Error Over Tilts error /= self._num_tilts @@ -2251,6 +2447,9 @@ def reconstruct( if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0 else None, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None, + tv_denoise_weights=tv_denoise_weights, + tv_denoise_inner_iter=tv_denoise_inner_iter, ) self.error_iterations.append(error.item()) @@ -2399,6 +2598,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, projection_angle_deg: float, projection_axes: Tuple[int, int], x_lims: Tuple[int, int], @@ -2420,6 +2620,8 @@ def _visualize_last_iteration( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe projection_angle_deg: float Angle in degrees to rotate 3D array around prior to projection projection_axes: tuple(int,int) @@ -2429,10 +2631,12 @@ def _visualize_last_iteration( y_lims: tuple(float,float) min/max y indices """ + asnumpy = self._asnumpy + figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) asnumpy = self._asnumpy @@ -2533,17 +2737,26 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + probe_array, + chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") @@ -2556,7 +2769,10 @@ def _visualize_last_iteration( if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg( + ax_cb, + chroma_boost=chroma_boost, + ) else: ax = fig.add_subplot(spec[0]) im = ax.imshow( @@ -2585,10 +2801,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -2598,6 +2814,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], projection_angle_deg: float, projection_axes: Tuple[int, int], @@ -2620,6 +2837,9 @@ def _visualize_all_iterations( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations projection_angle_deg: float @@ -2672,8 +2892,8 @@ def _visualize_all_iterations( ) figsize = kwargs.pop("figsize", auto_figsize) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2782,35 +3002,37 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] - ) - ), - hue_start=hue_start, - invert=invert, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) + ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2822,7 +3044,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2834,6 +3056,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, projection_angle_deg: float = None, projection_axes: Tuple[int, int] = (0, 2), @@ -2856,6 +3079,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes iterations_grid: Tuple[int,int] Grid dimensions to plot reconstruction iterations projection_angle_deg: float @@ -2879,6 +3105,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, projection_angle_deg=projection_angle_deg, projection_axes=projection_axes, @@ -2893,6 +3120,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, projection_angle_deg=projection_angle_deg, projection_axes=projection_axes, @@ -2997,22 +3225,16 @@ def show_object_fft( figsize = kwargs.pop("figsize", (6, 6)) cmap = kwargs.pop("cmap", "magma") - vmin = kwargs.pop("vmin", 0) - vmax = kwargs.pop("vmax", 1) - power = kwargs.pop("power", 0.2) - pixelsize = 1 / (object_fft.shape[0] * self.sampling[0]) + pixelsize = 1 / (object_fft.shape[1] * self.sampling[1]) show( object_fft, figsize=figsize, cmap=cmap, - vmin=vmin, - vmax=vmax, scalebar=True, pixelsize=pixelsize, ticks=False, pixelunits=r"$\AA^{-1}$", - power=power, **kwargs, ) @@ -3036,3 +3258,29 @@ def positions(self): positions_all.append(asnumpy(positions)) return np.asarray(positions_all) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + raise NotImplementedError() + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + raise NotImplementedError() + + def show_uncertainty_visualization( + self, + errors=None, + max_batch_size=None, + projected_cropped_potential=None, + kde_sigma=None, + plot_histogram=True, + plot_contours=False, + **kwargs, + ): + """Plot uncertainty visualization using self-consistency errors""" + raise NotImplementedError() diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 7c5896b6a..4877512d7 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -8,22 +8,44 @@ import matplotlib.pyplot as plt import numpy as np -from emdfile import Custom, tqdmnd +from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd from matplotlib.gridspec import GridSpec -from py4DSTEM import DataCube +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable +from py4DSTEM import Calibration, DataCube +from py4DSTEM.preprocess.utils import get_shifted_ar from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform from py4DSTEM.process.utils.cross_correlate import align_images_fourier from py4DSTEM.process.utils.utils import electron_wavelength_angstrom +from py4DSTEM.visualize import show from scipy.linalg import polar +from scipy.optimize import minimize from scipy.special import comb try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np warnings.simplefilter(action="always", category=UserWarning) +_aberration_names = { + (1, 0): "C1 ", + (1, 2): "stig ", + (2, 1): "coma ", + (2, 3): "trefoil ", + (3, 0): "C3 ", + (3, 2): "stig2 ", + (3, 4): "quadfoil ", + (4, 1): "coma2 ", + (4, 3): "trefoil2 ", + (4, 5): "pentafoil ", + (5, 0): "C5 ", + (5, 2): "stig3 ", + (5, 4): "quadfoil2 ", + (5, 6): "hexafoil ", +} + class ParallaxReconstruction(PhaseReconstruction): """ @@ -35,9 +57,6 @@ class ParallaxReconstruction(PhaseReconstruction): Input 4D diffraction pattern intensities energy: float The electron energy of the wave functions in eV - dp_mean: ndarray, optional - Mean diffraction pattern - If None, get_dp_mean() is used verbose: bool, optional If True, class methods will inherit this and print additional information device: str, optional @@ -73,6 +92,8 @@ def __init__( else: raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}") + self.set_save_defaults() + # Data self._datacube = datacube @@ -86,9 +107,78 @@ def __init__( def to_h5(self, group): """ Wraps datasets and metadata to write in emdfile classes, - notably ... + notably the (subpixel-)aligned BF. """ - raise NotImplementedError() + # instantiation metadata + self.metadata = Metadata( + name="instantiation_metadata", + data={ + "energy": self._energy, + "verbose": self._verbose, + "device": self._device, + "object_padding_px": self._object_padding_px, + "name": self.name, + }, + ) + + # preprocessing metadata + self.metadata = Metadata( + name="preprocess_metadata", + data={ + "scan_sampling": self._scan_sampling, + "wavelength": self._wavelength, + }, + ) + + # reconstruction metadata + recon_metadata = {"reconstruction_error": float(self._recon_error)} + + if hasattr(self, "aberration_C1"): + recon_metadata |= { + "aberration_rotation_QR": self.rotation_Q_to_R_rads, + "aberration_transpose": self.transpose, + "aberration_C1": self.aberration_C1, + "aberration_A1x": self.aberration_A1x, + "aberration_A1y": self.aberration_A1y, + } + + if hasattr(self, "_kde_upsample_factor"): + recon_metadata |= { + "kde_upsample_factor": self._kde_upsample_factor, + } + self._subpixel_aligned_BF_emd = Array( + name="subpixel_aligned_BF", + data=self._asnumpy(self._recon_BF_subpixel_aligned), + ) + + if hasattr(self, "aberration_dict"): + self.metadata = Metadata( + name="aberrations_metadata", + data={ + v["aberration name"]: v["value [Ang]"] + for k, v in self.aberration_dict.items() + }, + ) + + self.metadata = Metadata( + name="reconstruction_metadata", + data=recon_metadata, + ) + + self._aligned_BF_emd = Array( + name="aligned_BF", + data=self._asnumpy(self._recon_BF), + ) + + # datacube + if self._save_datacube: + self.metadata = self._datacube.calibration + Custom.to_h5(self, group) + else: + dc = self._datacube + self._datacube = None + Custom.to_h5(self, group) + self._datacube = dc @classmethod def _get_constructor_args(cls, group): @@ -96,14 +186,68 @@ def _get_constructor_args(cls, group): Returns a dictionary of arguments/values to pass to the class' __init__ function """ - raise NotImplementedError() + # Get data + dict_data = cls._get_emd_attr_data(cls, group) + + # Get metadata dictionaries + instance_md = _read_metadata(group, "instantiation_metadata") + + # Fix calibrations bug + if "_datacube" in dict_data: + calibrations_dict = _read_metadata(group, "calibration")._params + cal = Calibration() + cal._params.update(calibrations_dict) + dc = dict_data["_datacube"] + dc.calibration = cal + else: + dc = None + + # Populate args and return + kwargs = { + "datacube": dc, + "energy": instance_md["energy"], + "object_padding_px": instance_md["object_padding_px"], + "name": instance_md["name"], + "verbose": True, # for compatibility + "device": "cpu", # for compatibility + } + + return kwargs def _populate_instance(self, group): """ Sets post-initialization properties, notably some preprocessing meta optional; during read, this method is run after object instantiation. """ - raise NotImplementedError() + + xp = self._xp + + # Preprocess metadata + preprocess_md = _read_metadata(group, "preprocess_metadata") + self._scan_sampling = preprocess_md["scan_sampling"] + self._wavelength = preprocess_md["wavelength"] + + # Reconstruction metadata + reconstruction_md = _read_metadata(group, "reconstruction_metadata") + self._recon_error = reconstruction_md["reconstruction_error"] + + # Data + dict_data = Custom._get_emd_attr_data(Custom, group) + + if "aberration_C1" in reconstruction_md.keys: + self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"] + self.transpose = reconstruction_md["aberration_transpose"] + self.aberration_C1 = reconstruction_md["aberration_C1"] + self.aberration_A1x = reconstruction_md["aberration_A1x"] + self.aberration_A1y = reconstruction_md["aberration_A1y"] + + if "kde_upsample_factor" in reconstruction_md.keys: + self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"] + self._recon_BF_subpixel_aligned = xp.asarray( + dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32 + ) + + self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32) def preprocess( self, @@ -111,6 +255,7 @@ def preprocess( threshold_intensity: float = 0.8, normalize_images: bool = True, normalize_order=0, + descan_correct: bool = True, defocus_guess: float = None, rotation_guess: float = None, plot_average_bf: bool = True, @@ -133,6 +278,8 @@ def preprocess( defocus_guess: float, optional Initial guess of defocus value (defocus dF) in A If None, first iteration is assumed to be in-focus + descan_correct: float, optional + If True, aligns bright field stack based on measured descan rotation_guess: float, optional Initial guess of defocus value in degrees If None, first iteration assumed to be 0 @@ -171,7 +318,10 @@ def preprocess( self._datacube, require_calibrations=True, ) - self._intensities = xp.asarray(self._intensities, dtype=xp.float32) + + self._region_of_interest_shape = np.array(self._intensities.shape[-2:]) + self._scan_shape = np.array(self._intensities.shape[:2]) + # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -180,6 +330,45 @@ def preprocess( "dp_mean must match the datacube shape. Try setting dp_mean = None." ) + # descan correct + if descan_correct: + ( + _, + _, + com_fitted_x, + com_fitted_y, + _, + _, + ) = self._calculate_intensities_center_of_mass( + self._intensities, + dp_mask=None, + fit_function="plane", + com_shifts=None, + com_measured=None, + ) + + com_fitted_x = asnumpy(com_fitted_x) + com_fitted_y = asnumpy(com_fitted_y) + intensities = asnumpy(self._intensities) + intensities_shifted = np.zeros_like(intensities) + + center_x, center_y = self._region_of_interest_shape / 2 + + for rx in range(intensities_shifted.shape[0]): + for ry in range(intensities_shifted.shape[1]): + intensity_shifted = get_shifted_ar( + intensities[rx, ry], + -com_fitted_x[rx, ry] + center_x, + -com_fitted_y[rx, ry] + center_y, + bilinear=True, + device="cpu", + ) + + intensities_shifted[rx, ry] = intensity_shifted + + self._intensities = xp.asarray(intensities_shifted, xp.float32) + self._dp_mean = self._intensities.mean((0, 1)) + # select virtual detector pixels self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity) self._num_bf_images = int(xp.count_nonzero(self._dp_mask)) @@ -187,14 +376,16 @@ def preprocess( # diffraction space coordinates self._xy_inds = np.argwhere(self._dp_mask) - self._kxy = (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) * xp.array( - self._reciprocal_sampling - )[None] + self._kxy = xp.asarray( + (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None]) + * xp.array(self._reciprocal_sampling)[None], + dtype=xp.float32, + ) self._probe_angles = self._kxy * self._wavelength self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1)) # Window function - x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1)[1:] + x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:] x -= (x[1] - x[0]) / 2 wx = ( xp.sin( @@ -205,7 +396,7 @@ def preprocess( ) ** 2 ) - y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1)[1:] + y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:] y -= (y[1] - y[0]) / 2 wy = ( xp.sin( @@ -222,7 +413,8 @@ def preprocess( ( self._grid_scan_shape[0] + self._object_padding_px[0], self._grid_scan_shape[1] + self._object_padding_px[1], - ) + ), + dtype=xp.float32, ) self._window_pad[ self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -245,7 +437,8 @@ def preprocess( self._grid_scan_shape[1] + self._object_padding_px[1], ) if normalize_images: - self._stack_BF = xp.ones(stack_shape) + self._stack_BF = xp.ones(stack_shape, dtype=xp.float32) + self._stack_BF_no_window = xp.ones(stack_shape, xp.float32) if normalize_order == 0: all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None] @@ -259,13 +452,21 @@ def preprocess( self._window_inv[None] + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + elif normalize_order == 1: - x = xp.linspace(-0.5, 0.5, all_bfs.shape[1]) - y = xp.linspace(-0.5, 0.5, all_bfs.shape[2]) + x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32) + y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32) ya, xa = xp.meshgrid(y, x) basis = np.vstack( ( - xp.ones(xa.size), + xp.ones_like(xa), xa.ravel(), ya.ravel(), ) @@ -285,9 +486,18 @@ def preprocess( basis @ coefs[0], all_bfs.shape[1:3] ) + self._stack_BF_no_window[ + a0, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3]) + else: all_means = xp.mean(all_bfs, axis=(1, 2)) self._stack_BF = xp.full(stack_shape, all_means[:, None, None]) + self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None]) self._stack_BF[ :, self._object_padding_px[0] // 2 : self._grid_scan_shape[0] @@ -299,9 +509,21 @@ def preprocess( + self._window_edge[None] * all_bfs ) + self._stack_BF_no_window[ + :, + self._object_padding_px[0] // 2 : self._grid_scan_shape[0] + + self._object_padding_px[0] // 2, + self._object_padding_px[1] // 2 : self._grid_scan_shape[1] + + self._object_padding_px[1] // 2, + ] = all_bfs + # Fourier space operators for image shifts qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1) + qx = xp.asarray(qx, dtype=xp.float32) + qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1) + qy = xp.asarray(qy, dtype=xp.float32) + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") self._qx_shift = -2j * xp.pi * qxa self._qy_shift = -2j * xp.pi * qya @@ -336,7 +558,7 @@ def preprocess( del Gs else: - self._xy_shifts = xp.zeros((self._num_bf_images, 2)) + self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) self._stack_mean = xp.mean(self._stack_BF) self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images @@ -365,16 +587,29 @@ def preprocess( self.recon_BF = asnumpy(self._recon_BF) if plot_average_bf: - figsize = kwargs.pop("figsize", (6, 6)) + figsize = kwargs.pop("figsize", (8, 4)) - fig, ax = plt.subplots(figsize=figsize) + fig, ax = plt.subplots(1, 2, figsize=figsize) - self._visualize_figax(fig, ax, **kwargs) + self._visualize_figax(fig, ax[0], **kwargs) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - ax.set_title("Average Bright Field Image") + ax[0].set_ylabel("x [A]") + ax[0].set_xlabel("y [A]") + ax[0].set_title("Average Bright Field Image") + reciprocal_extent = [ + -0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]), + 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]), + ] + ax[1].imshow( + self._asnumpy(self._dp_mask), extent=reciprocal_extent, cmap="gray" + ) + ax[1].set_title("DP mask") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + plt.tight_layout() self._preprocessed = True if self._device == "gpu": @@ -506,8 +741,6 @@ def tune_angle_and_defocus( convergence.append(asnumpy(self._recon_error[0])) if plot_convergence: - from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - fig, ax = plt.subplots() ax.set_title("convergence") im = ax.imshow( @@ -533,9 +766,9 @@ def tune_angle_and_defocus( divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) + fig.colorbar(im, cax=cax) - plt.tight_layout() + fig.tight_layout() if return_values: convergence = np.array(convergence).reshape( @@ -548,7 +781,7 @@ def reconstruct( max_alignment_bin: int = None, min_alignment_bin: int = 1, max_iter_at_min_bin: int = 2, - upsample_factor: int = 8, + cross_correlation_upsample_factor: int = 8, regularizer_matrix_size: Tuple[int, int] = (1, 1), regularize_shifts: bool = True, running_average: bool = True, @@ -570,7 +803,7 @@ def reconstruct( Minimum bin size for bright field alignment max_iter_at_min_bin: int, optional Number of iterations to run at the smallest bin size - upsample_factor: int, optional + cross_correlation_upsample_factor: int, optional DFT upsample factor for subpixel alignment regularizer_matrix_size: Tuple[int,int], optional Bernstein basis degree used for regularizing shifts @@ -623,7 +856,8 @@ def reconstruct( ( self._num_bf_images, (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1), - ) + ), + dtype=xp.float32, ) for ii in np.arange(regularizer_matrix_size[0] + 1): Bi = ( @@ -708,7 +942,7 @@ def reconstruct( # Sort by radial order, from center to outer edge inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1)) - shifts_update = xp.zeros((self._num_bf_images, 2)) + shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32) for a1 in tqdmnd( xy_vals.shape[0], @@ -730,7 +964,7 @@ def reconstruct( xy_shift = align_images_fourier( G_ref, G, - upsample_factor=upsample_factor, + upsample_factor=cross_correlation_upsample_factor, device=self._device, ) @@ -777,11 +1011,19 @@ def reconstruct( self._qx_shift[None] * dx[:, None, None] + self._qy_shift[None] * dy[:, None, None] ) + self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op)) self._stack_mask = xp.real( xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op) ) + self._stack_BF = xp.asarray( + self._stack_BF, dtype=xp.float32 + ) # numpy fft upcasts? + self._stack_mask = xp.asarray( + self._stack_mask, dtype=xp.float32 + ) # numpy fft upcasts? + del Gs # Center the shifts @@ -837,59 +1079,757 @@ def reconstruct( return self + def subpixel_alignment( + self, + kde_upsample_factor=None, + kde_sigma=0.125, + plot_upsampled_BF_comparison: bool = True, + plot_upsampled_FFT_comparison: bool = False, + **kwargs, + ): + """ + Upsample and subpixel-align BFs using the measured image shifts. + Uses kernel density estimation (KDE) to align upsampled BFs. + + Parameters + ---------- + kde_upsample_factor: int, optional + Real-space upsampling factor + kde_sigma: float, optional + KDE gaussian kernel bandwidth + plot_upsampled_BF_comparison: bool, optional + If True, the pre/post alignment BF images are plotted for comparison + plot_upsampled_FFT_comparison: bool, optional + If True, the pre/post alignment BF FFTs are plotted for comparison + + """ + xp = self._xp + asnumpy = self._asnumpy + gaussian_filter = self._gaussian_filter + + xy_shifts = self._xy_shifts + BF_size = np.array(self._stack_BF_no_window.shape[-2:]) + + self._DF_upsample_limit = np.max( + 2 * self._region_of_interest_shape / self._scan_shape + ) + self._BF_upsample_limit = ( + 4 * self._kr.max() / self._reciprocal_sampling[0] + ) / self._scan_shape.max() + if self._device == "gpu": + self._BF_upsample_limit = self._BF_upsample_limit.item() + + if kde_upsample_factor is None: + if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit: + kde_upsample_factor = self._DF_upsample_limit + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (the " + f"dark-field upsampling limit)." + ), + UserWarning, + ) + + elif self._BF_upsample_limit * 3 / 2 > 1: + kde_upsample_factor = self._BF_upsample_limit * 3 / 2 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the " + f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})." + ), + UserWarning, + ) + else: + kde_upsample_factor = self._DF_upsample_limit * 2 / 3 + + warnings.warn( + ( + f"Upsampling factor set to {kde_upsample_factor:.2f} (2/3 times the " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f})." + ), + UserWarning, + ) + + if kde_upsample_factor < 1: + raise ValueError("kde_upsample_factor must be larger than 1") + + if kde_upsample_factor > self._DF_upsample_limit: + warnings.warn( + ( + "Requested upsampling factor exceeds " + f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}." + ), + UserWarning, + ) + + self._kde_upsample_factor = kde_upsample_factor + pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int") + pixel_size = pixel_output.prod() + + # shifted coordinates + x = xp.arange(BF_size[0]) + y = xp.arange(BF_size[1]) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel() + ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel() + + # bilinear sampling + xF = xp.floor(xa).astype("int") + yF = xp.floor(ya).astype("int") + dx = xa - xF + dy = ya - yF + + # resampling + inds_1D = xp.ravel_multi_index( + xp.hstack( + [ + [xF, yF], + [xF + 1, yF], + [xF, yF + 1], + [xF + 1, yF + 1], + ] + ), + pixel_output, + mode=["wrap", "wrap"], + ) + + weights = xp.hstack( + ( + (1 - dx) * (1 - dy), + (dx) * (1 - dy), + (1 - dx) * (dy), + (dx) * (dy), + ) + ) + + pix_count = xp.reshape( + xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output + ) + pix_output = xp.reshape( + xp.bincount( + inds_1D, + weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4), + minlength=pixel_size, + ), + pixel_output, + ) + + # kernel density estimate + sigma = kde_sigma * self._kde_upsample_factor + pix_count = gaussian_filter(pix_count, sigma) + pix_count[pix_count == 0.0] = np.inf + pix_output = gaussian_filter(pix_output, sigma) + pix_output /= pix_count + + self._recon_BF_subpixel_aligned = pix_output + self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned) + + # plotting + if plot_upsampled_BF_comparison: + if plot_upsampled_FFT_comparison: + figsize = kwargs.pop("figsize", (8, 8)) + fig, axs = plt.subplots(2, 2, figsize=figsize) + else: + figsize = kwargs.pop("figsize", (8, 4)) + fig, axs = plt.subplots(1, 2, figsize=figsize) + + axs = axs.flat + cmap = kwargs.pop("cmap", "magma") + + cropped_object = self._crop_padded_object(self._recon_BF) + cropped_object_aligned = self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] + + axs[0].imshow( + cropped_object, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[0].set_title("Aligned Bright Field") + + axs[1].imshow( + cropped_object_aligned, + extent=extent, + cmap=cmap, + **kwargs, + ) + axs[1].set_title("Upsampled Bright Field") + + for ax in axs[:2]: + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") + + if plot_upsampled_FFT_comparison: + recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF))) + pad_x = np.round( + BF_size[0] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_y = np.round( + BF_size[1] * (self._kde_upsample_factor - 1) / 2 + ).astype("int") + pad_recon_fft = asnumpy( + xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y))) + ) + + upsampled_fft = asnumpy( + xp.fft.fftshift( + xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + ) + ) + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + show( + pad_recon_fft, + figax=(fig, axs[2]), + extent=reciprocal_extent, + cmap="gray", + title="Aligned Bright Field FFT", + **kwargs, + ) + + show( + upsampled_fft, + figax=(fig, axs[3]), + extent=reciprocal_extent, + cmap="gray", + title="Upsampled Bright Field FFT", + **kwargs, + ) + + for ax in axs[2:]: + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.xaxis.set_ticks_position("bottom") + + fig.tight_layout() + def aberration_fit( self, - plot_CTF_compare: bool = False, - plot_dk: float = 0.005, - plot_k_sigma: float = 0.02, + fit_BF_shifts: bool = False, + fit_CTF_FFT: bool = False, + fit_aberrations_max_radial_order: int = 3, + fit_aberrations_max_angular_order: int = 4, + fit_aberrations_min_radial_order: int = 2, + fit_aberrations_min_angular_order: int = 0, + fit_max_thon_rings: int = 6, + fit_power_alpha: float = 2.0, + plot_CTF_comparison: bool = None, + plot_BF_shifts_comparison: bool = None, + upsampled: bool = True, + force_transpose: bool = False, + force_rotation_deg: float = None, ): """ Fit aberrations to the measured image shifts. Parameters ---------- - plot_CTF_compare: bool, optional - If True, the fitted CTF is plotted against the reconstructed frequencies - plot_dk: float, optional - Reciprocal bin-size for polar-averaged FFT - plot_k_sigma: float, optional - sigma to gaussian blur polar-averaged FFT by + fit_BF_shifts: bool + Set to True to fit aberrations to the measured BF shifts directly. + fit_CTF_FFT: bool + Set to True to fit aberrations in the FFT of the (upsampled) BF + image. Note that this method relies on visible zero crossings in the FFT. + fit_aberrations_max_radial_order: int + Max radial order for fitting of aberrations. + fit_aberrations_max_angular_order: int + Max angular order for fitting of aberrations. + fit_aberrations_min_radial_order: int + Min radial order for fitting of aberrations. + fit_aberrations_min_angular_order: int + Min angular order for fitting of aberrations. + fit_max_thon_rings: int + Max number of Thon rings to search for during CTF FFT fitting. + fit_power_alpha: int + Power to raise FFT alpha weighting during CTF FFT fitting. + plot_CTF_comparison: bool, optional + If True, the fitted CTF is plotted against the reconstructed frequencies. + plot_BF_shifts_comparison: bool, optional + If True, the measured vs fitted BF shifts are plotted. + upsampled: bool + If True, and upsampled BF is available, uses that for CTF FFT fitting. + force_transpose: bool + If True, flips the measured x and y shifts. + force_rotation_deg: float + If not None, sets the rotation angle to value in degrees. """ xp = self._xp asnumpy = self._asnumpy - gaussian_filter = self._gaussian_filter + + ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + self.transpose = force_transpose # Solve affine transformation m = asnumpy( xp.linalg.lstsq(self._probe_angles, self._xy_shifts_Ang, rcond=None)[0] ) - m_rotation, m_aberration = polar(m, side="right") - # Convert into rotation and aberration coefficients - self.rotation_Q_to_R_rads = -1 * np.arctan2(m_rotation[1, 0], m_rotation[0, 0]) - if np.abs(np.mod(self.rotation_Q_to_R_rads + np.pi, 2.0 * np.pi) - np.pi) > ( - np.pi * 0.5 - ): - self.rotation_Q_to_R_rads = ( - np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi + if force_rotation_deg is None: + m_rotation, m_aberration = polar(m, side="right") + + if force_transpose: + m_rotation = m_rotation.T + + # Convert into rotation and aberration coefficients + + self.rotation_Q_to_R_rads = -1 * np.arctan2( + m_rotation[1, 0], m_rotation[0, 0] ) - m_aberration = -1.0 * m_aberration + if np.abs( + np.mod(self.rotation_Q_to_R_rads + np.pi, 2.0 * np.pi) - np.pi + ) > (np.pi * 0.5): + self.rotation_Q_to_R_rads = ( + np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi + ) + m_aberration = -1.0 * m_aberration + else: + self.rotation_Q_to_R_rads = np.deg2rad(force_rotation_deg) + c, s = np.cos(self.rotation_Q_to_R_rads), np.sin(self.rotation_Q_to_R_rads) + + m_rotation = np.array([[c, -s], [s, c]]) + if force_transpose: + m_rotation = m_rotation.T + + m_aberration = m_rotation @ m + self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0 - self.aberration_A1x = ( - m_aberration[0, 0] - m_aberration[1, 1] - ) / 2.0 # factor /2 for A1 astigmatism? /4? - self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if self._device == "gpu": - xp._default_memory_pool.free_all_blocks() - xp.clear_memo() + if self.transpose: + self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + else: + self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 + self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + + ### Second pass + + # Aberration coefs + mn = [] + + for m in range( + fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order + ): + n_max = np.minimum(fit_aberrations_max_angular_order, m + 1) + for n in range(fit_aberrations_min_angular_order, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + self._aberrations_mn = np.array(mn) + self._aberrations_mn = self._aberrations_mn[ + np.argsort(self._aberrations_mn[:, 1]), : + ] + + sub = self._aberrations_mn[:, 1] > 0 + self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][ + np.argsort(self._aberrations_mn[sub, 0]), : + ] + self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][ + np.argsort(self._aberrations_mn[~sub, 0]), : + ] + self._aberrations_num = self._aberrations_mn.shape[0] + + if plot_CTF_comparison is None: + if fit_CTF_FFT: + plot_CTF_comparison = True + + if plot_BF_shifts_comparison is None: + if fit_BF_shifts: + plot_BF_shifts_comparison = True + + # Thon Rings Fitting + if fit_CTF_FFT or plot_CTF_comparison: + if upsampled and hasattr(self, "_kde_upsample_factor"): + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned)) + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + + reciprocal_extent = [ + -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor), + 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor), + ] + + else: + im_FFT = xp.abs(xp.fft.fft2(self._recon_BF)) + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + upsampled = False + + reciprocal_extent = [ + -0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[1], + 0.5 / self._scan_sampling[0], + -0.5 / self._scan_sampling[0], + ] + + # FFT coordinates + qx = xp.fft.fftfreq(im_FFT.shape[0], sx) + qy = xp.fft.fftfreq(im_FFT.shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha_FFT = xp.sqrt(qr2) * self._wavelength + theta_FFT = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + self._aberrations_basis_FFT = xp.zeros( + (alpha_FFT.size, self._aberrations_num) + ) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) / (m + 1) + ).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1) + ).ravel() + else: + # sin coef + self._aberrations_basis_FFT[:, a0] = ( + alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape_FFT = alpha_FFT.shape + plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1) + angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25 + + # CTF function + def calculate_CTF_FFT(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis_FFT[:, a0] + return xp.reshape(chi, alpha_shape) + + # Direct Shifts Fitting + if fit_BF_shifts: + # FFT coordinates + sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0]) + sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1]) + qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx) + qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy) + qx, qy = np.meshgrid(qx, qy, indexing="ij") + + # passive rotation basis by -theta + rotation_angle = -self.rotation_Q_to_R_rads + qx, qy = qx * np.cos(rotation_angle) + qy * np.sin( + rotation_angle + ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle) + + qr2 = qx**2 + qy**2 + u = qx * self._wavelength + v = qy * self._wavelength + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy, qx) + + # Aberration basis + self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num)) + self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + + if n == 0: + # Radially symmetric basis + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel() + self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel() + + elif a == 0: + # cos coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta)) + / (m + 1) + ).ravel() + + else: + # sin coef + self._aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() + self._aberrations_basis_du[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta)) + / (m + 1) + ).ravel() + self._aberrations_basis_dv[:, a0] = ( + alpha ** (m - 1) + * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta)) + / (m + 1) + ).ravel() + + # global scaling + self._aberrations_basis *= 2 * np.pi / self._wavelength + self._aberrations_surface_shape = alpha.shape + + # CTF function + def calculate_CTF(alpha_shape, *coefs): + chi = xp.zeros_like(self._aberrations_basis[:, 0]) + for a0 in range(len(coefs)): + chi += coefs[a0] * self._aberrations_basis[:, a0] + return xp.reshape(chi, alpha_shape) + + # initial coefficients and plotting intensity range mask + self._aberrations_coefs = np.zeros(self._aberrations_num) + + aberrations_mn_list = self._aberrations_mn.tolist() + if [1, 0, 0] in aberrations_mn_list: + ind_C1 = aberrations_mn_list.index([1, 0, 0]) + self._aberrations_coefs[ind_C1] = self.aberration_C1 + + if [1, 2, 0] in aberrations_mn_list: + ind_A1x = aberrations_mn_list.index([1, 2, 0]) + ind_A1y = aberrations_mn_list.index([1, 2, 1]) + self._aberrations_coefs[ind_A1x] = self.aberration_A1x + self._aberrations_coefs[ind_A1y] = self.aberration_A1y + + # Refinement using CTF fitting / Thon rings + if fit_CTF_FFT: + # scoring function to minimize - mean value of zero crossing regions of FFT + def score_CTF(coefs): + im_CTF = xp.abs( + calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs) + ) + mask = xp.logical_and( + im_CTF > 0.5 * np.pi, + im_CTF < (max_num_rings + 0.5) * np.pi, + ) + if np.any(mask): + weights = xp.cos(im_CTF[mask]) ** 4 + return asnumpy( + xp.sum( + weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha + ) + / xp.sum(weights) + ) + else: + return np.inf + + for max_num_rings in range(1, fit_max_thon_rings + 1): + # minimization + res = minimize( + score_CTF, + self._aberrations_coefs, + # method = 'Nelder-Mead', + # method = 'CG', + method="BFGS", + tol=1e-8, + ) + self._aberrations_coefs = res.x + + # Refinement using CTF fitting / Thon rings + elif fit_BF_shifts: + # Gradient basis + corner_indices = self._xy_inds - xp.asarray( + self._region_of_interest_shape // 2 + ) + raveled_indices = np.ravel_multi_index( + corner_indices.T, self._region_of_interest_shape, mode="wrap" + ) + gradients = xp.vstack( + ( + self._aberrations_basis_du[raveled_indices, :], + self._aberrations_basis_dv[raveled_indices, :], + ) + ) + + # (Relative) untransposed fit + raveled_shifts = self._xy_shifts_Ang.T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, raveled_shifts, rcond=None + )[:2] + + self._aberrations_coefs = asnumpy(aberrations_coefs) + + if self.transpose: + aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & ( + self._aberrations_mn[:, 2] == 0 + ) + self._aberrations_coefs[aberrations_to_flip] *= -1 + + # Plot the measured/fitted shifts comparison + if plot_BF_shifts_comparison: + measured_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 0] + + measured_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + measured_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = self._xy_shifts_Ang[:, 1] + + fitted_shifts = ( + xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1) + .reshape((2, -1)) + .T + ) + + fitted_shifts_sx = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sx[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 0] + + fitted_shifts_sy = xp.zeros( + self._region_of_interest_shape, dtype=xp.float32 + ) + fitted_shifts_sy[ + self._xy_inds[:, 0], self._xy_inds[:, 1] + ] = fitted_shifts[:, 1] + + max_shift = xp.max( + xp.array( + [ + xp.abs(measured_shifts_sx).max(), + xp.abs(measured_shifts_sy).max(), + xp.abs(fitted_shifts_sx).max(), + xp.abs(fitted_shifts_sy).max(), + ] + ) + ) + + show( + [ + [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)], + [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)], + ], + cmap="PiYG", + vmin=-max_shift, + vmax=max_shift, + intensity_range="absolute", + axsize=(4, 4), + ticks=False, + title=[ + "Measured Vertical Shifts", + "Fitted Vertical Shifts", + "Measured Horizontal Shifts", + "Fitted Horizontal Shifts", + ], + ) + + # Plot the CTF comparison between experiment and fit + if plot_CTF_comparison: + # Generate FFT plotting image + im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha) + int_vals = np.sort(im_scale.ravel()) + int_range = ( + int_vals[np.round(0.02 * im_scale.size).astype("int")], + int_vals[np.round(0.98 * im_scale.size).astype("int")], + ) + int_range = ( + int_range[0], + (int_range[1] - int_range[0]) * 1.0 + int_range[0], + ) + im_scale = np.clip( + (np.fft.fftshift(im_scale) - int_range[0]) + / (int_range[1] - int_range[0]), + 0, + 1, + ) + im_plot = np.tile(im_scale[:, :, None], (1, 1, 3)) + + # Add CTF zero crossings + im_CTF = calculate_CTF_FFT( + self._aberrations_surface_shape_FFT, *self._aberrations_coefs + ) + im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4 + im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2 + im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15 + im_CTF[xp.logical_not(plot_mask)] = 0 + + im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask)) + im_plot[:, :, 0] += im_CTF + im_plot[:, :, 1] -= im_CTF + im_plot[:, :, 2] -= im_CTF + im_plot = np.clip(im_plot, 0, 1) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4)) + ax1.imshow( + im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent + ) + ax2.imshow( + np.fft.fftshift(asnumpy(im_CTF_cos)), + cmap="gray", + extent=reciprocal_extent, + ) + + for ax in (ax1, ax2): + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + + ax1.set_title("Aligned Bright Field FFT") + ax2.set_title("Fitted CTF Zero-Crossings") + + fig.tight_layout() + + self.aberration_dict = { + tuple(self._aberrations_mn[a0]): { + "aberration name": _aberration_names.get( + tuple(self._aberrations_mn[a0, :2]), "-" + ).strip(), + "value [Ang]": self._aberrations_coefs[a0], + } + for a0 in range(self._aberrations_num) + } # Print results if self._verbose: + if fit_CTF_FFT or fit_BF_shifts: + print("Initial Aberration coefficients") + print("-------------------------------") print( ( "Rotation of Q w.r.t. R = " @@ -903,99 +1843,105 @@ def aberration_fit( f"{self.aberration_A1y:.0f}) Ang" ) ) - if self.aberration_C1 > 0: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - else: - print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") - print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") - - # Plot the CTF comparison between experiment and fit - if plot_CTF_compare: - # Get polar mean from FFT of BF reconstruction - im_fft = xp.abs(xp.fft.fft2(self._recon_BF)) - - # coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) - kra = xp.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2) - k_max = xp.max(kra) / np.sqrt(2.0) - k_num_bins = int(xp.ceil(k_max / plot_dk)) - k_bins = xp.arange(k_num_bins + 1) * plot_dk - - # histogram - k_ind = kra / plot_dk - kf = np.floor(k_ind).astype("int") - dk = k_ind - kf - sub = kf <= k_num_bins - hist_exp = xp.bincount( - kf[sub], weights=im_fft[sub] * (1 - dk[sub]), minlength=k_num_bins - ) - hist_norm = xp.bincount( - kf[sub], weights=(1 - dk[sub]), minlength=k_num_bins - ) - sub = kf <= k_num_bins - 1 + print(f"Aberration C1 = {self.aberration_C1:.0f} Ang") + print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang") + print(f"Transpose = {self.transpose}") + + if fit_CTF_FFT or fit_BF_shifts: + print() + print("Refined Aberration coefficients") + print("-------------------------------") + print("aberration radial angular dir. coefs") + print("name order order Ang ") + print("---------- ------- ------- ---- -----") + + for a0 in range(self._aberrations_mn.shape[0]): + m, n, a = self._aberrations_mn[a0] + name = _aberration_names.get((m, n), " -- ") + if n == 0: + print( + name + + " " + + str(m + 1) + + " 0 - " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + elif a == 0: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " x " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) + else: + print( + name + + " " + + str(m + 1) + + " " + + str(n) + + " y " + + str(np.round(self._aberrations_coefs[a0]).astype("int")) + ) - hist_exp += xp.bincount( - kf[sub] + 1, weights=im_fft[sub] * (dk[sub]), minlength=k_num_bins - ) - hist_norm += xp.bincount( - kf[sub] + 1, weights=(dk[sub]), minlength=k_num_bins - ) + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() - # KDE and normalizing - k_sigma = plot_dk / plot_k_sigma - hist_exp[0] = 0.0 - hist_exp = gaussian_filter(hist_exp, sigma=k_sigma, mode="nearest") - hist_norm = gaussian_filter(hist_norm, sigma=k_sigma, mode="nearest") - hist_exp /= hist_norm + def _calculate_CTF(self, alpha_shape, sampling, *coefs): + xp = self._xp - # CTF comparison - CTF_fit = xp.sin( - (-np.pi * self._wavelength * self.aberration_C1) * k_bins**2 - ) + # FFT coordinates + sx, sy = sampling + qx = xp.fft.fftfreq(alpha_shape[0], sx) + qy = xp.fft.fftfreq(alpha_shape[1], sy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + + alpha = xp.sqrt(qr2) * self._wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.zeros((alpha.size, self._aberrations_num)) + for a0 in range(self._aberrations_num): + m, n, a = self._aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - # plotting input - log scale - min_hist_val = xp.max(hist_exp) * 1e-3 - hist_plot = xp.log(np.maximum(hist_exp, min_hist_val)) - hist_plot -= xp.min(hist_plot) - hist_plot /= xp.max(hist_plot) + # global scaling + aberrations_basis *= 2 * np.pi / self._wavelength - hist_plot = asnumpy(hist_plot) - k_bins = asnumpy(k_bins) - CTF_fit = asnumpy(CTF_fit) + chi = xp.zeros_like(aberrations_basis[:, 0]) - fig, ax = plt.subplots(figsize=(8, 4)) + for a0 in range(len(coefs)): + chi += coefs[a0] * aberrations_basis[:, a0] - ax.fill_between( - k_bins, - hist_plot, - color=(0.7, 0.7, 0.7, 1), - ) - - ax.plot( - k_bins, - np.clip(CTF_fit, 0.0, np.inf), - color=(1, 0, 0, 1), - linewidth=2, - ) - ax.plot( - k_bins, - np.clip(-CTF_fit, 0.0, np.inf), - color=(0, 0.5, 1, 1), - linewidth=2, - ) - ax.set_xlim([0, k_bins[-1]]) - ax.set_ylim([0, 1.05]) + return xp.reshape(chi, alpha_shape) def aberration_correct( self, + use_CTF_fit=None, plot_corrected_phase: bool = True, k_info_limit: float = None, k_info_power: float = 1.0, Wiener_filter=False, - Wiener_signal_noise_ratio=1.0, - Wiener_filter_low_only=False, + Wiener_signal_noise_ratio: float = 1.0, + Wiener_filter_low_only: bool = False, + upsampled: bool = True, **kwargs, ): """ @@ -1003,6 +1949,9 @@ def aberration_correct( Parameters ---------- + use_FFT_fit: bool + Use the CTF fitted to the zero crossings of the FFT. + Default is True plot_corrected_phase: bool, optional If True, the CTF-corrected phase is plotted k_info_limit: float, optional @@ -1028,46 +1977,79 @@ def aberration_correct( ) ) + if upsampled and hasattr(self, "_kde_upsample_factor"): + im = self._recon_BF_subpixel_aligned + sx = self._scan_sampling[0] / self._kde_upsample_factor + sy = self._scan_sampling[1] / self._kde_upsample_factor + else: + upsampled = False + im = self._recon_BF + sx = self._scan_sampling[0] + sy = self._scan_sampling[1] + # Fourier coordinates - kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0]) - ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1]) + kx = xp.fft.fftfreq(im.shape[0], sx) + ky = xp.fft.fftfreq(im.shape[1], sy) kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2 - # CTF - sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + if use_CTF_fit is None: + if hasattr(self, "_aberrations_surface_shape"): + use_CTF_fit = True - if Wiener_filter: - SNR_inv = ( - xp.sqrt( - 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power)) - ) - / Wiener_signal_noise_ratio + if use_CTF_fit: + sin_chi = np.sin( + self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs) ) - CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) - if Wiener_filter_low_only: - # limit Wiener filter to only the part of the CTF before 1st maxima - k_thresh = 1 / xp.sqrt( - 2.0 * self._wavelength * xp.abs(self.aberration_C1) - ) - k_mask = kra2 >= k_thresh**2 - CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) - # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr - - else: - # CTF without tilt correction (beyond the parallax operator) CTF_corr = xp.sign(sin_chi) CTF_corr[0, 0] = 0 # apply correction to mean reconstructed BF image - im_fft_corr = xp.fft.fft2(self._recon_BF) * CTF_corr + im_fft_corr = xp.fft.fft2(im) * CTF_corr # if needed, add low pass filter output image if k_info_limit is not None: im_fft_corr /= 1 + (kra2**k_info_power) / ( (k_info_limit) ** (2 * k_info_power) ) + else: + # CTF + sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2) + + if Wiener_filter: + SNR_inv = ( + xp.sqrt( + 1 + + (kra2**k_info_power) + / ((k_info_limit) ** (2 * k_info_power)) + ) + / Wiener_signal_noise_ratio + ) + CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv) + if Wiener_filter_low_only: + # limit Wiener filter to only the part of the CTF before 1st maxima + k_thresh = 1 / xp.sqrt( + 2.0 * self._wavelength * xp.abs(self.aberration_C1) + ) + k_mask = kra2 >= k_thresh**2 + CTF_corr[k_mask] = xp.sign(sin_chi[k_mask]) + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + else: + # CTF without tilt correction (beyond the parallax operator) + CTF_corr = xp.sign(sin_chi) + CTF_corr[0, 0] = 0 + + # apply correction to mean reconstructed BF image + im_fft_corr = xp.fft.fft2(im) * CTF_corr + + # if needed, add low pass filter output image + if k_info_limit is not None: + im_fft_corr /= 1 + (kra2**k_info_power) / ( + (k_info_limit) ** (2 * k_info_power) + ) # Output phase image self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) @@ -1084,12 +2066,14 @@ def aberration_correct( fig, ax = plt.subplots(figsize=figsize) - cropped_object = self._crop_padded_object(self._recon_phase_corrected) + cropped_object = self._crop_padded_object( + self._recon_phase_corrected, upsampled=upsampled + ) extent = [ 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], + sy * cropped_object.shape[1], + sx * cropped_object.shape[0], 0, ] @@ -1246,6 +2230,7 @@ def _crop_padded_object( self, padded_object: np.ndarray, remaining_padding: int = 0, + upsampled: bool = False, ): """ Utility function to crop padded object @@ -1266,8 +2251,19 @@ def _crop_padded_object( asnumpy = self._asnumpy - pad_x = self._object_padding_px[0] // 2 - remaining_padding - pad_y = self._object_padding_px[1] // 2 - remaining_padding + if upsampled: + pad_x = np.round( + self._object_padding_px[0] / 2 * self._kde_upsample_factor + ).astype("int") + pad_y = np.round( + self._object_padding_px[1] / 2 * self._kde_upsample_factor + ).astype("int") + else: + pad_x = self._object_padding_px[0] // 2 + pad_y = self._object_padding_px[1] // 2 + + pad_x -= remaining_padding + pad_y -= remaining_padding return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y]) @@ -1276,6 +2272,7 @@ def _visualize_figax( fig, ax, remaining_padding: int = 0, + upsampled: bool = False, **kwargs, ): """ @@ -1294,14 +2291,31 @@ def _visualize_figax( cmap = kwargs.pop("cmap", "magma") - cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + if upsampled: + cropped_object = self._crop_padded_object( + self._recon_BF_subpixel_aligned, remaining_padding, upsampled + ) - extent = [ - 0, - self._scan_sampling[1] * cropped_object.shape[1], - self._scan_sampling[0] * cropped_object.shape[0], - 0, - ] + extent = [ + 0, + self._scan_sampling[1] + * cropped_object.shape[1] + / self._kde_upsample_factor, + self._scan_sampling[0] + * cropped_object.shape[0] + / self._kde_upsample_factor, + 0, + ] + + else: + cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding) + + extent = [ + 0, + self._scan_sampling[1] * cropped_object.shape[1], + self._scan_sampling[0] * cropped_object.shape[0], + 0, + ] ax.imshow( cropped_object, @@ -1310,10 +2324,11 @@ def _visualize_figax( **kwargs, ) - def _visualize_shifts( + def show_shifts( self, scale_arrows=1, plot_arrow_freq=1, + plot_rotated_shifts=True, **kwargs, ): """ @@ -1330,10 +2345,22 @@ def _visualize_shifts( xp = self._xp asnumpy = self._asnumpy - figsize = kwargs.pop("figsize", (6, 6)) color = kwargs.pop("color", (1, 0, 0, 1)) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + figsize = kwargs.pop("figsize", (8, 4)) + fig, ax = plt.subplots(1, 2, figsize=figsize) + scaling_factor = ( + xp.array(self._reciprocal_sampling) + / xp.array(self._scan_sampling) + * scale_arrows + ) + rotated_shifts = self._xy_shifts_Ang * scaling_factor - fig, ax = plt.subplots(figsize=figsize) + else: + figsize = kwargs.pop("figsize", (4, 4)) + fig, ax = plt.subplots(figsize=figsize) + + shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0] dp_mask_ind = xp.nonzero(self._dp_mask) yy, xx = xp.meshgrid( @@ -1343,29 +2370,68 @@ def _visualize_shifts( masked_ind = xp.logical_and(freq_mask, self._dp_mask) plot_ind = masked_ind[dp_mask_ind] - ax.quiver( - asnumpy(self._kxy[plot_ind, 1]), - asnumpy(self._kxy[plot_ind, 0]), - asnumpy( - self._xy_shifts[plot_ind, 1] - * scale_arrows - * self._reciprocal_sampling[0] - ), - asnumpy( - self._xy_shifts[plot_ind, 0] - * scale_arrows - * self._reciprocal_sampling[1] - ), - color=color, - angles="xy", - scale_units="xy", - scale=1, - **kwargs, - ) - kr_max = xp.max(self._kr) - ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) - ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"): + ax[0].quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[0].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[0].set_title("Measured Bright Field Shifts") + ax[0].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[0].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[0].set_aspect("equal") + + # passive coordinate rotation + tf_T = AffineTransform(angle=-self.rotation_Q_to_R_rads) + rotated_kxy = tf_T(self._kxy[plot_ind], xp=xp) + ax[1].quiver( + asnumpy(rotated_kxy[:, 1]), + asnumpy(rotated_kxy[:, 0]), + asnumpy(rotated_shifts[plot_ind, 1]), + asnumpy(rotated_shifts[plot_ind, 0]), + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax[1].set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax[1].set_title("Rotated Bright Field Shifts") + ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]") + ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]") + ax[1].set_aspect("equal") + else: + ax.quiver( + asnumpy(self._kxy[plot_ind, 1]), + asnumpy(self._kxy[plot_ind, 0]), + asnumpy(shifts[plot_ind, 1]), + asnumpy(shifts[plot_ind, 0]), + color=color, + angles="xy", + scale_units="xy", + scale=1, + **kwargs, + ) + + ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max]) + ax.set_title("Measured BF Shifts") + ax.set_ylabel(r"$k_x$ [$A^{-1}$]") + ax.set_xlabel(r"$k_y$ [$A^{-1}$]") + ax.set_aspect("equal") + + fig.tight_layout() def visualize( self, @@ -1391,3 +2457,21 @@ def visualize( ax.set_title("Reconstructed Bright Field Image") return self + + @property + def object_cropped(self): + """cropped object""" + if hasattr(self, "_recon_phase_corrected"): + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_phase_corrected, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_phase_corrected) + else: + if hasattr(self, "_kde_upsample_factor"): + return self._crop_padded_object( + self._recon_BF_subpixel_aligned, upsampled=True + ) + else: + return self._crop_padded_object(self._recon_BF) diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index 67dba6115..59bf61da2 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from py4DSTEM.process.phase.utils import ( array_slice, @@ -8,6 +10,16 @@ ) from py4DSTEM.process.utils import get_CoM +try: + import cupy as cp +except (ModuleNotFoundError, ImportError): + cp = np + import os + + # make sure pylops doesn't try to use cupy + os.environ["CUPY_PYLOPS"] = "0" +import pylops # this must follow the exception + class PtychographicConstraints: """ @@ -183,6 +195,63 @@ def _object_butterworth_constraint( return current_object + def _object_denoise_tv_pylops(self, current_object, weight, iterations): + """ + Performs second order TV denoising along x and y + + Parameters + ---------- + current_object: np.ndarray + Current object estimate + weight : float + Denoising weight. The greater `weight`, the more denoising (at + the expense of fidelity to `input`). + iterations: float + Number of iterations to run in denoising algorithm. + `niter_out` in pylops + + Returns + ------- + constrained_object: np.ndarray + Constrained object estimate + + """ + xp = self._xp + + if xp.iscomplexobj(current_object): + current_object_tv = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) + + else: + nx, ny = current_object.shape + niter_out = iterations + niter_in = 1 + Iop = pylops.Identity(nx * ny) + xy_laplacian = pylops.Laplacian( + (nx, ny), axes=(0, 1), edge=False, kind="backward" + ) + + l1_regs = [xy_laplacian] + + current_object_tv = pylops.optimization.sparsity.splitbregman( + Op=Iop, + y=current_object.ravel(), + RegsL1=l1_regs, + niter_outer=niter_out, + niter_inner=niter_in, + epsRL1s=[weight], + tol=1e-4, + tau=1.0, + show=False, + )[0] + + current_object_tv = current_object_tv.reshape(current_object.shape) + + return current_object_tv + def _object_denoise_tv_chambolle( self, current_object, @@ -229,90 +298,100 @@ def _object_denoise_tv_chambolle( Adapted skimage.restoration.denoise_tv_chambolle. """ xp = self._xp - - current_object_sum = xp.sum(current_object) - if axis is None: - ndim = xp.arange(current_object.ndim).tolist() - elif isinstance(axis, tuple): - ndim = list(axis) + if xp.iscomplexobj(current_object): + updated_object = current_object + warnings.warn( + ("TV denoising is currently only supported for potential objects."), + UserWarning, + ) else: - ndim = [axis] - - if pad_object: - pad_width = ((0, 0),) * current_object.ndim - pad_width = list(pad_width) - for ax in range(len(ndim)): - pad_width[ndim[ax]] = (1, 1) - current_object = xp.pad( - current_object, pad_width=pad_width, mode="constant" + current_object_sum = xp.sum(current_object) + if axis is None: + ndim = xp.arange(current_object.ndim).tolist() + elif isinstance(axis, tuple): + ndim = list(axis) + else: + ndim = [axis] + + if pad_object: + pad_width = ((0, 0),) * current_object.ndim + pad_width = list(pad_width) + for ax in range(len(ndim)): + pad_width[ndim[ax]] = (1, 1) + current_object = xp.pad( + current_object, pad_width=pad_width, mode="constant" + ) + + p = xp.zeros( + (current_object.ndim,) + current_object.shape, + dtype=current_object.dtype, ) + g = xp.zeros_like(p) + d = xp.zeros_like(current_object) + + i = 0 + while i < max_num_iter: + if i > 0: + # d will be the (negative) divergence of p + d = -p.sum(0) + slices_d = [ + slice(None), + ] * current_object.ndim + slices_p = [ + slice(None), + ] * (current_object.ndim + 1) + for ax in range(len(ndim)): + slices_d[ndim[ax]] = slice(1, None) + slices_p[ndim[ax] + 1] = slice(0, -1) + slices_p[0] = ndim[ax] + d[tuple(slices_d)] += p[tuple(slices_p)] + slices_d[ndim[ax]] = slice(None) + slices_p[ndim[ax] + 1] = slice(None) + updated_object = current_object + d + else: + updated_object = current_object + E = (d**2).sum() - p = xp.zeros( - (current_object.ndim,) + current_object.shape, dtype=current_object.dtype - ) - g = xp.zeros_like(p) - d = xp.zeros_like(current_object) - - i = 0 - while i < max_num_iter: - if i > 0: - # d will be the (negative) divergence of p - d = -p.sum(0) - slices_d = [ - slice(None), - ] * current_object.ndim - slices_p = [ + # g stores the gradients of updated_object along each axis + # e.g. g[0] is the first order finite difference along axis 0 + slices_g = [ slice(None), ] * (current_object.ndim + 1) for ax in range(len(ndim)): - slices_d[ndim[ax]] = slice(1, None) - slices_p[ndim[ax] + 1] = slice(0, -1) - slices_p[0] = ndim[ax] - d[tuple(slices_d)] += p[tuple(slices_p)] - slices_d[ndim[ax]] = slice(None) - slices_p[ndim[ax] + 1] = slice(None) - updated_object = current_object + d - else: - updated_object = current_object - E = (d**2).sum() - - # g stores the gradients of updated_object along each axis - # e.g. g[0] is the first order finite difference along axis 0 - slices_g = [ - slice(None), - ] * (current_object.ndim + 1) - for ax in range(len(ndim)): - slices_g[ndim[ax] + 1] = slice(0, -1) - slices_g[0] = ndim[ax] - g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) - slices_g[ndim[ax] + 1] = slice(None) - if scaling is not None: - scaling /= xp.max(scaling) - g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] - norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] - E += weight * norm.sum() - tau = 1.0 / (2.0 * len(ndim)) - norm *= tau / weight - norm += 1.0 - p -= tau * g - p /= norm - E /= float(current_object.size) - if i == 0: - E_init = E - E_previous = E - else: - if xp.abs(E_previous - E) < eps * E_init: - break - else: + slices_g[ndim[ax] + 1] = slice(0, -1) + slices_g[0] = ndim[ax] + g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax]) + slices_g[ndim[ax] + 1] = slice(None) + if scaling is not None: + scaling /= xp.max(scaling) + g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis] + norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...] + E += weight * norm.sum() + tau = 1.0 / (2.0 * len(ndim)) + norm *= tau / weight + norm += 1.0 + p -= tau * g + p /= norm + E /= float(current_object.size) + if i == 0: + E_init = E E_previous = E - i += 1 + else: + if xp.abs(E_previous - E) < eps * E_init: + break + else: + E_previous = E + i += 1 - if pad_object: - for ax in range(len(ndim)): - slices = array_slice(ndim[ax], current_object.ndim, 1, -1) - updated_object = updated_object[slices] + if pad_object: + for ax in range(len(ndim)): + slices = array_slice(ndim[ax], current_object.ndim, 1, -1) + updated_object = updated_object[slices] + updated_object = ( + updated_object / xp.sum(updated_object) * current_object_sum + ) - return updated_object / xp.sum(updated_object) * current_object_sum + return updated_object def _probe_center_of_mass_constraint(self, current_probe): """ @@ -364,7 +443,7 @@ def _probe_amplitude_constraint( erf = self._erf probe_intensity = xp.abs(current_probe) ** 2 - # current_probe_sum = xp.sum(probe_intensity) + current_probe_sum = xp.sum(probe_intensity) X = xp.fft.fftfreq(current_probe.shape[0])[:, None] Y = xp.fft.fftfreq(current_probe.shape[1])[None] @@ -374,10 +453,10 @@ def _probe_amplitude_constraint( tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_fourier_amplitude_constraint( self, @@ -406,7 +485,7 @@ def _probe_fourier_amplitude_constraint( xp = self._xp asnumpy = self._asnumpy - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) updated_probe_fft, _, _, _ = regularize_probe_amplitude( @@ -419,10 +498,10 @@ def _probe_fourier_amplitude_constraint( updated_probe_fft = xp.asarray(updated_probe_fft) updated_probe = xp.fft.ifft2(updated_probe_fft) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aperture_constraint( self, @@ -444,16 +523,16 @@ def _probe_aperture_constraint( """ xp = self._xp - # current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) updated_probe = xp.fft.ifft2( xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture ) - # updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - # normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe # * normalization + return updated_probe * normalization def _probe_aberration_fitting_constraint( self, @@ -485,16 +564,18 @@ def _probe_aberration_fitting_constraint( fourier_probe = xp.fft.fft2(current_probe) fourier_probe_abs = xp.abs(fourier_probe) sampling = self.sampling + energy = self._energy fitted_angle, _ = fit_aberration_surface( fourier_probe, sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - fourier_probe = fourier_probe_abs * xp.exp(1.0j * fitted_angle) + fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle) current_probe = xp.fft.ifft2(fourier_probe) return current_probe diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index e3713cde1..c8cc5ee3e 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -14,7 +14,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ImportError, ModuleNotFoundError): cp = np from emdfile import Custom, tqdmnd @@ -66,6 +66,8 @@ class SimultaneousPtychographicReconstruction(PtychographicReconstruction): object_padding_px: Tuple[int,int], optional Pixel dimensions to pad objects with If None, the padding is set to half the probe ROI dimensions + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction initial_object_guess: np.ndarray, optional Initial guess for complex-valued object of dimensions (Px,Py) If None, initialized to 1.0j @@ -102,6 +104,7 @@ def __init__( vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, object_padding_px: Tuple[int, int] = None, + positions_mask: np.ndarray = None, initial_object_guess: np.ndarray = None, initial_probe_guess: np.ndarray = None, initial_scan_positions: np.ndarray = None, @@ -167,6 +170,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -192,6 +196,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -246,6 +251,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -326,6 +333,27 @@ def preprocess( f"simultaneous_measurements_mode must be either '-+', '-0+', or '0+', not {self._simultaneous_measurements_mode}" ) + if self._positions_mask is not None: + self._positions_mask = np.asarray(self._positions_mask) + + if self._positions_mask.ndim == 2: + warnings.warn( + "2D `positions_mask` assumed the same for all measurements.", + UserWarning, + ) + self._positions_mask = np.tile( + self._positions_mask, (self._num_sim_measurements, 1, 1) + ) + + if self._positions_mask.dtype != "bool": + warnings.warn( + "`positions_mask` converted to `bool` array.", + UserWarning, + ) + self._positions_mask = self._positions_mask.astype("bool") + else: + self._positions_mask = [None] * self._num_sim_measurements + if force_com_shifts is None: force_com_shifts = [None, None, None] elif len(force_com_shifts) == self._num_sim_measurements: @@ -338,6 +366,9 @@ def preprocess( ) ) + # Ensure plot_center_of_mass is not in kwargs + kwargs.pop("plot_center_of_mass", None) + # 1st measurement sets rotation angle and transposition ( measurement_0, @@ -404,6 +435,8 @@ def preprocess( intensities_0, com_fitted_x_0, com_fitted_y_0, + crop_patterns, + self._positions_mask[0], ) # explicitly delete namescapes @@ -487,6 +520,8 @@ def preprocess( intensities_1, com_fitted_x_1, com_fitted_y_1, + crop_patterns, + self._positions_mask[1], ) # explicitly delete namescapes @@ -571,6 +606,8 @@ def preprocess( intensities_2, com_fitted_x_2, com_fitted_y_2, + crop_patterns, + self._positions_mask[2], ) # explicitly delete namescapes @@ -610,8 +647,8 @@ def preprocess( self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:]) self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions - ) + self._scan_positions, self._positions_mask[0] + ) # TO-DO: generaltize to per-dataset probe positions # handle semiangle specified in pixels if self._semiangle_cutoff_pixels: @@ -683,6 +720,10 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) self._probe = ( ComplexProbe( @@ -746,19 +787,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -780,23 +815,22 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + cax1, + chroma_boost=chroma_boost, ) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="Greys_r", ) ax2.scatter( self.positions[:, 1], @@ -808,7 +842,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -2232,6 +2266,9 @@ def _constraints( q_highpass_e, q_highpass_m, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, warmup_iteration, object_positivity, shrinkage_rad, @@ -2300,6 +2337,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising warmup_iteration: bool If True, constraints electrostatic object only object_positivity: bool @@ -2349,6 +2392,15 @@ def _constraints( if self._object_type == "complex": magnetic_obj = magnetic_obj.real + if tv_denoise: + electrostatic_obj = self._object_denoise_tv_pylops( + electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) + + if not warmup_iteration: + magnetic_obj = self._object_denoise_tv_pylops( + magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter + ) if shrinkage_rad > 0.0 or object_mask is not None: electrostatic_obj = self._object_shrinkage_constraint( @@ -2446,6 +2498,9 @@ def reconstruct( q_highpass_e: float = None, q_highpass_m: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -2538,6 +2593,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass filtering magnetic object butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -2748,6 +2809,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = (None,) * self._num_sim_measurements self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -2899,6 +2962,9 @@ def reconstruct( q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -3005,6 +3071,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -3023,14 +3090,15 @@ def _visualize_last_iteration( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ figsize = kwargs.pop("figsize", (12, 5)) cmap_e = kwargs.pop("cmap_e", "magma") cmap_m = kwargs.pop("cmap_m", "PuOr") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) if self._object_type == "complex": obj_e = np.angle(self.object[0]) @@ -3052,6 +3120,8 @@ def _visualize_last_iteration( vmin_m = kwargs.pop("vmin_m", min_m) vmax_m = kwargs.pop("vmax_m", max_m) + chroma_boost = kwargs.pop("chroma_boost", 1) + extent = [ 0, self.sampling[1] * rotated_shape[1], @@ -3155,30 +3225,35 @@ def _visualize_last_iteration( # Probe ax = fig.add_subplot(spec[0, 2]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + probe_array, + chroma_boost=chroma_boost, ) ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, power=2, chroma_boost=chroma_boost ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: # Electrostatic Object @@ -3229,10 +3304,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1, :]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -3242,6 +3317,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -3263,6 +3339,9 @@ def _visualize_all_iterations( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -3275,6 +3354,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -3296,6 +3376,9 @@ def visualize( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -3311,6 +3394,7 @@ def visualize( plot_convergence=plot_convergence, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, @@ -3322,9 +3406,105 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, ) return self + + @property + def self_consistency_errors(self): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Re-initialize fractional positions and vector patches, max_batch_size = None + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + error = xp.sum( + xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + error /= self._mean_diffraction_intensity + + return asnumpy(error) + + def _return_self_consistency_errors( + self, + max_batch_size=None, + ): + """Compute the self-consistency errors for each probe position""" + + xp = self._xp + asnumpy = self._asnumpy + + # Batch-size + if max_batch_size is None: + max_batch_size = self._num_diffraction_patterns + + # Re-initialize fractional positions and vector patches + errors = np.array([]) + positions_px = self._positions_px.copy() + + for start, end in generate_batches( + self._num_diffraction_patterns, max_batch=max_batch_size + ): + # batch indices + self._positions_px = positions_px[start:end] + self._positions_px_fractional = self._positions_px - xp.round( + self._positions_px + ) + ( + self._vectorized_patch_indices_row, + self._vectorized_patch_indices_col, + ) = self._extract_vectorized_patch_indices() + amplitudes = self._amplitudes[0][start:end] + + # Overlaps + _, _, overlap = self._warmup_overlap_projection(self._object, self._probe) + fourier_overlap = xp.fft.fft2(overlap[0]) + + # Normalized mean-squared errors + batch_errors = xp.sum( + xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1) + ) + errors = np.hstack((errors, batch_errors)) + + self._positions_px = positions_px.copy() + errors /= self._mean_diffraction_intensity + + return asnumpy(errors) + + def _return_projected_cropped_potential( + self, + ): + """Utility function to accommodate multiple classes""" + if self._object_type == "complex": + projected_cropped_potential = np.angle(self.object_cropped[0]) + else: + projected_cropped_potential = self.object_cropped[0] + + return projected_cropped_potential + + @property + def object_cropped(self): + """Cropped and rotated object""" + + obj_e, obj_m = self._object + obj_e = self._crop_rotate_object_fov(obj_e) + obj_m = self._crop_rotate_object_fov(obj_m) + return (obj_e, obj_m) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index df0ef5e1c..36baac21e 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -14,7 +14,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ImportError, ModuleNotFoundError): cp = np from emdfile import Custom, tqdmnd @@ -79,6 +79,8 @@ class SingleslicePtychographicReconstruction(PtychographicReconstruction): object_type: str, optional The object can be reconstructed as a real potential ('potential') or a complex object ('complex') + positions_mask: np.ndarray, optional + Boolean real space mask to select positions in datacube to skip for reconstruction name: str, optional Class name kwargs: @@ -102,6 +104,7 @@ def __init__( initial_scan_positions: np.ndarray = None, object_padding_px: Tuple[int, int] = None, object_type: str = "complex", + positions_mask: np.ndarray = None, verbose: bool = True, device: str = "cpu", name: str = "ptychographic_reconstruction", @@ -163,6 +166,7 @@ def __init__( self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px + self._positions_mask = positions_mask self._verbose = verbose self._device = device self._preprocessed = False @@ -188,6 +192,7 @@ def preprocess( force_angular_sampling: float = None, force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, + crop_patterns: bool = False, **kwargs, ): """ @@ -245,6 +250,8 @@ def preprocess( object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded + crop_patterns: bool + if True, crop patterns to avoid wrap around of patterns when centering Returns -------- @@ -268,6 +275,13 @@ def preprocess( ) ) + if self._positions_mask is not None and self._positions_mask.dtype != "bool": + warnings.warn( + ("`positions_mask` converted to `bool` array"), + UserWarning, + ) + self._positions_mask = np.asarray(self._positions_mask, dtype="bool") + ( self._datacube, self._vacuum_probe_intensity, @@ -333,6 +347,8 @@ def preprocess( self._intensities, self._com_fitted_x, self._com_fitted_y, + crop_patterns, + self._positions_mask, ) # explicitly delete namespace @@ -341,7 +357,7 @@ def preprocess( del self._intensities self._positions_px = self._calculate_scan_positions_in_pixels( - self._scan_positions + self._scan_positions, self._positions_mask ) # handle semiangle specified in pixels @@ -412,6 +428,11 @@ def preprocess( bilinear=True, device=self._device, ) + if crop_patterns: + self._vacuum_probe_intensity = self._vacuum_probe_intensity[ + self._crop_mask + ].reshape(self._region_of_interest_shape) + self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, @@ -451,7 +472,6 @@ def preprocess( self._probe_initial = self._probe.copy() self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) - self._known_aberrations_array = ComplexProbe( energy=self._energy, gpts=self._region_of_interest_shape, @@ -474,19 +494,13 @@ def preprocess( if plot_probe_overlaps: figsize = kwargs.pop("figsize", (9, 4)) - cmap = kwargs.pop("cmap", "Greys_r") - vmin = kwargs.pop("vmin", None) - vmax = kwargs.pop("vmax", None) - hue_start = kwargs.pop("hue_start", 0) - invert = kwargs.pop("invert", False) + chroma_boost = kwargs.pop("chroma_boost", 1) # initial probe complex_probe_rgb = Complex2RGB( self.probe_centered, - vmin=vmin, - vmax=vmax, - hue_start=hue_start, - invert=invert, + power=2, + chroma_boost=chroma_boost, ) extent = [ @@ -508,23 +522,19 @@ def preprocess( ax1.imshow( complex_probe_rgb, extent=probe_extent, - **kwargs, ) divider = make_axes_locatable(ax1) cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) + add_colorbar_arg(cax1, chroma_boost=chroma_boost) ax1.set_ylabel("x [A]") ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + ax1.set_title("Initial probe intensity") ax2.imshow( asnumpy(probe_overlap), extent=extent, - cmap=cmap, - **kwargs, + cmap="gray", ) ax2.scatter( self.positions[:, 1], @@ -536,7 +546,7 @@ def preprocess( ax2.set_xlabel("y [A]") ax2.set_xlim((extent[0], extent[1])) ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + ax2.set_title("Object field of view") fig.tight_layout() @@ -1023,6 +1033,9 @@ def _constraints( q_lowpass, q_highpass, butterworth_order, + tv_denoise, + tv_denoise_weight, + tv_denoise_inner_iter, object_positivity, shrinkage_rad, object_mask, @@ -1078,6 +1091,12 @@ def _constraints( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise: bool + If True, applies TV denoising on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool If True, clips negative potential values shrinkage_rad: float @@ -1108,6 +1127,11 @@ def _constraints( butterworth_order, ) + if tv_denoise: + current_object = self._object_denoise_tv_pylops( + current_object, tv_denoise_weight, tv_denoise_inner_iter + ) + if shrinkage_rad > 0.0 or object_mask is not None: current_object = self._object_shrinkage_constraint( current_object, @@ -1198,6 +1222,9 @@ def reconstruct( q_lowpass: float = None, q_highpass: float = None, butterworth_order: float = 2, + tv_denoise_iter: int = np.inf, + tv_denoise_weight: float = None, + tv_denoise_inner_iter: float = 40, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1284,6 +1311,12 @@ def reconstruct( Cut-off frequency in A^-1 for high-pass butterworth filter butterworth_order: float Butterworth filter order. Smaller gives a smoother filter + tv_denoise_iter: int, optional + Number of iterations to run using tv denoise filter on object + tv_denoise_weight: float + Denoising weight. The greater `weight`, the more denoising. + tv_denoise_inner_iter: float + Number of iterations to run in inner loop of TV denoising object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1486,6 +1519,8 @@ def reconstruct( ) = self._extract_vectorized_patch_indices() self._exit_waves = None self._object_type = self._object_type_initial + if hasattr(self, "_tf"): + del self._tf elif reset is None: if hasattr(self, "error"): warnings.warn( @@ -1618,6 +1653,9 @@ def reconstruct( q_lowpass=q_lowpass, q_highpass=q_highpass, butterworth_order=butterworth_order, + tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None, + tv_denoise_weight=tv_denoise_weight, + tv_denoise_inner_iter=tv_denoise_inner_iter, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -1711,6 +1749,7 @@ def _visualize_last_iteration( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, padding: int, **kwargs, ): @@ -1729,13 +1768,16 @@ def _visualize_last_iteration( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ figsize = kwargs.pop("figsize", (8, 5)) cmap = kwargs.pop("cmap", "magma") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + + chroma_boost = kwargs.pop("chroma_boost", 1) if self._object_type == "complex": obj = np.angle(self.object) @@ -1827,30 +1869,38 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[0, 1]) if plot_fourier_probe: + if remove_initial_probe_aberrations: + probe_array = self.probe_fourier_residual + else: + probe_array = self.probe_fourier + probe_array = Complex2RGB( - self.probe_fourier, hue_start=hue_start, invert=invert + probe_array, + chroma_boost=chroma_boost, ) + ax.set_title("Reconstructed Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - self.probe, hue_start=hue_start, invert=invert + self.probe, + power=2, + chroma_boost=chroma_boost, ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: divider = make_axes_locatable(ax) ax_cb = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg(ax_cb, hue_start=hue_start, invert=invert) + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) else: ax = fig.add_subplot(spec[0]) @@ -1883,10 +1933,10 @@ def _visualize_last_iteration( ax = fig.add_subplot(spec[1]) ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax.set_ylabel("NMSE") - ax.set_xlabel("Iteration Number") + ax.set_xlabel("Iteration number") ax.yaxis.tick_right() - fig.suptitle(f"Normalized Mean Squared Error: {self.error:.3e}") + fig.suptitle(f"Normalized mean squared error: {self.error:.3e}") spec.tight_layout(fig) def _visualize_all_iterations( @@ -1896,6 +1946,7 @@ def _visualize_all_iterations( plot_convergence: bool, plot_probe: bool, plot_fourier_probe: bool, + remove_initial_probe_aberrations: bool, iterations_grid: Tuple[int, int], padding: int, **kwargs, @@ -1917,6 +1968,9 @@ def _visualize_all_iterations( If true, the reconstructed complex probe is displayed plot_fourier_probe: bool If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object """ @@ -1957,9 +2011,9 @@ def _visualize_all_iterations( else (3 * iterations_grid[1], 3 * iterations_grid[0]) ) figsize = kwargs.pop("figsize", auto_figsize) - cmap = kwargs.pop("cmap", "inferno") - invert = kwargs.pop("invert", False) - hue_start = kwargs.pop("hue_start", 0) + cmap = kwargs.pop("cmap", "magma") + + chroma_boost = kwargs.pop("chroma_boost", 1) errors = np.array(self.error_iterations) @@ -2057,36 +2111,37 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: - probe_array = Complex2RGB( - asnumpy( - self._return_fourier_probe_from_centered_probe( - probes[grid_range[n]] - ) - ), - hue_start=hue_start, - invert=invert, + probe_array = asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]], + remove_initial_probe_aberrations=remove_initial_probe_aberrations, + ) ) + + probe_array = Complex2RGB(probe_array, chroma_boost=chroma_boost) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") ax.set_ylabel("kx [mrad]") ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( - probes[grid_range[n]], hue_start=hue_start, invert=invert + probes[grid_range[n]], + power=2, + chroma_boost=chroma_boost, ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe intensity") ax.set_ylabel("x [A]") ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, - **kwargs, ) if cbar: add_colorbar_arg( - grid.cbar_axes[n], hue_start=hue_start, invert=invert + grid.cbar_axes[n], + chroma_boost=chroma_boost, ) if plot_convergence: @@ -2098,7 +2153,7 @@ def _visualize_all_iterations( ax2 = fig.add_subplot(spec[1]) ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs) ax2.set_ylabel("NMSE") - ax2.set_xlabel("Iteration Number") + ax2.set_xlabel("Iteration number") ax2.yaxis.tick_right() spec.tight_layout(fig) @@ -2110,6 +2165,7 @@ def visualize( plot_convergence: bool = True, plot_probe: bool = True, plot_fourier_probe: bool = False, + remove_initial_probe_aberrations: bool = False, cbar: bool = True, padding: int = 0, **kwargs, @@ -2131,6 +2187,9 @@ def visualize( If true, the reconstructed probe intensity is also displayed plot_fourier_probe: bool, optional If true, the reconstructed complex Fourier probe is displayed + remove_initial_probe_aberrations: bool, optional + If true, when plotting fourier probe, removes initial probe + to visualize changes padding : int, optional Pixels to pad by post rotating-cropping object @@ -2147,6 +2206,7 @@ def visualize( plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, cbar=cbar, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, padding=padding, **kwargs, ) @@ -2157,6 +2217,7 @@ def visualize( iterations_grid=iterations_grid, plot_probe=plot_probe, plot_fourier_probe=plot_fourier_probe, + remove_initial_probe_aberrations=remove_initial_probe_aberrations, cbar=cbar, padding=padding, **kwargs, diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index c2e1d3b77..fc1b59a07 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -8,7 +8,7 @@ try: import cupy as cp from cupyx.scipy.fft import rfft -except ImportError: +except (ImportError, ModuleNotFoundError): cp = None from scipy.fft import dstn, idstn @@ -1543,46 +1543,84 @@ def step_model(radius, sig_0, rad_0, width): def aberrations_basis_function( probe_size, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ + + # Add constant phase shift in basis + mn = [[-1, 0, 0]] + + for m in range(1, max_radial_order): + n_max = np.minimum(max_angular_order, m + 1) + for n in range(0, n_max + 1): + if (m + n) % 2: + mn.append([m, n, 0]) + if n > 0: + mn.append([m, n, 1]) + + aberrations_mn = np.array(mn) + aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :] + + sub = aberrations_mn[:, 1] > 0 + aberrations_mn[sub, :] = aberrations_mn[sub, :][ + np.argsort(aberrations_mn[sub, 0]), : + ] + aberrations_mn[~sub, :] = aberrations_mn[~sub, :][ + np.argsort(aberrations_mn[~sub, 0]), : + ] + aberrations_num = aberrations_mn.shape[0] + sx, sy = probe_size dx, dy = probe_sampling + wavelength = electron_wavelength_angstrom(energy) + qx = xp.fft.fftfreq(sx, dx) qy = xp.fft.fftfreq(sy, dy) + qr2 = qx[:, None] ** 2 + qy[None, :] ** 2 + alpha = xp.sqrt(qr2) * wavelength + theta = xp.arctan2(qy[None, :], qx[:, None]) + + # Aberration basis + aberrations_basis = xp.ones((alpha.size, aberrations_num)) + + # Skip constant to avoid dividing by zero in normalization + for a0 in range(1, aberrations_num): + m, n, a = aberrations_mn[a0] + if n == 0: + # Radially symmetric basis + aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel() + + elif a == 0: + # cos coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.cos(n * theta) / (m + 1) + ).ravel() + else: + # sin coef + aberrations_basis[:, a0] = ( + alpha ** (m + 1) * xp.sin(n * theta) / (m + 1) + ).ravel() - qxa, qya = xp.meshgrid(qx, qy, indexing="ij") - q2 = qxa**2 + qya**2 - theta = xp.arctan2(qya, qxa) - - basis = [] - index = [] - - for n in range(max_angular_order + 1): - for m in range((max_radial_order - n) // 2 + 1): - basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) - index.append((m, n, 0)) - if n > 0: - basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) - index.append((m, n, 1)) - - basis = xp.array(basis) + # global scaling + aberrations_basis *= 2 * np.pi / wavelength - return basis, index + return aberrations_basis, aberrations_mn def fit_aberration_surface( complex_probe, probe_sampling, + energy, max_angular_order, max_radial_order, xp=np, ): """ """ probe_amp = xp.abs(complex_probe) - probe_angle = xp.angle(complex_probe) + probe_angle = -xp.angle(complex_probe) if xp is np: probe_angle = probe_angle.astype(np.float64) @@ -1592,21 +1630,47 @@ def fit_aberration_surface( unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) - basis, _ = aberrations_basis_function( + raveled_basis, _ = aberrations_basis_function( complex_probe.shape, probe_sampling, + energy, max_angular_order, max_radial_order, xp=xp, ) - raveled_basis = basis.reshape((basis.shape[0], -1)) raveled_weights = probe_amp.ravel() - Aw = raveled_basis.T * raveled_weights[:, None] + Aw = raveled_basis * raveled_weights[:, None] bw = unwrapped_angle.ravel() * raveled_weights coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] - fitted_angle = xp.tensordot(coeff, basis, axes=1) + fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape) return fitted_angle, coeff + + +def rotate_point(origin, point, angle): + """ + Rotate a point (x1, y1) counterclockwise by a given angle around + a given origin (x0, y0). + + Parameters + -------- + origin: 2-tuple of floats + (x0, y0) + point: 2-tuple of floats + (x1, y1) + angle: float (radians) + + Returns + -------- + rotated points (2-tuple) + + """ + ox, oy = origin + px, py = point + + qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy) + qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy) + return qx, qy diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py index 3ed00ec13..7df00a235 100644 --- a/py4DSTEM/process/polar/polar_analysis.py +++ b/py4DSTEM/process/polar/polar_analysis.py @@ -2,49 +2,72 @@ import numpy as np import matplotlib.pyplot as plt +from scipy.optimize import curve_fit +from scipy.ndimage import gaussian_filter from emdfile import tqdmnd -def calculate_FEM_global( +def calculate_radial_statistics( self, - use_median_local=False, - use_median_global=False, - plot_results=False, + plot_results_mean=False, + plot_results_var=False, figsize=(8, 4), returnval=False, returnfig=False, progress_bar=True, ): """ - Calculate fluctuation electron microscopy (FEM) statistics, including radial mean, - variance, and normalized variance. This function uses the original FEM definitions, - where the signal is computed pattern-by-pattern. + Calculate the radial statistics used in fluctuation electron microscopy (FEM) + and as an initial step in radial distribution function (RDF) calculation. + The computed quantities are the radial mean, variance, and normalized variance. + + There are several ways the means and variances can be computed. Here we first + compute the mean and standard deviation pattern by pattern, i.e. for + diffraction signal d(x,y; q,theta) we take + + d_mean_all(x,y; q) = \int_{0}^{2\pi} d(x,y; q,\theta) d\theta + d_var_all(x,y; q) = \int_{0}^{2\pi} + \( d(x,y; q,\theta) - d_mean_all(x,y; q,\theta) \)^2 d\theta + + Then we find the mean and variance profiles by taking the means of these + quantities over all scan positions: + + d_mean(q) = \sum_{x,y} d_mean_all(x,y; q) + d_var(q) = \sum_{x,y} d_var_all(x,y; q) + + and the normalized variance is d_var/d_mean. + + This follows the methods described in [@cophus TODO ADD CITATION]. - TODO - finish docstrings, add median statistics. Parameters -------- - self: PolarDatacube - Polar datacube used for measuring FEM properties. + plot_results_mean: bool + Toggles plotting the computed radial means + plot_results_var: bool + Toggles plotting the computed radial variances + figsize: 2-tuple + Size of output figures + returnval: bool + Toggles returning the answer. Answers are always stored internally. + returnfig: bool + Toggles returning figures that have been plotted. Only figures for + which `plot_results_*` is True are returned. Returns -------- radial_avg: np.array - Average radial intensity + Optional - returned iff returnval is True. The average radial intensity. radial_var: np.array - Variance in the radial dimension - - + Optional - returned iff returnval is True. The radial variance. + fig_means: 2-tuple (fig,ax) + Optional - returned iff returnfig is True. Plot of the radial means. + fig_var: 2-tuple (fig,ax) + Optional - returned iff returnfig is True. Plot of the radial variances. """ - # Get the dimensioned radial bins - self.scattering_vector = ( - self.radial_bins * self.qstep * self.calibration.get_Q_pixel_size() - ) - self.scattering_vector_units = self.calibration.get_Q_pixel_units() - - # init radial data array + # init radial data arrays self.radial_all = np.zeros( ( self._datacube.shape[0], @@ -52,72 +75,513 @@ def calculate_FEM_global( self.polar_shape[1], ) ) + self.radial_all_std = np.zeros( + ( + self._datacube.shape[0], + self._datacube.shape[1], + self.polar_shape[1], + ) + ) - # Compute the radial mean for each probe position + # Compute the radial mean and standard deviation for each probe position for rx, ry in tqdmnd( self._datacube.shape[0], self._datacube.shape[1], - desc="Global FEM", + desc="Radial statistics", unit=" probe positions", disable=not progress_bar, ): self.radial_all[rx, ry] = np.mean(self.data[rx, ry], axis=0) + self.radial_all_std[rx, ry] = np.sqrt( + np.mean((self.data[rx, ry] - self.radial_all[rx, ry][None]) ** 2, axis=0) + ) - self.radial_avg = np.mean(self.radial_all, axis=(0, 1)) + self.radial_mean = np.mean(self.radial_all, axis=(0, 1)) self.radial_var = np.mean( - (self.radial_all - self.radial_avg[None, None]) ** 2, axis=(0, 1) + (self.radial_all - self.radial_mean[None, None]) ** 2, axis=(0, 1) ) - self.radial_var_norm = self.radial_var / self.radial_avg**2 + + self.radial_var_norm = np.copy(self.radial_var) + sub = self.radial_mean > 0.0 + self.radial_var_norm[sub] /= self.radial_mean[sub] ** 2 + + # prepare answer + statistics = self.radial_mean, self.radial_var, self.radial_var_norm + if returnval: + ans = statistics if not returnfig else [statistics] + else: + ans = None if not returnfig else [] # plot results - if plot_results: + if plot_results_mean: + fig, ax = plot_radial_mean( + self, + figsize=figsize, + returnfig=True, + ) if returnfig: - fig, ax = plot_FEM_global( - self, - figsize=figsize, - returnfig=True, - ) - else: - plot_FEM_global( - self, - figsize=figsize, - ) - - # Return values - if returnval: + ans.append((fig, ax)) + if plot_results_var: + fig, ax = plot_radial_var_norm( + self, + figsize=figsize, + returnfig=True, + ) if returnfig: - return self.radial_avg, self.radial_var, fig, ax - else: - return self.radial_avg, self.radial_var + ans.append((fig, ax)) + + # return + return ans + + +def plot_radial_mean( + self, + log_x=False, + log_y=False, + figsize=(8, 4), + returnfig=False, +): + """ + Plot the radial means. + + Parameters + ---------- + log_x : bool + Toggle log scaling of the x-axis + log_y : bool + Toggle log scaling of the y-axis + figsize : 2-tuple + Size of the output figure + returnfig : bool + Toggle returning the figure + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.radial_mean, + ) + + if log_x: + ax.set_xscale("log") + if log_y: + ax.set_yscale("log") + + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") + ax.set_ylabel("Radial Mean") + if log_x and self.qq[0] == 0.0: + ax.set_xlim((self.qq[1], self.qq[-1])) else: - if returnfig: - return fig, ax - else: - pass + ax.set_xlim((self.qq[0], self.qq[-1])) + + if returnfig: + return fig, ax -def plot_FEM_global( +def plot_radial_var_norm( self, figsize=(8, 4), returnfig=False, ): """ - Plotting function for the global FEM. + Plot the radial variances. + + Parameters + ---------- + figsize : 2-tuple + Size of the output figure + returnfig : bool + Toggle returning the figure + """ fig, ax = plt.subplots(figsize=figsize) ax.plot( - self.scattering_vector, + self.qq, self.radial_var_norm, ) - ax.set_xlabel("Scattering Vector (" + self.scattering_vector_units + ")") + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") ax.set_ylabel("Normalized Variance") - ax.set_xlim((self.scattering_vector[0], self.scattering_vector[-1])) + ax.set_xlim((self.qq[0], self.qq[-1])) if returnfig: return fig, ax +def calculate_pair_dist_function( + self, + k_min=0.05, + k_max=None, + k_width=0.25, + k_lowpass=None, + k_highpass=None, + r_min=0.0, + r_max=20.0, + r_step=0.02, + damp_origin_fluctuations=True, + enforce_positivity=True, + density=None, + plot_background_fits=False, + plot_sf_estimate=False, + plot_reduced_pdf=True, + plot_pdf=False, + figsize=(8, 4), + maxfev=None, + returnval=False, + returnfig=False, +): + """ + Calculate the pair distribution function (PDF). + + First a background is calculated using primarily the signal at the highest + scattering vectors available, given by a sum of two exponentials ~exp(-k^2) + and ~exp(-k^4) modelling the single atom scattering factor plus a constant + offset. Next, the structure factor is computed as + + S(k) = (I(k) - bg(k)) * k / f(k) + + where k is the magnitude of the scattering vector, I(k) is the mean radial + signal, f(k) is the single atom scattering factor, and bg(k) is the total + background signal (i.e. f(k) plus a constant offset). S(k) is masked outside + of the selected fitting region of k-values [k_min,k_max] and low/high pass + filters are optionally applied. The structure factor is then inverted into + the reduced pair distribution function g(r) using + + g(r) = \frac{2}{\pi) \int sin( 2\pi r k ) S(k) dk + + The value of the integral is (optionally) damped to zero at the origin to + match the physical requirement that this condition holds. Finally, the + full PDF G(r) is computed if a known density is provided, using + + G(r) = 1 + [ \frac{2}{\pi} * g(r) / ( 4\pi * D * r dr ) ] + + This follows the methods described in [@cophus TODO ADD CITATION]. + + + Parameters + ---------- + k_min : number + Minimum scattering vector to include in the calculation + k_max : number or None + Maximum scattering vector to include in the calculation. Note that + this cutoff is used when calculating the structure factor - however it + is *not* used when estimating the background / single atom scattering + factor, which is best estimated from high scattering lengths. + k_width : number + The fitting window for the structure factor calculation [k_min,k_max] + includes a damped region at its edges, i.e. the signal is smoothly dampled + to zero in the regions [k_min, k_min+k_width] and [k_max-k_width,k_max] + k_lowpass : number or None + Lowpass filter, in units the scattering vector stepsize (i.e. self.qstep) + k_highpass : number or None + Highpass filter, in units the scattering vector stepsize (i.e. self.qstep) + r_min,r_max,r_step : numbers + Define the real space coordinates r that the PDF g(r) will be computed in. + The coordinates will be np.arange(r_min,r_max,r_step), given in units + inverse to the scattering vector units. + damp_origin_fluctuations : bool + The value of the PDF approaching the origin should be zero, however numerical + instability may result in non-physical finite values there. This flag toggles + damping the value of the PDF to zero near the origin. + enforce_positivity: + Force all pdf values to be >0. + density : number or None + The density of the sample, if known. If this is not provided, only the + reduced PDF is calculated. If this value is provided, the PDF is also + calculated. + plot_background_fits : bool + plot_sf_estimate : bool + plot_reduced_pdf=True : bool + plot_pdf : bool + figsize : 2-tuple + maxfev : integer or None + Max number of iterations to use when fitting the background + returnval: bool + Toggles returning the answer. Answers are always stored internally. + returnfig: bool + Toggles returning figures that have been plotted. Only figures for + which `plot_*` is True are returned. + """ + + # set up coordinates and scaling + k = self.qq + dk = k[1] - k[0] + k2 = k**2 + Ik = self.radial_mean + int_mean = np.mean(Ik) + sub_fit = k >= k_min + + # initial guesses for background coefs + const_bg = np.min(self.radial_mean) / int_mean + int0 = np.median(self.radial_mean) / int_mean - const_bg + sigma0 = np.mean(k) + coefs = [const_bg, int0, sigma0, int0, sigma0] + lb = [0, 0, 0, 0, 0] + ub = [np.inf, np.inf, np.inf, np.inf, np.inf] + # Weight the fit towards high k values + noise_est = k[-1] - k + dk + + # Estimate the mean atomic form factor + background + if maxfev is None: + coefs = curve_fit( + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma=noise_est[sub_fit], + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + )[0] + else: + coefs = curve_fit( + scattering_model, + k2[sub_fit], + Ik[sub_fit] / int_mean, + sigma=noise_est[sub_fit], + p0=coefs, + xtol=1e-8, + bounds=(lb, ub), + maxfev=maxfev, + )[0] + + coefs[0] *= int_mean + coefs[1] *= int_mean + coefs[3] *= int_mean + + # Calculate the mean atomic form factor without a constant offset + # coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4]) + # fk = scattering_model(k2, coefs_fk) + bg = scattering_model(k2, coefs) + fk = bg - coefs[0] + + # mask for structure factor estimate + if k_max is None: + k_max = np.max(k) + mask = np.clip( + np.minimum( + (k - 0.0) / k_width, + (k_max - k) / k_width, + ), + 0, + 1, + ) + mask = np.sin(mask * (np.pi / 2)) + + # Estimate the reduced structure factor S(k) + Sk = (Ik - bg) * k / fk + + # Masking edges of S(k) + mask_sum = np.sum(mask) + Sk = (Sk - np.sum(Sk * mask) / mask_sum) * mask + + # Filtering of S(k) + if k_lowpass is not None and k_lowpass > 0.0: + Sk = gaussian_filter(Sk, sigma=k_lowpass / dk, mode="nearest") + if k_highpass is not None and k_highpass > 0.0: + Sk_lowpass = gaussian_filter(Sk, sigma=k_highpass / dk, mode="nearest") + Sk -= Sk_lowpass + + # Calculate the PDF + r = np.arange(r_min, r_max, r_step) + ra, ka = np.meshgrid(r, k) + pdf_reduced = ( + (2 / np.pi) + * dk + * np.sum( + np.sin(2 * np.pi * ra * ka) * Sk[:, None], + axis=0, + ) + ) + + # Damp the unphysical fluctuations at the PDF origin + if damp_origin_fluctuations: + ind_max = np.argmax(pdf_reduced) + r_ind_max = r[ind_max] + r_mask = np.minimum(r / r_ind_max, 1.0) + r_mask = np.sin(r_mask * np.pi / 2) ** 2 + pdf_reduced *= r_mask + + # Store results + self.pdf_r = r + self.pdf_reduced = pdf_reduced + + self.Sk = Sk + self.fk = fk + self.bg = bg + self.offset = coefs[0] + self.Sk_mask = mask + + # if density is provided, we can estimate the full PDF + if density is not None: + pdf = pdf_reduced.copy() + pdf[1:] /= 4 * np.pi * density * r[1:] * (r[1] - r[0]) + pdf += 1 + + # damp and clip values below zero + if damp_origin_fluctuations: + pdf *= r_mask + if enforce_positivity: + pdf = np.maximum(pdf, 0.0) + + # store results + self.pdf = pdf + + # prepare answer + if density is None: + return_values = self.pdf_r, self.pdf_reduced + else: + return_values = self.pdf_r, self.pdf_reduced, self.pdf + if returnval: + ans = return_values if not returnfig else [return_values] + else: + ans = None if not returnfig else [] + + # Plots + if plot_background_fits: + fig, ax = self.plot_background_fits(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + if plot_sf_estimate: + fig, ax = self.plot_sf_estimate(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + if plot_reduced_pdf: + fig, ax = self.plot_reduced_pdf(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + if plot_pdf: + fig, ax = self.plot_pdf(figsize=figsize, returnfig=True) + if returnfig: + ans.append((fig, ax)) + + # return + return ans + + +def plot_background_fits( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.radial_mean, + color="k", + ) + ax.plot( + self.qq, + self.bg, + color="r", + ) + ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")") + ax.set_ylabel("Radial Mean") + ax.set_xlim((self.qq[0], self.qq[-1])) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("I(k) and Background Fit Estimates") + ax.set_ylim( + ( + np.min(self.radial_mean[self.radial_mean > 0]) * 0.8, + np.max(self.radial_mean * self.Sk_mask) * 1.25, + ) + ) + ax.set_yscale("log") + if returnfig: + return fig, ax + plt.show() + + +def plot_sf_estimate( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.qq, + self.Sk, + color="r", + ) + yr = (np.min(self.Sk), np.max(self.Sk)) + ax.set_ylim( + ( + yr[0] - 0.05 * (yr[1] - yr[0]), + yr[1] + 0.05 * (yr[1] - yr[0]), + ) + ) + ax.set_xlabel("Scattering Vector [A^-1]") + ax.set_ylabel("Reduced Structure Factor") + if returnfig: + return fig, ax + plt.show() + + +def plot_reduced_pdf( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.pdf_r, + self.pdf_reduced, + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Reduced Pair Distribution Function") + if returnfig: + return fig, ax + plt.show() + + +def plot_pdf( + self, + figsize=(8, 4), + returnfig=False, +): + """ + TODO + """ + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + self.pdf_r, + self.pdf, + color="r", + ) + ax.set_xlabel("Radius [A]") + ax.set_ylabel("Pair Distribution Function") + if returnfig: + return fig, ax + plt.show() + + # functions for inverting from reduced PDF back to S(k) + + # # invert + # ind_max = np.argmax(pdf_reduced* np.sqrt(r)) + # r_ind_max = r[ind_max-1] + # r_mask = np.minimum(r / (r_ind_max), 1.0) + # r_mask = np.sin(r_mask*np.pi/2)**2 + + # Sk_back_proj = (0.5*r_step)*np.sum( + # np.sin( + # 2*np.pi*ra*ka + # ) * pdf_corr[None,:],# * r_mask[None,:], + # # ) * pdf_corr[None,:],# * r_mask[None,:], + # axis=1, + # ) + + def calculate_FEM_local( self, figsize=(8, 6), @@ -143,4 +607,41 @@ def calculate_FEM_local( """ - 1 + 1 + pass + + +def scattering_model(k2, *coefs): + """ + The scattering model used to fit the PDF background. The fit + function is a constant plus two exponentials - one in k^2 and one + in k^4: + + f(k; c,i0,s0,i1,s1) = + c + i0*exp(k^2/-2*s0^2) + i1*exp(k^4/-2*s1^4) + + Parameters + ---------- + k2 : 1d array + the scattering vector squared + coefs : 5-tuple + Initial guesses at the parameters (c,i0,s0,i1,s1) + """ + coefs = np.squeeze(np.array(coefs)) + + const_bg = coefs[0] + int0 = coefs[1] + sigma0 = coefs[2] + int1 = coefs[3] + sigma1 = coefs[4] + + int_model = ( + const_bg + + int0 * np.exp(k2 / (-2 * sigma0**2)) + + int1 * np.exp(k2**2 / (-2 * sigma1**4)) + ) + + # (int1*sigma1)/(k2 + sigma1**2) + # int1*np.exp(k2/(-2*sigma1**2)) + # int1*np.exp(k2/(-2*sigma1**2)) + + return int_model diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index 2bfb205fe..4c06d4c09 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -94,9 +94,15 @@ def __init__( pass from py4DSTEM.process.polar.polar_analysis import ( - calculate_FEM_global, - plot_FEM_global, + calculate_radial_statistics, + calculate_pair_dist_function, calculate_FEM_local, + plot_radial_mean, + plot_radial_var_norm, + plot_background_fits, + plot_sf_estimate, + plot_reduced_pdf, + plot_pdf, ) from py4DSTEM.process.polar.polar_peaks import ( find_peaks_single_pattern, @@ -286,16 +292,36 @@ def transform(self): these values directly. Length 2 list of ints uses the calibrated origin value at this scan position. None uses the calibrated mean origin. + ellipse : tuple or None + Variable behavior depending on the arg type. Length 3 tuples uses + these values directly (a,b,theta). None uses the calibrated value. mask : boolean array or None A mask applied to the data before transformation. The value of - masked pixels (0's) in the output is determined by `returnval`. Note - that this mask is applied in combination with any mask at - PolarData.mask. - returnval : 'masked' or 'nan' or None - Controls the returned data. 'masked' returns a numpy masked array. - 'nan' returns a normal numpy array with masked pixels set to np.nan. - None returns a 2-tuple of numpy arrays - the transformed data with - masked pixels set to 0, and the transformed mask. + masked pixels in the output is determined by `returnval`. Note that + this mask is applied in combination with any mask at PolarData.mask. + mask_thresh : number + Pixels in the transformed mask with values below this number are + considered masked, and will be populated by the values specified + by `returnval`. + returnval : 'masked' or 'nan' or 'all' or 'zeros' or 'all_zeros' + Controls the returned data, including how un-sampled points + are handled. + - 'masked' returns a numpy masked array. + - 'nan' returns a normal numpy array with unsampled pixels set to + np.nan. + - 'all' returns a 4-tuple of numpy arrays - the transformed data + with unsampled pixels set to 'nan', the normalization array, the + normalization array scaled to account for the q-dependent + sampling density, and the polar boolean mask + - 'zeros' returns a normal numpy with unsampled pixels set to 0 + - 'all_zeros' returns the same 4-tuple as 'all', but with unsampled + pixels in the transformed data array set to zeros. + + Returns + -------- + variable + see `returnval`, above. Default is a masked array representing + the polar transformed data. """ return self._polar_data_getter._transform diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py index be9ae989e..4064fccaf 100644 --- a/py4DSTEM/process/polar/polar_peaks.py +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -167,7 +167,7 @@ def find_peaks_single_pattern( if remove_masked_peaks: peaks = np.delete( peaks, - mask_bool[peaks[:, 0], peaks[:, 1]] == False, + mask_bool[peaks[:, 0], peaks[:, 1]] == False, # noqa: E712 axis=0, ) diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain.py deleted file mode 100644 index db252f75b..000000000 --- a/py4DSTEM/process/strain.py +++ /dev/null @@ -1,601 +0,0 @@ -# Defines the Strain class - -from typing import Optional - -import matplotlib.pyplot as plt -import numpy as np -from py4DSTEM import PointList -from py4DSTEM.braggvectors import BraggVectors -from py4DSTEM.data import Data, RealSlice -from py4DSTEM.preprocess.utils import get_maxima_2D -from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show - - -class StrainMap(RealSlice, Data): - """ - Stores strain map. - - TODO add docs - - """ - - def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): - """ - TODO - """ - assert isinstance( - braggvectors, BraggVectors - ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" - - # initialize as a RealSlice - RealSlice.__init__( - self, - name=name, - data=np.empty( - ( - 6, - braggvectors.Rshape[0], - braggvectors.Rshape[1], - ) - ), - slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], - ) - - # set up braggvectors - # this assigns the bvs, ensures the origin is calibrated, - # and adds the strainmap to the bvs' tree - self.braggvectors = braggvectors - - # initialize as Data - Data.__init__(self) - - # set calstate - # this property is used only to check to make sure that - # the braggvectors being used throughout a workflow are - # the same. The state of calibration of the vectors is noted - # here, and then checked each time the vectors are used - - # if they differ, an error message and instructions for - # re-calibration are issued - self.calstate = self.braggvectors.calstate - assert self.calstate["center"], "braggvectors must be centered" - # get the BVM - # a new BVM using the current calstate is computed - self.bvm = self.braggvectors.histogram(mode="cal") - - # braggvector properties - - @property - def braggvectors(self): - return self._braggvectors - - @braggvectors.setter - def braggvectors(self, x): - assert isinstance( - x, BraggVectors - ), f".braggvectors must be BraggVectors, not type {type(x)}" - assert ( - x.calibration.origin is not None - ), f"braggvectors must have a calibrated origin" - self._braggvectors = x - self._braggvectors.tree(self, force=True) - - def reset_calstate(self): - """ - Resets the calibration state. This recomputes the BVM, and removes any computations - this StrainMap instance has stored, which will need to be recomputed. - """ - for attr in ( - "g0", - "g1", - "g2", - ): - if hasattr(self, attr): - delattr(self, attr) - self.calstate = self.braggvectors.calstate - pass - - # Class methods - - def choose_lattice_vectors( - self, - index_g0, - index_g1, - index_g2, - subpixel="multicorr", - upsample_factor=16, - sigma=0, - minAbsoluteIntensity=0, - minRelativeIntensity=0, - relativeToPeak=0, - minSpacing=0, - edgeBoundary=1, - maxNumPeaks=10, - figsize=(12, 6), - c_indices="lightblue", - c0="g", - c1="r", - c2="r", - c_vectors="r", - c_vectorlabels="w", - size_indices=20, - width_vectors=1, - size_vectorlabels=20, - vis_params={}, - returncalc=False, - returnfig=False, - ): - """ - Choose which lattice vectors to use for strain mapping. - - Overlays the bvm with the points detected via local 2D - maxima detection, plus an index for each point. User selects - 3 points using the overlaid indices, which are identified as - the origin and the termini of the lattice vectors g1 and g2. - - Parameters - ---------- - index_g0 : int - selected index for the origin - index_g1 : int - selected index for g1 - index_g2 :int - selected index for g2 - subpixel : str in ('pixel','poly','multicorr') - See the docstring for py4DSTEM.preprocess.get_maxima_2D - upsample_factor : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - sigma : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - minAbsoluteIntensity : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - minRelativeIntensity : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - relativeToPeak : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - minSpacing : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - edgeBoundary : number - See the py4DSTEM.preprocess.get_maxima_2D docstring - maxNumPeaks : int - See the py4DSTEM.preprocess.get_maxima_2D docstring - figsize : 2-tuple - the size of the figure - c_indices : color - color of the maxima - c0 : color - color of the origin - c1 : color - color of g1 point - c2 : color - color of g2 point - c_vectors : color - color of the g1/g2 vectors - c_vectorlabels : color - color of the vector labels - size_indices : number - size of the indices - width_vectors : number - width of the vectors - size_vectorlabels : number - size of the vector labels - vis_params : dict - additional visualization parameters passed to `show` - returncalc : bool - toggles returning the answer - returnfig : bool - toggles returning the figure - - Returns - ------- - (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter - """ - # validate inputs - for i in (index_g0, index_g1, index_g2): - assert isinstance(i, (int, np.integer)), "indices must be integers!" - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - # find the maxima - g = get_maxima_2D( - self.bvm.data, - subpixel=subpixel, - upsample_factor=upsample_factor, - sigma=sigma, - minAbsoluteIntensity=minAbsoluteIntensity, - minRelativeIntensity=minRelativeIntensity, - relativeToPeak=relativeToPeak, - minSpacing=minSpacing, - edgeBoundary=edgeBoundary, - maxNumPeaks=maxNumPeaks, - ) - - # get the lattice vectors - gx, gy = g["x"], g["y"] - g0 = gx[index_g0], gy[index_g0] - g1x = gx[index_g1] - g0[0] - g1y = gy[index_g1] - g0[1] - g2x = gx[index_g2] - g0[0] - g2y = gy[index_g2] - g0[1] - g1, g2 = (g1x, g1y), (g2x, g2y) - - # make the figure - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - show(self.bvm.data, figax=(fig, ax1), **vis_params) - show(self.bvm.data, figax=(fig, ax2), **vis_params) - - # Add indices to left panel - d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} - d0 = { - "x": gx[index_g0], - "y": gy[index_g0], - "size": size_indices, - "color": c0, - "fontweight": "bold", - "labels": [str(index_g0)], - } - d1 = { - "x": gx[index_g1], - "y": gy[index_g1], - "size": size_indices, - "color": c1, - "fontweight": "bold", - "labels": [str(index_g1)], - } - d2 = { - "x": gx[index_g2], - "y": gy[index_g2], - "size": size_indices, - "color": c2, - "fontweight": "bold", - "labels": [str(index_g2)], - } - add_pointlabels(ax1, d) - add_pointlabels(ax1, d0) - add_pointlabels(ax1, d1) - add_pointlabels(ax1, d2) - - # Add vectors to right panel - dg1 = { - "x0": gx[index_g0], - "y0": gy[index_g0], - "vx": g1[0], - "vy": g1[1], - "width": width_vectors, - "color": c_vectors, - "label": r"$g_1$", - "labelsize": size_vectorlabels, - "labelcolor": c_vectorlabels, - } - dg2 = { - "x0": gx[index_g0], - "y0": gy[index_g0], - "vx": g2[0], - "vy": g2[1], - "width": width_vectors, - "color": c_vectors, - "label": r"$g_2$", - "labelsize": size_vectorlabels, - "labelcolor": c_vectorlabels, - } - add_vector(ax2, dg1) - add_vector(ax2, dg2) - - # store vectors - self.g = g - self.g0 = g0 - self.g1 = g1 - self.g2 = g2 - - # return - if returncalc and returnfig: - return (g0, g1, g2), (fig, (ax1, ax2)) - elif returncalc: - return (g0, g1, g2) - elif returnfig: - return (fig, (ax1, ax2)) - else: - return - - def fit_lattice_vectors( - self, - x0=None, - y0=None, - max_peak_spacing=2, - mask=None, - plot=True, - vis_params={}, - returncalc=False, - ): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - x0 : floagt - x-coord of origin - y0 : float - y-coord of origin - max_peak_spacing: float - Maximum distance from the ideal lattice points - to include a peak for indexing - mask: bool - Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - plot:bool - plot results if tru - vis_params : dict - additional visualization parameters passed to `show` - returncalc : bool - if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map - """ - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - if x0 is None: - x0 = self.braggvectors.Qshape[0] / 2 - if y0 is None: - y0 = self.braggvectors.Qshape[0] / 2 - - # index braggvectors - from py4DSTEM.process.latticevectors import index_bragg_directions - - _, _, braggdirections = index_bragg_directions( - x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 - ) - - self.braggdirections = braggdirections - - if plot: - self.show_bragg_indexing( - self.bvm, - bragg_directions=braggdirections, - points=True, - **vis_params, - ) - - # add indicies to braggvectors - from py4DSTEM.process.latticevectors import add_indices_to_braggvectors - - bragg_vectors_indexed = add_indices_to_braggvectors( - self.braggvectors, - self.braggdirections, - maxPeakSpacing=max_peak_spacing, - qx_shift=self.braggvectors.Qshape[0] / 2, - qy_shift=self.braggvectors.Qshape[1] / 2, - mask=mask, - ) - - self.bragg_vectors_indexed = bragg_vectors_indexed - - # fit bragg vectors - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) - self.g1g2_map = g1g2_map - - if returncalc: - braggdirections, bragg_vectors_indexed, g1g2_map - - def get_strain( - self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs - ): - """ - mask: nd.array (bool) - Use lattice vectors from g1g2_map scan positions - wherever mask==True. If mask is None gets median strain - map from entire field of view. If mask is not None, gets - reference g1 and g2 from region and then calculates strain. - g_reference: nd.array of form [x,y] - G_reference (tupe): reference coordinate system for - xaxis_x and xaxis_y - flip_theta: bool - If True, flips rotation coordinate system - returncal: bool - It True, returns rotated map - """ - # check the calstate - assert ( - self.calstate == self.braggvectors.calstate - ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - - if mask is None: - mask = np.ones(self.g1g2_map.shape, dtype="bool") - - from py4DSTEM.process.latticevectors import get_strain_from_reference_region - - strainmap_g1g2 = get_strain_from_reference_region( - self.g1g2_map, - mask=mask, - ) - else: - from py4DSTEM.process.latticevectors import get_reference_g1g2 - - g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) - - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - - strainmap_g1g2 = get_strain_from_reference_g1g2( - self.g1g2_map, g1_ref, g2_ref - ) - - self.strainmap_g1g2 = strainmap_g1g2 - - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) - - from py4DSTEM.process.latticevectors import get_rotated_strain_map - - strainmap_rotated = get_rotated_strain_map( - self.strainmap_g1g2, - xaxis_x=g_reference[0], - xaxis_y=g_reference[1], - flip_theta=flip_theta, - ) - - self.strainmap_rotated = strainmap_rotated - - from py4DSTEM.visualize import show_strain - - figsize = kwargs.pop("figsize", (14, 4)) - vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) - vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) - ticknumber = kwargs.pop("ticknumber", 3) - bkgrd = kwargs.pop("bkgrd", False) - axes_plots = kwargs.pop("axes_plots", ()) - - fig, ax = show_strain( - self.strainmap_rotated, - vrange_exx=vrange_exx, - vrange_theta=vrange_theta, - ticknumber=ticknumber, - axes_plots=axes_plots, - bkgrd=bkgrd, - figsize=figsize, - **kwargs, - returnfig=True, - ) - - if not np.all(mask == True): - ax[0][0].imshow(mask, alpha=0.2, cmap="binary") - ax[0][1].imshow(mask, alpha=0.2, cmap="binary") - ax[1][0].imshow(mask, alpha=0.2, cmap="binary") - ax[1][1].imshow(mask, alpha=0.2, cmap="binary") - - if returncalc: - return self.strainmap_rotated - - def show_lattice_vectors( - ar, - x0, - y0, - g1, - g2, - color="r", - width=1, - labelsize=20, - labelcolor="w", - returnfig=False, - **kwargs, - ): - """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" - fig, ax = show(ar, returnfig=True, **kwargs) - - # Add vectors - dg1 = { - "x0": x0, - "y0": y0, - "vx": g1[0], - "vy": g1[1], - "width": width, - "color": color, - "label": r"$g_1$", - "labelsize": labelsize, - "labelcolor": labelcolor, - } - dg2 = { - "x0": x0, - "y0": y0, - "vx": g2[0], - "vy": g2[1], - "width": width, - "color": color, - "label": r"$g_2$", - "labelsize": labelsize, - "labelcolor": labelcolor, - } - add_vector(ax, dg1) - add_vector(ax, dg2) - - if returnfig: - return fig, ax - else: - plt.show() - return - - def show_bragg_indexing( - self, - ar, - bragg_directions, - voffset=5, - hoffset=0, - color="w", - size=20, - points=True, - pointcolor="r", - pointsize=50, - returnfig=False, - **kwargs, - ): - """ - Shows an array with an overlay describing the Bragg directions - - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. - """ - assert isinstance(bragg_directions, PointList) - for k in ("qx", "qy", "h", "k"): - assert k in bragg_directions.data.dtype.fields - - fig, ax = show(ar, returnfig=True, **kwargs) - d = { - "bragg_directions": bragg_directions, - "voffset": voffset, - "hoffset": hoffset, - "color": color, - "size": size, - "points": points, - "pointsize": pointsize, - "pointcolor": pointcolor, - } - add_bragg_index_labels(ax, d) - - if returnfig: - return fig, ax - else: - plt.show() - return - - def copy(self, name=None): - name = name if name is not None else self.name + "_copy" - strainmap_copy = StrainMap(self.braggvectors) - for attr in ( - "g", - "g0", - "g1", - "g2", - "calstate", - "bragg_directions", - "bragg_vectors_indexed", - "g1g2_map", - "strainmap_g1g2", - "strainmap_rotated", - ): - if hasattr(self, attr): - setattr(strainmap_copy, attr, getattr(self, attr)) - - for k in self.metadata.keys(): - strainmap_copy.metadata = self.metadata[k].copy() - return strainmap_copy - - # IO methods - - # read - @classmethod - def _get_constructor_args(cls, group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = RealSlice._get_constructor_args(group) - args = { - "data": ar_constr_args["data"], - "name": ar_constr_args["name"], - } - return args diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py new file mode 100644 index 000000000..b487c916b --- /dev/null +++ b/py4DSTEM/process/strain/__init__.py @@ -0,0 +1,10 @@ +from py4DSTEM.process.strain.strain import StrainMap +from py4DSTEM.process.strain.latticevectors import ( + index_bragg_directions, + add_indices_to_braggvectors, + fit_lattice_vectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_strain_from_reference_g1g2, + get_rotated_strain_map, +) diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py new file mode 100644 index 000000000..dcff91709 --- /dev/null +++ b/py4DSTEM/process/strain/latticevectors.py @@ -0,0 +1,469 @@ +# Functions for indexing the Bragg directions + +import numpy as np +from emdfile import PointList, PointListArray, tqdmnd +from numpy.linalg import lstsq +from py4DSTEM.data import RealSlice + + +def index_bragg_directions(x0, y0, gx, gy, g1, g2): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + The approach is to solve the matrix equation + ``alpha = beta * M`` + where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions, + beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the + h,k indices. + + Args: + x0 (float): x-coord of origin + y0 (float): y-coord of origin + gx (1d array): x-coord of the reciprocal lattice vectors + gy (1d array): y-coord of the reciprocal lattice vectors + g1 (2-tuple of floats): g1x,g1y + g2 (2-tuple of floats): g2x,g2y + + Returns: + (3-tuple) A 3-tuple containing: + + * **h**: *(ndarray of ints)* first index of the bragg directions + * **k**: *(ndarray of ints)* second index of the bragg directions + * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the + indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y + coords 'h' and 'k' contain h and k. + """ + # Get beta, the matrix of lattice vectors + beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]]) + + # Get alpha, the matrix of measured bragg angles + alpha = np.vstack([gx - x0, gy - y0]) + + # Calculate M, the matrix of peak positions + M = lstsq(beta, alpha, rcond=None)[0].T + M = np.round(M).astype(int) + + # Get h,k + h = M[:, 0] + k = M[:, 1] + + # Store in a PointList + coords = [("qx", float), ("qy", float), ("h", int), ("k", int)] + temp_array = np.zeros([], dtype=coords) + bragg_directions = PointList(data=temp_array) + bragg_directions.add_data_by_field((gx, gy, h, k)) + mask = np.zeros(bragg_directions["qx"].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) + + return h, k, bragg_directions + + +def add_indices_to_braggvectors( + braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None +): + """ + Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, + identify the indices for each peak in the PointListArray braggpeaks. + Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus + three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak + indices with the ints (h,k) and indicating whether the peak was successfully indexed + or not with the bool index_mask. If `mask` is specified, only the locations where + mask is True are indexed. + + Args: + braggpeaks (PointListArray): the braggpeaks to index. Must contain + the coordinates 'qx', 'qy', and 'intensity' + lattice (PointList): the positions (qx,qy) of the (h,k) lattice points. + Must contain the coordinates 'qx', 'qy', 'h', and 'k' + maxPeakSpacing (float): Maximum distance from the ideal lattice points + to include a peak for indexing + qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList + relative to the `braggpeaks` PointListArray + mask (bool): Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + + Returns: + (PointListArray): The original braggpeaks pointlistarray, with new coordinates + 'h', 'k', containing the indices of each indexable peak. + """ + + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + + if mask is None: + mask = np.ones(braggpeaks.Rshape, dtype=bool) + + assert ( + mask.shape == braggpeaks.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + + coords = [ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ] + + indexed_braggpeaks = PointListArray( + dtype=coords, + shape=braggpeaks.Rshape, + ) + + calstate = braggpeaks.calstate + + # loop over all the scan positions + for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + if mask[Rx, Ry]: + pl = braggpeaks.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + ( + pl.data["qy"][i] - lattice.data["qy"] + qy_shift + ) ** 2 + ind = np.argmin(r2) + if r2[ind] <= maxPeakSpacing**2: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + lattice.data["h"][ind], + lattice.data["k"][ind], + ) + ) + + return indexed_braggpeaks + + +def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (7-tuple) A 7-tuple containing: + + * **x0**: *(float)* the x-coord of the origin of the best-fit lattice. + * **y0**: *(float)* the y-coord of the origin + * **g1x**: *(float)* x-coord of the first lattice vector + * **g1y**: *(float)* y-coord of the first lattice vector + * **g2x**: *(float)* x-coord of the second lattice vector + * **g2y**: *(float)* y-coord of the second lattice vector + * **error**: *(float)* the fit error + """ + assert isinstance(braggpeaks, PointList) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + braggpeaks = braggpeaks.copy() + + # Remove unindexed peaks + if "index_mask" in braggpeaks.dtype.names: + deletemask = braggpeaks.data["index_mask"] == False # noqa:E712 + braggpeaks.remove(deletemask) + + # Check to ensure enough peaks are present + if braggpeaks.length < minNumPeaks: + return None, None, None, None, None, None, None + + # Get M, the matrix of (h,k) indices + h, k = braggpeaks.data["h"], braggpeaks.data["k"] + M = np.vstack((np.ones_like(h, dtype=int), h, k)).T + + # Get alpha, the matrix of measured Bragg peak positions + alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T + + # Get weighted matrices + weights = braggpeaks.data["intensity"] + weighted_M = M * weights[:, np.newaxis] + weighted_alpha = alpha * weights[:, np.newaxis] + + # Solve for lattice vectors + beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0] + x0, y0 = beta[0, 0], beta[0, 1] + g1x, g1y = beta[1, 0], beta[1, 1] + g2x, g2y = beta[2, 0], beta[2, 1] + + # Calculate the error + alpha_calculated = np.matmul(M, beta) + error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1)) + error = np.sum(error * weights) / np.sum(weights) + + return x0, y0, g1x, g1y, g2x, g2y, error + + +def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): + """ + Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some + known (h,k) indexing. + + Args: + braggpeaks (PointList): A 6 coordinate PointList containing the data to fit. + Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a + weighting factor when fitting), 'h','k' (indexing). May optionally also + contain 'index_mask' (bool), indicating which peaks have been successfully + indixed and should be used. + x0 (float): x-coord of the origin + y0 (float): y-coord of the origin + minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks + which can be indexed, return None for all return parameters + + Returns: + (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays: + + * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice + * ``g1g2_map.get_slice('y0')`` y-coord of the origin + * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector + * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector + * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector + * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector + * ``g1g2_map.get_slice('error')`` the fit error + * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful + fits + """ + assert isinstance(braggpeaks, PointListArray) + assert np.all( + [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")] + ) + + # Make RealSlice to contain outputs + slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask") + g1g2_map = RealSlice( + data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])), + slicelabels=slicelabels, + name="g1g2_map", + ) + + # Fit lattice vectors + for Rx, Ry in tqdmnd( + braggpeaks.shape[0], + braggpeaks.shape[1], + desc="Fitting lattice vectors", + unit="DP", + unit_scale=True, + ): + braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) + qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( + braggpeaks_curr, x0, y0, minNumPeaks + ) + # Store data + if g1x is not None: + g1g2_map.get_slice("x0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("y0").data[Rx, Ry] = qx0 + g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x + g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y + g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x + g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y + g1g2_map.get_slice("error").data[Rx, Ry] = error + g1g2_map.get_slice("mask").data[Rx, Ry] = 1 + + return g1g2_map + + +def get_reference_g1g2(g1g2_map, mask): + """ + Gets a pair of reference lattice vectors from a region of real space specified by + mask. Takes the median of the lattice vectors in g1g2_map within the specified + region. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever + mask==True + + Returns: + (2-tuple of 2-tuples) A 2-tuple containing: + + * **g1**: *(2-tuple)* first reference lattice vector (x,y) + * **g2**: *(2-tuple)* second reference lattice vector (x,y) + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")] + ) + assert mask.dtype == bool + g1x = np.median(g1g2_map.get_slice("g1x").data[mask]) + g1y = np.median(g1g2_map.get_slice("g1y").data[mask]) + g2x = np.median(g1g2_map.get_slice("g2x").data[mask]) + g2y = np.median(g1g2_map.get_slice("g2y").data[mask]) + return (g1x, g1y), (g2x, g2y) + + +def get_strain_from_reference_g1g2(g1g2_map, g1, g2): + """ + Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map + g1g2_map. + + Note that this function will return the strain map oriented with respect to the x/y + axes of diffraction space - to rotate the coordinate system, use + get_rotated_strain_map(). Calibration of the rotational misalignment between real and + diffraction space may also be necessary. + + Args: + g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data + under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for + fit_lattice_vectors_all_DPs() for more information. + g1 (2-tuple): first reference lattice vector (x,y) + g2 (2-tuple): second reference lattice vector (x,y) + + Returns: + (RealSlice) the strain map; contains the elements of the infinitessimal strain + matrix, in the following 5 arrays: + + * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect + to x + * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect + to y + * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect + to y + * ``strain_map.get_slice('theta')``: rotation of lattice with respect to + reference + * ``strain_map.get_slice('mask')``: 0/False indicates unknown values + + Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical + """ + assert isinstance(g1g2_map, RealSlice) + assert np.all( + [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")] + ) + + # Get RealSlice for output storage + R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"), + name="strain_map", + ) + + # Get reference lattice matrix + g1x, g1y = g1 + g2x, g2y = g2 + M = np.array([[g1x, g1y], [g2x, g2y]]) + + for Rx, Ry in tqdmnd( + R_Nx, + R_Ny, + desc="Calculating strain", + unit="DP", + unit_scale=True, + ): + # Get lattice vectors for DP at Rx,Ry + alpha = np.array( + [ + [ + g1g2_map.get_slice("g1x").data[Rx, Ry], + g1g2_map.get_slice("g1y").data[Rx, Ry], + ], + [ + g1g2_map.get_slice("g2x").data[Rx, Ry], + g1g2_map.get_slice("g2y").data[Rx, Ry], + ], + ] + ) + # Get transformation matrix + beta = lstsq(M, alpha, rcond=None)[0].T + + # Get the infinitesimal strain matrix + strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] + strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] + strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 + strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 + strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ + Rx, Ry + ] + return strain_map + + +def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): + """ + Starting from a strain map defined with respect to the xy coordinate system of + diffraction space, i.e. where exx and eyy are the compression/tension along the Qx + and Qy directions, respectively, get a strain map defined with respect to some other + right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, + xaxis_y). + + Args: + xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector + along the new x-axis + unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the + infinitessimal strain matrix elements, stored at + * unrotated_strain_map.get_slice('e_xx') + * unrotated_strain_map.get_slice('e_xy') + * unrotated_strain_map.get_slice('e_yy') + * unrotated_strain_map.get_slice('theta') + + Returns: + (RealSlice) the rotated counterpart to unrotated_strain_map, with the + rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate + system + """ + assert isinstance(unrotated_strain_map, RealSlice) + assert np.all( + [ + key in ["e_xx", "e_xy", "e_yy", "theta", "mask"] + for key in unrotated_strain_map.slicelabels + ] + ) + theta = -np.arctan2(xaxis_y, xaxis_x) + cost = np.cos(theta) + sint = np.sin(theta) + cost2 = cost**2 + sint2 = sint**2 + + Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx, Ry)), + slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"], + name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)), + ) + + rotated_strain_map.data[0, :, :] = ( + cost2 * unrotated_strain_map.get_slice("e_xx").data + - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + sint2 * unrotated_strain_map.get_slice("e_yy").data + ) + rotated_strain_map.data[1, :, :] = ( + cost + * sint + * ( + unrotated_strain_map.get_slice("e_xx").data + - unrotated_strain_map.get_slice("e_yy").data + ) + + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data + ) + rotated_strain_map.data[2, :, :] = ( + sint2 * unrotated_strain_map.get_slice("e_xx").data + + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data + + cost2 * unrotated_strain_map.get_slice("e_yy").data + ) + if flip_theta is True: + rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data + else: + rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data + rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data + return rotated_strain_map diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py new file mode 100644 index 000000000..621fc820e --- /dev/null +++ b/py4DSTEM/process/strain/strain.py @@ -0,0 +1,1247 @@ +# Defines the Strain class + +import warnings +from typing import Optional + +import matplotlib.pyplot as plt +from matplotlib.patches import Circle +from matplotlib.collections import PatchCollection +from mpl_toolkits.axes_grid1 import make_axes_locatable +import numpy as np +from py4DSTEM import PointList, PointListArray, tqdmnd +from py4DSTEM.braggvectors import BraggVectors +from py4DSTEM.data import Data, RealSlice +from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.process.strain.latticevectors import ( + add_indices_to_braggvectors, + fit_lattice_vectors_all_DPs, + get_reference_g1g2, + get_rotated_strain_map, + get_strain_from_reference_g1g2, + index_bragg_directions, +) +from py4DSTEM.visualize import ( + show, + add_bragg_index_labels, + add_pointlabels, + add_vector, + ax_addaxes, + ax_addaxes_QtoR, +) + + +class StrainMap(RealSlice, Data): + """ + Storage and processing methods for 4D-STEM datasets. + + """ + + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): + """ + Parameters + ---------- + braggvectors : BraggVectors + The Bragg vectors + name : str + The name of the strainmap + + Returns + ------- + A new StrainMap instance. + """ + assert isinstance( + braggvectors, BraggVectors + ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" + + # initialize as a RealSlice + RealSlice.__init__( + self, + name=name, + data=np.empty( + ( + 6, + braggvectors.Rshape[0], + braggvectors.Rshape[1], + ) + ), + slicelabels=["e_xx", "e_yy", "e_xy", "theta", "mask", "error"], + ) + + # set up braggvectors + # this assigns the bvs, ensures the origin is calibrated, + # and adds the strainmap to the bvs' tree + self.braggvectors = braggvectors + + # initialize as Data + Data.__init__(self) + + # set calstate + # this property is used only to check to make sure that + # the braggvectors being used throughout a workflow are + # the same. The state of calibration of the vectors is noted + # here, and then checked each time the vectors are used - + # if they differ, an error message and instructions for + # re-calibration are issued + self.calstate = self.braggvectors.calstate + assert self.calstate["center"], "braggvectors must be centered" + if self.calstate["rotate"] is False: + warnings.warn( + ("Real to reciprocal space rotation not calibrated"), + UserWarning, + ) + + # get the BVM + # a new BVM using the current calstate is computed + self.bvm = self.braggvectors.histogram(mode="cal") + + # braggvector properties + + @property + def braggvectors(self): + return self._braggvectors + + @braggvectors.setter + def braggvectors(self, x): + assert isinstance( + x, BraggVectors + ), f".braggvectors must be BraggVectors, not type {type(x)}" + assert ( + x.calibration.origin is not None + ), "braggvectors must have a calibrated origin" + self._braggvectors = x + self._braggvectors.tree(self, force=True) + + @property + def rshape(self): + return self._braggvectors.Rshape + + @property + def qshape(self): + return self._braggvectors.Qshape + + @property + def origin(self): + return self.calibration.get_origin_mean() + + def reset_calstate(self): + """ + Resets the calibration state. This recomputes the BVM, and removes any computations + this StrainMap instance has stored, which will need to be recomputed. + """ + for attr in ( + "g0", + "g1", + "g2", + ): + if hasattr(self, attr): + delattr(self, attr) + self.calstate = self.braggvectors.calstate + pass + + # Class methods + + def choose_basis_vectors( + self, + index_g1=None, + index_g2=None, + index_origin=None, + subpixel="multicorr", + upsample_factor=16, + sigma=0, + minAbsoluteIntensity=0, + minRelativeIntensity=0, + relativeToPeak=0, + minSpacing=0, + edgeBoundary=1, + maxNumPeaks=10, + x0=None, + y0=None, + figsize=(14, 9), + c_indices="lightblue", + c0="g", + c1="r", + c2="r", + c_vectors="r", + c_vectorlabels="w", + size_indices=15, + width_vectors=1, + size_vectorlabels=15, + vis_params={}, + returncalc=False, + returnfig=False, + ): + """ + Choose basis lattice vectors g1 and g2 for strain mapping. + + Overlays the bvm with the points detected via local 2D + maxima detection, plus an index for each point. Three points + are selected which correspond to the origin, and the basis + reciprocal lattice vectors g1 and g2. By default these are + automatically located; the user can override and select these + manually using the `index_*` arguments. + + Parameters + ---------- + index_g1 : int + selected index for g1 + index_g2 :int + selected index for g2 + index_origin : int + selected index for the origin + subpixel : str in ('pixel','poly','multicorr') + See the docstring for py4DSTEM.preprocess.get_maxima_2D + upsample_factor : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + sigma : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + figsize : 2-tuple + the size of the figure + c_indices : color + color of the maxima + c0 : color + color of the origin + c1 : color + color of g1 point + c2 : color + color of g2 point + c_vectors : color + color of the g1/g2 vectors + c_vectorlabels : color + color of the vector labels + size_indices : number + size of the indices + width_vectors : number + width of the vectors + size_vectorlabels : number + size of the vector labels + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + toggles returning the answer + returnfig : bool + toggles returning the figure + + Returns + ------- + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or the latter two + """ + # validate inputs + for i in (index_origin, index_g1, index_g2): + assert isinstance(i, (int, np.integer)) or ( + i is None + ), "indices must be integers!" + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # find the maxima + + g = get_maxima_2D( + self.bvm.data, + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + ) + + # guess the origin and g1 g2 vectors if indices aren't provided + if np.any([x is None for x in (index_g1, index_g2, index_origin)]): + # get distances and angles from calibrated origin + g_dists = np.hypot(g["x"] - self.origin[0], g["y"] - self.origin[1]) + g_angles = np.angle( + g["x"] - self.origin[0] + 1j * (g["y"] - self.origin[1]) + ) + + # guess the origin + if index_origin is None: + index_origin = np.argmin(g_dists) + g_dists[index_origin] = 2 * np.max(g_dists) + + # guess g1 + if index_g1 is None: + index_g1 = np.argmin(g_dists) + g_dists[index_g1] = 2 * np.max(g_dists) + + # guess g2 + if index_g2 is None: + angle_scaling = np.cos(g_angles - g_angles[index_g1]) ** 2 + index_g2 = np.argmin(g_dists * (angle_scaling + 0.1)) + + # get the lattice vectors + gx, gy = g["x"], g["y"] + g0 = gx[index_origin], gy[index_origin] + g1x = gx[index_g1] - g0[0] + g1y = gy[index_g1] - g0[1] + g2x = gx[index_g2] - g0[0] + g2y = gy[index_g2] - g0[1] + g1, g2 = (g1x, g1y), (g2x, g2y) + + # index the lattice vectors + _, _, braggdirections = index_bragg_directions( + g0[0], g0[1], g["x"], g["y"], g1, g2 + ) + + # make the figure + fig, ax = plt.subplots(1, 3, figsize=figsize) + show(self.bvm.data, figax=(fig, ax[0]), **vis_params) + show(self.bvm.data, figax=(fig, ax[1]), **vis_params) + self.show_bragg_indexing( + self.bvm.data, + bragg_directions=braggdirections, + points=True, + figax=(fig, ax[2]), + size=size_indices, + **vis_params, + ) + + # Add indices to left panel + d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} + d0 = { + "x": gx[index_origin], + "y": gy[index_origin], + "size": size_indices, + "color": c0, + "fontweight": "bold", + "labels": [str(index_origin)], + } + d1 = { + "x": gx[index_g1], + "y": gy[index_g1], + "size": size_indices, + "color": c1, + "fontweight": "bold", + "labels": [str(index_g1)], + } + d2 = { + "x": gx[index_g2], + "y": gy[index_g2], + "size": size_indices, + "color": c2, + "fontweight": "bold", + "labels": [str(index_g2)], + } + add_pointlabels(ax[0], d) + add_pointlabels(ax[0], d0) + add_pointlabels(ax[0], d1) + add_pointlabels(ax[0], d2) + + # Add vectors to right panel + dg1 = { + "x0": gx[index_origin], + "y0": gy[index_origin], + "vx": g1[0], + "vy": g1[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_1$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + dg2 = { + "x0": gx[index_origin], + "y0": gy[index_origin], + "vx": g2[0], + "vy": g2[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_2$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + add_vector(ax[1], dg1) + add_vector(ax[1], dg2) + + # store vectors + self.g = g + self.g0 = g0 + self.g1 = g1 + self.g2 = g2 + + # center the bragg directions and store + braggdirections.data["qx"] -= self.origin[0] + braggdirections.data["qy"] -= self.origin[1] + self.braggdirections = braggdirections + + # return + if returncalc and returnfig: + return (self.g0, self.g1, self.g2, self.braggdirections), (fig, ax) + elif returncalc: + return (self.g0, self.g1, self.g2, self.braggdirections) + elif returnfig: + return (fig, ax) + else: + return + + def set_max_peak_spacing( + self, + max_peak_spacing, + returnfig=False, + **vis_params, + ): + """ + Set the size of the regions of diffraction space in which detected Bragg + peaks will be indexed and included in subsequent fitting of basis + vectors, and visualize those regions. + + Parameters + ---------- + max_peak_spacing : number + The maximum allowable distance in pixels between a detected Bragg peak and + the indexed maxima found in `choose_basis_vectors` for the detected + peak to be indexed + returnfig : bool + Toggles returning the figure + vis_params : dict + Any additional arguments are passed to the `show` function when + visualization the BVM + """ + # set the max peak spacing + self.max_peak_spacing = max_peak_spacing + + # make the figure + fig, ax = show( + self.bvm.data, + returnfig=True, + **vis_params, + ) + + # make the circle patch collection + patches = [] + qx = self.braggdirections["qx"] + qy = self.braggdirections["qy"] + origin = self.origin + for idx in range(len(qx)): + c = Circle( + xy=(qy[idx] + origin[1], qx[idx] + origin[0]), + radius=self.max_peak_spacing, + edgecolor="r", + fill=False, + ) + patches.append(c) + pc = PatchCollection(patches, match_original=True) + + # draw the circles + ax.add_collection(pc) + + # return + if returnfig: + return fig, ax + else: + plt.show() + + def fit_basis_vectors( + self, mask=None, max_peak_spacing=None, vis_params={}, returncalc=False + ): + """ + Fit the basis lattice vectors to the detected Bragg peaks at each + scan position. + + First, the lattice vectors at each scan position are indexed using the + basis vectors g1 and g2 specified previously with `choose_basis_vectors` + Detected Bragg peaks which are farther from the set of lattice vectors + found in `choose_basis vectors` than the maximum peak spacing are + ignored; the maximum peak spacing can be set previously by calling + `set_max_peak_spacing` or by specifying the `max_peak_spacing` argument + here. A fit is then performed to refine the values of g1 and g2 at each + scan position, fitting the basis vectors to all detected and indexed + peaks, weighting the peaks according to their intensity. + + Parameters + ---------- + mask : 2d boolean array + A real space shaped Boolean mask indicating scan positions at which + to fit the lattice vectors. + max_peak_spacing : float + Maximum distance from the ideal lattice points to include a peak + for indexing + vis_params : dict + Visualization parameters for showing the max peak spacing; ignored + if `max_peak_spacing` is not set + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # handle the max peak spacing + if max_peak_spacing is not None: + self.set_max_peak_spacing(max_peak_spacing, **vis_params) + assert hasattr(self, "max_peak_spacing"), "Set the maximum peak spacing!" + + # index the bragg vectors + + # handle the mask + if mask is None: + mask = np.ones(self.braggvectors.Rshape, dtype=bool) + assert ( + mask.shape == self.braggvectors.Rshape + ), "mask must have same shape as pointlistarray" + assert mask.dtype == bool, "mask must be boolean" + self.mask = mask + + # set up new braggpeaks PLA + indexed_braggpeaks = PointListArray( + dtype=[ + ("qx", float), + ("qy", float), + ("intensity", float), + ("h", int), + ("k", int), + ], + shape=self.braggvectors.Rshape, + ) + + # loop over all the scan positions + # and perform indexing, excluding peaks outside of max_peak_spacing + calstate = self.braggvectors.calstate + for Rx, Ry in tqdmnd( + mask.shape[0], + mask.shape[1], + desc="Indexing Bragg scattering", + unit="DP", + unit_scale=True, + ): + if mask[Rx, Ry]: + pl = self.braggvectors.get_vectors( + Rx, + Ry, + center=True, + ellipse=calstate["ellipse"], + rotate=calstate["rotate"], + pixel=False, + ) + for i in range(pl.data.shape[0]): + r = np.hypot( + pl.data["qx"][i] - self.braggdirections.data["qx"], + pl.data["qy"][i] - self.braggdirections.data["qy"], + ) + ind = np.argmin(r) + if r[ind] <= self.max_peak_spacing: + indexed_braggpeaks[Rx, Ry].add_data_by_field( + ( + pl.data["qx"][i], + pl.data["qy"][i], + pl.data["intensity"][i], + self.braggdirections.data["h"][ind], + self.braggdirections.data["k"][ind], + ) + ) + self.bragg_vectors_indexed = indexed_braggpeaks + + # fit bragg vectors + g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) + self.g1g2_map = g1g2_map + + # update the mask + g1g2_mask = self.g1g2_map["mask"].data.astype("bool") + self.mask = np.logical_and(self.mask, g1g2_mask) + + # return + if returncalc: + return self.bragg_vectors_indexed, self.g1g2_map + + def get_strain( + self, gvects=None, coordinate_rotation=0, returncalc=False, **kwargs + ): + """ + Compute the strain as the deviation of the basis reciprocal lattice + vectors which have been fit at each scan position with respect to a + pair of reference lattice vectors, determined by the argument `gvects`. + + Parameters + ---------- + gvects : None or 2d-array or tuple + Specifies how to select the reference lattice vectors. If None, + use the median of the fit lattice vectors over the whole dataset. + If a 2d array is passed, it should be real space shaped and boolean. + In this case, uses the median of the fit lattice vectors in all scan + positions where this array is True. Otherwise, should be a length 2 + tuple of length 2 array/list/tuples, which are used directly as + g1 and g2. + coordinate_rotation : number + Rotate the reference coordinate system counterclockwise by this + amount, in degrees + returncal : bool + It True, returns rotated map + """ + # confirm that the calstate hasn't changed + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # get the reference g-vectors + if gvects is None: + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, self.mask) + elif isinstance(gvects, np.ndarray): + assert gvects.shape == self.rshape + assert gvects.dtype == bool + g1_ref, g2_ref = get_reference_g1g2( + self.g1g2_map, np.logical_and(gvects, self.mask) + ) + else: + g1_ref = np.array(gvects[0]) + g2_ref = np.array(gvects[1]) + + # find the strain + strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) + self.strainmap_g1g2 = strainmap_g1g2 + + # get the reference coordinate system + theta = np.radians(coordinate_rotation) + xaxis_x = np.cos(theta) + xaxis_y = np.sin(theta) + self.coordinate_rotation_degrees = coordinate_rotation + self.coordinate_rotation_radians = theta + + # get the strain in the reference coordinates + strainmap_rotated = get_rotated_strain_map( + self.strainmap_g1g2, + xaxis_x=xaxis_x, + xaxis_y=xaxis_y, + flip_theta=False, + ) + + # store the data + self.data[0] = strainmap_rotated["e_xx"].data + self.data[1] = strainmap_rotated["e_yy"].data + self.data[2] = strainmap_rotated["e_xy"].data + self.data[3] = strainmap_rotated["theta"].data + self.data[4] = strainmap_rotated["mask"].data + + # plot the results + fig, ax = self.show_strain( + **kwargs, + returnfig=True, + ) + + # return + if returncalc: + return self.strainmap + + def get_reference_g1g2(self, ROI): + """ + Get reference g1,g2 vectors by taking the median fit vectors + in the specified ROI. + + Parameters + ---------- + ROI : real space shaped 2d boolean ndarray + Use scan positions where ROI is True + + Returns + ------- + g1_ref,g2_ref : 2 tuple of length 2 ndarrays + """ + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, ROI) + return g1_ref, g2_ref + + def show_strain( + self, + vrange=[-3, 3], + vrange_theta=[-3, 3], + vrange_exx=None, + vrange_exy=None, + vrange_eyy=None, + show_cbars=None, + bordercolor="k", + borderwidth=1, + titlesize=18, + ticklabelsize=10, + ticknumber=5, + unitlabelsize=16, + cmap="RdBu_r", + cmap_theta="PRGn", + mask_color="k", + color_axes="k", + show_legend=True, + show_gvects=True, + color_gvects="r", + legend_camera_length=1.6, + scale_gvects=0.6, + layout="square", + figsize=None, + returnfig=False, + ): + """ + Display a strain map, showing the 4 strain components + (e_xx,e_yy,e_xy,theta), and masking each image with + strainmap.get_slice('mask') + + Parameters + ---------- + vrange : length 2 list or tuple + The colorbar intensity range for exx,eyy, and exy. + vrange_theta : length 2 list or tuple + The colorbar intensity range for theta. + vrange_exx : length 2 list or tuple + The colorbar intensity range for exx; overrides `vrange` + for exx + vrange_exy : length 2 list or tuple + The colorbar intensity range for exy; overrides `vrange` + for exy + vrange_eyy : length 2 list or tuple + The colorbar intensity range for eyy; overrides `vrange` + for eyy + bkgrd : bool + Overlay a mask over background pixels + show_cbars : None or a tuple of strings + Show colorbars for the specified axes. Valid strings are + 'exx', 'eyy', 'exy', and 'theta'. + bordercolor : color + Color for the image borders + borderwidth : number + Width of the image borders + titlesize : number + Size of the image titles + ticklabelsize : number + Size of the colorbar ticks + ticknumber : number + Number of ticks on colorbars + unitlabelsize : number + Size of the units label on the colorbars + cmap : colormap + Colormap for exx, exy, and eyy + cmap_theta : colormap + Colormap for theta + mask_color : color + Color for the background mask + color_axes : color + Color for the legend coordinate axes + show_gvects : bool + Toggles displaying the g-vectors in the legend + color_gvects : color + Color for the legend g-vectors + legend_camera_length : number + The distance the legend is viewed from; a smaller number yields + a larger legend + scale_gvects : number + Scaling for the legend g-vectors relative to the coordinate axes + layout : int + Determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize : length 2 tuple of numbers + Size of the figure + returnfig : bool + Toggles returning the figure + """ + + from py4DSTEM.visualize import show_strain + + fig, ax = show_strain( + self, + vrange=vrange, + vrange_theta=vrange_theta, + vrange_exx=vrange_exx, + vrange_exy=vrange_exy, + vrange_eyy=vrange_eyy, + show_cbars=show_cbars, + bordercolor=bordercolor, + borderwidth=borderwidth, + titlesize=titlesize, + ticklabelsize=ticklabelsize, + ticknumber=ticknumber, + unitlabelsize=unitlabelsize, + cmap=cmap, + cmap_theta=cmap_theta, + mask_color=mask_color, + color_axes=color_axes, + show_legend=show_legend, + rotation_deg=np.rad2deg(self.coordinate_rotation_radians), + show_gvects=show_gvects, + g1=self.g1, + g2=self.g2, + color_gvects=color_gvects, + legend_camera_length=legend_camera_length, + scale_gvects=scale_gvects, + layout=layout, + figsize=figsize, + returnfig=True, + ) + + # show/return + if not returnfig: + plt.show() + return + else: + return fig, ax + + def show_reference_directions( + self, + im_uncal=None, + im_cal=None, + color_axes="linen", + color_gvects="r", + origin_uncal=None, + origin_cal=None, + camera_length=1.8, + visp_uncal={"scaling": "log"}, + visp_cal={"scaling": "log"}, + layout="horizontal", + titlesize=16, + size_labels=14, + figsize=None, + returnfig=False, + ): + """ + Show the reference coordinate system used to compute the strain + overlaid over calibrated and uncalibrated diffraction space images. + + The diffraction images used can be specificied with the `im_uncal` + and `im_cal` arguments, and default to the uncalibrated and calibrated + Bragg vector maps. The `rotate_cal` argument causes the `im_cal` array + to be rotated by -QR rotation from the calibration metadata, so that an + uncalibrated image (like a raw diffraction image or mean or max + diffraction pattern) can be passed to the `im_cal` argument. + + Parameters + ---------- + im_uncal : 2d array or None + Uncalibrated diffraction space image to dispay; defaults to + the maximal diffraction image. + im_cal : 2d array or None + Calibrated diffraction space image to display; defaults to + the calibrated Bragg vector map. + color_axes : color + The color of the overlaid coordinate axes + color_gvects : color + The color of the g-vectors + origin_uncal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the uncalibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + origin_cal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the calibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + camera_length : number + Determines the length of the overlaid coordinate axes; a smaller + number yields larger axes. + visp_uncal : dict + Visualization parameters for the uncalibrated diffraction image. + visp_cal : dict + Visualization parameters for the calibrated diffraction image. + layout : str; either "horizontal" or "vertical" + Determines the layout of the visualization. + titlesize : number + The size of the plot titles + size_labels : number + The size of the axis labels + figsize : length 2 tuple of numbers or None + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Set up the figure + assert layout in ("horizontal", "vertical") + + # Set the figsize + if figsize is None: + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) + if layout == "horizontal": + figsize = (10 * ratio, 8 / ratio) + else: + figsize = (8 * ratio, 12 / ratio) + + # Create the figure + if layout == "horizontal": + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + else: + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) + + # prepare images + if im_uncal is None: + im_uncal = self.braggvectors.histogram(mode="raw") + if im_cal is None: + im_cal = self.braggvectors.histogram(mode="cal") + + # display images + show(im_cal, figax=(fig, ax1), **visp_cal) + show(im_uncal, figax=(fig, ax2), **visp_uncal) + ax1.set_title("Calibrated", size=titlesize) + ax2.set_title("Uncalibrated", size=titlesize) + + # Get the coordinate axes + + # get the directions + + # calibrated + rotation = self.coordinate_rotation_radians + xaxis_cal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_cal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + + # uncalibrated + QRrot = self.calibration.get_QR_rotation() + rotation = np.sum([self.coordinate_rotation_radians, -QRrot]) + xaxis_uncal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_uncal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + # inversion + if self.calibration.get_QR_flip(): + xaxis_uncal = np.array([xaxis_uncal[1], xaxis_uncal[0]]) + yaxis_uncal = np.array([yaxis_uncal[1], yaxis_uncal[0]]) + + # set the lengths + Lmean = np.mean([im_cal.shape[0], im_cal.shape[1]]) / 2 + xaxis_cal *= Lmean / camera_length + yaxis_cal *= Lmean / camera_length + xaxis_uncal *= Lmean / camera_length + yaxis_uncal *= Lmean / camera_length + + # Get the g-vectors + + # calibrated + g1_cal = np.array(self.g1) + g2_cal = np.array(self.g2) + + # uncalibrated + R = np.array([[np.cos(QRrot), -np.sin(QRrot)], [np.sin(QRrot), np.cos(QRrot)]]) + g1_uncal = np.matmul(g1_cal, R) + g2_uncal = np.matmul(g2_cal, R) + # inversion + if self.calibration.get_QR_flip(): + g1_uncal = np.array([g1_uncal[1], g1_uncal[0]]) + g2_uncal = np.array([g2_uncal[1], g2_uncal[0]]) + + # Set origin positions + if origin_uncal is None: + origin_uncal = self.calibration.get_origin_mean() + if origin_cal is None: + origin_cal = self.calibration.get_origin_mean() + + # Draw calibrated coordinate axes + coordax_width = Lmean * 2 / 100 + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=xaxis_cal[1], + dy=xaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=yaxis_cal[1], + dy=yaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax1.text( + x=origin_cal[1] + xaxis_cal[1] * 1.16, + y=origin_cal[0] + xaxis_cal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax1.text( + x=origin_cal[1] + yaxis_cal[1] * 1.16, + y=origin_cal[0] + yaxis_cal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw uncalibrated coordinate axes + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=xaxis_uncal[1], + dy=xaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=yaxis_uncal[1], + dy=yaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax2.text( + x=origin_uncal[1] + xaxis_uncal[1] * 1.16, + y=origin_uncal[0] + xaxis_uncal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax2.text( + x=origin_uncal[1] + yaxis_uncal[1] * 1.16, + y=origin_uncal[0] + yaxis_uncal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw the calibrated g-vectors + + # draw the g vectors + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=g1_cal[1], + dy=g1_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=g2_cal[1], + dy=g2_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax1.text( + x=origin_cal[1] + g1_cal[1] * 1.16, + y=origin_cal[0] + g1_cal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax1.text( + x=origin_cal[1] + g2_cal[1] * 1.16, + y=origin_cal[0] + g2_cal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw the uncalibrated g-vectors + + # draw the g vectors + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=g1_uncal[1], + dy=g1_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=g2_uncal[1], + dy=g2_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax2.text( + x=origin_uncal[1] + g1_uncal[1] * 1.16, + y=origin_uncal[0] + g1_uncal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax2.text( + x=origin_uncal[1] + g2_uncal[1] * 1.16, + y=origin_uncal[0] + g2_uncal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # show/return + if not returnfig: + plt.show() + return + else: + return fig, (ax1, ax2) + + def show_lattice_vectors( + ar, + x0, + y0, + g1, + g2, + color="r", + width=1, + labelsize=20, + labelcolor="w", + returnfig=False, + **kwargs, + ): + """ + Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). + g1 and g2 are 2-tuples (gx,gy). + """ + fig, ax = show(ar, returnfig=True, **kwargs) + + # Add vectors + dg1 = { + "x0": x0, + "y0": y0, + "vx": g1[0], + "vy": g1[1], + "width": width, + "color": color, + "label": r"$g_1$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + dg2 = { + "x0": x0, + "y0": y0, + "vx": g2[0], + "vy": g2[1], + "width": width, + "color": color, + "label": r"$g_2$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + add_vector(ax, dg1) + add_vector(ax, dg2) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def show_bragg_indexing( + self, + ar, + bragg_directions, + voffset=5, + hoffset=0, + color="w", + size=20, + points=True, + pointcolor="r", + pointsize=50, + figax=None, + returnfig=False, + **kwargs, + ): + """ + Shows an array with an overlay describing the Bragg directions + + Parameters + ---------- + ar : np.ndarray + The display image + bragg_directions : PointList + The Bragg scattering directions. Must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. + """ + assert isinstance(bragg_directions, PointList) + for k in ("qx", "qy", "h", "k"): + assert k in bragg_directions.data.dtype.fields + + if figax is None: + fig, ax = show(ar, returnfig=True, **kwargs) + else: + fig = figax[0] + ax = figax[1] + show(ar, figax=figax, **kwargs) + + d = { + "bragg_directions": bragg_directions, + "voffset": voffset, + "hoffset": hoffset, + "color": color, + "size": size, + "points": points, + "pointsize": pointsize, + "pointcolor": pointcolor, + } + add_bragg_index_labels(ax, d) + + if returnfig: + return fig, ax + else: + return + + def copy(self, name=None): + name = name if name is not None else self.name + "_copy" + strainmap_copy = StrainMap(self.braggvectors) + for attr in ( + "g", + "g0", + "g1", + "g2", + "calstate", + "bragg_directions", + "bragg_vectors_indexed", + "g1g2_map", + "strainmap_g1g2", + "strainmap_rotated", + "mask", + ): + if hasattr(self, attr): + setattr(strainmap_copy, attr, getattr(self, attr)) + + for k in self.metadata.keys(): + strainmap_copy.metadata = self.metadata[k].copy() + return strainmap_copy + + # TODO IO methods + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = RealSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args diff --git a/py4DSTEM/process/utils/cross_correlate.py b/py4DSTEM/process/utils/cross_correlate.py index 50de91e33..df3612658 100644 --- a/py4DSTEM/process/utils/cross_correlate.py +++ b/py4DSTEM/process/utils/cross_correlate.py @@ -6,7 +6,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np diff --git a/py4DSTEM/process/utils/multicorr.py b/py4DSTEM/process/utils/multicorr.py index bc07390bb..58c5fc051 100644 --- a/py4DSTEM/process/utils/multicorr.py +++ b/py4DSTEM/process/utils/multicorr.py @@ -15,7 +15,7 @@ try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 4ef2e1d8a..e2bf4307c 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -24,7 +24,7 @@ def clear_output(wait=True): try: import cupy as cp -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): cp = np diff --git a/py4DSTEM/process/wholepatternfit/wp_models.py b/py4DSTEM/process/wholepatternfit/wp_models.py index 3d53c1743..c0907a11e 100644 --- a/py4DSTEM/process/wholepatternfit/wp_models.py +++ b/py4DSTEM/process/wholepatternfit/wp_models.py @@ -1117,7 +1117,7 @@ def __init__( name="Complex Overlapped Disk Lattice", verbose=False, ): - return NotImplementedError( + raise NotImplementedError( "This model type has not been updated for use with the new architecture." ) diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index 26b0b89d5..904dceb29 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -190,7 +190,7 @@ def get_module_states(state_dict: dict) -> dict: # check that all the depencies could be imported i.e. state == True # and set the state of the module to that - module_states[key] = all(temp_lst) == True + module_states[key] = all(temp_lst) is True return module_states @@ -338,7 +338,7 @@ def check_module_functionality(state_dict: dict) -> None: # check that all the depencies could be imported i.e. state == True # and set the state of the module to that - module_states[key] = all(temp_lst) == True + module_states[key] = all(temp_lst) is True # Print out the state of all the modules in colour code for key, val in module_states.items(): @@ -375,12 +375,12 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs): # check that CUDA is detected correctly cuda_availability = cp.cuda.is_available() if cuda_availability: - s = f" CUDA is Available " + s = " CUDA is Available " s = create_success(s) s = f"{s: <80}" print(s) else: - s = f" CUDA is Unavailable " + s = " CUDA is Unavailable " s = create_failure(s) s = f"{s: <80}" print(s) diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 224f1fb74..103751b29 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1 +1 @@ -__version__ = "0.14.4" +__version__ = "0.14.9" diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index 996bb89b3..32baff443 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -408,7 +408,7 @@ def add_ellipses(ax, d): (cent[1], cent[0]), 2 * _b, 2 * _a, - -np.degrees(_theta), + angle=-np.degrees(_theta), color=col, fill=f, alpha=_alpha, @@ -832,7 +832,14 @@ def add_scalebar(ax, d): labelpos_y = y0 # Add line - ax.plot((yi, yf), (xi, xf), lw=width, color=color, alpha=alpha) + ax.plot( + (yi, yf), + (xi, xf), + color=color, + alpha=alpha, + lw=width, + solid_capstyle="butt", + ) # Add label if label: diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index fb99de5ae..8462eec7d 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -8,6 +8,7 @@ from matplotlib.axes import Axes from matplotlib.colors import is_color_like from matplotlib.figure import Figure +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM.data import Calibration, DiffractionSlice, RealSlice from py4DSTEM.visualize.overlay import ( add_annuli, @@ -25,7 +26,7 @@ def show( ar, - figsize=(8, 8), + figsize=(5, 5), cmap="gray", scaling="none", intensity_range="ordered", @@ -75,7 +76,8 @@ def show( theta=None, title=None, show_fft=False, - **kwargs + show_cbar=False, + **kwargs, ): """ General visualization function for 2D arrays. @@ -302,14 +304,15 @@ def show( does not add a scalebar. If a dict is passed, it is propagated to the add_scalebar function which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. - show_fft (Bool): if True, plots 2D-fft of array + show_fft (bool): if True, plots 2D-fft of array + show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() Returns: if returnfig==False (default), the figure is plotted and nothing is returned. if returnfig==True, return the figure and the axis. """ - if scalebar == True: + if scalebar is True: scalebar = {} # Alias dep @@ -366,7 +369,9 @@ def show( from py4DSTEM.visualize import show if show_fft: - ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy()))) + n0 = ar.shape + w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None] + ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy()))) for a0 in range(num_images): im = show( ar[a0], @@ -410,7 +415,7 @@ def show( if ( hasattr(ar, "calibration") and (ar.calibration is not None) - and (scalebar != False) + and (scalebar is not False) ): cal = ar.calibration er = ".calibration attribute must be a Calibration instance" @@ -488,12 +493,12 @@ def show( if np.all(np.isnan(_ar)): _ar[:, :] = 0 if intensity_range == "absolute": - if vmin != None: + if vmin is not None: if vmin > 0.0: vmin = np.log(vmin) else: vmin = np.min(_ar[_mask]) - if vmax != None: + if vmax is not None: vmax = np.log(vmax) elif scaling == "power": if power_offset is False: @@ -509,9 +514,9 @@ def show( _ar = np.power(ar.copy(), power) _mask = np.ones_like(_ar.data, dtype=bool) if intensity_range == "absolute": - if vmin != None: + if vmin is not None: vmin = np.power(vmin, power) - if vmax != None: + if vmax is not None: vmax = np.power(vmax, power) else: raise Exception @@ -605,6 +610,10 @@ def show( ax.matshow( mask_display, cmap=cmap, alpha=mask_alpha, vmin=vmin, vmax=vmax ) + if show_cbar: + ax_divider = make_axes_locatable(ax) + c_axis = ax_divider.append_axes("right", size="7%") + fig.colorbar(cax, cax=c_axis) # ...or, plot its histogram else: hist, bin_edges = np.histogram( @@ -895,7 +904,7 @@ def show_Q( gridlabelsize=12, gridlabelcolor="k", alpha=0.35, - **kwargs + **kwargs, ): """ Shows a diffraction space image with options for several overlays to define the scale, @@ -1135,7 +1144,7 @@ def show_rectangles( alpha=0.25, linewidth=2, returnfig=False, - **kwargs + **kwargs, ): """ Visualization function which plots a 2D array with one or more overlayed rectangles. @@ -1188,7 +1197,7 @@ def show_circles( alpha=0.3, linewidth=2, returnfig=False, - **kwargs + **kwargs, ): """ Visualization function which plots a 2D array with one or more overlayed circles. @@ -1243,7 +1252,7 @@ def show_ellipses( alpha=0.3, linewidth=2, returnfig=False, - **kwargs + **kwargs, ): """ Visualization function which plots a 2D array with one or more overlayed ellipses. @@ -1299,7 +1308,7 @@ def show_annuli( alpha=0.3, linewidth=2, returnfig=False, - **kwargs + **kwargs, ): """ Visualization function which plots a 2D array with one or more overlayed annuli. @@ -1351,7 +1360,7 @@ def show_points( open_circles=False, title=None, returnfig=False, - **kwargs + **kwargs, ): """ Plots a 2D array with one or more points. diff --git a/py4DSTEM/visualize/vis_RQ.py b/py4DSTEM/visualize/vis_RQ.py index 6c2fbff3c..85c0eb042 100644 --- a/py4DSTEM/visualize/vis_RQ.py +++ b/py4DSTEM/visualize/vis_RQ.py @@ -15,7 +15,7 @@ def show_selected_dp( pointsize=50, pointcolor="r", scaling="log", - **kwargs + **kwargs, ): """ """ dp = datacube.data[rx, ry, :, :] diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index 89f09606a..125a2ce67 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -1,6 +1,5 @@ from matplotlib import cm, colors as mcolors, pyplot as plt import numpy as np -from matplotlib.colors import hsv_to_rgb from matplotlib.patches import Wedge from mpl_toolkits.axes_grid1 import make_axes_locatable from scipy.spatial import Voronoi @@ -17,6 +16,7 @@ ) from py4DSTEM.visualize.vis_grid import show_image_grid from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR +from colorspacious import cspace_convert def show_elliptical_fit( @@ -31,7 +31,7 @@ def show_elliptical_fit( linewidth_ann=2, linewidth_ell=2, returnfig=False, - **kwargs + **kwargs, ): """ Plots an elliptical curve over its annular fit region. @@ -97,7 +97,7 @@ def show_amorphous_ring_fit( ellipse_alpha=0.7, ellipse_lw=2, returnfig=False, - **kwargs + **kwargs, ): """ Display a diffraction pattern with a fit to its amorphous ring, interleaving @@ -156,16 +156,16 @@ def show_amorphous_ring_fit( mask=np.logical_not(mask), mask_color="empty", returnfig=True, - returnclipvals=True, + return_intensity_range=True, **kwargs, ) show( fit, scaling=scaling, figax=(fig, ax), - clipvals="manual", - min=vmin, - max=vmax, + intensity_range="absolute", + vmin=vmin, + vmax=vmax, cmap=cmap_fit, mask=mask, mask_color="empty", @@ -225,7 +225,7 @@ def show_qprofile( ticklabelsize=14, grid=True, label=None, - **kwargs + **kwargs, ): """ Plots a diffraction space radial profile. @@ -302,7 +302,7 @@ def show_voronoi( color_lines="w", max_dist=None, returnfig=False, - **kwargs + **kwargs, ): """ words @@ -375,7 +375,7 @@ def show_class_BPs_grid( axsize=(6, 6), titlesize=0, get_bordercolor=None, - **kwargs + **kwargs, ): """ words @@ -404,308 +404,6 @@ def show_class_BPs_grid( return fig, axs -def show_strain( - strainmap, - vrange_exx, - vrange_theta, - vrange_exy=None, - vrange_eyy=None, - flip_theta=False, - bkgrd=True, - show_cbars=("exx", "eyy", "exy", "theta"), - bordercolor="k", - borderwidth=1, - titlesize=24, - ticklabelsize=16, - ticknumber=5, - unitlabelsize=24, - show_axes=True, - axes_x0=0, - axes_y0=0, - xaxis_x=1, - xaxis_y=0, - axes_length=10, - axes_width=1, - axes_color="r", - xaxis_space="Q", - labelaxes=True, - QR_rotation=0, - axes_labelsize=12, - axes_labelcolor="r", - axes_plots=("exx"), - cmap="RdBu_r", - layout=0, - figsize=(12, 12), - returnfig=False, -): - """ - Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') - - Args: - strainmap (RealSlice): - vrange_exx (length 2 list or tuple): - vrange_theta (length 2 list or tuple): - vrange_exy (length 2 list or tuple): - vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle - bkgrd (bool): - show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a - tuple containing any, all, or none of ('exx','eyy','exy','theta'). - bordercolor (color): - borderwidth (number): - titlesize (number): - ticklabelsize (number): - ticknumber (number): number of ticks on colorbars - unitlabelsize (number): - show_axes (bool): - axes_x0 (number): - axes_y0 (number): - xaxis_x (number): - xaxis_y (number): - axes_length (number): - axes_width (number): - axes_color (color): - xaxis_space (string): must be 'Q' or 'R' - labelaxes (bool): - QR_rotation (number): - axes_labelsize (number): - axes_labelcolor (color): - axes_plots (tuple of strings): controls if coordinate axes showing the - orientation of the strain matrices are overlaid over any of the plots. - Must be a tuple of strings containing any, all, or none of - ('exx','eyy','exy','theta'). - cmap (colormap): - layout=0 (int): determines the layout of the grid which the strain components - will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). - figsize (length 2 tuple of numbers): - returnfig (bool): - """ - # Lookup table for different layouts - assert layout in (0, 1, 2) - layout_lookup = { - 0: ["left", "right", "left", "right"], - 1: ["bottom", "bottom", "bottom", "bottom"], - 2: ["right", "right", "right", "right"], - } - layout_p = layout_lookup[layout] - - # Contrast limits - if vrange_exy is None: - vrange_exy = vrange_exx - if vrange_eyy is None: - vrange_eyy = vrange_exx - for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): - assert len(vrange) == 2, "vranges must have length 2" - vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 - vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 - vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 - # theta is plotted in units of degrees - vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( - 180.0 / np.pi - ) - - # Get images - e_xx = np.ma.array( - strainmap.get_slice("e_xx").data, mask=strainmap.get_slice("mask").data == False - ) - e_yy = np.ma.array( - strainmap.get_slice("e_yy").data, mask=strainmap.get_slice("mask").data == False - ) - e_xy = np.ma.array( - strainmap.get_slice("e_xy").data, mask=strainmap.get_slice("mask").data == False - ) - theta = np.ma.array( - strainmap.get_slice("theta").data, - mask=strainmap.get_slice("mask").data == False, - ) - if flip_theta == True: - theta = -theta - - # Plot - if layout == 0: - fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) - elif layout == 1: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) - else: - fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) - cax11 = show( - e_xx, - figax=(fig, ax11), - vmin=vmin_exx, - vmax=vmax_exx, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax12 = show( - e_yy, - figax=(fig, ax12), - vmin=vmin_eyy, - vmax=vmax_eyy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax21 = show( - e_xy, - figax=(fig, ax21), - vmin=vmin_exy, - vmax=vmax_exy, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - cax22 = show( - theta, - figax=(fig, ax22), - vmin=vmin_theta, - vmax=vmax_theta, - intensity_range="absolute", - cmap=cmap, - returncax=True, - ) - ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) - ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) - ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) - ax22.set_title(r"$\theta$", size=titlesize) - - # Add black background - if bkgrd: - mask = np.ma.masked_where( - strainmap.get_slice("mask").data.astype(bool), - np.zeros_like(strainmap.get_slice("mask").data), - ) - ax11.matshow(mask, cmap="gray") - ax12.matshow(mask, cmap="gray") - ax21.matshow(mask, cmap="gray") - ax22.matshow(mask, cmap="gray") - - # Colorbars - show_cbars = np.array( - [ - "exx" in show_cbars, - "eyy" in show_cbars, - "exy" in show_cbars, - "theta" in show_cbars, - ] - ) - if np.any(show_cbars): - divider11 = make_axes_locatable(ax11) - divider12 = make_axes_locatable(ax12) - divider21 = make_axes_locatable(ax21) - divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) - cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) - cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) - cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) - for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( - range(4), - show_cbars, - (cax11, cax12, cax21, cax22), - (cbax11, cbax12, cbax21, cbax22), - (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), - (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), - (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), - ("% ", " %", "% ", r" $^\circ$"), - ): - if show_cbar: - ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) - if ind < 3: - ticklabels = np.round( - np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), - decimals=2, - ).astype(str) - else: - ticklabels = np.round( - np.linspace( - (180 / np.pi) * vmin, - (180 / np.pi) * vmax, - ticknumber, - endpoint=True, - ), - decimals=2, - ).astype(str) - - if tickside in ("left", "right"): - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="vertical" - ) - cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) - cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) - cbax.yaxis.set_label_position(tickside) - else: - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="horizontal" - ) - cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) - cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) - cbax.xaxis.set_label_position(tickside) - else: - cbax.axis("off") - - # Add coordinate axes - if show_axes: - assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array( - [ - "exx" in axes_plots, - "eyy" in axes_plots, - "exy" in axes_plots, - "theta" in axes_plots, - ] - ) - for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): - if _show: - if xaxis_space == "R": - ax_addaxes( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - else: - ax_addaxes_QtoR( - _ax, - xaxis_x, - xaxis_y, - axes_length, - axes_x0, - axes_y0, - QR_rotation, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - - # Add borders - if bordercolor is not None: - for ax in (ax11, ax12, ax21, ax22): - for s in ["bottom", "top", "left", "right"]: - ax.spines[s].set_color(bordercolor) - ax.spines[s].set_linewidth(borderwidth) - ax.set_xticks([]) - ax.set_yticks([]) - - if not returnfig: - plt.show() - return - else: - axs = ((ax11, ax12), (ax21, ax22)) - return fig, axs - - def show_pointlabels( ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs ): @@ -732,7 +430,7 @@ def select_point( color_selected="r", size=20, returnfig=False, - **kwargs + **kwargs, ): """ Show enumerated index labels for a set of points, with one selected point highlighted @@ -857,7 +555,7 @@ def show_selected_dps( HW=None, figsize_im=(6, 6), figsize_dp=(4, 4), - **kwargs + **kwargs, ): """ Shows two plots: first, a real space image overlaid with colored dots @@ -937,15 +635,20 @@ def show_selected_dps( ) -def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): +def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1): """ complex_data (array): complex array to plot vmin (float) : minimum absolute value vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + power (float) : power to raise amplitude to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.5) """ amp = np.abs(complex_data) + phase = np.angle(complex_data) + + if power is not None: + amp = amp**power + if np.isclose(np.max(amp), np.min(amp)): if vmin is None: vmin = 0 @@ -966,36 +669,40 @@ def Complex2RGB(complex_data, vmin=None, vmax=None, hue_start=0, invert=False): amp = np.where(amp < vmin, vmin, amp) amp = np.where(amp > vmax, vmax, amp) + amp = ((amp - vmin) / vmax).clip(1e-16, 1) - phase = np.angle(complex_data) + np.deg2rad(hue_start) - amp /= np.max(amp) - rgb = np.zeros(phase.shape + (3,)) - rgb[..., 0] = 0.5 * (np.sin(phase) + 1) * amp - rgb[..., 1] = 0.5 * (np.sin(phase + np.pi / 2) + 1) * amp - rgb[..., 2] = 0.5 * (-np.sin(phase) + 1) * amp + J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff + C = np.minimum(chroma_boost * 98 * J / 123, 110) + h = np.rad2deg(phase) + 180 - return 1 - rgb if invert else rgb + JCh = np.stack((J, C, h), axis=-1) + rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) + return rgb -def add_colorbar_arg(cax, vmin=None, vmax=None, hue_start=0, invert=False): + +def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5): """ - cax : axis to add cbar too - vmin (float) : minimum absolute value - vmax (float) : maximum absolute value - hue_start (float) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme + cax : axis to add cbar to + chroma_boost (float): boosts chroma for higher-contrast (~1-2.25) + c (float) : constant chroma value + j (float) : constant luminance value """ - z = np.exp(1j * np.linspace(-np.pi, np.pi, 200)) - rgb_vals = Complex2RGB(z, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert) + + h = np.linspace(0, 360, 256, endpoint=False) + J = np.full_like(h, j) + C = np.full_like(h, np.minimum(c * chroma_boost, 110)) + JCh = np.stack((J, C, h), axis=-1) + rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1) newcmp = mcolors.ListedColormap(rgb_vals) norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi) - cb1 = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax) - cb1.set_label("arg", rotation=0, ha="center", va="bottom") - cb1.ax.yaxis.set_label_coords(0.5, 1.01) - cb1.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) - cb1.set_ticklabels( + cb.set_label("arg", rotation=0, ha="center", va="bottom") + cb.ax.yaxis.set_label_coords(0.5, 1.01) + cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])) + cb.set_ticklabels( [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"] ) @@ -1004,14 +711,14 @@ def show_complex( ar_complex, vmin=None, vmax=None, + power=None, + chroma_boost=1, cbar=True, scalebar=False, pixelunits="pixels", pixelsize=1, returnfig=False, - hue_start=0, - invert=False, - **kwargs + **kwargs, ): """ Function to plot complex arrays @@ -1023,13 +730,13 @@ def show_complex( vmax (float, optional) : maximum absolute value if None, vmin/vmax are set to fractions of the distribution of pixel values in the array, e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels - cbar (bool, optional) : if True, include color wheel + power (float,optional) : power to raise amplitude to + chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25) + cbar (bool, optional) : if True, include color bar scalebar (bool, optional) : if True, adds scale bar pixelunits (str, optional) : units for scalebar pixelsize (float, optional) : size of one pixel in pixelunits for scalebar returnfig (bool, optional) : if True, the function returns the tuple (figure,axis) - hue_start (float, optional) : rotational offset for colormap (degrees) - inverse (bool) : if True, uses light color scheme Returns: if returnfig==False (default), the figure is plotted and nothing is returned. @@ -1044,7 +751,7 @@ def show_complex( if isinstance(ar_complex, list): if isinstance(ar_complex[0], list): rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for sublist in ar_complex for ar in sublist ] @@ -1053,7 +760,7 @@ def show_complex( else: rgb = [ - Complex2RGB(ar, vmin, vmax, hue_start=hue_start, invert=invert) + Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost) for ar in ar_complex ] if len(rgb[0].shape) == 4: @@ -1064,7 +771,9 @@ def show_complex( W = len(ar_complex) is_grid = True else: - rgb = Complex2RGB(ar_complex, vmin, vmax, hue_start=hue_start, invert=invert) + rgb = Complex2RGB( + ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost + ) if len(rgb.shape) == 4: is_grid = True H = 1 @@ -1115,37 +824,534 @@ def show_complex( add_scalebar(ax, scalebar) # add color bar - if cbar == True: - ax0 = fig.add_axes([1, 0.35, 0.3, 0.3]) - - # create wheel - AA = 1000 - kx = np.fft.fftshift(np.fft.fftfreq(AA)) - ky = np.fft.fftshift(np.fft.fftfreq(AA)) - kya, kxa = np.meshgrid(ky, kx) - kra = (kya**2 + kxa**2) ** 0.5 - ktheta = np.arctan2(-kxa, kya) - ktheta = kra * np.exp(1j * ktheta) - - # convert to hsv - rgb = Complex2RGB(ktheta, 0, 0.4, hue_start=hue_start, invert=invert) - ind = kra > 0.4 - rgb[ind] = [1, 1, 1] - - # plot - ax0.imshow(rgb) - - # add axes - ax0.axhline(AA / 2, 0, AA, color="k") - ax0.axvline(AA / 2, 0, AA, color="k") - ax0.axis("off") - - label_size = 16 - - ax0.text(AA, AA / 2, 1, fontsize=label_size) - ax0.text(AA / 2, 0, "i", fontsize=label_size) - ax0.text(AA / 2, AA, "-i", fontsize=label_size) - ax0.text(0, AA / 2, -1, fontsize=label_size) - - if returnfig == True: + if cbar: + if is_grid: + for ax_flat in ax.flatten(): + divider = make_axes_locatable(ax_flat) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + else: + divider = make_axes_locatable(ax) + ax_cb = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg(ax_cb, chroma_boost=chroma_boost) + + fig.tight_layout() + + if returnfig: return fig, ax + + +def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False): + """ + Utility function for calculating min and max values for plotting array + based on distribution of pixel values + + Parameters + ---------- + array: np.array + array to be plotted + vmin: float + lower fraction cut off of pixel values + vmax: float + upper fraction cut off of pixel values + normalize: bool + if True, rescales from 0 to 1 + + Returns + ---------- + scaled_array: np.array + array clipped outside vmin and vmax + vmin: float + lower value to be plotted + vmax: float + upper value to be plotted + """ + + if vmin is None: + vmin = 0.02 + if vmax is None: + vmax = 0.98 + + vals = np.sort(array.ravel()) + ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int") + ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int") + ind_vmin = np.max([0, ind_vmin]) + ind_vmax = np.min([len(vals) - 1, ind_vmax]) + vmin = vals[ind_vmin] + vmax = vals[ind_vmax] + + if vmax == vmin: + vmin = vals[0] + vmax = vals[-1] + + scaled_array = array.copy() + scaled_array = np.where(scaled_array < vmin, vmin, scaled_array) + scaled_array = np.where(scaled_array > vmax, vmax, scaled_array) + + if normalize: + scaled_array -= scaled_array.min() + scaled_array /= scaled_array.max() + vmin = 0 + vmax = 1 + + return scaled_array, vmin, vmax + + +def show_strain( + data, + vrange=[-3, 3], + vrange_theta=[-3, 3], + vrange_exx=None, + vrange_exy=None, + vrange_eyy=None, + show_cbars=None, + bordercolor="k", + borderwidth=1, + titlesize=18, + ticklabelsize=10, + ticknumber=5, + unitlabelsize=16, + cmap="RdBu_r", + cmap_theta="PRGn", + mask_color="k", + color_axes="k", + show_legend=False, + rotation_deg=0, + show_gvects=True, + g1=None, + g2=None, + color_gvects="r", + legend_camera_length=1.6, + scale_gvects=0.6, + layout="square", + figsize=None, + returnfig=False, +): + """ + Display a strain map, showing the 4 strain components + (e_xx,e_yy,e_xy,theta), and masking each image with + strainmap.get_slice('mask') + + Parameters + ---------- + data : strainmap + vrange : length 2 list or tuple + The colorbar intensity range for exx,eyy, and exy. + vrange_theta : length 2 list or tuple + The colorbar intensity range for theta. + vrange_exx : length 2 list or tuple + The colorbar intensity range for exx; overrides `vrange` + for exx + vrange_exy : length 2 list or tuple + The colorbar intensity range for exy; overrides `vrange` + for exy + vrange_eyy : length 2 list or tuple + The colorbar intensity range for eyy; overrides `vrange` + for eyy + show_cbars : None or a tuple of strings + Show colorbars for the specified axes. Valid strings are + 'exx', 'eyy', 'exy', and 'theta'. + bordercolor : color + Color for the image borders + borderwidth : number + Width of the image borders + titlesize : number + Size of the image titles + ticklabelsize : number + Size of the colorbar ticks + ticknumber : number + Number of ticks on colorbars + unitlabelsize : number + Size of the units label on the colorbars + cmap : colormap + Colormap for exx, exy, and eyy + cmap_theta : colormap + Colormap for theta + mask_color : color + Color for the background mask + color_axes : color + Color for the legend coordinate axes + show_gvects : bool + Toggles displaying the g-vectors in the legend + rotation_deg : float + coordinate rotation for strainmap in degrees + g1 : tuple + g1 orientation (x,y) + g2 : tuple + g2 orientation (x,y) + color_gvects : color + Color for the legend g-vectors + legend_camera_length : number + The distance the legend is viewed from; a smaller number yields + a larger legend + scale_gvects : number + Scaling for the legend g-vectors relative to the coordinate axes + layout : int + Determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize : length 2 tuple of numbers + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Lookup table for different layouts + assert layout in ("square", "horizontal", "vertical") + layout_lookup = { + "square": ["left", "right", "left", "right"], + "horizontal": ["bottom", "bottom", "bottom", "bottom"], + "vertical": ["right", "right", "right", "right"], + } + + layout_p = layout_lookup[layout] + + # Set which colorbars to display + if show_cbars is None: + if np.all( + [ + v is None + for v in ( + vrange_exx, + vrange_eyy, + vrange_exy, + ) + ] + ): + show_cbars = ("eyy", "theta") + else: + show_cbars = ("exx", "eyy", "exy", "theta") + else: + assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars]) + + # Contrast limits + if vrange_exx is None: + vrange_exx = vrange + if vrange_exy is None: + vrange_exy = vrange + if vrange_eyy is None: + vrange_eyy = vrange + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 + # theta is plotted in units of degrees + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) + + # Get images + mask = data.get_slice("mask").data == False # noqa: E712,E501 + e_xx = np.ma.array(data.get_slice("e_xx").data, mask=mask) + e_yy = np.ma.array(data.get_slice("e_yy").data, mask=mask) + e_xy = np.ma.array(data.get_slice("e_xy").data, mask=mask) + theta = np.ma.array(data.get_slice("theta").data, mask=mask) + # e_xx = data.get_slice("e_xx").data + # e_yy = data.get_slice("e_yy").data + # e_xy = data.get_slice("e_xy").data + # theta = data.get_slice("theta").data + + ## Plot + + # if figsize hasn't been set, set it based on the + # chosen layout and the image shape + if figsize is None: + ratio = np.sqrt(e_xx.shape[1] / e_xx.shape[0]) + if layout == "square": + figsize = (13 * ratio, 8 / ratio) + elif layout == "horizontal": + figsize = (10 * ratio, 4 / ratio) + else: + figsize = (4 * ratio, 10 / ratio) + + # set up layout + if show_legend: + if layout == "square": + fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( + 2, 3, figsize=figsize + ) + elif layout == "horizontal": + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 1, 5, figsize=figsize + ) + else: + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 5, 1, figsize=figsize + ) + else: + if layout == "square": + fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + elif layout == "horizontal": + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) + else: + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) + + # display images, returning cbar axis references + cax11 = show( + e_xx, + figax=(fig, ax11), + vmin=vmin_exx, + vmax=vmax_exx, + intensity_range="absolute", + cmap=cmap, + mask_color=mask_color, + returncax=True, + ) + cax12 = show( + e_yy, + figax=(fig, ax12), + vmin=vmin_eyy, + vmax=vmax_eyy, + intensity_range="absolute", + cmap=cmap, + mask_color=mask_color, + returncax=True, + ) + cax21 = show( + e_xy, + figax=(fig, ax21), + vmin=vmin_exy, + vmax=vmax_exy, + intensity_range="absolute", + cmap=cmap, + mask_color=mask_color, + returncax=True, + ) + cax22 = show( + theta, + figax=(fig, ax22), + vmin=vmin_theta, + vmax=vmax_theta, + intensity_range="absolute", + cmap=cmap_theta, + mask_color=mask_color, + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) + + # add colorbars + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) + if np.any(show_cbars): + divider11 = make_axes_locatable(ax11) + divider12 = make_axes_locatable(ax12) + divider21 = make_axes_locatable(ax21) + divider22 = make_axes_locatable(ax22) + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( + range(4), + show_cbars, + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): + if show_cbar: + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) + if ind < 3: + ticklabels = np.round( + np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), + decimals=2, + ).astype(str) + else: + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) + cbax.yaxis.set_ticks_position(tickside) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) + cbax.yaxis.set_label_position(tickside) + else: + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) + cbax.xaxis.set_ticks_position(tickside) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) + cbax.xaxis.set_label_position(tickside) + else: + cbax.axis("off") + + # Add borders + if bordercolor is not None: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: + ax.spines[s].set_color(bordercolor) + ax.spines[s].set_linewidth(borderwidth) + ax.set_xticks([]) + ax.set_yticks([]) + + # Legend + if show_legend: + # for layout "square", combine vertical plots on the right end + if layout == "square": + # get gridspec object + gs = ax_legend1.get_gridspec() + # remove last two axes + ax_legend1.remove() + ax_legend2.remove() + # make new axis + ax_legend = fig.add_subplot(gs[:, -1]) + + # get the coordinate axes' directions + rotation = np.deg2rad(rotation_deg) + xaxis_vectx = np.cos(rotation) + xaxis_vecty = np.sin(rotation) + yaxis_vectx = np.cos(rotation + np.pi / 2) + yaxis_vecty = np.sin(rotation + np.pi / 2) + + # make the coordinate axes + ax_legend.arrow( + x=0, + y=0, + dx=xaxis_vecty, + dy=xaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.arrow( + x=0, + y=0, + dx=yaxis_vecty, + dy=yaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.text( + x=xaxis_vecty * 1.16, + y=xaxis_vectx * 1.16, + s="x", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=yaxis_vecty * 1.16, + y=yaxis_vectx * 1.16, + s="y", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # make the g-vectors + if show_gvects: + # get the g-vectors directions + g1q = np.array(g1) + g2q = np.array(g2) + g1norm = np.linalg.norm(g1q) + g2norm = np.linalg.norm(g2q) + g1q /= g1norm + g2q /= g2norm + # set the lengths + g_ratio = g2norm / g1norm + if g_ratio > 1: + g1q /= g_ratio + else: + g2q *= g_ratio + g1_x, g1_y = g1q + g2_x, g2_y = g2q + + # draw the g vectors + ax_legend.arrow( + x=0, + y=0, + dx=g1_y * scale_gvects, + dy=g1_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.arrow( + x=0, + y=0, + dx=g2_y * scale_gvects, + dy=g2_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.text( + x=g1_y * scale_gvects * 1.2, + y=g1_x * scale_gvects * 1.2, + s=r"$g_1$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=g2_y * scale_gvects * 1.2, + y=g2_x * scale_gvects * 1.2, + s=r"$g_2$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # find center and extent + xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx]) + xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx]) + ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty]) + ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty]) + if show_gvects: + xmin = np.min([xmin, g1_x, g2_x]) + xmax = np.max([xmax, g1_x, g2_x]) + ymin = np.min([ymin, g1_y, g2_y]) + ymax = np.max([ymax, g1_y, g2_y]) + x0 = np.mean([xmin, xmax]) + y0 = np.mean([ymin, ymax]) + xL = (xmax - x0) * legend_camera_length + yL = (ymax - y0) * legend_camera_length + + # set the extent and aspect + ax_legend.set_xlim([y0 - yL, y0 + yL]) + ax_legend.set_ylim([x0 - xL, x0 + xL]) + ax_legend.invert_yaxis() + ax_legend.set_aspect("equal") + ax_legend.axis("off") + + # show/return + if not returnfig: + plt.show() + return + else: + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs diff --git a/setup.py b/setup.py index c3cbbd151..631f23f9a 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ author_email="ben.savitzky@gmail.com", license="GNU GPLv3", keywords="STEM 4DSTEM", - python_requires=">=3.9,<3.12", + python_requires=">=3.9,<=3.12", install_requires=[ "numpy >= 1.19", "scipy >= 1.5.2", @@ -37,9 +37,11 @@ "gdown >= 4.7.1", "dask >= 2.3.0", "distributed >= 2.3.0", - "emdfile >= 0.0.13", + "emdfile >= 0.0.14", "mpire >= 2.7.1", "threadpoolctl >= 3.1.0", + "pylops >= 2.1.0", + "colorspacious >= 1.1.2", ], extras_require={ "ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"],