-
Notifications
You must be signed in to change notification settings - Fork 1
/
util.py
92 lines (70 loc) · 2.44 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import torch
import torch.nn.functional as F
def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.2
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
loss = F.cross_entropy(pred, gold, reduction='mean')
return loss
class IOStream():
"""
When distributed training on multiple GPUs, only write logs through the results obtained by
the first gpu device whose rank=0, otherwise lead to duplicate logs
"""
def __init__(self, path, rank=-1):
self.rank = rank
if self.rank == 0:
self.f = open(path, 'a')
def cprint(self, text):
if self.rank == 0:
# print(text)
self.f.write(text+'\n')
self.f.flush()
def close(self):
if self.rank == 0:
self.f.close()
def adjust_learning_rate(epoch, opt, optimizer):
"""Sets the learning rate to the initial LR decayed by decay rate every steep step"""
steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
if steps > 0:
new_lr = opt.lr * (opt.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class AccuracyMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.num_pos = 0
self.num_neg = 0
self.total = 0
def update(self, num_pos, num_neg, n=1):
self.num_pos += num_pos
self.num_neg += num_neg
self.total += n
def pos_count(self, pred, label):
# torch.eq(a,b): Computes element-wise equality
results = torch.eq(pred, label)
return results.sum()