Skip to content

Commit

Permalink
Merge pull request py4dstem#589 from cophus/whole_pattern_fitting
Browse files Browse the repository at this point in the history
Adding strain map outputs to whole pattern fitting

Former-commit-id: 9455b2e
  • Loading branch information
sezelt authored Dec 19, 2023
2 parents b1875e2 + ef5495e commit 050d43a
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion py4DSTEM/process/wholepatternfit/wpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
WPFModelType,
Parameter,
)
from py4DSTEM.data import RealSlice
from py4DSTEM.process.strain.latticevectors import get_strain_from_reference_g1g2

from typing import Optional
import numpy as np
Expand Down Expand Up @@ -448,7 +450,7 @@ def accept_mean_CBED_fit(self):

def get_lattice_maps(self) -> list[RealSlice]:
"""
Get the fitted reciprical lattice vectors refined at each scan point.
Get the fitted reciprocal lattice vectors refined at each scan point.
Returns
-------
Expand Down Expand Up @@ -482,6 +484,68 @@ def get_lattice_maps(self) -> list[RealSlice]:

return g_maps

def get_strain_maps(
self,
) -> list[RealSlice]:
"""
Calculate a strain map from the the fitted reciprocal lattice vectors
refined at each scan point. Currently we output strain maps aligned to the
coordinate (qx,qy) directions, and we assume a median reference lattice.
TODO -allow rotation of Q w.r.t. R coordinate space.
-pass in reference lattice, or a mask to the reference ROI.
Returns
-------
strain_maps: list[RealSlice]
RealSlice objects containing the strain data as a function of scan positions,
for each lattice fit in the whole pattern fitting model.
"""
assert hasattr(self, "fit_data"), "Please run fitting first!"

lattices = [m for m in self.model if WPFModelType.LATTICE in m.model_type]
strain_maps = []

for lat in lattices:
# Construct the stack of lattice vectors
g1g2_map = RealSlice(
np.stack(
[
self.fit_data.data[lat.params["ux"].offset],
self.fit_data.data[lat.params["uy"].offset],
self.fit_data.data[lat.params["vx"].offset],
self.fit_data.data[lat.params["vy"].offset],
np.logical_and(
self.fit_metrics["status"].data
>= 0, # negative status indicates fit error
self.fit_metrics["nfev"].data > 0,
),
],
axis=0,
),
slicelabels=["g1x", "g1y", "g2x", "g2y", "mask"],
name=lat.name,
)

# Get the reference lattice vectors
# TODO - update this to allow other refs, ROI, etc.
mask = (g1g2_map.get_slice("mask").data).astype("bool")
g1_ref = (
np.median(g1g2_map.get_slice("g1x").data[mask]),
np.median(g1g2_map.get_slice("g1y").data[mask]),
)
g2_ref = (
np.median(g1g2_map.get_slice("g2x").data[mask]),
np.median(g1g2_map.get_slice("g2y").data[mask]),
)

# calculate strain
strain_map = get_strain_from_reference_g1g2(g1g2_map, g1_ref, g2_ref)
strain_map.name = g1g2_map.name + " strain map"
strain_maps.append(strain_map)

return strain_maps

def _setup_static_data(self):
"""
Generate basic data that each model can access during the fitting routine
Expand Down

0 comments on commit 050d43a

Please sign in to comment.