-
Notifications
You must be signed in to change notification settings - Fork 16
/
run.py
542 lines (444 loc) · 23.8 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
# -*- coding: utf-8 -*-
# @Time : 2021/1/25 6:37 PM
# @Author : liuxiyang
from helper import *
from data_loader import *
# sys.path.append('./')
from model.models import *
import traceback
class Runner(object):
def load_data(self):
"""
Reading in raw triples and converts it into a standard format.
Parameters
----------
self.p.dataset: Takes in the name of the dataset (FB15k-237)
Returns
-------
self.ent2id: Entity to unique identifier mapping
self.id2rel: Inverse mapping of self.ent2id
self.rel2id: Relation to unique identifier mapping
self.num_ent: Number of entities in the Knowledge graph
self.num_rel: Number of relations in the Knowledge graph
self.embed_dim: Embedding dimension used
self.data['train']: Stores the triples corresponding to training dataset
self.data['valid']: Stores the triples corresponding to validation dataset
self.data['test']: Stores the triples corresponding to test dataset
self.data_iter: The dataloader for different data splits
"""
ent_set, rel_set = OrderedSet(), OrderedSet()
for split in ['train', 'test', 'valid']:
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t'))
ent_set.add(sub)
rel_set.add(rel)
ent_set.add(obj)
self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
self.rel2id.update({rel + '_reverse': idx + len(self.rel2id) for idx, rel in enumerate(rel_set)})
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
self.p.num_ent = len(self.ent2id)
self.p.num_rel = len(self.rel2id) // 2
self.p.embed_dim = self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim
self.data = ddict(list)
sr2o = ddict(set)
for split in ['train', 'test', 'valid']:
for line in open('./data/{}/{}.txt'.format(self.p.dataset, split)):
sub, rel, obj = map(str.lower, line.strip().split('\t'))
sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
self.data[split].append((sub, rel, obj))
if split == 'train':
sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel + self.p.num_rel)].add(sub)
# self.data: all origin train + valid + test triplets
self.data = dict(self.data)
# self.sr2o: train origin edges and reverse edges
self.sr2o = {k: list(v) for k, v in sr2o.items()}
for split in ['test', 'valid']:
for sub, rel, obj in self.data[split]:
sr2o[(sub, rel)].add(obj)
sr2o[(obj, rel + self.p.num_rel)].add(sub)
self.sr2o_all = {k: list(v) for k, v in sr2o.items()}
self.triples = ddict(list)
# for (sub, rel), obj in self.sr2o.items():
# self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})
if self.p.strategy == 'one_to_n':
for (sub, rel), obj in self.sr2o.items():
self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1})
else:
for sub, rel, obj in self.data['train']:
rel_inv = rel + self.p.num_rel
sub_samp = len(self.sr2o[(sub, rel)]) + len(self.sr2o[(obj, rel_inv)])
sub_samp = np.sqrt(1 / sub_samp)
self.triples['train'].append(
{'triple': (sub, rel, obj), 'label': self.sr2o[(sub, rel)], 'sub_samp': sub_samp})
self.triples['train'].append(
{'triple': (obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp})
for split in ['test', 'valid']:
for sub, rel, obj in self.data[split]:
rel_inv = rel + self.p.num_rel
self.triples['{}_{}'.format(split, 'tail')].append(
{'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
self.triples['{}_{}'.format(split, 'head')].append(
{'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})
self.triples = dict(self.triples)
def get_data_loader(dataset_class, split, batch_size, shuffle=True):
return DataLoader(
dataset_class(self.triples[split], self.p),
batch_size=batch_size,
shuffle=shuffle,
num_workers=max(0, self.p.num_workers),
collate_fn=dataset_class.collate_fn
)
self.data_iter = {
'train': get_data_loader(TrainDataset, 'train', self.p.batch_size),
'valid_head': get_data_loader(TestDataset, 'valid_head', self.p.test_batch_size),
'valid_tail': get_data_loader(TestDataset, 'valid_tail', self.p.test_batch_size),
'test_head': get_data_loader(TestDataset, 'test_head', self.p.test_batch_size),
'test_tail': get_data_loader(TestDataset, 'test_tail', self.p.test_batch_size),
}
self.edge_index, self.edge_type = self.construct_adj()
def construct_adj(self):
"""
Constructor of the runner class
Parameters
----------
Returns
-------
Constructs the adjacency matrix for GCN
"""
edge_index, edge_type = [], []
for sub, rel, obj in self.data['train']:
edge_index.append((sub, obj))
edge_type.append(rel)
# Adding inverse edges
for sub, rel, obj in self.data['train']:
edge_index.append((obj, sub))
edge_type.append(rel + self.p.num_rel)
# edge_index: 2 * 2E, edge_type: 2E * 1
edge_index = torch.LongTensor(edge_index).to(self.device).t()
edge_type = torch.LongTensor(edge_type).to(self.device)
return edge_index, edge_type
def __init__(self, params):
"""
Constructor of the runner class
Parameters
----------
params: List of hyper-parameters of the model
Returns
-------
Creates computational graph and optimizer
"""
self.p = params
self.logger = get_logger(self.p.name, self.p.log_dir, self.p.config_dir)
self.logger.info(vars(self.p))
pprint(vars(self.p))
if self.p.gpu != '-1' and torch.cuda.is_available():
self.device = torch.device('cuda')
torch.cuda.set_rng_state(torch.cuda.get_rng_state())
torch.backends.cudnn.deterministic = True
else:
self.device = torch.device('cpu')
self.load_data()
self.model = self.add_model(self.p.model, self.p.score_func)
self.optimizer = self.add_optimizer(self.model.parameters())
def add_model(self, model, score_func):
"""
Creates the computational graph
Parameters
----------
model_name: Contains the model name to be created
Returns
-------
Creates the computational graph for model and initializes it
"""
model_name = '{}_{}'.format(model, score_func)
if model_name.lower() == 'ragat_transe':
model = RagatTransE(self.edge_index, self.edge_type, params=self.p)
elif model_name.lower() == 'ragat_distmult':
model = RagatDistMult(self.edge_index, self.edge_type, params=self.p)
elif model_name.lower() == 'ragat_conve':
model = RagatConvE(self.edge_index, self.edge_type, params=self.p)
elif model_name.lower() == 'ragat_interacte':
model = RagatInteractE(self.edge_index, self.edge_type, params=self.p)
else:
raise NotImplementedError
model.to(self.device)
return model
def add_optimizer(self, parameters):
"""
Creates an optimizer for training the parameters
Parameters
----------
parameters: The parameters of the model
Returns
-------
Returns an optimizer for learning the parameters of the model
"""
return torch.optim.Adam(parameters, lr=self.p.lr, weight_decay=self.p.l2)
def read_batch(self, batch, split):
"""
Function to read a batch of data and move the tensors in batch to CPU/GPU
Parameters
----------
batch: the batch to process
split: (string) If split == 'train', 'valid' or 'test' split
Returns
-------
Head, Relation, Tails, labels
"""
# if split == 'train':
# triple, label = [_.to(self.device) for _ in batch]
# return triple[:, 0], triple[:, 1], triple[:, 2], label
# else:
# triple, label = [_.to(self.device) for _ in batch]
# return triple[:, 0], triple[:, 1], triple[:, 2], label
if split == 'train':
if self.p.strategy == 'one_to_x':
triple, label, neg_ent, sub_samp = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp
else:
triple, label = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None
else:
triple, label = [_.to(self.device) for _ in batch]
return triple[:, 0], triple[:, 1], triple[:, 2], label
def save_model(self, save_path):
"""
Function to save a model. It saves the model parameters, best validation scores,
best epoch corresponding to best validation, state of the optimizer and all arguments for the run.
Parameters
----------
save_path: path where the model is saved
Returns
-------
"""
state = {
'state_dict': self.model.state_dict(),
'best_val': self.best_val,
'best_epoch': self.best_epoch,
'optimizer': self.optimizer.state_dict(),
'args': vars(self.p)
}
torch.save(state, save_path)
def load_model(self, load_path):
"""
Function to load a saved model
Parameters
----------
load_path: path to the saved model
Returns
-------
"""
state = torch.load(load_path)
state_dict = state['state_dict']
self.best_val = state['best_val']
self.best_val_mrr = self.best_val['mrr']
self.model.load_state_dict(state_dict)
self.optimizer.load_state_dict(state['optimizer'])
def evaluate(self, split, epoch):
"""
Function to evaluate the model on validation or test set
Parameters
----------
split: (string) If split == 'valid' then evaluate on the validation set, else the test set
epoch: (int) Current epoch count
Returns
-------
resutls: The evaluation results containing the following:
results['mr']: Average of ranks_left and ranks_right
results['mrr']: Mean Reciprocal Rank
results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
"""
left_results = self.predict(split=split, mode='tail_batch')
right_results = self.predict(split=split, mode='head_batch')
results = get_combined_results(left_results, right_results)
res_mrr = '\n\tMRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mrr'],
results['right_mrr'],
results['mrr'])
res_mr = '\tMR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mr'],
results['right_mr'],
results['mr'])
res_hit1 = '\tHit-1: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@1'],
results['right_hits@1'],
results['hits@1'])
res_hit3 = '\tHit-3: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@3'],
results['right_hits@3'],
results['hits@3'])
res_hit10 = '\tHit-10: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(results['left_hits@10'],
results['right_hits@10'],
results['hits@10'])
log_res = res_mrr + res_mr + res_hit1 + res_hit3 + res_hit10
if (epoch + 1) % 10 == 0 or split == 'test':
self.logger.info(
'[Evaluating Epoch {} {}]: {}'.format(epoch, split, log_res))
else:
self.logger.info(
'[Evaluating Epoch {} {}]: {}'.format(epoch, split, res_mrr))
return results
def predict(self, split='valid', mode='tail_batch'):
"""
Function to run model evaluation for a given mode
Parameters
----------
split: (string) If split == 'valid' then evaluate on the validation set, else the test set
mode: (string): Can be 'head_batch' or 'tail_batch'
Returns
-------
resutls: The evaluation results containing the following:
results['mr']: Average of ranks_left and ranks_right
results['mrr']: Mean Reciprocal Rank
results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score
"""
self.model.eval()
with torch.no_grad():
results = {}
train_iter = iter(self.data_iter['{}_{}'.format(split, mode.split('_')[0])])
for step, batch in enumerate(train_iter):
sub, rel, obj, label = self.read_batch(batch, split)
pred = self.model.forward(sub, rel)
b_range = torch.arange(pred.size()[0], device=self.device)
target_pred = pred[b_range, obj]
# filter setting
pred = torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred
ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[
b_range, obj]
ranks = ranks.float()
results['count'] = torch.numel(ranks) + results.get('count', 0.0)
results['mr'] = torch.sum(ranks).item() + results.get('mr', 0.0)
results['mrr'] = torch.sum(1.0 / ranks).item() + results.get('mrr', 0.0)
for k in range(10):
results['hits@{}'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)]) + results.get(
'hits@{}'.format(k + 1), 0.0)
# if step % 100 == 0:
# self.logger.info('[{}, {} Step {}]\t{}'.format(split.title(), mode.title(), step, self.p.name))
return results
def run_epoch(self, epoch, val_mrr=0):
"""
Function to run one epoch of training
Parameters
----------
epoch: current epoch count
Returns
-------
loss: The loss value after the completion of one epoch
"""
self.model.train()
losses = []
train_iter = iter(self.data_iter['train'])
for step, batch in enumerate(train_iter):
self.optimizer.zero_grad()
sub, rel, obj, label, neg_ent, sub_samp = self.read_batch(batch, 'train')
pred = self.model.forward(sub, rel, neg_ent)
loss = self.model.loss(pred, label)
loss.backward()
self.optimizer.step()
losses.append(loss.item())
# if step % 100 == 0:
# self.logger.info('[E:{}| {}]: Train Loss:{:.5}, Val MRR:{:.5}\t{}'.format(epoch, step, np.mean(losses),
# self.best_val_mrr,
# self.p.name))
loss = np.mean(losses)
# self.logger.info('[Epoch:{}]: Training Loss:{:.4}\n'.format(epoch, loss))
return loss
def fit(self):
"""
Function to run training and evaluation of model
Parameters
----------
Returns
-------
"""
try:
self.best_val_mrr, self.best_val, self.best_epoch, val_mrr = 0., {}, 0, 0.
save_path = os.path.join('./checkpoints', self.p.name)
if self.p.restore:
self.load_model(save_path)
self.logger.info('Successfully Loaded previous model')
val_results = {}
val_results['mrr'] = 0
for epoch in range(self.p.max_epochs):
train_loss = self.run_epoch(epoch, val_mrr)
# if ((epoch + 1) % 10 == 0):
val_results = self.evaluate('valid', epoch)
if val_results['mrr'] > self.best_val_mrr:
self.best_val = val_results
self.best_val_mrr = val_results['mrr']
self.best_epoch = epoch
self.save_model(save_path)
self.logger.info(
'[Epoch {}]: Training Loss: {:.5}, Best valid MRR: {:.5}\n\n'.format(epoch, train_loss,
self.best_val_mrr))
self.logger.info('Loading best model, Evaluating on Test data')
self.load_model(save_path)
test_results = self.evaluate('test', self.best_epoch)
except Exception as e:
self.logger.debug("%s____%s\n"
"traceback.format_exc():____%s" % (Exception, e, traceback.format_exc()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parser For Arguments',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-name', default='testrun', help='Set run name for saving/restoring models')
parser.add_argument('-data', dest='dataset', default='FB15k-237', help='Dataset to use, default: FB15k-237')
parser.add_argument('-model', dest='model', default='ragat', help='Model Name')
parser.add_argument('-score_func', dest='score_func', default='conve', help='Score Function for Link prediction')
parser.add_argument('-opn', dest='opn', default='corr', help='Composition Operation to be used in RAGAT')
# opn is new hyperparameter
parser.add_argument('-batch', dest='batch_size', default=1024, type=int, help='Batch size')
parser.add_argument('-test_batch', dest='test_batch_size', default=1024, type=int,
help='Batch size of valid and test data')
parser.add_argument('-gamma', type=float, default=40.0, help='Margin')
parser.add_argument('-gpu', type=str, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
parser.add_argument('-epoch', dest='max_epochs', type=int, default=500, help='Number of epochs')
parser.add_argument('-l2', type=float, default=0.0, help='L2 Regularization for Optimizer')
parser.add_argument('-lr', type=float, default=0.001, help='Starting Learning Rate')
parser.add_argument('-lbl_smooth', dest='lbl_smooth', type=float, default=0.1, help='Label Smoothing')
parser.add_argument('-num_workers', type=int, default=10, help='Number of processes to construct batches')
parser.add_argument('-seed', dest='seed', default=41504, type=int, help='Seed for randomization')
parser.add_argument('-restore', dest='restore', action='store_true', help='Restore from the previously saved model')
parser.add_argument('-bias', dest='bias', action='store_true', help='Whether to use bias in the model')
parser.add_argument('-init_dim', dest='init_dim', default=100, type=int,
help='Initial dimension size for entities and relations')
parser.add_argument('-gcn_dim', dest='gcn_dim', default=200, type=int, help='Number of hidden units in GCN')
parser.add_argument('-embed_dim', dest='embed_dim', default=200, type=int,
help='Embedding dimension to give as input to score function')
parser.add_argument('-gcn_layer', dest='gcn_layer', default=1, type=int, help='Number of GCN Layers to use')
parser.add_argument('-gcn_drop', dest='dropout', default=0.3, type=float, help='Dropout to use in GCN Layer')
parser.add_argument('-hid_drop', dest='hid_drop', default=0.3, type=float, help='Dropout after GCN')
# ConvE specific hyperparameters
parser.add_argument('-hid_drop2', dest='hid_drop2', default=0.3, type=float, help='ConvE: Hidden dropout')
parser.add_argument('-feat_drop', dest='feat_drop', default=0.3, type=float, help='ConvE: Feature Dropout')
parser.add_argument('-k_w', dest='k_w', default=10, type=int, help='ConvE: k_w')
parser.add_argument('-k_h', dest='k_h', default=20, type=int, help='ConvE: k_h')
parser.add_argument('-num_filt', dest='num_filt', default=200, type=int,
help='ConvE: Number of filters in convolution')
parser.add_argument('-ker_sz', dest='ker_sz', default=7, type=int, help='ConvE: Kernel size to use')
parser.add_argument('-logdir', dest='log_dir', default='./log/', help='Log directory')
parser.add_argument('-config', dest='config_dir', default='./config/', help='Config directory')
# InteractE hyperparameters
parser.add_argument('-neg_num', dest="neg_num", default=1000, type=int,
help='Number of negative samples to use for loss calculation')
parser.add_argument("-strategy", type=str, default='one_to_n', help='Training strategy to use')
parser.add_argument('-form', type=str, default='plain', help='The reshaping form to use')
parser.add_argument('-ik_w', dest="ik_w", default=10, type=int, help='Width of the reshaped matrix')
parser.add_argument('-ik_h', dest="ik_h", default=20, type=int, help='Height of the reshaped matrix')
parser.add_argument('-inum_filt', dest="inum_filt", default=200, type=int, help='Number of filters in convolution')
parser.add_argument('-iker_sz', dest="iker_sz", default=9, type=int, help='Kernel size to use')
parser.add_argument('-iperm', dest="iperm", default=1, type=int, help='Number of Feature rearrangement to use')
parser.add_argument('-iinp_drop', dest="iinp_drop", default=0.3, type=float, help='Dropout for Input layer')
parser.add_argument('-ifeat_drop', dest="ifeat_drop", default=0.3, type=float, help='Dropout for Feature')
parser.add_argument('-ihid_drop', dest="ihid_drop", default=0.3, type=float, help='Dropout for Hidden layer')
parser.add_argument('-attention', dest="att", help="Whether to use attention layer")
parser.add_argument('-head_num', dest="head_num", default=2, type=int, help="Number of attention heads")
args = parser.parse_args()
if not args.restore:
args.name = args.name + '_' + time.strftime('%d_%m_%Y') + '_' + time.strftime('%H:%M:%S')
set_gpu(args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
model = Runner(args)
model.fit()