Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I've got a blank stress tensor, baby, and I'll write your strain #551

Merged
merged 17 commits into from
Nov 6, 2023
160 changes: 101 additions & 59 deletions py4DSTEM/braggvectors/braggvector_methods.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# BraggVectors methods

import numpy as np
from scipy.ndimage import gaussian_filter
from warnings import warn
import inspect
from warnings import warn

from emdfile import Array, Metadata, tqdmnd, _read_metadata
import matplotlib.pyplot as plt
import numpy as np
from emdfile import Array, Metadata, _read_metadata, tqdmnd
from py4DSTEM import show
from py4DSTEM.datacube import VirtualImage
from scipy.ndimage import gaussian_filter


class BraggVectorMethods:
Expand Down Expand Up @@ -518,6 +520,7 @@ def fit_origin(
mask_check_data=True,
plot=True,
plot_range=None,
cmap="RdBu_r",
returncalc=True,
**kwargs,
):
Expand All @@ -537,6 +540,7 @@ def fit_origin(
mask_check_data (bool): Get mask from origin measurements equal to zero. (TODO - replace)
plot (bool, optional): plot results
plot_range (float): min and max color range for plot (pixels)
cmap (colormap): plotting colormap

Returns:
(variable): Return value depends on returnfitp. If ``returnfitp==False``
Expand All @@ -561,75 +565,98 @@ def fit_origin(
else:
qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(tuple(q_meas))

# try to add to calibration
# try to add update calibration metadata
try:
self.calibration.set_origin([qx0_fit, qy0_fit])
self.calibration.set_origin((qx0_fit, qy0_fit))
self.setcal()
except AttributeError:
warn(
"No calibration found on this datacube - fit values are not being stored"
)
pass
if plot:
from py4DSTEM.visualize import show_image_grid

if mask is None:
qx0_meas, qy0_meas = q_meas
qx0_res_plot = qx0_residuals
qy0_res_plot = qy0_residuals
else:
qx0_meas = np.ma.masked_array(q_meas[0], mask=np.logical_not(mask))
qy0_meas = np.ma.masked_array(q_meas[1], mask=np.logical_not(mask))
qx0_res_plot = np.ma.masked_array(
qx0_residuals, mask=np.logical_not(mask)
)
qy0_res_plot = np.ma.masked_array(
qy0_residuals, mask=np.logical_not(mask)
)
qx0_mean = np.mean(qx0_fit)
qy0_mean = np.mean(qy0_fit)

if plot_range is None:
plot_range = 2 * np.max(qx0_fit - qx0_mean)

cmap = kwargs.get("cmap", "RdBu_r")
kwargs.pop("cmap", None)
axsize = kwargs.get("axsize", (6, 2))
kwargs.pop("axsize", None)

show_image_grid(
lambda i: [
qx0_meas - qx0_mean,
qx0_fit - qx0_mean,
qx0_res_plot,
qy0_meas - qy0_mean,
qy0_fit - qy0_mean,
qy0_res_plot,
][i],
H=2,
W=3,
# show
if plot:
self.show_origin_fit(
q_meas[0],
q_meas[1],
qx0_fit,
qy0_fit,
qx0_residuals,
qy0_residuals,
mask=mask,
plot_range=plot_range,
cmap=cmap,
axsize=axsize,
title=[
"measured origin, x",
"fitorigin, x",
"residuals, x",
"measured origin, y",
"fitorigin, y",
"residuals, y",
],
vmin=-1 * plot_range,
vmax=1 * plot_range,
intensity_range="absolute",
**kwargs,
)

# update calibration metadata
self.calibration.set_origin((qx0_fit, qy0_fit))
self.setcal()

# return
if returncalc:
return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals

def show_origin_fit(
self,
qx0_meas,
qy0_meas,
qx0_fit,
qy0_fit,
qx0_residuals,
qy0_residuals,
mask=None,
plot_range=None,
cmap="RdBu_r",
**kwargs,
):
# apply mask
if mask is not None:
qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask))
qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask))
qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask))
qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask))
qx0_mean = np.mean(qx0_fit)
qy0_mean = np.mean(qy0_fit)

# set range
if plot_range is None:
plot_range = max(
(
1.5 * np.max(np.abs(qx0_fit - qx0_mean)),
1.5 * np.max(np.abs(qy0_fit - qy0_mean)),
)
)

# set figsize
imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0])
axsize = (3 * imsize_ratio, 3 / imsize_ratio)
axsize = kwargs.pop("axsize", axsize)

# plot
fig, ax = show(
[
[qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals],
[qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals],
],
cmap=cmap,
axsize=axsize,
title=[
"measured origin, x",
"fitorigin, x",
"residuals, x",
"measured origin, y",
"fitorigin, y",
"residuals, y",
],
vmin=-1 * plot_range,
vmax=1 * plot_range,
intensity_range="absolute",
show_cbar=True,
returnfig=True,
**kwargs,
)
plt.tight_layout()

return

def fit_p_ellipse(
self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs
):
Expand Down Expand Up @@ -765,6 +792,21 @@ def mask_in_R(self, mask, update_inplace=False, returncalc=True):
else:
return

def to_strainmap(self, name: str = None):
"""
Generate a StrainMap object from the BraggVectors
equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors)

Args:
name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'.

Returns:
py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors
"""
from py4DSTEM.process.strain import StrainMap

return StrainMap(self, name) if name else StrainMap(self)


######### END BraggVectorMethods CLASS ########

Expand Down
12 changes: 6 additions & 6 deletions py4DSTEM/braggvectors/braggvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def setcal(
if pixel is None:
pixel = False if c.get_Q_pixel_size() == 1 else True
if rotate is None:
rotate = False if c.get_QR_rotflip() is None else True
rotate = False if c.get_QR_rotation() is None else True

# validate requested state
if center:
Expand All @@ -210,7 +210,7 @@ def setcal(
if pixel:
assert c.get_Q_pixel_size() is not None, "Requested calibration not found"
if rotate:
assert c.get_QR_rotflip() is not None, "Requested calibration not found"
assert c.get_QR_rotation() is not None, "Requested calibration not found"

# set the calibrations
self._calstate = {
Expand Down Expand Up @@ -478,15 +478,15 @@ def _transform(

# Q/R rotation
if rotate:
flip = cal.get_QR_flip()
theta = cal.get_QR_rotation_degrees()
assert flip is not None, "Requested calibration was not found!"
theta = cal.get_QR_rotation()
assert theta is not None, "Requested calibration was not found!"
flip = cal.get_QR_flip()
flip = False if flip is None else flip
# rotation matrix
R = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
# apply
# rotate and flip
if flip:
positions = R @ np.vstack((ans["qy"], ans["qx"]))
else:
Expand Down
31 changes: 31 additions & 0 deletions py4DSTEM/data/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
self["R_pixel_size"] = 1
self["Q_pixel_units"] = "pixels"
self["R_pixel_units"] = "pixels"
self["QR_flip"] = False

# EMD root property
@property
Expand Down Expand Up @@ -666,8 +667,17 @@ def ellipse(self, x):

# Q/R-space rotation and flip

@call_calibrate
def set_QR_rotation(self, x):
self._params["QR_rotation"] = x
self._params["QR_rotation_degrees"] = np.degrees(x)

def get_QR_rotation(self):
return self._get_value("QR_rotation")

@call_calibrate
def set_QR_rotation_degrees(self, x):
self._params["QR_rotation"] = np.radians(x)
self._params["QR_rotation_degrees"] = x

def get_QR_rotation_degrees(self):
Expand All @@ -689,10 +699,31 @@ def set_QR_rotflip(self, rot_flip):
flip (bool): True indicates a Q/R axes flip
"""
rot, flip = rot_flip
self._params["QR_rotation"] = rot
self._params["QR_rotation_degrees"] = np.degrees(rot)
self._params["QR_flip"] = flip

@call_calibrate
def set_QR_rotflip_degrees(self, rot_flip):
"""
Args:
rot_flip (tuple), (rot, flip) where:
rot (number): rotation in degrees
flip (bool): True indicates a Q/R axes flip
"""
rot, flip = rot_flip
self._params["QR_rotation"] = np.radians(rot)
self._params["QR_rotation_degrees"] = rot
self._params["QR_flip"] = flip

def get_QR_rotflip(self):
rot = self.get_QR_rotation()
flip = self.get_QR_flip()
if rot is None or flip is None:
return None
return (rot, flip)

def get_QR_rotflip_degrees(self):
rot = self.get_QR_rotation_degrees()
flip = self.get_QR_flip()
if rot is None or flip is None:
Expand Down
Loading