-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess-samples.py
155 lines (131 loc) · 5.5 KB
/
preprocess-samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import base64
import csv
import os
import random
from typing import Literal
import dominate
# pyright: reportWildcardImportFromLibrary=false
from dominate.tags import *
from dominate.util import raw
import yaml
from coherence.entities.coreferences import coref_annotation, coref_diagram
from main import get_prediction
from nlp.helpers import normalize_quotes
from stylometry.logistic_regression import predict_author
from tem.model import TopicEvolution
from tem.process import get_default_te
from train_tem_metrics import predict_from_tem_metrics
def html_results(
text: str,
author: Literal[0, 1, -1],
te: TopicEvolution,
entity_diagram: tuple[div, div],
title: str = 'unBlock Analysis',
) -> str:
te_img_data = base64.encodebytes(te.graph().pipe(format='png')).decode('ascii')
doc = dominate.document(title=title)
with doc.head:
# raw prevents escaping of `>` character
# CSS `#container > * + *` is like tailwind's `space-y` class
style(raw('''\
body {
padding: 0 4rem;
color: #333;
}
h2 {
border-top: 1px solid lightgray;
padding-top: 1rem
}
#container { display: flex; }
#container > * + * { margin-left: 4rem; }
#left { flex-basis: 65%; }
#right {
color: #666;
flex-basis: 35%;
}
.col > * + * { margin-top: 2rem; }
'''))
with doc:
h1(title)
container = div(id='container')
left = container.add(div(id='left', className='col'))
left.add(h2([
'We are not sure whether the text was written by a human or generated by a machine.',
'The text was likely generated by a machine.',
'The text was likely written by a human.',
][author]))
left.add(p('This prediction is based on an analysis of the Topic Evolution graph below & stylometry markers in the text'))
left.add(h2('Topic Evolution graph'))
left.add(img(
src=f'data:image/png;base64,{te_img_data}',
style='width: 100%'
))
left.add(h2('Entity occurrence chart'))
left.add(entity_diagram[0])
left.add(h3('Legend'))
left.add(entity_diagram[1])
right = container.add(div(id='right', className='col'))
right.add(h2('Full text'))
text_container = right.add(div())
for paragraph in text.split('\n'):
# dominate doesn't do well with unicode characters so we change the
# most frequent ones (quotes) into their ascii equivalents
text_container.add(p(raw(normalize_quotes(paragraph))))
return doc.render()
def analyze_samples(databases: list[tuple[str, list[dict[str, str]]]], sets: int, samples: int):
directory = os.path.join('samples')
os.makedirs(directory)
def draw() -> tuple[str, dict[str, str]]:
if len(databases) == 0:
raise IndexError('Not enough samples in databases.')
i = random.choice(range(len(databases)))
name, db = databases[i]
j = random.choice(range(len(db)))
sample = db.pop(j)
if len(db) == 0:
databases.pop(i)
return (name, sample)
sources_fp = open(os.path.join(directory, '.sources.csv'), 'w')
sources_writer = csv.writer(sources_fp)
sources_writer.writerow(['set', 'sample', 'source', 'prediction'])
total = sets * samples
for set_id in range(1, sets + 1):
directory_i = os.path.join(directory, str(set_id))
os.makedirs(directory_i)
sampled = 0
while sampled < samples:
text_id = sampled + 1
progress = (set_id - 1) * samples + text_id
print(f'\r\033[Kanalyzing sample {progress}/{total}', end='')
source, sample = draw()
text = sample['text']
if not text: continue
try:
style_prediction = predict_author(text)
te = get_default_te(text)
te_prediction = predict_from_tem_metrics(te)
author = get_prediction(style_prediction, te_prediction)
entity_diagram = coref_diagram(coref_annotation(text))
except: continue
sources_writer.writerow([set_id, text_id, source, ['Not sure', 'Machine', 'Human'][author]])
with open(os.path.join(directory_i, f'{text_id}.html'), 'w') as fp:
fp.write(html_results(text, author, te, entity_diagram, title=f'unBlock Analysis for text {text_id}'))
sampled += 1
sources_fp.close()
print('\ndone!')
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Analysis preprocessor')
parser.add_argument('files', type=str, nargs='+', help='input yaml database files to uniformly draw samples from')
parser.add_argument('--samples', '-m', type=int, required=True, help='number of samples per set')
parser.add_argument('--sets', '-n', type=int, required=True, help='number of sample sets')
args = parser.parse_args()
databases: list[tuple[str, list[dict[str, str]]]] = list()
for i, db in enumerate(args.files):
name = db.split('/')[-1].split('.')[0]
print(f'\r\033[Kreading yaml database "{name}" ({i + 1}/{len(args.files)})', end='')
with open(db, 'r') as fp:
data = yaml.safe_load(fp.read())
databases.append((name, data))
print('\ndone reading')
analyze_samples(databases, args.sets, args.samples)