-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
32 lines (24 loc) · 1.07 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from category_id_map import category_id_to_lv2id
from util import evaluate
def evaluate_submission(result_file, ground_truth_file):
ground_truth = {}
with open(ground_truth_file, 'r') as f:
for line in f:
vid, category_id = line.strip().split(',')
ground_truth[vid] = category_id_to_lv2id(category_id)
predictions, labels = [], []
with open(result_file, 'r') as f:
for line in f:
vid, category_id = line.strip().split(',')
if vid not in ground_truth:
raise Exception(f'ERROR id {vid} in result.csv')
predictions.append(category_id_to_lv2id(category_id))
labels.append(ground_truth[vid])
if len(predictions) != len(ground_truth):
raise Exception(f'ERROR: Wrong line numbers')
return evaluate(predictions, labels)
if __name__ == '__main__':
result_file = 'data/result.csv'
ground_truth_file = 'data/private/ground_truth_test_a.csv'
result = evaluate_submission(result_file, ground_truth_file)
print(f'mean F1 score is {result["mean_f1"]}')