Skip to content

Commit

Permalink
Thnks fr th Mmr(s)
Browse files Browse the repository at this point in the history
  • Loading branch information
smribet committed Nov 1, 2023
1 parent 67e15e7 commit 2d48616
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 15 deletions.
24 changes: 16 additions & 8 deletions py4DSTEM/process/phase/iterative_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,12 @@ def _normalize_diffraction_intensities(
mean_intensity = 0

diffraction_intensities = self._asnumpy(diffraction_intensities)
if positions_mask is not None:
number_of_patterns = np.count_nonzero(self._positions_mask.ravel())
sx, sy = np.where(~self._positions_mask)
else:
number_of_patterns = np.prod(diffraction_intensities.shape[:2])

if crop_patterns:
crop_x = int(
np.minimum(
Expand All @@ -1181,8 +1187,7 @@ def _normalize_diffraction_intensities(
region_of_interest_shape = (crop_w * 2, crop_w * 2)
amplitudes = np.zeros(
(
diffraction_intensities.shape[0],
diffraction_intensities.shape[1],
number_of_patterns,
crop_w * 2,
crop_w * 2,
),
Expand All @@ -1198,13 +1203,19 @@ def _normalize_diffraction_intensities(

else:
region_of_interest_shape = diffraction_intensities.shape[-2:]
amplitudes = np.zeros(diffraction_intensities.shape, dtype=np.float32)
amplitudes = np.zeros(
(number_of_patterns,) + region_of_interest_shape, dtype=np.float32
)

com_fitted_x = self._asnumpy(com_fitted_x)
com_fitted_y = self._asnumpy(com_fitted_y)

counter = 0
for rx in range(diffraction_intensities.shape[0]):
for ry in range(diffraction_intensities.shape[1]):
if positions_mask is not None:
if rx in sx and ry in sy:
continue
intensities = get_shifted_ar(
diffraction_intensities[rx, ry],
-com_fitted_x[rx, ry],
Expand All @@ -1219,13 +1230,10 @@ def _normalize_diffraction_intensities(
)

mean_intensity += np.sum(intensities)
amplitudes[rx, ry] = np.sqrt(np.maximum(intensities, 0))
amplitudes[counter] = np.sqrt(np.maximum(intensities, 0))
counter += 1

amplitudes = xp.reshape(amplitudes, (-1,) + region_of_interest_shape)
amplitudes = xp.asarray(amplitudes)
if positions_mask is not None:
amplitudes = amplitudes[positions_mask.ravel()]

mean_intensity /= amplitudes.shape[0]

return amplitudes, mean_intensity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)

if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
raise ValueError(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)
if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
raise ValueError(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)
if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
if object_type != "potential":
raise NotImplementedError()

if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down
2 changes: 1 addition & 1 deletion py4DSTEM/process/phase/iterative_overlap_tomography.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
if object_type != "potential":
raise NotImplementedError()

if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
raise ValueError(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)
if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
f"object_type must be either 'potential' or 'complex', not {object_type}"
)

if positions_mask.dtype != "bool":
if positions_mask is not None and positions_mask.dtype != "bool":
warnings.warn(
("`positions_mask` converted to `bool` array"),
UserWarning,
Expand Down

0 comments on commit 2d48616

Please sign in to comment.