diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py index a47a242c5..845ebabf6 100644 --- a/py4DSTEM/braggvectors/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -1,12 +1,14 @@ # BraggVectors methods -import numpy as np -from scipy.ndimage import gaussian_filter -from warnings import warn import inspect +from warnings import warn -from emdfile import Array, Metadata, tqdmnd, _read_metadata +import matplotlib.pyplot as plt +import numpy as np +from emdfile import Array, Metadata, _read_metadata, tqdmnd +from py4DSTEM import show from py4DSTEM.datacube import VirtualImage +from scipy.ndimage import gaussian_filter class BraggVectorMethods: @@ -518,6 +520,7 @@ def fit_origin( mask_check_data=True, plot=True, plot_range=None, + cmap="RdBu_r", returncalc=True, **kwargs, ): @@ -537,6 +540,7 @@ def fit_origin( mask_check_data (bool): Get mask from origin measurements equal to zero. (TODO - replace) plot (bool, optional): plot results plot_range (float): min and max color range for plot (pixels) + cmap (colormap): plotting colormap Returns: (variable): Return value depends on returnfitp. If ``returnfitp==False`` @@ -567,75 +571,98 @@ def fit_origin( robust_thresh=robust_thresh, ) - # try to add to calibration + # try to add update calibration metadata try: - self.calibration.set_origin([qx0_fit, qy0_fit]) + self.calibration.set_origin((qx0_fit, qy0_fit)) + self.setcal() except AttributeError: warn( "No calibration found on this datacube - fit values are not being stored" ) pass - if plot: - from py4DSTEM.visualize import show_image_grid - if mask is None: - qx0_meas, qy0_meas = q_meas - qx0_res_plot = qx0_residuals - qy0_res_plot = qy0_residuals - else: - qx0_meas = np.ma.masked_array(q_meas[0], mask=np.logical_not(mask)) - qy0_meas = np.ma.masked_array(q_meas[1], mask=np.logical_not(mask)) - qx0_res_plot = np.ma.masked_array( - qx0_residuals, mask=np.logical_not(mask) - ) - qy0_res_plot = np.ma.masked_array( - qy0_residuals, mask=np.logical_not(mask) - ) - qx0_mean = np.mean(qx0_fit) - qy0_mean = np.mean(qy0_fit) - - if plot_range is None: - plot_range = 2 * np.max(qx0_fit - qx0_mean) - - cmap = kwargs.get("cmap", "RdBu_r") - kwargs.pop("cmap", None) - axsize = kwargs.get("axsize", (6, 2)) - kwargs.pop("axsize", None) - - show_image_grid( - lambda i: [ - qx0_meas - qx0_mean, - qx0_fit - qx0_mean, - qx0_res_plot, - qy0_meas - qy0_mean, - qy0_fit - qy0_mean, - qy0_res_plot, - ][i], - H=2, - W=3, + # show + if plot: + self.show_origin_fit( + q_meas[0], + q_meas[1], + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask=mask, + plot_range=plot_range, cmap=cmap, - axsize=axsize, - title=[ - "measured origin, x", - "fitorigin, x", - "residuals, x", - "measured origin, y", - "fitorigin, y", - "residuals, y", - ], - vmin=-1 * plot_range, - vmax=1 * plot_range, - intensity_range="absolute", **kwargs, ) - # update calibration metadata - self.calibration.set_origin((qx0_fit, qy0_fit)) - self.setcal() - + # return if returncalc: return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals + def show_origin_fit( + self, + qx0_meas, + qy0_meas, + qx0_fit, + qy0_fit, + qx0_residuals, + qy0_residuals, + mask=None, + plot_range=None, + cmap="RdBu_r", + **kwargs, + ): + # apply mask + if mask is not None: + qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask)) + qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask)) + qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask)) + qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask)) + qx0_mean = np.mean(qx0_fit) + qy0_mean = np.mean(qy0_fit) + + # set range + if plot_range is None: + plot_range = max( + ( + 1.5 * np.max(np.abs(qx0_fit - qx0_mean)), + 1.5 * np.max(np.abs(qy0_fit - qy0_mean)), + ) + ) + + # set figsize + imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0]) + axsize = (3 * imsize_ratio, 3 / imsize_ratio) + axsize = kwargs.pop("axsize", axsize) + + # plot + fig, ax = show( + [ + [qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals], + [qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals], + ], + cmap=cmap, + axsize=axsize, + title=[ + "measured origin, x", + "fitorigin, x", + "residuals, x", + "measured origin, y", + "fitorigin, y", + "residuals, y", + ], + vmin=-1 * plot_range, + vmax=1 * plot_range, + intensity_range="absolute", + show_cbar=True, + returnfig=True, + **kwargs, + ) + plt.tight_layout() + + return + def fit_p_ellipse( self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs ): @@ -771,6 +798,21 @@ def mask_in_R(self, mask, update_inplace=False, returncalc=True): else: return + def to_strainmap(self, name: str = None): + """ + Generate a StrainMap object from the BraggVectors + equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors) + + Args: + name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'. + + Returns: + py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors + """ + from py4DSTEM.process.strain import StrainMap + + return StrainMap(self, name) if name else StrainMap(self) + ######### END BraggVectorMethods CLASS ######## diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py index 45c08b9c9..26f8eb8f4 100644 --- a/py4DSTEM/braggvectors/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -200,7 +200,7 @@ def setcal( if pixel is None: pixel = False if c.get_Q_pixel_size() == 1 else True if rotate is None: - rotate = False if c.get_QR_rotflip() is None else True + rotate = False if c.get_QR_rotation() is None else True # validate requested state if center: @@ -210,7 +210,7 @@ def setcal( if pixel: assert c.get_Q_pixel_size() is not None, "Requested calibration not found" if rotate: - assert c.get_QR_rotflip() is not None, "Requested calibration not found" + assert c.get_QR_rotation() is not None, "Requested calibration not found" # set the calibrations self._calstate = { @@ -478,15 +478,15 @@ def _transform( # Q/R rotation if rotate: - flip = cal.get_QR_flip() - theta = cal.get_QR_rotation_degrees() - assert flip is not None, "Requested calibration was not found!" + theta = cal.get_QR_rotation() assert theta is not None, "Requested calibration was not found!" + flip = cal.get_QR_flip() + flip = False if flip is None else flip # rotation matrix R = np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ) - # apply + # rotate and flip if flip: positions = R @ np.vstack((ans["qy"], ans["qx"])) else: diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py index a31f098d4..ffdbfa410 100644 --- a/py4DSTEM/data/calibration.py +++ b/py4DSTEM/data/calibration.py @@ -205,6 +205,7 @@ def __init__( self["R_pixel_size"] = 1 self["Q_pixel_units"] = "pixels" self["R_pixel_units"] = "pixels" + self["QR_flip"] = False # EMD root property @property @@ -666,8 +667,17 @@ def ellipse(self, x): # Q/R-space rotation and flip + @call_calibrate + def set_QR_rotation(self, x): + self._params["QR_rotation"] = x + self._params["QR_rotation_degrees"] = np.degrees(x) + + def get_QR_rotation(self): + return self._get_value("QR_rotation") + @call_calibrate def set_QR_rotation_degrees(self, x): + self._params["QR_rotation"] = np.radians(x) self._params["QR_rotation_degrees"] = x def get_QR_rotation_degrees(self): @@ -689,10 +699,31 @@ def set_QR_rotflip(self, rot_flip): flip (bool): True indicates a Q/R axes flip """ rot, flip = rot_flip + self._params["QR_rotation"] = rot + self._params["QR_rotation_degrees"] = np.degrees(rot) + self._params["QR_flip"] = flip + + @call_calibrate + def set_QR_rotflip_degrees(self, rot_flip): + """ + Args: + rot_flip (tuple), (rot, flip) where: + rot (number): rotation in degrees + flip (bool): True indicates a Q/R axes flip + """ + rot, flip = rot_flip + self._params["QR_rotation"] = np.radians(rot) self._params["QR_rotation_degrees"] = rot self._params["QR_flip"] = flip def get_QR_rotflip(self): + rot = self.get_QR_rotation() + flip = self.get_QR_flip() + if rot is None or flip is None: + return None + return (rot, flip) + + def get_QR_rotflip_degrees(self): rot = self.get_QR_rotation_degrees() flip = self.get_QR_flip() if rot is None or flip is None: diff --git a/py4DSTEM/process/calibration/rotation.py b/py4DSTEM/process/calibration/rotation.py index 2c3e7bb43..aaf8a49ce 100644 --- a/py4DSTEM/process/calibration/rotation.py +++ b/py4DSTEM/process/calibration/rotation.py @@ -2,6 +2,179 @@ import numpy as np from typing import Optional +import matplotlib.pyplot as plt +from py4DSTEM import show + + +def compare_QR_rotation( + im_R, + im_Q, + QR_rotation, + R_rotation=0, + R_position=None, + Q_position=None, + R_pos_anchor="center", + Q_pos_anchor="center", + R_length=0.33, + Q_length=0.33, + R_width=0.001, + Q_width=0.001, + R_head_length_adjust=1, + Q_head_length_adjust=1, + R_head_width_adjust=1, + Q_head_width_adjust=1, + R_color="r", + Q_color="r", + figsize=(10, 5), + returnfig=False, +): + """ + Visualize a rotational offset between an image in real space, e.g. a STEM + virtual image, and an image in diffraction space, e.g. a defocused CBED + shadow image of the same region, by displaying an arrow overlaid over each + of these two images with the specified QR rotation applied. The QR rotation + is defined as the counter-clockwise rotation from real space to diffraction + space, in degrees. + + Parameters + ---------- + im_R : numpy array or other 2D image-like object (e.g. a VirtualImage) + A real space image, e.g. a STEM virtual image + im_Q : numpy array or other 2D image-like object + A diffraction space image, e.g. a defocused CBED image + QR_rotation : number + The counterclockwise rotation from real space to diffraction space, + in degrees + R_rotation : number + The orientation of the arrow drawn in real space, in degrees + R_position : None or 2-tuple + The position of the anchor point for the R-space arrow. If None, defaults + to the center of the image + Q_position : None or 2-tuple + The position of the anchor point for the Q-space arrow. If None, defaults + to the center of the image + R_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the R-space arrow, i.e. the point being specified by + the `R_position` parameter + Q_pos_anchor : 'center' or 'tail' or 'head' + The anchor point for the Q-space arrow, i.e. the point being specified by + the `Q_position` parameter + R_length : number or None + The length of the R-space arrow, as a fraction of the mean size of the + image + Q_length : number or None + The length of the Q-space arrow, as a fraction of the mean size of the + image + R_width : number + The width of the R-space arrow + Q_width : number + The width of the R-space arrow + R_head_length_adjust : number + Scaling factor for the R-space arrow head length + Q_head_length_adjust : number + Scaling factor for the Q-space arrow head length + R_head_width_adjust : number + Scaling factor for the R-space arrow head width + Q_head_width_adjust : number + Scaling factor for the Q-space arrow head width + R_color : color + Color of the R-space arrow + Q_color : color + Color of the Q-space arrow + figsize : 2-tuple + The figure size + returnfig : bool + Toggles returning the figure and axes + """ + # parse inputs + if R_position is None: + R_position = ( + im_R.shape[0] / 2, + im_R.shape[1] / 2, + ) + if Q_position is None: + Q_position = ( + im_Q.shape[0] / 2, + im_Q.shape[1] / 2, + ) + R_length = np.mean(im_R.shape) * R_length + Q_length = np.mean(im_Q.shape) * Q_length + assert R_pos_anchor in ("center", "tail", "head") + assert Q_pos_anchor in ("center", "tail", "head") + + # compute positions + rpos_x, rpos_y = R_position + qpos_x, qpos_y = Q_position + R_rot_rad = np.radians(R_rotation) + Q_rot_rad = np.radians(R_rotation + QR_rotation) + rvecx = np.cos(R_rot_rad) + rvecy = np.sin(R_rot_rad) + qvecx = np.cos(Q_rot_rad) + qvecy = np.sin(Q_rot_rad) + if R_pos_anchor == "center": + x0_r = rpos_x - rvecx * R_length / 2 + y0_r = rpos_y - rvecy * R_length / 2 + x1_r = rpos_x + rvecx * R_length / 2 + y1_r = rpos_y + rvecy * R_length / 2 + elif R_pos_anchor == "tail": + x0_r = rpos_x + y0_r = rpos_y + x1_r = rpos_x + rvecx * R_length + y1_r = rpos_y + rvecy * R_length + elif R_pos_anchor == "head": + x0_r = rpos_x - rvecx * R_length + y0_r = rpos_y - rvecy * R_length + x1_r = rpos_x + y1_r = rpos_y + else: + raise Exception(f"Invalid value for R_pos_anchor {R_pos_anchor}") + if Q_pos_anchor == "center": + x0_q = qpos_x - qvecx * Q_length / 2 + y0_q = qpos_y - qvecy * Q_length / 2 + x1_q = qpos_x + qvecx * Q_length / 2 + y1_q = qpos_y + qvecy * Q_length / 2 + elif Q_pos_anchor == "tail": + x0_q = qpos_x + y0_q = qpos_y + x1_q = qpos_x + qvecx * Q_length + y1_q = qpos_y + qvecy * Q_length + elif Q_pos_anchor == "head": + x0_q = qpos_x - qvecx * Q_length + y0_q = qpos_y - qvecy * Q_length + x1_q = qpos_x + y1_q = qpos_y + else: + raise Exception(f"Invalid value for Q_pos_anchor {Q_pos_anchor}") + + # make the figure + axsize = (figsize[0] / 2, figsize[1]) + fig, axs = show([im_R, im_Q], returnfig=True, axsize=axsize) + axs[0, 0].arrow( + x=y0_r, + y=x0_r, + dx=y1_r - y0_r, + dy=x1_r - x0_r, + color=R_color, + length_includes_head=True, + width=R_width, + head_width=R_length * R_head_width_adjust * 0.072, + head_length=R_length * R_head_length_adjust * 0.1, + ) + axs[0, 1].arrow( + x=y0_q, + y=x0_q, + dx=y1_q - y0_q, + dy=x1_q - x0_q, + color=Q_color, + length_includes_head=True, + width=Q_width, + head_width=Q_length * Q_head_width_adjust * 0.072, + head_length=Q_length * Q_head_length_adjust * 0.1, + ) + if returnfig: + return fig, axs + else: + plt.show() def get_Qvector_from_Rvector(vx, vy, QR_rotation): diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py index 26c8d66a5..ba9bb4fcf 100644 --- a/py4DSTEM/process/strain/latticevectors.py +++ b/py4DSTEM/process/strain/latticevectors.py @@ -258,7 +258,13 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): ) # Fit lattice vectors - for Rx, Ry in tqdmnd(braggpeaks.shape[0], braggpeaks.shape[1]): + for Rx, Ry in tqdmnd( + braggpeaks.shape[0], + braggpeaks.shape[1], + desc="Fitting lattice vectors", + unit="DP", + unit_scale=True, + ): braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry) qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors( braggpeaks_curr, x0, y0, minNumPeaks @@ -359,32 +365,37 @@ def get_strain_from_reference_g1g2(g1g2_map, g1, g2): g2x, g2y = g2 M = np.array([[g1x, g1y], [g2x, g2y]]) - for Rx in range(R_Nx): - for Ry in range(R_Ny): - # Get lattice vectors for DP at Rx,Ry - alpha = np.array( + for Rx, Ry in tqdmnd( + R_Nx, + R_Ny, + desc="Calculating strain", + unit="DP", + unit_scale=True, + ): + # Get lattice vectors for DP at Rx,Ry + alpha = np.array( + [ [ - [ - g1g2_map.get_slice("g1x").data[Rx, Ry], - g1g2_map.get_slice("g1y").data[Rx, Ry], - ], - [ - g1g2_map.get_slice("g2x").data[Rx, Ry], - g1g2_map.get_slice("g2y").data[Rx, Ry], - ], - ] - ) - # Get transformation matrix - beta = lstsq(M, alpha, rcond=None)[0].T - - # Get the infinitesimal strain matrix - strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] - strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] - strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 - strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 - strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ - Rx, Ry + g1g2_map.get_slice("g1x").data[Rx, Ry], + g1g2_map.get_slice("g1y").data[Rx, Ry], + ], + [ + g1g2_map.get_slice("g2x").data[Rx, Ry], + g1g2_map.get_slice("g2y").data[Rx, Ry], + ], ] + ) + # Get transformation matrix + beta = lstsq(M, alpha, rcond=None)[0].T + + # Get the infinitesimal strain matrix + strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0] + strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1] + strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0 + strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0 + strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[ + Rx, Ry + ] return strain_map diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 538c90825..ab8a46a9a 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -4,6 +4,8 @@ from typing import Optional import matplotlib.pyplot as plt +from matplotlib.patches import Circle +from matplotlib.collections import PatchCollection from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np from py4DSTEM import PointList, PointListArray, tqdmnd @@ -18,10 +20,14 @@ get_strain_from_reference_g1g2, index_bragg_directions, ) -from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show -from py4DSTEM.visualize import ax_addaxes, ax_addaxes_QtoR - -warnings.simplefilter(action="always", category=UserWarning) +from py4DSTEM.visualize import ( + show, + add_bragg_index_labels, + add_pointlabels, + add_vector, + ax_addaxes, + ax_addaxes_QtoR, +) class StrainMap(RealSlice, Data): @@ -32,11 +38,16 @@ class StrainMap(RealSlice, Data): def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): """ - Accepts: - braggvectors (BraggVectors): BraggVectors for Strain Map - name (str): the name of the strainmap - Returns: - A new StrainMap instance. + Parameters + ---------- + braggvectors : BraggVectors + The Bragg vectors + name : str + The name of the strainmap + + Returns + ------- + A new StrainMap instance. """ assert isinstance( braggvectors, BraggVectors @@ -112,13 +123,6 @@ def qshape(self): def origin(self): return self.calibration.get_origin_mean() - @property - def mask(self): - try: - return self.g1g2_map["mask"].data.astype("bool") - except: - return np.ones(self.rshape, dtype=bool) - def reset_calstate(self): """ Resets the calibration state. This recomputes the BVM, and removes any computations @@ -136,7 +140,7 @@ def reset_calstate(self): # Class methods - def choose_lattice_vectors( + def choose_basis_vectors( self, index_g1=None, index_g2=None, @@ -167,12 +171,14 @@ def choose_lattice_vectors( returnfig=False, ): """ - Choose which lattice vectors to use for strain mapping. + Choose basis lattice vectors g1 and g2 for strain mapping. Overlays the bvm with the points detected via local 2D - maxima detection, plus an index for each point. User selects - 3 points using the overlaid indices, which are identified as - the origin and the termini of the lattice vectors g1 and g2. + maxima detection, plus an index for each point. Three points + are selected which correspond to the origin, and the basis + reciprocal lattice vectors g1 and g2. By default these are + automatically located; the user can override and select these + manually using the `index_*` arguments. Parameters ---------- @@ -229,7 +235,7 @@ def choose_lattice_vectors( Returns ------- - (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or the latter two """ # validate inputs for i in (index_origin, index_g1, index_g2): @@ -384,42 +390,114 @@ def choose_lattice_vectors( else: return - def fit_lattice_vectors( + def set_max_peak_spacing( self, - max_peak_spacing=2, - mask=None, - returncalc=False, + max_peak_spacing, + returnfig=False, + **vis_params, ): """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. - - Args: - max_peak_spacing: float - Maximum distance from the ideal lattice points - to include a peak for indexing - mask: bool - Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - returncalc : bool - if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + Set the size of the regions of diffraction space in which detected Bragg + peaks will be indexed and included in subsequent fitting of basis + vectors, and visualize those regions. + + Parameters + ---------- + max_peak_spacing : number + The maximum allowable distance in pixels between a detected Bragg peak and + the indexed maxima found in `choose_basis_vectors` for the detected + peak to be indexed + returnfig : bool + Toggles returning the figure + vis_params : dict + Any additional arguments are passed to the `show` function when + visualization the BVM + """ + # set the max peak spacing + self.max_peak_spacing = max_peak_spacing + + # make the figure + fig, ax = show( + self.bvm.data, + returnfig=True, + **vis_params, + ) + + # make the circle patch collection + patches = [] + qx = self.braggdirections["qx"] + qy = self.braggdirections["qy"] + origin = self.origin + for idx in range(len(qx)): + c = Circle( + xy=(qy[idx] + origin[1], qx[idx] + origin[0]), + radius=self.max_peak_spacing, + edgecolor="r", + fill=False, + ) + patches.append(c) + pc = PatchCollection(patches, match_original=True) + + # draw the circles + ax.add_collection(pc) + + # return + if returnfig: + return fig, ax + else: + plt.show() + + def fit_basis_vectors( + self, mask=None, max_peak_spacing=None, vis_params={}, returncalc=False + ): + """ + Fit the basis lattice vectors to the detected Bragg peaks at each + scan position. + + First, the lattice vectors at each scan position are indexed using the + basis vectors g1 and g2 specified previously with `choose_basis_vectors` + Detected Bragg peaks which are farther from the set of lattice vectors + found in `choose_basis vectors` than the maximum peak spacing are + ignored; the maximum peak spacing can be set previously by calling + `set_max_peak_spacing` or by specifying the `max_peak_spacing` argument + here. A fit is then performed to refine the values of g1 and g2 at each + scan position, fitting the basis vectors to all detected and indexed + peaks, weighting the peaks according to their intensity. + + Parameters + ---------- + mask : 2d boolean array + A real space shaped Boolean mask indicating scan positions at which + to fit the lattice vectors. + max_peak_spacing : float + Maximum distance from the ideal lattice points to include a peak + for indexing + vis_params : dict + Visualization parameters for showing the max peak spacing; ignored + if `max_peak_spacing` is not set + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map """ # check the calstate assert ( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - ### add indices to the bragg vectors + # handle the max peak spacing + if max_peak_spacing is not None: + self.set_max_peak_spacing(max_peak_spacing, **vis_params) + assert hasattr(self, "max_peak_spacing"), "Set the maximum peak spacing!" - # validate mask + # index the bragg vectors + + # handle the mask if mask is None: mask = np.ones(self.braggvectors.Rshape, dtype=bool) assert ( mask.shape == self.braggvectors.Rshape ), "mask must have same shape as pointlistarray" assert mask.dtype == bool, "mask must be boolean" + self.mask = mask # set up new braggpeaks PLA indexed_braggpeaks = PointListArray( @@ -432,10 +510,17 @@ def fit_lattice_vectors( ], shape=self.braggvectors.Rshape, ) - calstate = self.braggvectors.calstate # loop over all the scan positions - for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]): + # and perform indexing, excluding peaks outside of max_peak_spacing + calstate = self.braggvectors.calstate + for Rx, Ry in tqdmnd( + mask.shape[0], + mask.shape[1], + desc="Indexing Bragg scattering", + unit="DP", + unit_scale=True, + ): if mask[Rx, Ry]: pl = self.braggvectors.get_vectors( Rx, @@ -451,7 +536,7 @@ def fit_lattice_vectors( pl.data["qy"][i] - self.braggdirections.data["qy"], ) ind = np.argmin(r) - if r[ind] <= max_peak_spacing: + if r[ind] <= self.max_peak_spacing: indexed_braggpeaks[Rx, Ry].add_data_by_field( ( pl.data["qx"][i], @@ -463,186 +548,237 @@ def fit_lattice_vectors( ) self.bragg_vectors_indexed = indexed_braggpeaks - ### fit bragg vectors + # fit bragg vectors g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) self.g1g2_map = g1g2_map + # update the mask + g1g2_mask = self.g1g2_map["mask"].data.astype("bool") + self.mask = np.logical_and(self.mask, g1g2_mask) + # return if returncalc: return self.bragg_vectors_indexed, self.g1g2_map def get_strain( - self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs + self, gvects=None, coordinate_rotation=0, returncalc=False, **kwargs ): """ - mask: nd.array (bool) - Use lattice vectors from g1g2_map scan positions - wherever mask==True. If mask is None gets median strain - map from entire field of view. If mask is not None, gets - reference g1 and g2 from region and then calculates strain. - g_reference: nd.array of form [x,y] - G_reference (tupe): reference coordinate system for - xaxis_x and xaxis_y - flip_theta: bool - If True, flips rotation coordinate system - returncal: bool + Compute the strain as the deviation of the basis reciprocal lattice + vectors which have been fit at each scan position with respect to a + pair of reference lattice vectors, determined by the argument `gvects`. + + Parameters + ---------- + gvects : None or 2d-array or tuple + Specifies how to select the reference lattice vectors. If None, + use the median of the fit lattice vectors over the whole dataset. + If a 2d array is passed, it should be real space shaped and boolean. + In this case, uses the median of the fit lattice vectors in all scan + positions where this array is True. Otherwise, should be a length 2 + tuple of length 2 array/list/tuples, which are used directly as + g1 and g2. + coordinate_rotation : number + Rotate the reference coordinate system counterclockwise by this + amount, in degrees + returncal : bool It True, returns rotated map """ - # check the calstate + # confirm that the calstate hasn't changed assert ( self.calstate == self.braggvectors.calstate ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." - if mask is None: - mask = self.mask - # mask = np.ones(self.g1g2_map.shape, dtype="bool") - # strainmap_g1g2 = get_strain_from_reference_region( - # self.g1g2_map, - # mask=mask, - # ) - - # g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) - # strain_map = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - # else: - - g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + # get the reference g-vectors + if gvects is None: + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, self.mask) + elif isinstance(gvects, np.ndarray): + assert gvects.shape == self.rshape + assert gvects.dtype == bool + g1_ref, g2_ref = get_reference_g1g2( + self.g1g2_map, np.logical_and(gvects, self.mask) + ) + else: + g1_ref = np.array(gvects[0]) + g2_ref = np.array(gvects[1]) + # find the strain strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - self.strainmap_g1g2 = strainmap_g1g2 - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) + # get the reference coordinate system + theta = np.radians(coordinate_rotation) + xaxis_x = np.cos(theta) + xaxis_y = np.sin(theta) + self.coordinate_rotation_degrees = coordinate_rotation + self.coordinate_rotation_radians = theta + # get the strain in the reference coordinates strainmap_rotated = get_rotated_strain_map( self.strainmap_g1g2, - xaxis_x=g_reference[0], - xaxis_y=g_reference[1], - flip_theta=flip_theta, + xaxis_x=xaxis_x, + xaxis_y=xaxis_y, + flip_theta=False, ) + # store the data self.data[0] = strainmap_rotated["e_xx"].data self.data[1] = strainmap_rotated["e_yy"].data self.data[2] = strainmap_rotated["e_xy"].data self.data[3] = strainmap_rotated["theta"].data self.data[4] = strainmap_rotated["mask"].data - self.g_reference = g_reference - - figsize = kwargs.pop("figsize", (14, 4)) - vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) - vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) - ticknumber = kwargs.pop("ticknumber", 3) - bkgrd = kwargs.pop("bkgrd", False) - axes_plots = kwargs.pop("axes_plots", ()) + # plot the results fig, ax = self.show_strain( - vrange_exx=vrange_exx, - vrange_theta=vrange_theta, - ticknumber=ticknumber, - axes_plots=axes_plots, - bkgrd=bkgrd, - figsize=figsize, **kwargs, returnfig=True, ) - if not np.all(mask == True): - ax[0][0].imshow(mask, alpha=0.2, cmap="binary") - ax[0][1].imshow(mask, alpha=0.2, cmap="binary") - ax[1][0].imshow(mask, alpha=0.2, cmap="binary") - ax[1][1].imshow(mask, alpha=0.2, cmap="binary") - + # return if returncalc: return self.strainmap + def get_reference_g1g2(self, ROI): + """ + Get reference g1,g2 vectors by taking the median fit vectors + in the specified ROI. + + Parameters + ---------- + ROI : real space shaped 2d boolean ndarray + Use scan positions where ROI is True + + Returns + ------- + g1_ref,g2_ref : 2 tuple of length 2 ndarrays + """ + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, ROI) + return g1_ref, g2_ref + def show_strain( self, - vrange_exx, - vrange_theta, + vrange=[-3, 3], + vrange_theta=[-3, 3], + vrange_exx=None, vrange_exy=None, vrange_eyy=None, - flip_theta=False, bkgrd=True, - show_cbars=("exx", "eyy", "exy", "theta"), + show_cbars=None, bordercolor="k", borderwidth=1, - titlesize=24, - ticklabelsize=16, + titlesize=18, + ticklabelsize=10, ticknumber=5, - unitlabelsize=24, - show_axes=False, - axes_position=(0, 0), - axes_length=10, - axes_width=1, - axes_color="w", - xaxis_space="Q", - labelaxes=True, - QR_rotation=0, - axes_labelsize=12, - axes_labelcolor="r", - axes_plots=("exx"), + unitlabelsize=16, cmap="RdBu_r", + cmap_theta="PRGn", mask_color="k", - layout=0, - figsize=(12, 12), + color_axes="k", + show_gvects=True, + color_gvects="r", + legend_camera_length=1.6, + scale_gvects=0.6, + layout="square", + figsize=None, returnfig=False, ): """ - Display a strain map, showing the 4 strain components (e_xx,e_yy,e_xy,theta), and - masking each image with strainmap.get_slice('mask') - - Args: - vrange_exx (length 2 list or tuple): - vrange_theta (length 2 list or tuple): - vrange_exy (length 2 list or tuple): - vrange_eyy (length 2 list or tuple): - flip_theta (bool): if True, take negative of angle - bkgrd (bool): - show_cbars (tuple of strings): Show colorbars for the specified axes. Must be a - tuple containing any, all, or none of ('exx','eyy','exy','theta'). - bordercolor (color): - borderwidth (number): - titlesize (number): - ticklabelsize (number): - ticknumber (number): number of ticks on colorbars - unitlabelsize (number): - show_axes (bool): - axes_x0 (number): - axes_y0 (number): - xaxis_x (number): - xaxis_y (number): - axes_length (number): - axes_width (number): - axes_color (color): - xaxis_space (string): must be 'Q' or 'R' - labelaxes (bool): - QR_rotation (number): - axes_labelsize (number): - axes_labelcolor (color): - axes_plots (tuple of strings): controls if coordinate axes showing the - orientation of the strain matrices are overlaid over any of the plots. - Must be a tuple of strings containing any, all, or none of - ('exx','eyy','exy','theta'). - cmap (colormap): - layout=0 (int): determines the layout of the grid which the strain components - will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). - figsize (length 2 tuple of numbers): - returnfig (bool): + Display a strain map, showing the 4 strain components + (e_xx,e_yy,e_xy,theta), and masking each image with + strainmap.get_slice('mask') + + Parameters + ---------- + vrange : length 2 list or tuple + The colorbar intensity range for exx,eyy, and exy. + vrange_theta : length 2 list or tuple + The colorbar intensity range for theta. + vrange_exx : length 2 list or tuple + The colorbar intensity range for exx; overrides `vrange` + for exx + vrange_exy : length 2 list or tuple + The colorbar intensity range for exy; overrides `vrange` + for exy + vrange_eyy : length 2 list or tuple + The colorbar intensity range for eyy; overrides `vrange` + for eyy + bkgrd : bool + Overlay a mask over background pixels + show_cbars : None or a tuple of strings + Show colorbars for the specified axes. Valid strings are + 'exx', 'eyy', 'exy', and 'theta'. + bordercolor : color + Color for the image borders + borderwidth : number + Width of the image borders + titlesize : number + Size of the image titles + ticklabelsize : number + Size of the colorbar ticks + ticknumber : number + Number of ticks on colorbars + unitlabelsize : number + Size of the units label on the colorbars + cmap : colormap + Colormap for exx, exy, and eyy + cmap_theta : colormap + Colormap for theta + mask_color : color + Color for the background mask + color_axes : color + Color for the legend coordinate axes + show_gvects : bool + Toggles displaying the g-vectors in the legend + color_gvects : color + Color for the legend g-vectors + legend_camera_length : number + The distance the legend is viewed from; a smaller number yields + a larger legend + scale_gvects : number + Scaling for the legend g-vectors relative to the coordinate axes + layout : int + Determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize : length 2 tuple of numbers + Size of the figure + returnfig : bool + Toggles returning the figure """ # Lookup table for different layouts - assert layout in (0, 1, 2) + assert layout in ("square", "horizontal", "vertical") layout_lookup = { - 0: ["left", "right", "left", "right"], - 1: ["bottom", "bottom", "bottom", "bottom"], - 2: ["right", "right", "right", "right"], + "square": ["left", "right", "left", "right"], + "horizontal": ["bottom", "bottom", "bottom", "bottom"], + "vertical": ["right", "right", "right", "right"], } + layout_p = layout_lookup[layout] + # Set which colorbars to display + if show_cbars is None: + if np.all( + [ + v is None + for v in ( + vrange_exx, + vrange_eyy, + vrange_exy, + ) + ] + ): + show_cbars = ("eyy", "theta") + else: + show_cbars = ("exx", "eyy", "exy", "theta") + else: + assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars]) + # Contrast limits + if vrange_exx is None: + vrange_exx = vrange if vrange_exy is None: - vrange_exy = vrange_exx + vrange_exy = vrange if vrange_eyy is None: - vrange_eyy = vrange_exx + vrange_eyy = vrange for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): assert len(vrange) == 2, "vranges must have length 2" vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 @@ -667,25 +803,35 @@ def show_strain( self.get_slice("theta").data, mask=self.get_slice("mask").data == False, ) - if flip_theta == True: - theta = -theta ## Plot - # modify the figsize according to the image aspect ratio - ratio = np.sqrt(self.rshape[1] / self.rshape[0]) - figsize_mean = np.mean(figsize) - figsize = (figsize_mean * ratio, figsize_mean / ratio) + # if figsize hasn't been set, set it based on the + # chosen layout and the image shape + if figsize is None: + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) + if layout == "square": + figsize = (13 * ratio, 8 / ratio) + elif layout == "horizontal": + figsize = (10 * ratio, 4 / ratio) + else: + figsize = (4 * ratio, 10 / ratio) # set up layout - if layout == 0: - fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) - elif layout == 1: + if layout == "square": + fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( + 2, 3, figsize=figsize + ) + elif layout == "horizontal": figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 1, 5, figsize=figsize + ) else: figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 5, 1, figsize=figsize + ) # display images, returning cbar axis references cax11 = show( @@ -727,7 +873,7 @@ def show_strain( vmin=vmin_theta, vmax=vmax_theta, intensity_range="absolute", - cmap=cmap, + cmap=cmap_theta, mask=self.mask, mask_color=mask_color, returncax=True, @@ -815,49 +961,6 @@ def show_strain( else: cbax.axis("off") - # Add coordinate axes - if show_axes: - assert xaxis_space in ("R", "Q"), "xaxis_space must be 'R' or 'Q'" - show_which_axes = np.array( - [ - "exx" in axes_plots, - "eyy" in axes_plots, - "exy" in axes_plots, - "theta" in axes_plots, - ] - ) - for _show, _ax in zip(show_which_axes, (ax11, ax12, ax21, ax22)): - if _show: - if xaxis_space == "R": - ax_addaxes( - _ax, - self.g_reference[0], - self.g_reference[1], - axes_length, - axes_position[0], - axes_position[1], - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - else: - ax_addaxes_QtoR( - _ax, - self.g_reference[0], - self.g_reference[1], - axes_length, - axes_position[0], - axes_position[1], - QR_rotation, - width=axes_width, - color=axes_color, - labelaxes=labelaxes, - labelsize=axes_labelsize, - labelcolor=axes_labelcolor, - ) - # Add borders if bordercolor is not None: for ax in (ax11, ax12, ax21, ax22): @@ -867,6 +970,146 @@ def show_strain( ax.set_xticks([]) ax.set_yticks([]) + # Legend + + # for layout "square", combine vertical plots on the right end + if layout == "square": + # get gridspec object + gs = ax_legend1.get_gridspec() + # remove last two axes + ax_legend1.remove() + ax_legend2.remove() + # make new axis + ax_legend = fig.add_subplot(gs[:, -1]) + + # get the coordinate axes' directions + rotation = self.coordinate_rotation_radians + xaxis_vectx = np.cos(rotation) + xaxis_vecty = np.sin(rotation) + yaxis_vectx = np.cos(rotation + np.pi / 2) + yaxis_vecty = np.sin(rotation + np.pi / 2) + + # make the coordinate axes + ax_legend.arrow( + x=0, + y=0, + dx=xaxis_vecty, + dy=xaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.arrow( + x=0, + y=0, + dx=yaxis_vecty, + dy=yaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.text( + x=xaxis_vecty * 1.16, + y=xaxis_vectx * 1.16, + s="x", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=yaxis_vecty * 1.16, + y=yaxis_vectx * 1.16, + s="y", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # make the g-vectors + if show_gvects: + # get the g-vectors directions + g1q = np.array(self.g1) + g2q = np.array(self.g2) + g1norm = np.linalg.norm(g1q) + g2norm = np.linalg.norm(g2q) + g1q /= g1norm + g2q /= g2norm + # set the lengths + g_ratio = g2norm / g1norm + if g_ratio > 1: + g1q /= g_ratio + else: + g2q *= g_ratio + g1_x, g1_y = g1q + g2_x, g2_y = g2q + + # draw the g vectors + ax_legend.arrow( + x=0, + y=0, + dx=g1_y * scale_gvects, + dy=g1_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.arrow( + x=0, + y=0, + dx=g2_y * scale_gvects, + dy=g2_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.text( + x=g1_y * scale_gvects * 1.2, + y=g1_x * scale_gvects * 1.2, + s=r"$g_1$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=g2_y * scale_gvects * 1.2, + y=g2_x * scale_gvects * 1.2, + s=r"$g_2$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # find center and extent + xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx]) + xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx]) + ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty]) + ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty]) + if show_gvects: + xmin = np.min([xmin, g1_x, g2_x]) + xmax = np.max([xmax, g1_x, g2_x]) + ymin = np.min([ymin, g1_y, g2_y]) + ymax = np.max([ymax, g1_y, g2_y]) + x0 = np.mean([xmin, xmax]) + y0 = np.mean([ymin, ymax]) + xL = (xmax - x0) * legend_camera_length + yL = (ymax - y0) * legend_camera_length + + # set the extent and aspect + ax_legend.set_xlim([y0 - yL, y0 + yL]) + ax_legend.set_ylim([x0 - xL, x0 + xL]) + ax_legend.invert_yaxis() + ax_legend.set_aspect("equal") + ax_legend.axis("off") + + # show/return if not returnfig: plt.show() return @@ -874,6 +1117,324 @@ def show_strain( axs = ((ax11, ax12), (ax21, ax22)) return fig, axs + def show_reference_directions( + self, + im_uncal=None, + im_cal=None, + color_axes="linen", + color_gvects="r", + origin_uncal=None, + origin_cal=None, + camera_length=1.8, + visp_uncal={"scaling": "log"}, + visp_cal={"scaling": "log"}, + layout="horizontal", + titlesize=16, + size_labels=14, + figsize=None, + returnfig=False, + ): + """ + Show the reference coordinate system used to compute the strain + overlaid over calibrated and uncalibrated diffraction space images. + + The diffraction images used can be specificied with the `im_uncal` + and `im_cal` arguments, and default to the uncalibrated and calibrated + Bragg vector maps. The `rotate_cal` argument causes the `im_cal` array + to be rotated by -QR rotation from the calibration metadata, so that an + uncalibrated image (like a raw diffraction image or mean or max + diffraction pattern) can be passed to the `im_cal` argument. + + Parameters + ---------- + im_uncal : 2d array or None + Uncalibrated diffraction space image to dispay; defaults to + the maximal diffraction image. + im_cal : 2d array or None + Calibrated diffraction space image to display; defaults to + the calibrated Bragg vector map. + color_axes : color + The color of the overlaid coordinate axes + color_gvects : color + The color of the g-vectors + origin_uncal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the uncalibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + origin_cal : 2-tuple or None + Where to place the origin of the coordinate system overlaid on + the calibrated diffraction image. Defaults to the mean origin + from the calibration metadata. + camera_length : number + Determines the length of the overlaid coordinate axes; a smaller + number yields larger axes. + visp_uncal : dict + Visualization parameters for the uncalibrated diffraction image. + visp_cal : dict + Visualization parameters for the calibrated diffraction image. + layout : str; either "horizontal" or "vertical" + Determines the layout of the visualization. + titlesize : number + The size of the plot titles + size_labels : number + The size of the axis labels + figsize : length 2 tuple of numbers or None + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Set up the figure + assert layout in ("horizontal", "vertical") + + # Set the figsize + if figsize is None: + ratio = np.sqrt(self.rshape[1] / self.rshape[0]) + if layout == "horizontal": + figsize = (10 * ratio, 8 / ratio) + else: + figsize = (8 * ratio, 12 / ratio) + + # Create the figure + if layout == "horizontal": + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + else: + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) + + # prepare images + if im_uncal is None: + im_uncal = self.braggvectors.histogram(mode="raw") + if im_cal is None: + im_cal = self.braggvectors.histogram(mode="cal") + + # display images + show(im_cal, figax=(fig, ax1), **visp_cal) + show(im_uncal, figax=(fig, ax2), **visp_uncal) + ax1.set_title("Calibrated", size=titlesize) + ax2.set_title("Uncalibrated", size=titlesize) + + # Get the coordinate axes + + # get the directions + + # calibrated + rotation = self.coordinate_rotation_radians + xaxis_cal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_cal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + + # uncalibrated + QRrot = self.calibration.get_QR_rotation() + rotation = np.sum([self.coordinate_rotation_radians, -QRrot]) + xaxis_uncal = np.array([np.cos(rotation), np.sin(rotation)]) + yaxis_uncal = np.array( + [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)] + ) + # inversion + if self.calibration.get_QR_flip(): + xaxis_uncal = np.array([xaxis_uncal[1], xaxis_uncal[0]]) + yaxis_uncal = np.array([yaxis_uncal[1], yaxis_uncal[0]]) + + # set the lengths + Lmean = np.mean([im_cal.shape[0], im_cal.shape[1]]) / 2 + xaxis_cal *= Lmean / camera_length + yaxis_cal *= Lmean / camera_length + xaxis_uncal *= Lmean / camera_length + yaxis_uncal *= Lmean / camera_length + + # Get the g-vectors + + # calibrated + g1_cal = np.array(self.g1) + g2_cal = np.array(self.g2) + + # uncalibrated + R = np.array([[np.cos(QRrot), -np.sin(QRrot)], [np.sin(QRrot), np.cos(QRrot)]]) + g1_uncal = np.matmul(g1_cal, R) + g2_uncal = np.matmul(g2_cal, R) + # inversion + if self.calibration.get_QR_flip(): + g1_uncal = np.array([g1_uncal[1], g1_uncal[0]]) + g2_uncal = np.array([g2_uncal[1], g2_uncal[0]]) + + # Set origin positions + if origin_uncal is None: + origin_uncal = self.calibration.get_origin_mean() + if origin_cal is None: + origin_cal = self.calibration.get_origin_mean() + + # Draw calibrated coordinate axes + coordax_width = Lmean * 2 / 100 + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=xaxis_cal[1], + dy=xaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=yaxis_cal[1], + dy=yaxis_cal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax1.text( + x=origin_cal[1] + xaxis_cal[1] * 1.16, + y=origin_cal[0] + xaxis_cal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax1.text( + x=origin_cal[1] + yaxis_cal[1] * 1.16, + y=origin_cal[0] + yaxis_cal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw uncalibrated coordinate axes + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=xaxis_uncal[1], + dy=xaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=yaxis_uncal[1], + dy=yaxis_uncal[0], + color=color_axes, + length_includes_head=True, + width=coordax_width, + head_width=coordax_width * 5, + ) + ax2.text( + x=origin_uncal[1] + xaxis_uncal[1] * 1.16, + y=origin_uncal[0] + xaxis_uncal[0] * 1.16, + s="x", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax2.text( + x=origin_uncal[1] + yaxis_uncal[1] * 1.16, + y=origin_uncal[0] + yaxis_uncal[0] * 1.16, + s="y", + fontsize=size_labels, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw the calibrated g-vectors + + # draw the g vectors + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=g1_cal[1], + dy=g1_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax1.arrow( + x=origin_cal[1], + y=origin_cal[0], + dx=g2_cal[1], + dy=g2_cal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax1.text( + x=origin_cal[1] + g1_cal[1] * 1.16, + y=origin_cal[0] + g1_cal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax1.text( + x=origin_cal[1] + g2_cal[1] * 1.16, + y=origin_cal[0] + g2_cal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # Draw the uncalibrated g-vectors + + # draw the g vectors + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=g1_uncal[1], + dy=g1_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax2.arrow( + x=origin_uncal[1], + y=origin_uncal[0], + dx=g2_uncal[1], + dy=g2_uncal[0], + color=color_gvects, + length_includes_head=True, + width=coordax_width * 0.5, + head_width=coordax_width * 2.5, + ) + ax2.text( + x=origin_uncal[1] + g1_uncal[1] * 1.16, + y=origin_uncal[0] + g1_uncal[0] * 1.16, + s=r"$g_1$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax2.text( + x=origin_uncal[1] + g2_uncal[1] * 1.16, + y=origin_uncal[0] + g2_uncal[0] * 1.16, + s=r"$g_2$", + fontsize=size_labels * 0.88, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # show/return + if not returnfig: + plt.show() + return + else: + return fig, (ax1, ax2) + def show_lattice_vectors( ar, x0, @@ -887,7 +1448,10 @@ def show_lattice_vectors( returnfig=False, **kwargs, ): - """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" + """ + Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). + g1 and g2 are 2-tuples (gx,gy). + """ fig, ax = show(ar, returnfig=True, **kwargs) # Add vectors @@ -940,10 +1504,13 @@ def show_bragg_indexing( """ Shows an array with an overlay describing the Bragg directions - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. + Parameters + ---------- + ar : np.ndarray + The display image + bragg_directions : PointList + The Bragg scattering directions. Must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. """ assert isinstance(bragg_directions, PointList) for k in ("qx", "qy", "h", "k"): @@ -987,6 +1554,7 @@ def copy(self, name=None): "g1g2_map", "strainmap_g1g2", "strainmap_rotated", + "mask", ): if hasattr(self, attr): setattr(strainmap_copy, attr, getattr(self, attr)) diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index b6077c412..00309ec36 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -8,6 +8,7 @@ from matplotlib.axes import Axes from matplotlib.colors import is_color_like from matplotlib.figure import Figure +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from py4DSTEM.data import Calibration, DiffractionSlice, RealSlice from py4DSTEM.visualize.overlay import ( add_annuli, @@ -75,6 +76,7 @@ def show( theta=None, title=None, show_fft=False, + show_cbar=False, **kwargs ): """ @@ -302,7 +304,8 @@ def show( does not add a scalebar. If a dict is passed, it is propagated to the add_scalebar function which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits for scalebar. If False, no scalebar is added. - show_fft (Bool): if True, plots 2D-fft of array + show_fft (bool): if True, plots 2D-fft of array + show_cbar (bool) : if True, adds cbar **kwargs: any keywords accepted by matplotlib's ax.matshow() Returns: @@ -607,6 +610,10 @@ def show( ax.matshow( mask_display, cmap=cmap, alpha=mask_alpha, vmin=vmin, vmax=vmax ) + if show_cbar: + ax_divider = make_axes_locatable(ax) + c_axis = ax_divider.append_axes("right", size="7%") + fig.colorbar(cax, cax=c_axis) # ...or, plot its histogram else: hist, bin_edges = np.histogram(