forked from Guzaiwang/CE-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
107 lines (84 loc) · 3.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import cv2
import os
from time import time
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
from networks.cenet import CE_Net_
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder
from Visualizer import Visualizer
import Constants
import image_utils
# Please specify the ID of graphics cards that you want to use
os.environ['CUDA_VISIBLE_DEVICES'] = "8"
def CE_Net_Train():
NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]
# run the Visdom
viz = Visualizer(env=NAME)
solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD
# For different 2D medical image segmentation tasks, please specify the dataset which you use
# for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.
dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE')
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batchsize,
shuffle=True,
num_workers=4)
# start the logging files
mylog = open('logs/' + NAME + '.log', 'w')
tic = time()
no_optim = 0
total_epoch = Constants.TOTAL_EPOCH
train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
for epoch in range(1, total_epoch + 1):
data_loader_iter = iter(data_loader)
train_epoch_loss = 0
index = 0
for img, mask in data_loader_iter:
solver.set_input(img, mask)
train_loss, pred = solver.optimize()
train_epoch_loss += train_loss
index = index + 1
# show the original images, predication and ground truth on the visdom.
show_image = (img + 1.6) / 3.2 * 255.
viz.img(name='images', img_=show_image[0, :, :, :])
viz.img(name='labels', img_=mask[0, :, :, :])
viz.img(name='prediction', img_=pred[0, :, :, :])
train_epoch_loss = train_epoch_loss/len(data_loader_iter)
print(mylog, '********')
print(mylog, 'epoch:', epoch, ' time:', int(time() - tic))
print(mylog, 'train_loss:', train_epoch_loss)
print(mylog, 'SHAPE:', Constants.Image_size)
print('********')
print('epoch:', epoch, ' time:', int(time() - tic))
print('train_loss:', train_epoch_loss)
print('SHAPE:', Constants.Image_size)
if train_epoch_loss >= train_epoch_best_loss:
no_optim += 1
else:
no_optim = 0
train_epoch_best_loss = train_epoch_loss
solver.save('./weights/' + NAME + '.th')
if no_optim > Constants.NUM_EARLY_STOP:
print(mylog, 'early stop at %d epoch' % epoch)
print('early stop at %d epoch' % epoch)
break
if no_optim > Constants.NUM_UPDATE_LR:
if solver.old_lr < 5e-7:
break
solver.load('./weights/' + NAME + '.th')
solver.update_lr(2.0, factor=True, mylog=mylog)
mylog.flush()
print(mylog, 'Finish!')
print('Finish!')
mylog.close()
if __name__ == '__main__':
print(torch.__version__())
CE_Net_Train()