Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
cophus committed Dec 18, 2023
1 parent dcaee94 commit 5c16a63
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions py4DSTEM/process/wholepatternfit/wpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ def get_strain_maps(
) -> list[RealSlice]:
"""
Calculate a strain map from the the fitted reciprocal lattice vectors
refined at each scan point. Currently we output strain maps aligned to the
coordinate (qx,qy) directions, and we assume a median reference lattice.
refined at each scan point. Currently we output strain maps aligned to the
coordinate (qx,qy) directions, and we assume a median reference lattice.
TODO -allow rotation of Q w.r.t. R coordinate space.
-pass in reference lattice, or a mask to the reference ROI.
Expand All @@ -503,7 +503,6 @@ def get_strain_maps(
"""
assert hasattr(self, "fit_data"), "Please run fitting first!"


lattices = [m for m in self.model if WPFModelType.LATTICE in m.model_type]
strain_maps = []

Expand All @@ -517,7 +516,8 @@ def get_strain_maps(
self.fit_data.data[lat.params["vx"].offset],
self.fit_data.data[lat.params["vy"].offset],
np.logical_and(
self.fit_metrics["status"].data >= 0, # negative status indicates fit error
self.fit_metrics["status"].data
>= 0, # negative status indicates fit error
self.fit_metrics["nfev"].data > 0,
),
],
Expand All @@ -527,28 +527,28 @@ def get_strain_maps(
name=lat.name,
)

fig,ax = plt.subplots(figsize=(4,4))
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(
g1g2_map['mask'].data.astype('float'),
vmin = 0,
vmax = 1,
)
g1g2_map["mask"].data.astype("float"),
vmin=0,
vmax=1,
)

# Get the reference lattice vectors
# TODO - update this to allow other refs, ROI, etc.
mask = (g1g2_map.get_slice("mask").data).astype('bool')
mask = (g1g2_map.get_slice("mask").data).astype("bool")
g1_ref = (
np.median(g1g2_map.get_slice("g1x").data[mask]),
np.median(g1g2_map.get_slice("g1y").data[mask]),
np.median(g1g2_map.get_slice("g1y").data[mask]),
)
g2_ref = (
np.median(g1g2_map.get_slice("g2x").data[mask]),
np.median(g1g2_map.get_slice("g2y").data[mask]),
np.median(g1g2_map.get_slice("g2y").data[mask]),
)

# calculate strain
strain_map = get_strain_from_reference_g1g2(g1g2_map, g1_ref, g2_ref)
strain_map.name = g1g2_map.name + ' strain map'
strain_map.name = g1g2_map.name + " strain map"
strain_maps.append(strain_map)

return strain_maps
Expand Down

0 comments on commit 5c16a63

Please sign in to comment.