diff --git a/py4DSTEM/process/wholepatternfit/wpf.py b/py4DSTEM/process/wholepatternfit/wpf.py index f206004b4..5d4a4e91f 100644 --- a/py4DSTEM/process/wholepatternfit/wpf.py +++ b/py4DSTEM/process/wholepatternfit/wpf.py @@ -7,6 +7,8 @@ WPFModelType, Parameter, ) +from py4DSTEM.data import RealSlice +from py4DSTEM.process.strain.latticevectors import get_strain_from_reference_g1g2 from typing import Optional import numpy as np @@ -448,7 +450,7 @@ def accept_mean_CBED_fit(self): def get_lattice_maps(self) -> list[RealSlice]: """ - Get the fitted reciprical lattice vectors refined at each scan point. + Get the fitted reciprocal lattice vectors refined at each scan point. Returns ------- @@ -482,6 +484,68 @@ def get_lattice_maps(self) -> list[RealSlice]: return g_maps + def get_strain_maps( + self, + ) -> list[RealSlice]: + """ + Calculate a strain map from the the fitted reciprocal lattice vectors + refined at each scan point. Currently we output strain maps aligned to the + coordinate (qx,qy) directions, and we assume a median reference lattice. + + TODO -allow rotation of Q w.r.t. R coordinate space. + -pass in reference lattice, or a mask to the reference ROI. + + Returns + ------- + strain_maps: list[RealSlice] + RealSlice objects containing the strain data as a function of scan positions, + for each lattice fit in the whole pattern fitting model. + """ + assert hasattr(self, "fit_data"), "Please run fitting first!" + + lattices = [m for m in self.model if WPFModelType.LATTICE in m.model_type] + strain_maps = [] + + for lat in lattices: + # Construct the stack of lattice vectors + g1g2_map = RealSlice( + np.stack( + [ + self.fit_data.data[lat.params["ux"].offset], + self.fit_data.data[lat.params["uy"].offset], + self.fit_data.data[lat.params["vx"].offset], + self.fit_data.data[lat.params["vy"].offset], + np.logical_and( + self.fit_metrics["status"].data + >= 0, # negative status indicates fit error + self.fit_metrics["nfev"].data > 0, + ), + ], + axis=0, + ), + slicelabels=["g1x", "g1y", "g2x", "g2y", "mask"], + name=lat.name, + ) + + # Get the reference lattice vectors + # TODO - update this to allow other refs, ROI, etc. + mask = (g1g2_map.get_slice("mask").data).astype("bool") + g1_ref = ( + np.median(g1g2_map.get_slice("g1x").data[mask]), + np.median(g1g2_map.get_slice("g1y").data[mask]), + ) + g2_ref = ( + np.median(g1g2_map.get_slice("g2x").data[mask]), + np.median(g1g2_map.get_slice("g2y").data[mask]), + ) + + # calculate strain + strain_map = get_strain_from_reference_g1g2(g1g2_map, g1_ref, g2_ref) + strain_map.name = g1g2_map.name + " strain map" + strain_maps.append(strain_map) + + return strain_maps + def _setup_static_data(self): """ Generate basic data that each model can access during the fitting routine