forked from khanrc/pt.darts
-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
114 lines (95 loc) · 5.18 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
""" Config class for search/augment """
import argparse
import os
import genotypes as gt
from functools import partial
import torch
def get_parser(name):
""" make default formatted parser """
parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# print default value always
parser.add_argument = partial(parser.add_argument, help=' ')
return parser
def parse_gpus(gpus):
if gpus == 'all':
return list(range(torch.cuda.device_count()))
else:
return [int(s) for s in gpus.split(',')]
class BaseConfig(argparse.Namespace):
def print_params(self, prtf=print):
prtf("")
prtf("Parameters:")
for attr, value in sorted(vars(self).items()):
prtf("{}={}".format(attr.upper(), value))
prtf("")
def as_markdown(self):
""" Return configs as markdown format """
text = "|name|value| \n|-|-| \n"
for attr, value in sorted(vars(self).items()):
text += "|{}|{}| \n".format(attr, value)
return text
class SearchConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Search config")
parser.add_argument('--name', default='cifar10')
parser.add_argument('--dataset', default='cifar10', help='CIFAR10 / MNIST / FashionMNIST')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--w_lr', type=float, default=0.025, help='lr for weights')
parser.add_argument('--w_lr_min', type=float, default=0.001, help='minimum lr for weights')
parser.add_argument('--w_momentum', type=float, default=0.9, help='momentum for weights')
parser.add_argument('--w_weight_decay', type=float, default=3e-4,
help='weight decay for weights')
parser.add_argument('--w_grad_clip', type=float, default=5.,
help='gradient clipping for weights')
parser.add_argument('--print_freq', type=int, default=50, help='print frequency')
parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
'`all` indicates use all gpus.')
parser.add_argument('--epochs', type=int, default=50, help='# of training epochs')
parser.add_argument('--init_channels', type=int, default=16)
parser.add_argument('--layers', type=int, default=8, help='# of layers')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--alpha_lr', type=float, default=3e-4, help='lr for alpha')
parser.add_argument('--alpha_weight_decay', type=float, default=1e-3,
help='weight decay for alpha')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
self.data_path = './data/'
self.path = os.path.join('searchs', self.name)
self.plot_path = os.path.join(self.path, 'plots')
self.gpus = parse_gpus(self.gpus)
class AugmentConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Augment config")
parser.add_argument('--name', required=True)
parser.add_argument('--dataset', required=True, help='CIFAR10 / MNIST / FashionMNIST')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--lr', type=float, default=0.025, help='lr for weights')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float, default=5.,
help='gradient clipping for weights')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
'`all` indicates use all gpus.')
parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
parser.add_argument('--init_channels', type=int, default=36)
parser.add_argument('--layers', type=int, default=20, help='# of layers')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')
parser.add_argument('--genotype', required=True, help='Cell genotype')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
self.data_path = './data/'
self.path = os.path.join('augments', self.name)
self.genotype = gt.from_str(self.genotype)
self.gpus = parse_gpus(self.gpus)