forked from thuml/Transfer-Learning-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
/
co_tuning.py
240 lines (207 loc) · 11.4 KB
/
co_tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
@author: Yifei Ji, Junguang Jiang
@contact: [email protected], [email protected]
"""
import random
import time
import warnings
import argparse
import shutil
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Subset
import utils
from tllib.regularization.co_tuning import CoTuningLoss, Relationship, Classifier
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.data import ForeverDataIterator
import tllib.vision.datasets as datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None):
dataset = datasets.__dict__[dataset_name]
if sample_rate < 100:
train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True,
transform=train_transform)
determin_train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True,
transform=val_transform)
test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform)
num_classes = train_dataset.num_classes
else:
train_dataset = dataset(root=root, split='train', transform=train_transform)
determin_train_dataset = dataset(root=root, split='train', transform=val_transform)
test_dataset = dataset(root=root, split='test', transform=val_transform)
num_classes = train_dataset.num_classes
if num_samples_per_classes is not None:
samples = list(range(len(train_dataset)))
random.shuffle(samples)
samples_len = min(num_samples_per_classes * num_classes, len(train_dataset))
train_dataset = Subset(train_dataset, samples[:samples_len])
determin_train_dataset = Subset(determin_train_dataset, samples[:samples_len])
return train_dataset, determin_train_dataset, test_dataset, num_classes
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, determin_train_dataset, val_dataset, num_classes = get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
determin_train_loader = DataLoader(determin_train_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers, drop_last=False)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer,
finetune=args.finetune).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# build relationship between source classes and target classes
source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source)
relationship = Relationship(determin_train_loader, source_classifier, device,
os.path.join(logger.root, args.relationship))
co_tuning_loss = CoTuningLoss()
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
# train for one epoch
train(train_iter, classifier, optimizer, epoch, relationship, co_tuning_loss, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
epoch: int, relationship, co_tuning_loss, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, label_t = next(train_iter)
x = x.to(device)
label_s = torch.from_numpy(relationship[label_t]).cuda().float()
label_t = label_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, y_t = model(x)
tgt_loss = F.cross_entropy(y_t, label_t)
src_loss = co_tuning_loss(y_s, label_s)
loss = tgt_loss + args.trade_off * src_loss
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
cls_acc = accuracy(y_t, label_t)[0]
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Co-Tuning for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--trade-off', default=2.3, type=float,
metavar='P', help='the trade-off hyper-parameter for co-tuning loss')
parser.add_argument("--relationship", type=str, default='relationship.npy',
help="Where to save relationship file.")
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='cotuning',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)