-
Notifications
You must be signed in to change notification settings - Fork 0
/
3. 图像分割 batchsize 实现
34 lines (26 loc) · 1.28 KB
/
3. 图像分割 batchsize 实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
train_label = pd.read_csv("H:/data/JSOD/JSOD/train_augmentation_list.csv", header=None)
for i in range(200):
train_label = train_label.sample(frac=1) # 打乱顺序
rate = 0.00001
whole_loss = 0.0
img_batch = np.ndarray((4, 224, 224, 3), dtype=np.float32)
gt_batch = np.ndarray((4, 224, 224, 1), dtype=np.float32)
label_batch = np.ndarray((4, 1), dtype=np.float32)
for j, row in enumerate(train_label.iterrows()):
if (j > 0) & (j % 4 == 0):
_, loss = sess.run([train_step, cross_entropy_contact], feed_dict={xs: img_batch, yg: gt_batch, yl: label_batch, lr: rate})
whole_loss += loss
img = cv2.imread('H:/data/JSOD/JSOD/train_augmentation/' + row[1][0] + '.jpg')
img = cv2.resize(img, (224, 224))
img = img - [103.939, 116.779, 123.68]
img_batch[j % 4] = img
gt = cv2.imread('H:/data/JSOD/JSOD/train_augmentation/' + row[1][0] + '.png', cv2.IMREAD_GRAYSCALE)
gt = cv2.resize(gt, (224, 224))
gt = gt / 255
gt[gt > 0.5] = 1
gt[gt <= 0.5] = 0
gt_batch[j % 4] = gt[:, :, np.newaxis]
label_batch[j % 4] = row[1][1]
print("Epoch %d: %f" % (i, (whole_loss / 5000)))
if (i % 1) == 0:
saver.save(sess, "my_net/Saliency" + str(int(i)) + ".ckpt")