-
Notifications
You must be signed in to change notification settings - Fork 1
/
training.py
89 lines (71 loc) · 3.35 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pytorch_lightning as pl
import argparse
import torch
from trainers import CellTrainer
from dataprep import load_train_val, MyDataset
import ssl
def get_args():
parser = argparse.ArgumentParser(description='Produces training arguments:')
parser.add_argument("--arch" ,default="FPN" ,type=str, nargs='?', const=1)
parser.add_argument("--encoder_name",default="resnet18",type=str, nargs='?', const=1)
parser.add_argument("--devices" ,default=0 ,type=int, nargs='?', const=1)
parser.add_argument("--batch_size" ,default=35 ,type=int, nargs='?', const=1)
parser.add_argument("--in_channels" ,default=1 ,type=int, nargs='?', const=1)
parser.add_argument("--out_channels",default=5 ,type=int, nargs='?', const=1)
parser.add_argument("--max_epochs" ,default=1 ,type=int, nargs='?', const=1)
parser.add_argument("--encoder_depth",default=5 ,type=int, nargs='?', const=1)
parser.add_argument("--weights" ,default="imagenet",type=str, nargs='?', const=1)
return parser
def multi_class_train(main_path,arch,encoder_name,batch_size,in_channels,out_channels,devices,max_epochs,weigths):
model_dic = main_path+"models/"
log_dir = main_path+"lightning_logs/"
name = arch+"_"+encoder_name+"_"+str(in_channels)+"_"+str(out_channels)+"_"+str(max_epochs)+"_"+str(batch_size)+"_"+weigths
train_path = main_path+"pngs/train/"
valid_path = main_path+"pngs/valid/"
TRAIN_DATA = MyDataset(train_path)
VALID_DATA = MyDataset(valid_path)
train_data, valid_data = load_train_val(TRAIN_DATA,VALID_DATA,batch_size)
model = CellTrainer(arch, encoder_name, in_channels, out_channels,weigths)
tb_logger = pl.loggers.TensorBoardLogger(save_dir=log_dir+name+"/",flush_secs=10)
if torch.cuda.is_available():
trainer = pl.Trainer(accelerator="gpu", devices=devices,max_epochs=max_epochs,
logger=tb_logger,)
else:
trainer = pl.Trainer(accelerator="cpu", max_epochs=max_epochs,
logger=tb_logger,)
trainer.fit(model, train_dataloaders=train_data,val_dataloaders=valid_data)
torch.save(model,model_dic+arch+"_"+encoder_name+"_"+str(in_channels)+"_"+str(out_channels)+"_"+str(max_epochs)+"_"+str(batch_size))
if __name__ == '__main__':
ssl._create_default_https_context = ssl._create_unverified_context
main_path = "./data/"
#arch = "FPN"
#encoder_name = "resnet18"
#devices = [0]
#batch_size = 30
#in_channels = 1
#out_channels = 3
#max_epochs = 1
args = get_args().parse_args()
arch = args.arch
encoder_name = args.encoder_name
devices = args.devices
batch_size = args.batch_size
in_channels = args.in_channels
out_channels = args.out_channels
max_epochs = args.max_epochs
encoder_depth= args.encoder_depth
weights = args.weights
print(args)
if len(str(devices))>1:
devices = [int(i) for i in str(devices)]
else:
devices = [devices]
multi_class_train(main_path,
arch,
encoder_name,
batch_size,
in_channels,
out_channels,
devices,
max_epochs,
weights)