forked from Noahs-ARK/soft_patterns
-
Notifications
You must be signed in to change notification settings - Fork 1
/
soft_patterns_test.py
executable file
·134 lines (111 loc) · 4.71 KB
/
soft_patterns_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
"""
Script to evaluate the accuracy of a model.
"""
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from collections import OrderedDict
from soft_patterns import ProbSemiring, MaxPlusSemiring, LogSpaceMaxTimesSemiring, MaxTimesSemiring, BooleanSemiring, ViterbiSemiring, MinPlusSemiring, \
evaluate_accuracy, SoftPatternClassifier, soft_pattern_arg_parser, general_arg_parser
from baselines.cnn import PooledCnnClassifier, max_pool_seq, cnn_arg_parser
from baselines.dan import DanClassifier
from baselines.lstm import AveragingRnnClassifier
import sys
import torch
import numpy as np
from torch.nn import LSTM
from data import vocab_from_text, read_embeddings, read_docs, read_labels
from rnn import Rnn
SCORE_IDX = 0
START_IDX_IDX = 1
END_IDX_IDX = 2
# TODO: refactor duplicate code with soft_patterns.py
def main(args):
print(args)
n = args.num_train_instances
mlp_hidden_dim = args.mlp_hidden_dim
num_mlp_layers = args.num_mlp_layers
dev_vocab = vocab_from_text(args.vd)
print("Dev vocab size:", len(dev_vocab))
vocab, embeddings, word_dim = \
read_embeddings(args.embedding_file, dev_vocab)
if args.seed != -1:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if args.dan or args.bilstm:
num_padding_tokens = 1
elif args.cnn:
num_padding_tokens = args.window_size - 1
else:
pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in args.patterns.split("_")),
key=lambda t: t[0]))
num_padding_tokens = max(list(pattern_specs.keys())) - 1
dev_input, _ = read_docs(args.vd, vocab, num_padding_tokens=num_padding_tokens)
dev_labels = read_labels(args.vl)
dev_data = list(zip(dev_input, dev_labels))
if n is not None:
dev_data = dev_data[:n]
num_classes = len(set(dev_labels))
print("num_classes:", num_classes)
if args.dan:
model = DanClassifier(mlp_hidden_dim,
num_mlp_layers,
num_classes,
embeddings,
args.gpu)
elif args.bilstm:
cell_type = LSTM
model = AveragingRnnClassifier(args.hidden_dim,
mlp_hidden_dim,
num_mlp_layers,
num_classes,
embeddings,
cell_type=cell_type,
gpu=args.gpu)
elif args.cnn:
model = PooledCnnClassifier(args.window_size,
args.num_cnn_layers,
args.cnn_hidden_dim,
num_mlp_layers,
mlp_hidden_dim,
num_classes,
embeddings,
pooling=max_pool_seq,
gpu=args.gpu)
else:
semiring = {
'ProbSemiring': ProbSemiring,
'MaxPlusSemiring': MaxPlusSemiring,
'LogSpaceMaxTimesSemiring': LogSpaceMaxTimesSemiring,
'MaxTimesSemiring': MaxTimesSemiring,
'BooleanSemiring': BooleanSemiring,
'ViterbiSemiring': ViterbiSemiring,
'MinPlusSemiring': MinPlusSemiring
}[args.semiring]
if args.use_rnn:
rnn = Rnn(word_dim,
args.hidden_dim,
cell_type=LSTM,
gpu=args.gpu)
else:
rnn = None
model = SoftPatternClassifier(pattern_specs, mlp_hidden_dim, num_mlp_layers, num_classes, embeddings, vocab,
semiring, args.bias_scale_param, args.gpu, rnn, None, args.no_sl, args.shared_sl,
args.no_eps, args.eps_scale, args.self_loop_scale)
if args.gpu:
state_dict = torch.load(args.input_model)
else:
state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
if args.gpu:
model.to_cuda(model)
test_acc = evaluate_accuracy(model, dev_data, args.batch_size, args.gpu)
print("Test accuracy: {:>8,.3f}%".format(100*test_acc))
return 0
if __name__ == '__main__':
parser = ArgumentParser(description=__doc__,
formatter_class=ArgumentDefaultsHelpFormatter,
parents=[soft_pattern_arg_parser(), cnn_arg_parser(), general_arg_parser()])
parser.add_argument("--dan", help="Dan classifier", action='store_true')
parser.add_argument("--cnn", help="CNN classifier", action='store_true')
parser.add_argument("--bilstm", help="BiLSTM classifier", action='store_true')
sys.exit(main(parser.parse_args()))