Skip to content

Commit

Permalink
update uncertainty viz
Browse files Browse the repository at this point in the history
  • Loading branch information
gvarnavi committed Nov 4, 2023
1 parent 341d879 commit d32b18d
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 31 deletions.
257 changes: 226 additions & 31 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import ImageGrid
from py4DSTEM.visualize import show, show_complex
from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex
from scipy.ndimage import rotate

try:
Expand All @@ -23,7 +23,11 @@
from py4DSTEM.process.phase.iterative_ptychographic_constraints import (
PtychographicConstraints,
)
from py4DSTEM.process.phase.utils import AffineTransform, polar_aliases
from py4DSTEM.process.phase.utils import (
AffineTransform,
generate_batches,
polar_aliases,
)
from py4DSTEM.process.utils import (
electron_wavelength_angstrom,
fourier_resample,
Expand Down Expand Up @@ -2237,6 +2241,226 @@ def _return_object_fft(
obj = self._crop_rotate_object_fov(asnumpy(obj))
return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))

def _return_self_consistency_errors(
self,
max_batch_size=None,
):
"""Compute the self-consistency errors for each probe position"""

xp = self._xp
asnumpy = self._asnumpy

# Batch-size
if max_batch_size is None:
max_batch_size = self._num_diffraction_patterns

# Re-initialize fractional positions and vector patches
errors = np.array([])
positions_px = self._positions_px.copy()

for start, end in generate_batches(
self._num_diffraction_patterns, max_batch=max_batch_size
):
# batch indices
self._positions_px = positions_px[start:end]
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)
(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()
amplitudes = self._amplitudes[start:end]

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)

# Normalized mean-squared errors
batch_errors = xp.sum(
xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1)
)
errors = np.hstack((errors, batch_errors))

self._positions_px = positions_px.copy()
errors /= self._mean_diffraction_intensity

return asnumpy(errors)

def show_uncertainty_visualization(
self,
errors=None,
max_batch_size=None,
kde_sigma=None,
plot_histogram=True,
plot_contours=False,
**kwargs,
):
"""Plot uncertainty visualization using self-consistency errors"""

if errors is None:
errors = self._return_self_consistency_errors(max_batch_size=max_batch_size)

if kde_sigma is None:
kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0]

xp = self._xp
asnumpy = self._asnumpy
gaussian_filter = self._gaussian_filter

## Kernel Density Estimation

# rotated basis
angle = (
self._rotation_best_rad
if self._rotation_best_transpose
else -self._rotation_best_rad
)

tf = AffineTransform(angle=angle)
rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp)

padding = xp.min(rotated_points, axis=0).astype("int")

# bilinear sampling
pixel_output = np.array(self.object_cropped.shape) + asnumpy(2 * padding)
pixel_size = pixel_output.prod()

xa = rotated_points[:, 0]
ya = rotated_points[:, 1]

# bilinear sampling
xF = xp.floor(xa).astype("int")
yF = xp.floor(ya).astype("int")
dx = xa - xF
dy = ya - yF

# resampling
inds_1D = xp.ravel_multi_index(
xp.hstack(
[
[xF, yF],
[xF + 1, yF],
[xF, yF + 1],
[xF + 1, yF + 1],
]
),
pixel_output,
mode=["wrap", "wrap"],
)

weights = xp.hstack(
(
(1 - dx) * (1 - dy),
(dx) * (1 - dy),
(1 - dx) * (dy),
(dx) * (dy),
)
)

pix_count = xp.reshape(
xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output
)

pix_output = xp.reshape(
xp.bincount(
inds_1D,
weights=weights * xp.tile(xp.asarray(errors), 4),
minlength=pixel_size,
),
pixel_output,
)

# kernel density estimate
pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap")
pix_count[pix_count == 0.0] = np.inf
pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap")
pix_output /= pix_count
pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]]
pix_output, _, _ = return_scaled_histogram_ordering(
pix_output.get(), normalize=True
)

## Visualization
if plot_histogram:
spec = GridSpec(
ncols=1,
nrows=2,
height_ratios=[1, 4],
hspace=0.15,
)
auto_figsize = (4, 5.25)
else:
spec = GridSpec(
ncols=1,
nrows=1,
)
auto_figsize = (4, 4)

figsize = kwargs.pop("figsize", auto_figsize)

fig = plt.figure(figsize=figsize)

if plot_histogram:
ax_hist = fig.add_subplot(spec[0])

counts, bins = np.histogram(errors, bins=50)
ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5)
ax_hist.set_ylabel("Counts")
ax_hist.set_xlabel("Normalized Squared Error")

ax = fig.add_subplot(spec[-1])

cmap = kwargs.pop("cmap", "magma")
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)

cropped_object_angle, vmin, vmax = return_scaled_histogram_ordering(
np.angle(self.object_cropped),
vmin=vmin,
vmax=vmax,
)

extent = [
0,
self.sampling[1] * cropped_object_angle.shape[1],
self.sampling[0] * cropped_object_angle.shape[0],
0,
]

ax.imshow(
cropped_object_angle,
vmin=vmin,
vmax=vmax,
extent=extent,
alpha=1 - pix_output,
cmap=cmap,
**kwargs,
)

if plot_contours:
aligned_points = asnumpy(rotated_points - padding)
aligned_points[:, 0] *= self.sampling[0]
aligned_points[:, 1] *= self.sampling[1]

ax.tricontour(
aligned_points[:, 1],
aligned_points[:, 0],
errors,
colors="grey",
levels=5,
# linestyles='dashed',
linewidths=0.5,
)

ax.set_ylabel("x [A]")
ax.set_xlabel("y [A]")
ax.set_xlim((extent[0], extent[1]))
ax.set_ylim((extent[2], extent[3]))
ax.xaxis.set_ticks_position("bottom")

spec.tight_layout(fig)

def show_fourier_probe(
self,
probe=None,
Expand Down Expand Up @@ -2383,32 +2607,3 @@ def object_cropped(self):
"""Cropped and rotated object"""

return self._crop_rotate_object_fov(self._object)

@property
def self_consistency_errors(self):
"""Compute the self-consistency errors for each probe position"""

xp = self._xp
asnumpy = self._asnumpy

# Re-initialize fractional positions and vector patches, max_batch_size = None
self._positions_px_fractional = self._positions_px - xp.round(
self._positions_px
)

(
self._vectorized_patch_indices_row,
self._vectorized_patch_indices_col,
) = self._extract_vectorized_patch_indices()

# Overlaps
_, _, overlap = self._overlap_projection(self._object, self._probe)
fourier_overlap = xp.fft.fft2(overlap)

# Normalized mean-squared errors
error = xp.sum(
xp.abs(self._amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1)
)
error /= self._mean_diffraction_intensity

return asnumpy(error)
31 changes: 31 additions & 0 deletions py4DSTEM/visualize/vis_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,34 @@ def show_complex(

if returnfig:
return fig, ax


def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False):
if vmin is None:
vmin = 0.02
if vmax is None:
vmax = 0.98

vals = np.sort(array.ravel())
ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
ind_vmin = np.max([0, ind_vmin])
ind_vmax = np.min([len(vals) - 1, ind_vmax])
vmin = vals[ind_vmin]
vmax = vals[ind_vmax]

if vmax == vmin:
vmin = vals[0]
vmax = vals[-1]

scaled_array = array.copy()
scaled_array = np.where(scaled_array < vmin, vmin, scaled_array)
scaled_array = np.where(scaled_array > vmax, vmax, scaled_array)

if normalize:
scaled_array -= scaled_array.min()
scaled_array /= scaled_array.max()
vmin = 0
vmax = 1

return scaled_array, vmin, vmax

0 comments on commit d32b18d

Please sign in to comment.