forked from thuml/Transfer-Learning-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pseudo_label.py
270 lines (235 loc) · 12.8 KB
/
pseudo_label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
"""
@author: Baixu Chen
@contact: [email protected]
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = weak_augment
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
x_u, labels_u = next(unlabeled_train_iter)
x_u = x_u.to(device)
labels_u = labels_u.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# self training loss
y_u = model(x_u)
self_training_loss, mask, pseudo_labels = self_training_criterion(y_u, y_u)
self_training_loss = args.trade_off_self_training * self_training_loss
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# accuracy of pseudo labels
n_pseudo_labels = mask.sum()
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_u).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Pseudo Label for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--threshold', default=0.95, type=float,
help='confidence threshold (default: 0.95)')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=40, type=int, metavar='N',
help='number of total epochs to run (default: 40)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='pseudo_label', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)