Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding strain map outputs to whole pattern fitting #589

Merged
merged 7 commits into from
Dec 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion py4DSTEM/process/wholepatternfit/wpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
WPFModelType,
Parameter,
)
from py4DSTEM.data import RealSlice
from py4DSTEM.process.strain.latticevectors import get_strain_from_reference_g1g2

from typing import Optional
import numpy as np
Expand Down Expand Up @@ -448,7 +450,7 @@ def accept_mean_CBED_fit(self):

def get_lattice_maps(self) -> list[RealSlice]:
"""
Get the fitted reciprical lattice vectors refined at each scan point.
Get the fitted reciprocal lattice vectors refined at each scan point.

Returns
-------
Expand Down Expand Up @@ -482,6 +484,68 @@ def get_lattice_maps(self) -> list[RealSlice]:

return g_maps

def get_strain_maps(
self,
) -> list[RealSlice]:
"""
Calculate a strain map from the the fitted reciprocal lattice vectors
refined at each scan point. Currently we output strain maps aligned to the
coordinate (qx,qy) directions, and we assume a median reference lattice.

TODO -allow rotation of Q w.r.t. R coordinate space.
-pass in reference lattice, or a mask to the reference ROI.

Returns
-------
strain_maps: list[RealSlice]
RealSlice objects containing the strain data as a function of scan positions,
for each lattice fit in the whole pattern fitting model.
"""
assert hasattr(self, "fit_data"), "Please run fitting first!"

lattices = [m for m in self.model if WPFModelType.LATTICE in m.model_type]
strain_maps = []

for lat in lattices:
# Construct the stack of lattice vectors
g1g2_map = RealSlice(
np.stack(
[
self.fit_data.data[lat.params["ux"].offset],
self.fit_data.data[lat.params["uy"].offset],
self.fit_data.data[lat.params["vx"].offset],
self.fit_data.data[lat.params["vy"].offset],
np.logical_and(
self.fit_metrics["status"].data
>= 0, # negative status indicates fit error
self.fit_metrics["nfev"].data > 0,
),
],
axis=0,
),
slicelabels=["g1x", "g1y", "g2x", "g2y", "mask"],
name=lat.name,
)

# Get the reference lattice vectors
# TODO - update this to allow other refs, ROI, etc.
mask = (g1g2_map.get_slice("mask").data).astype("bool")
g1_ref = (
np.median(g1g2_map.get_slice("g1x").data[mask]),
np.median(g1g2_map.get_slice("g1y").data[mask]),
)
g2_ref = (
np.median(g1g2_map.get_slice("g2x").data[mask]),
np.median(g1g2_map.get_slice("g2y").data[mask]),
)

# calculate strain
strain_map = get_strain_from_reference_g1g2(g1g2_map, g1_ref, g2_ref)
strain_map.name = g1g2_map.name + " strain map"
strain_maps.append(strain_map)

return strain_maps

def _setup_static_data(self):
"""
Generate basic data that each model can access during the fitting routine
Expand Down