forked from Orienfish/SCALE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
set_utils.py
132 lines (110 loc) · 3.81 KB
/
set_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import math
import numpy as np
import os
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from networks.resnet_big import ResNet, Bottleneck, BasicBlock
from networks.resnet_pnn import resnet18_pnn
from networks.resnet_big import ConvEncoder
dataset_num_classes = {
'nabird': 555,
'oxford_pets': 37,
'cub200': 200,
'caltech101': 101,
'stanford_dogs': 120,
'voc2007': 21,
'cifar10': 10,
'cifar100': 20,
'imagenet': 1000,
'tinyimagenet': 100, # temp setting
'stream51': 51,
'core50': 50,
'mnist': 10
}
def create_model(model_type: str,
method: str,
dataset: str,
**kwargs):
if model_type == 'resnet18':
if method == 'pnn':
model = resnet18_pnn(dataset_num_classes[dataset])
else:
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
elif model_type == 'resnet34':
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
elif model_type == 'resnet50':
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
elif model_type == 'resnet101':
ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
elif model_type == 'cnn':
model = ConvEncoder()
else:
raise ValueError(model_type)
return model
def load_student_backbone(model_type: str,
method: str,
dataset: str,
ckpt: str,
**kwargs):
"""
Load student model of model_type and pretrained weights from ckpt.
"""
# Set models
model = create_model(model_type, method, dataset, **kwargs)
model = torch.nn.DataParallel(model)
if torch.cuda.is_available():
model = model.cuda()
cudnn.benchmark = True
if ckpt is not None:
state = torch.load(ckpt)
state_dict = {}
for k in list(state.keys()):
if k.startswith("fc."):
continue
state_dict['module.' + k] = state[k]
del state[k]
model.load_state_dict(state_dict)
return model
def set_constant_learning_rate(lr, optimizer):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def adjust_learning_rate(args, optimizer, epoch):
lr = args.learning_rate
if args.cosine:
eta_min = lr * (args.lr_decay_rate ** 3)
lr = eta_min + (lr - eta_min) * (
1 + math.cos(math.pi * epoch / args.epochs)) / 2
else:
steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
if steps > 0:
lr = lr * (args.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
if args.warm and epoch <= args.warm_epochs:
p = (batch_id + (epoch - 1) * total_batches) / \
(args.warm_epochs * total_batches)
lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def set_optimizer(lr, momentum, weight_decay, model, **kwargs):
parameters = [{
'name': 'backbone',
'params': [param for name, param in model.named_parameters()],
}]
# Include the model as heads as part of the parameters to optimize
if 'criterion' in kwargs:
parameters = [{
'name': 'backbone',
'params': [param for name, param in model.named_parameters()],
}, {
'name': 'heads',
'params': [param for name, param in kwargs['criterion'].named_parameters()],
}]
optimizer = optim.SGD(parameters,
lr=lr,
momentum=momentum,
weight_decay=weight_decay)
return optimizer