-
Notifications
You must be signed in to change notification settings - Fork 23
/
train_selfsup.py
98 lines (86 loc) · 4.51 KB
/
train_selfsup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import os
import numpy as np
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import train_func as tf
from augmentloader import AugmentLoader
from loss import MaximalCodingRateReduction
import utils
parser = argparse.ArgumentParser(description='Unsupervised Learning')
parser.add_argument('--arch', type=str, default='resnet18',
help='architecture for deep neural network (default: resnet18)')
parser.add_argument('--fd', type=int, default=32,
help='dimension of feature dimension (default: 32)')
parser.add_argument('--data', type=str, default='cifar10',
help='dataset for training (default: CIFAR10)')
parser.add_argument('--epo', type=int, default=50,
help='number of epochs for training (default: 50)')
parser.add_argument('--bs', type=int, default=1000,
help='input batch size for training (default: 1000)')
parser.add_argument('--aug', type=int, default=50,
help='number of augmentations per mini-batch (default: 50)')
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate (default: 0.001)')
parser.add_argument('--mom', type=float, default=0.9,
help='momentum (default: 0.9)')
parser.add_argument('--wd', type=float, default=5e-4,
help='weight decay (default: 5e-4)')
parser.add_argument('--gam1', type=float, default=1.0,
help='gamma1 for tuning empirical loss (default: 1.0)')
parser.add_argument('--gam2', type=float, default=10,
help='gamma2 for tuning empirical loss (default: 10)')
parser.add_argument('--eps', type=float, default=2,
help='eps squared (default: 2)')
parser.add_argument('--tail', type=str, default='',
help='extra information to add to folder name')
parser.add_argument('--transform', type=str, default='default',
help='transform applied to trainset (default: default')
parser.add_argument('--sampler', type=str, default='random',
help='sampler used in augmentloader (default: random')
parser.add_argument('--pretrain_dir', type=str, default=None,
help='load pretrained checkpoint for assigning labels')
parser.add_argument('--pretrain_epo', type=int, default=None,
help='load pretrained epoch for assigning labels')
parser.add_argument('--save_dir', type=str, default='./saved_models/',
help='base directory for saving PyTorch model. (default: ./saved_models/)')
parser.add_argument('--data_dir', type=str, default='./data/',
help='base directory for saving PyTorch model. (default: ./data/)')
args = parser.parse_args()
## Pipelines Setup
model_dir = os.path.join(args.save_dir,
'selfsup_{}+{}_{}_epo{}_bs{}_aug{}+{}_lr{}_mom{}_wd{}_gam1{}_gam2{}_eps{}{}'.format(
args.arch, args.fd, args.data, args.epo, args.bs, args.aug, args.transform,
args.lr, args.mom, args.wd, args.gam1, args.gam2, args.eps, args.tail))
utils.init_pipeline(model_dir)
## Prepare for Training
if args.pretrain_dir is not None:
net, _ = tf.load_checkpoint(args.pretrain_dir, args.pretrain_epo)
utils.update_params(model_dir, args.pretrain_dir)
else:
net = tf.load_architectures(args.arch, args.fd)
transforms = tf.load_transforms(args.transform)
trainset = tf.load_trainset(args.data, path=args.data_dir)
trainloader = AugmentLoader(trainset,
transforms=transforms,
sampler=args.sampler,
batch_size=args.bs,
num_aug=args.aug)
criterion = MaximalCodingRateReduction(gam1=args.gam1, gam2=args.gam2, eps=args.eps)
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.mom, weight_decay=args.wd)
scheduler = lr_scheduler.MultiStepLR(optimizer, [30, 60], gamma=0.1)
utils.save_params(model_dir, vars(args))
## Training
for epoch in range(args.epo):
for step, (batch_imgs, _, batch_idx) in enumerate(trainloader):
batch_features = net(batch_imgs.cuda())
loss, loss_empi, loss_theo = criterion(batch_features, batch_idx)
optimizer.zero_grad()
loss.backward()
optimizer.step()
utils.save_state(model_dir, epoch, step, loss.item(), *loss_empi, *loss_theo)
if step % 20 == 0:
utils.save_ckpt(model_dir, net, epoch)
scheduler.step()
utils.save_ckpt(model_dir, net, epoch)
print("training complete.")