-
Notifications
You must be signed in to change notification settings - Fork 1
/
analyze_alignment.py
executable file
·215 lines (164 loc) · 7 KB
/
analyze_alignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import sys
import os
import math
import argparse
sys.path.append('cryoem/')
sys.path.append('cryoem/util')
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 100
plt.style.use(['dark_background'])
def main(args):
which = args.inputname
path = args.inputpath
# Read the data
print('reading theoretical data...')
temet_theoretical = read_theoretical(path+'temet_' + which + '.star')
wt_theoretical = read_theoretical(path+'wt_' + which + '.star')
assert len(temet_theoretical) == len(wt_theoretical), 'theoretical starfiles have inconsistent lengths'
print('reading experimental data...')
temet_experimental = read_experimental(path+'temet_' + which + '.par')
wt_experimental = read_experimental(path+'wt_' + which + '.par')
assert len(temet_experimental) == len(wt_experimental), 'experimental parfiles have inconsistent lengths'
assert len(temet_experimental) == len(temet_theoretical), 'experimental and theoretical particle files have inconsistent lengths'
print('read %d particles' % len(temet_theoretical))
# Compute angle errors
print('computing angular errors...')
wt_angle_errors = computeAngleErrors(wt_theoretical.quaternion, wt_experimental.quaternion)
temet_angle_errors = computeAngleErrors(temet_theoretical.quaternion, temet_experimental.quaternion)
wt_angle_mae = mean(wt_angle_errors)
temet_angle_me = mean(temet_angle_errors)
# Compute position errors
print('computing positional errors...')
wt_theoretical_shifts = zip(wt_theoretical.shiftX, wt_theoretical.shiftY)
wt_experimental_shifts = zip(wt_experimental.shiftX, wt_experimental.shiftY)
temet_theoretical_shifts = zip(temet_theoretical.shiftX, temet_theoretical.shiftY)
temet_experimental_shifts = zip(temet_experimental.shiftX, temet_experimental.shiftY)
wt_position_errors = computeShiftErrors(wt_theoretical_shifts,wt_experimental_shifts)
temet_position_errors = computeShiftErrors(temet_theoretical_shifts,temet_experimental_shifts)
wt_position_mae = mean(wt_position_errors)
temet_position_mae = mean(temet_position_errors)
print('plotting...')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))
# Plot angle errors
ax1.hist((wt_angle_errors), bins=180, alpha=0.5, color='skyblue', label='wt angle errors');
ax1.hist((temet_angle_errors), bins=180, alpha=0.5,color='red', label='temet angle errors');
ax1.title.set_text('Angle Errors')
ax1.set_xlabel("Error Distance (°)")
ax1.set_ylabel("Number of Errors")
ax1.text(0.1,0.7,"TeMet MAE: %.2f \nWT MAE: %.2f" % (temet_angle_me, wt_angle_mae),transform=ax1.transAxes)
ax1.legend()
# Plot position errors
ax2.hist((wt_position_errors), bins=100, alpha=0.5, color='skyblue', label='wt position errors');
ax2.hist((temet_position_errors), bins=100, alpha=0.5,color='red', label='temet position errors');
ax2.title.set_text('Position Errors')
ax2.set_xlabel("Error Distance (Å)")
ax2.set_ylabel("Number of Errors")
ax2.text(.1,0.7,"TeMet MAE: %.2f \nWT MAE: %.2f" % (temet_position_mae, wt_position_mae),transform=ax2.transAxes)
'displaying plots...'
ax2.legend()
plt.savefig('errors.png')
plt.show()
print('done!!!')
def read_theoretical(path):
# Read the theoretical starfile
# We only want (1-indexed): 2 (psi), 3 (phi), 4 (theta), 12 (originX), 13 (originY)
# BEWARE skiprows, starfile header lengths may vary
theoretical = pd.read_csv(path, delim_whitespace=True, header=None, skiprows=21, low_memory=False)
theoretical = theoretical[theoretical.columns[[1, 2, 3, 11, 12]]]
theoretical.columns = [ 'psi', 'phi', 'theta', 'shiftX', 'shiftY']
theoretical = theoretical.astype(float)
theoretical['quaternion'] = theoretical.apply(lambda row: euler2quat(row.phi*np.pi/180, row.theta*np.pi/180, row.psi*np.pi/180), axis=1)
return theoretical
def read_experimental(path):
# Read the experimental parfile
# BEWARE dropping last two rows.
experimental = pd.read_csv(path, delim_whitespace=True, low_memory=False)
experimental = experimental[experimental.columns[[1, 3, 2, 4, 5]]]
experimental.columns = [ 'psi', 'phi', 'theta', 'shiftX', 'shiftY']
experimental.drop(experimental.tail(2).index,inplace=True)
experimental = experimental.astype(float)
experimental['quaternion'] = experimental.apply(lambda row: euler2quat(row.phi*np.pi/180, row.theta*np.pi/180, row.psi*np.pi/180), axis=1)
return experimental
def euler2quat(alpha, beta, gamma):
ha, hb, hg = alpha / 2, beta / 2, gamma / 2
ha_p_hg = ha + hg
hg_m_ha = hg - ha
q = np.array([np.cos(ha_p_hg) * np.cos(hb),
np.sin(hg_m_ha) * np.sin(hb),
np.cos(hg_m_ha) * np.sin(hb),
np.sin(ha_p_hg) * np.cos(hb)])
return q
# Quaternion to Euler Angles (from https://github.com/asarnow/pyem geom.py)
def quat2euler(q):
ha1 = np.arctan2(q[1], q[2])
ha2 = np.arctan2(q[3], q[0])
alpha = ha2 - ha1 # np.arctan2(r21/r20)
beta = 2 * np.arccos(np.sqrt(q[0]**2 + q[3]**2)) # np.arccos*r33
gamma = ha1 + ha2 # np.arctan2(r12/-r02)
return alpha, beta, gamma
# Angular distance between two quaternions
def quatInverse(q):
d = q[0]**2 + q[1]**2 + q[2]**2 + q[3]**2
d = 1
return [q[0]/d, -q[1]/d, -q[2]/d, -q[3]/d]
def quatConj(q):
return [q[0],-q[1],-q[2],-q[3]]
def quatDist(a,b):
# Check to verify that quaternions are unit lengths
assert abs(math.sqrt(a[0]**2+a[1]**2+a[2]**2+a[3]**2)-1)<.001,"a is not a unit quaternion"
assert abs(math.sqrt(b[0]**2+b[1]**2+b[2]**2+b[3]**2)-1)<.001,"b is not a unit quaternion"
# # # Compute distance
s = a[0]*b[0] + a[1]*b[1] + a[2]*b[2] + a[3]*b[3]
s = round(s,4)
s = 2*(s**2)-1
other = (np.arccos(s))*180/np.pi
return other
# z = a*quatConj(b)
# theta = 2*np.arccos(z[0])
# o = abs(theta*180/np.pi - 180)
# print("Method2: %f"%o)
# return abs(theta*180/np.pi - 180)
# a = quatInverse(a)
# s = [a[0]*b[0] , a[1]*b[1] , a[2]*b[2] , a[3]*b[3]]
# l = math.sqrt(s[0]**2+s[1]**2+s[2]**2+s[3]**2)
# o = 2*math.atan2(l,s[3])
# return (o*180/np.pi)
# # print("s = %f" % s)
# # assert s <= 1, "product greater than 1"
# # assert s >= -1, "product less than -1"
# try:
# output = 2*np.arccos(s)*180/np.pi
# return output
# except:
# e = sys.exc_info()[0]
# print("Quat Distance Error: %s" % e)
# return 50
# Given two ordered lists of quaternions, compute distances between each angle
def computeAngleErrors(A, B):
qq = zip(A, B)
errors = []
for i,v in enumerate(qq):
dist = quatDist(v[0],v[1])
errors.append((dist))
return errors
# A and B are (x,y) tuples
def euclideanDistance(A,B):
return math.sqrt((A[0]-B[0])**2 + (A[1]-B[1])**2)
# theoretical and experimental are arrays of (x,y) tuples
def computeShiftErrors(theoretical, experimental):
ab = zip(theoretical, experimental)
errors = []
for i, v in enumerate(ab):
dist = euclideanDistance(v[0],v[1])
errors.append(dist)
return errors
def mean(array):
return sum(array)/len(array)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--inputpath", help="path to input files", type=str)
parser.add_argument("--inputname", help="name for format wt_<name>.star etc.", type=str)
sys.exit(main(parser.parse_args()))