forked from SeuTao/Humpback-Whale-Identification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble.py
executable file
·123 lines (94 loc) · 3.94 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from collections import defaultdict, Counter
from process.data_helper import *
SIGFIGS = 6
id_name_label, dict_label = load_CLASS_NAME()
dict_label['new_whale'] = -1
id_name_label[-1] = 'new_whale'
def read_models(model_pred, thres, blend=None):
if not blend:
blend = defaultdict(Counter)
count = 0
for m, w in model_pred.items():
m_list = os.listdir(m)
print(m)
print(len(m_list))
for m_tmp in m_list:
count += 1
with open(os.path.join(m, m_tmp), 'r') as f:
f.readline()
for l in f:
id, r = l.split(',')
id, r = id, r.split(' ')
n = len(r) // 2 * 2
for i in range(0, n, 2):
tmp = r[i]
prob = float(r[i + 1])
k = dict_label[tmp]
v = 10**(SIGFIGS - 1) * prob
blend[id][k] += w * v
# add new_whale
k = dict_label['new_whale']
v = 10 ** (SIGFIGS - 1) * float(thres)
blend[id][k] += w * v
print(count)
return blend
def clalibrate_distribution(blend):
id_dict = {}
id_top1_dict = {}
for i in range(5005):
id_top1_dict[i] = [None, 0.0]
count_NewWhale = 0
for id, v in blend.items():
for t in enumerate(v.most_common(1)):
if t[1][0] == -1:
count_NewWhale += 1
if t[1][0] not in id_dict:
id_dict[t[1][0]] = [id]
else:
id_dict[t[1][0]].append(id)
for t in enumerate(v.most_common(5)):
if t[1][0] != -1:
if t[1][1] > id_top1_dict[t[1][0]][1]:
id_top1_dict[t[1][0]] = [id, t[1][1]]
print('id num:' + str(len(id_dict)))
print('missing id num:' + str(5005 - len(id_dict)))
print('new whale num:' + str(count_NewWhale))
missing_ids = {}
for i in range(5005):
if i not in id_dict:
missing_ids[i] = 0
return blend, missing_ids
def write_models(blend, file_name, is_top1=False):
with open(file_name + '.csv', 'w') as f:
f.write('Image,Id\n')
nc = 0
for id, v in blend.items():
if is_top1:
output = ' '.join(['{}'.format(id_name_label[int(t[0])]) for i_t, t in enumerate(v.most_common(20)) if i_t < 1])
output += ' None None None None'
print(l)
f.write(','.join([str(id), output + '\n']))
else:
output = ' '.join(['{}'.format(id_name_label[int(t[0])]) for i_t, t in enumerate(v.most_common(20)) if i_t < 5])
f.write(','.join([str(id), output + '\n']))
if output.find('new_whale') == 0:
nc += 1
print('new whale num: ' + str(nc))
return file_name + '.csv'
if __name__ == '__main__':
model_pred = {
r'./models/resnet101_fold0_256_512/checkpoint/max_valid_model': 10,
r'./models/seresnet101_fold0_256_512/checkpoint/max_valid_model': 10,
r'./models/seresnext101_fold0_256_512/checkpoint/max_valid_model': 10,
r'./models/resnet101_fold0_512_512/checkpoint/max_valid_model': 10,
r'./models/seresnet101_fold0_512_512/checkpoint/max_valid_model': 10,
r'./models/resnet101_fold0_pseudo_256_512/checkpoint/max_valid_model': 20,
r'./models/seresnet101_fold0_pseudo_256_512/checkpoint/max_valid_model': 20,
r'./models/resnet101_fold0_pseudo_512_512/checkpoint/max_valid_model': 20,
r'./models/seresnet101_fold0_pseudo_512_512/checkpoint/max_valid_model': 20,
r'./models/seresnext101_fold0_pseudo_512_512/checkpoint/max_valid_model': 20,
}
thres = 0.185
avg = read_models(model_pred, thres)
avg, missing_ids = clalibrate_distribution(blend=avg)
csv_name = write_models(avg, 'final_submission_id_4081_NW_2123', is_top1=False)