-
Notifications
You must be signed in to change notification settings - Fork 25
/
CyclicGen_train_stage1.py
318 lines (251 loc) · 14.6 KB
/
CyclicGen_train_stage1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""Train a voxel flow model on ucf101 dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dataset
import numpy as np
import os
import tensorflow as tf
from datetime import datetime
from CyclicGen_model import Voxel_flow_model
from utils.image_utils import imwrite
from skimage.measure import compare_ssim as ssim
from vgg16 import Vgg16
FLAGS = tf.app.flags.FLAGS
# Define necessary FLAGS
tf.app.flags.DEFINE_string('train_dir', './CyclicGen_checkpoints_stage1/',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_string('train_image_dir', './voxel_flow_train_image/',
"""Directory where to output images.""")
tf.app.flags.DEFINE_string('test_image_dir', './voxel_flow_test_image_baseline/',
"""Directory where to output images.""")
tf.app.flags.DEFINE_string('subset', 'train',
"""Either 'train' or 'validation'.""")
tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', None,
"""If specified, restore this pretrained model """
"""before beginning any training.""")
tf.app.flags.DEFINE_integer('max_steps', 10000000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_integer(
'batch_size', 8, 'The number of samples in each batch.')
tf.app.flags.DEFINE_float('initial_learning_rate', 0.0001,
"""Initial learning rate.""")
tf.app.flags.DEFINE_integer('training_data_step', 1, """The step used to reduce training data size""")
def _read_image(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string, channels=3)
# image_decoded.set_shape([256, 256, 3])
return tf.cast(image_decoded, dtype=tf.float32) / 127.5 - 1.0
def random_scaling(image, seed=1):
scaling = tf.random_uniform([], 0.4, 0.6, seed=seed)
return tf.image.resize_images(image, [tf.cast(tf.round(256*scaling), tf.int32), tf.cast(tf.round(256*scaling), tf.int32)])
def train(dataset_frame1, dataset_frame2, dataset_frame3):
"""Trains a model."""
with tf.Graph().as_default():
# Create input.
data_list_frame1 = dataset_frame1.read_data_list_file()
data_list_frame1 = data_list_frame1[::FLAGS.training_data_step]
dataset_frame1 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame1))
dataset_frame1 = dataset_frame1.apply(
tf.contrib.data.shuffle_and_repeat(buffer_size=1000000, count=None, seed=1)).map(_read_image).map(
lambda image: tf.image.random_flip_left_right(image, seed=1)).map(
lambda image: tf.image.random_flip_up_down(image, seed=1)).map(
lambda image: tf.random_crop(image, [256, 256, 3], seed=1))
dataset_frame1 = dataset_frame1.prefetch(8)
data_list_frame2 = dataset_frame2.read_data_list_file()
data_list_frame2 = data_list_frame2[::FLAGS.training_data_step]
dataset_frame2 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame2))
dataset_frame2 = dataset_frame2.apply(
tf.contrib.data.shuffle_and_repeat(buffer_size=1000000, count=None, seed=1)).map(_read_image).map(
lambda image: tf.image.random_flip_left_right(image, seed=1)).map(
lambda image: tf.image.random_flip_up_down(image, seed=1)).map(
lambda image: tf.random_crop(image, [256, 256, 3], seed=1))
dataset_frame2 = dataset_frame2.prefetch(8)
data_list_frame3 = dataset_frame3.read_data_list_file()
data_list_frame3 = data_list_frame3[::FLAGS.training_data_step]
dataset_frame3 = tf.data.Dataset.from_tensor_slices(tf.constant(data_list_frame3))
dataset_frame3 = dataset_frame3.apply(
tf.contrib.data.shuffle_and_repeat(buffer_size=1000000, count=None, seed=1)).map(_read_image).map(
lambda image: tf.image.random_flip_left_right(image, seed=1)).map(
lambda image: tf.image.random_flip_up_down(image, seed=1)).map(
lambda image: tf.random_crop(image, [256, 256, 3], seed=1))
dataset_frame3 = dataset_frame3.prefetch(8)
batch_frame1 = dataset_frame1.batch(FLAGS.batch_size).make_initializable_iterator()
batch_frame2 = dataset_frame2.batch(FLAGS.batch_size).make_initializable_iterator()
batch_frame3 = dataset_frame3.batch(FLAGS.batch_size).make_initializable_iterator()
# Create input and target placeholder.
input1 = batch_frame1.get_next()
input2 = batch_frame2.get_next()
input3 = batch_frame3.get_next()
edge_vgg_1 = Vgg16(input1,reuse=None)
edge_vgg_3 = Vgg16(input3,reuse=True)
edge_1 = tf.nn.sigmoid(edge_vgg_1.fuse)
edge_3 = tf.nn.sigmoid(edge_vgg_3.fuse)
edge_1 = tf.reshape(edge_1,[-1,input1.get_shape().as_list()[1],input1.get_shape().as_list()[2],1])
edge_3 = tf.reshape(edge_3,[-1,input1.get_shape().as_list()[1],input1.get_shape().as_list()[2],1])
with tf.variable_scope("Cycle_DVF"):
model1 = Voxel_flow_model()
prediction1, flow1 = model1.inference(tf.concat([input1, input3, edge_1, edge_3], 3))
reproduction_loss1 = model1.l1loss(prediction1, input2)
t_vars = tf.trainable_variables()
print('all layers:')
for var in t_vars: print(var.name)
dof_vars = [var for var in t_vars if not 'hed' in var.name]
print('optimize layers:')
for var in dof_vars: print(var.name)
total_loss = reproduction_loss1
# Perform learning rate scheduling.
learning_rate = FLAGS.initial_learning_rate
# Create an optimizer that performs gradient descent.
opt = tf.train.AdamOptimizer(learning_rate)
grads = opt.compute_gradients(total_loss)
with tf.variable_scope(tf.get_variable_scope(), reuse=None):
update_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss, var_list=dof_vars)
# Create summaries
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
summaries.append(tf.summary.scalar('total_loss', total_loss))
summaries.append(tf.summary.image('input1', input1, 3))
summaries.append(tf.summary.image('input2', input2, 3))
summaries.append(tf.summary.image('input3', input3, 3))
summaries.append(tf.summary.image('edge_1', edge_1, 3))
summaries.append(tf.summary.image('edge_3', edge_1, 3))
summaries.append(tf.summary.image('prediction1', prediction1, 3))
# Create a saver.
saver = tf.train.Saver(tf.all_variables(), max_to_keep=50)
# Build the summary operation from the last tower summaries.
summary_op = tf.summary.merge_all()
# Restore checkpoint from file.
if FLAGS.pretrained_model_checkpoint_path:
sess = tf.Session()
restorer = tf.train.Saver()
restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), FLAGS.pretrained_model_checkpoint_path))
sess.run([batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
else:
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run([init, batch_frame1.initializer, batch_frame2.initializer, batch_frame3.initializer])
meta_model_file = 'hed_model/new-model.ckpt'
saver2 = tf.train.Saver(var_list=[v for v in tf.all_variables() if "hed" in v.name])
saver2.restore(sess, meta_model_file)
# Summary Writter
summary_writer = tf.summary.FileWriter(
FLAGS.train_dir,
graph=sess.graph)
data_size = len(data_list_frame1)
epoch_num = int(data_size / FLAGS.batch_size)
for step in range(0, FLAGS.max_steps):
batch_idx = step % epoch_num
# Run single step update.
_, loss_value = sess.run([update_op, total_loss])
if batch_idx == 0:
print('Epoch Number: %d' % int(step / epoch_num))
if step % 10 == 0:
print("Loss at step %d: %f" % (step, loss_value))
if step % 100 == 0:
# Output Summary
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
# Save checkpoint
if step % 2000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
def validate(dataset_frame1, dataset_frame2, dataset_frame3):
"""Performs validation on model.
Args:
"""
pass
def test(dataset_frame1, dataset_frame2, dataset_frame3):
def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
"""Perform test on a trained model."""
with tf.Graph().as_default():
# Create input and target placeholder.
input_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 6))
target_placeholder = tf.placeholder(tf.float32, shape=(None, 256, 256, 3))
edge_vgg_1 = Vgg16(input_placeholder[:, :, :, :3], reuse=None)
edge_vgg_3 = Vgg16(input_placeholder[:, :, :, 3:6], reuse=True)
edge_1 = tf.nn.sigmoid(edge_vgg_1.fuse)
edge_3 = tf.nn.sigmoid(edge_vgg_3.fuse)
edge_1 = tf.reshape(edge_1, [-1, input_placeholder.get_shape().as_list()[1], input_placeholder.get_shape().as_list()[2], 1])
edge_3 = tf.reshape(edge_3, [-1, input_placeholder.get_shape().as_list()[1], input_placeholder.get_shape().as_list()[2], 1])
with tf.variable_scope("Cycle_DVF"):
# Prepare model.
model = Voxel_flow_model(is_train=False)
prediction = model.inference(tf.concat([input_placeholder, edge_1, edge_3], 3))
# Create a saver and load.
sess = tf.Session()
# Restore checkpoint from file.
if FLAGS.pretrained_model_checkpoint_path:
restorer = tf.train.Saver()
restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), FLAGS.pretrained_model_checkpoint_path))
# Process on test dataset.
data_list_frame1 = dataset_frame1.read_data_list_file()
data_size = len(data_list_frame1)
data_list_frame2 = dataset_frame2.read_data_list_file()
data_list_frame3 = dataset_frame3.read_data_list_file()
i = 0
PSNR = 0
SSIM = 0
for id_img in range(0, data_size):
UCF_index = data_list_frame1[id_img][:-12]
# Load single data.
batch_data_frame1 = [dataset_frame1.process_func(os.path.join('ucf101_interp_ours', ll)[:-5] + '00.png') for
ll in data_list_frame1[id_img:id_img + 1]]
batch_data_frame2 = [dataset_frame2.process_func(os.path.join('ucf101_interp_ours', ll)[:-5] + '01_gt.png')
for ll in data_list_frame2[id_img:id_img + 1]]
batch_data_frame3 = [dataset_frame3.process_func(os.path.join('ucf101_interp_ours', ll)[:-5] + '02.png') for
ll in data_list_frame3[id_img:id_img + 1]]
batch_data_mask = [
dataset_frame3.process_func(os.path.join('motion_masks_ucf101_interp', ll)[:-11] + 'motion_mask.png')
for ll in data_list_frame3[id_img:id_img + 1]]
batch_data_frame1 = np.array(batch_data_frame1)
batch_data_frame2 = np.array(batch_data_frame2)
batch_data_frame3 = np.array(batch_data_frame3)
batch_data_mask = (np.array(batch_data_mask) + 1.0) / 2.0
feed_dict = {input_placeholder: np.concatenate((batch_data_frame1, batch_data_frame3), 3),
target_placeholder: batch_data_frame2}
# Run single step update.
prediction_np, target_np, warped_img1, warped_img2 = sess.run([prediction,
target_placeholder, model.warped_img1,
model.warped_img2],
feed_dict=feed_dict)
imwrite('ucf101_interp_ours/' + str(UCF_index) + '/frame_01_CyclicGen.png', prediction_np[0][-1, :, :, :])
print(np.sum(batch_data_mask))
if np.sum(batch_data_mask) > 0:
img_pred_mask = np.expand_dims(batch_data_mask[0], -1) * (prediction_np[0][-1] + 1.0) / 2.0
img_target_mask = np.expand_dims(batch_data_mask[0], -1) * (target_np[-1] + 1.0) / 2.0
mse = np.sum((img_pred_mask - img_target_mask) ** 2) / (3. * np.sum(batch_data_mask))
psnr_cur = 20.0 * np.log10(1.0) - 10.0 * np.log10(mse)
img_pred_gray = rgb2gray((prediction_np[0][-1] + 1.0) / 2.0)
img_target_gray = rgb2gray((target_np[-1] + 1.0) / 2.0)
ssim_cur = ssim(img_pred_gray, img_target_gray, data_range=1.0)
PSNR += psnr_cur
SSIM += ssim_cur
i += 1
print("Overall PSNR: %f db" % (PSNR / i))
print("Overall SSIM: %f db" % (SSIM / i))
if __name__ == '__main__':
if FLAGS.subset == 'train':
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
data_list_path_frame1 = "data_list/ucf101_train_files_frame1.txt"
data_list_path_frame2 = "data_list/ucf101_train_files_frame2.txt"
data_list_path_frame3 = "data_list/ucf101_train_files_frame3.txt"
ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1)
ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2)
ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3)
train(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)
elif FLAGS.subset == 'test':
os.environ["CUDA_VISIBLE_DEVICES"] = ""
data_list_path_frame1 = "data_list/ucf101_test_files_frame1.txt"
data_list_path_frame2 = "data_list/ucf101_test_files_frame2.txt"
data_list_path_frame3 = "data_list/ucf101_test_files_frame3.txt"
ucf101_dataset_frame1 = dataset.Dataset(data_list_path_frame1)
ucf101_dataset_frame2 = dataset.Dataset(data_list_path_frame2)
ucf101_dataset_frame3 = dataset.Dataset(data_list_path_frame3)
test(ucf101_dataset_frame1, ucf101_dataset_frame2, ucf101_dataset_frame3)