forked from Guzaiwang/CE-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
framework.py
executable file
·83 lines (65 loc) · 2.6 KB
/
framework.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
from torch.autograd import Variable as V
import cv2
import numpy as np
class MyFrame():
def __init__(self, net, loss, lr=2e-4, evalmode=False):
self.net = net().cuda()
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr)
# self.optimizer = torch.optim.SGD(params=self.net.parameters(), lr=lr)
self.loss = loss()
self.old_lr = lr
if evalmode:
for i in self.net.modules():
if isinstance(i, nn.BatchNorm2d):
i.eval()
def set_input(self, img_batch, mask_batch=None, img_id=None):
self.img = img_batch
self.mask = mask_batch
self.img_id = img_id
def test_one_img(self, img):
pred = self.net.forward(img)
pred[pred>0.5] = 1
pred[pred<=0.5] = 0
mask = pred.squeeze().cpu().data.numpy()
return mask
def test_batch(self):
self.forward(volatile=True)
mask = self.net.forward(self.img).cpu().data.numpy().squeeze(1)
mask[mask>0.5] = 1
mask[mask<=0.5] = 0
return mask, self.img_id
def test_one_img_from_path(self, path):
img = cv2.imread(path)
img = np.array(img, np.float32)/255.0 * 3.2 - 1.6
img = V(torch.Tensor(img).cuda())
mask = self.net.forward(img).squeeze().cpu().data.numpy()#.squeeze(1)
mask[mask>0.5] = 1
mask[mask<=0.5] = 0
return mask
def forward(self, volatile=False):
self.img = V(self.img.cuda(), volatile=volatile)
if self.mask is not None:
self.mask = V(self.mask.cuda(), volatile=volatile)
def optimize(self):
self.forward()
self.optimizer.zero_grad()
pred = self.net.forward(self.img)
loss = self.loss(self.mask, pred)
loss.backward()
self.optimizer.step()
return loss.data, pred
def save(self, path):
torch.save(self.net.state_dict(), path)
def load(self, path):
self.net.load_state_dict(torch.load(path))
def update_lr(self, new_lr, mylog, factor=False):
if factor:
new_lr = self.old_lr / new_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
print (mylog, 'update learning rate: %f -> %f' % (self.old_lr, new_lr))
print ('update learning rate: %f -> %f' % (self.old_lr, new_lr))
self.old_lr = new_lr