-
Notifications
You must be signed in to change notification settings - Fork 10
/
conll.py
101 lines (85 loc) · 4.11 KB
/
conll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import re
import tempfile
import subprocess
import operator
import collections
import logging
logger = logging.getLogger(__name__)
BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)") # First line at each document
COREF_RESULTS_REGEX = re.compile(r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tF1: ([0-9.]+)%.*", re.DOTALL)
def get_doc_key(doc_id, part):
return "{}_{}".format(doc_id, int(part))
def output_conll(input_file, output_file, predictions, subtoken_map):
prediction_map = {}
for doc_key, clusters in predictions.items():
start_map = collections.defaultdict(list)
end_map = collections.defaultdict(list)
word_map = collections.defaultdict(list)
for cluster_id, mentions in enumerate(clusters):
for start, end in mentions:
start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
if start == end:
word_map[start].append(cluster_id)
else:
start_map[start].append((cluster_id, end))
end_map[end].append((cluster_id, start))
for k,v in start_map.items():
start_map[k] = [cluster_id for cluster_id, end in sorted(v, key=operator.itemgetter(1), reverse=True)]
for k,v in end_map.items():
end_map[k] = [cluster_id for cluster_id, start in sorted(v, key=operator.itemgetter(1), reverse=True)]
prediction_map[doc_key] = (start_map, end_map, word_map)
word_index = 0
for line in input_file.readlines():
row = line.split()
if len(row) == 0:
output_file.write("\n")
elif row[0].startswith("#"):
begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
if begin_match:
doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
start_map, end_map, word_map = prediction_map[doc_key]
word_index = 0
output_file.write(line)
output_file.write("\n")
else:
assert get_doc_key(row[0], row[1]) == doc_key
coref_list = []
if word_index in end_map:
for cluster_id in end_map[word_index]:
coref_list.append("{})".format(cluster_id))
if word_index in word_map:
for cluster_id in word_map[word_index]:
coref_list.append("({})".format(cluster_id))
if word_index in start_map:
for cluster_id in start_map[word_index]:
coref_list.append("({}".format(cluster_id))
if len(coref_list) == 0:
row[-1] = "-"
else:
row[-1] = "|".join(coref_list)
output_file.write(" ".join(row))
output_file.write("\n")
word_index += 1
def official_conll_eval(gold_path, predicted_path, metric, official_stdout=True):
cmd = ["conll-2012/scorer/v8.01/scorer.pl", metric, gold_path, predicted_path, "none"]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
stdout, stderr = process.communicate()
process.wait()
stdout = stdout.decode("utf-8")
if stderr is not None:
logger.error(stderr)
if official_stdout:
logger.info("Official result for {}".format(metric))
logger.info(stdout)
coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)
recall = float(coref_results_match.group(1))
precision = float(coref_results_match.group(2))
f1 = float(coref_results_match.group(3))
return {"r": recall, "p": precision, "f": f1}
def evaluate_conll(gold_path, predictions, subtoken_maps, official_stdout=True):
with tempfile.NamedTemporaryFile(delete=True, mode="w") as prediction_file:
with open(gold_path, "r") as gold_file:
output_conll(gold_file, prediction_file, predictions, subtoken_maps)
# logger.info("Predicted conll file: {}".format(prediction_file.name))
results = {m: official_conll_eval(gold_file.name, prediction_file.name, m, official_stdout) for m in ("muc", "bcub", "ceafe") }
return results