Skip to content

Commit

Permalink
Merge branch 'phase_contrast' of github.com:py4dstem/py4DSTEM into ph…
Browse files Browse the repository at this point in the history
…ase_contrast
  • Loading branch information
gvarnavi committed Nov 3, 2023
2 parents 2e59d1c + 8323bc8 commit 341d879
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 57 deletions.
22 changes: 14 additions & 8 deletions py4DSTEM/braggvectors/braggvector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,20 @@ def fit_origin(
from py4DSTEM.process.calibration import fit_origin

if mask_check_data is True:
# TODO - replace this bad hack for the mask for the origin fit
mask = np.logical_not(q_meas[0] == 0)
qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(
tuple(q_meas),
mask=mask,
)
else:
qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(tuple(q_meas))
data_mask = np.logical_not(q_meas[0] == 0)
if mask is None:
mask = data_mask
else:
mask = np.logical_and(mask, data_mask)

qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(
tuple(q_meas),
mask=mask,
fitfunction=fitfunction,
robust=robust,
robust_steps=robust_steps,
robust_thresh=robust_thresh,
)

# try to add to calibration
try:
Expand Down
38 changes: 26 additions & 12 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ def _normalize_diffraction_intensities(
com_fitted_x,
com_fitted_y,
crop_patterns,
positions_mask,
):
"""
Fix diffraction intensities CoM, shift to origin, and take square root
Expand All @@ -1147,6 +1148,8 @@ def _normalize_diffraction_intensities(
crop_patterns: bool
if True, crop patterns to avoid wrap around of patterns
when centering
positions_mask: np.ndarray, optional
Boolean real space mask to select positions in datacube to skip for reconstruction
Returns
-------
Expand All @@ -1160,6 +1163,11 @@ def _normalize_diffraction_intensities(
mean_intensity = 0

diffraction_intensities = self._asnumpy(diffraction_intensities)
if positions_mask is not None:
number_of_patterns = np.count_nonzero(self._positions_mask.ravel())
else:
number_of_patterns = np.prod(diffraction_intensities.shape[:2])

if crop_patterns:
crop_x = int(
np.minimum(
Expand All @@ -1178,8 +1186,7 @@ def _normalize_diffraction_intensities(
region_of_interest_shape = (crop_w * 2, crop_w * 2)
amplitudes = np.zeros(
(
diffraction_intensities.shape[0],
diffraction_intensities.shape[1],
number_of_patterns,
crop_w * 2,
crop_w * 2,
),
Expand All @@ -1195,13 +1202,19 @@ def _normalize_diffraction_intensities(

else:
region_of_interest_shape = diffraction_intensities.shape[-2:]
amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32)
amplitudes = np.zeros(
(number_of_patterns,) + region_of_interest_shape, dtype=np.float32
)

com_fitted_x = self._asnumpy(com_fitted_x)
com_fitted_y = self._asnumpy(com_fitted_y)

counter = 0
for rx in range(diffraction_intensities.shape[0]):
for ry in range(diffraction_intensities.shape[1]):
if positions_mask is not None:
if not self._positions_mask[rx, ry]:
continue
intensities = get_shifted_ar(
diffraction_intensities[rx, ry],
-com_fitted_x[rx, ry],
Expand All @@ -1216,9 +1229,9 @@ def _normalize_diffraction_intensities(
)

mean_intensity += np.sum(intensities)
amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0))
amplitudes[counter] = np.sqrt(np.maximum(intensities, 0))
counter += 1

amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape)
amplitudes = xp.asarray(amplitudes)
mean_intensity /= amplitudes.shape[0]

Expand Down Expand Up @@ -1535,7 +1548,9 @@ def _set_polar_parameters(self, parameters: dict):
else:
raise ValueError("{} not a recognized parameter".format(symbol))

def _calculate_scan_positions_in_pixels(self, positions: np.ndarray):
def _calculate_scan_positions_in_pixels(
self, positions: np.ndarray, positions_mask
):
"""
Method to compute the initial guess of scan positions in pixels.
Expand All @@ -1544,6 +1559,8 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray):
positions: (J,2) np.ndarray or None
Input probe positions in Å.
If None, a raster scan using experimental parameters is constructed.
positions_mask: np.ndarray, optional
Boolean real space mask to select positions in datacube to skip for reconstruction
Returns
-------
Expand Down Expand Up @@ -1592,6 +1609,9 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray):
positions = np.array([x.ravel(), y.ravel()]).T
positions -= np.min(positions, axis=0)

if positions_mask is not None:
positions = positions[positions_mask.ravel()]

if self._object_padding_px is None:
float_padding = self._region_of_interest_shape / 2
self._object_padding_px = (float_padding, float_padding)
Expand Down Expand Up @@ -2286,22 +2306,16 @@ def show_object_fft(self, obj=None, **kwargs):

figsize = kwargs.pop("figsize", (6, 6))
cmap = kwargs.pop("cmap", "magma")
vmin = kwargs.pop("vmin", 0)
vmax = kwargs.pop("vmax", 1)
power = kwargs.pop("power", 0.2)

pixelsize = 1 / (object_fft.shape[1] * self.sampling[1])
show(
object_fft,
figsize=figsize,
cmap=cmap,
vmin=vmin,
vmax=vmax,
scalebar=True,
pixelsize=pixelsize,
ticks=False,
pixelunits=r"$\AA^{-1}$",
power=power,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,17 @@ class MixedstateMultislicePtychographicReconstruction(PtychographicReconstructio
initial_scan_positions: np.ndarray, optional
Probe positions in Å for each diffraction intensity
If None, initialized to a grid scan
theta_x: float
x tilt of propagator (in degrees)
theta_y: float
y tilt of propagator (in degrees)
middle_focus: bool
if True, adds half the sample thickness to the defocus
object_type: str, optional
The object can be reconstructed as a real potential ('potential') or a complex
object ('complex')
positions_mask: np.ndarray, optional
Boolean real space mask to select positions in datacube to skip for reconstruction
verbose: bool, optional
If True, class methods will inherit this and print additional information
device: str, optional
Expand Down Expand Up @@ -114,7 +122,11 @@ def __init__(
initial_object_guess: np.ndarray = None,
initial_probe_guess: np.ndarray = None,
initial_scan_positions: np.ndarray = None,
theta_x: float = 0,
theta_y: float = 0,
middle_focus: bool = False,
object_type: str = "complex",
positions_mask: np.ndarray = None,
verbose: bool = True,
device: str = "cpu",
name: str = "multi-slice_ptychographic_reconstruction",
Expand Down Expand Up @@ -162,6 +174,25 @@ def __init__(
if (key not in polar_symbols) and (key not in polar_aliases.keys()):
raise ValueError("{} not a recognized parameter".format(key))

if np.isscalar(slice_thicknesses):
mean_slice_thickness = slice_thicknesses
else:
mean_slice_thickness = np.mean(slice_thicknesses)

if middle_focus:
if "defocus" in kwargs:
kwargs["defocus"] += mean_slice_thickness * num_slices / 2
elif "C10" in kwargs:
kwargs["C10"] -= mean_slice_thickness * num_slices / 2
elif polar_parameters is not None and "defocus" in polar_parameters:
polar_parameters["defocus"] = (
polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2
)
elif polar_parameters is not None and "C10" in polar_parameters:
polar_parameters["C10"] = (
polar_parameters["C10"] - mean_slice_thickness * num_slices / 2
)

self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))

if polar_parameters is None:
Expand All @@ -186,6 +217,13 @@ def __init__(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)

if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
)
positions_mask = np.asarray(positions_mask, dtype="bool")

self.set_save_defaults()

# Data
Expand All @@ -201,6 +239,7 @@ def __init__(
self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
self._rolloff = rolloff
self._object_type = object_type
self._positions_mask = positions_mask
self._object_padding_px = object_padding_px
self._verbose = verbose
self._device = device
Expand All @@ -210,13 +249,17 @@ def __init__(
self._num_probes = num_probes
self._num_slices = num_slices
self._slice_thicknesses = slice_thicknesses
self._theta_x = theta_x
self._theta_y = theta_y

def _precompute_propagator_arrays(
self,
gpts: Tuple[int, int],
sampling: Tuple[float, float],
energy: float,
slice_thicknesses: Sequence[float],
theta_x: float,
theta_y: float,
):
"""
Precomputes propagator arrays complex wave-function will be convolved by,
Expand All @@ -232,6 +275,10 @@ def _precompute_propagator_arrays(
The electron energy of the wave functions in eV
slice_thicknesses: Sequence[float]
Array of slice thicknesses in A
theta_x: float
x tilt of propagator (in degrees)
theta_y: float
y tilt of propagator (in degrees)
Returns
-------
Expand All @@ -251,13 +298,23 @@ def _precompute_propagator_arrays(
propagators = xp.empty(
(num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
)

theta_x = np.deg2rad(theta_x)
theta_y = np.deg2rad(theta_y)

for i, dz in enumerate(slice_thicknesses):
propagators[i] = xp.exp(
1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(
1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
)
propagators[i] *= xp.exp(
1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x))
)
propagators[i] *= xp.exp(
1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))
)

return propagators

Expand Down Expand Up @@ -445,7 +502,11 @@ def preprocess(
self._amplitudes,
self._mean_diffraction_intensity,
) = self._normalize_diffraction_intensities(
self._intensities, self._com_fitted_x, self._com_fitted_y, crop_patterns
self._intensities,
self._com_fitted_x,
self._com_fitted_y,
crop_patterns,
self._positions_mask,
)

# explicitly delete namespace
Expand All @@ -454,7 +515,7 @@ def preprocess(
del self._intensities

self._positions_px = self._calculate_scan_positions_in_pixels(
self._scan_positions
self._scan_positions, self._positions_mask
)

# handle semiangle specified in pixels
Expand Down Expand Up @@ -597,6 +658,8 @@ def preprocess(
self.sampling,
self._energy,
self._slice_thicknesses,
self._theta_x,
self._theta_y,
)

# overlaps
Expand Down Expand Up @@ -3060,6 +3123,7 @@ def show_slices(
common_color_scale: bool = True,
padding: int = 0,
num_cols: int = 3,
show_fft: bool = False,
**kwargs,
):
"""
Expand All @@ -3075,12 +3139,20 @@ def show_slices(
Padding to leave uncropped
num_cols: int, optional
Number of GridSpec columns
show_fft: bool, optional
if True, plots fft of object slices
"""

if ms_object is None:
ms_object = self._object

rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding)
if show_fft:
rotated_object = np.abs(
np.fft.fftshift(
np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1)
)
)
rotated_shape = rotated_object.shape

if np.iscomplexobj(rotated_object):
Expand All @@ -3098,8 +3170,21 @@ def show_slices(

axsize = kwargs.pop("axsize", (3, 3))
cmap = kwargs.pop("cmap", "magma")
vmin = np.min(rotated_object) if common_color_scale else None
vmax = np.max(rotated_object) if common_color_scale else None

if common_color_scale:
vals = np.sort(rotated_object.ravel())
ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int")
ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int")
ind_vmin = np.max([0, ind_vmin])
ind_vmax = np.min([len(vals) - 1, ind_vmax])
vmin = vals[ind_vmin]
vmax = vals[ind_vmax]
if vmax == vmin:
vmin = vals[0]
vmax = vals[-1]
else:
vmax = None
vmin = None
vmin = kwargs.pop("vmin", vmin)
vmax = kwargs.pop("vmax", vmax)

Expand Down
Loading

0 comments on commit 341d879

Please sign in to comment.