diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index dcfd8f504..6ebb9962e 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -1341,7 +1341,18 @@ def aberration_fit( ### First pass # Convert real space shifts to Angstroms - self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) + + if force_transpose is None: + self.transpose_detected = False + else: + self.transpose_detected = force_transpose + + if force_transpose is True: + self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array( + self._scan_sampling + ) + else: + self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling) # Solve affine transformation m = asnumpy( @@ -1362,11 +1373,6 @@ def aberration_fit( self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0 self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 - if force_transpose is None: - self.transpose_detected = False - else: - self.transpose_detected = force_transpose - ### Second pass # Aberration coefs @@ -1590,8 +1596,15 @@ def score_CTF(coefs): ) ) - if force_transpose is None or force_transpose is True: - # Transposed fit + # (Relative) untransposed fit + tf = AffineTransform(angle=self.rotation_Q_to_R_rads) + rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() + aberrations_coefs, res = xp.linalg.lstsq( + gradients, rotated_shifts, rcond=None + )[:2] + + if force_transpose is None: + # (Relative) transposed fit transposed_shifts = xp.flip(self._xy_shifts_Ang, axis=1) m_T = asnumpy( xp.linalg.lstsq(self._probe_angles, transposed_shifts, rcond=None)[ @@ -1615,19 +1628,10 @@ def score_CTF(coefs): gradients, rotated_shifts_T, rcond=None )[:2] - if force_transpose is None or force_transpose is False: - # Untransposed fit - tf = AffineTransform(angle=self.rotation_Q_to_R_rads) - rotated_shifts = tf(self._xy_shifts_Ang, xp=xp).T.ravel() - aberrations_coefs, res = xp.linalg.lstsq( - gradients, rotated_shifts, rcond=None - )[:2] - - if force_transpose is None: # Compare fits if res_T.sum() < res.sum(): self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self.transpose_detected = True + self.transpose_detected = not self.transpose_detected self._aberrations_coefs = asnumpy(aberrations_coefs_T) self._rotated_shifts = rotated_shifts_T @@ -1638,15 +1642,6 @@ def score_CTF(coefs): ), UserWarning, ) - else: - self._aberrations_coefs = asnumpy(aberrations_coefs) - self._rotated_shifts = rotated_shifts - - elif force_transpose is True: - self.rotation_Q_to_R_rads = rotation_Q_to_R_rads_T - self._aberrations_coefs = asnumpy(aberrations_coefs_T) - self._rotated_shifts = rotated_shifts_T - else: self._aberrations_coefs = asnumpy(aberrations_coefs) self._rotated_shifts = rotated_shifts